diff --git a/cmd/strelaysrv/listener.go b/cmd/strelaysrv/listener.go index d9df57b12..d31f9b84a 100644 --- a/cmd/strelaysrv/listener.go +++ b/cmd/strelaysrv/listener.go @@ -23,7 +23,7 @@ var ( numConnections int64 ) -func listener(_, addr string, config *tls.Config) { +func listener(_, addr string, config *tls.Config, token string) { tcpListener, err := net.Listen("tcp", addr) if err != nil { log.Fatalln(err) @@ -49,7 +49,7 @@ func listener(_, addr string, config *tls.Config) { } if isTLS { - go protocolConnectionHandler(conn, config) + go protocolConnectionHandler(conn, config, token) } else { go sessionConnectionHandler(conn) } @@ -57,7 +57,7 @@ func listener(_, addr string, config *tls.Config) { } } -func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { +func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config, token string) { conn := tls.Server(tcpConn, config) if err := conn.SetDeadline(time.Now().Add(messageTimeout)); err != nil { if debug { @@ -119,6 +119,15 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { switch msg := message.(type) { case protocol.JoinRelayRequest: + if token != "" && msg.Token != token { + if debug { + log.Printf("invalid token %s\n", msg.Token) + } + protocol.WriteMessage(conn, protocol.ResponseWrongToken) + conn.Close() + continue + } + if atomic.LoadInt32(&overLimit) > 0 { protocol.WriteMessage(conn, protocol.RelayFull{}) if debug { diff --git a/cmd/strelaysrv/main.go b/cmd/strelaysrv/main.go index 4633cd2ea..bdf317875 100644 --- a/cmd/strelaysrv/main.go +++ b/cmd/strelaysrv/main.go @@ -56,6 +56,7 @@ var ( networkBufferSize int statusAddr string + token string poolAddrs string pools []string providedBy string @@ -89,6 +90,7 @@ func main() { flag.IntVar(&globalLimitBps, "global-rate", globalLimitBps, "Global rate limit, in bytes/s") flag.BoolVar(&debug, "debug", debug, "Enable debug output") flag.StringVar(&statusAddr, "status-srv", ":22070", "Listen address for status service (blank to disable)") + flag.StringVar(&token, "token", "", "Token to restrict access to the relay (optional). Disables joining any pools.") flag.StringVar(&poolAddrs, "pools", defaultPoolAddrs, "Comma separated list of relay pool addresses to join") flag.StringVar(&providedBy, "provided-by", "", "An optional description about who provides the relay") flag.StringVar(&extAddress, "ext-address", "", "An optional address to advertise as being available on.\n\tAllows listening on an unprivileged port with port forwarding from e.g. 443, and be connected to on port 443.") @@ -256,6 +258,10 @@ func main() { log.Println("URI:", uri.String()) + if token != "" { + poolAddrs = "" + } + if poolAddrs == defaultPoolAddrs { log.Println("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") log.Println("!! Joining default relay pools, this relay will be available for public use. !!") @@ -271,7 +277,7 @@ func main() { } } - go listener(proto, listen, tlsCfg) + go listener(proto, listen, tlsCfg, token) sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) diff --git a/lib/relay/client/static.go b/lib/relay/client/static.go index 1be039fbe..11f50083b 100644 --- a/lib/relay/client/static.go +++ b/lib/relay/client/static.go @@ -27,7 +27,8 @@ type staticClient struct { messageTimeout time.Duration connectTimeout time.Duration - conn *tls.Conn + conn *tls.Conn + token string } func newStaticClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation, timeout time.Duration) *staticClient { @@ -38,6 +39,8 @@ func newStaticClient(uri *url.URL, certs []tls.Certificate, invitations chan pro messageTimeout: time.Minute * 2, connectTimeout: timeout, + + token: uri.Query().Get("token"), } c.commonClient = newCommonClient(invitations, c.serve, c.String()) return c @@ -173,7 +176,7 @@ func (c *staticClient) disconnect() { } func (c *staticClient) join() error { - if err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{}); err != nil { + if err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{Token: c.token}); err != nil { return err } diff --git a/lib/relay/protocol/packets.go b/lib/relay/protocol/packets.go index 7f06dca86..1e85e2da6 100644 --- a/lib/relay/protocol/packets.go +++ b/lib/relay/protocol/packets.go @@ -31,9 +31,12 @@ type header struct { type Ping struct{} type Pong struct{} -type JoinRelayRequest struct{} type RelayFull struct{} +type JoinRelayRequest struct { + Token string +} + type JoinSessionRequest struct { Key []byte // max:32 } diff --git a/lib/relay/protocol/packets_xdr.go b/lib/relay/protocol/packets_xdr.go index 51c20350c..81200aa05 100644 --- a/lib/relay/protocol/packets_xdr.go +++ b/lib/relay/protocol/packets_xdr.go @@ -137,40 +137,6 @@ func (*Pong) UnmarshalXDRFrom(_ *xdr.Unmarshaller) error { /* -JoinRelayRequest Structure: -(contains no fields) - - -struct JoinRelayRequest { -} - -*/ - -func (JoinRelayRequest) XDRSize() int { - return 0 -} -func (JoinRelayRequest) MarshalXDR() ([]byte, error) { - return nil, nil -} - -func (JoinRelayRequest) MustMarshalXDR() []byte { - return nil -} - -func (JoinRelayRequest) MarshalXDRInto(_ *xdr.Marshaller) error { - return nil -} - -func (*JoinRelayRequest) UnmarshalXDR(_ []byte) error { - return nil -} - -func (*JoinRelayRequest) UnmarshalXDRFrom(_ *xdr.Unmarshaller) error { - return nil -} - -/* - RelayFull Structure: (contains no fields) @@ -205,6 +171,57 @@ func (*RelayFull) UnmarshalXDRFrom(_ *xdr.Unmarshaller) error { /* +JoinRelayRequest Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/ / +\ Token (length + padded data) \ +/ / ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct JoinRelayRequest { + string Token<>; +} + +*/ + +func (o JoinRelayRequest) XDRSize() int { + return 4 + len(o.Token) + xdr.Padding(len(o.Token)) +} + +func (o JoinRelayRequest) MarshalXDR() ([]byte, error) { + buf := make([]byte, o.XDRSize()) + m := &xdr.Marshaller{Data: buf} + return buf, o.MarshalXDRInto(m) +} + +func (o JoinRelayRequest) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o JoinRelayRequest) MarshalXDRInto(m *xdr.Marshaller) error { + m.MarshalString(o.Token) + return m.Error +} + +func (o *JoinRelayRequest) UnmarshalXDR(bs []byte) error { + u := &xdr.Unmarshaller{Data: bs} + return o.UnmarshalXDRFrom(u) +} +func (o *JoinRelayRequest) UnmarshalXDRFrom(u *xdr.Unmarshaller) error { + o.Token = u.UnmarshalString() + return u.Error +} + +/* + JoinSessionRequest Structure: 0 1 2 3 diff --git a/lib/relay/protocol/protocol.go b/lib/relay/protocol/protocol.go index 0bc079ab6..a4eb4d8d5 100644 --- a/lib/relay/protocol/protocol.go +++ b/lib/relay/protocol/protocol.go @@ -17,6 +17,7 @@ var ( ResponseSuccess = Response{0, "success"} ResponseNotFound = Response{1, "not found"} ResponseAlreadyConnected = Response{2, "already connected"} + ResponseWrongToken = Response{3, "wrong token"} ResponseUnexpectedMessage = Response{100, "unexpected message"} ) @@ -107,6 +108,14 @@ func ReadMessage(r io.Reader) (interface{}, error) { return msg, err case messageTypeJoinRelayRequest: var msg JoinRelayRequest + + // In prior versions of the protocol JoinRelayRequest did not have a + // token field. Trying to unmarshal such a request will result in + // an error, return msg with an empty token instead. + if header.messageLength == 0 { + return msg, nil + } + err := msg.UnmarshalXDR(buf) return msg, err case messageTypeJoinSessionRequest: