Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TestShutdownDeadlock timeout #119

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ func (s *ExtensionManagerServer) Call(ctx context.Context, registry string, item
func (s *ExtensionManagerServer) Shutdown(ctx context.Context) (err error) {
s.mutex.Lock()
defer s.mutex.Unlock()

stat, err := s.serverClient.DeregisterExtension(s.uuid)
err = errors.Wrap(err, "deregistering extension")
if err == nil && stat.Code != 0 {
Expand All @@ -333,7 +334,7 @@ func (s *ExtensionManagerServer) Shutdown(ctx context.Context) (err error) {
s.server = nil
// Stop the server asynchronously so that the current request
// can complete. Otherwise, this is vulnerable to deadlock if a
// shutdown request is being processed when shutdown is
// shutdown request is being processed when Shutdown is
// explicitly called.
go func() {
server.Stop()
Expand Down
44 changes: 30 additions & 14 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io/ioutil"
"net"
"os"
"runtime/pprof"
"strings"
"sync"
"syscall"
Expand Down Expand Up @@ -98,18 +99,19 @@ const parallelTestShutdownDeadlock = 20

func TestShutdownDeadlock(t *testing.T) {
for i := 0; i < parallelTestShutdownDeadlock; i++ {
i := i
t.Run("", func(t *testing.T) {
t.Parallel()
testShutdownDeadlock(t)
testShutdownDeadlock(t, i)
})
}
}
func testShutdownDeadlock(t *testing.T) {
func testShutdownDeadlock(t *testing.T, uuid int) {
tempPath, err := ioutil.TempFile("", "")
require.Nil(t, err)
defer os.Remove(tempPath.Name())

retUUID := osquery.ExtensionRouteUUID(0)
retUUID := osquery.ExtensionRouteUUID(uuid)
mock := &MockExtensionManager{
RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
return &osquery.ExtensionStatus{Code: 0, UUID: retUUID}, nil
Expand All @@ -119,16 +121,22 @@ func testShutdownDeadlock(t *testing.T) {
},
CloseFunc: func() {},
}
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}
server := ExtensionManagerServer{
serverClient: mock,
sockPath: tempPath.Name(),
timeout: defaultTimeout,
}

wait := sync.WaitGroup{}
var wait sync.WaitGroup

wait.Add(1)
go func() {
// We do not wait for this routine to finish because thrift.TServer.Serve
// seems to sometimes hang after shutdowns. (This test is just testing
// the Shutdown doesn't hang.)
err := server.Start()
require.Nil(t, err)
wait.Done()
require.NoError(t, err)
}()

// Wait for server to be set up
server.waitStarted()

Expand All @@ -138,10 +146,17 @@ func testShutdownDeadlock(t *testing.T) {
addr, err := net.ResolveUnixAddr("unix", listenPath)
require.Nil(t, err)
timeout := 500 * time.Millisecond
trans := thrift.NewTSocketFromAddrTimeout(addr, timeout, timeout)
err = trans.Open()
require.Nil(t, err)
client := osquery.NewExtensionManagerClientFactory(trans,
opened := false
attempt := 0
var transport *thrift.TSocket
for !opened && attempt < 10 {
transport = thrift.NewTSocketFromAddrTimeout(addr, timeout, timeout)
err = transport.Open()
opened = err == nil
attempt++
}
require.NoError(t, err)
client := osquery.NewExtensionManagerClientFactory(transport,
thrift.NewTBinaryProtocolFactoryDefault())

// Simultaneously call shutdown through a request from the client and
Expand All @@ -156,7 +171,7 @@ func testShutdownDeadlock(t *testing.T) {
go func() {
defer wait.Done()
err = server.Shutdown(context.Background())
require.Nil(t, err)
require.NoError(t, err)
}()

// Track whether shutdown completed
Expand All @@ -171,7 +186,8 @@ func testShutdownDeadlock(t *testing.T) {
select {
case <-completed:
// Success. Do nothing.
case <-time.After(5 * time.Second):
case <-time.After(10 * time.Second):
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
t.Fatal("hung on shutdown")
}
}
Expand Down