Skip to content

Commit

Permalink
Fix the Middleware type (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
spyzhov authored Nov 11, 2022
1 parent fb88c58 commit ec6e3f2
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 38 deletions.
45 changes: 21 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,8 @@ logs, and headers transparently into all requests through the cHTTP clients.
Current interface of the middleware is based on 2 types of functions:

```go
// RoundTripper is a RoundTrip function implementation of the http.RoundTripper interface.
type RoundTripper func(request *http.Request) (*http.Response, error)

// Middleware is an extended interface to the RoundTrip function of the http.RoundTripper interface.
type Middleware func(request *http.Request, next RoundTripper) (*http.Response, error)
type Middleware func(request *http.Request, next func(request *http.Request) (*http.Response, error)) (*http.Response, error)
```

Usage example:
Expand All @@ -134,7 +131,7 @@ func main() {
}
client := chttp.NewJSON(nil)
client.With(middleware.JSON(), middleware.Debug(true, nil))
client.With(func(request *http.Request, next chttp.RoundTripper) (*http.Response, error) {
client.With(func(request *http.Request, next func(request *http.Request) (*http.Response, error)) (*http.Response, error) {
fmt.Println("Before the request")
resp, err := next(request)
fmt.Println("After the request")
Expand All @@ -155,15 +152,15 @@ Adds a custom headers based on the request.
**Example:**

```go
chttp.NewClient(nil).
With(middleware.CustomHeaders(func(request *http.Request) map[string]string {
if request.Method == http.MethodPost {
return map[string]string{
"Accept": "*/*",
}
}
return nil
}))
client := chttp.NewClient(nil)
client.With(middleware.CustomHeaders(func(request *http.Request) map[string]string {
if request.Method == http.MethodPost {
return map[string]string{
"Accept": "*/*",
}
}
return nil
}))
```

#### Debug
Expand All @@ -175,8 +172,8 @@ Dumps requests and responses in the logs.
**Example:**

```go
chttp.NewClient(nil).
With(middleware.Debug(true, nil))
client := chttp.NewClient(nil)
client.With(middleware.Debug(true, nil))
```

#### Headers
Expand All @@ -186,10 +183,10 @@ Adds a static headers.
**Example:**

```go
chttp.NewClient(nil).
With(middleware.Headers(map[string]string{
"Accept": "*/*",
}))
client := chttp.NewClient(nil)
client.With(middleware.Headers(map[string]string{
"Accept": "*/*",
}))
```

#### JSON
Expand All @@ -199,8 +196,8 @@ Adds a `Content-Type` and `Accept` headers with the `application/json` value.
**Example:**

```go
chttp.NewClient(nil).
With(middleware.JSON())
client := chttp.NewClient(nil)
client.With(middleware.JSON())
```

#### Trace
Expand All @@ -210,8 +207,8 @@ Adds short logs on each request.
**Example:**

```go
chttp.NewClient(nil).
With(middleware.Trace(nil))
client := chttp.NewClient(nil)
client.With(middleware.Trace(nil))
```

# License
Expand Down
4 changes: 2 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func TestClient_Method(t *testing.T) {
{
name: "error",
middlewares: []Middleware{
func(request *http.Request, next RoundTripper) (*http.Response, error) {
func(request *http.Request, next func(request *http.Request) (*http.Response, error)) (*http.Response, error) {
return nil, fmt.Errorf("test error")
},
},
Expand Down Expand Up @@ -134,7 +134,7 @@ func TestClient_With(t *testing.T) {
var index int
for i := 0; i < 10; i++ {
c.With((func(i int) Middleware {
return func(request *http.Request, next RoundTripper) (*http.Response, error) {
return func(request *http.Request, next func(request *http.Request) (*http.Response, error)) (*http.Response, error) {
if index != i {
t.Errorf("middleware called on wrong position: expected %d, actual %d", i, index)
}
Expand Down
2 changes: 1 addition & 1 deletion json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func TestJSONClient_Method(t *testing.T) {
{
name: "error",
middlewares: []Middleware{
func(request *http.Request, next RoundTripper) (*http.Response, error) {
func(request *http.Request, next func(request *http.Request) (*http.Response, error)) (*http.Response, error) {
return nil, fmt.Errorf("test error")
},
},
Expand Down
67 changes: 67 additions & 0 deletions middleware/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Middlewares

## CustomHeaders

Adds a custom headers based on the request.

**Example:**

```go
client := chttp.NewClient(nil)
client.With(middleware.CustomHeaders(func(request *http.Request) map[string]string {
if request.Method == http.MethodPost {
return map[string]string{
"Accept": "*/*",
}
}
return nil
}))
```

## Debug

**NB!** Don't use it in production!

Dumps requests and responses in the logs.

**Example:**

```go
client := chttp.NewClient(nil)
client.With(middleware.Debug(true, nil))
```

## Headers

Adds a static headers.

**Example:**

```go
client := chttp.NewClient(nil)
client.With(middleware.Headers(map[string]string{
"Accept": "*/*",
}))
```

## JSON

Adds a `Content-Type` and `Accept` headers with the `application/json` value.

**Example:**

```go
client := chttp.NewClient(nil)
client.With(middleware.JSON())
```

## Trace

Adds short logs on each request.

**Example:**

```go
client := chttp.NewClient(nil)
client.With(middleware.Trace(nil))
```
2 changes: 1 addition & 1 deletion middleware/custom_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

// CustomHeaders is a chttp.Middleware constructor to add custom header based on the request to request.
func CustomHeaders(headers func(request *http.Request) map[string]string) chttp.Middleware {
return func(request *http.Request, next chttp.RoundTripper) (*http.Response, error) {
return func(request *http.Request, next func(request *http.Request) (*http.Response, error)) (*http.Response, error) {
for name, value := range headers(request) {
request.Header.Set(name, value)
}
Expand Down
2 changes: 1 addition & 1 deletion middleware/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

// Debug is a constructor for Debug, that provides default transport
func Debug(active bool, logger Logger) chttp.Middleware {
return func(request *http.Request, next chttp.RoundTripper) (*http.Response, error) {
return func(request *http.Request, next func(request *http.Request) (*http.Response, error)) (*http.Response, error) {
if !active {
return next(request)
}
Expand Down
2 changes: 1 addition & 1 deletion middleware/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

// Headers is a chttp.Middleware constructor to add static headers to request.
func Headers(headers map[string]string, force bool) chttp.Middleware {
return func(request *http.Request, next chttp.RoundTripper) (*http.Response, error) {
return func(request *http.Request, next func(request *http.Request) (*http.Response, error)) (*http.Response, error) {
for name, value := range headers {
if force || request.Header.Get(name) == "" {
request.Header.Set(name, value)
Expand Down
2 changes: 1 addition & 1 deletion middleware/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

// Trace middleware adds short logs on each request.
func Trace(logger Logger) chttp.Middleware {
return func(request *http.Request, next chttp.RoundTripper) (response *http.Response, err error) {
return func(request *http.Request, next func(request *http.Request) (*http.Response, error)) (response *http.Response, err error) {
defer func(start time.Time) {
var path string
if request.URL != nil {
Expand Down
6 changes: 3 additions & 3 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ func (c *Client) transport(Default http.RoundTripper) http.RoundTripper {

type transport struct {
Client *Client
Default RoundTripper
Default func(request *http.Request) (*http.Response, error)
}

func (t *transport) RoundTrip(request *http.Request) (*http.Response, error) {
middlewares := t.Client.getMiddlewares()
var next RoundTripper
var next func(request *http.Request) (*http.Response, error)
next = func(request *http.Request) (*http.Response, error) {
var middleware Middleware
if len(middlewares) == 0 {
Expand All @@ -36,7 +36,7 @@ func (t *transport) RoundTrip(request *http.Request) (*http.Response, error) {
}

func (t *transport) getDefault() Middleware {
return func(request *http.Request, _ RoundTripper) (*http.Response, error) {
return func(request *http.Request, _ func(request *http.Request) (*http.Response, error)) (*http.Response, error) {
return t.Default(request)
}
}
5 changes: 1 addition & 4 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,5 @@ import (
"net/http"
)

// RoundTripper is a RoundTrip function implementation of the http.RoundTripper interface.
type RoundTripper func(request *http.Request) (*http.Response, error)

// Middleware is an extended interface to the RoundTrip function of the http.RoundTripper interface.
type Middleware func(request *http.Request, next RoundTripper) (*http.Response, error)
type Middleware func(request *http.Request, next func(request *http.Request) (*http.Response, error)) (*http.Response, error)

0 comments on commit ec6e3f2

Please sign in to comment.