diff --git a/core/contract.go b/core/contract.go index 0948eb94e0..8c8994d8e1 100644 --- a/core/contract.go +++ b/core/contract.go @@ -167,6 +167,32 @@ func (c *ContractUpdater) UpdateStorage(diff map[felt.Felt]*felt.Felt, cb OnValu return cStorage.Commit() } +// UpdateStorage applies a change-set to the contract storage. +func (c *ContractUpdater) UpdateStorageKV(diff []FeltKV, cb OnValueChanged) error { + cStorage, err := storage(c.Address, c.txn) + if err != nil { + return err + } + + // apply the diff + for _, kv := range diff { + key := kv.Key + value := kv.Value + oldValue, pErr := cStorage.Put(key, value) + if pErr != nil { + return pErr + } + + if oldValue != nil { + if err = cb(key, oldValue); err != nil { + return err + } + } + } + + return cStorage.Commit() +} + func ContractStorage(addr, key *felt.Felt, txn db.Transaction) (*felt.Felt, error) { cStorage, err := storage(addr, txn) if err != nil { diff --git a/core/state.go b/core/state.go index 67e7f16d21..a9acdea5b9 100644 --- a/core/state.go +++ b/core/state.go @@ -387,6 +387,20 @@ func (s *State) UpdateContractStorages(storages map[felt.Felt]map[felt.Felt]*fel return storageCloser() } +func (s *State) UpdateContractStoragesKV(storages map[felt.Felt][]FeltKV) error { + stateTrie, storageCloser, err := s.storage() + if err != nil { + return err + } + + err = s.updateContractStoragesKV(stateTrie, storages, 0, false) + if err != nil { + return err + } + + return storageCloser() +} + // replaceContract replaces the class that a contract at a given address instantiates func (s *State) replaceContract(stateTrie *trie.Trie, addr, classHash *felt.Felt) (*felt.Felt, error) { contract, err := NewContractUpdater(addr, s.txn) @@ -533,6 +547,106 @@ func (s *State) updateStorageBuffered(contractAddr *felt.Felt, updateDiff map[fe return bufferedTxn, nil } +func (s *State) updateStorageBufferedKV(contractAddr *felt.Felt, updateDiff []FeltKV, blockNumber uint64, logChanges bool) ( + *db.BufferedTransaction, error, +) { + // to avoid multiple transactions writing to s.txn, create a buffered transaction and use that in the worker goroutine + bufferedTxn := db.NewBufferedTransaction(s.txn) + bufferedState := NewState(bufferedTxn) + bufferedContract, err := NewContractUpdater(contractAddr, bufferedTxn) + if err != nil { + return nil, err + } + + onValueChanged := func(location, oldValue *felt.Felt) error { + if logChanges { + return bufferedState.LogContractStorage(contractAddr, location, oldValue, blockNumber) + } + return nil + } + + if err = bufferedContract.UpdateStorageKV(updateDiff, onValueChanged); err != nil { + return nil, err + } + + return bufferedTxn, nil +} + +type FeltKV struct { + Key *felt.Felt + Value *felt.Felt +} + +// updateContractStorage applies the diff set to the Trie of the +// contract at the given address in the given Txn context. +func (s *State) updateContractStoragesKV(stateTrie *trie.Trie, diffs map[felt.Felt][]FeltKV, + blockNumber uint64, logChanges bool, +) error { + // make sure all noClassContracts are deployed + for addr := range diffs { + if _, ok := noClassContracts[addr]; !ok { + continue + } + + _, err := NewContractUpdater(&addr, s.txn) + if err != nil { + if !errors.Is(err, ErrContractNotDeployed) { + return err + } + // Deploy noClassContract + err = s.putNewContract(stateTrie, &addr, noClassContractsClassHash, blockNumber) + if err != nil { + return err + } + } + } + + // sort the contracts in decending diff size order + // so we start with the heaviest update first + keys := make([]felt.Felt, 0, len(diffs)) + for key := range diffs { + keys = append(keys, key) + } + sort.SliceStable(keys, func(i, j int) bool { + return len(diffs[keys[i]]) > len(diffs[keys[j]]) + }) + + // update per-contract storage Tries concurrently + contractUpdaters := pool.NewWithResults[*db.BufferedTransaction]().WithErrors().WithMaxGoroutines(runtime.GOMAXPROCS(0)) + for _, key := range keys { + conractAddr := key + updateDiff := diffs[conractAddr] + contractUpdaters.Go(func() (*db.BufferedTransaction, error) { + return s.updateStorageBufferedKV(&conractAddr, updateDiff, blockNumber, logChanges) + }) + } + + bufferedTxns, err := contractUpdaters.Wait() + if err != nil { + return err + } + + // flush buffered txns + for _, bufferedTxn := range bufferedTxns { + if err = bufferedTxn.Flush(); err != nil { + return err + } + } + + for addr := range diffs { + contract, err := NewContractUpdater(&addr, s.txn) + if err != nil { + return err + } + + if err = s.updateContractCommitment(stateTrie, contract); err != nil { + return err + } + } + + return nil +} + // updateContractStorage applies the diff set to the Trie of the // contract at the given address in the given Txn context. func (s *State) updateContractStorages(stateTrie *trie.Trie, diffs map[felt.Felt]map[felt.Felt]*felt.Felt, diff --git a/core/trie/trie.go b/core/trie/trie.go index 86d5eb785b..fc6b2eeee6 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -336,8 +336,14 @@ func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []StorageNode return nil } +var triePut = promauto.NewCounter(prometheus.CounterOpts{ + Name: "juno_trie_put", + Help: "trie put", +}) + // Put updates the corresponding `value` for a `key` func (t *Trie) Put(key, value *felt.Felt) (*felt.Felt, error) { + triePut.Inc() if key.Cmp(t.maxKey) > 0 { return nil, fmt.Errorf("key %s exceeds trie height %d", key, t.height) } diff --git a/sync/snapsyncer.go b/sync/snapsyncer.go index 6507dfe850..8ffce19b07 100644 --- a/sync/snapsyncer.go +++ b/sync/snapsyncer.go @@ -48,7 +48,7 @@ type SnapSyncher struct { storageRangeDone chan interface{} storageRangeJobCount int32 - storageRangeJob chan *storageRangeJob + storageRangeJobQueue chan *storageRangeJob storageRefreshJob chan *storageRangeJob classesJob chan *felt.Felt @@ -363,7 +363,7 @@ func (s *SnapSyncher) initState(ctx context.Context) error { fmt.Printf("Start state root is %s\n", s.startingBlock.GlobalStateRoot) s.currentGlobalStateRoot = s.startingBlock.GlobalStateRoot.Clone() s.storageRangeJobCount = 0 - s.storageRangeJob = make(chan *storageRangeJob, storageJobQueueSize) + s.storageRangeJobQueue = make(chan *storageRangeJob, storageJobQueueSize) s.classesJob = make(chan *felt.Felt, classesJobQueueSize) s.contractRangeDone = make(chan interface{}) @@ -716,7 +716,7 @@ func (s *SnapSyncher) queueStorageRangeJobJob(ctx context.Context, job *storageR queued := false for !queued { select { - case s.storageRangeJob <- job: + case s.storageRangeJobQueue <- job: queued = true atomic.AddInt32(&s.storageRangeJobCount, 1) case <-ctx.Done(): @@ -757,6 +757,8 @@ func (s *SnapSyncher) runStorageRangeWorker(ctx context.Context, workerIdx int) } select { + case job := <-s.storageRangeJobQueue: + jobs = append(jobs, job) case <-ctx.Done(): return ctx.Err() case <-time.After(time.Second * 1): @@ -767,8 +769,6 @@ func (s *SnapSyncher) runStorageRangeWorker(ctx context.Context, workerIdx int) case <-contractDoneChecker: // Its done... return nil - case job := <-s.storageRangeJob: - jobs = append(jobs, job) } } @@ -803,7 +803,8 @@ func (s *SnapSyncher) runStorageRangeWorker(ctx context.Context, workerIdx int) })(func(response *StorageRangeStreamingResult, err error) bool { job := jobs[processedJobs] if !job.path.Equal(response.StorageAddr) { - panic(fmt.Errorf("storage addr differ %s %s %d\n", job.path, response.StorageAddr, workerIdx)) + s.log.Errorw(fmt.Sprintf("storage addr differ %s %s %d\n", job.path, response.StorageAddr, workerIdx)) + return false } if response.Range == nil && response.RangeProof == nil { @@ -1005,7 +1006,8 @@ func (s *SnapSyncher) runFetchClassJob(ctx context.Context) error { } if !h.Equal(keyBatches[i]) { - return errors.New("invalid class hash") + s.log.Warnw("invalid classhash", "got", h, "expected", keyBatches[i]) + // return errors.New("invalid class hash") } if coreClass.Version() == 1 { diff --git a/sync/snapsyncer_test.go b/sync/snapsyncer_test.go index 6771acc05d..b792d40202 100644 --- a/sync/snapsyncer_test.go +++ b/sync/snapsyncer_test.go @@ -268,7 +268,7 @@ func TestSnapOfflineCopy(t *testing.T) { func TestSnapCopyTrie(t *testing.T) { var d db.DB - d, err := pebble.New("/home/amirul/fastworkscratch3/juno_sepolia/juno-sepolia/", 128000000, 128, utils.NewNopZapLogger()) + d, err := pebble.New("/home/amirul/fastworkscratch3/juno_db/juno_mainnet/", 1280, 128, utils.NewNopZapLogger()) assert.NoError(t, err) bc := blockchain.New(d, &utils.Sepolia) // Needed because class loader need encoder to be registered