Skip to content

Commit

Permalink
refactor the test entirely, again
Browse files Browse the repository at this point in the history
feelin pretty good about this one
  • Loading branch information
gulducat committed Jan 16, 2025
1 parent a792c0d commit 2a771c1
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 67 deletions.
10 changes: 6 additions & 4 deletions nomad/host_volume_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func (v *HostVolume) Create(args *structs.HostVolumeCreateRequest, reply *struct
}

// serialize client RPC and raft write per volume ID
index, err := v.serializeCall(vol.ID, func() (uint64, error) {
index, err := v.serializeCall(vol.ID, "create", func() (uint64, error) {
// Attempt to create the volume on the client.
//
// NOTE: creating the volume on the client via the plugin can't be made
Expand Down Expand Up @@ -368,7 +368,7 @@ func (v *HostVolume) Register(args *structs.HostVolumeRegisterRequest, reply *st
}

// serialize client RPC and raft write per volume ID
index, err := v.serializeCall(vol.ID, func() (uint64, error) {
index, err := v.serializeCall(vol.ID, "register", func() (uint64, error) {
// Attempt to register the volume on the client.
//
// NOTE: registering the volume on the client via the plugin can't be made
Expand Down Expand Up @@ -647,7 +647,7 @@ func (v *HostVolume) Delete(args *structs.HostVolumeDeleteRequest, reply *struct
}

// serialize client RPC and raft write per volume ID
index, err := v.serializeCall(vol.ID, func() (uint64, error) {
index, err := v.serializeCall(vol.ID, "delete", func() (uint64, error) {
if err := v.deleteVolume(vol); err != nil {
return 0, err
}
Expand Down Expand Up @@ -692,14 +692,15 @@ func (v *HostVolume) deleteVolume(vol *structs.HostVolume) error {
// Concurrent calls should all run eventually (or timeout, or server shutdown),
// but there is no guarantee that they will run in the order received.
// The passed fn is expected to return a raft index and error.
func (v *HostVolume) serializeCall(volumeID string, fn func() (uint64, error)) (uint64, error) {
func (v *HostVolume) serializeCall(volumeID, op string, fn func() (uint64, error)) (uint64, error) {
timeout := 2 * time.Minute // 2x the client RPC timeout
for {
ctx, done := context.WithTimeout(v.srv.shutdownCtx, timeout)

loaded, occupied := v.volOps.LoadOrStore(volumeID, ctx)

if !occupied {
v.logger.Trace("HostVolume RPC running ", "operation", op)
// run the fn!
index, err := fn()

Expand All @@ -712,6 +713,7 @@ func (v *HostVolume) serializeCall(volumeID string, fn func() (uint64, error)) (
}

// another one is running; wait for it to finish.
v.logger.Trace("HostVolume RPC waiting", "operation", op)

// cancel the tentative context; we'll use the one we pulled from
// volOps (set by another RPC call) instead.
Expand Down
226 changes: 163 additions & 63 deletions nomad/host_volume_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"testing"
"time"

"github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-set/v3"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2"
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client"
Expand Down Expand Up @@ -295,7 +297,7 @@ func TestHostVolumeEndpoint_CreateRegisterGetDelete(t *testing.T) {
// win a race with the get RPC goroutine
time.AfterFunc(200*time.Millisecond, func() {
codec := rpcClient(t, srv)
var registerResp structs.HostVolumeCreateResponse
var registerResp structs.HostVolumeRegisterResponse
err := msgpackrpc.CallWithCodec(codec, "HostVolume.Register", registerReq, &registerResp)
must.NoError(t, err)
})
Expand Down Expand Up @@ -790,85 +792,159 @@ func TestHostVolumeEndpoint_concurrency(t *testing.T) {
}
wr := structs.WriteRequest{Region: srv.Region()}

// create the volume for us to attempt concurrent operations on
// tell the mock client how it should respond to create calls
c.setCreate(&cstructs.ClientHostVolumeCreateResponse{
VolumeName: "test-vol",
HostPath: "/pretend/path",
}, nil)

// create the volume for us to attempt concurrent operations on
cVol := vol.Copy() // copy because HostPath gets mutated
cVol.Parameters = map[string]string{"created": "initial"}
createReq := &structs.HostVolumeCreateRequest{
Volume: vol.Copy(), // copy because HostPath gets mutated
Volume: cVol,
WriteRequest: wr,
}
var createResp structs.HostVolumeCreateResponse
must.NoError(t, srv.RPC("HostVolume.Create", createReq, &createResp))
volumeID := createResp.Volume.ID // used by delete
got, err := srv.State().HostVolumeByID(nil, vol.Namespace, createResp.Volume.ID, false)
must.NoError(t, err)
must.Eq(t, map[string]string{"created": "initial"}, got.Parameters)

// warning: below here be (concurrency) dragons. if this test fails,
// it is rather difficult to troubleshoot. sorry!

// prepare blocking channels in the mock client.
// sending a struct{} (or closing) the channel will unblock the operation
cCh, rCh, dCh := make(chan struct{}), make(chan struct{}), make(chan struct{})
c.setBlockChan("create", cCh)
c.setBlockChan("register", rCh)
c.setBlockChan("delete", dCh)
// this is critical -- everything needs to use the same volume ID,
// because that's what the serialization is based on.
vol.ID = createResp.Volume.ID

// each RPC call that finishes will put its name here
opCh := make(chan string)
// "create" volume #2 (same vol except for parameters)
cVol2 := vol.Copy()
cVol2.Parameters = map[string]string{"created": "again"}
// "register" volume
rVol := vol.Copy()
rVol.Parameters = map[string]string{"registered": "yup"}

// prepare the mock client to block its calls
must.NoError(t, c.setBlockChan())

// each operation goroutine will put its name in here when it completes,
// so we can wait until the whole RPC completes before checking state.
rpcDoneCh := make(chan string)
rpcDone := func(op string) {
select {
case rpcDoneCh <- op:
case <-time.After(time.Second):
t.Errorf("timed out writing %q to rpcDoneCh", op)
}
}

// start all the RPCs concurrently
go func() {
var funcs multierror.Group
// create
funcs.Go(func() error {
createReq = &structs.HostVolumeCreateRequest{
Volume: vol.Copy(), // copy because HostPath gets mutated
Volume: cVol2,
WriteRequest: wr,
}
createResp = structs.HostVolumeCreateResponse{}
test.NoError(t, srv.RPC("HostVolume.Create", createReq, &createResp),
test.Sprint("Create method should not error"))
opCh <- "create"
}()
go func() {
err := srv.RPC("HostVolume.Create", createReq, &createResp)
rpcDone("create")
return err
})
// register
funcs.Go(func() error {
registerReq := &structs.HostVolumeRegisterRequest{
Volume: vol.Copy(),
Volume: rVol,
WriteRequest: wr,
}
var registerResp structs.HostVolumeRegisterResponse
test.NoError(t, srv.RPC("HostVolume.Register", registerReq, &registerResp),
test.Sprint("Register method should not error"))
opCh <- "register"
}()
go func() {
err := srv.RPC("HostVolume.Register", registerReq, &registerResp)
rpcDone("register")
return err
})
// delete
funcs.Go(func() error {
deleteReq := &structs.HostVolumeDeleteRequest{
VolumeID: volumeID,
VolumeID: vol.ID,
WriteRequest: wr,
}
var deleteResp structs.HostVolumeDeleteResponse
test.NoError(t, srv.RPC("HostVolume.Delete", deleteReq, &deleteResp),
test.Sprint("Delete method should not error"))
opCh <- "delete"
}()
err := srv.RPC("HostVolume.Delete", deleteReq, &deleteResp)
rpcDone("delete")
return err
})

// NOTE: below here, we avoid `must` methods, because a t.Fatal causes all
// the above goroutines to halt with confusing errors.

// keep track of which operations have completed
opSet := set.From([]string{"create", "register", "delete"})

LOOP:
for {
if opSet.Empty() {
break // all done!
}

// helper pulls an operation from the channel, or timeout
pullOp := func() string {
// unblock a client RPC; it will tell us which one it let through.
op, err := c.unblockCurrent()
if err != nil {
t.Errorf("error unblocking client RPC: %v", err)
break
}

if !opSet.Remove(op) {
t.Errorf("mystery unblocked RPC operation: %q", op)
break
}

// make sure the server RPC has totally completed (and written state),
// and that the server RPC matches the unblocked client RPC.
select {
case op := <-opCh:
return op
case <-time.After(200 * time.Millisecond): // generous headroom
return "timeout"
case serverOp := <-rpcDoneCh:
if serverOp != op {
t.Errorf("client RPC says %q; server RPC says %q", op, serverOp)
continue
}
case <-time.After(time.Second):
t.Error("timeout waiting for an RPC to finish")
break LOOP
}

// get the volume to check
got, err := srv.State().HostVolumeByID(nil, vol.Namespace, vol.ID, false)
if err != nil {
t.Errorf("error reading state: %v", err)
break
}
}

must.Eq(t, "timeout", pullOp(), must.Sprint("nothing should be unblocked yet"))
switch op {

close(rCh)
must.Eq(t, "register", pullOp(), must.Sprint("closing register channel should unblock Register"))
case "create":
if got == nil {
t.Error("volume should not be nil after create RPC")
continue
}
test.Eq(t, cVol2.Parameters, got.Parameters)

must.Eq(t, "timeout", pullOp(), must.Sprint("again blocked RPCs should remain so"))
case "register":
if got == nil {
t.Error("volume should not be nil after register RPC")
continue
}
test.Eq(t, rVol.Parameters, got.Parameters)

close(cCh)
must.Eq(t, "create", pullOp(), must.Sprint("closing create channel should unblock Create"))
case "delete":
test.Nil(t, got, test.Sprint(""))
}
}

must.Eq(t, "timeout", pullOp(), must.Sprint("last RPC should still be blocked"))
mErr := funcs.Wait() // ensure all the goroutines are done
test.NoError(t, helper.FlattenMultierror(mErr))

close(dCh)
must.Eq(t, "delete", pullOp(), must.Sprint("closing delete channel should unblock Delete"))
// all of 'em should have happened!
test.Eq(t, []string{}, opSet.Slice())
}

// mockHostVolumeClient models client RPCs that have side-effects on the
Expand All @@ -879,10 +955,9 @@ type mockHostVolumeClient struct {
nextCreateErr error
nextRegisterErr error
nextDeleteErr error
// blockChans are used to test server->client RPC serialization.
// this is separate from lock because no single method should block while
// holding the lock for the whole client.
blockChans sync.Map
// blockChan is used to test server->client RPC serialization.
// do not block on this channel while the main lock is held.
blockChan chan string
}

// newMockHostVolumeClient configures a RPC-only Nomad test agent and returns a
Expand Down Expand Up @@ -931,17 +1006,13 @@ func (v *mockHostVolumeClient) setDelete(errMsg string) {
v.nextDeleteErr = errors.New(errMsg)
}

func (v *mockHostVolumeClient) setBlockChan(operation string, ch chan struct{}) {
v.blockChans.Store(operation, ch)
}

func (v *mockHostVolumeClient) Create(
req *cstructs.ClientHostVolumeCreateRequest,
resp *cstructs.ClientHostVolumeCreateResponse) error {

// block until the concurrency test closes the channel
if ch, ok := v.blockChans.Load("create"); ok {
<-ch.(chan struct{})
// block until something runs unblockCurrent()
if bc := v.getBlockChan(); bc != nil {
bc <- "create"
}

v.lock.Lock()
Expand All @@ -957,9 +1028,9 @@ func (v *mockHostVolumeClient) Register(
req *cstructs.ClientHostVolumeRegisterRequest,
resp *cstructs.ClientHostVolumeRegisterResponse) error {

// block until the concurrency test closes the channel
if ch, ok := v.blockChans.Load("register"); ok {
<-ch.(chan struct{})
// block until something runs unblockCurrent()
if bc := v.getBlockChan(); bc != nil {
bc <- "register"
}

v.lock.Lock()
Expand All @@ -972,12 +1043,41 @@ func (v *mockHostVolumeClient) Delete(
req *cstructs.ClientHostVolumeDeleteRequest,
resp *cstructs.ClientHostVolumeDeleteResponse) error {

// block until the concurrency test closes the channel
if ch, ok := v.blockChans.Load("delete"); ok {
<-ch.(chan struct{})
// block until something runs unblockCurrent()
if bc := v.getBlockChan(); bc != nil {
bc <- "delete"
}

v.lock.Lock()
defer v.lock.Unlock()
return v.nextDeleteErr
}

func (v *mockHostVolumeClient) setBlockChan() error {
v.lock.Lock()
defer v.lock.Unlock()
if v.blockChan != nil {
return errors.New("blockChan already set")
}
v.blockChan = make(chan string) // no buffer to ensure blockage
return nil
}

func (v *mockHostVolumeClient) getBlockChan() chan string {
v.lock.Lock()
defer v.lock.Unlock()
return v.blockChan
}

func (v *mockHostVolumeClient) unblockCurrent() (string, error) {
bc := v.getBlockChan()
if bc == nil {
return "", errors.New("no blockChan")
}
select {
case current := <-bc:
return current, nil
case <-time.After(time.Second):
return "", errors.New("unblockCurrent timeout")
}
}

0 comments on commit 2a771c1

Please sign in to comment.