From e3cde127e7242960aaafb72e91d6c1bed1612582 Mon Sep 17 00:00:00 2001 From: Lucas Manuel Rodriguez Date: Wed, 8 Nov 2023 13:35:17 -0300 Subject: [PATCH] Make `ExtensionManagerServer.Shutdown` idempotent (#117) * Make Shutdown idempotent * Protect access to s.serverClient * Add sleep to make retry effective --- server.go | 21 ++++++++---- server_test.go | 90 +++++++++++++++++++++++++++++++++----------------- 2 files changed, 74 insertions(+), 37 deletions(-) diff --git a/server.go b/server.go index 513a418..29d0007 100644 --- a/server.go +++ b/server.go @@ -256,12 +256,16 @@ func (s *ExtensionManagerServer) Run() error { for { time.Sleep(s.pingInterval) + s.mutex.Lock() + serverClient := s.serverClient + s.mutex.Unlock() + // can't ping if s.Shutdown has already happened - if s.serverClient == nil { + if serverClient == nil { break } - status, err := s.serverClient.Ping() + status, err := serverClient.Ping() if err != nil { errc <- errors.Wrap(err, "extension ping failed") break @@ -323,12 +327,15 @@ func (s *ExtensionManagerServer) Shutdown(ctx context.Context) (err error) { s.mutex.Lock() defer s.mutex.Unlock() - stat, err := s.serverClient.DeregisterExtension(s.uuid) - err = errors.Wrap(err, "deregistering extension") - if err == nil && stat.Code != 0 { - err = errors.Errorf("status %d deregistering extension: %s", stat.Code, stat.Message) + if s.serverClient != nil { + var stat *osquery.ExtensionStatus + stat, err = s.serverClient.DeregisterExtension(s.uuid) + err = errors.Wrap(err, "deregistering extension") + if err == nil && stat.Code != 0 { + err = errors.Errorf("status %d deregistering extension: %s", stat.Code, stat.Message) + } } - s.serverClient.Close() + if s.server != nil { server := s.server s.server = nil diff --git a/server_test.go b/server_test.go index 2693d4e..cebc3d6 100644 --- a/server_test.go +++ b/server_test.go @@ -24,7 +24,7 @@ import ( // Verify that an error in server.Start will return an error instead of deadlock. func TestNoDeadlockOnError(t *testing.T) { registry := make(map[string](map[string]OsqueryPlugin)) - for reg, _ := range validRegistryNames { + for reg := range validRegistryNames { registry[reg] = make(map[string]OsqueryPlugin) } mut := sync.Mutex{} @@ -43,8 +43,9 @@ func TestNoDeadlockOnError(t *testing.T) { CloseFunc: func() {}, } server := &ExtensionManagerServer{ - serverClient: mock, - registry: registry, + serverClient: mock, + registry: registry, + serverClientShouldShutdown: true, } log := func(ctx context.Context, typ logger.LogType, logText string) error { @@ -63,8 +64,12 @@ func TestNoDeadlockOnError(t *testing.T) { // Ensure that the extension server will shutdown and return if the osquery // instance it is talking to stops responding to pings. func TestShutdownWhenPingFails(t *testing.T) { + tempPath, err := ioutil.TempFile("", "") + require.Nil(t, err) + defer os.Remove(tempPath.Name()) + registry := make(map[string](map[string]OsqueryPlugin)) - for reg, _ := range validRegistryNames { + for reg := range validRegistryNames { registry[reg] = make(map[string]OsqueryPlugin) } mock := &MockExtensionManager{ @@ -81,11 +86,14 @@ func TestShutdownWhenPingFails(t *testing.T) { CloseFunc: func() {}, } server := &ExtensionManagerServer{ - serverClient: mock, - registry: registry, + serverClient: mock, + registry: registry, + serverClientShouldShutdown: true, + pingInterval: 1 * time.Second, + sockPath: tempPath.Name(), } - err := server.Run() + err = server.Run() assert.Error(t, err) assert.Contains(t, err.Error(), "broken pipe") assert.True(t, mock.DeRegisterExtensionFuncInvoked) @@ -106,6 +114,7 @@ func TestShutdownDeadlock(t *testing.T) { }) } } + func testShutdownDeadlock(t *testing.T, uuid int) { tempPath, err := ioutil.TempFile("", "") require.Nil(t, err) @@ -122,9 +131,10 @@ func testShutdownDeadlock(t *testing.T, uuid int) { CloseFunc: func() {}, } server := ExtensionManagerServer{ - serverClient: mock, - sockPath: tempPath.Name(), - timeout: defaultTimeout, + serverClient: mock, + sockPath: tempPath.Name(), + timeout: defaultTimeout, + serverClientShouldShutdown: true, } var wait sync.WaitGroup @@ -152,8 +162,12 @@ func testShutdownDeadlock(t *testing.T, uuid int) { for !opened && attempt < 10 { transport = thrift.NewTSocketFromAddrTimeout(addr, timeout, timeout) err = transport.Open() - opened = err == nil attempt++ + if err != nil { + time.Sleep(1 * time.Second) + } else { + opened = true + } } require.NoError(t, err) client := osquery.NewExtensionManagerClientFactory(transport, @@ -193,9 +207,13 @@ func testShutdownDeadlock(t *testing.T, uuid int) { } func TestShutdownBasic(t *testing.T) { - tempPath, err := ioutil.TempFile("", "") - require.Nil(t, err) - defer os.Remove(tempPath.Name()) + dir := t.TempDir() + + tempPath := func() string { + tmp, err := os.CreateTemp(dir, "") + require.NoError(t, err) + return tmp.Name() + } retUUID := osquery.ExtensionRouteUUID(0) mock := &MockExtensionManager{ @@ -207,26 +225,38 @@ func TestShutdownBasic(t *testing.T) { }, CloseFunc: func() {}, } - server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()} - completed := make(chan struct{}) - go func() { - err := server.Start() + for _, server := range []*ExtensionManagerServer{ + // Create the extension manager without using NewExtensionManagerServer. + {serverClient: mock, sockPath: tempPath()}, + // Create the extension manager using ExtensionManagerServer. + {serverClient: mock, sockPath: tempPath(), serverClientShouldShutdown: true}, + } { + completed := make(chan struct{}) + go func() { + err := server.Start() + require.NoError(t, err) + close(completed) + }() + + server.waitStarted() + + err := server.Shutdown(context.Background()) require.NoError(t, err) - close(completed) - }() - server.waitStarted() - err = server.Shutdown(context.Background()) - require.NoError(t, err) + // Test that server.Shutdown is idempotent. + err = server.Shutdown(context.Background()) + require.NoError(t, err) + + // Either indicate successful shutdown, or fatal the test because it + // hung + select { + case <-completed: + // Success. Do nothing. + case <-time.After(5 * time.Second): + t.Fatal("hung on shutdown") + } - // Either indicate successful shutdown, or fatal the test because it - // hung - select { - case <-completed: - // Success. Do nothing. - case <-time.After(5 * time.Second): - t.Fatal("hung on shutdown") } }