From a15a7ab4300c9b7f077d5b48d3cf1e6375ee8eb9 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Fri, 3 May 2024 18:21:41 +0200 Subject: [PATCH] add ServerConn.ValidateCredentials() --- README.md | 1 + .../main.go | 44 +++-- examples/server-auth/main.go | 177 ++++++++++++++++++ pkg/liberrors/server.go | 16 ++ server.go | 4 + server_conn.go | 82 +++++++- server_test.go | 91 +++++++-- 7 files changed, 373 insertions(+), 42 deletions(-) create mode 100644 examples/server-auth/main.go diff --git a/README.md b/README.md index c03f7cfc..bfbd4af4 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,7 @@ Features: * [client-record-format-vp9](examples/client-record-format-vp9/main.go) * [server](examples/server/main.go) * [server-tls](examples/server-tls/main.go) +* [server-auth](examples/server-auth/main.go) * [server-h264-save-to-disk](examples/server-h264-save-to-disk/main.go) * [proxy](examples/proxy/main.go) diff --git a/examples/client-record-format-mjpeg-from-image/main.go b/examples/client-record-format-mjpeg-from-image/main.go index c22e5ffe..3879bb0a 100644 --- a/examples/client-record-format-mjpeg-from-image/main.go +++ b/examples/client-record-format-mjpeg-from-image/main.go @@ -20,6 +20,28 @@ import ( // 4. generate RTP/M-JPEG packets from the JPEG image // 5. write packets to the server +func createRandomImage(i int) *image.RGBA { + img := image.NewRGBA(image.Rect(0, 0, 640, 480)) + + var cl color.RGBA + switch i { + case 0: + cl = color.RGBA{255, 0, 0, 0} + case 1: + cl = color.RGBA{0, 255, 0, 0} + case 2: + cl = color.RGBA{0, 0, 255, 0} + } + + for y := 0; y < img.Rect.Dy(); y++ { + for x := 0; x < img.Rect.Dx(); x++ { + img.SetRGBA(x, y, cl) + } + } + + return img +} + func main() { // create a description that contains a M-JPEG format forma := &format.MJPEG{} @@ -59,29 +81,13 @@ func main() { i := 0 for range ticker.C { - // create a RGBA image - image := image.NewRGBA(image.Rect(0, 0, 640, 480)) - - // fill the image - var cl color.RGBA - switch i { - case 0: - cl = color.RGBA{255, 0, 0, 0} - case 1: - cl = color.RGBA{0, 255, 0, 0} - case 2: - cl = color.RGBA{0, 0, 255, 0} - } - for y := 0; y < image.Rect.Dy(); y++ { - for x := 0; x < image.Rect.Dx(); x++ { - image.SetRGBA(x, y, cl) - } - } + // create a random image + img := createRandomImage(i) i = (i + 1) % 3 // encode the image with JPEG var buf bytes.Buffer - err := jpeg.Encode(&buf, image, &jpeg.Options{Quality: 80}) + err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 80}) if err != nil { panic(err) } diff --git a/examples/server-auth/main.go b/examples/server-auth/main.go new file mode 100644 index 00000000..132e9a7c --- /dev/null +++ b/examples/server-auth/main.go @@ -0,0 +1,177 @@ +package main + +import ( + "log" + "sync" + + "github.com/pion/rtp" + + "github.com/bluenviron/gortsplib/v4" + "github.com/bluenviron/gortsplib/v4/pkg/base" + "github.com/bluenviron/gortsplib/v4/pkg/description" + "github.com/bluenviron/gortsplib/v4/pkg/format" +) + +// This example shows how to +// 1. create a RTSP server which accepts plain connections +// 2. allow a single client to authenticate and publish a stream with TCP or UDP +// 3. allow multiple clients to authenticate and read that stream with TCP, UDP or UDP-multicast + +const ( + readUser = "readuser" + readPass = "readpass" + publishUser = "publishuser" + publishPass = "publishpass" +) + +type serverHandler struct { + s *gortsplib.Server + mutex sync.Mutex + stream *gortsplib.ServerStream + publisher *gortsplib.ServerSession +} + +// called when a connection is opened. +func (sh *serverHandler) OnConnOpen(ctx *gortsplib.ServerHandlerOnConnOpenCtx) { + log.Printf("conn opened") +} + +// called when a connection is closed. +func (sh *serverHandler) OnConnClose(ctx *gortsplib.ServerHandlerOnConnCloseCtx) { + log.Printf("conn closed (%v)", ctx.Error) +} + +// called when a session is opened. +func (sh *serverHandler) OnSessionOpen(ctx *gortsplib.ServerHandlerOnSessionOpenCtx) { + log.Printf("session opened") +} + +// called when a session is closed. +func (sh *serverHandler) OnSessionClose(ctx *gortsplib.ServerHandlerOnSessionCloseCtx) { + log.Printf("session closed") + + sh.mutex.Lock() + defer sh.mutex.Unlock() + + // if the session is the publisher, + // close the stream and disconnect any reader. + if sh.stream != nil && ctx.Session == sh.publisher { + sh.stream.Close() + sh.stream = nil + } +} + +// called when receiving a DESCRIBE request. +func (sh *serverHandler) OnDescribe(ctx *gortsplib.ServerHandlerOnDescribeCtx) (*base.Response, *gortsplib.ServerStream, error) { + log.Printf("describe request") + + res, err := ctx.Conn.ValidateCredentials(ctx.Request, readUser, readPass, nil, nil) + if err != nil { + return res, nil, err + } + + sh.mutex.Lock() + defer sh.mutex.Unlock() + + // no one is publishing yet + if sh.stream == nil { + return &base.Response{ + StatusCode: base.StatusNotFound, + }, nil, nil + } + + // send medias that are being published to the client + return &base.Response{ + StatusCode: base.StatusOK, + }, sh.stream, nil +} + +// called when receiving an ANNOUNCE request. +func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (*base.Response, error) { + log.Printf("announce request") + + res, err := ctx.Conn.ValidateCredentials(ctx.Request, publishUser, publishPass, nil, nil) + if err != nil { + return res, err + } + + sh.mutex.Lock() + defer sh.mutex.Unlock() + + // disconnect existing publisher + if sh.stream != nil { + sh.stream.Close() + sh.publisher.Close() + } + + // create the stream and save the publisher + sh.stream = gortsplib.NewServerStream(sh.s, ctx.Description) + sh.publisher = ctx.Session + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil +} + +// called when receiving a SETUP request. +func (sh *serverHandler) OnSetup(ctx *gortsplib.ServerHandlerOnSetupCtx) (*base.Response, *gortsplib.ServerStream, error) { + log.Printf("setup request") + + res, err := ctx.Conn.ValidateCredentials(ctx.Request, readUser, readPass, nil, nil) + if err != nil { + return res, nil, err + } + + // no one is publishing yet + if sh.stream == nil { + return &base.Response{ + StatusCode: base.StatusNotFound, + }, nil, nil + } + + return &base.Response{ + StatusCode: base.StatusOK, + }, sh.stream, nil +} + +// called when receiving a PLAY request. +func (sh *serverHandler) OnPlay(ctx *gortsplib.ServerHandlerOnPlayCtx) (*base.Response, error) { + log.Printf("play request") + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil +} + +// called when receiving a RECORD request. +func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*base.Response, error) { + log.Printf("record request") + + // called when receiving a RTP packet + ctx.Session.OnPacketRTPAny(func(medi *description.Media, forma format.Format, pkt *rtp.Packet) { + // route the RTP packet to all readers + sh.stream.WritePacketRTP(medi, pkt) + }) + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil +} + +func main() { + // configure the server + h := &serverHandler{} + h.s = &gortsplib.Server{ + Handler: h, + RTSPAddress: ":8554", + UDPRTPAddress: ":8000", + UDPRTCPAddress: ":8001", + MulticastIPRange: "224.1.0.0/16", + MulticastRTPPort: 8002, + MulticastRTCPPort: 8003, + } + + // start server and wait until a fatal error + log.Printf("server is ready") + panic(h.s.StartAndWait()) +} diff --git a/pkg/liberrors/server.go b/pkg/liberrors/server.go index bcd345d4..e39a3850 100644 --- a/pkg/liberrors/server.go +++ b/pkg/liberrors/server.go @@ -266,3 +266,19 @@ func (ErrServerPathNoSlash) Error() string { "This typically happens when VLC fails a request, and then switches to an " + "unsupported RTSP dialect" } + +// ErrServerFatalAuth is an error that can be returned by a server. +type ErrServerFatalAuth struct{} + +// Error implements the error interface. +func (e ErrServerFatalAuth) Error() string { + return "authentication error" +} + +// ErrServerNonFatalAuth is an error that can be returned by a server. +type ErrServerNonFatalAuth struct{} + +// Error implements the error interface. +func (e ErrServerNonFatalAuth) Error() string { + return "non-fatal authentication error" +} diff --git a/server.go b/server.go index cf4885d4..e79f5748 100644 --- a/server.go +++ b/server.go @@ -115,6 +115,7 @@ type Server struct { receiverReportPeriod time.Duration sessionTimeout time.Duration checkStreamPeriod time.Duration + authRealm string ctx context.Context ctxCancel func() @@ -181,6 +182,9 @@ func (s *Server) Start() error { if s.checkStreamPeriod == 0 { s.checkStreamPeriod = 1 * time.Second } + if s.authRealm == "" { + s.authRealm = "ipcam" + } if s.TLSConfig != nil && s.UDPRTPAddress != "" { return fmt.Errorf("TLS can't be used with UDP") diff --git a/server_conn.go b/server_conn.go index 43494321..ae05dd01 100644 --- a/server_conn.go +++ b/server_conn.go @@ -4,16 +4,19 @@ import ( "context" "crypto/tls" "errors" + "fmt" "net" gourl "net/url" "strconv" "strings" "time" + "github.com/bluenviron/gortsplib/v4/pkg/auth" "github.com/bluenviron/gortsplib/v4/pkg/base" "github.com/bluenviron/gortsplib/v4/pkg/bytecounter" "github.com/bluenviron/gortsplib/v4/pkg/conn" "github.com/bluenviron/gortsplib/v4/pkg/description" + "github.com/bluenviron/gortsplib/v4/pkg/headers" "github.com/bluenviron/gortsplib/v4/pkg/liberrors" ) @@ -65,13 +68,15 @@ type ServerConn struct { s *Server nconn net.Conn - ctx context.Context - ctxCancel func() - userData interface{} - remoteAddr *net.TCPAddr - bc *bytecounter.ByteCounter - conn *conn.Conn - session *ServerSession + ctx context.Context + ctxCancel func() + userData interface{} + remoteAddr *net.TCPAddr + bc *bytecounter.ByteCounter + conn *conn.Conn + session *ServerSession + authNonce string + authFailures int // in chReadRequest chan readReq @@ -132,6 +137,60 @@ func (sc *ServerConn) UserData() interface{} { return sc.userData } +// ValidateCredentials validates credentials provided by the user. +func (sc *ServerConn) ValidateCredentials( + req *base.Request, + expectedUser string, + expectedPass string, + baseURL *base.URL, + methods []headers.AuthMethod, +) (*base.Response, error) { + if sc.authNonce == "" { + n, err := auth.GenerateNonce() + if err != nil { + return &base.Response{ + StatusCode: base.StatusInternalServerError, + }, fmt.Errorf("unable to generate nonce") + } + sc.authNonce = n + } + + err := auth.Validate( + req, + expectedUser, + expectedPass, + baseURL, + methods, + sc.s.authRealm, + sc.authNonce) + if err != nil { + sc.authFailures++ + + // VLC with login prompt sends 4 requests: + // 1) without credentials + // 2) with password but without username + // 3) without credentials + // 4) with password and username + // therefore we must allow up to 3 failures + if sc.authFailures > 3 { + return &base.Response{ + StatusCode: base.StatusUnauthorized, + }, liberrors.ErrServerFatalAuth{} + } + + return &base.Response{ + StatusCode: base.StatusUnauthorized, + Header: base.Header{ + "WWW-Authenticate": auth.GenerateWWWAuthenticate(methods, sc.s.authRealm, sc.authNonce), + }, + }, liberrors.ErrServerNonFatalAuth{} + } + + sc.authFailures = 0 + + return nil, nil +} + func (sc *ServerConn) ip() net.IP { return sc.remoteAddr.IP } @@ -380,13 +439,18 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error { res, err := sc.handleRequestInner(req) + var eerr1 liberrors.ErrServerNonFatalAuth + if errors.As(err, &eerr1) { + err = nil + } + if res.Header == nil { res.Header = make(base.Header) } // add cseq - var eerr liberrors.ErrServerCSeqMissing - if !errors.As(err, &eerr) { + var eerr2 liberrors.ErrServerCSeqMissing + if !errors.As(err, &eerr2) { res.Header["CSeq"] = req.Header["CSeq"] } diff --git a/server_test.go b/server_test.go index 8bb768cf..da3c6355 100644 --- a/server_test.go +++ b/server_test.go @@ -1035,20 +1035,77 @@ func TestServerSessionTeardown(t *testing.T) { } func TestServerAuth(t *testing.T) { - nonce, err := auth.GenerateNonce() - require.NoError(t, err) + for _, method := range []string{"all", "basic", "digest"} { + t.Run(method, func(t *testing.T) { + s := &Server{ + Handler: &testServerHandler{ + onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + var methods []headers.AuthMethod + if method == "basic" { + methods = []headers.AuthMethod{headers.AuthBasic} + } else if method == "digest" { + methods = []headers.AuthMethod{headers.AuthDigestMD5} + } + + res, err := ctx.Conn.ValidateCredentials(ctx.Request, "myuser", "mypass", nil, methods) + if err != nil { + return res, err + } + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + RTSPAddress: "localhost:8554", + } + + err := s.Start() + require.NoError(t, err) + defer s.Close() + + nconn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) + + medias := []*description.Media{testH264Media} + req := base.Request{ + Method: base.Announce, + URL: mustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Body: mediasToSDP(medias), + } + + res, err := writeReqReadRes(conn, req) + require.NoError(t, err) + require.Equal(t, base.StatusUnauthorized, res.StatusCode) + + sender, err := auth.NewSender(res.Header["WWW-Authenticate"], "myuser", "mypass") + require.NoError(t, err) + + sender.AddAuthorization(&req) + res, err = writeReqReadRes(conn, req) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + }) + } +} + +func TestServerAuthFail(t *testing.T) { s := &Server{ Handler: &testServerHandler{ + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { + require.EqualError(t, ctx.Error, "authentication error") + }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { - err2 := auth.Validate(ctx.Request, "myuser", "mypass", nil, nil, "IPCAM", nonce) - if err2 != nil { - return &base.Response{ //nolint:nilerr - StatusCode: base.StatusUnauthorized, - Header: base.Header{ - "WWW-Authenticate": auth.GenerateWWWAuthenticate(nil, "IPCAM", nonce), - }, - }, nil + res, err := ctx.Conn.ValidateCredentials(ctx.Request, "myuser2", "mypass2", nil, nil) + if err != nil { + return res, err } return &base.Response{ @@ -1059,7 +1116,7 @@ func TestServerAuth(t *testing.T) { RTSPAddress: "localhost:8554", } - err = s.Start() + err := s.Start() require.NoError(t, err) defer s.Close() @@ -1088,7 +1145,13 @@ func TestServerAuth(t *testing.T) { require.NoError(t, err) sender.AddAuthorization(&req) - res, err = writeReqReadRes(conn, req) - require.NoError(t, err) - require.Equal(t, base.StatusOK, res.StatusCode) + + for i := 0; i < 3; i++ { + res, err = writeReqReadRes(conn, req) + require.NoError(t, err) + require.Equal(t, base.StatusUnauthorized, res.StatusCode) + } + + _, err = writeReqReadRes(conn, req) + require.Error(t, err) }