diff --git a/pkg/storage/expose/nbd_dispatch.go b/pkg/storage/expose/nbd_dispatch.go index dc5b59c..6e9d429 100644 --- a/pkg/storage/expose/nbd_dispatch.go +++ b/pkg/storage/expose/nbd_dispatch.go @@ -3,6 +3,7 @@ package expose import ( "context" "encoding/binary" + "errors" "fmt" "io" "sync" @@ -13,6 +14,8 @@ import ( "github.com/loopholelabs/silo/pkg/storage" ) +var ErrShuttingDown = errors.New("shutting down. Cannot serve any new requests.") + const dispatchBufferSize = 4 * 1024 * 1024 /** @@ -76,6 +79,8 @@ type Dispatch struct { prov storage.Provider fatal chan error pendingResponses sync.WaitGroup + shuttingDown bool + shuttingDownLock sync.Mutex metricPacketsIn uint64 metricPacketsOut uint64 metricReadAt uint64 @@ -140,6 +145,11 @@ func (d *Dispatch) GetMetrics() *DispatchMetrics { } func (d *Dispatch) Wait() { + d.shuttingDownLock.Lock() + d.shuttingDown = true + defer d.shuttingDownLock.Unlock() + // Stop accepting any new requests... + if d.logger != nil { d.logger.Trace().Str("device", d.dev).Msg("nbd waiting for pending responses") } @@ -342,16 +352,22 @@ func (d *Dispatch) cmdRead(cmdHandle uint64, cmdFrom uint64, cmdLength uint32) e case e = <-errchan: } - errorValue := uint32(0) if e != nil { - errorValue = 1 - data = make([]byte, 0) // If there was an error, don't send data + return d.writeResponse(1, handle, []byte{}) } - return d.writeResponse(errorValue, handle, data) + return d.writeResponse(0, handle, data) } - if d.asyncReads { + d.shuttingDownLock.Lock() + if !d.shuttingDown { d.pendingResponses.Add(1) + } else { + d.shuttingDownLock.Unlock() + return ErrShuttingDown + } + d.shuttingDownLock.Unlock() + + if d.asyncReads { go func() { ctime := time.Now() err := performRead(cmdHandle, cmdFrom, cmdLength) @@ -368,7 +384,6 @@ func (d *Dispatch) cmdRead(cmdHandle uint64, cmdFrom uint64, cmdLength uint32) e d.pendingResponses.Done() }() } else { - d.pendingResponses.Add(1) ctime := time.Now() err := performRead(cmdHandle, cmdFrom, cmdLength) if err == nil { @@ -418,8 +433,16 @@ func (d *Dispatch) cmdWrite(cmdHandle uint64, cmdFrom uint64, cmdLength uint32, return d.writeResponse(errorValue, handle, []byte{}) } - if d.asyncWrites { + d.shuttingDownLock.Lock() + if !d.shuttingDown { d.pendingResponses.Add(1) + } else { + d.shuttingDownLock.Unlock() + return ErrShuttingDown + } + d.shuttingDownLock.Unlock() + + if d.asyncWrites { go func() { ctime := time.Now() err := performWrite(cmdHandle, cmdFrom, cmdLength, cmdData) @@ -436,7 +459,6 @@ func (d *Dispatch) cmdWrite(cmdHandle uint64, cmdFrom uint64, cmdLength uint32, d.pendingResponses.Done() }() } else { - d.pendingResponses.Add(1) ctime := time.Now() err := performWrite(cmdHandle, cmdFrom, cmdLength, cmdData) if err == nil {