Skip to content

Commit

Permalink
channeldb: optimize memory usage of initial cache fill
Browse files Browse the repository at this point in the history
With this commit we use an optimized version of the node iteration that
causes fewer memory allocations by only loading the part of the graph
node that we actually need to know for the cache.
  • Loading branch information
guggero committed Sep 29, 2021
1 parent a5202a8 commit 6240851
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 13 deletions.
125 changes: 114 additions & 11 deletions channeldb/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int,
startTime := time.Now()
log.Debugf("Populating in-memory channel graph, this might take a " +
"while...")
err := g.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error {
return g.graphCache.AddNode(tx, &graphCacheNode{node})
err := g.ForEachNodeCacheable(func(tx kvdb.RTx, node GraphCacheNode) error {
return g.graphCache.AddNode(tx, node)
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -468,6 +468,47 @@ func (c *ChannelGraph) ForEachNode(cb func(kvdb.RTx, *LightningNode) error) erro
return kvdb.View(c.db, traversal, func() {})
}

// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
// graph, executing the passed callback with each node encountered. If the
// callback returns an error, then the transaction is aborted and the iteration
// stops early.
func (c *ChannelGraph) ForEachNodeCacheable(cb func(kvdb.RTx,
GraphCacheNode) error) error {

traversal := func(tx kvdb.RTx) error {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}

cacheableNode := newGraphCacheNode(route.Vertex{}, nil)
return nodes.ForEach(func(pubKey, nodeBytes []byte) error {
// If this is the source key, then we skip this
// iteration as the value for this key is a pubKey
// rather than raw node information.
if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 {
return nil
}

nodeReader := bytes.NewReader(nodeBytes)
err := deserializeLightningNodeCacheable(
nodeReader, cacheableNode,
)
if err != nil {
return err
}

// Execute the callback, the transaction will abort if
// this returns an error.
return cb(tx, cacheableNode)
})
}

return kvdb.View(c.db, traversal, func() {})
}

// SourceNode returns the source node of the graph. The source node is treated
// as the center node within a star-graph. This method may be used to kick off
// a path finding algorithm in order to explore the reachability of another
Expand Down Expand Up @@ -559,8 +600,10 @@ func (c *ChannelGraph) AddLightningNode(node *LightningNode,

r := &batch.Request{
Update: func(tx kvdb.RwTx) error {
wNode := &graphCacheNode{node}
if err := c.graphCache.AddNode(tx, wNode); err != nil {
cNode := newGraphCacheNode(
node.PubKeyBytes, node.Features,
)
if err := c.graphCache.AddNode(tx, cNode); err != nil {
return err
}

Expand Down Expand Up @@ -2532,17 +2575,30 @@ func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) (
// graphCacheNode is a struct that wraps a LightningNode in a way that it can be
// cached in the graph cache.
type graphCacheNode struct {
lnNode *LightningNode
pubKeyBytes route.Vertex
features *lnwire.FeatureVector

nodeScratch [8]byte
}

// newGraphCacheNode returns a new cache optimized node.
func newGraphCacheNode(pubKey route.Vertex,
features *lnwire.FeatureVector) *graphCacheNode {

return &graphCacheNode{
pubKeyBytes: pubKey,
features: features,
}
}

// PubKey returns the node's public identity key.
func (w *graphCacheNode) PubKey() route.Vertex {
return w.lnNode.PubKeyBytes
func (n *graphCacheNode) PubKey() route.Vertex {
return n.pubKeyBytes
}

// Features returns the node's features.
func (w *graphCacheNode) Features() *lnwire.FeatureVector {
return w.lnNode.Features
func (n *graphCacheNode) Features() *lnwire.FeatureVector {
return n.features
}

// ForEachChannel iterates through all channels of this node, executing the
Expand All @@ -2553,11 +2609,11 @@ func (w *graphCacheNode) Features() *lnwire.FeatureVector {
// halted with the error propagated back up to the caller.
//
// Unknown policies are passed into the callback as nil values.
func (w *graphCacheNode) ForEachChannel(tx kvdb.RTx,
func (n *graphCacheNode) ForEachChannel(tx kvdb.RTx,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error {

return w.lnNode.ForEachChannel(tx, cb)
return nodeTraversal(tx, n.pubKeyBytes[:], nil, cb)
}

var _ GraphCacheNode = (*graphCacheNode)(nil)
Expand Down Expand Up @@ -3865,6 +3921,53 @@ func fetchLightningNode(nodeBucket kvdb.RBucket,
return deserializeLightningNode(nodeReader)
}

func deserializeLightningNodeCacheable(r io.Reader, node *graphCacheNode) error {
// Always populate a feature vector, even if we don't have a node
// announcement and short circuit below.
node.features = lnwire.EmptyFeatureVector()

// Skip ahead:
// - LastUpdate (8 bytes)
if _, err := r.Read(node.nodeScratch[:]); err != nil {
return err
}

if _, err := io.ReadFull(r, node.pubKeyBytes[:]); err != nil {
return err
}

// Read the node announcement flag.
if _, err := r.Read(node.nodeScratch[:2]); err != nil {
return err
}
hasNodeAnn := byteOrder.Uint16(node.nodeScratch[:2])

// The rest of the data is optional, and will only be there if we got a
// node announcement for this node.
if hasNodeAnn == 0 {
return nil
}

// We did get a node announcement for this node, so we'll have the rest
// of the data available.
var rgb uint8
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return err
}

if _, err := wire.ReadVarString(r, 0); err != nil {
return err
}

return node.features.Decode(r)
}

func deserializeLightningNode(r io.Reader) (LightningNode, error) {
var (
node LightningNode
Expand Down
103 changes: 101 additions & 2 deletions channeldb/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
Expand Down Expand Up @@ -1145,6 +1146,58 @@ func TestGraphTraversal(t *testing.T) {
require.Equal(t, numChannels, numNodeChans)
}

// TestGraphTraversalCacheable tests that the memory optimized node traversal is
// working correctly.
func TestGraphTraversalCacheable(t *testing.T) {
t.Parallel()

graph, cleanUp, err := MakeTestGraph()
defer cleanUp()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}

// We'd like to test some of the graph traversal capabilities within
// the DB, so we'll create a series of fake nodes to insert into the
// graph. And we'll create 5 channels between the first two nodes.
const numNodes = 20
const numChannels = 5
chanIndex, _ := fillTestGraph(t, graph, numNodes, numChannels)

// Create a map of all nodes with the iteration we know works (because
// it is tested in another test).
nodeMap := make(map[route.Vertex]struct{})
err = graph.ForEachNode(func(tx kvdb.RTx, n *LightningNode) error {
nodeMap[n.PubKeyBytes] = struct{}{}

return nil
})
require.NoError(t, err)
require.Len(t, nodeMap, numNodes)

// Iterate through all the known channels within the graph DB by
// iterating over each node, once again if the map is empty that
// indicates that all edges have properly been reached.
err = graph.ForEachNodeCacheable(
func(tx kvdb.RTx, node GraphCacheNode) error {
delete(nodeMap, node.PubKey())

return node.ForEachChannel(
tx, func(tx kvdb.RTx, info *ChannelEdgeInfo,
policy *ChannelEdgePolicy,
policy2 *ChannelEdgePolicy) error {

delete(chanIndex, info.ChannelID)
return nil
},
)
},
)
require.NoError(t, err)
require.Len(t, nodeMap, 0)
require.Len(t, chanIndex, 0)
}

func TestGraphCacheTraversal(t *testing.T) {
t.Parallel()

Expand All @@ -1164,6 +1217,8 @@ func TestGraphCacheTraversal(t *testing.T) {
// properly been reached.
numNodeChans := 0
for _, node := range nodeList {
node := node

err = graph.graphCache.ForEachChannel(
node.PubKeyBytes, func(d *DirectedChannel) error {
delete(chanIndex, d.ChannelID)
Expand Down Expand Up @@ -1197,7 +1252,7 @@ func TestGraphCacheTraversal(t *testing.T) {
require.Equal(t, numChannels*2*(numNodes-1), numNodeChans)
}

func fillTestGraph(t *testing.T, graph *ChannelGraph, numNodes,
func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes,
numChannels int) (map[uint64]struct{}, []*LightningNode) {

nodes := make([]*LightningNode, numNodes)
Expand Down Expand Up @@ -1237,7 +1292,7 @@ func fillTestGraph(t *testing.T, graph *ChannelGraph, numNodes,

for i := 0; i < numChannels; i++ {
txHash := sha256.Sum256([]byte{byte(i)})
chanID := uint64((n << 4) + i + 1)
chanID := uint64((n << 8) + i + 1)
op := wire.OutPoint{
Hash: txHash,
Index: 0,
Expand Down Expand Up @@ -3592,3 +3647,47 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) {
require.Nil(t, err)
}
}

// BenchmarkForEachChannel is a benchmark test that measures the number of
// allocations and the total memory consumed by the full graph traversal.
func BenchmarkForEachChannel(b *testing.B) {
graph, cleanUp, err := MakeTestGraph()
require.Nil(b, err)
defer cleanUp()

const numNodes = 100
const numChannels = 4
_, _ = fillTestGraph(b, graph, numNodes, numChannels)

b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
var (
totalCapacity btcutil.Amount
maxHTLCs lnwire.MilliSatoshi
)
err := graph.ForEachNodeCacheable(
func(tx kvdb.RTx, n GraphCacheNode) error {
return n.ForEachChannel(
tx, func(tx kvdb.RTx,
info *ChannelEdgeInfo,
policy *ChannelEdgePolicy,
policy2 *ChannelEdgePolicy) error {

// We need to do something with
// the data here, otherwise the
// compiler is going to optimize
// this away, and we get bogus
// results.
totalCapacity += info.Capacity
maxHTLCs += policy.MaxHTLC
maxHTLCs += policy2.MaxHTLC

return nil
},
)
},
)
require.NoError(b, err)
}
}

0 comments on commit 6240851

Please sign in to comment.