diff --git a/docs/release-notes/release-notes-0.16.0.md b/docs/release-notes/release-notes-0.16.0.md index 62cdb85792..e75af137b7 100644 --- a/docs/release-notes/release-notes-0.16.0.md +++ b/docs/release-notes/release-notes-0.16.0.md @@ -159,6 +159,11 @@ all unconfirmed transactions are rebroadcast on start up. `ForwardingHistory`rpcs](https://github.com/lightningnetwork/lnd/pull/7471) to the rpc response. +* HTLC interceptor clients can now [supply unencrypted failure + reasons](https://github.com/lightningnetwork/lnd/pull/7067) when failing + HTLCs. This allows returning any type of failure to the sender as if it + originates from the LND node. + ## Wallet * [Allows Taproot public keys and tap scripts to be imported as watch-only diff --git a/htlcswitch/hop/error_encryptor.go b/htlcswitch/hop/error_encryptor.go index 7b6a3dd1a5..dd5b71e903 100644 --- a/htlcswitch/hop/error_encryptor.go +++ b/htlcswitch/hop/error_encryptor.go @@ -54,6 +54,13 @@ type ErrorEncrypter interface { // until the error arrives at the source of the payment. IntermediateEncrypt(lnwire.OpaqueReason) lnwire.OpaqueReason + // EncryptEncodedFirstHop encrypts an already encoded failure message. + // This method will be used at the source that the error occurs. It + // differs from IntermediateEncrypt slightly, in that it computes a + // proper MAC over the error. + EncryptEncodedFirstHop( + encodedFailureMessage []byte) (lnwire.OpaqueReason, error) + // Type returns an enum indicating the underlying concrete instance // backing this interface. Type() EncrypterType @@ -117,6 +124,20 @@ func (s *SphinxErrorEncrypter) EncryptFirstHop( return s.EncryptError(true, b.Bytes()), nil } +// EncryptEncodedFirstHop encrypts an already encoded failure message. This +// method will be used at the source that the error occurs. +func (s *SphinxErrorEncrypter) EncryptEncodedFirstHop( + encodedFailureMessage []byte) (lnwire.OpaqueReason, error) { + + var w bytes.Buffer + err := lnwire.EncodeFailureHeader(&w, encodedFailureMessage, 0) + if err != nil { + return nil, err + } + + return s.EncryptError(true, w.Bytes()), nil +} + // EncryptMalformedError is similar to EncryptFirstHop (it adds the MAC), but // it accepts an opaque failure reason rather than a failure message. This // method is used when we receive an UpdateFailMalformedHTLC from the remote diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index 09646bf2d0..776b8bd5d2 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -1,10 +1,12 @@ package htlcswitch import ( + "bytes" "crypto/sha256" "fmt" "sync" + "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb/models" @@ -83,6 +85,11 @@ type InterceptableSwitch struct { // currentHeight is the currently best known height. currentHeight int32 + // signChannelUpdate is used when an intercepting application includes + // an unsigned channel update to be signed by us. + signChannelUpdate func(u *lnwire.ChannelUpdate) (*ecdsa.Signature, + error) + wg sync.WaitGroup quit chan struct{} } @@ -119,9 +126,15 @@ type FwdResolution struct { // FwdActionSettle. Preimage lntypes.Preimage - // FailureMessage is the encrypted failure message that is to be passed - // back to the sender if action is FwdActionFail. - FailureMessage []byte + // EncryptedFailureMessage is the encrypted failure message that is to + // be passed back to the sender if action is FwdActionFail. This field + // is mutually exclusive with FailureMessage. + EncryptedFailureMessage []byte + + // FailureMessage is the decoded failure message that is to be encrypted + // for the first hop. This field is mutually exclusive with + // EncryptedFailureMessage. + FailureMessage lnwire.FailureMessage // FailureCode is the failure code that is to be passed back to the // sender if action is FwdActionFail. @@ -158,6 +171,11 @@ type InterceptableSwitchConfig struct { // RequireInterceptor indicates whether processing should block if no // interceptor is connected. RequireInterceptor bool + + // SignChannelUpdate is used when an intercepting application includes + // an unsigned channel update to be signed by us. + SignChannelUpdate func(u *lnwire.ChannelUpdate) (*ecdsa.Signature, + error) } // NewInterceptableSwitch returns an instance of InterceptableSwitch. @@ -181,6 +199,7 @@ func NewInterceptableSwitch(cfg *InterceptableSwitchConfig) ( cltvRejectDelta: cfg.CltvRejectDelta, cltvInterceptDelta: cfg.CltvInterceptDelta, notifier: cfg.Notifier, + signChannelUpdate: cfg.SignChannelUpdate, quit: make(chan struct{}), }, nil @@ -377,17 +396,115 @@ func (s *InterceptableSwitch) resolve(res *FwdResolution) error { return intercepted.Settle(res.Preimage) case FwdActionFail: - if len(res.FailureMessage) > 0 { - return intercepted.Fail(res.FailureMessage) - } + switch { + // Fail with encrypted failure message. + case len(res.EncryptedFailureMessage) > 0: + return intercepted.Fail( + res.EncryptedFailureMessage, false, + ) + + // Fail with known failure message that is to be encoded and + // encrypted. + case res.FailureMessage != nil: + msg := res.FailureMessage + + // Re-sign the channel update if present. Note that this + // changes the passed in FwdResolution. + update := getChannelUpdateRef(msg) + if update != nil { + err := s.validateChannelUpdate(update) + if err != nil { + return err + } + + err = s.resignChannelUpdate(update) + if err != nil { + return err + } + } - return intercepted.FailWithCode(res.FailureCode) + var encodedMsg bytes.Buffer + err := lnwire.EncodeFailureMessage( + &encodedMsg, msg, 0, + ) + if err != nil { + return err + } + + return intercepted.Fail( + encodedMsg.Bytes(), true, + ) + + // Fail with failure code. + default: + return intercepted.FailWithCode(res.FailureCode) + } default: return fmt.Errorf("unrecognized action %v", res.Action) } } +func (s *InterceptableSwitch) validateChannelUpdate( + update *lnwire.ChannelUpdate) error { + + // The maxHTLC flag is mandatory. + if !update.MessageFlags.HasMaxHtlc() { + return errors.Errorf("max htlc flag not set for channel") + } + + // Check that max htlc is at least min htlc. + maxHtlc := update.HtlcMaximumMsat + if maxHtlc == 0 || maxHtlc < update.HtlcMinimumMsat { + return errors.Errorf("invalid max htlc for channel update ") + } + + return nil +} + +// resignChannelUpdate signs the provided channel update with our node key. +func (s *InterceptableSwitch) resignChannelUpdate( + update *lnwire.ChannelUpdate) error { + + sig, err := s.signChannelUpdate(update) + if err != nil { + return err + } + + update.Signature, err = lnwire.NewSigFromSignature(sig) + if err != nil { + return nil + } + + return nil +} + +// getChannelUpdateRef returns a reference to an embedded channel update if +// present in the failure message. +func getChannelUpdateRef(msg lnwire.FailureMessage) *lnwire.ChannelUpdate { + switch m := msg.(type) { + case *lnwire.FailFeeInsufficient: + return &m.Update + + case *lnwire.FailIncorrectCltvExpiry: + return &m.Update + + case *lnwire.FailTemporaryChannelFailure: + return m.Update + + case *lnwire.FailAmountBelowMinimum: + return &m.Update + + case *lnwire.FailExpiryTooSoon: + return &m.Update + + case *lnwire.FailChannelDisabled: + return &m.Update + } + + return nil +} + // Resolve resolves an intercepted packet. func (s *InterceptableSwitch) Resolve(res *FwdResolution) error { internalRes := &fwdResolution{ @@ -615,10 +732,24 @@ func (f *interceptedForward) Resume() error { return f.htlcSwitch.ForwardPackets(nil, f.packet) } -// Fail notifies the intention to Fail an existing hold forward with an -// encrypted failure reason. -func (f *interceptedForward) Fail(reason []byte) error { - obfuscatedReason := f.packet.obfuscator.IntermediateEncrypt(reason) +// Fail notifies the intention to Fail an existing hold forward. +func (f *interceptedForward) Fail(reason []byte, encryptFirstHop bool) error { + var ( + obfuscatedReason []byte + obfuscator = f.packet.obfuscator + ) + + if encryptFirstHop { + var err error + obfuscatedReason, err = obfuscator.EncryptEncodedFirstHop( + reason, + ) + if err != nil { + return err + } + } else { + obfuscatedReason = obfuscator.IntermediateEncrypt(reason) + } return f.resolve(&lnwire.UpdateFailHTLC{ Reason: obfuscatedReason, diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 8f418edfed..4ea34cc871 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -332,9 +332,10 @@ type InterceptedForward interface { // forward with a given preimage. Settle(lntypes.Preimage) error - // Fail notifies the intention to fail an existing hold forward with an - // encrypted failure reason. - Fail(reason []byte) error + // Fail notifies the intention to fail an existing hold forward with a + // failure reason. The encryptFirstHop bool indicates whether the + // failure reason still needs to be encrypted for the first hop. + Fail(reason []byte, encryptFirstHop bool) error // FailWithCode notifies the intention to fail an existing hold forward // with the specified failure code. diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index bb36eda20b..03bd7f0d48 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -436,6 +436,21 @@ func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) ( return b.Bytes(), nil } +func (o *mockObfuscator) EncryptEncodedFirstHop( + reason []byte) (lnwire.OpaqueReason, error) { + + var b bytes.Buffer + if _, err := b.Write(fakeHmac); err != nil { + return nil, err + } + + if err := lnwire.EncodeFailureHeader(&b, reason, 0); err != nil { + return nil, err + } + + return b.Bytes(), nil +} + func (o *mockObfuscator) IntermediateEncrypt(reason lnwire.OpaqueReason) lnwire.OpaqueReason { return reason } diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 28f3d9e2c8..e37da2101d 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -1,6 +1,7 @@ package htlcswitch import ( + "bytes" "crypto/rand" "crypto/sha256" "fmt" @@ -10,6 +11,8 @@ import ( "testing" "time" + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" @@ -26,6 +29,11 @@ import ( "github.com/stretchr/testify/require" ) +const ( + cltvRejectDelta = 10 + cltvInterceptDelta = 13 +) + var zeroCircuit = models.CircuitKey{} var emptyScid = lnwire.ShortChannelID{} @@ -3803,10 +3811,13 @@ type interceptableSwitchTestContext struct { aliceChannelLink *mockChannelLink bobChannelLink *mockChannelLink s *Switch + + notifier *mock.ChainNotifier + interceptableSwitch *InterceptableSwitch } -func newInterceptableSwitchTestContext( - t *testing.T) *interceptableSwitchTestContext { +func newInterceptableSwitchTestContext(t *testing.T, //nolint: thelper + requireInterceptor bool) *interceptableSwitchTestContext { chanID1, chanID2, aliceChanID, bobChanID := genIDs() @@ -3854,8 +3865,8 @@ func newInterceptableSwitchTestContext( rhash: sha256.Sum256(preimage[:]), onionBlob: [1366]byte{4, 5, 6}, incomingHtlcID: uint64(0), - cltvRejectDelta: 10, - cltvInterceptDelta: 13, + cltvRejectDelta: cltvRejectDelta, + cltvInterceptDelta: cltvInterceptDelta, forwardInterceptor: &mockForwardInterceptor{ t: t, interceptedChan: make(chan InterceptedPacket), @@ -3865,9 +3876,48 @@ func newInterceptableSwitchTestContext( s: s, } + ctx.instantiateInterceptableSwitch(requireInterceptor) + return ctx } +func (c *interceptableSwitchTestContext) signChannelUpdate( + u *lnwire.ChannelUpdate) (*ecdsa.Signature, error) { + + data, err := u.DataToSign() + require.NoError(c.t, err) + + key, _ := btcec.PrivKeyFromBytes(alicePrivKey) + sig := ecdsa.Sign(key, data) + + return sig, nil +} + +func (c *interceptableSwitchTestContext) instantiateInterceptableSwitch( + requireInterceptor bool) { + + c.notifier = &mock.ChainNotifier{ + EpochChan: make(chan *chainntnfs.BlockEpoch, 1), + } + c.notifier.EpochChan <- &chainntnfs.BlockEpoch{ + Height: testStartingHeight, + } + + switchForwardInterceptor, err := NewInterceptableSwitch( + &InterceptableSwitchConfig{ + Switch: c.s, + CltvRejectDelta: cltvRejectDelta, + CltvInterceptDelta: cltvInterceptDelta, + Notifier: c.notifier, + RequireInterceptor: requireInterceptor, + SignChannelUpdate: c.signChannelUpdate, + }, + ) + require.NoError(c.t, err) + + c.interceptableSwitch = switchForwardInterceptor +} + func (c *interceptableSwitchTestContext) createTestPacket() *htlcPacket { c.incomingHtlcID++ @@ -3907,33 +3957,23 @@ func (c *interceptableSwitchTestContext) createSettlePacket( func TestSwitchHoldForward(t *testing.T) { t.Parallel() - c := newInterceptableSwitchTestContext(t) + c := newInterceptableSwitchTestContext(t, false) defer c.finish() - notifier := &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch, 1), - } - notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight} + require.NoError(c.t, c.interceptableSwitch.Start()) - switchForwardInterceptor, err := NewInterceptableSwitch( - &InterceptableSwitchConfig{ - Switch: c.s, - CltvRejectDelta: c.cltvRejectDelta, - CltvInterceptDelta: c.cltvInterceptDelta, - Notifier: notifier, - }, + // Set interceptor. + c.interceptableSwitch.SetInterceptor( + c.forwardInterceptor.InterceptForwardHtlc, ) - require.NoError(t, err) - require.NoError(t, switchForwardInterceptor.Start()) - switchForwardInterceptor.SetInterceptor(c.forwardInterceptor.InterceptForwardHtlc) linkQuit := make(chan struct{}) // Test a forward that expires too soon. packet := c.createTestPacket() packet.incomingTimeout = testStartingHeight + c.cltvRejectDelta - 1 - err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet) + err := c.interceptableSwitch.ForwardPackets(linkQuit, false, packet) require.NoError(t, err, "can't forward htlc packet") assertOutgoingLinkReceive(t, c.bobChannelLink, false) assertOutgoingLinkReceiveIntercepted(t, c.aliceChannelLink) @@ -3951,12 +3991,12 @@ func TestSwitchHoldForward(t *testing.T) { return nil, errors.New("cannot fetch update") } - err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet) + err = c.interceptableSwitch.ForwardPackets(linkQuit, false, packet) require.NoError(t, err, "can't forward htlc packet") receivedPkt := assertOutgoingLinkReceive(t, c.bobChannelLink, true) assertNumCircuits(t, c.s, 1, 1) - require.NoError(t, switchForwardInterceptor.ForwardPackets( + require.NoError(t, c.interceptableSwitch.ForwardPackets( linkQuit, false, c.createSettlePacket(receivedPkt.outgoingHTLCID), )) @@ -3968,7 +4008,7 @@ func TestSwitchHoldForward(t *testing.T) { // Test resume a hold forward. assertNumCircuits(t, c.s, 0, 0) - err = switchForwardInterceptor.ForwardPackets( + err = c.interceptableSwitch.ForwardPackets( linkQuit, false, c.createTestPacket(), ) require.NoError(t, err) @@ -3976,7 +4016,7 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, c.s, 0, 0) assertOutgoingLinkReceive(t, c.bobChannelLink, false) - require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{ + require.NoError(t, c.interceptableSwitch.Resolve(&FwdResolution{ Action: FwdActionResume, Key: c.forwardInterceptor.getIntercepted().IncomingCircuit, })) @@ -3984,7 +4024,7 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, c.s, 1, 1) // settling the htlc to close the circuit. - err = switchForwardInterceptor.ForwardPackets( + err = c.interceptableSwitch.ForwardPackets( linkQuit, false, c.createSettlePacket(receivedPkt.outgoingHTLCID), ) @@ -3994,7 +4034,7 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, c.s, 0, 0) // Test resume a hold forward after disconnection. - require.NoError(t, switchForwardInterceptor.ForwardPackets( + require.NoError(t, c.interceptableSwitch.ForwardPackets( linkQuit, false, c.createTestPacket(), )) @@ -4006,13 +4046,13 @@ func TestSwitchHoldForward(t *testing.T) { assertOutgoingLinkReceive(t, c.bobChannelLink, false) // Disconnect should resume the forwarding. - switchForwardInterceptor.SetInterceptor(nil) + c.interceptableSwitch.SetInterceptor(nil) receivedPkt = assertOutgoingLinkReceive(t, c.bobChannelLink, true) assertNumCircuits(t, c.s, 1, 1) // Settle the htlc to close the circuit. - require.NoError(t, switchForwardInterceptor.ForwardPackets( + require.NoError(t, c.interceptableSwitch.ForwardPackets( linkQuit, false, c.createSettlePacket(receivedPkt.outgoingHTLCID), )) @@ -4021,17 +4061,17 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, c.s, 0, 0) // Test failing a hold forward - switchForwardInterceptor.SetInterceptor( + c.interceptableSwitch.SetInterceptor( c.forwardInterceptor.InterceptForwardHtlc, ) - require.NoError(t, switchForwardInterceptor.ForwardPackets( + require.NoError(t, c.interceptableSwitch.ForwardPackets( linkQuit, false, c.createTestPacket(), )) assertNumCircuits(t, c.s, 0, 0) assertOutgoingLinkReceive(t, c.bobChannelLink, false) - require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{ + require.NoError(t, c.interceptableSwitch.Resolve(&FwdResolution{ Action: FwdActionFail, Key: c.forwardInterceptor.getIntercepted().IncomingCircuit, FailureCode: lnwire.CodeTemporaryChannelFailure, @@ -4042,18 +4082,20 @@ func TestSwitchHoldForward(t *testing.T) { // Test failing a hold forward with a failure message. require.NoError(t, - switchForwardInterceptor.ForwardPackets( + c.interceptableSwitch.ForwardPackets( linkQuit, false, c.createTestPacket(), ), ) assertNumCircuits(t, c.s, 0, 0) assertOutgoingLinkReceive(t, c.bobChannelLink, false) - reason := lnwire.OpaqueReason([]byte{1, 2, 3}) - require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{ - Action: FwdActionFail, - Key: c.forwardInterceptor.getIntercepted().IncomingCircuit, - FailureMessage: reason, + reason := lnwire.OpaqueReason(make([]byte, 292)) + copy(reason, []byte{1, 2, 3}) + intercepted := c.forwardInterceptor.getIntercepted() + require.NoError(t, c.interceptableSwitch.Resolve(&FwdResolution{ + Action: FwdActionFail, + Key: intercepted.IncomingCircuit, + EncryptedFailureMessage: reason, })) assertOutgoingLinkReceive(t, c.bobChannelLink, false) @@ -4064,7 +4106,7 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, c.s, 0, 0) // Test failing a hold forward with a malformed htlc failure. - err = switchForwardInterceptor.ForwardPackets( + err = c.interceptableSwitch.ForwardPackets( linkQuit, false, c.createTestPacket(), ) require.NoError(t, err) @@ -4073,7 +4115,7 @@ func TestSwitchHoldForward(t *testing.T) { assertOutgoingLinkReceive(t, c.bobChannelLink, false) code := lnwire.CodeInvalidOnionKey - require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{ + require.NoError(t, c.interceptableSwitch.Resolve(&FwdResolution{ Action: FwdActionFail, Key: c.forwardInterceptor.getIntercepted().IncomingCircuit, FailureCode: code, @@ -4095,13 +4137,13 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, c.s, 0, 0) // Test settling a hold forward - require.NoError(t, switchForwardInterceptor.ForwardPackets( + require.NoError(t, c.interceptableSwitch.ForwardPackets( linkQuit, false, c.createTestPacket(), )) assertNumCircuits(t, c.s, 0, 0) assertOutgoingLinkReceive(t, c.bobChannelLink, false) - require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{ + require.NoError(t, c.interceptableSwitch.Resolve(&FwdResolution{ Key: c.forwardInterceptor.getIntercepted().IncomingCircuit, Action: FwdActionSettle, Preimage: c.preimage, @@ -4110,29 +4152,16 @@ func TestSwitchHoldForward(t *testing.T) { assertOutgoingLinkReceive(t, c.aliceChannelLink, true) assertNumCircuits(t, c.s, 0, 0) - require.NoError(t, switchForwardInterceptor.Stop()) + require.NoError(t, c.interceptableSwitch.Stop()) // Test always-on interception. - notifier = &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch, 1), - } - notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight} + c.instantiateInterceptableSwitch(true) - switchForwardInterceptor, err = NewInterceptableSwitch( - &InterceptableSwitchConfig{ - Switch: c.s, - CltvRejectDelta: c.cltvRejectDelta, - CltvInterceptDelta: c.cltvInterceptDelta, - RequireInterceptor: true, - Notifier: notifier, - }, - ) - require.NoError(t, err) - require.NoError(t, switchForwardInterceptor.Start()) + require.NoError(t, c.interceptableSwitch.Start()) // Forward a fresh packet. It is expected to be failed immediately, // because there is no interceptor registered. - require.NoError(t, switchForwardInterceptor.ForwardPackets( + require.NoError(t, c.interceptableSwitch.ForwardPackets( linkQuit, false, c.createTestPacket(), )) @@ -4145,7 +4174,7 @@ func TestSwitchHoldForward(t *testing.T) { // goroutine. errChan := make(chan error) go func() { - errChan <- switchForwardInterceptor.ForwardPackets( + errChan <- c.interceptableSwitch.ForwardPackets( linkQuit, true, c.createTestPacket(), ) }() @@ -4155,7 +4184,7 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, c.s, 0, 0) // Register an interceptor. - switchForwardInterceptor.SetInterceptor( + c.interceptableSwitch.SetInterceptor( c.forwardInterceptor.InterceptForwardHtlc, ) @@ -4166,16 +4195,16 @@ func TestSwitchHoldForward(t *testing.T) { c.forwardInterceptor.getIntercepted() // Disconnect and reconnect interceptor. - switchForwardInterceptor.SetInterceptor(nil) - switchForwardInterceptor.SetInterceptor( + c.interceptableSwitch.SetInterceptor(nil) + c.interceptableSwitch.SetInterceptor( c.forwardInterceptor.InterceptForwardHtlc, ) // A replay of the held packet is expected. - intercepted := c.forwardInterceptor.getIntercepted() + intercepted = c.forwardInterceptor.getIntercepted() // Settle the packet. - require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{ + require.NoError(t, c.interceptableSwitch.Resolve(&FwdResolution{ Key: intercepted.IncomingCircuit, Action: FwdActionSettle, Preimage: c.preimage, @@ -4184,7 +4213,7 @@ func TestSwitchHoldForward(t *testing.T) { assertOutgoingLinkReceive(t, c.aliceChannelLink, true) assertNumCircuits(t, c.s, 0, 0) - require.NoError(t, switchForwardInterceptor.Stop()) + require.NoError(t, c.interceptableSwitch.Stop()) select { case <-c.forwardInterceptor.interceptedChan: @@ -4197,28 +4226,12 @@ func TestSwitchHoldForward(t *testing.T) { func TestInterceptableSwitchWatchDog(t *testing.T) { t.Parallel() - c := newInterceptableSwitchTestContext(t) + c := newInterceptableSwitchTestContext(t, false) defer c.finish() - // Start interceptable switch. - notifier := &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch, 1), - } - notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight} - - switchForwardInterceptor, err := NewInterceptableSwitch( - &InterceptableSwitchConfig{ - Switch: c.s, - CltvRejectDelta: c.cltvRejectDelta, - CltvInterceptDelta: c.cltvInterceptDelta, - Notifier: notifier, - }, - ) - require.NoError(t, err) - require.NoError(t, switchForwardInterceptor.Start()) - + require.NoError(t, c.interceptableSwitch.Start()) // Set interceptor. - switchForwardInterceptor.SetInterceptor( + c.interceptableSwitch.SetInterceptor( c.forwardInterceptor.InterceptForwardHtlc, ) @@ -4227,7 +4240,7 @@ func TestInterceptableSwitchWatchDog(t *testing.T) { packet := c.createTestPacket() - err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet) + err := c.interceptableSwitch.ForwardPackets(linkQuit, false, packet) require.NoError(t, err, "can't forward htlc packet") // Intercept the packet. @@ -4239,7 +4252,7 @@ func TestInterceptableSwitchWatchDog(t *testing.T) { ) // Htlc expires before a resolution from the interceptor. - notifier.EpochChan <- &chainntnfs.BlockEpoch{ + c.notifier.EpochChan <- &chainntnfs.BlockEpoch{ Height: int32(packet.incomingTimeout) - int32(c.cltvRejectDelta), } @@ -4248,13 +4261,65 @@ func TestInterceptableSwitchWatchDog(t *testing.T) { assertOutgoingLinkReceive(t, c.aliceChannelLink, true) // It is too late now to resolve. Expect an error. - require.Error(t, switchForwardInterceptor.Resolve(&FwdResolution{ + require.Error(t, c.interceptableSwitch.Resolve(&FwdResolution{ Action: FwdActionSettle, Key: intercepted.IncomingCircuit, Preimage: c.preimage, })) } +func TestInterceptableSwitchUnencryptedFailure(t *testing.T) { + t.Parallel() + + c := newInterceptableSwitchTestContext(t, false) + defer c.finish() + + require.NoError(t, c.interceptableSwitch.Start()) + c.interceptableSwitch.SetInterceptor( + c.forwardInterceptor.InterceptForwardHtlc, + ) + + // Receive a packet. + linkQuit := make(chan struct{}) + + packet := c.createTestPacket() + + err := c.interceptableSwitch.ForwardPackets(linkQuit, false, packet) + require.NoError(t, err, "can't forward htlc packet") + + // Intercept the packet. + intercepted := c.forwardInterceptor.getIntercepted() + + // Fail the htlc with an unencrypted failure message. + msg := lnwire.NewFeeInsufficient(10, lnwire.ChannelUpdate{ + MessageFlags: lnwire.ChanUpdateRequiredMaxHtlc, + HtlcMaximumMsat: 10000, + }) + var msgBuffer bytes.Buffer + require.NoError(t, lnwire.EncodeFailureMessage(&msgBuffer, msg, 0)) + + require.NoError(t, c.interceptableSwitch.Resolve(&FwdResolution{ + Action: FwdActionFail, + Key: intercepted.IncomingCircuit, + FailureMessage: msg, + })) + assertOutgoingLinkReceive(t, c.bobChannelLink, false) + recvPkt := assertOutgoingLinkReceive(t, c.aliceChannelLink, true) + assertNumCircuits(t, c.s, 0, 0) + + // Assert that the sender receives the expected failure message type. + deobfuscator := newMockDeobfuscator() + + fwdErr, err := deobfuscator.DecryptError( + recvPkt.htlc.(*lnwire.UpdateFailHTLC).Reason, + ) + require.NoError(t, err) + + feeInsufficientErr, ok := fwdErr.msg.(*lnwire.FailFeeInsufficient) + require.True(t, ok) + require.NotZero(t, feeInsufficientErr.Update.Signature) +} + // TestSwitchDustForwarding tests that the switch properly fails HTLC's which // have incoming or outgoing links that breach their dust thresholds. func TestSwitchDustForwarding(t *testing.T) { diff --git a/intercepted_forward.go b/intercepted_forward.go index 70590d0e41..431d064e23 100644 --- a/intercepted_forward.go +++ b/intercepted_forward.go @@ -51,9 +51,10 @@ func (f *interceptedForward) Resume() error { return ErrCannotResume } -// Fail notifies the intention to fail an existing hold forward with an -// encrypted failure reason. -func (f *interceptedForward) Fail(_ []byte) error { +// Fail notifies the intention to fail an existing hold forward with a failure +// reason. The encryptFirstHop bool indicates whether the failure reason still +// needs to be encrypted for the first hop. +func (f *interceptedForward) Fail(_ []byte, _ bool) error { // We can't actively fail an htlc. The best we could do is abandon the // resolver, but this wouldn't be a safe operation. There may be a race // with the preimage beacon supplying a preimage. Therefore we don't diff --git a/lnrpc/routerrpc/forward_interceptor.go b/lnrpc/routerrpc/forward_interceptor.go index 8af1a21f6c..fdf129762f 100644 --- a/lnrpc/routerrpc/forward_interceptor.go +++ b/lnrpc/routerrpc/forward_interceptor.go @@ -1,6 +1,8 @@ package routerrpc import ( + "bytes" + "crypto/sha256" "errors" "github.com/lightningnetwork/lnd/channeldb/models" @@ -120,62 +122,18 @@ func (r *forwardInterceptor) resolveFromClient( }) case ResolveHoldForwardAction_FAIL: - // Fail with an encrypted reason. - if in.FailureMessage != nil { - if in.FailureCode != 0 { - return status.Errorf( - codes.InvalidArgument, - "failure message and failure code "+ - "are mutually exclusive", - ) - } - - // Verify that the size is equal to the fixed failure - // message size + hmac + two uint16 lengths. See BOLT - // #4. - if len(in.FailureMessage) != - lnwire.FailureMessageLength+32+2+2 { - - return status.Errorf( - codes.InvalidArgument, - "failure message length invalid", - ) - } - - return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{ - Key: circuitKey, - Action: htlcswitch.FwdActionFail, - FailureMessage: in.FailureMessage, - }) + // Instantiate the fwdRes with base fields set. + fwdRes := &htlcswitch.FwdResolution{ + Key: circuitKey, + Action: htlcswitch.FwdActionFail, } - var code lnwire.FailCode - switch in.FailureCode { - case lnrpc.Failure_INVALID_ONION_HMAC: - code = lnwire.CodeInvalidOnionHmac - - case lnrpc.Failure_INVALID_ONION_KEY: - code = lnwire.CodeInvalidOnionKey - - case lnrpc.Failure_INVALID_ONION_VERSION: - code = lnwire.CodeInvalidOnionVersion - - // Default to TemporaryChannelFailure. - case 0, lnrpc.Failure_TEMPORARY_CHANNEL_FAILURE: - code = lnwire.CodeTemporaryChannelFailure - - default: - return status.Errorf( - codes.InvalidArgument, - "unsupported failure code: %v", in.FailureCode, - ) + err := unmarshallFailureResolution(in, fwdRes) + if err != nil { + return err } - return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{ - Key: circuitKey, - Action: htlcswitch.FwdActionFail, - FailureCode: code, - }) + return r.htlcSwitch.Resolve(fwdRes) case ResolveHoldForwardAction_SETTLE: if in.Preimage == nil { @@ -199,3 +157,86 @@ func (r *forwardInterceptor) resolveFromClient( ) } } + +// unmarshallFailureResolution unmarshalls the rpc failure code and message into +// a resolution struct. +func unmarshallFailureResolution(in *ForwardHtlcInterceptResponse, + fwdRes *htlcswitch.FwdResolution) error { + + if in.FailureMessage != nil { + if in.FailureCode != 0 { + return status.Errorf( + codes.InvalidArgument, + "failure message and failure code "+ + "are mutually exclusive", + ) + } + + switch { + // Verify that for encrypted messages the size is equal to the + // fixed failure message size + hmac + two uint16 lengths. See + // BOLT #4. + case !in.FailureMessageUnencrypted: + if len(in.FailureMessage) != + lnwire.FailureMessageLength+sha256.Size+2+2 { + + return status.Errorf( + codes.InvalidArgument, + "failure message length invalid", + ) + } + + fwdRes.EncryptedFailureMessage = in.FailureMessage + + // For unencrypted messages, verify that they are parseable. + case in.FailureMessageUnencrypted: + r := bytes.NewReader(in.FailureMessage) + msg, err := lnwire.DecodeFailureMessage(r, 0) + if err != nil { + return status.Errorf( + codes.InvalidArgument, + "failure message unparseable", + ) + } + + fwdRes.FailureMessage = msg + } + } else { + code, err := unmarshalFailCode(in.FailureCode) + if err != nil { + return err + } + + fwdRes.FailureCode = code + } + + return nil +} + +func unmarshalFailCode(failureCode lnrpc.Failure_FailureCode) (lnwire.FailCode, + error) { + + var code lnwire.FailCode + switch failureCode { + case lnrpc.Failure_INVALID_ONION_HMAC: + code = lnwire.CodeInvalidOnionHmac + + case lnrpc.Failure_INVALID_ONION_KEY: + code = lnwire.CodeInvalidOnionKey + + case lnrpc.Failure_INVALID_ONION_VERSION: + code = lnwire.CodeInvalidOnionVersion + + // Default to TemporaryChannelFailure. + case 0, lnrpc.Failure_TEMPORARY_CHANNEL_FAILURE: + code = lnwire.CodeTemporaryChannelFailure + + default: + return 0, status.Errorf( + codes.InvalidArgument, + "unsupported failure code: %v", failureCode, + ) + } + + return code, nil +} diff --git a/lnrpc/routerrpc/router.pb.go b/lnrpc/routerrpc/router.pb.go index e02035a29a..eae4e4ba42 100644 --- a/lnrpc/routerrpc/router.pb.go +++ b/lnrpc/routerrpc/router.pb.go @@ -3113,11 +3113,18 @@ type ForwardHtlcInterceptResponse struct { Action ResolveHoldForwardAction `protobuf:"varint,2,opt,name=action,proto3,enum=routerrpc.ResolveHoldForwardAction" json:"action,omitempty"` // The preimage in case the resolve action is Settle. Preimage []byte `protobuf:"bytes,3,opt,name=preimage,proto3" json:"preimage,omitempty"` - // Encrypted failure message in case the resolve action is Fail. + // Failure message in case the resolve action is Fail. The field + // failure_message_unencrypted indicates whether this message is already + // encrypted for the first hop. // - // If failure_message is specified, the failure_code field must be set - // to zero. + // If failure_message is specified, the failure_code field must be set to + // zero. FailureMessage []byte `protobuf:"bytes,4,opt,name=failure_message,json=failureMessage,proto3" json:"failure_message,omitempty"` + // Indicates whether the failure message still needs to be encrypted for the + // first hop. + FailureMessageUnencrypted bool `protobuf:"varint,6,opt,name=failure_message_unencrypted,json=failureMessageUnencrypted,proto3" json:"failure_message_unencrypted,omitempty"` + // Deprecated: use failure_message with failure_message_unencrypted. + // // Return the specified failure code in case the resolve action is Fail. The // message data fields are populated automatically. // @@ -3188,6 +3195,13 @@ func (x *ForwardHtlcInterceptResponse) GetFailureMessage() []byte { return nil } +func (x *ForwardHtlcInterceptResponse) GetFailureMessageUnencrypted() bool { + if x != nil { + return x.FailureMessageUnencrypted + } + return false +} + func (x *ForwardHtlcInterceptResponse) GetFailureCode() lnrpc.Failure_FailureCode { if x != nil { return x.FailureCode @@ -3673,7 +3687,7 @@ var file_routerrpc_router_proto_rawDesc = []byte{ 0x73, 0x74, 0x6f, 0x6d, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0c, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0xa8, 0x02, 0x0a, + 0x0c, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0xe8, 0x02, 0x0a, 0x1c, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x48, 0x74, 0x6c, 0x63, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x63, 0x65, 0x70, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x47, 0x0a, 0x14, 0x69, 0x6e, 0x63, 0x6f, 0x6d, 0x69, 0x6e, 0x67, 0x5f, 0x63, 0x69, 0x72, 0x63, 0x75, 0x69, @@ -3688,7 +3702,11 @@ var file_routerrpc_router_proto_rawDesc = []byte{ 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x70, 0x72, 0x65, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x12, 0x27, 0x0a, 0x0f, 0x66, 0x61, 0x69, 0x6c, 0x75, 0x72, 0x65, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0e, 0x66, 0x61, 0x69, 0x6c, 0x75, 0x72, - 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x3d, 0x0a, 0x0c, 0x66, 0x61, 0x69, 0x6c, + 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x3e, 0x0a, 0x1b, 0x66, 0x61, 0x69, 0x6c, + 0x75, 0x72, 0x65, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x5f, 0x75, 0x6e, 0x65, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x19, 0x66, + 0x61, 0x69, 0x6c, 0x75, 0x72, 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x55, 0x6e, 0x65, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x12, 0x3d, 0x0a, 0x0c, 0x66, 0x61, 0x69, 0x6c, 0x75, 0x72, 0x65, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1a, 0x2e, 0x6c, 0x6e, 0x72, 0x70, 0x63, 0x2e, 0x46, 0x61, 0x69, 0x6c, 0x75, 0x72, 0x65, 0x2e, 0x46, 0x61, 0x69, 0x6c, 0x75, 0x72, 0x65, 0x43, 0x6f, 0x64, 0x65, 0x52, 0x0b, 0x66, 0x61, 0x69, 0x6c, diff --git a/lnrpc/routerrpc/router.proto b/lnrpc/routerrpc/router.proto index d591cc485c..8f3ebd301e 100644 --- a/lnrpc/routerrpc/router.proto +++ b/lnrpc/routerrpc/router.proto @@ -914,12 +914,20 @@ message ForwardHtlcInterceptResponse { // The preimage in case the resolve action is Settle. bytes preimage = 3; - // Encrypted failure message in case the resolve action is Fail. + // Failure message in case the resolve action is Fail. The field + // failure_message_unencrypted indicates whether this message is already + // encrypted for the first hop. // - // If failure_message is specified, the failure_code field must be set - // to zero. + // If failure_message is specified, the failure_code field must be set to + // zero. bytes failure_message = 4; + // Indicates whether the failure message still needs to be encrypted for the + // first hop. + bool failure_message_unencrypted = 6; + + // Deprecated: use failure_message with failure_message_unencrypted. + // // Return the specified failure code in case the resolve action is Fail. The // message data fields are populated automatically. // diff --git a/lnrpc/routerrpc/router.swagger.json b/lnrpc/routerrpc/router.swagger.json index 5858ff24f3..15ccb88005 100644 --- a/lnrpc/routerrpc/router.swagger.json +++ b/lnrpc/routerrpc/router.swagger.json @@ -1320,11 +1320,15 @@ "failure_message": { "type": "string", "format": "byte", - "description": "Encrypted failure message in case the resolve action is Fail.\n\nIf failure_message is specified, the failure_code field must be set\nto zero." + "description": "Failure message in case the resolve action is Fail. The field\nfailure_message_unencrypted indicates whether this message is already\nencrypted for the first hop.\n\nIf failure_message is specified, the failure_code field must be set to\nzero." + }, + "failure_message_unencrypted": { + "type": "boolean", + "description": "Indicates whether the failure message still needs to be encrypted for the\nfirst hop." }, "failure_code": { "$ref": "#/definitions/FailureFailureCode", - "description": "Return the specified failure code in case the resolve action is Fail. The\nmessage data fields are populated automatically.\n\nIf a non-zero failure_code is specified, failure_message must not be set.\n\nFor backwards-compatibility reasons, TEMPORARY_CHANNEL_FAILURE is the\ndefault value for this field." + "description": "Deprecated: use failure_message with failure_message_unencrypted.\n\nReturn the specified failure code in case the resolve action is Fail. The\nmessage data fields are populated automatically.\n\nIf a non-zero failure_code is specified, failure_message must not be set.\n\nFor backwards-compatibility reasons, TEMPORARY_CHANNEL_FAILURE is the\ndefault value for this field." } }, "description": "*\nForwardHtlcInterceptResponse enables the caller to resolve a previously hold\nforward. The caller can choose either to:\n- `Resume`: Execute the default behavior (usually forward).\n- `Reject`: Fail the htlc backwards.\n- `Settle`: Settle this htlc with a given preimage." diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go index e6ae7437f5..e017189110 100644 --- a/lnwire/onion_error.go +++ b/lnwire/onion_error.go @@ -1324,9 +1324,18 @@ func EncodeFailure(w *bytes.Buffer, failure FailureMessage, pver uint32) error { return err } + failureMessage := failureMessageBuffer.Bytes() + + return EncodeFailureHeader(w, failureMessage, pver) +} + +// EncodeFailureHeader adds the necessary onion failure header information to an +// encoded failure. +func EncodeFailureHeader(w *bytes.Buffer, failureMessage []byte, + _ uint32) error { + // The combined size of this message must be below the max allowed // failure message length. - failureMessage := failureMessageBuffer.Bytes() if len(failureMessage) > FailureMessageLength { return fmt.Errorf("failure message exceed max "+ "available size: %v", len(failureMessage)) diff --git a/server.go b/server.go index 5b421ae2f5..164c3b1893 100644 --- a/server.go +++ b/server.go @@ -679,6 +679,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, CltvInterceptDelta: lncfg.DefaultCltvInterceptDelta, RequireInterceptor: s.cfg.RequireInterceptor, Notifier: s.cc.ChainNotifier, + SignChannelUpdate: s.signAliasUpdate, }, ) if err != nil {