Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ExtensionManagerServer.Shutdown idempotent #117

Merged
merged 4 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,16 @@ func (s *ExtensionManagerServer) Run() error {
for {
time.Sleep(s.pingInterval)

s.mutex.Lock()
serverClient := s.serverClient
s.mutex.Unlock()
Comment on lines +259 to +261
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not totally sure I understand why this helps. I think it's so that this routine can happen at the same time some other thread might have called shutdown (and set s.serverClient to nil), but I'm not totally sure I see it.

Is this about a narrow race where shutdown happens between if serverClient == nil and s.serverClient.Ping() ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's narrow.

Here's the race condition if we don't add the locks (it's between the read serverClient := s.serverClient and the write s.serverClient = nil in Shutdown):

=== RUN   TestNoDeadlockOnError
==================
WARNING: DATA RACE
Write at 0x00c00010f610 by goroutine 14:
  github.com/osquery/osquery-go.(*ExtensionManagerServer).Shutdown()
      /Users/luk/fleetdm/git/forks/osquery-go/server.go:352 +0x3d2
  github.com/osquery/osquery-go.(*ExtensionManagerServer).Run()
      /Users/luk/fleetdm/git/forks/osquery-go/server.go:279 +0x1d0
  github.com/osquery/osquery-go.TestNoDeadlockOnError()
      /Users/luk/fleetdm/git/forks/osquery-go/server_test.go:57 +0x55c
  testing.tRunner()
      /Users/luk/go/src/testing/testing.go:1595 +0x238
  testing.(*T).Run.func1()
      /Users/luk/go/src/testing/testing.go:1648 +0x44

Previous read at 0x00c00010f610 by goroutine 16:
  github.com/osquery/osquery-go.(*ExtensionManagerServer).Run.func2()
      /Users/luk/fleetdm/git/forks/osquery-go/server.go:259 +0x64

Goroutine 14 (running) created at:
  testing.(*T).Run()
      /Users/luk/go/src/testing/testing.go:1648 +0x82a
  testing.runTests.func1()
      /Users/luk/go/src/testing/testing.go:2054 +0x84
  testing.tRunner()
      /Users/luk/go/src/testing/testing.go:1595 +0x238
  testing.runTests()
      /Users/luk/go/src/testing/testing.go:2052 +0x896
  testing.(*M).Run()
      /Users/luk/go/src/testing/testing.go:1925 +0xb57
  main.main()
      _testmain.go:69 +0x2bd

Goroutine 16 (running) created at:
  github.com/osquery/osquery-go.(*ExtensionManagerServer).Run()
      /Users/luk/fleetdm/git/forks/osquery-go/server.go:255 +0x18c
  github.com/osquery/osquery-go.TestNoDeadlockOnError()
      /Users/luk/fleetdm/git/forks/osquery-go/server_test.go:57 +0x55c
  testing.tRunner()
      /Users/luk/go/src/testing/testing.go:1595 +0x238
  testing.(*T).Run.func1()
      /Users/luk/go/src/testing/testing.go:1648 +0x44
==================
    testing.go:1465: race detected during execution of test
--- FAIL: TestNoDeadlockOnError (0.00s)


// can't ping if s.Shutdown has already happened
if s.serverClient == nil {
if serverClient == nil {
break
}

status, err := s.serverClient.Ping()
status, err := serverClient.Ping()
if err != nil {
errc <- errors.Wrap(err, "extension ping failed")
break
Expand Down Expand Up @@ -315,12 +319,16 @@ func (s *ExtensionManagerServer) Call(ctx context.Context, registry string, item
func (s *ExtensionManagerServer) Shutdown(ctx context.Context) (err error) {
s.mutex.Lock()
defer s.mutex.Unlock()
stat, err := s.serverClient.DeregisterExtension(s.uuid)
err = errors.Wrap(err, "deregistering extension")
if err == nil && stat.Code != 0 {
err = errors.Errorf("status %d deregistering extension: %s", stat.Code, stat.Message)

if s.serverClient != nil {
var stat *osquery.ExtensionStatus
stat, err = s.serverClient.DeregisterExtension(s.uuid)
err = errors.Wrap(err, "deregistering extension")
if err == nil && stat.Code != 0 {
err = errors.Errorf("status %d deregistering extension: %s", stat.Code, stat.Message)
}
}
s.serverClient.Close()
Copy link

@sharon-fdm sharon-fdm Nov 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to close ?
(s.serverClient.Close())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or is it closed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh. This line looks like I should have removed it in #112


if s.server != nil {
server := s.server
s.server = nil
Expand Down
79 changes: 52 additions & 27 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
// Verify that an error in server.Start will return an error instead of deadlock.
func TestNoDeadlockOnError(t *testing.T) {
registry := make(map[string](map[string]OsqueryPlugin))
for reg, _ := range validRegistryNames {
for reg := range validRegistryNames {
registry[reg] = make(map[string]OsqueryPlugin)
}
mut := sync.Mutex{}
Expand All @@ -42,8 +42,9 @@ func TestNoDeadlockOnError(t *testing.T) {
CloseFunc: func() {},
}
server := &ExtensionManagerServer{
serverClient: mock,
registry: registry,
serverClient: mock,
registry: registry,
serverClientShouldShutdown: true,
}

log := func(ctx context.Context, typ logger.LogType, logText string) error {
Expand All @@ -62,8 +63,12 @@ func TestNoDeadlockOnError(t *testing.T) {
// Ensure that the extension server will shutdown and return if the osquery
// instance it is talking to stops responding to pings.
func TestShutdownWhenPingFails(t *testing.T) {
tempPath, err := ioutil.TempFile("", "")
require.Nil(t, err)
defer os.Remove(tempPath.Name())

registry := make(map[string](map[string]OsqueryPlugin))
for reg, _ := range validRegistryNames {
for reg := range validRegistryNames {
registry[reg] = make(map[string]OsqueryPlugin)
}
mock := &MockExtensionManager{
Expand All @@ -80,11 +85,14 @@ func TestShutdownWhenPingFails(t *testing.T) {
CloseFunc: func() {},
}
server := &ExtensionManagerServer{
serverClient: mock,
registry: registry,
serverClient: mock,
registry: registry,
serverClientShouldShutdown: true,
pingInterval: 1 * time.Second,
sockPath: tempPath.Name(),
}

err := server.Run()
err = server.Run()
assert.Error(t, err)
assert.Contains(t, err.Error(), "broken pipe")
assert.True(t, mock.DeRegisterExtensionFuncInvoked)
Expand All @@ -104,6 +112,7 @@ func TestShutdownDeadlock(t *testing.T) {
})
}
}

func testShutdownDeadlock(t *testing.T) {
tempPath, err := ioutil.TempFile("", "")
require.Nil(t, err)
Expand All @@ -119,7 +128,7 @@ func testShutdownDeadlock(t *testing.T) {
},
CloseFunc: func() {},
}
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name(), serverClientShouldShutdown: true}

wait := sync.WaitGroup{}

Expand Down Expand Up @@ -177,9 +186,13 @@ func testShutdownDeadlock(t *testing.T) {
}

func TestShutdownBasic(t *testing.T) {
tempPath, err := ioutil.TempFile("", "")
require.Nil(t, err)
defer os.Remove(tempPath.Name())
dir := t.TempDir()

tempPath := func() string {
tmp, err := os.CreateTemp(dir, "")
require.NoError(t, err)
return tmp.Name()
}

retUUID := osquery.ExtensionRouteUUID(0)
mock := &MockExtensionManager{
Expand All @@ -191,26 +204,38 @@ func TestShutdownBasic(t *testing.T) {
},
CloseFunc: func() {},
}
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}

completed := make(chan struct{})
go func() {
err := server.Start()
for _, server := range []*ExtensionManagerServer{
// Create the extension manager without using NewExtensionManagerServer.
{serverClient: mock, sockPath: tempPath()},
// Create the extension manager using ExtensionManagerServer.
{serverClient: mock, sockPath: tempPath(), serverClientShouldShutdown: true},
} {
completed := make(chan struct{})
go func() {
err := server.Start()
require.NoError(t, err)
close(completed)
}()

server.waitStarted()

err := server.Shutdown(context.Background())
require.NoError(t, err)
close(completed)
}()

server.waitStarted()
err = server.Shutdown(context.Background())
require.NoError(t, err)
// Test that server.Shutdown is idempotent.
err = server.Shutdown(context.Background())
require.NoError(t, err)

// Either indicate successful shutdown, or fatal the test because it
// hung
select {
case <-completed:
// Success. Do nothing.
case <-time.After(5 * time.Second):
t.Fatal("hung on shutdown")
}

// Either indicate successful shutdown, or fatal the test because it
// hung
select {
case <-completed:
// Success. Do nothing.
case <-time.After(5 * time.Second):
t.Fatal("hung on shutdown")
}
}

Expand Down