Skip to content

Commit

Permalink
Backoff and Retry Client
Browse files Browse the repository at this point in the history
  • Loading branch information
bbengfort committed Jan 3, 2024
1 parent 8a8d8b5 commit 0664688
Show file tree
Hide file tree
Showing 18 changed files with 340 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.20.x
go-version: 1.21.x

- name: Install Staticcheck
run: go install honnef.co/go/tools/cmd/[email protected]
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2023 trisacrypto
Copyright (c) 2023 TRISA

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
19 changes: 17 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,17 @@
# courier
A standalone certificate delivery service
# Courier

A stand-alone service that allows the GDS to deliver TRISA certificates via a webhook
rather than email. The service accepts PCKS12 passwords and encrypted certificates from
TRISA as HTTP `POST` requests and stores the certificates and passwords in either
Google Secret Manager or on the local disk (other secret management backends such as
Vault or Postgres may be available in the future).

This tool is mostly used by TRISA Service Providers (TSPs) who have to handle many
TRISA certificate deliveries at a time. VASPs who want to automate certificate delivery
may also use this service.

## Deploying with Docker

The simplest way to run the courier service is to use the docker image
`trisa/courier:latest` and to configure it from the environment. This allows the
courier service to be easily run on a Kubernetes cluster.
2 changes: 1 addition & 1 deletion cmd/courier/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ func storeCertificate(c *cli.Context) (err error) {

// Get a secret from the secret manager.
func getSecret(c *cli.Context) (err error) {
conf := config.SecretsConfig{
conf := config.GCPSecretsConfig{
Enabled: true,
Project: c.String("project"),
Credentials: c.String("credentials"),
Expand Down
2 changes: 1 addition & 1 deletion containers/courier/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Dynamic Builds
ARG BUILDER_IMAGE=golang:1.20-buster
ARG BUILDER_IMAGE=golang:1.21-buster
ARG FINAL_IMAGE=debian:buster-slim

# Build Stage
Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module github.com/trisacrypto/courier

go 1.20
go 1.21

require (
cloud.google.com/go/secretmanager v1.11.2
github.com/cenkalti/backoff/v4 v4.2.1
github.com/gin-gonic/gin v1.9.1
github.com/googleapis/gax-go v1.0.3
github.com/joho/godotenv v1.5.1
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.110.6 h1:8uYAkj3YHTP/1iwReuHPxLSbdcyc+dSBbzFMrVwDR6Q=
cloud.google.com/go v0.110.6/go.mod h1:+EYjdK8e5RME/VY/qLCAtuyALQ9q67dvuum8i+H5xsI=
cloud.google.com/go/compute v1.23.0 h1:tP41Zoavr8ptEqaW6j+LQOnyBBhO7OkOMAGrgLopTwY=
cloud.google.com/go/compute v1.23.0/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM=
cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
Expand All @@ -16,6 +17,8 @@ github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kd
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
Expand Down Expand Up @@ -48,6 +51,7 @@ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
Expand Down
74 changes: 67 additions & 7 deletions pkg/api/v1/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,31 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"

"github.com/cenkalti/backoff/v4"
)

const DefaultRetries = 3

func DefaultBackoff() BackoffFactory {
return func() backoff.BackOff {
return backoff.NewExponentialBackOff()
}
}

// New creates a new API client that implements the CourierClient interface.
func New(endpoint string, opts ...ClientOption) (_ CourierClient, err error) {
if endpoint == "" {
return nil, ErrEndpointRequired
}

// Create a client with the parsed endpoint.
c := &APIv1{}
c := &APIv1{retries: -1}
if c.url, err = url.Parse(endpoint); err != nil {
return nil, err
}
Expand All @@ -39,13 +48,26 @@ func New(endpoint string, opts ...ClientOption) (_ CourierClient, err error) {
Timeout: 30 * time.Second,
}
}

// If backoff hasn't been specified add the default backoff factory
if c.backoff == nil {
c.backoff = DefaultBackoff()
}

// If retries haven't been specified add the default number of retries
if c.retries < 0 {
c.retries = DefaultRetries
}

return c, nil
}

// APIv1 implements the CourierClient interface.
type APIv1 struct {
url *url.URL
client *http.Client
url *url.URL
client *http.Client
backoff BackoffFactory
retries int
}

var _ CourierClient = &APIv1{}
Expand Down Expand Up @@ -172,8 +194,47 @@ func (c *APIv1) NewRequest(ctx context.Context, method, path string, data interf
}

// Do executes an http request against the server, performs error checking, and
// deserializes response data into the specified struct.
// deserializes response data into the specified struct. This function also manages
// retries using a backoff strategy.
func (s *APIv1) Do(req *http.Request, data interface{}, checkStatus bool) (rep *http.Response, err error) {
attempts := 0
start := time.Now()
ctx := req.Context()
delay := s.backoff()
errs := make([]error, 0, s.retries+1)

for attempts <= s.retries {
attempts++
if rep, err = s.do(req, data, checkStatus); err == nil {
// Success!
return rep, nil
}

// Failure! Retry as needed.
errs = append(errs, err)

// Compute the backoff delay before the next request
dur := delay.NextBackOff()
if dur == backoff.Stop {
// Stop indicates no more retries should be allowed.
return rep, JoinStatusErrors(attempts, time.Since(start), errs...)
}

// Wait for backoff delay or until context is canceled
wait := time.After(dur)
select {
case <-ctx.Done():
errs = append(errs, ctx.Err())
return rep, JoinStatusErrors(attempts, time.Since(start), errs...)
case <-wait:
continue
}
}

return rep, JoinStatusErrors(attempts, time.Since(start), errs...)
}

func (s *APIv1) do(req *http.Request, data interface{}, checkStatus bool) (rep *http.Response, err error) {
if rep, err = s.client.Do(req); err != nil {
return rep, err
}
Expand All @@ -189,8 +250,7 @@ func (s *APIv1) Do(req *http.Request, data interface{}, checkStatus bool) (rep *
return rep, NewStatusError(rep.StatusCode, reply.Error)
}
}

return rep, errors.New(rep.Status)
return rep, NewStatusError(rep.StatusCode, rep.Status)
}
}

Expand Down
82 changes: 69 additions & 13 deletions pkg/api/v1/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/http"
"time"

"github.com/gin-gonic/gin"
)
Expand All @@ -15,21 +16,9 @@ var (
notAllowed = Reply{Success: false, Error: "method not allowed"}
ErrEndpointRequired = errors.New("endpoint is required")
ErrIDRequired = errors.New("missing ID in request")
ErrInvalidRetries = errors.New("number of retries must be zero or more")
)

func NewStatusError(code int, err string) error {
return &StatusError{Code: code, Err: err}
}

type StatusError struct {
Code int
Err string
}

func (e StatusError) Error() string {
return fmt.Sprintf("[%d]: %s", e.Code, e.Err)
}

// ErrorResponse constructs an new response from the error or returns a success: false.
func ErrorResponse(err interface{}) Reply {
if err == nil {
Expand Down Expand Up @@ -57,6 +46,73 @@ func ErrorResponse(err interface{}) Reply {
return rep
}

func NewStatusError(code int, err string) error {
return &StatusError{Code: code, Err: err}
}

type StatusError struct {
Code int
Err string
}

func (e StatusError) Error() string {
return fmt.Sprintf("[%d]: %s", e.Code, e.Err)
}

// Deduplicates status errors and creates a multi-status error to return. Removes nil
// errors and returns nil if all errs are nil. If only one errors is returned, return
// that error instead of a multierror (e.g. if all responses have the same status code).
func JoinStatusErrors(attempts int, delay time.Duration, errs ...error) error {
err := &MultiStatusError{
Errs: make([]error, 0),
Attempts: attempts,
}

seen := make(map[string]struct{})
for _, e := range errs {
if e == nil {
continue
}

if _, ok := seen[e.Error()]; ok {
continue
}

err.Errs = append(err.Errs, e)
seen[e.Error()] = struct{}{}
}

switch len(err.Errs) {
case 0:
return nil
case 1:
return err.Errs[0]
default:
return err
}
}

type MultiStatusError struct {
Errs []error
Attempts int
Delay time.Duration
}

func (e *MultiStatusError) Error() string {
return fmt.Sprintf("after %d attempts: %s", e.Attempts, e.Last())
}

func (e *MultiStatusError) Last() error {
if len(e.Errs) > 0 {
return e.Errs[len(e.Errs)-1]
}
return nil
}

func (e *MultiStatusError) Unwrap() []error {
return e.Errs
}

// NotFound returns a standard 404 response.
func NotFound(c *gin.Context) {
c.JSON(http.StatusNotFound, notFound)
Expand Down
74 changes: 74 additions & 0 deletions pkg/api/v1/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package api_test

import (
"errors"
"net/http"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/trisacrypto/courier/pkg/api/v1"
)

func TestJoinStatusErrors(t *testing.T) {
t.Run("Empty", func(t *testing.T) {
err := api.JoinStatusErrors(0, 0, nil)
require.NoError(t, err, "expected a nil error returned")

err = api.JoinStatusErrors(0, 0, nil, nil, nil, nil, nil, nil)
require.NoError(t, err, "expected a nil error returned for multiple nil errors")
})

t.Run("SingleStatusError", func(t *testing.T) {
err := api.JoinStatusErrors(1, 421*time.Millisecond, api.NewStatusError(http.StatusServiceUnavailable, "could not reach specified service"))
require.Error(t, err, "expected error to be returned")

serr, ok := err.(*api.StatusError)
require.True(t, ok, "expected error to be a status error, not a multi status error")
require.Equal(t, 503, serr.Code)
})

t.Run("SingleError", func(t *testing.T) {
err := api.JoinStatusErrors(1, 421*time.Millisecond, errors.New("something went wrong"))
require.Error(t, err, "expected error to be returned")

_, ok := err.(*api.StatusError)
require.False(t, ok, "expected error to not be a status error")
require.EqualError(t, err, "something went wrong")
})

t.Run("MultiStatusErrors", func(t *testing.T) {})

t.Run("MultiErrors", func(t *testing.T) {})

t.Run("Mixed", func(t *testing.T) {})

t.Run("Deduplication", func(t *testing.T) {})

t.Run("MultiDeduplication", func(t *testing.T) {})
}

func TestMultiStatusError(t *testing.T) {
testCases := []struct {
err *api.MultiStatusError
expected string
}{
{
&api.MultiStatusError{
Attempts: 1,
Delay: 585 * time.Millisecond,
Errs: []error{
&api.StatusError{
Code: http.StatusInternalServerError,
Err: http.StatusText(http.StatusInternalServerError),
},
},
},
"after 1 attempts: [500]: Internal Server Error",
},
}

for i, tc := range testCases {
require.EqualError(t, tc.err, tc.expected, "test case %d failed", i)
}
}
Loading

0 comments on commit 0664688

Please sign in to comment.