Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add multipart live upload progress callback #287, #379 #880

Merged
merged 1 commit into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,8 @@ func createMultipart(w *multipart.Writer, r *Request) error {
return err
}

partWriter = mf.wrapProgressCallbackIfPresent(partWriter)

if _, err = partWriter.Write(p[:size]); err != nil {
return err
}
Expand Down
133 changes: 112 additions & 21 deletions multipart.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,52 +21,143 @@
// MultipartField struct represents the multipart field to compose
// all [io.Reader] capable input for multipart form request
type MultipartField struct {
io.Reader
Name string
FileName string
// Name of the multipart field name that the server expects it
Name string

// FileName is used to set the file name we have to send to the server
FileName string

// ContentType is a multipart file content-type value. It is highly
// recommended setting it if you know the content-type so that Resty
// don't have to do additional computing to auto-detect (Optional)
ContentType string

filePath string
// Reader is an input of [io.Reader] for multipart upload. It
// is optional if you set the FilePath value
Reader io.Reader

// FilePath is a file path for multipart upload. It
// is optional if you set the Reader value
FilePath string

// FileSize in bytes is used just for the information purpose of
// sharing via [MultipartFieldCallbackFunc] (Optional)
FileSize int64

// ProgressCallback function is used to provide live progress details
// during a multipart upload (Optional)
//
// NOTE: It is recommended to set the FileSize value when using
// ProgressCallback feature so that Resty sends the FileSize
// value via [MultipartFieldProgress]
ProgressCallback MultipartFieldCallbackFunc
}

// Clone method returns the deep copy of m except [io.Reader].
func (m *MultipartField) Clone() *MultipartField {
mm := new(MultipartField)
*mm = *m
return mm
func (mf *MultipartField) Clone() *MultipartField {
mf2 := new(MultipartField)
*mf2 = *mf
return mf2
}

func (m *MultipartField) resetReader() error {
if rs, ok := m.Reader.(io.ReadSeeker); ok {
func (mf *MultipartField) resetReader() error {
if rs, ok := mf.Reader.(io.ReadSeeker); ok {
_, err := rs.Seek(0, io.SeekStart)
return err
}
return nil
}

func (m *MultipartField) close() {
closeq(m.Reader)
func (mf *MultipartField) close() {
closeq(mf.Reader)
}

func (m *MultipartField) createHeader() textproto.MIMEHeader {
func (mf *MultipartField) createHeader() textproto.MIMEHeader {
h := make(textproto.MIMEHeader)
if isStringEmpty(m.FileName) {
if isStringEmpty(mf.FileName) {
h.Set(hdrContentDisposition,
fmt.Sprintf(`form-data; name="%s"`, escapeQuotes(m.Name)))
fmt.Sprintf(`form-data; name="%s"`, escapeQuotes(mf.Name)))
} else {
h.Set(hdrContentDisposition,
fmt.Sprintf(`form-data; name="%s"; filename="%s"`,
escapeQuotes(m.Name), escapeQuotes(m.FileName)))
escapeQuotes(mf.Name), escapeQuotes(mf.FileName)))
}
if !isStringEmpty(m.ContentType) {
h.Set(hdrContentTypeKey, m.ContentType)
if !isStringEmpty(mf.ContentType) {
h.Set(hdrContentTypeKey, mf.ContentType)
}
return h
}

func (m *MultipartField) openFileIfRequired() (err error) {
if m.Reader == nil && !isStringEmpty(m.filePath) {
m.Reader, err = os.Open(m.filePath)
func (mf *MultipartField) openFileIfRequired() error {
if isStringEmpty(mf.FilePath) || mf.Reader != nil {
return nil
}

file, err := os.Open(mf.FilePath)
if err != nil {
return err
}

fileStat, err := file.Stat()
if err != nil {
return err

Check warning on line 103 in multipart.go

View check run for this annotation

Codecov / codecov/patch

multipart.go#L103

Added line #L103 was not covered by tests
}

mf.Reader = file
mf.FileSize = fileStat.Size()

return nil
}

func (mf *MultipartField) wrapProgressCallbackIfPresent(pw io.Writer) io.Writer {
if mf.ProgressCallback == nil {
return pw
}

return &multipartProgressWriter{
w: pw,
f: func(pb int64) {
mf.ProgressCallback(MultipartFieldProgress{
Name: mf.Name,
FileName: mf.FileName,
FileSize: mf.FileSize,
Written: pb,
})
},
}
}

// MultipartFieldCallbackFunc function used to transmit live multipart upload
// progress in bytes count
type MultipartFieldCallbackFunc func(MultipartFieldProgress)

// MultipartFieldProgress struct used to provide multipart field upload progress
// details via callback function
type MultipartFieldProgress struct {
Name string
FileName string
FileSize int64
Written int64
}

// String method creates the string representation of [MultipartFieldProgress]
func (mfp MultipartFieldProgress) String() string {
return fmt.Sprintf("FieldName: %s, FileName: %s, FileSize: %v, Written: %v",
mfp.Name, mfp.FileName, mfp.FileSize, mfp.Written)
}

type multipartProgressWriter struct {
w io.Writer
pb int64
f func(int64)
}

func (mpw *multipartProgressWriter) Write(p []byte) (n int, err error) {
n, err = mpw.w.Write(p)
if n <= 0 {
return
}
mpw.pb += int64(n)
mpw.f(mpw.pb)
return
}
88 changes: 83 additions & 5 deletions multipart_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,77 @@ func TestMultipartLargeFile(t *testing.T) {
})
}

func TestMultipartFieldProgressCallback(t *testing.T) {
ts := createFormPostServer(t)
defer ts.Close()
defer cleanupFiles(".testdata/upload")

file1, _ := os.Open(filepath.Join(getTestDataPath(), "test-img.png"))
file1Stat, _ := file1.Stat()

fileName2 := "50mbfile.bin"
filePath2 := createBinFile(fileName2, 50<<20)
defer cleanupFiles(filePath2)
file2, _ := os.Open(filePath2)
file2Stat, _ := file2.Stat()

fileName3 := "100mbfile.bin"
filePath3 := createBinFile(fileName3, 100<<20)
defer cleanupFiles(filePath3)
file3, _ := os.Open(filePath3)
file3Stat, _ := file3.Stat()

progressCallback := func(mp MultipartFieldProgress) {
t.Logf("%s\n", mp)
}

fields := []*MultipartField{
{
Name: "test-image-1",
FileName: "test-image-1.png",
ContentType: "image/png",
Reader: file1,
FileSize: file1Stat.Size(),
ProgressCallback: progressCallback,
},
{
Name: "50mbfile",
FileName: fileName2,
Reader: file2,
FileSize: file2Stat.Size(),
ProgressCallback: progressCallback,
},
{
Name: "100mbfile",
FileName: fileName3,
Reader: file3,
FileSize: file3Stat.Size(),
ProgressCallback: progressCallback,
},
}

c := dcnld()

r := c.R().
SetFormData(map[string]string{"first_name": "Jeevanandam", "last_name": "M"}).
SetMultipartFields(fields...)
resp, err := r.Post(ts.URL + "/upload")

responseStr := resp.String()

assertError(t, err)
assertEqual(t, http.StatusOK, resp.StatusCode())
assertEqual(t, true, strings.Contains(responseStr, "test-image-1.png"))
assertEqual(t, true, strings.Contains(responseStr, "50mbfile.bin"))
assertEqual(t, true, strings.Contains(responseStr, "100mbfile.bin"))
}

var errTestErrorReader = errors.New("fake")

type errorReader struct{}

func (errorReader) Read(p []byte) (n int, err error) {
return 0, errors.New("fake")
return 0, errTestErrorReader
}

func TestMultipartReaderErrors(t *testing.T) {
Expand All @@ -370,8 +437,6 @@ func TestMultipartReaderErrors(t *testing.T) {

c := dcnl().SetBaseURL(ts.URL)

testErr := errors.New("fake")

t.Run("multipart fields with errorReader", func(t *testing.T) {
resp, err := c.R().
SetMultipartFields(&MultipartField{
Expand All @@ -382,7 +447,7 @@ func TestMultipartReaderErrors(t *testing.T) {
Post("/upload")

assertNotNil(t, err)
assertEqual(t, testErr, err)
assertEqual(t, errTestErrorReader, err)
assertNotNil(t, resp)
assertEqual(t, nil, resp.Body)
})
Expand All @@ -393,7 +458,7 @@ func TestMultipartReaderErrors(t *testing.T) {
Post("/upload")

assertNotNil(t, err)
assertEqual(t, testErr, err)
assertEqual(t, errTestErrorReader, err)
assertNotNil(t, resp)
assertEqual(t, nil, resp.Body)
})
Expand All @@ -410,11 +475,24 @@ func TestMultipartReaderErrors(t *testing.T) {
})
}

type returnValueTestWriter struct {
}

func (z *returnValueTestWriter) Write(p []byte) (n int, err error) {
return 0, nil
}

func TestMultipartCornerCoverage(t *testing.T) {
mf := &MultipartField{
Name: "foo",
Reader: bytes.NewBufferString("I have no seek capability"),
}
err := mf.resetReader()
assertNil(t, err)

// wrap test writer to return 0 written value
mpw := multipartProgressWriter{w: &returnValueTestWriter{}}
n, err := mpw.Write([]byte("test return value"))
assertNil(t, err)
assertEqual(t, 0, n)
}
Loading