Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(share/availability): simplify light availability #3895

Merged
merged 2 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions share/availability/full/availability.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,6 @@ func (fa *ShareAvailability) SharesAvailable(ctx context.Context, header *header
return nil
}

// we assume the caller of this method has already performed basic validation on the
// given roots. If for some reason this has not happened, the node should panic.
if err := dah.ValidateBasic(); err != nil {
log.Errorw("Availability validation cannot be performed on a malformed DataAvailabilityHeader",
"err", err)
panic(err)
}

// a hack to avoid loading the whole EDS in mem if we store it already.
if ok, _ := fa.store.HasByHeight(ctx, header.Height()); ok {
return nil
Expand Down
38 changes: 16 additions & 22 deletions share/availability/light/availability.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package light

import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
Expand All @@ -17,9 +18,9 @@ import (
)

var (
log = logging.Logger("share/light")
cacheAvailabilityPrefix = datastore.NewKey("sampling_result")
writeBatchSize = 2048
log = logging.Logger("share/light")
samplingResultsPrefix = datastore.NewKey("sampling_result")
writeBatchSize = 2048
)

// ShareAvailability implements share.Availability using Data Availability Sampling technique.
Expand All @@ -30,9 +31,6 @@ type ShareAvailability struct {
getter shwap.Getter
params Parameters

// TODO(@Wondertan): Once we come to parallelized DASer, this lock becomes a contention point
// Related to #483
// TODO: Striped locks? :D
dsLk sync.RWMutex
ds *autobatch.Datastore
}
Expand All @@ -44,7 +42,7 @@ func NewShareAvailability(
opts ...Option,
) *ShareAvailability {
params := *DefaultParameters()
ds = namespace.Wrap(ds, cacheAvailabilityPrefix)
ds = namespace.Wrap(ds, samplingResultsPrefix)
autoDS := autobatch.NewAutoBatching(ds, writeBatchSize)

for _, opt := range opts {
Expand All @@ -68,7 +66,7 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header
}

// load snapshot of the last sampling errors from disk
key := rootKey(dah)
key := datastoreKeyForRoot(dah)
la.dsLk.RLock()
last, err := la.ds.Get(ctx, key)
la.dsLk.RUnlock()
Expand All @@ -84,37 +82,30 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header
return err
case errors.Is(err, datastore.ErrNotFound):
// No sampling result found, select new samples
samples, err = SampleSquare(len(dah.RowRoots), int(la.params.SampleAmount))
if err != nil {
return err
}
samples = selectRandomSamples(len(dah.RowRoots), int(la.params.SampleAmount))
default:
// Sampling result found, unmarshal it
samples, err = decodeSamples(last)
err = json.Unmarshal(last, &samples)
if err != nil {
return err
}
}

if err := dah.ValidateBasic(); err != nil {
return err
}

var (
failedSamplesLock sync.Mutex
failedSamples []Sample
)

log.Debugw("starting sampling session", "root", dah.String())
log.Debugw("starting sampling session", "height", header.Height())
var wg sync.WaitGroup
for _, s := range samples {
wg.Add(1)
go func(s Sample) {
defer wg.Done()
// check if the sample is available
_, err := la.getter.GetShare(ctx, header, int(s.Row), int(s.Col))
_, err := la.getter.GetShare(ctx, header, s.Row, s.Col)
if err != nil {
log.Debugw("error fetching share", "root", dah.String(), "row", s.Row, "col", s.Col)
log.Debugw("error fetching share", "height", header.Height(), "row", s.Row, "col", s.Col)
failedSamplesLock.Lock()
failedSamples = append(failedSamples, s)
failedSamplesLock.Unlock()
Expand All @@ -124,7 +115,10 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header
wg.Wait()

// store the result of the sampling session
bs := encodeSamples(failedSamples)
bs, err := json.Marshal(failedSamples)
walldiss marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return fmt.Errorf("failed to marshal sampling result: %w", err)
}
la.dsLk.Lock()
err = la.ds.Put(ctx, key, bs)
la.dsLk.Unlock()
Expand All @@ -145,7 +139,7 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header
return nil
}

func rootKey(root *share.AxisRoots) datastore.Key {
func datastoreKeyForRoot(root *share.AxisRoots) datastore.Key {
return datastore.NewKey(root.String())
}

Expand Down
56 changes: 30 additions & 26 deletions share/availability/light/availability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package light
import (
"context"
_ "embed"
"encoding/json"
"sync"
"testing"

Expand All @@ -22,7 +23,7 @@ import (
"github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex"
)

func TestSharesAvailableCaches(t *testing.T) {
func TestSharesAvailableSuccess(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -48,27 +49,29 @@ func TestSharesAvailableCaches(t *testing.T) {
ds := datastore.NewMapDatastore()
avail := NewShareAvailability(getter, ds)

// cache doesn't have eds yet
has, err := avail.ds.Has(ctx, rootKey(roots))
// Ensure the datastore doesn't have the sampling result yet
has, err := avail.ds.Has(ctx, datastoreKeyForRoot(roots))
require.NoError(t, err)
require.False(t, has)

err = avail.SharesAvailable(ctx, eh)
require.NoError(t, err)

// is now stored success result
result, err := avail.ds.Get(ctx, rootKey(roots))
// Verify that the sampling result is stored with all samples marked as available
result, err := avail.ds.Get(ctx, datastoreKeyForRoot(roots))
require.NoError(t, err)
failed, err := decodeSamples(result)

var failed []Sample
err = json.Unmarshal(result, &failed)
require.NoError(t, err)
require.Empty(t, failed)
}

func TestSharesAvailableHitsCache(t *testing.T) {
func TestSharesAvailableSkipSampled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// create getter that always return ErrNotFound
// Create a getter that always returns ErrNotFound
getter := mock.NewMockGetter(gomock.NewController(t))
getter.EXPECT().
GetShare(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Expand All @@ -86,16 +89,19 @@ func TestSharesAvailableHitsCache(t *testing.T) {
err := avail.SharesAvailable(ctx, eh)
require.ErrorIs(t, err, share.ErrNotAvailable)

// put success result in cache
err = avail.ds.Put(ctx, rootKey(roots), []byte{})
// Store a successful sampling result in the datastore
failed := []Sample{}
data, err := json.Marshal(failed)
require.NoError(t, err)
err = avail.ds.Put(ctx, datastoreKeyForRoot(roots), data)
require.NoError(t, err)

// should hit cache after putting
// SharesAvailable should now return no error since the success sampling result is stored
err = avail.SharesAvailable(ctx, eh)
require.NoError(t, err)
}

func TestSharesAvailableEmptyRoot(t *testing.T) {
func TestSharesAvailableEmptyEDS(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -117,42 +123,40 @@ func TestSharesAvailableFailed(t *testing.T) {
ds := datastore.NewMapDatastore()
avail := NewShareAvailability(getter, ds)

// create new eds, that is not available by getter
// Create new eds, that is not available by getter
eds := edstest.RandEDS(t, 16)
roots, err := share.NewAxisRoots(eds)
require.NoError(t, err)
eh := headertest.RandExtendedHeaderWithRoot(t, roots)

// getter doesn't have the eds, so it should fail
// Getter doesn't have the eds, so it should fail for all samples
getter.EXPECT().
GetShare(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(libshare.Share{}, shrex.ErrNotFound).
AnyTimes()
err = avail.SharesAvailable(ctx, eh)
require.ErrorIs(t, err, share.ErrNotAvailable)

// cache should have failed results now
result, err := avail.ds.Get(ctx, rootKey(roots))
// The datastore should now contain the sampling result with all samples in Remaining
result, err := avail.ds.Get(ctx, datastoreKeyForRoot(roots))
require.NoError(t, err)

failed, err := decodeSamples(result)
var failed []Sample
err = json.Unmarshal(result, &failed)
require.NoError(t, err)
require.Len(t, failed, int(avail.params.SampleAmount))

// ensure that retry persists the failed samples selection
// create new getter with only the failed samples available, and add them to the onceGetter
onceGetter := newOnceGetter()
onceGetter.AddSamples(failed)

// replace getter with the new one
avail.getter = onceGetter
// Simulate a getter that now returns shares successfully
successfulGetter := newOnceGetter()
successfulGetter.AddSamples(failed)
avail.getter = successfulGetter

// should be able to retrieve all the failed samples now
err = avail.SharesAvailable(ctx, eh)
require.NoError(t, err)

// onceGetter should have no more samples stored after the call
require.Empty(t, onceGetter.available)
require.Empty(t, successfulGetter.available)
}

type onceGetter struct {
Expand All @@ -178,7 +182,7 @@ func (m onceGetter) AddSamples(samples []Sample) {
func (m onceGetter) GetShare(_ context.Context, _ *header.ExtendedHeader, row, col int) (libshare.Share, error) {
m.Lock()
defer m.Unlock()
s := Sample{Row: uint16(row), Col: uint16(col)}
s := Sample{Row: row, Col: col}
if _, ok := m.available[s]; ok {
delete(m.available, s)
return libshare.Share{}, nil
Expand Down
100 changes: 19 additions & 81 deletions share/availability/light/sample.go
Original file line number Diff line number Diff line change
@@ -1,104 +1,42 @@
// TODO(@Wondertan): Instead of doing sampling over the coordinates do a random walk over NMT trees.
package light

import (
crand "crypto/rand"
"encoding/binary"
"errors"
"math/big"

"golang.org/x/exp/maps"
)

// Sample is a point in 2D space over square.
// Sample represents a coordinate in a 2D data square.
type Sample struct {
walldiss marked this conversation as resolved.
Show resolved Hide resolved
Row, Col uint16
Row int `json:"row"`
Col int `json:"col"`
}

// SampleSquare randomly picks *num* unique points from the given *width* square
// and returns them as samples.
func SampleSquare(squareWidth, num int) ([]Sample, error) {
ss := newSquareSampler(squareWidth, num)
err := ss.generateSample(num)
if err != nil {
return nil, err
// selectRandomSamples randomly picks unique coordinates from a square of given size.
func selectRandomSamples(squareSize, sampleCount int) []Sample {
total := squareSize * squareSize
if sampleCount > total {
sampleCount = total
}
return ss.samples(), nil
}

type squareSampler struct {
squareWidth int
smpls map[Sample]struct{}
}

func newSquareSampler(squareWidth, expectedSamples int) *squareSampler {
return &squareSampler{
squareWidth: squareWidth,
smpls: make(map[Sample]struct{}, expectedSamples),
}
}

// generateSample randomly picks unique point on a 2D spaces.
func (ss *squareSampler) generateSample(num int) error {
if num > ss.squareWidth*ss.squareWidth {
num = ss.squareWidth
}

done := 0
for done < num {
samples := make(map[Sample]struct{}, sampleCount)
for len(samples) < sampleCount {
s := Sample{
Row: randInt(ss.squareWidth),
Col: randInt(ss.squareWidth),
Row: randInt(squareSize),
Col: randInt(squareSize),
}

if _, ok := ss.smpls[s]; ok {
continue
}

done++
ss.smpls[s] = struct{}{}
samples[s] = struct{}{}
}

return nil
}

func (ss *squareSampler) samples() []Sample {
samples := make([]Sample, 0, len(ss.smpls))
for s := range ss.smpls {
samples = append(samples, s)
}
return samples
return maps.Keys(samples)
}

func randInt(max int) uint16 {
func randInt(max int) int {
n, err := crand.Int(crand.Reader, big.NewInt(int64(max)))
if err != nil {
panic(err) // won't panic as rand.Reader is endless
}

return uint16(n.Uint64())
}

// encodeSamples encodes a slice of samples into a byte slice using little endian encoding.
func encodeSamples(samples []Sample) []byte {
bs := make([]byte, 0, len(samples)*4)
for _, s := range samples {
bs = binary.LittleEndian.AppendUint16(bs, s.Row)
bs = binary.LittleEndian.AppendUint16(bs, s.Col)
}
return bs
}

// decodeSamples decodes a byte slice into a slice of samples.
func decodeSamples(bs []byte) ([]Sample, error) {
if len(bs)%4 != 0 {
return nil, errors.New("invalid byte slice length")
}

samples := make([]Sample, 0, len(bs)/4)
for i := 0; i < len(bs); i += 4 {
samples = append(samples, Sample{
Row: binary.LittleEndian.Uint16(bs[i : i+2]),
Col: binary.LittleEndian.Uint16(bs[i+2 : i+4]),
})
}
return samples, nil
// n.Uint64() is safe as max is int
return int(n.Uint64())
}
Loading