Skip to content

Commit

Permalink
hashmail: update to user structured logs
Browse files Browse the repository at this point in the history
This will make querying by stream IDs more uniform (and hence, way
easier).
  • Loading branch information
ellemouton committed Sep 12, 2024
1 parent efc8340 commit 550bede
Showing 1 changed file with 61 additions and 51 deletions.
112 changes: 61 additions & 51 deletions hashmail_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"github.com/btcsuite/btclog"
"io"
"sync"
"time"
Expand Down Expand Up @@ -104,8 +105,8 @@ func (r *readStream) ReadNextMsg(ctx context.Context) ([]byte, error) {

// ReturnStream gives up the read stream by passing it back up through the
// payment stream.
func (r *readStream) ReturnStream() {
log.Debugf("Returning read stream %x", r.parentStream.id[:])
func (r *readStream) ReturnStream(ctx context.Context) {
log.DebugS(ctx, "Returning read stream")
r.parentStream.ReturnReadStream(r)
}

Expand Down Expand Up @@ -193,7 +194,7 @@ type stream struct {
}

// newStream creates a new stream independent of any given stream ID.
func newStream(id streamID, limiter *rate.Limiter,
func newStream(ctx context.Context, id streamID, limiter *rate.Limiter,
equivAuth func(auth *hashmailrpc.CipherBoxAuth) error,
onStale func() error, staleTimeout time.Duration) *stream {

Expand All @@ -210,7 +211,7 @@ func newStream(id streamID, limiter *rate.Limiter,
id: id,
equivAuth: equivAuth,
limiter: limiter,
status: newStreamStatus(onStale, staleTimeout),
status: newStreamStatus(ctx, onStale, staleTimeout),
readBytesChan: make(chan []byte),
readErrChan: make(chan error, 1),
quit: make(chan struct{}),
Expand Down Expand Up @@ -305,8 +306,8 @@ func (s *stream) ReturnWriteStream(w *writeStream) {
// RequestReadStream attempts to request the read stream from the main backing
// stream. If we're unable to obtain it before the timeout, then an error is
// returned.
func (s *stream) RequestReadStream() (*readStream, error) {
log.Tracef("HashMailStream(%x): requesting read stream", s.id[:])
func (s *stream) RequestReadStream(ctx context.Context) (*readStream, error) {
log.TraceS(ctx, "Requested read stream")

select {
case r := <-s.readStreamChan:
Expand All @@ -320,8 +321,8 @@ func (s *stream) RequestReadStream() (*readStream, error) {
// RequestWriteStream attempts to request the read stream from the main backing
// stream. If we're unable to obtain it before the timeout, then an error is
// returned.
func (s *stream) RequestWriteStream() (*writeStream, error) {
log.Tracef("HashMailStream(%x): requesting write stream", s.id[:])
func (s *stream) RequestWriteStream(ctx context.Context) (*writeStream, error) {
log.TraceS(ctx, "Requesting write stream")

select {
case w := <-s.writeStreamChan:
Expand Down Expand Up @@ -389,8 +390,10 @@ func (h *hashMailServer) Stop() {
}

// tearDownStaleStream can be used to tear down a stale mailbox stream.
func (h *hashMailServer) tearDownStaleStream(id streamID) error {
log.Debugf("Tearing down stale HashMail stream: id=%x", id)
func (h *hashMailServer) tearDownStaleStream(ctx context.Context,
id streamID) error {

log.DebugS(ctx, "Tearing down stale HashMail stream")

h.Lock()
defer h.Unlock()
Expand Down Expand Up @@ -428,15 +431,15 @@ func (h *hashMailServer) ValidateStreamAuth(ctx context.Context,
}

// InitStream attempts to initialize a new stream given a valid descriptor.
func (h *hashMailServer) InitStream(
func (h *hashMailServer) InitStream(ctx context.Context,
init *hashmailrpc.CipherBoxAuth) (*hashmailrpc.CipherInitResp, error) {

h.Lock()
defer h.Unlock()

streamID := newStreamID(init.Desc.StreamId)

log.Debugf("Creating new HashMail Stream: %x", streamID)
log.DebugS(ctx, "Creating new HashMail Stream")

// The stream is already active, and we only allow a single session for
// a given stream to exist.
Expand All @@ -452,10 +455,11 @@ func (h *hashMailServer) InitStream(
rate.Every(h.cfg.msgRate), h.cfg.msgBurstAllowance,
)
freshStream := newStream(
streamID, limiter, func(auth *hashmailrpc.CipherBoxAuth) error {
ctx, streamID, limiter,
func(auth *hashmailrpc.CipherBoxAuth) error {
return nil
}, func() error {
return h.tearDownStaleStream(streamID)
return h.tearDownStaleStream(ctx, streamID)
}, h.cfg.staleTimeout,
)

Expand All @@ -470,7 +474,9 @@ func (h *hashMailServer) InitStream(

// LookUpReadStream attempts to loop up a new stream. If the stream is found, then
// the stream is marked as being active. Otherwise, an error is returned.
func (h *hashMailServer) LookUpReadStream(streamID []byte) (*readStream, error) {
func (h *hashMailServer) LookUpReadStream(ctx context.Context,
streamID []byte) (*readStream, error) {

h.RLock()
defer h.RUnlock()

Expand All @@ -479,12 +485,13 @@ func (h *hashMailServer) LookUpReadStream(streamID []byte) (*readStream, error)
return nil, fmt.Errorf("stream not found")
}

return stream.RequestReadStream()
return stream.RequestReadStream(ctx)
}

// LookUpWriteStream attempts to loop up a new stream. If the stream is found,
// then the stream is marked as being active. Otherwise, an error is returned.
func (h *hashMailServer) LookUpWriteStream(streamID []byte) (*writeStream, error) {
func (h *hashMailServer) LookUpWriteStream(ctx context.Context,
streamID []byte) (*writeStream, error) {

h.RLock()
defer h.RUnlock()
Expand All @@ -494,7 +501,7 @@ func (h *hashMailServer) LookUpWriteStream(streamID []byte) (*writeStream, error
return nil, fmt.Errorf("stream not found")
}

return stream.RequestWriteStream()
return stream.RequestWriteStream(ctx)
}

// TearDownStream attempts to tear down a stream which renders both sides of
Expand Down Expand Up @@ -523,8 +530,7 @@ func (h *hashMailServer) TearDownStream(ctx context.Context, streamID []byte,
return err
}

log.Debugf("Tearing down HashMail stream: id=%x, auth=%v",
auth.Desc.StreamId, auth.Auth)
log.DebugS(ctx, "Tearing down HashMail stream", "auth", auth.Auth)

// At this point we know the auth was valid, so we'll tear down the
// stream.
Expand Down Expand Up @@ -568,16 +574,16 @@ func (h *hashMailServer) NewCipherBox(ctx context.Context,
return nil, err
}

log.Debugf("New HashMail stream init: id=%x, auth=%v",
init.Desc.StreamId, init.Auth)
ctxl := btclog.WithCtx(ctx, btclog.Hex("stream_id", init.Desc.StreamId))

if err := h.ValidateStreamAuth(ctx, init); err != nil {
log.Debugf("Stream creation validation failed (id=%x): %v",
init.Desc.StreamId, err)
log.DebugS(ctxl, "New HashMail stream init", "auth", init.Auth)

if err := h.ValidateStreamAuth(ctxl, init); err != nil {
log.DebugS(ctxl, "Stream creation validation failed", err)
return nil, err
}

resp, err := h.InitStream(init)
resp, err := h.InitStream(ctxl, init)
if err != nil {
return nil, err
}
Expand All @@ -597,8 +603,9 @@ func (h *hashMailServer) DelCipherBox(ctx context.Context,
return nil, err
}

log.Debugf("New HashMail stream deletion: id=%x, auth=%v",
auth.Desc.StreamId, auth.Auth)
ctxl := btclog.WithCtx(ctx, btclog.Hex("stream_id", auth.Desc.StreamId))

log.DebugS(ctxl, "New HashMail stream deletion", "auth", auth.Auth)

if err := h.TearDownStream(ctx, auth.Desc.StreamId, auth); err != nil {
return nil, err
Expand All @@ -610,7 +617,7 @@ func (h *hashMailServer) DelCipherBox(ctx context.Context,
// SendStream implements the client streaming call to utilize the write end of
// a stream to send a message to the read end.
func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamServer) error {
log.Debugf("New HashMail write stream pending...")
log.Debug("New HashMail write stream pending...")

// We'll need to receive the first message in order to determine if
// this stream exists or not
Expand All @@ -621,6 +628,9 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe
return err
}

ctx := btclog.WithCtx(readStream.Context(),
btclog.Hex("stream_id", cipherBox.Desc.StreamId))

switch {
case cipherBox.Desc == nil:
return fmt.Errorf("cipher box descriptor required")
Expand All @@ -629,12 +639,11 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe
return fmt.Errorf("stream_id required")
}

log.Debugf("New HashMail write stream: id=%x",
cipherBox.Desc.StreamId)
log.DebugS(ctx, "New HashMail write stream")

// Now that we have the first message, we can attempt to look up the
// given stream.
writeStream, err := h.LookUpWriteStream(cipherBox.Desc.StreamId)
writeStream, err := h.LookUpWriteStream(ctx, cipherBox.Desc.StreamId)
if err != nil {
return err
}
Expand All @@ -643,13 +652,12 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe
// write inactive if the client hangs up on their end.
defer writeStream.ReturnStream()

log.Tracef("Sending msg_len=%v to stream_id=%x", len(cipherBox.Msg),
cipherBox.Desc.StreamId)
log.TraceS(ctx, "Sending message to stream",
"msg_len", len(cipherBox.Msg))

// We'll send the first message into the stream, then enter our loop
// below to continue to read from the stream and send it to the read
// end.
ctx := readStream.Context()
if err := writeStream.WriteMsg(ctx, cipherBox.Msg); err != nil {
return err
}
Expand All @@ -659,7 +667,7 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe
// exit before shutting down.
select {
case <-ctx.Done():
log.Debugf("SendStream: Context done, exiting")
log.DebugS(ctx, "SendStream: Context done, exiting")
return nil
case <-h.quit:
return fmt.Errorf("server shutting down")
Expand All @@ -669,13 +677,13 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe

cipherBox, err := readStream.Recv()
if err != nil {
log.Debugf("SendStream: Exiting write stream RPC "+
"stream read: %v", err)
log.DebugS(ctx, "SendStream: Exiting write stream RPC "+
"stream read", err)
return err
}

log.Tracef("Sending msg_len=%v to stream_id=%x",
len(cipherBox.Msg), cipherBox.Desc.StreamId)
log.TraceS(ctx, "Sending message to stream",
"msg_len", len(cipherBox.Msg))

if err := writeStream.WriteMsg(ctx, cipherBox.Msg); err != nil {
return err
Expand All @@ -689,25 +697,28 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe
func (h *hashMailServer) RecvStream(desc *hashmailrpc.CipherBoxDesc,
reader hashmailrpc.HashMail_RecvStreamServer) error {

ctx := btclog.WithCtx(reader.Context(),
btclog.Hex("stream_id", desc.StreamId))

// First, we'll attempt to locate the stream. We allow any single
// entity that knows of the full stream ID to access the read end.
readStream, err := h.LookUpReadStream(desc.StreamId)
readStream, err := h.LookUpReadStream(ctx, desc.StreamId)
if err != nil {
return err
}

log.Debugf("New HashMail read stream: id=%x", desc.StreamId)
log.DebugS(ctx, "New HashMail read stream")

// If the reader hangs up, then we'll mark the stream as inactive so
// another can take its place.
defer readStream.ReturnStream()
defer readStream.ReturnStream(ctx)

for {
// Check to see if the stream has been closed or if we need to
// exit before shutting down.
// exit before shutting d[own.
select {
case <-reader.Context().Done():
log.Debugf("Read stream context done.")
log.DebugS(ctx, "Read stream context done.")
return nil
case <-h.quit:
return fmt.Errorf("server shutting down")
Expand All @@ -717,12 +728,11 @@ func (h *hashMailServer) RecvStream(desc *hashmailrpc.CipherBoxDesc,

nextMsg, err := readStream.ReadNextMsg(reader.Context())
if err != nil {
log.Debugf("Got error an read stream read: %v", err)
log.ErrorS(ctx, "Got error on read stream read", err)
return err
}

log.Tracef("Read %v bytes for HashMail stream_id=%x",
len(nextMsg), desc.StreamId)
log.TraceS(ctx, "Read bytes", "msg_len", len(nextMsg))

// In order not to duplicate metric data, we only record this
// read if its streamID is odd. We use the base stream ID as the
Expand All @@ -742,7 +752,7 @@ func (h *hashMailServer) RecvStream(desc *hashmailrpc.CipherBoxDesc,
Msg: nextMsg,
})
if err != nil {
log.Debugf("Got error when sending on read stream: %v",
log.DebugS(ctx, "Got error when sending on read stream",
err)
return err
}
Expand All @@ -767,7 +777,7 @@ type streamStatus struct {
}

// newStreamStatus constructs a new streamStatus instance.
func newStreamStatus(onStale func() error,
func newStreamStatus(ctx context.Context, onStale func() error,
staleTimeout time.Duration) *streamStatus {

if staleTimeout < 0 {
Expand All @@ -778,7 +788,7 @@ func newStreamStatus(onStale func() error,

staleTimer := time.AfterFunc(staleTimeout, func() {
if err := onStale(); err != nil {
log.Errorf("error in onStale callback: %v", err)
log.ErrorS(ctx, "Error from onStale callback", err)
}
})

Expand Down

0 comments on commit 550bede

Please sign in to comment.