diff --git a/.gitignore b/.gitignore index bc1061e..2cf3874 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ srtrelay config.toml -/.vscode \ No newline at end of file +# VSCode +/.vscode +# Jetbrains +/.idea +*.iml \ No newline at end of file diff --git a/api/server.go b/api/server.go index 789dec7..0395b42 100644 --- a/api/server.go +++ b/api/server.go @@ -3,6 +3,7 @@ package api import ( "context" "encoding/json" + "errors" "log" "net/http" "sync" @@ -49,7 +50,7 @@ func (s *Server) Listen(ctx context.Context) error { go func() { defer s.done.Done() err := serv.ListenAndServe() - if err != nil && err != http.ErrServerClosed { + if err != nil && !errors.Is(err, http.ErrServerClosed) { log.Println(err) } }() diff --git a/auth/http.go b/auth/http.go index fe868c4..d36914d 100644 --- a/auth/http.go +++ b/auth/http.go @@ -14,18 +14,16 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" ) -var ( - requestDurations = promauto.NewHistogramVec( - prometheus.HistogramOpts{ - Namespace: metrics.Namespace, - Subsystem: "auth", - Name: "request_duration_seconds", - Help: "A histogram of auth http request latencies.", - Buckets: prometheus.DefBuckets, - NativeHistogramBucketFactor: 1.1, - }, - []string{"url", "application"}, - ) +var requestDurations = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: metrics.Namespace, + Subsystem: "auth", + Name: "request_duration_seconds", + Help: "A histogram of auth http request latencies.", + Buckets: prometheus.DefBuckets, + NativeHistogramBucketFactor: 1.1, + }, + []string{"url", "application"}, ) type httpAuth struct { @@ -52,7 +50,7 @@ type HTTPAuthConfig struct { } // NewHttpAuth creates an Authenticator with a HTTP backend -func NewHTTPAuth(authConfig HTTPAuthConfig) *httpAuth { +func NewHTTPAuth(authConfig HTTPAuthConfig) Authenticator { m := requestDurations.MustCurryWith(prometheus.Labels{"url": authConfig.URL, "application": authConfig.Application}) return &httpAuth{ config: authConfig, @@ -74,6 +72,7 @@ func (h *httpAuth) Authenticate(streamid stream.StreamID) bool { "call": {streamid.Mode().String()}, "app": {h.config.Application}, "name": {streamid.Name()}, + "username": {streamid.Username()}, h.config.PasswordParam: {streamid.Password()}, }) if err != nil { diff --git a/srt/server.go b/srt/server.go index 3afb3e8..8b61971 100644 --- a/srt/server.go +++ b/srt/server.go @@ -331,8 +331,8 @@ func (s *ServerImpl) registerForStats(ctx context.Context, conn *srtConn) { func (s *ServerImpl) GetStatistics() []*relay.StreamStatistics { streams := s.relay.GetStatistics() - for _, stream := range streams { - stream.URL = fmt.Sprintf("srt://%s?streamid=play/%s", s.config.PublicAddress, stream.Name) + for _, st := range streams { + st.URL = fmt.Sprintf("srt://%s?streamid=#!::m=request,r=%s", s.config.PublicAddress, st.Name) // New format } return streams } diff --git a/srt/server_test.go b/srt/server_test.go index d7d3187..5217b44 100644 --- a/srt/server_test.go +++ b/srt/server_test.go @@ -46,7 +46,7 @@ func TestServerImpl_GetStatistics(t *testing.T) { streams := s.GetStatistics() expected := []*relay.StreamStatistics{ - {Name: "s1", URL: "srt://testserver.de:1337?streamid=play/s1", Clients: 2, Created: streams[0].Created}, + {Name: "s1", URL: "srt://testserver.de:1337?streamid=#!::m=request,r=s1", Clients: 2, Created: streams[0].Created}, // New Format } if err := compareStats(streams, expected); err != nil { t.Error(err) @@ -65,8 +65,8 @@ func (s *testSocket) Read(b []byte) (int, error) { if !ok { return 0, io.EOF } - len := copy(b, buf) - return len, nil + length := copy(b, buf) + return length, nil } func (s *testSocket) Write(b []byte) (int, error) { diff --git a/stream/streamid.go b/stream/streamid.go index d83fbc8..17c61e2 100644 --- a/stream/streamid.go +++ b/stream/streamid.go @@ -8,11 +8,14 @@ import ( "github.com/IGLOU-EU/go-wildcard/v2" ) +const IDPrefix = "#!::" + var ( InvalidSlashes = errors.New("Invalid number of slashes, must be 1 or 2") InvalidMode = errors.New("Invalid mode") MissingName = errors.New("Missing name after slash") InvalidNamePassword = errors.New("Name/Password is not allowed to contain slashes") + InvalidValue = fmt.Errorf("Invalid value") ) // Mode - client mode @@ -41,9 +44,10 @@ type StreamID struct { mode Mode name string password string + username string } -// Creates new StreamID +// NewStreamID creates new StreamID // returns error if mode is invalid. // id is nil on error func NewStreamID(name string, password string, mode Mode) (*StreamID, error) { @@ -61,39 +65,77 @@ func NewStreamID(name string, password string, mode Mode) (*StreamID, error) { } // FromString reads a streamid from a string. -// The accepted stream id format is //. -// The second slash and password is optional and defaults to empty. +// The accepted old stream id format is //. The second slash and password is +// optional and defaults to empty. The new format is `#!::m=(request|publish),r=(stream-key),u=(username),s=(password)` // If error is not nil then StreamID will remain unchanged. func (s *StreamID) FromString(src string) error { - split := strings.Split(src, "/") - password := "" - if len(split) == 3 { - password = split[2] - } else if len(split) != 2 { - return InvalidSlashes - } - modeStr := split[0] - name := split[1] + if strings.HasPrefix(src, IDPrefix) { + for _, kv := range strings.Split(src[len(IDPrefix):], ",") { + kv2 := strings.SplitN(kv, "=", 2) + if len(kv2) != 2 { + return InvalidValue + } - if len(name) == 0 { - return MissingName + key, value := kv2[0], kv2[1] + + switch key { + case "u": + s.username = value + + case "r": + s.name = value + + case "h": + + case "s": + s.password = value + + case "t": + + case "m": + switch value { + case "request": + s.mode = ModePlay + + case "publish": + s.mode = ModePublish + + default: + return InvalidMode + } + + default: + return fmt.Errorf("unsupported key '%s'", key) + } + } + } else { + split := strings.Split(src, "/") + + s.password = "" + if len(split) == 3 { + s.password = split[2] + } else if len(split) != 2 { + return InvalidSlashes + } + modeStr := split[0] + s.name = split[1] + + switch modeStr { + case "play": + s.mode = ModePlay + case "publish": + s.mode = ModePublish + default: + return InvalidMode + } } - var mode Mode - switch modeStr { - case "play": - mode = ModePlay - case "publish": - mode = ModePublish - default: - return InvalidMode + if len(s.name) == 0 { + return MissingName } s.str = src - s.mode = mode - s.name = name - s.password = password return nil } @@ -140,3 +182,7 @@ func (s StreamID) Name() string { func (s StreamID) Password() string { return s.password } + +func (s StreamID) Username() string { + return s.username +} diff --git a/stream/streamid_test.go b/stream/streamid_test.go index 5d4e204..d9fa88a 100644 --- a/stream/streamid_test.go +++ b/stream/streamid_test.go @@ -1,36 +1,72 @@ package stream import ( + "errors" + "fmt" "testing" ) func TestParseStreamID(t *testing.T) { tests := []struct { - name string - streamID string - wantMode Mode - wantName string - wantPass string - wantErr error + name string + streamID string + wantMode Mode + wantName string + wantPass string + wantUsername string + wantErr error }{ - {"MissingSlash", "s1", 0, "", "", InvalidSlashes}, - {"InvalidName", "play//s1", 0, "", "", MissingName}, - {"InvalidMode", "foobar/bla", 0, "", "", InvalidMode}, - {"InvalidSlash", "foobar/bla//", 0, "", "", InvalidSlashes}, - {"EmptyPass", "play/s1/", ModePlay, "s1", "", nil}, - {"ValidPass", "play/s1/#![äöü", ModePlay, "s1", "#![äöü", nil}, - {"ValidPlay", "play/s1", ModePlay, "s1", "", nil}, - {"ValidPublish", "publish/abcdef", ModePublish, "abcdef", "", nil}, - {"ValidPlaySpace", "play/bla fasel", ModePlay, "bla fasel", "", nil}, + // Old school + {"MissingSlash", "s1", 0, "", "", "", InvalidSlashes}, + {"InvalidName", "play//s1", 0, "", "", "", MissingName}, + {"InvalidMode", "foobar/bla", 0, "", "", "", InvalidMode}, + {"InvalidSlash", "foobar/bla//", 0, "", "", "", InvalidSlashes}, + {"EmptyPass", "play/s1/", ModePlay, "s1", "", "", nil}, + {"ValidPass", "play/s1/#![äöü", ModePlay, "s1", "#![äöü", "", nil}, + {"ValidPlay", "play/s1", ModePlay, "s1", "", "", nil}, + {"ValidPublish", "publish/abcdef", ModePublish, "abcdef", "", "", nil}, + {"ValidPlaySpace", "play/bla fasel", ModePlay, "bla fasel", "", "", nil}, + // New hotness - Bad + {"NewInvalidPubEmptyName", "#!::m=publish", ModePublish, "", "", "", MissingName}, + {"NewInvalidPlayEmptyName", "#!::m=request", ModePlay, "", "", "", MissingName}, + {"NewInvalidPubBadKey", "#!::m=publish,y=bar", ModePublish, "", "", "", fmt.Errorf("unsupported key '%s'", "y")}, + {"NewInvalidPlayBadKey", "#!::m=request,x=foo", ModePlay, "", "", "", fmt.Errorf("unsupported key '%s'", "x")}, + {"NewInvalidPubNoEquals", "#!::m=publish,r", ModePublish, "abc", "", "", InvalidValue}, + {"NewInvalidPlayNoEquals", "#!::m=request,r", ModePlay, "abc", "", "", InvalidValue}, + {"NewInvalidPubNoValue", "#!::m=publish,r=", ModePublish, "abc", "", "", MissingName}, + {"NewInvalidPlayNoValue", "#!::m=request,s=", ModePlay, "abc", "", "", MissingName}, + {"NewInvalidPubBadKey", "#!::m=publish,x=", ModePublish, "abc", "", "", fmt.Errorf("unsupported key '%s'", "x")}, + {"NewInvalidPlayBadKey", "#!::m=request,y=", ModePlay, "abc", "", "", fmt.Errorf("unsupported key '%s'", "y")}, + // New hotness - Standard + {"NewValidNameRequest", "#!::m=publish,r=abc", ModePublish, "abc", "", "", nil}, + {"NewValidPlay", "#!::m=request,r=abc", ModePlay, "abc", "", "", nil}, + {"NewValidNameRequestRev", "#!::r=abc,m=publish", ModePublish, "abc", "", "", nil}, + {"NewValidPlayRev", "#!::r=abc,m=request", ModePlay, "abc", "", "", nil}, + {"NewValidPassPub", "#!::m=publish,r=abc,s=bob", ModePublish, "abc", "bob", "", nil}, + {"NewValidPassPlay", "#!::m=request,r=abc,s=alice", ModePlay, "abc", "alice", "", nil}, + {"NewValidPassPubOrder", "#!::s=bob,m=publish,r=abc123", ModePublish, "abc123", "bob", "", nil}, + {"NewValidPassPlayOrder", "#!::m=request,s=alice,r=def", ModePlay, "def", "alice", "", nil}, + {"NewValidPubUsername", "#!::s=bob,m=publish,r=abc123,u=eve", ModePublish, "abc123", "bob", "eve", nil}, + {"NewValidPlayUsername", "#!::m=request,s=alice,r=def,u=bar", ModePlay, "def", "alice", "bar", nil}, + {"NewValidPubUsernameOrder", "#!::s=bob,m=publish,u=eve,r=abc123", ModePublish, "abc123", "bob", "eve", nil}, + {"NewValidPlayUsernameOrder", "#!::m=request,u=bar,s=alice,r=def", ModePlay, "def", "alice", "bar", nil}, + // New Hotness - Unicode + {"NewValidUnicodePub", "#!::m=publish,r=#![äöü,s=bob", ModePublish, "#![äöü", "bob", "", nil}, + {"NewValidUnicodePlay", "#!::m=request,r=#![äöü,s=alice", ModePlay, "#![äöü", "alice", "", nil}, + {"NewValidUnicodePassPub", "#!::m=publish,s=#![äöü,r=bob", ModePublish, "bob", "#![äöü", "", nil}, + {"NewValidUnicodePassPlay", "#!::m=request,s=#![äöü,r=alice", ModePlay, "alice", "#![äöü", "", nil}, + {"NewValidUnicodeUserPub", "#!::s=bye,m=publish,u=#![äöü,r=art", ModePublish, "art", "bye", "#![äöü", nil}, + {"NewValidUnicodeUserPlay", "#!::m=request,u=#![äöü,r=eve,s=hai", ModePlay, "eve", "hai", "#![äöü", nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var streamid StreamID err := streamid.FromString(tt.streamID) - if err != tt.wantErr { - t.Errorf("ParseStreamID() error = %v, wantErr %v", err, tt.wantErr) - } + if err != nil { + if err.Error() != tt.wantErr.Error() { // Only really care about str value for this, otherwise: if !errors.Is(err, tt.wantErr) { + t.Errorf("ParseStreamID() error = %v, wantErr %v", err, tt.wantErr) + } if streamid.String() != "" { t.Error("str should be empty on failed parse") } @@ -48,6 +84,9 @@ func TestParseStreamID(t *testing.T) { if str := streamid.String(); str != tt.streamID { t.Errorf("String() got String = %v, want %v", str, tt.streamID) } + if str := streamid.Username(); str != tt.wantUsername { + t.Errorf("Username() got String = %v, want %v", str, tt.wantUsername) + } }) } } @@ -72,7 +111,7 @@ func TestNewStreamID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { id, err := NewStreamID(tt.argName, tt.argPassword, tt.argMode) - if err != tt.wantErr { + if !errors.Is(err, tt.wantErr) { t.Errorf("ParseStreamID() error = %v, wantErr %v", err, tt.wantErr) } if err != nil {