-
Notifications
You must be signed in to change notification settings - Fork 143
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(util): file downloader with verify sha256 hash (#1422)
- Loading branch information
Showing
3 changed files
with
395 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,284 @@ | ||
package downloader | ||
|
||
import ( | ||
"context" | ||
"crypto/sha256" | ||
"encoding/hex" | ||
"errors" | ||
"fmt" | ||
"hash" | ||
"io" | ||
"net/http" | ||
"os" | ||
"path/filepath" | ||
) | ||
|
||
var ( | ||
ErrHeaderRequest = errors.New("request header error") | ||
ErrSHA256Mismatch = errors.New("sha256 mismatch") | ||
ErrCreateDir = errors.New("create dir error") | ||
ErrInvalidFilePath = errors.New("file path is a directory, not a file") | ||
ErrGetFileInfo = errors.New("get file info error") | ||
ErrCopyExistsFileData = errors.New("error copying existing file data") | ||
ErrDoRequest = errors.New("error doing request") | ||
ErrFileWriting = errors.New("error writing file") | ||
ErrNewRequest = errors.New("error creating request") | ||
ErrOpenFileExists = errors.New("error opening existing file") | ||
) | ||
|
||
type Downloader struct { | ||
client *http.Client | ||
url string | ||
filePath string | ||
sha256Sum string | ||
fileType string | ||
fileName string | ||
statsCh chan Stats | ||
errCh chan error | ||
} | ||
|
||
type Stats struct { | ||
Downloaded int64 | ||
TotalSize int64 | ||
Percent float64 | ||
Completed bool | ||
} | ||
|
||
func New(url, filePath, sha256Sum string, opts ...Option) *Downloader { | ||
opt := defaultOptions() | ||
|
||
for _, o := range opts { | ||
o(opt) | ||
} | ||
|
||
return &Downloader{ | ||
client: opt.client, | ||
url: url, | ||
filePath: filePath, | ||
sha256Sum: sha256Sum, | ||
statsCh: make(chan Stats), | ||
errCh: make(chan error, 1), | ||
} | ||
} | ||
|
||
func (d *Downloader) Start(ctx context.Context) { | ||
go d.download(ctx) | ||
} | ||
|
||
func (d *Downloader) Stats() <-chan Stats { | ||
return d.statsCh | ||
} | ||
|
||
func (d *Downloader) FileType() string { | ||
return d.fileType | ||
} | ||
|
||
func (d *Downloader) FileName() string { | ||
return d.fileName | ||
} | ||
|
||
func (d *Downloader) Errors() <-chan error { | ||
return d.errCh | ||
} | ||
|
||
func (d *Downloader) download(ctx context.Context) { | ||
stats, err := d.getHeader(ctx) | ||
if err != nil { | ||
d.handleError(err) | ||
|
||
return | ||
} | ||
|
||
d.fileName = filepath.Base(d.filePath) | ||
if err := d.createDir(); err != nil { | ||
d.handleError(err) | ||
|
||
return | ||
} | ||
|
||
out, err := d.openFile() | ||
if err != nil { | ||
d.handleError(err) | ||
|
||
return | ||
} | ||
defer func() { | ||
_ = out.Close() | ||
}() | ||
|
||
if err := d.validateExistingFile(out, &stats); err != nil { | ||
d.handleError(err) | ||
|
||
return | ||
} | ||
|
||
if err := d.downloadFile(ctx, out, &stats); err != nil { | ||
d.handleError(err) | ||
} | ||
} | ||
|
||
func (d *Downloader) getHeader(ctx context.Context) (Stats, error) { | ||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, d.url, http.NoBody) | ||
if err != nil { | ||
return Stats{}, ErrHeaderRequest | ||
} | ||
|
||
resp, err := d.client.Do(req) | ||
if err != nil { | ||
return Stats{}, ErrHeaderRequest | ||
} | ||
|
||
defer func() { | ||
_ = resp.Body.Close() | ||
}() | ||
|
||
d.fileType = resp.Header.Get("Content-Type") | ||
|
||
return Stats{ | ||
TotalSize: resp.ContentLength, | ||
}, nil | ||
} | ||
|
||
func (d *Downloader) createDir() error { | ||
dir := filepath.Dir(d.filePath) | ||
if err := os.MkdirAll(dir, 0o750); err != nil { | ||
return ErrCreateDir | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (d *Downloader) openFile() (*os.File, error) { | ||
fileInfo, err := os.Stat(d.filePath) | ||
if err == nil && fileInfo.IsDir() { | ||
return nil, ErrInvalidFilePath | ||
} | ||
|
||
return os.OpenFile(d.filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600) | ||
} | ||
|
||
func (*Downloader) validateExistingFile(out *os.File, stats *Stats) error { | ||
fileInfo, err := out.Stat() | ||
if err != nil { | ||
return ErrGetFileInfo | ||
} | ||
stats.Downloaded = fileInfo.Size() | ||
|
||
return nil | ||
} | ||
|
||
func (d *Downloader) downloadFile(ctx context.Context, out *os.File, stats *Stats) error { | ||
req, err := d.createRequest(ctx, stats.Downloaded) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
resp, err := d.client.Do(req) | ||
if err != nil { | ||
return ErrDoRequest | ||
} | ||
|
||
defer func() { | ||
_ = resp.Body.Close() | ||
}() | ||
|
||
buffer := make([]byte, 32*1024) | ||
hasher := sha256.New() | ||
|
||
if err := d.updateHasherWithExistingData(stats.Downloaded, hasher); err != nil { | ||
return err | ||
} | ||
|
||
return d.writeToFile(ctx, resp, out, buffer, hasher, stats) | ||
} | ||
|
||
func (d *Downloader) createRequest(ctx context.Context, downloaded int64) (*http.Request, error) { | ||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, http.NoBody) | ||
if err != nil { | ||
return nil, ErrNewRequest | ||
} | ||
if downloaded > 0 { | ||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", downloaded)) | ||
} | ||
|
||
return req, nil | ||
} | ||
|
||
func (d *Downloader) updateHasherWithExistingData(downloaded int64, hasher io.Writer) error { | ||
if downloaded > 0 { | ||
existingFile, err := os.Open(d.filePath) | ||
if err != nil { | ||
return ErrOpenFileExists | ||
} | ||
defer func() { | ||
_ = existingFile.Close() | ||
}() | ||
|
||
if _, err := io.CopyN(hasher, existingFile, downloaded); err != nil { | ||
return ErrCopyExistsFileData | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (d *Downloader) writeToFile(ctx context.Context, resp *http.Response, out *os.File, buffer []byte, | ||
hasher hash.Hash, stats *Stats, | ||
) error { | ||
for { | ||
select { | ||
case <-ctx.Done(): | ||
d.stop() | ||
|
||
return ctx.Err() | ||
default: | ||
n, err := resp.Body.Read(buffer) | ||
if n > 0 { | ||
if _, err := out.Write(buffer[:n]); err != nil { | ||
return ErrFileWriting | ||
} | ||
|
||
if _, err := hasher.Write(buffer[:n]); err != nil { | ||
return ErrFileWriting | ||
} | ||
|
||
stats.Downloaded += int64(n) | ||
stats.Percent = float64(stats.Downloaded) / float64(stats.TotalSize) * 100 | ||
d.statsCh <- *stats | ||
} | ||
if err != nil { | ||
if err == io.EOF { | ||
return d.finalizeDownload(hasher, stats) | ||
} | ||
|
||
return fmt.Errorf("error reading response body: %w", err) | ||
} | ||
} | ||
} | ||
} | ||
|
||
func (d *Downloader) finalizeDownload(hasher hash.Hash, stats *Stats) error { | ||
stats.Completed = true | ||
sum := hex.EncodeToString(hasher.Sum(nil)) | ||
if sum != d.sha256Sum { | ||
return ErrSHA256Mismatch | ||
} | ||
d.statsCh <- *stats | ||
|
||
d.stop() | ||
|
||
return nil | ||
} | ||
|
||
func (d *Downloader) stop() { | ||
close(d.statsCh) | ||
close(d.errCh) | ||
} | ||
|
||
func (d *Downloader) handleError(err error) { | ||
select { | ||
case d.errCh <- err: | ||
default: | ||
d.stop() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
package downloader | ||
|
||
import ( | ||
"context" | ||
"crypto/sha256" | ||
"encoding/hex" | ||
"log" | ||
"net/http" | ||
"net/http/httptest" | ||
"os" | ||
"testing" | ||
"time" | ||
|
||
"github.com/pactus-project/pactus/util" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestDownloader(t *testing.T) { | ||
fileContent := []byte("This is a test file content") | ||
fileURL := "/testfile" | ||
expectedSHA256 := sha256.Sum256(fileContent) | ||
expectedSHA256Hex := hex.EncodeToString(expectedSHA256[:]) | ||
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
if r.URL.Path == fileURL { | ||
_, err := w.Write(fileContent) | ||
assert.NoError(t, err) | ||
} else { | ||
http.NotFound(w, r) | ||
} | ||
})) | ||
defer server.Close() | ||
|
||
filePath := util.TempFilePath() | ||
|
||
defer func() { | ||
assert.NoError(t, os.RemoveAll("./testdata")) | ||
}() | ||
|
||
dl := New(server.URL+fileURL, filePath, expectedSHA256Hex, WithCustomClient(server.Client())) | ||
|
||
assrt := assert.New(t) | ||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) | ||
defer cancel() | ||
|
||
go func() { | ||
dl.Start(ctx) | ||
}() | ||
|
||
done := make(chan bool) | ||
|
||
go func() { | ||
for stat := range dl.Stats() { | ||
log.Printf("Downloaded: %d / %d (%.2f%%)\n", stat.Downloaded, stat.TotalSize, stat.Percent) | ||
assrt.True(stat.Downloaded <= stat.TotalSize, "Downloaded size should not exceed total size") | ||
assrt.True(stat.Percent <= 100, "Download percentage should not exceed 100") | ||
|
||
if stat.Completed { | ||
log.Println("Download completed successfully") | ||
assrt.Equal(float64(100), stat.Percent, "Download should be 100% complete") | ||
done <- true | ||
|
||
return | ||
} | ||
} | ||
}() | ||
|
||
go func() { | ||
for err := range dl.Errors() { | ||
assrt.Fail("Download encountered an error", err) | ||
done <- true | ||
|
||
return | ||
} | ||
}() | ||
|
||
select { | ||
case <-done: | ||
case <-time.After(2 * time.Minute): | ||
cancel() | ||
assrt.Fail("Download test timed out") | ||
} | ||
|
||
t.Log(dl.FileName()) | ||
t.Log(dl.FileType()) | ||
|
||
downloadedContent, err := os.ReadFile(filePath) | ||
assrt.NoError(err, "Failed to read the downloaded file") | ||
assrt.Equal(fileContent, downloadedContent, "Downloaded file content does not match expected content") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
package downloader | ||
|
||
import "net/http" | ||
|
||
type options struct { | ||
client *http.Client | ||
} | ||
|
||
type Option func(*options) | ||
|
||
func defaultOptions() *options { | ||
return &options{ | ||
client: http.DefaultClient, | ||
} | ||
} | ||
|
||
func WithCustomClient(client *http.Client) Option { | ||
return func(o *options) { | ||
o.client = client | ||
} | ||
} |