Skip to content

Commit

Permalink
Merge pull request #9 from trisacrypto/sc-22752
Browse files Browse the repository at this point in the history
Backoff and Retry Client
  • Loading branch information
bbengfort authored Jan 3, 2024
2 parents 8a8d8b5 + 3dc820a commit f38ddfb
Show file tree
Hide file tree
Showing 19 changed files with 428 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.20.x
go-version: 1.21.x

- name: Install Staticcheck
run: go install honnef.co/go/tools/cmd/[email protected]
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2023 trisacrypto
Copyright (c) 2023 TRISA

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
19 changes: 17 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,17 @@
# courier
A standalone certificate delivery service
# Courier

A stand-alone service that allows the GDS to deliver TRISA certificates via a webhook
rather than email. The service accepts PCKS12 passwords and encrypted certificates from
TRISA as HTTP `POST` requests and stores the certificates and passwords in either
Google Secret Manager or on the local disk (other secret management backends such as
Vault or Postgres may be available in the future).

This tool is mostly used by TRISA Service Providers (TSPs) who have to handle many
TRISA certificate deliveries at a time. VASPs who want to automate certificate delivery
may also use this service.

## Deploying with Docker

The simplest way to run the courier service is to use the docker image
`trisa/courier:latest` and to configure it from the environment. This allows the
courier service to be easily run on a Kubernetes cluster.
2 changes: 1 addition & 1 deletion cmd/courier/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ func storeCertificate(c *cli.Context) (err error) {

// Get a secret from the secret manager.
func getSecret(c *cli.Context) (err error) {
conf := config.SecretsConfig{
conf := config.GCPSecretsConfig{
Enabled: true,
Project: c.String("project"),
Credentials: c.String("credentials"),
Expand Down
2 changes: 1 addition & 1 deletion containers/courier/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Dynamic Builds
ARG BUILDER_IMAGE=golang:1.20-buster
ARG BUILDER_IMAGE=golang:1.21-buster
ARG FINAL_IMAGE=debian:buster-slim

# Build Stage
Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module github.com/trisacrypto/courier

go 1.20
go 1.21

require (
cloud.google.com/go/secretmanager v1.11.2
github.com/cenkalti/backoff/v4 v4.2.1
github.com/gin-gonic/gin v1.9.1
github.com/googleapis/gax-go v1.0.3
github.com/joho/godotenv v1.5.1
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.110.6 h1:8uYAkj3YHTP/1iwReuHPxLSbdcyc+dSBbzFMrVwDR6Q=
cloud.google.com/go v0.110.6/go.mod h1:+EYjdK8e5RME/VY/qLCAtuyALQ9q67dvuum8i+H5xsI=
cloud.google.com/go/compute v1.23.0 h1:tP41Zoavr8ptEqaW6j+LQOnyBBhO7OkOMAGrgLopTwY=
cloud.google.com/go/compute v1.23.0/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM=
cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
Expand All @@ -16,6 +17,8 @@ github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kd
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
Expand Down Expand Up @@ -48,6 +51,7 @@ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
Expand Down
74 changes: 67 additions & 7 deletions pkg/api/v1/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,31 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"

"github.com/cenkalti/backoff/v4"
)

const DefaultRetries = 3

func DefaultBackoff() BackoffFactory {
return func() backoff.BackOff {
return backoff.NewExponentialBackOff()
}
}

// New creates a new API client that implements the CourierClient interface.
func New(endpoint string, opts ...ClientOption) (_ CourierClient, err error) {
if endpoint == "" {
return nil, ErrEndpointRequired
}

// Create a client with the parsed endpoint.
c := &APIv1{}
c := &APIv1{retries: -1}
if c.url, err = url.Parse(endpoint); err != nil {
return nil, err
}
Expand All @@ -39,13 +48,26 @@ func New(endpoint string, opts ...ClientOption) (_ CourierClient, err error) {
Timeout: 30 * time.Second,
}
}

// If backoff hasn't been specified add the default backoff factory
if c.backoff == nil {
c.backoff = DefaultBackoff()
}

// If retries haven't been specified add the default number of retries
if c.retries < 0 {
c.retries = DefaultRetries
}

return c, nil
}

// APIv1 implements the CourierClient interface.
type APIv1 struct {
url *url.URL
client *http.Client
url *url.URL
client *http.Client
backoff BackoffFactory
retries int
}

var _ CourierClient = &APIv1{}
Expand Down Expand Up @@ -172,8 +194,47 @@ func (c *APIv1) NewRequest(ctx context.Context, method, path string, data interf
}

// Do executes an http request against the server, performs error checking, and
// deserializes response data into the specified struct.
// deserializes response data into the specified struct. This function also manages
// retries using a backoff strategy.
func (s *APIv1) Do(req *http.Request, data interface{}, checkStatus bool) (rep *http.Response, err error) {
attempts := 0
start := time.Now()
ctx := req.Context()
delay := s.backoff()
errs := make([]error, 0, s.retries+1)

for attempts <= s.retries {
attempts++
if rep, err = s.do(req, data, checkStatus); err == nil {
// Success!
return rep, nil
}

// Failure! Retry as needed.
errs = append(errs, err)

// Compute the backoff delay before the next request
dur := delay.NextBackOff()
if dur == backoff.Stop {
// Stop indicates no more retries should be allowed.
return rep, JoinStatusErrors(attempts, time.Since(start), errs...)
}

// Wait for backoff delay or until context is canceled
wait := time.After(dur)
select {
case <-ctx.Done():
errs = append(errs, ctx.Err())
return rep, JoinStatusErrors(attempts, time.Since(start), errs...)
case <-wait:
continue
}
}

return rep, JoinStatusErrors(attempts, time.Since(start), errs...)
}

func (s *APIv1) do(req *http.Request, data interface{}, checkStatus bool) (rep *http.Response, err error) {
if rep, err = s.client.Do(req); err != nil {
return rep, err
}
Expand All @@ -189,8 +250,7 @@ func (s *APIv1) Do(req *http.Request, data interface{}, checkStatus bool) (rep *
return rep, NewStatusError(rep.StatusCode, reply.Error)
}
}

return rep, errors.New(rep.Status)
return rep, NewStatusError(rep.StatusCode, rep.Status)
}
}

Expand Down
31 changes: 31 additions & 0 deletions pkg/api/v1/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ import (
"context"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/stretchr/testify/require"
"github.com/trisacrypto/courier/pkg/api/v1"
)
Expand Down Expand Up @@ -63,3 +66,31 @@ func TestStoreCertificatePassword(t *testing.T) {
err = client.StoreCertificatePassword(context.Background(), req)
require.ErrorIs(t, err, api.ErrIDRequired, "client should error if no ID is provided")
}

func TestRetriesWithBackoff(t *testing.T) {
// Create a test server
var attempts uint32
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddUint32(&attempts, 1)
http.Error(w, http.StatusText(http.StatusTooEarly), http.StatusTooEarly)
}))
defer ts.Close()

// Create a client to test the client method
client, err := api.New(ts.URL, api.WithRetries(10), api.WithBackoff(func() backoff.BackOff {
return backoff.NewConstantBackOff(100 * time.Millisecond)
}))
require.NoError(t, err, "could not create client")

rawClient, ok := client.(*api.APIv1)
require.True(t, ok, "expected client to be an APIv1 client")

req, err := rawClient.NewRequest(context.Background(), http.MethodGet, "/", nil, nil)
require.NoError(t, err, "could not create request")

start := time.Now()
_, err = rawClient.Do(req, nil, true)
require.Error(t, err, "expected an error to be returned")
require.Equal(t, uint32(11), attempts, "expected 10 retry attempts")
require.Greater(t, time.Since(start), 950*time.Millisecond, "expected backoff delay")
}
85 changes: 72 additions & 13 deletions pkg/api/v1/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/http"
"time"

"github.com/gin-gonic/gin"
)
Expand All @@ -15,21 +16,9 @@ var (
notAllowed = Reply{Success: false, Error: "method not allowed"}
ErrEndpointRequired = errors.New("endpoint is required")
ErrIDRequired = errors.New("missing ID in request")
ErrInvalidRetries = errors.New("number of retries must be zero or more")
)

func NewStatusError(code int, err string) error {
return &StatusError{Code: code, Err: err}
}

type StatusError struct {
Code int
Err string
}

func (e StatusError) Error() string {
return fmt.Sprintf("[%d]: %s", e.Code, e.Err)
}

// ErrorResponse constructs an new response from the error or returns a success: false.
func ErrorResponse(err interface{}) Reply {
if err == nil {
Expand Down Expand Up @@ -57,6 +46,76 @@ func ErrorResponse(err interface{}) Reply {
return rep
}

func NewStatusError(code int, err string) error {
if err == "" {
err = http.StatusText(code)
}
return &StatusError{Code: code, Err: err}
}

type StatusError struct {
Code int
Err string
}

func (e StatusError) Error() string {
return fmt.Sprintf("[%d]: %s", e.Code, e.Err)
}

// Deduplicates status errors and creates a multi-status error to return. Removes nil
// errors and returns nil if all errs are nil. If only one errors is returned, return
// that error instead of a multierror (e.g. if all responses have the same status code).
func JoinStatusErrors(attempts int, delay time.Duration, errs ...error) error {
err := &MultiStatusError{
Errs: make([]error, 0),
Attempts: attempts,
}

seen := make(map[string]struct{})
for _, e := range errs {
if e == nil {
continue
}

if _, ok := seen[e.Error()]; ok {
continue
}

err.Errs = append(err.Errs, e)
seen[e.Error()] = struct{}{}
}

switch len(err.Errs) {
case 0:
return nil
case 1:
return err.Errs[0]
default:
return err
}
}

type MultiStatusError struct {
Errs []error
Attempts int
Delay time.Duration
}

func (e *MultiStatusError) Error() string {
return fmt.Sprintf("after %d attempts: %s", e.Attempts, e.Last())
}

func (e *MultiStatusError) Last() error {
if len(e.Errs) > 0 {
return e.Errs[len(e.Errs)-1]
}
return nil
}

func (e *MultiStatusError) Unwrap() []error {
return e.Errs
}

// NotFound returns a standard 404 response.
func NotFound(c *gin.Context) {
c.JSON(http.StatusNotFound, notFound)
Expand Down
Loading

0 comments on commit f38ddfb

Please sign in to comment.