From aee4482f35c93e9c7bc79e188afb157a99257209 Mon Sep 17 00:00:00 2001 From: Javad Rajabzadeh Date: Sun, 11 Aug 2024 18:15:53 +0330 Subject: [PATCH] fix(util): chunked download to improve download speed (#1459) --- util/downloader/chunk.go | 31 ++++++ util/downloader/chunk_test.go | 112 +++++++++++++++++++ util/downloader/downloader.go | 202 ++++++++++++++++------------------ util/downloader/errors.go | 18 +++ 4 files changed, 257 insertions(+), 106 deletions(-) create mode 100644 util/downloader/chunk.go create mode 100644 util/downloader/chunk_test.go create mode 100644 util/downloader/errors.go diff --git a/util/downloader/chunk.go b/util/downloader/chunk.go new file mode 100644 index 000000000..1d9446c95 --- /dev/null +++ b/util/downloader/chunk.go @@ -0,0 +1,31 @@ +package downloader + +import "fmt" + +type chunk struct { + start, end int64 +} + +func createChunks(contentLength, totalChunks int64) []*chunk { + chunks := make([]*chunk, 0, totalChunks) + chunkSize := contentLength / totalChunks + for i := int64(0); i < totalChunks; i++ { + start := i * chunkSize + end := start + chunkSize - 1 + // adjust the end for the last chunk + if i == totalChunks-1 { + end = contentLength - 1 + } + chunks = append(chunks, &chunk{start: start, end: end}) + } + + return chunks +} + +func (c *chunk) rangeHeader() string { + return fmt.Sprintf("bytes=%d-%d", c.start, c.end) +} + +func (c *chunk) size() int64 { + return (c.end + 1) - c.start +} diff --git a/util/downloader/chunk_test.go b/util/downloader/chunk_test.go new file mode 100644 index 000000000..efdb7f85d --- /dev/null +++ b/util/downloader/chunk_test.go @@ -0,0 +1,112 @@ +package downloader + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCreateChunks(t *testing.T) { + tests := []struct { + contentLength int64 + totalChunks int64 + expected []*chunk + }{ + { + contentLength: 181403648, + totalChunks: 16, + expected: []*chunk{ + {start: 0, end: 11337727}, + {start: 11337728, end: 22675455}, + {start: 22675456, end: 34013183}, + {start: 34013184, end: 45350911}, + {start: 45350912, end: 56688639}, + {start: 56688640, end: 68026367}, + {start: 68026368, end: 79364095}, + {start: 79364096, end: 90701823}, + {start: 90701824, end: 102039551}, + {start: 102039552, end: 113377279}, + {start: 113377280, end: 124715007}, + {start: 124715008, end: 136052735}, + {start: 136052736, end: 147390463}, + {start: 147390464, end: 158728191}, + {start: 158728192, end: 170065919}, + {start: 170065920, end: 181403647}, + }, + }, + { + contentLength: 10, + totalChunks: 3, + expected: []*chunk{ + {start: 0, end: 2}, + {start: 3, end: 5}, + {start: 6, end: 9}, + }, + }, + { + contentLength: 10, + totalChunks: 1, + expected: []*chunk{ + {start: 0, end: 9}, + }, + }, + { + contentLength: 0, + totalChunks: 1, + expected: []*chunk{ + {start: 0, end: -1}, + }, + }, + } + + for _, tt := range tests { + actual := createChunks(tt.contentLength, tt.totalChunks) + assert.Equal(t, tt.expected, actual) + } +} + +func TestChunkRangeHeader(t *testing.T) { + tests := []struct { + chunk chunk + expected string + }{ + { + chunk: chunk{start: 0, end: 499}, + expected: "bytes=0-499", + }, + { + chunk: chunk{start: 500, end: 999}, + expected: "bytes=500-999", + }, + } + + for _, tt := range tests { + actual := tt.chunk.rangeHeader() + assert.Equal(t, tt.expected, actual) + } +} + +func TestChunkSize(t *testing.T) { + tests := []struct { + chunk chunk + expected int64 + }{ + { + chunk: chunk{start: 0, end: 499}, + expected: 500, + }, + { + chunk: chunk{start: 500, end: 999}, + expected: 500, + }, + { + chunk: chunk{start: 0, end: 0}, + expected: 1, + }, + } + + for _, tt := range tests { + actual := tt.chunk.size() + assert.Equal(t, tt.expected, actual) + } +} diff --git a/util/downloader/downloader.go b/util/downloader/downloader.go index ee842bd89..06e2d2567 100644 --- a/util/downloader/downloader.go +++ b/util/downloader/downloader.go @@ -6,24 +6,16 @@ import ( "encoding/hex" "errors" "fmt" - "hash" "io" "net/http" "os" "path/filepath" + "sync" ) -var ( - ErrHeaderRequest = errors.New("request header error") - ErrSHA256Mismatch = errors.New("sha256 mismatch") - ErrCreateDir = errors.New("create dir error") - ErrInvalidFilePath = errors.New("file path is a directory, not a file") - ErrGetFileInfo = errors.New("get file info error") - ErrCopyExistsFileData = errors.New("error copying existing file data") - ErrDoRequest = errors.New("error doing request") - ErrFileWriting = errors.New("error writing file") - ErrNewRequest = errors.New("error creating request") - ErrOpenFileExists = errors.New("error opening existing file") +const ( + _defaultConcurrencyPerChunk = 16 + _defaultMinSizeForChunk = 1 << 20 ) type Downloader struct { @@ -35,6 +27,11 @@ type Downloader struct { fileName string statsCh chan Stats errCh chan error + + chunks []*chunk + + mu sync.Mutex + downloaded int64 } type Stats struct { @@ -56,6 +53,7 @@ func New(url, filePath, sha256Sum string, opts ...Option) *Downloader { url: url, filePath: filePath, sha256Sum: sha256Sum, + chunks: make([]*chunk, 0, _defaultConcurrencyPerChunk), statsCh: make(chan Stats), errCh: make(chan error, 1), } @@ -96,7 +94,7 @@ func (d *Downloader) download(ctx context.Context) { return } - out, err := d.openFile() + out, err := os.OpenFile(d.filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600) if err != nil { d.handleError(err) @@ -106,13 +104,25 @@ func (d *Downloader) download(ctx context.Context) { _ = out.Close() }() - if err := d.validateExistingFile(out, &stats); err != nil { - d.handleError(err) + d.statsCh <- stats - return + var wg sync.WaitGroup + for _, c := range d.chunks { + wg.Add(1) + go func(c *chunk) { + defer wg.Done() + err := d.downloadChunkWithContext(ctx, out, c, stats.TotalSize) + if err != nil { + d.handleError(err) + + return + } + }(c) } - if err := d.downloadFile(ctx, out, &stats); err != nil { + wg.Wait() + + if err := d.finalizeDownload(&stats); err != nil { d.handleError(err) } } @@ -120,12 +130,12 @@ func (d *Downloader) download(ctx context.Context) { func (d *Downloader) getHeader(ctx context.Context) (Stats, error) { req, err := http.NewRequestWithContext(ctx, http.MethodHead, d.url, http.NoBody) if err != nil { - return Stats{}, ErrHeaderRequest + return Stats{}, &Error{Message: "failed to create new request for get header", Reason: err} } resp, err := d.client.Do(req) if err != nil { - return Stats{}, ErrHeaderRequest + return Stats{}, &Error{Message: "failed to do request get header", Reason: err} } defer func() { @@ -134,6 +144,15 @@ func (d *Downloader) getHeader(ctx context.Context) (Stats, error) { d.fileType = resp.Header.Get("Content-Type") + if resp.ContentLength > _defaultMinSizeForChunk { + d.chunks = createChunks(resp.ContentLength, _defaultConcurrencyPerChunk) + } else { + d.chunks = append(d.chunks, &chunk{ + start: 0, + end: resp.ContentLength, + }) + } + return Stats{ TotalSize: resp.ContentLength, }, nil @@ -142,126 +161,97 @@ func (d *Downloader) getHeader(ctx context.Context) (Stats, error) { func (d *Downloader) createDir() error { dir := filepath.Dir(d.filePath) if err := os.MkdirAll(dir, 0o750); err != nil { - return ErrCreateDir + return &Error{Message: "failed to create file path directory", Reason: err} } return nil } -func (d *Downloader) openFile() (*os.File, error) { - fileInfo, err := os.Stat(d.filePath) - if err == nil && fileInfo.IsDir() { - return nil, ErrInvalidFilePath - } - - return os.OpenFile(d.filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600) -} - -func (*Downloader) validateExistingFile(out *os.File, stats *Stats) error { - fileInfo, err := out.Stat() - if err != nil { - return ErrGetFileInfo - } - stats.Downloaded = fileInfo.Size() - - return nil -} - -func (d *Downloader) downloadFile(ctx context.Context, out *os.File, stats *Stats) error { - req, err := d.createRequest(ctx, stats.Downloaded) +func (d *Downloader) downloadChunkWithContext(ctx context.Context, out *os.File, c *chunk, totalSize int64) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, http.NoBody) if err != nil { - return err + return &Error{Message: "failed to create new request for download chunk", Reason: err} } + req.Header.Set("Range", c.rangeHeader()) resp, err := d.client.Do(req) if err != nil { - return ErrDoRequest + return &Error{Message: "failed to do request download chunk", Reason: err} } defer func() { _ = resp.Body.Close() }() - buffer := make([]byte, 32*1024) - hasher := sha256.New() - - if err := d.updateHasherWithExistingData(stats.Downloaded, hasher); err != nil { - return err - } - - return d.writeToFile(ctx, resp, out, buffer, hasher, stats) -} - -func (d *Downloader) createRequest(ctx context.Context, downloaded int64) (*http.Request, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, http.NoBody) - if err != nil { - return nil, ErrNewRequest - } - if downloaded > 0 { - req.Header.Set("Range", fmt.Sprintf("bytes=%d-", downloaded)) + if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { + return &Error{ + Message: "response has invalid status code", + Reason: fmt.Errorf("got http response %s from %s: %w", resp.Status, d.url, err), + } } - return req, nil -} - -func (d *Downloader) updateHasherWithExistingData(downloaded int64, hasher io.Writer) error { - if downloaded > 0 { - existingFile, err := os.Open(d.filePath) - if err != nil { - return ErrOpenFileExists + buf := make([]byte, 32*1024) // 32KB buffer for reading the response body + offset := c.start + for { + n, err := resp.Body.Read(buf) + if n > 0 { + d.mu.Lock() + for written := 0; written < n; { + w, err := out.WriteAt(buf[written:n], offset+int64(written)) + if err != nil { + d.mu.Unlock() + + return &Error{Message: "failed write data into file", Reason: err} + } + written += w + } + offset += int64(n) + d.downloaded += int64(n) + d.updateStats(d.downloaded, totalSize) + d.mu.Unlock() } - defer func() { - _ = existingFile.Close() - }() + if err != nil { + // if error is io.EOF stop write for loop response body. + if errors.Is(err, io.EOF) { + break + } - if _, err := io.CopyN(hasher, existingFile, downloaded); err != nil { - return ErrCopyExistsFileData + return &Error{Message: "error read body download chunk", Reason: err} } } return nil } -func (d *Downloader) writeToFile(ctx context.Context, resp *http.Response, out *os.File, buffer []byte, - hasher hash.Hash, stats *Stats, -) error { - for { - select { - case <-ctx.Done(): - d.stop() - - return ctx.Err() - default: - n, err := resp.Body.Read(buffer) - if n > 0 { - if _, err := out.Write(buffer[:n]); err != nil { - return ErrFileWriting - } - - if _, err := hasher.Write(buffer[:n]); err != nil { - return ErrFileWriting - } +func (d *Downloader) updateStats(downloaded, totalSize int64) { + stats := Stats{ + Downloaded: downloaded, + TotalSize: totalSize, + Percent: float64(downloaded) / float64(totalSize) * 100, + } + d.statsCh <- stats +} - stats.Downloaded += int64(n) - stats.Percent = float64(stats.Downloaded) / float64(stats.TotalSize) * 100 - d.statsCh <- *stats - } - if err != nil { - if err == io.EOF { - return d.finalizeDownload(hasher, stats) - } +func (d *Downloader) finalizeDownload(stats *Stats) error { + // Recalculate the hash by re-reading the entire file + out, err := os.Open(d.filePath) + if err != nil { + return &Error{Message: "failed to open file", Reason: err} + } + defer func() { + _ = out.Close() + }() - return fmt.Errorf("error reading response body: %w", err) - } - } + hasher := sha256.New() + if _, err := io.Copy(hasher, out); err != nil { + return &Error{Message: "failed copy file data to hasher for calculate hash", Reason: err} } -} -func (d *Downloader) finalizeDownload(hasher hash.Hash, stats *Stats) error { stats.Completed = true + stats.Percent = 100 sum := hex.EncodeToString(hasher.Sum(nil)) if sum != d.sha256Sum { - return ErrSHA256Mismatch + return &Error{Message: "sha256 mismatch", Reason: err} } d.statsCh <- *stats diff --git a/util/downloader/errors.go b/util/downloader/errors.go new file mode 100644 index 000000000..b11669bf5 --- /dev/null +++ b/util/downloader/errors.go @@ -0,0 +1,18 @@ +package downloader + +import ( + "fmt" +) + +type Error struct { + Message string + Reason error +} + +func (e *Error) Error() string { + return fmt.Sprintf("%s: %s", e.Message, e.Reason.Error()) +} + +func (e *Error) Unwrap() error { + return e.Reason +}