diff --git a/client.go b/client.go index 595ee93..6c4ca34 100644 --- a/client.go +++ b/client.go @@ -52,12 +52,13 @@ func Head(url string, opts ...ReqOption) (*Response, error) { type Client struct { *http.Client + hooks []Hook } -var DefaultClient = &Client{http.DefaultClient} +var DefaultClient = &Client{Client: http.DefaultClient} func NewClient(opts ...ClientOption) *Client { - c := &Client{&http.Client{}} + c := &Client{Client: &http.Client{}} for _, opt := range opts { opt(c) } @@ -87,12 +88,33 @@ func (s *Client) Request(method, url string, opts ...ReqOption) (*Response, erro return nil, err } - resp, err := s.Do(req.Request) - if err != nil { - return nil, err + for _, h := range s.hooks { + h.BeforeProcess(req) } - - return NewResponse(resp) + var result *http.Response + var resp *Response + success := make(chan struct{}) + done := req.Context().Done() + if done != nil { + go func() { + result, err = s.Do(req.Request) + close(success) + }() + select { + case <-done: + err = ErrTimeout + case <-success: + } + } else { + result, err = s.Do(req.Request) + } + if err == nil { + resp, err = NewResponse(result) + } + for _, h := range s.hooks { + h.AfterProcess(req, resp, err) + } + return resp, err } func (s *Client) Get(url string, opts ...ReqOption) (*Response, error) { @@ -123,6 +145,10 @@ func (s *Client) Head(url string, opts ...ReqOption) (*Response, error) { return s.Request(HEAD, url, opts...) } +func (s *Client) AddHook(h Hook) { + s.hooks = append(s.hooks, h) +} + var unmarshal = json.Unmarshal var marshal = json.Marshal diff --git a/error.go b/error.go index 3a419e5..314dced 100644 --- a/error.go +++ b/error.go @@ -21,4 +21,6 @@ var ( ErrInvalidFile = errors.New("go-requests: Invalid file content") ErrInvalidBodyType = errors.New("go-requests: Invalid Body Type") + + ErrTimeout = errors.New("go-requests: timeout") ) diff --git a/hook.go b/hook.go new file mode 100644 index 0000000..8548c10 --- /dev/null +++ b/hook.go @@ -0,0 +1,8 @@ +package requests + +type Hook interface { + // BeforeProcess Before the HTTP request is executed + BeforeProcess(req *Request) + // AfterProcess After the HTTP request is executed + AfterProcess(req *Request, resp *Response, err error) +} diff --git a/hook_test.go b/hook_test.go new file mode 100644 index 0000000..6cc5a8b --- /dev/null +++ b/hook_test.go @@ -0,0 +1,23 @@ +package requests + +import ( + "fmt" + "testing" +) + +type mockHook struct { +} + +func (m mockHook) BeforeProcess(req *Request) { + fmt.Println("before process") +} + +func (m mockHook) AfterProcess(req *Request, resp *Response, err error) { + fmt.Println("after process") +} + +func TestHook(t *testing.T) { + client := NewClient() + client.AddHook(mockHook{}) + t.Log(client.Get(testUrl)) +} diff --git a/option.go b/option.go index 01cc2b2..6837060 100644 --- a/option.go +++ b/option.go @@ -2,6 +2,7 @@ package requests import ( "bytes" + "context" "fmt" "github.com/pkg/errors" "io" @@ -143,3 +144,23 @@ func (f file) Do(req *Request) error { req.files = append(req.files, &f) return nil } + +type Ctx struct { + context.Context +} + +func (c Ctx) Do(req *Request) error { + req.Request = req.WithContext(c) + return nil +} + +type Timeout time.Duration + +func (t Timeout) Do(req *Request) error { + if time.Duration(t) == 0 { + return nil + } + ctx, _ := context.WithTimeout(req.Context(), time.Duration(t)) + req.Request = req.WithContext(ctx) + return nil +} diff --git a/option_test.go b/option_test.go index da72e57..23845c7 100644 --- a/option_test.go +++ b/option_test.go @@ -34,6 +34,31 @@ func TestWithTimeout(t *testing.T) { } } +func TestTimeout(t *testing.T) { + url := testUrl + "/timeout" + type args struct { + timeout time.Duration + } + tests := []struct { + name string + args args + wantError bool + }{ + {args: args{}}, + {args: args{timeout: 2 * time.Second}}, + {args: args{timeout: 1100 * time.Millisecond}}, + {args: args{timeout: 1000 * time.Millisecond}, wantError: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Get(url, Timeout(tt.args.timeout)) + if !reflect.DeepEqual(err != nil, tt.wantError) { + t.Errorf("WithTimeout() err = %v, wantError %v", err, tt.wantError) + } + }) + } +} + func TestHeaders(t *testing.T) { url := testUrl + "/header" type args struct {