Skip to content

Commit

Permalink
Make ExtensionManagerServer.Shutdown idempotent (#117)
Browse files Browse the repository at this point in the history
* Make Shutdown idempotent

* Protect access to s.serverClient

* Add sleep to make retry effective
  • Loading branch information
lucasmrod authored Nov 8, 2023
1 parent b411f54 commit e3cde12
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 37 deletions.
21 changes: 14 additions & 7 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,16 @@ func (s *ExtensionManagerServer) Run() error {
for {
time.Sleep(s.pingInterval)

s.mutex.Lock()
serverClient := s.serverClient
s.mutex.Unlock()

// can't ping if s.Shutdown has already happened
if s.serverClient == nil {
if serverClient == nil {
break
}

status, err := s.serverClient.Ping()
status, err := serverClient.Ping()
if err != nil {
errc <- errors.Wrap(err, "extension ping failed")
break
Expand Down Expand Up @@ -323,12 +327,15 @@ func (s *ExtensionManagerServer) Shutdown(ctx context.Context) (err error) {
s.mutex.Lock()
defer s.mutex.Unlock()

stat, err := s.serverClient.DeregisterExtension(s.uuid)
err = errors.Wrap(err, "deregistering extension")
if err == nil && stat.Code != 0 {
err = errors.Errorf("status %d deregistering extension: %s", stat.Code, stat.Message)
if s.serverClient != nil {
var stat *osquery.ExtensionStatus
stat, err = s.serverClient.DeregisterExtension(s.uuid)
err = errors.Wrap(err, "deregistering extension")
if err == nil && stat.Code != 0 {
err = errors.Errorf("status %d deregistering extension: %s", stat.Code, stat.Message)
}
}
s.serverClient.Close()

if s.server != nil {
server := s.server
s.server = nil
Expand Down
90 changes: 60 additions & 30 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
// Verify that an error in server.Start will return an error instead of deadlock.
func TestNoDeadlockOnError(t *testing.T) {
registry := make(map[string](map[string]OsqueryPlugin))
for reg, _ := range validRegistryNames {
for reg := range validRegistryNames {
registry[reg] = make(map[string]OsqueryPlugin)
}
mut := sync.Mutex{}
Expand All @@ -43,8 +43,9 @@ func TestNoDeadlockOnError(t *testing.T) {
CloseFunc: func() {},
}
server := &ExtensionManagerServer{
serverClient: mock,
registry: registry,
serverClient: mock,
registry: registry,
serverClientShouldShutdown: true,
}

log := func(ctx context.Context, typ logger.LogType, logText string) error {
Expand All @@ -63,8 +64,12 @@ func TestNoDeadlockOnError(t *testing.T) {
// Ensure that the extension server will shutdown and return if the osquery
// instance it is talking to stops responding to pings.
func TestShutdownWhenPingFails(t *testing.T) {
tempPath, err := ioutil.TempFile("", "")
require.Nil(t, err)
defer os.Remove(tempPath.Name())

registry := make(map[string](map[string]OsqueryPlugin))
for reg, _ := range validRegistryNames {
for reg := range validRegistryNames {
registry[reg] = make(map[string]OsqueryPlugin)
}
mock := &MockExtensionManager{
Expand All @@ -81,11 +86,14 @@ func TestShutdownWhenPingFails(t *testing.T) {
CloseFunc: func() {},
}
server := &ExtensionManagerServer{
serverClient: mock,
registry: registry,
serverClient: mock,
registry: registry,
serverClientShouldShutdown: true,
pingInterval: 1 * time.Second,
sockPath: tempPath.Name(),
}

err := server.Run()
err = server.Run()
assert.Error(t, err)
assert.Contains(t, err.Error(), "broken pipe")
assert.True(t, mock.DeRegisterExtensionFuncInvoked)
Expand All @@ -106,6 +114,7 @@ func TestShutdownDeadlock(t *testing.T) {
})
}
}

func testShutdownDeadlock(t *testing.T, uuid int) {
tempPath, err := ioutil.TempFile("", "")
require.Nil(t, err)
Expand All @@ -122,9 +131,10 @@ func testShutdownDeadlock(t *testing.T, uuid int) {
CloseFunc: func() {},
}
server := ExtensionManagerServer{
serverClient: mock,
sockPath: tempPath.Name(),
timeout: defaultTimeout,
serverClient: mock,
sockPath: tempPath.Name(),
timeout: defaultTimeout,
serverClientShouldShutdown: true,
}

var wait sync.WaitGroup
Expand Down Expand Up @@ -152,8 +162,12 @@ func testShutdownDeadlock(t *testing.T, uuid int) {
for !opened && attempt < 10 {
transport = thrift.NewTSocketFromAddrTimeout(addr, timeout, timeout)
err = transport.Open()
opened = err == nil
attempt++
if err != nil {
time.Sleep(1 * time.Second)
} else {
opened = true
}
}
require.NoError(t, err)
client := osquery.NewExtensionManagerClientFactory(transport,
Expand Down Expand Up @@ -193,9 +207,13 @@ func testShutdownDeadlock(t *testing.T, uuid int) {
}

func TestShutdownBasic(t *testing.T) {
tempPath, err := ioutil.TempFile("", "")
require.Nil(t, err)
defer os.Remove(tempPath.Name())
dir := t.TempDir()

tempPath := func() string {
tmp, err := os.CreateTemp(dir, "")
require.NoError(t, err)
return tmp.Name()
}

retUUID := osquery.ExtensionRouteUUID(0)
mock := &MockExtensionManager{
Expand All @@ -207,26 +225,38 @@ func TestShutdownBasic(t *testing.T) {
},
CloseFunc: func() {},
}
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}

completed := make(chan struct{})
go func() {
err := server.Start()
for _, server := range []*ExtensionManagerServer{
// Create the extension manager without using NewExtensionManagerServer.
{serverClient: mock, sockPath: tempPath()},
// Create the extension manager using ExtensionManagerServer.
{serverClient: mock, sockPath: tempPath(), serverClientShouldShutdown: true},
} {
completed := make(chan struct{})
go func() {
err := server.Start()
require.NoError(t, err)
close(completed)
}()

server.waitStarted()

err := server.Shutdown(context.Background())
require.NoError(t, err)
close(completed)
}()

server.waitStarted()
err = server.Shutdown(context.Background())
require.NoError(t, err)
// Test that server.Shutdown is idempotent.
err = server.Shutdown(context.Background())
require.NoError(t, err)

// Either indicate successful shutdown, or fatal the test because it
// hung
select {
case <-completed:
// Success. Do nothing.
case <-time.After(5 * time.Second):
t.Fatal("hung on shutdown")
}

// Either indicate successful shutdown, or fatal the test because it
// hung
select {
case <-completed:
// Success. Do nothing.
case <-time.After(5 * time.Second):
t.Fatal("hung on shutdown")
}
}

Expand Down

0 comments on commit e3cde12

Please sign in to comment.