Skip to content

Commit

Permalink
refactor(stamper): extract common functions and improve code
Browse files Browse the repository at this point in the history
  • Loading branch information
gacevicljubisa committed Jan 21, 2025
1 parent 5fe77b3 commit 077d60a
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 131 deletions.
5 changes: 5 additions & 0 deletions pkg/bee/api/postage.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,8 @@ func (p *PostageService) GetChainState(ctx context.Context) (ChainStateResponse,
}
return resp, nil
}

func (batch *PostageStampResponse) BatchUsage() float64 {
maxUtilization := 1 << (batch.Depth - batch.BucketDepth) // 2^(depth - bucketDepth)
return (float64(batch.Utilization) / float64(maxUtilization)) * 100 // batch utilization between 0 and 100 percent
}
218 changes: 87 additions & 131 deletions pkg/stamper/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,188 +35,110 @@ func (n *node) Create(ctx context.Context, amount uint64, depth uint16) error {
}

func (n *node) Dilute(ctx context.Context, threshold float64, depthIncrement uint16, batchIds []string) error {
batches, err := n.client.Postage.PostageBatches(ctx)
batches, _, err := n.getPostageBatches(ctx, false)
if err != nil {
return fmt.Errorf("node %s: get postage batches: %w", n.name, err)
return err
}

for _, batch := range batches {
if len(batchIds) > 0 && !contains(batchIds, batch.BatchID) {
if !isValidBatch(&batch, batchIds) {
continue
}

if !batch.Usable || batch.Utilization == 0 {
continue
}

usageFactor := batch.Depth - batch.BucketDepth // depth - bucketDepth
divisor := float64(int(1) << usageFactor) // 2^(depth - bucketDepth)
stampsUsage := (float64(batch.Utilization) / divisor) * 100 // (utilization / 2^(depth - bucketDepth)) * 100

if stampsUsage >= threshold {
newDepth := uint16(batch.Depth) + depthIncrement

n.log.Tracef("node %s: batch %s: stamps usage %.2f%%, diluting to depth %d", n.name, batch.BatchID, stampsUsage, newDepth)

if err := n.client.Postage.DilutePostageBatch(ctx, batch.BatchID, uint64(newDepth), ""); err != nil {
return fmt.Errorf("node %s: dilute batch %s: %w", n.name, batch.BatchID, err)
}

n.log.Infof("node %s: diluted batch %s to depth %d", n.name, batch.BatchID, newDepth)
if batch.BatchUsage() >= threshold {
return n.handleDilution(ctx, batch, depthIncrement)
}
}

return nil
}

// Set performs both Topup and Dilute operations on postage batches in one function.
// The order of operations is critical because Dilute increases the batch depth,
// which directly affects the calculations for Topup by reducing the effective batch TTL.
// Therefore, Topup is handled first, considering the original depth, followed by Dilute
// which accounts for the new depth and utilization threshold.
func (n *node) Set(ctx context.Context, ttlThreshold time.Duration, topUpFinalTTL time.Duration, utilizationThreshold float64, extraDepth uint16, secondsPerBlock int64, batchIds []string) error {

chainState, err := n.client.Postage.GetChainState(ctx)
// Set performs Topup and Dilute operations on postage batches.
// Topup is handled first based on the original depth, followed by Dilute
// which considers the new depth and utilization threshold.
func (n *node) Set(
ctx context.Context,
ttlThreshold time.Duration,
topUpFinalTTL time.Duration,
utilizationThreshold float64,
extraDepth uint16,
secondsPerBlock int64,
batchIds []string,
) error {
batches, price, err := n.getPostageBatches(ctx, true)
if err != nil {
return fmt.Errorf("node %s: get chain state: %w", n.name, err)
}

price := chainState.CurrentPrice.Int64()
if price <= 0 {
return fmt.Errorf("node %s: invalid chain price: %d", n.name, price)
}

batches, err := n.client.Postage.PostageBatches(ctx)
if err != nil {
return fmt.Errorf("node %s: get postage batches: %w", n.name, err)
return err
}

for _, batch := range batches {
if len(batchIds) > 0 && !contains(batchIds, batch.BatchID) {
continue
}

if !batch.Usable || batch.Utilization == 0 {
if !isValidBatch(&batch, batchIds) {
continue
}

batchTTL := time.Duration(batch.BatchTTL) * time.Second
var topUpTTL time.Duration

if batchTTL <= 0 {
continue
}
needsDilution := batch.BatchUsage() >= utilizationThreshold

// calculate batch usage
maxUtilization := 1 << (batch.Depth - batch.BucketDepth) // 2^(depth - bucketDepth)
batchUsage := (float64(batch.Utilization) / float64(maxUtilization)) * 100 // batch utilization between 0 and 100 percent
needsDilution := false

// dilution needed
if batchUsage >= utilizationThreshold {
needsDilution = true
if needsDilution {
batchTTL = batchTTL / (1 << extraDepth) // reduce batch TTL by 2^extraDepth
}

if batchTTL > ttlThreshold {

if needsDilution {
goto DILUTE
}

if batchTTL > ttlThreshold && !needsDilution {
continue
}

// dilution needed
if batchUsage >= utilizationThreshold {
needsDilution = true
batchTTL = batchTTL / (1 << extraDepth) // reduce batch TTL by 2^extraDepth
}

// Topup
if secondsPerBlock <= 0 {
secondsPerBlock = 1 // avoid division by zero
}

topUpTTL = topUpFinalTTL - batchTTL

if topUpTTL > 0 {
amount := (int64(topUpTTL.Seconds()) / secondsPerBlock) * price // number of blocks * price per block

n.log.Tracef("node %s: batch %s: required duration %d, amount %d", n.name, batch.BatchID, topUpTTL, amount)

if err := n.client.Postage.TopUpPostageBatch(ctx, batch.BatchID, amount, ""); err != nil {
return fmt.Errorf("node %s: top-up batch %s: %w", n.name, batch.BatchID, err)
}

n.log.Infof("node %s: topped up batch %s with amount %d", n.name, batch.BatchID, amount)
if err := n.handleTopup(ctx, batch, ttlThreshold, topUpFinalTTL, batchTTL, secondsPerBlock, price); err != nil {
return err
}

DILUTE:
if needsDilution {
newDepth := uint16(batch.Depth) + extraDepth

n.log.Tracef("node %s: batch %s: stamps usage %.2f%%, diluting to depth %d", n.name, batch.BatchID, batchUsage, newDepth)

if err := n.client.Postage.DilutePostageBatch(ctx, batch.BatchID, uint64(newDepth), ""); err != nil {
return fmt.Errorf("node %s: dilute batch %s: %w", n.name, batch.BatchID, err)
}

n.log.Infof("node %s: diluted batch %s to depth %d", n.name, batch.BatchID, newDepth)
return n.handleDilution(ctx, batch, extraDepth)
}
}

return nil
}

func (n *node) Topup(ctx context.Context, ttlThreshold time.Duration, topupDuration time.Duration, blockTime int64, batchIds []string) error {
chainState, err := n.client.Postage.GetChainState(ctx)
if err != nil {
return fmt.Errorf("node %s: get chain state: %w", n.name, err)
}

price := chainState.CurrentPrice.Int64()
if price <= 0 {
return fmt.Errorf("node %s: invalid chain price: %d", n.name, price)
}

batches, err := n.client.Postage.PostageBatches(ctx)
func (n *node) Topup(ctx context.Context, ttlThreshold time.Duration, topUpFinalTTL time.Duration, secondsPerBlock int64, batchIds []string) error {
batches, price, err := n.getPostageBatches(ctx, true)
if err != nil {
return fmt.Errorf("node %s: get postage batches: %w", n.name, err)
return err
}

for _, batch := range batches {
if len(batchIds) > 0 && !contains(batchIds, batch.BatchID) {
if !isValidBatch(&batch, batchIds) {
continue
}

if !batch.Usable || batch.Utilization == 0 {
continue
}
batchTTL := time.Duration(batch.BatchTTL) * time.Second

batchTTL := time.Unix(batch.BatchTTL, 0)
if time.Until(batchTTL) <= ttlThreshold {
depth := batch.Depth - batch.BucketDepth
multiplier := int64(1 << depth)
return n.handleTopup(ctx, batch, ttlThreshold, topUpFinalTTL, batchTTL, secondsPerBlock, price)
}

if blockTime <= 0 {
blockTime = 1 // avoid division by zero
}
return nil
}

secondsToTopup := int64(topupDuration.Seconds())
timeLeft := batchTTL.Unix() - time.Now().Unix()
if timeLeft < 0 {
timeLeft = 0
}
func (n *node) handleDilution(ctx context.Context, batch api.PostageStampResponse, extraDepth uint16) error {
newDepth := uint16(batch.Depth) + extraDepth

requiredDuration := secondsToTopup - timeLeft
if requiredDuration <= 0 {
continue
}
n.log.Tracef("node %s: batch %s: usage %.2f%%, diluting to depth %d", n.name, batch.BatchID, batch.BatchUsage(), newDepth)

if err := n.client.Postage.DilutePostageBatch(ctx, batch.BatchID, uint64(newDepth), ""); err != nil {
return fmt.Errorf("node %s: dilute batch %s: %w", n.name, batch.BatchID, err)
}

n.log.Infof("node %s: diluted batch %s to depth %d", n.name, batch.BatchID, newDepth)

return nil
}

amount := (requiredDuration / blockTime) * multiplier * price
func (n *node) handleTopup(ctx context.Context, batch api.PostageStampResponse, ttlThreshold, topUpFinalTTL, batchTTL time.Duration, secondsPerBlock, price int64) error {
if batchTTL <= ttlThreshold {
topUpTTL := topUpFinalTTL - batchTTL
if topUpTTL > 0 {
amount := (int64(topUpTTL.Seconds()) / secondsPerBlock) * price

n.log.Tracef("node %s: batch %s: required duration %d, amount %d", n.name, batch.BatchID, requiredDuration, amount)
n.log.Tracef("node %s: batch %s: required duration %d, amount %d", n.name, batch.BatchID, topUpTTL, amount)

if err := n.client.Postage.TopUpPostageBatch(ctx, batch.BatchID, amount, ""); err != nil {
return fmt.Errorf("node %s: top-up batch %s: %w", n.name, batch.BatchID, err)
Expand All @@ -229,6 +151,40 @@ func (n *node) Topup(ctx context.Context, ttlThreshold time.Duration, topupDurat
return nil
}

func (n *node) getPostageBatches(ctx context.Context, needPrice bool) (batches []api.PostageStampResponse, price int64, err error) {
if needPrice {
chainState, err := n.client.Postage.GetChainState(ctx)
if err != nil {
return nil, 0, fmt.Errorf("node %s: get chain state: %w", n.name, err)
}

price = chainState.CurrentPrice.Int64()
if price <= 0 {
return nil, 0, fmt.Errorf("node %s: invalid chain price: %d", n.name, price)
}
}

batches, err = n.client.Postage.PostageBatches(ctx)
if err != nil {
return nil, 0, fmt.Errorf("node %s: get postage batches: %w", n.name, err)
}

return batches, price, nil
}

// isValidBatch checks if a batch should be processed
func isValidBatch(batch *api.PostageStampResponse, batchIDs []string) bool {
if !batch.Usable || batch.Utilization == 0 || batch.BatchTTL <= 0 {
return false
}

if len(batchIDs) > 0 && !contains(batchIDs, batch.BatchID) {
return false
}

return true
}

func contains(slice []string, value string) bool {
for _, v := range slice {
if v == value {
Expand Down

0 comments on commit 077d60a

Please sign in to comment.