Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ExtensionManagerServer.Shutdown idempotent #117

Merged
merged 4 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,16 @@ func (s *ExtensionManagerServer) Call(ctx context.Context, registry string, item
func (s *ExtensionManagerServer) Shutdown(ctx context.Context) (err error) {
s.mutex.Lock()
defer s.mutex.Unlock()
stat, err := s.serverClient.DeregisterExtension(s.uuid)
err = errors.Wrap(err, "deregistering extension")
if err == nil && stat.Code != 0 {
err = errors.Errorf("status %d deregistering extension: %s", stat.Code, stat.Message)

if s.serverClient != nil {
var stat *osquery.ExtensionStatus
stat, err = s.serverClient.DeregisterExtension(s.uuid)
err = errors.Wrap(err, "deregistering extension")
if err == nil && stat.Code != 0 {
err = errors.Errorf("status %d deregistering extension: %s", stat.Code, stat.Message)
}
}
s.serverClient.Close()
Copy link

@sharon-fdm sharon-fdm Nov 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to close ?
(s.serverClient.Close())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or is it closed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh. This line looks like I should have removed it in #112


if s.server != nil {
server := s.server
s.server = nil
Expand Down
79 changes: 52 additions & 27 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
// Verify that an error in server.Start will return an error instead of deadlock.
func TestNoDeadlockOnError(t *testing.T) {
registry := make(map[string](map[string]OsqueryPlugin))
for reg, _ := range validRegistryNames {
for reg := range validRegistryNames {
registry[reg] = make(map[string]OsqueryPlugin)
}
mut := sync.Mutex{}
Expand All @@ -42,8 +42,9 @@ func TestNoDeadlockOnError(t *testing.T) {
CloseFunc: func() {},
}
server := &ExtensionManagerServer{
serverClient: mock,
registry: registry,
serverClient: mock,
registry: registry,
serverClientShouldShutdown: true,
}

log := func(ctx context.Context, typ logger.LogType, logText string) error {
Expand All @@ -62,8 +63,12 @@ func TestNoDeadlockOnError(t *testing.T) {
// Ensure that the extension server will shutdown and return if the osquery
// instance it is talking to stops responding to pings.
func TestShutdownWhenPingFails(t *testing.T) {
tempPath, err := ioutil.TempFile("", "")
require.Nil(t, err)
defer os.Remove(tempPath.Name())

registry := make(map[string](map[string]OsqueryPlugin))
for reg, _ := range validRegistryNames {
for reg := range validRegistryNames {
registry[reg] = make(map[string]OsqueryPlugin)
}
mock := &MockExtensionManager{
Expand All @@ -80,11 +85,14 @@ func TestShutdownWhenPingFails(t *testing.T) {
CloseFunc: func() {},
}
server := &ExtensionManagerServer{
serverClient: mock,
registry: registry,
serverClient: mock,
registry: registry,
serverClientShouldShutdown: true,
pingInterval: 1 * time.Second,
sockPath: tempPath.Name(),
}

err := server.Run()
err = server.Run()
assert.Error(t, err)
assert.Contains(t, err.Error(), "broken pipe")
assert.True(t, mock.DeRegisterExtensionFuncInvoked)
Expand All @@ -104,6 +112,7 @@ func TestShutdownDeadlock(t *testing.T) {
})
}
}

func testShutdownDeadlock(t *testing.T) {
tempPath, err := ioutil.TempFile("", "")
require.Nil(t, err)
Expand All @@ -119,7 +128,7 @@ func testShutdownDeadlock(t *testing.T) {
},
CloseFunc: func() {},
}
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name(), serverClientShouldShutdown: true}

wait := sync.WaitGroup{}

Expand Down Expand Up @@ -177,9 +186,13 @@ func testShutdownDeadlock(t *testing.T) {
}

func TestShutdownBasic(t *testing.T) {
tempPath, err := ioutil.TempFile("", "")
require.Nil(t, err)
defer os.Remove(tempPath.Name())
dir := t.TempDir()

tempPath := func() string {
tmp, err := os.CreateTemp(dir, "")
require.NoError(t, err)
return tmp.Name()
}

retUUID := osquery.ExtensionRouteUUID(0)
mock := &MockExtensionManager{
Expand All @@ -191,26 +204,38 @@ func TestShutdownBasic(t *testing.T) {
},
CloseFunc: func() {},
}
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}

completed := make(chan struct{})
go func() {
err := server.Start()
for _, server := range []*ExtensionManagerServer{
// Create the extension manager without using NewExtensionManagerServer.
{serverClient: mock, sockPath: tempPath()},
// Create the extension manager using ExtensionManagerServer.
{serverClient: mock, sockPath: tempPath(), serverClientShouldShutdown: true},
} {
completed := make(chan struct{})
go func() {
err := server.Start()
require.NoError(t, err)
close(completed)
}()

server.waitStarted()

err := server.Shutdown(context.Background())
require.NoError(t, err)
close(completed)
}()

server.waitStarted()
err = server.Shutdown(context.Background())
require.NoError(t, err)
// Test that server.Shutdown is idempotent.
err = server.Shutdown(context.Background())
require.NoError(t, err)

// Either indicate successful shutdown, or fatal the test because it
// hung
select {
case <-completed:
// Success. Do nothing.
case <-time.After(5 * time.Second):
t.Fatal("hung on shutdown")
}

// Either indicate successful shutdown, or fatal the test because it
// hung
select {
case <-completed:
// Success. Do nothing.
case <-time.After(5 * time.Second):
t.Fatal("hung on shutdown")
}
}

Expand Down