diff --git a/nomad/host_volume_endpoint.go b/nomad/host_volume_endpoint.go index 854dce48369..193d141ab59 100644 --- a/nomad/host_volume_endpoint.go +++ b/nomad/host_volume_endpoint.go @@ -269,7 +269,7 @@ func (v *HostVolume) Create(args *structs.HostVolumeCreateRequest, reply *struct } // serialize client RPC and raft write per volume ID - index, err := v.serializeCall(vol.ID, func() (uint64, error) { + index, err := v.serializeCall(vol.ID, "create", func() (uint64, error) { // Attempt to create the volume on the client. // // NOTE: creating the volume on the client via the plugin can't be made @@ -368,7 +368,7 @@ func (v *HostVolume) Register(args *structs.HostVolumeRegisterRequest, reply *st } // serialize client RPC and raft write per volume ID - index, err := v.serializeCall(vol.ID, func() (uint64, error) { + index, err := v.serializeCall(vol.ID, "register", func() (uint64, error) { // Attempt to register the volume on the client. // // NOTE: registering the volume on the client via the plugin can't be made @@ -647,7 +647,7 @@ func (v *HostVolume) Delete(args *structs.HostVolumeDeleteRequest, reply *struct } // serialize client RPC and raft write per volume ID - index, err := v.serializeCall(vol.ID, func() (uint64, error) { + index, err := v.serializeCall(vol.ID, "delete", func() (uint64, error) { if err := v.deleteVolume(vol); err != nil { return 0, err } @@ -692,7 +692,7 @@ func (v *HostVolume) deleteVolume(vol *structs.HostVolume) error { // Concurrent calls should all run eventually (or timeout, or server shutdown), // but there is no guarantee that they will run in the order received. // The passed fn is expected to return a raft index and error. -func (v *HostVolume) serializeCall(volumeID string, fn func() (uint64, error)) (uint64, error) { +func (v *HostVolume) serializeCall(volumeID, op string, fn func() (uint64, error)) (uint64, error) { timeout := 2 * time.Minute // 2x the client RPC timeout for { ctx, done := context.WithTimeout(v.srv.shutdownCtx, timeout) @@ -700,6 +700,7 @@ func (v *HostVolume) serializeCall(volumeID string, fn func() (uint64, error)) ( loaded, occupied := v.volOps.LoadOrStore(volumeID, ctx) if !occupied { + v.logger.Trace("HostVolume RPC running ", "operation", op) // run the fn! index, err := fn() @@ -712,6 +713,7 @@ func (v *HostVolume) serializeCall(volumeID string, fn func() (uint64, error)) ( } // another one is running; wait for it to finish. + v.logger.Trace("HostVolume RPC waiting", "operation", op) // cancel the tentative context; we'll use the one we pulled from // volOps (set by another RPC call) instead. diff --git a/nomad/host_volume_endpoint_test.go b/nomad/host_volume_endpoint_test.go index 912bade32de..cccd86027ed 100644 --- a/nomad/host_volume_endpoint_test.go +++ b/nomad/host_volume_endpoint_test.go @@ -11,6 +11,8 @@ import ( "testing" "time" + "github.com/hashicorp/go-multierror" + "github.com/hashicorp/go-set/v3" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2" "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/client" @@ -295,7 +297,7 @@ func TestHostVolumeEndpoint_CreateRegisterGetDelete(t *testing.T) { // win a race with the get RPC goroutine time.AfterFunc(200*time.Millisecond, func() { codec := rpcClient(t, srv) - var registerResp structs.HostVolumeCreateResponse + var registerResp structs.HostVolumeRegisterResponse err := msgpackrpc.CallWithCodec(codec, "HostVolume.Register", registerReq, ®isterResp) must.NoError(t, err) }) @@ -790,85 +792,159 @@ func TestHostVolumeEndpoint_concurrency(t *testing.T) { } wr := structs.WriteRequest{Region: srv.Region()} - // create the volume for us to attempt concurrent operations on + // tell the mock client how it should respond to create calls c.setCreate(&cstructs.ClientHostVolumeCreateResponse{ VolumeName: "test-vol", HostPath: "/pretend/path", }, nil) + + // create the volume for us to attempt concurrent operations on + cVol := vol.Copy() // copy because HostPath gets mutated + cVol.Parameters = map[string]string{"created": "initial"} createReq := &structs.HostVolumeCreateRequest{ - Volume: vol.Copy(), // copy because HostPath gets mutated + Volume: cVol, WriteRequest: wr, } var createResp structs.HostVolumeCreateResponse must.NoError(t, srv.RPC("HostVolume.Create", createReq, &createResp)) - volumeID := createResp.Volume.ID // used by delete + got, err := srv.State().HostVolumeByID(nil, vol.Namespace, createResp.Volume.ID, false) + must.NoError(t, err) + must.Eq(t, map[string]string{"created": "initial"}, got.Parameters) + + // warning: below here be (concurrency) dragons. if this test fails, + // it is rather difficult to troubleshoot. sorry! - // prepare blocking channels in the mock client. - // sending a struct{} (or closing) the channel will unblock the operation - cCh, rCh, dCh := make(chan struct{}), make(chan struct{}), make(chan struct{}) - c.setBlockChan("create", cCh) - c.setBlockChan("register", rCh) - c.setBlockChan("delete", dCh) + // this is critical -- everything needs to use the same volume ID, + // because that's what the serialization is based on. + vol.ID = createResp.Volume.ID - // each RPC call that finishes will put its name here - opCh := make(chan string) + // "create" volume #2 (same vol except for parameters) + cVol2 := vol.Copy() + cVol2.Parameters = map[string]string{"created": "again"} + // "register" volume + rVol := vol.Copy() + rVol.Parameters = map[string]string{"registered": "yup"} + + // prepare the mock client to block its calls + must.NoError(t, c.setBlockChan()) + + // each operation goroutine will put its name in here when it completes, + // so we can wait until the whole RPC completes before checking state. + rpcDoneCh := make(chan string) + rpcDone := func(op string) { + select { + case rpcDoneCh <- op: + case <-time.After(time.Second): + t.Errorf("timed out writing %q to rpcDoneCh", op) + } + } // start all the RPCs concurrently - go func() { + var funcs multierror.Group + // create + funcs.Go(func() error { createReq = &structs.HostVolumeCreateRequest{ - Volume: vol.Copy(), // copy because HostPath gets mutated + Volume: cVol2, WriteRequest: wr, } createResp = structs.HostVolumeCreateResponse{} - test.NoError(t, srv.RPC("HostVolume.Create", createReq, &createResp), - test.Sprint("Create method should not error")) - opCh <- "create" - }() - go func() { + err := srv.RPC("HostVolume.Create", createReq, &createResp) + rpcDone("create") + return err + }) + // register + funcs.Go(func() error { registerReq := &structs.HostVolumeRegisterRequest{ - Volume: vol.Copy(), + Volume: rVol, WriteRequest: wr, } var registerResp structs.HostVolumeRegisterResponse - test.NoError(t, srv.RPC("HostVolume.Register", registerReq, ®isterResp), - test.Sprint("Register method should not error")) - opCh <- "register" - }() - go func() { + err := srv.RPC("HostVolume.Register", registerReq, ®isterResp) + rpcDone("register") + return err + }) + // delete + funcs.Go(func() error { deleteReq := &structs.HostVolumeDeleteRequest{ - VolumeID: volumeID, + VolumeID: vol.ID, WriteRequest: wr, } var deleteResp structs.HostVolumeDeleteResponse - test.NoError(t, srv.RPC("HostVolume.Delete", deleteReq, &deleteResp), - test.Sprint("Delete method should not error")) - opCh <- "delete" - }() + err := srv.RPC("HostVolume.Delete", deleteReq, &deleteResp) + rpcDone("delete") + return err + }) + + // NOTE: below here, we avoid `must` methods, because a t.Fatal causes all + // the above goroutines to halt with confusing errors. + + // keep track of which operations have completed + opSet := set.From([]string{"create", "register", "delete"}) + +LOOP: + for { + if opSet.Empty() { + break // all done! + } - // helper pulls an operation from the channel, or timeout - pullOp := func() string { + // unblock a client RPC; it will tell us which one it let through. + op, err := c.unblockCurrent() + if err != nil { + t.Errorf("error unblocking client RPC: %v", err) + break + } + + if !opSet.Remove(op) { + t.Errorf("mystery unblocked RPC operation: %q", op) + break + } + + // make sure the server RPC has totally completed (and written state), + // and that the server RPC matches the unblocked client RPC. select { - case op := <-opCh: - return op - case <-time.After(200 * time.Millisecond): // generous headroom - return "timeout" + case serverOp := <-rpcDoneCh: + if serverOp != op { + t.Errorf("client RPC says %q; server RPC says %q", op, serverOp) + continue + } + case <-time.After(time.Second): + t.Error("timeout waiting for an RPC to finish") + break LOOP + } + + // get the volume to check + got, err := srv.State().HostVolumeByID(nil, vol.Namespace, vol.ID, false) + if err != nil { + t.Errorf("error reading state: %v", err) + break } - } - must.Eq(t, "timeout", pullOp(), must.Sprint("nothing should be unblocked yet")) + switch op { - close(rCh) - must.Eq(t, "register", pullOp(), must.Sprint("closing register channel should unblock Register")) + case "create": + if got == nil { + t.Error("volume should not be nil after create RPC") + continue + } + test.Eq(t, cVol2.Parameters, got.Parameters) - must.Eq(t, "timeout", pullOp(), must.Sprint("again blocked RPCs should remain so")) + case "register": + if got == nil { + t.Error("volume should not be nil after register RPC") + continue + } + test.Eq(t, rVol.Parameters, got.Parameters) - close(cCh) - must.Eq(t, "create", pullOp(), must.Sprint("closing create channel should unblock Create")) + case "delete": + test.Nil(t, got, test.Sprint("")) + } + } - must.Eq(t, "timeout", pullOp(), must.Sprint("last RPC should still be blocked")) + mErr := funcs.Wait() // ensure all the goroutines are done + test.NoError(t, helper.FlattenMultierror(mErr)) - close(dCh) - must.Eq(t, "delete", pullOp(), must.Sprint("closing delete channel should unblock Delete")) + // all of 'em should have happened! + test.Eq(t, []string{}, opSet.Slice()) } // mockHostVolumeClient models client RPCs that have side-effects on the @@ -879,10 +955,9 @@ type mockHostVolumeClient struct { nextCreateErr error nextRegisterErr error nextDeleteErr error - // blockChans are used to test server->client RPC serialization. - // this is separate from lock because no single method should block while - // holding the lock for the whole client. - blockChans sync.Map + // blockChan is used to test server->client RPC serialization. + // do not block on this channel while the main lock is held. + blockChan chan string } // newMockHostVolumeClient configures a RPC-only Nomad test agent and returns a @@ -931,17 +1006,13 @@ func (v *mockHostVolumeClient) setDelete(errMsg string) { v.nextDeleteErr = errors.New(errMsg) } -func (v *mockHostVolumeClient) setBlockChan(operation string, ch chan struct{}) { - v.blockChans.Store(operation, ch) -} - func (v *mockHostVolumeClient) Create( req *cstructs.ClientHostVolumeCreateRequest, resp *cstructs.ClientHostVolumeCreateResponse) error { - // block until the concurrency test closes the channel - if ch, ok := v.blockChans.Load("create"); ok { - <-ch.(chan struct{}) + // block until something runs unblockCurrent() + if bc := v.getBlockChan(); bc != nil { + bc <- "create" } v.lock.Lock() @@ -957,9 +1028,9 @@ func (v *mockHostVolumeClient) Register( req *cstructs.ClientHostVolumeRegisterRequest, resp *cstructs.ClientHostVolumeRegisterResponse) error { - // block until the concurrency test closes the channel - if ch, ok := v.blockChans.Load("register"); ok { - <-ch.(chan struct{}) + // block until something runs unblockCurrent() + if bc := v.getBlockChan(); bc != nil { + bc <- "register" } v.lock.Lock() @@ -972,12 +1043,41 @@ func (v *mockHostVolumeClient) Delete( req *cstructs.ClientHostVolumeDeleteRequest, resp *cstructs.ClientHostVolumeDeleteResponse) error { - // block until the concurrency test closes the channel - if ch, ok := v.blockChans.Load("delete"); ok { - <-ch.(chan struct{}) + // block until something runs unblockCurrent() + if bc := v.getBlockChan(); bc != nil { + bc <- "delete" } v.lock.Lock() defer v.lock.Unlock() return v.nextDeleteErr } + +func (v *mockHostVolumeClient) setBlockChan() error { + v.lock.Lock() + defer v.lock.Unlock() + if v.blockChan != nil { + return errors.New("blockChan already set") + } + v.blockChan = make(chan string) // no buffer to ensure blockage + return nil +} + +func (v *mockHostVolumeClient) getBlockChan() chan string { + v.lock.Lock() + defer v.lock.Unlock() + return v.blockChan +} + +func (v *mockHostVolumeClient) unblockCurrent() (string, error) { + bc := v.getBlockChan() + if bc == nil { + return "", errors.New("no blockChan") + } + select { + case current := <-bc: + return current, nil + case <-time.After(time.Second): + return "", errors.New("unblockCurrent timeout") + } +}