From fd8c792f55cd4e62179b434434ae4779ee3ac460 Mon Sep 17 00:00:00 2001 From: fanjindong <765912710@qq.com> Date: Fri, 28 Oct 2022 11:00:05 +0800 Subject: [PATCH] feat: support gzip --- option.go | 8 ++++++++ option_test.go | 32 ++++++++++++++++++++++++++++++++ request.go | 40 ++++++++++++++++++++++++++++++++++++++-- response.go | 12 ++++++++++++ server_test.go | 9 +++++++++ 5 files changed, 99 insertions(+), 2 deletions(-) diff --git a/option.go b/option.go index 6837060..051362a 100644 --- a/option.go +++ b/option.go @@ -161,6 +161,14 @@ func (t Timeout) Do(req *Request) error { return nil } ctx, _ := context.WithTimeout(req.Context(), time.Duration(t)) + // todo call cancel req.Request = req.WithContext(ctx) return nil } + +type Gzip struct{} + +func (Gzip) Do(req *Request) error { + req.gzip = true + return nil +} diff --git a/option_test.go b/option_test.go index 23845c7..d9c03ef 100644 --- a/option_test.go +++ b/option_test.go @@ -228,3 +228,35 @@ func TestFile(t *testing.T) { }) } } + +func TestGzip(t *testing.T) { + url := testUrl + "/post" + type args struct { + opts []ReqOption + } + tests := []struct { + name string + args args + want string + }{ + {args: args{[]ReqOption{Gzip{}}}, want: ""}, + {args: args{opts: []ReqOption{Gzip{}, Json{"a": "1"}}}, want: `{"a":"1"}`}, + {args: args{opts: []ReqOption{Json{"a": "1", "b": 2}, Gzip{}}}, want: `{"a":"1","b":2}`}, + {args: args{opts: []ReqOption{Json{"a": "1"}, Json{"b": "2"}, Gzip{}}}, want: `{"a":"1","b":"2"}`}, + {args: args{opts: []ReqOption{Gzip{}, Form{"a": "1"}}}, want: `{"a":"1"}`}, + {args: args{opts: []ReqOption{Gzip{}, Jsons{{"a": "1", "b": 2}}}}, want: `[{"a":"1","b":2}]`}, + {args: args{opts: []ReqOption{Gzip{}, Jsons{{"a": "1", "b": 2}, {"c": 0}}}}, want: `[{"a":"1","b":2},{"c":0}]`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp, err := Post(url, tt.args.opts...) + if err != nil { + t.Errorf("Json() got err = %v", err) + } + got := resp.Text() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Json() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/request.go b/request.go index 11cd0a7..afb1aa9 100644 --- a/request.go +++ b/request.go @@ -2,6 +2,7 @@ package requests import ( "bytes" + "compress/gzip" "context" "fmt" "github.com/ajg/form" @@ -19,6 +20,7 @@ type Request struct { form Form json Json jsons Jsons + gzip bool } // NewRequest wraps NewRequestWithContext using the background context. @@ -48,8 +50,13 @@ func (req *Request) loadBody() error { if err != nil { return errors.Wrap(ErrInvalidJson, err.Error()) } - jsonBuffer := bytes.NewBuffer(jsonBytes) - req.Body = ioutil.NopCloser(jsonBuffer) + if req.gzip { + if jsonBytes, err = req.compressed(jsonBytes); err != nil { + return err + } + } + buf := bytes.NewBuffer(jsonBytes) + req.Body = ioutil.NopCloser(buf) return nil } // application/x-www-form-urlencoded @@ -59,6 +66,13 @@ func (req *Request) loadBody() error { if err != nil { return errors.Wrap(ErrInvalidForm, err.Error()) } + if req.gzip { + if compressedData, err := req.compressed([]byte(data)); err != nil { + return err + } else { + data = string(compressedData) + } + } dataReader := strings.NewReader(data) req.Body = ioutil.NopCloser(dataReader) return nil @@ -86,7 +100,29 @@ func (req *Request) loadBody() error { if err := multipartWriter.Close(); err != nil { return err } + if req.gzip { + if compressedData, err := req.compressed(buffer.Bytes()); err != nil { + return err + } else { + buffer = bytes.NewBuffer(compressedData) + } + } req.Body = ioutil.NopCloser(buffer) req.Header.Add("content-Type", multipartWriter.FormDataContentType()) return nil } + +func (req *Request) compressed(data []byte) ([]byte, error) { + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + if _, err := gw.Write(data); err != nil { + return nil, err + } + if err := gw.Close(); err != nil { + return nil, err + } + + req.Header.Set("Content-Encoding", "gzip") + req.Header.Set("Accept-Encoding", "gzip") + return buf.Bytes(), nil +} diff --git a/response.go b/response.go index 7370012..7a5521c 100644 --- a/response.go +++ b/response.go @@ -1,6 +1,8 @@ package requests import ( + "compress/gzip" + "io" "io/ioutil" "net/http" "os" @@ -25,6 +27,12 @@ func (r *Response) Text() string { func (r *Response) Bytes() ([]byte, error) { if r.bytes == nil { + var err error + if r.Header.Get("Content-Encoding") == "gzip" { + if r.Body, err = r.decompressed(r.Body); err != nil { + return nil, err + } + } data, err := ioutil.ReadAll(r.Body) if err != nil { return nil, err @@ -56,3 +64,7 @@ func (r Response) SaveFile(filename string) error { return err } } + +func (r *Response) decompressed(reader io.Reader) (io.ReadCloser, error) { + return gzip.NewReader(reader) +} diff --git a/server_test.go b/server_test.go index 73bbc20..0a0cd2d 100644 --- a/server_test.go +++ b/server_test.go @@ -1,6 +1,7 @@ package requests import ( + "compress/gzip" "fmt" "io/ioutil" "net/http" @@ -27,6 +28,14 @@ func getHandler(w http.ResponseWriter, r *http.Request) { func postHandler(w http.ResponseWriter, r *http.Request) { contentType := r.Header.Get("content-Type") + if r.Header.Get("Content-Encoding") == "gzip" { + var err error + if r.Body, err = gzip.NewReader(r.Body); err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("gzip body: " + err.Error())) + return + } + } switch contentType { case "application/x-www-form-urlencoded": if err := r.ParseForm(); err != nil {