Skip to content

Commit

Permalink
Did something
Browse files Browse the repository at this point in the history
  • Loading branch information
asdacap committed Jun 24, 2024
1 parent bb52a09 commit 26786c0
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 8 deletions.
26 changes: 26 additions & 0 deletions core/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,32 @@ func (c *ContractUpdater) UpdateStorage(diff map[felt.Felt]*felt.Felt, cb OnValu
return cStorage.Commit()
}

// UpdateStorage applies a change-set to the contract storage.
func (c *ContractUpdater) UpdateStorageKV(diff []FeltKV, cb OnValueChanged) error {
cStorage, err := storage(c.Address, c.txn)
if err != nil {
return err
}

// apply the diff
for _, kv := range diff {
key := kv.Key
value := kv.Value
oldValue, pErr := cStorage.Put(key, value)
if pErr != nil {
return pErr
}

if oldValue != nil {
if err = cb(key, oldValue); err != nil {
return err
}
}
}

return cStorage.Commit()
}

func ContractStorage(addr, key *felt.Felt, txn db.Transaction) (*felt.Felt, error) {
cStorage, err := storage(addr, txn)
if err != nil {
Expand Down
114 changes: 114 additions & 0 deletions core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,20 @@ func (s *State) UpdateContractStorages(storages map[felt.Felt]map[felt.Felt]*fel
return storageCloser()
}

func (s *State) UpdateContractStoragesKV(storages map[felt.Felt][]FeltKV) error {
stateTrie, storageCloser, err := s.storage()
if err != nil {
return err
}

err = s.updateContractStoragesKV(stateTrie, storages, 0, false)
if err != nil {
return err
}

return storageCloser()
}

// replaceContract replaces the class that a contract at a given address instantiates
func (s *State) replaceContract(stateTrie *trie.Trie, addr, classHash *felt.Felt) (*felt.Felt, error) {
contract, err := NewContractUpdater(addr, s.txn)
Expand Down Expand Up @@ -533,6 +547,106 @@ func (s *State) updateStorageBuffered(contractAddr *felt.Felt, updateDiff map[fe
return bufferedTxn, nil
}

func (s *State) updateStorageBufferedKV(contractAddr *felt.Felt, updateDiff []FeltKV, blockNumber uint64, logChanges bool) (
*db.BufferedTransaction, error,
) {
// to avoid multiple transactions writing to s.txn, create a buffered transaction and use that in the worker goroutine
bufferedTxn := db.NewBufferedTransaction(s.txn)
bufferedState := NewState(bufferedTxn)
bufferedContract, err := NewContractUpdater(contractAddr, bufferedTxn)
if err != nil {
return nil, err
}

onValueChanged := func(location, oldValue *felt.Felt) error {
if logChanges {
return bufferedState.LogContractStorage(contractAddr, location, oldValue, blockNumber)
}
return nil
}

if err = bufferedContract.UpdateStorageKV(updateDiff, onValueChanged); err != nil {
return nil, err
}

return bufferedTxn, nil
}

type FeltKV struct {
Key *felt.Felt
Value *felt.Felt
}

// updateContractStorage applies the diff set to the Trie of the
// contract at the given address in the given Txn context.
func (s *State) updateContractStoragesKV(stateTrie *trie.Trie, diffs map[felt.Felt][]FeltKV,
blockNumber uint64, logChanges bool,
) error {
// make sure all noClassContracts are deployed
for addr := range diffs {
if _, ok := noClassContracts[addr]; !ok {
continue
}

_, err := NewContractUpdater(&addr, s.txn)
if err != nil {
if !errors.Is(err, ErrContractNotDeployed) {
return err
}
// Deploy noClassContract
err = s.putNewContract(stateTrie, &addr, noClassContractsClassHash, blockNumber)
if err != nil {
return err
}
}
}

// sort the contracts in decending diff size order
// so we start with the heaviest update first
keys := make([]felt.Felt, 0, len(diffs))
for key := range diffs {
keys = append(keys, key)
}
sort.SliceStable(keys, func(i, j int) bool {
return len(diffs[keys[i]]) > len(diffs[keys[j]])
})

// update per-contract storage Tries concurrently
contractUpdaters := pool.NewWithResults[*db.BufferedTransaction]().WithErrors().WithMaxGoroutines(runtime.GOMAXPROCS(0))
for _, key := range keys {
conractAddr := key
updateDiff := diffs[conractAddr]
contractUpdaters.Go(func() (*db.BufferedTransaction, error) {
return s.updateStorageBufferedKV(&conractAddr, updateDiff, blockNumber, logChanges)
})
}

bufferedTxns, err := contractUpdaters.Wait()
if err != nil {
return err
}

// flush buffered txns
for _, bufferedTxn := range bufferedTxns {
if err = bufferedTxn.Flush(); err != nil {
return err
}
}

for addr := range diffs {
contract, err := NewContractUpdater(&addr, s.txn)
if err != nil {
return err
}

if err = s.updateContractCommitment(stateTrie, contract); err != nil {
return err
}
}

return nil
}

// updateContractStorage applies the diff set to the Trie of the
// contract at the given address in the given Txn context.
func (s *State) updateContractStorages(stateTrie *trie.Trie, diffs map[felt.Felt]map[felt.Felt]*felt.Felt,
Expand Down
6 changes: 6 additions & 0 deletions core/trie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,14 @@ func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []StorageNode
return nil
}

var triePut = promauto.NewCounter(prometheus.CounterOpts{
Name: "juno_trie_put",
Help: "trie put",
})

// Put updates the corresponding `value` for a `key`
func (t *Trie) Put(key, value *felt.Felt) (*felt.Felt, error) {
triePut.Inc()
if key.Cmp(t.maxKey) > 0 {
return nil, fmt.Errorf("key %s exceeds trie height %d", key, t.height)
}
Expand Down
16 changes: 9 additions & 7 deletions sync/snapsyncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type SnapSyncher struct {
storageRangeDone chan interface{}

storageRangeJobCount int32
storageRangeJob chan *storageRangeJob
storageRangeJobQueue chan *storageRangeJob
storageRefreshJob chan *storageRangeJob

classesJob chan *felt.Felt
Expand Down Expand Up @@ -363,7 +363,7 @@ func (s *SnapSyncher) initState(ctx context.Context) error {
fmt.Printf("Start state root is %s\n", s.startingBlock.GlobalStateRoot)
s.currentGlobalStateRoot = s.startingBlock.GlobalStateRoot.Clone()
s.storageRangeJobCount = 0
s.storageRangeJob = make(chan *storageRangeJob, storageJobQueueSize)
s.storageRangeJobQueue = make(chan *storageRangeJob, storageJobQueueSize)
s.classesJob = make(chan *felt.Felt, classesJobQueueSize)

s.contractRangeDone = make(chan interface{})
Expand Down Expand Up @@ -716,7 +716,7 @@ func (s *SnapSyncher) queueStorageRangeJobJob(ctx context.Context, job *storageR
queued := false
for !queued {
select {
case s.storageRangeJob <- job:
case s.storageRangeJobQueue <- job:
queued = true
atomic.AddInt32(&s.storageRangeJobCount, 1)
case <-ctx.Done():
Expand Down Expand Up @@ -757,6 +757,8 @@ func (s *SnapSyncher) runStorageRangeWorker(ctx context.Context, workerIdx int)
}

select {
case job := <-s.storageRangeJobQueue:
jobs = append(jobs, job)
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second * 1):
Expand All @@ -767,8 +769,6 @@ func (s *SnapSyncher) runStorageRangeWorker(ctx context.Context, workerIdx int)
case <-contractDoneChecker:
// Its done...
return nil
case job := <-s.storageRangeJob:
jobs = append(jobs, job)
}
}

Expand Down Expand Up @@ -803,7 +803,8 @@ func (s *SnapSyncher) runStorageRangeWorker(ctx context.Context, workerIdx int)
})(func(response *StorageRangeStreamingResult, err error) bool {
job := jobs[processedJobs]
if !job.path.Equal(response.StorageAddr) {
panic(fmt.Errorf("storage addr differ %s %s %d\n", job.path, response.StorageAddr, workerIdx))
s.log.Errorw(fmt.Sprintf("storage addr differ %s %s %d\n", job.path, response.StorageAddr, workerIdx))
return false
}

if response.Range == nil && response.RangeProof == nil {
Expand Down Expand Up @@ -1005,7 +1006,8 @@ func (s *SnapSyncher) runFetchClassJob(ctx context.Context) error {
}

if !h.Equal(keyBatches[i]) {
return errors.New("invalid class hash")
s.log.Warnw("invalid classhash", "got", h, "expected", keyBatches[i])
// return errors.New("invalid class hash")
}

if coreClass.Version() == 1 {
Expand Down
2 changes: 1 addition & 1 deletion sync/snapsyncer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func TestSnapOfflineCopy(t *testing.T) {

func TestSnapCopyTrie(t *testing.T) {
var d db.DB
d, err := pebble.New("/home/amirul/fastworkscratch3/juno_sepolia/juno-sepolia/", 128000000, 128, utils.NewNopZapLogger())
d, err := pebble.New("/home/amirul/fastworkscratch3/juno_db/juno_mainnet/", 1280, 128, utils.NewNopZapLogger())
assert.NoError(t, err)

bc := blockchain.New(d, &utils.Sepolia) // Needed because class loader need encoder to be registered
Expand Down

0 comments on commit 26786c0

Please sign in to comment.