Skip to content

Commit

Permalink
feat: add multipart live upload progress callback #287, #379
Browse files Browse the repository at this point in the history
  • Loading branch information
jeevatkm committed Oct 5, 2024
1 parent 0d0bc1b commit acd1953
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 31 deletions.
2 changes: 2 additions & 0 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,8 @@ func createMultipart(w *multipart.Writer, r *Request) error {
return err
}

partWriter = mf.wrapProgressCallbackIfPresent(partWriter)

if _, err = partWriter.Write(p[:size]); err != nil {
return err
}
Expand Down
133 changes: 112 additions & 21 deletions multipart.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,52 +21,143 @@ func escapeQuotes(s string) string {
// MultipartField struct represents the multipart field to compose
// all [io.Reader] capable input for multipart form request
type MultipartField struct {
io.Reader
Name string
FileName string
// Name of the multipart field name that the server expects it
Name string

// FileName is used to set the file name we have to send to the server
FileName string

// ContentType is a multipart file content-type value. It is highly
// recommended setting it if you know the content-type so that Resty
// don't have to do additional computing to auto-detect (Optional)
ContentType string

filePath string
// Reader is an input of [io.Reader] for multipart upload. It
// is optional if you set the FilePath value
Reader io.Reader

// FilePath is a file path for multipart upload. It
// is optional if you set the Reader value
FilePath string

// FileSize in bytes is used just for the information purpose of
// sharing via [MultipartFieldCallbackFunc] (Optional)
FileSize int64

// ProgressCallback function is used to provide live progress details
// during a multipart upload (Optional)
//
// NOTE: It is recommended to set the FileSize value when using
// ProgressCallback feature so that Resty sends the FileSize
// value via [MultipartFieldProgress]
ProgressCallback MultipartFieldCallbackFunc
}

// Clone method returns the deep copy of m except [io.Reader].
func (m *MultipartField) Clone() *MultipartField {
mm := new(MultipartField)
*mm = *m
return mm
func (mf *MultipartField) Clone() *MultipartField {
mf2 := new(MultipartField)
*mf2 = *mf
return mf2
}

func (m *MultipartField) resetReader() error {
if rs, ok := m.Reader.(io.ReadSeeker); ok {
func (mf *MultipartField) resetReader() error {
if rs, ok := mf.Reader.(io.ReadSeeker); ok {
_, err := rs.Seek(0, io.SeekStart)
return err
}
return nil
}

func (m *MultipartField) close() {
closeq(m.Reader)
func (mf *MultipartField) close() {
closeq(mf.Reader)
}

func (m *MultipartField) createHeader() textproto.MIMEHeader {
func (mf *MultipartField) createHeader() textproto.MIMEHeader {
h := make(textproto.MIMEHeader)
if isStringEmpty(m.FileName) {
if isStringEmpty(mf.FileName) {
h.Set(hdrContentDisposition,
fmt.Sprintf(`form-data; name="%s"`, escapeQuotes(m.Name)))
fmt.Sprintf(`form-data; name="%s"`, escapeQuotes(mf.Name)))
} else {
h.Set(hdrContentDisposition,
fmt.Sprintf(`form-data; name="%s"; filename="%s"`,
escapeQuotes(m.Name), escapeQuotes(m.FileName)))
escapeQuotes(mf.Name), escapeQuotes(mf.FileName)))
}
if !isStringEmpty(m.ContentType) {
h.Set(hdrContentTypeKey, m.ContentType)
if !isStringEmpty(mf.ContentType) {
h.Set(hdrContentTypeKey, mf.ContentType)
}
return h
}

func (m *MultipartField) openFileIfRequired() (err error) {
if m.Reader == nil && !isStringEmpty(m.filePath) {
m.Reader, err = os.Open(m.filePath)
func (mf *MultipartField) openFileIfRequired() error {
if isStringEmpty(mf.FilePath) || mf.Reader != nil {
return nil
}

file, err := os.Open(mf.FilePath)
if err != nil {
return err
}

fileStat, err := file.Stat()
if err != nil {
return err

Check warning on line 103 in multipart.go

View check run for this annotation

Codecov / codecov/patch

multipart.go#L103

Added line #L103 was not covered by tests
}

mf.Reader = file
mf.FileSize = fileStat.Size()

return nil
}

func (mf *MultipartField) wrapProgressCallbackIfPresent(pw io.Writer) io.Writer {
if mf.ProgressCallback == nil {
return pw
}

return &multipartProgressWriter{
w: pw,
f: func(pb int64) {
mf.ProgressCallback(MultipartFieldProgress{
Name: mf.Name,
FileName: mf.FileName,
FileSize: mf.FileSize,
Written: pb,
})
},
}
}

// MultipartFieldCallbackFunc function used to transmit live multipart upload
// progress in bytes count
type MultipartFieldCallbackFunc func(MultipartFieldProgress)

// MultipartFieldProgress struct used to provide multipart field upload progress
// details via callback function
type MultipartFieldProgress struct {
Name string
FileName string
FileSize int64
Written int64
}

// String method creates the string representation of [MultipartFieldProgress]
func (mfp MultipartFieldProgress) String() string {
return fmt.Sprintf("FieldName: %s, FileName: %s, FileSize: %v, Written: %v",
mfp.Name, mfp.FileName, mfp.FileSize, mfp.Written)
}

type multipartProgressWriter struct {
w io.Writer
pb int64
f func(int64)
}

func (mpw *multipartProgressWriter) Write(p []byte) (n int, err error) {
n, err = mpw.w.Write(p)
if n <= 0 {
return
}
mpw.pb += int64(n)
mpw.f(mpw.pb)
return
}
88 changes: 83 additions & 5 deletions multipart_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,77 @@ func TestMultipartLargeFile(t *testing.T) {
})
}

func TestMultipartFieldProgressCallback(t *testing.T) {
ts := createFormPostServer(t)
defer ts.Close()
defer cleanupFiles(".testdata/upload")

file1, _ := os.Open(filepath.Join(getTestDataPath(), "test-img.png"))
file1Stat, _ := file1.Stat()

fileName2 := "50mbfile.bin"
filePath2 := createBinFile(fileName2, 50<<20)
defer cleanupFiles(filePath2)
file2, _ := os.Open(filePath2)
file2Stat, _ := file2.Stat()

fileName3 := "100mbfile.bin"
filePath3 := createBinFile(fileName3, 100<<20)
defer cleanupFiles(filePath3)
file3, _ := os.Open(filePath3)
file3Stat, _ := file3.Stat()

progressCallback := func(mp MultipartFieldProgress) {
t.Logf("%s\n", mp)
}

fields := []*MultipartField{
{
Name: "test-image-1",
FileName: "test-image-1.png",
ContentType: "image/png",
Reader: file1,
FileSize: file1Stat.Size(),
ProgressCallback: progressCallback,
},
{
Name: "50mbfile",
FileName: fileName2,
Reader: file2,
FileSize: file2Stat.Size(),
ProgressCallback: progressCallback,
},
{
Name: "100mbfile",
FileName: fileName3,
Reader: file3,
FileSize: file3Stat.Size(),
ProgressCallback: progressCallback,
},
}

c := dcnld()

r := c.R().
SetFormData(map[string]string{"first_name": "Jeevanandam", "last_name": "M"}).
SetMultipartFields(fields...)
resp, err := r.Post(ts.URL + "/upload")

responseStr := resp.String()

assertError(t, err)
assertEqual(t, http.StatusOK, resp.StatusCode())
assertEqual(t, true, strings.Contains(responseStr, "test-image-1.png"))
assertEqual(t, true, strings.Contains(responseStr, "50mbfile.bin"))
assertEqual(t, true, strings.Contains(responseStr, "100mbfile.bin"))
}

var errTestErrorReader = errors.New("fake")

type errorReader struct{}

func (errorReader) Read(p []byte) (n int, err error) {
return 0, errors.New("fake")
return 0, errTestErrorReader
}

func TestMultipartReaderErrors(t *testing.T) {
Expand All @@ -370,8 +437,6 @@ func TestMultipartReaderErrors(t *testing.T) {

c := dcnl().SetBaseURL(ts.URL)

testErr := errors.New("fake")

t.Run("multipart fields with errorReader", func(t *testing.T) {
resp, err := c.R().
SetMultipartFields(&MultipartField{
Expand All @@ -382,7 +447,7 @@ func TestMultipartReaderErrors(t *testing.T) {
Post("/upload")

assertNotNil(t, err)
assertEqual(t, testErr, err)
assertEqual(t, errTestErrorReader, err)
assertNotNil(t, resp)
assertEqual(t, nil, resp.Body)
})
Expand All @@ -393,7 +458,7 @@ func TestMultipartReaderErrors(t *testing.T) {
Post("/upload")

assertNotNil(t, err)
assertEqual(t, testErr, err)
assertEqual(t, errTestErrorReader, err)
assertNotNil(t, resp)
assertEqual(t, nil, resp.Body)
})
Expand All @@ -410,11 +475,24 @@ func TestMultipartReaderErrors(t *testing.T) {
})
}

type returnValueTestWriter struct {
}

func (z *returnValueTestWriter) Write(p []byte) (n int, err error) {
return 0, nil
}

func TestMultipartCornerCoverage(t *testing.T) {
mf := &MultipartField{
Name: "foo",
Reader: bytes.NewBufferString("I have no seek capability"),
}
err := mf.resetReader()
assertNil(t, err)

// wrap test writer to return 0 written value
mpw := multipartProgressWriter{w: &returnValueTestWriter{}}
n, err := mpw.Write([]byte("test return value"))
assertNil(t, err)
assertEqual(t, 0, n)
}
Loading

0 comments on commit acd1953

Please sign in to comment.