Skip to content

Commit

Permalink
feat: hook, timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
樊金东 committed Sep 14, 2021
1 parent 23e68c2 commit 16d1ddb
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 7 deletions.
40 changes: 33 additions & 7 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ func Head(url string, opts ...ReqOption) (*Response, error) {

type Client struct {
*http.Client
hooks []Hook
}

var DefaultClient = &Client{http.DefaultClient}
var DefaultClient = &Client{Client: http.DefaultClient}

func NewClient(opts ...ClientOption) *Client {
c := &Client{&http.Client{}}
c := &Client{Client: &http.Client{}}
for _, opt := range opts {
opt(c)
}
Expand Down Expand Up @@ -87,12 +88,33 @@ func (s *Client) Request(method, url string, opts ...ReqOption) (*Response, erro
return nil, err
}

resp, err := s.Do(req.Request)
if err != nil {
return nil, err
for _, h := range s.hooks {
h.BeforeProcess(req)
}

return NewResponse(resp)
var result *http.Response
var resp *Response
success := make(chan struct{})
done := req.Context().Done()
if done != nil {
go func() {
result, err = s.Do(req.Request)
close(success)
}()
select {
case <-done:
err = ErrTimeout
case <-success:
}
} else {
result, err = s.Do(req.Request)
}
if err == nil {
resp, err = NewResponse(result)
}
for _, h := range s.hooks {
h.AfterProcess(req, resp, err)
}
return resp, err
}

func (s *Client) Get(url string, opts ...ReqOption) (*Response, error) {
Expand Down Expand Up @@ -123,6 +145,10 @@ func (s *Client) Head(url string, opts ...ReqOption) (*Response, error) {
return s.Request(HEAD, url, opts...)
}

func (s *Client) AddHook(h Hook) {
s.hooks = append(s.hooks, h)
}

var unmarshal = json.Unmarshal
var marshal = json.Marshal

Expand Down
2 changes: 2 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ var (
ErrInvalidFile = errors.New("go-requests: Invalid file content")

ErrInvalidBodyType = errors.New("go-requests: Invalid Body Type")

ErrTimeout = errors.New("go-requests: timeout")
)
8 changes: 8 additions & 0 deletions hook.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package requests

type Hook interface {
// BeforeProcess Before the HTTP request is executed
BeforeProcess(req *Request)
// AfterProcess After the HTTP request is executed
AfterProcess(req *Request, resp *Response, err error)
}
23 changes: 23 additions & 0 deletions hook_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package requests

import (
"fmt"
"testing"
)

type mockHook struct {
}

func (m mockHook) BeforeProcess(req *Request) {
fmt.Println("before process")
}

func (m mockHook) AfterProcess(req *Request, resp *Response, err error) {
fmt.Println("after process")
}

func TestHook(t *testing.T) {
client := NewClient()
client.AddHook(mockHook{})
t.Log(client.Get(testUrl))
}
21 changes: 21 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package requests

import (
"bytes"
"context"
"fmt"
"github.com/pkg/errors"
"io"
Expand Down Expand Up @@ -143,3 +144,23 @@ func (f file) Do(req *Request) error {
req.files = append(req.files, &f)
return nil
}

type Ctx struct {
context.Context
}

func (c Ctx) Do(req *Request) error {
req.Request = req.WithContext(c)
return nil
}

type Timeout time.Duration

func (t Timeout) Do(req *Request) error {
if time.Duration(t) == 0 {
return nil
}
ctx, _ := context.WithTimeout(req.Context(), time.Duration(t))
req.Request = req.WithContext(ctx)
return nil
}
25 changes: 25 additions & 0 deletions option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,31 @@ func TestWithTimeout(t *testing.T) {
}
}

func TestTimeout(t *testing.T) {
url := testUrl + "/timeout"
type args struct {
timeout time.Duration
}
tests := []struct {
name string
args args
wantError bool
}{
{args: args{}},
{args: args{timeout: 2 * time.Second}},
{args: args{timeout: 1100 * time.Millisecond}},
{args: args{timeout: 1000 * time.Millisecond}, wantError: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := Get(url, Timeout(tt.args.timeout))
if !reflect.DeepEqual(err != nil, tt.wantError) {
t.Errorf("WithTimeout() err = %v, wantError %v", err, tt.wantError)
}
})
}
}

func TestHeaders(t *testing.T) {
url := testUrl + "/header"
type args struct {
Expand Down

0 comments on commit 16d1ddb

Please sign in to comment.