From 43dcd1b4eab042d0dbdb1441152a35a446a4297c Mon Sep 17 00:00:00 2001 From: Javad Rajabzadeh Date: Wed, 24 Jul 2024 18:47:47 +0330 Subject: [PATCH] feat(cmd): add node type page to the startup assistant (#1431) --- cmd/cmd.go | 6 +- cmd/daemon/import.go | 87 ++++----- cmd/daemon/prune.go | 7 +- cmd/daemon/start.go | 5 +- cmd/downlaod_mgr.go | 225 ------------------------ cmd/gtk/main.go | 7 +- cmd/gtk/startup_assistant.go | 331 ++++++++++++++++++++++++++++++----- cmd/gtk/widget_node.go | 2 +- cmd/importer.go | 239 +++++++++++++++++++++++++ scripts/snapshot.py | 106 +++++++---- util/io.go | 30 ++++ util/io_test.go | 82 +++++++++ util/utils.go | 4 +- 13 files changed, 749 insertions(+), 382 deletions(-) delete mode 100644 cmd/downlaod_mgr.go create mode 100644 cmd/importer.go diff --git a/cmd/cmd.go b/cmd/cmd.go index 9e873a1c1..e1cf0fbf9 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -634,7 +634,7 @@ func MakeValidatorKey(walletInstance *wallet.Wallet, valAddrsInfo []vault.Addres return valKeys, nil } -func TerminalProgressBar(totalSize, barWidth int, showBytes bool) *progressbar.ProgressBar { +func TerminalProgressBar(totalSize int64, barWidth int) *progressbar.ProgressBar { if barWidth < 15 { barWidth = 15 } @@ -642,10 +642,10 @@ func TerminalProgressBar(totalSize, barWidth int, showBytes bool) *progressbar.P opts := []progressbar.Option{ progressbar.OptionSetWriter(ansi.NewAnsiStdout()), progressbar.OptionEnableColorCodes(true), - progressbar.OptionShowBytes(showBytes), progressbar.OptionSetWidth(barWidth), progressbar.OptionSetElapsedTime(false), progressbar.OptionSetPredictTime(false), + progressbar.OptionShowDescriptionAtLineEnd(), progressbar.OptionSetTheme(progressbar.Theme{ Saucer: "[green]=[reset]", SaucerHead: "[green]>[reset]", @@ -655,5 +655,5 @@ func TerminalProgressBar(totalSize, barWidth int, showBytes bool) *progressbar.P }), } - return progressbar.NewOptions(totalSize, opts...) + return progressbar.NewOptions64(totalSize, opts...) } diff --git a/cmd/daemon/import.go b/cmd/daemon/import.go index b9eebe7f1..c52a6cc7b 100644 --- a/cmd/daemon/import.go +++ b/cmd/daemon/import.go @@ -4,11 +4,9 @@ import ( "fmt" "os" "path/filepath" - "time" "github.com/gofrs/flock" "github.com/pactus-project/pactus/cmd" - "github.com/pactus-project/pactus/genesis" "github.com/pactus-project/pactus/util" "github.com/spf13/cobra" ) @@ -21,7 +19,7 @@ func buildImportCmd(parentCmd *cobra.Command) { parentCmd.AddCommand(importCmd) workingDirOpt := addWorkingDirOption(importCmd) - serverAddrOpt := importCmd.Flags().String("server-addr", "https://download.pactus.org", + serverAddrOpt := importCmd.Flags().String("server-addr", cmd.DefaultSnapshotURL, "import server address") importCmd.Run = func(c *cobra.Command, _ []string) { @@ -46,39 +44,25 @@ func buildImportCmd(parentCmd *cobra.Command) { return } - storeDir, _ := filepath.Abs(conf.Store.StorePath()) - if !util.IsDirNotExistsOrEmpty(storeDir) { - cmd.PrintErrorMsgf("The data directory is not empty: %s", conf.Store.StorePath()) - - return - } + cmd.PrintLine() snapshotURL := *serverAddrOpt + importer, err := cmd.NewImporter( + gen.ChainType(), + snapshotURL, + conf.Store.DataPath(), + ) + cmd.FatalErrorCheck(err) - switch gen.ChainType() { - case genesis.Mainnet: - snapshotURL += "/mainnet/" - case genesis.Testnet: - snapshotURL += "/testnet/" - case genesis.Localnet: - cmd.PrintErrorMsgf("Unsupported chain type: %s", gen.ChainType()) - - return - } - - metadata, err := cmd.GetSnapshotMetadata(c.Context(), snapshotURL) - if err != nil { - cmd.PrintErrorMsgf("Failed to get snapshot metadata: %s", err) - - return - } + metadata, err := importer.GetMetadata(c.Context()) + cmd.FatalErrorCheck(err) snapshots := make([]string, 0, len(metadata)) for _, m := range metadata { item := fmt.Sprintf("snapshot %s (%s)", - parseTime(m.CreatedAt).Format("2006-01-02"), - util.FormatBytesToHumanReadable(m.TotalSize), + m.CreatedAtTime().Format("2006-01-02"), + util.FormatBytesToHumanReadable(m.Data.Size), ) snapshots = append(snapshots, item) @@ -89,34 +73,32 @@ func buildImportCmd(parentCmd *cobra.Command) { choice := cmd.PromptSelect("Please select a snapshot", snapshots) selected := metadata[choice] - tmpDir := util.TempDirPath() - extractPath := fmt.Sprintf("%s/data", tmpDir) - err = os.MkdirAll(extractPath, 0o750) - cmd.FatalErrorCheck(err) + cmd.TrapSignal(func() { + _ = fileLock.Unlock() + _ = importer.Cleanup() + }) cmd.PrintLine() - zipFileList := cmd.DownloadManager( + importer.Download( c.Context(), &selected, - snapshotURL, - tmpDir, downloadProgressBar, ) - for _, zFile := range zipFileList { - err := cmd.ExtractAndStoreFile(zFile, extractPath) - cmd.FatalErrorCheck(err) - } + cmd.PrintLine() + cmd.PrintLine() + cmd.PrintInfoMsgf("Extracting files...") - err = os.MkdirAll(filepath.Dir(conf.Store.StorePath()), 0o750) + err = importer.ExtractAndStoreFiles() cmd.FatalErrorCheck(err) - err = cmd.CopyAllFiles(extractPath, conf.Store.StorePath()) + cmd.PrintInfoMsgf("Moving data...") + err = importer.MoveStore() cmd.FatalErrorCheck(err) - err = os.RemoveAll(tmpDir) + err = importer.Cleanup() cmd.FatalErrorCheck(err) _ = fileLock.Unlock() @@ -131,19 +113,12 @@ func buildImportCmd(parentCmd *cobra.Command) { } func downloadProgressBar(fileName string, totalSize, downloaded int64, _ float64) { - bar := cmd.TerminalProgressBar(int(totalSize), 30, true) - bar.Describe(fileName) - err := bar.Add(int(downloaded)) + bar := cmd.TerminalProgressBar(totalSize, 30) + bar.Describe(fmt.Sprintf("%s (%s/%s)", + fileName, + util.FormatBytesToHumanReadable(uint64(downloaded)), + util.FormatBytesToHumanReadable(uint64(totalSize)), + )) + err := bar.Add64(downloaded) cmd.FatalErrorCheck(err) } - -func parseTime(dateString string) time.Time { - const layout = "2006-01-02T15:04:05.000000" - - parsedTime, err := time.Parse(layout, dateString) - if err != nil { - return time.Time{} - } - - return parsedTime -} diff --git a/cmd/daemon/prune.go b/cmd/daemon/prune.go index 3fd49ade9..a58c2ab89 100644 --- a/cmd/daemon/prune.go +++ b/cmd/daemon/prune.go @@ -34,10 +34,7 @@ func buildPruneCmd(parentCmd *cobra.Command) { fileLock := flock.New(lockFilePath) locked, err := fileLock.TryLock() - if err != nil { - // handle unable to attempt to acquire lock - cmd.FatalErrorCheck(err) - } + cmd.FatalErrorCheck(err) if !locked { cmd.PrintWarnMsgf("Could not lock '%s', another instance is running?", lockFilePath) @@ -114,7 +111,7 @@ func buildPruneCmd(parentCmd *cobra.Command) { } func pruningProgressBar(prunedCount, skippedCount, totalCount uint32) { - bar := cmd.TerminalProgressBar(int(totalCount), 30, false) + bar := cmd.TerminalProgressBar(int64(totalCount), 30) bar.Describe(fmt.Sprintf("Pruned: %d | Skipped: %d", prunedCount, skippedCount)) err := bar.Add(int(prunedCount + skippedCount)) cmd.FatalErrorCheck(err) diff --git a/cmd/daemon/start.go b/cmd/daemon/start.go index b707cb57d..3df4b2f32 100644 --- a/cmd/daemon/start.go +++ b/cmd/daemon/start.go @@ -41,10 +41,7 @@ func buildStartCmd(parentCmd *cobra.Command) { fileLock := flock.New(lockFilePath) locked, err := fileLock.TryLock() - if err != nil { - // handle unable to attempt to acquire lock - cmd.FatalErrorCheck(err) - } + cmd.FatalErrorCheck(err) if !locked { cmd.PrintWarnMsgf("Could not lock '%s', another instance is running?", lockFilePath) diff --git a/cmd/downlaod_mgr.go b/cmd/downlaod_mgr.go deleted file mode 100644 index ed00b6bba..000000000 --- a/cmd/downlaod_mgr.go +++ /dev/null @@ -1,225 +0,0 @@ -package cmd - -import ( - "archive/zip" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "net/url" - "os" - "path/filepath" - - "github.com/pactus-project/pactus/util/downloader" -) - -const maxDecompressedSize = 10 << 20 // 10 MB - -type Metadata struct { - Name string `json:"name"` - CreatedAt string `json:"created_at"` - Compress string `json:"compress"` - TotalSize uint64 `json:"total_size"` - Data []*SnapshotData `json:"data"` -} - -type SnapshotData struct { - Name string `json:"name"` - Path string `json:"path"` - Sha string `json:"sha"` - Size uint64 `json:"size"` -} - -func GetSnapshotMetadata(ctx context.Context, snapshotURL string) ([]Metadata, error) { - cli := http.DefaultClient - metaURL, err := url.JoinPath(snapshotURL, "metadata.json") - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, metaURL, http.NoBody) - if err != nil { - return nil, err - } - - resp, err := cli.Do(req) - if err != nil { - return nil, err - } - defer func() { - _ = resp.Body.Close() - }() - - if resp.StatusCode != http.StatusOK { - return nil, errors.New(resp.Status) - } - - metadata := make([]Metadata, 0) - - dec := json.NewDecoder(resp.Body) - - if err := dec.Decode(&metadata); err != nil { - return nil, err - } - - return metadata, nil -} - -func DownloadManager( - ctx context.Context, - metadata *Metadata, - baseURL, tempPath string, - stateFunc func(fileName string, totalSize, downloaded int64, percentage float64), -) []string { - zipFileListPath := make([]string, 0) - - for _, data := range metadata.Data { - done := make(chan struct{}) - dlLink, err := url.JoinPath(baseURL, data.Path) - FatalErrorCheck(err) - - fileName := filepath.Base(dlLink) - - filePath := fmt.Sprintf("%s/%s", tempPath, fileName) - - dl := downloader.New( - dlLink, - filePath, - data.Sha, - ) - - dl.Start(ctx) - - go func() { - err := <-dl.Errors() - FatalErrorCheck(err) - }() - - go func() { - for state := range dl.Stats() { - stateFunc(fileName, state.TotalSize, state.Downloaded, state.Percent) - if state.Completed { - done <- struct{}{} - close(done) - - return - } - } - }() - - <-done - zipFileListPath = append(zipFileListPath, filePath) - } - - return zipFileListPath -} - -func ExtractAndStoreFile(zipFilePath, extractPath string) error { - r, err := zip.OpenReader(zipFilePath) - if err != nil { - return fmt.Errorf("failed to open zip file: %w", err) - } - defer func() { - _ = r.Close() - }() - - for _, f := range r.File { - rc, err := f.Open() - if err != nil { - return fmt.Errorf("failed to open file in zip archive: %w", err) - } - - fpath := fmt.Sprintf("%s/%s", extractPath, f.Name) - - outFile, err := os.Create(fpath) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - - // fixed potential DoS vulnerability via decompression bomb - lr := io.LimitedReader{R: rc, N: maxDecompressedSize} - _, err = io.Copy(outFile, &lr) - if err != nil { - return fmt.Errorf("failed to copy file contents: %w", err) - } - - // check if the file size exceeds the limit - if lr.N <= 0 { - return fmt.Errorf("file exceeds maximum decompressed size limit: %s", fpath) - } - - _ = rc.Close() - _ = outFile.Close() - } - - return nil -} - -// CopyAllFiles copies all files from srcDir to dstDir. -func CopyAllFiles(srcDir, dstDir string) error { - err := os.MkdirAll(dstDir, 0o750) - if err != nil { - return fmt.Errorf("failed to create destination directory: %w", err) - } - - return filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - if info.IsDir() { - return nil // Skip directories - } - - relativePath, err := filepath.Rel(srcDir, path) - if err != nil { - return err - } - - dstPath := filepath.Join(dstDir, relativePath) - - err = os.MkdirAll(filepath.Dir(dstPath), 0o750) - if err != nil { - return fmt.Errorf("failed to create directory: %w", err) - } - - err = copyFile(path, dstPath) - if err != nil { - return fmt.Errorf("failed to copy file from %s to %s: %w", path, dstPath, err) - } - - return nil - }) -} - -func copyFile(src, dst string) error { - sourceFile, err := os.Open(src) - if err != nil { - return fmt.Errorf("failed to open source file: %w", err) - } - defer func() { - _ = sourceFile.Close() - }() - - destinationFile, err := os.Create(dst) - if err != nil { - return fmt.Errorf("failed to create destination file: %w", err) - } - defer func() { - _ = destinationFile.Close() - }() - - _, err = io.Copy(destinationFile, sourceFile) - if err != nil { - return fmt.Errorf("failed to copy file contents: %w", err) - } - - err = destinationFile.Sync() - if err != nil { - return fmt.Errorf("failed to sync destination file: %w", err) - } - - return nil -} diff --git a/cmd/gtk/main.go b/cmd/gtk/main.go index 451466828..c9bf57945 100644 --- a/cmd/gtk/main.go +++ b/cmd/gtk/main.go @@ -48,7 +48,7 @@ func main() { } // If node is not initialized yet - if !util.PathExists(workingDir) { + if util.IsDirNotExistsOrEmpty(workingDir) { network := genesis.Mainnet if *testnetOpt { network = genesis.Testnet @@ -63,10 +63,7 @@ func main() { fileLock := flock.New(lockFilePath) locked, err := fileLock.TryLock() - if err != nil { - // handle unable to attempt to acquire lock - fatalErrorCheck(err) - } + fatalErrorCheck(err) if !locked { cmd.PrintWarnMsgf("Could not lock '%s', another instance is running?", lockFilePath) diff --git a/cmd/gtk/startup_assistant.go b/cmd/gtk/startup_assistant.go index cfa910b51..a0d0112cc 100644 --- a/cmd/gtk/startup_assistant.go +++ b/cmd/gtk/startup_assistant.go @@ -3,15 +3,19 @@ package main import ( + "context" "fmt" "log" + "path/filepath" "regexp" "strings" + "time" "github.com/gotk3/gotk3/glib" "github.com/gotk3/gotk3/gtk" "github.com/pactus-project/pactus/cmd" "github.com/pactus-project/pactus/genesis" + "github.com/pactus-project/pactus/util" "github.com/pactus-project/pactus/wallet" ) @@ -25,8 +29,8 @@ func setMargin(widget gtk.IWidget, top, bottom, start, end int) { widget.ToWidget().SetMarginEnd(end) } -//nolint:gocognit // complexity can't be reduced more. -func startupAssistant(workingDir string, chain genesis.ChainType) bool { +//nolint:all // complexity can't be reduced more. It needs to refactor. +func startupAssistant(workingDir string, chainType genesis.ChainType) bool { successful := false assistant, err := gtk.AssistantNew() fatalErrorCheck(err) @@ -37,25 +41,29 @@ func startupAssistant(workingDir string, chain genesis.ChainType) bool { assistFunc := pageAssistant() // --- page_mode - mode, restoreRadio, pageModeName := pageMode(assistant, assistFunc) + wgtWalletMode, radioRestoreWallet, pageModeName := pageWalletMode(assistant, assistFunc) // --- page_seed_generate - seedGenerate, textViewSeed, pageSeedGenerateName := pageSeedGenerate(assistant, assistFunc) + wgtSeedGenerate, txtSeed, pageSeedGenerateName := pageSeedGenerate(assistant, assistFunc) // --- page_seed_confirm - seedConfirm, pageSeedConfirmName := pageSeedConfirm(assistant, assistFunc, textViewSeed) + wgtSeedConfirm, pageSeedConfirmName := pageSeedConfirm(assistant, assistFunc, txtSeed) // -- page_seed_restore - seedRestore, textViewRestoreSeed, pageSeedRestoreName := pageSeedRestore(assistant, assistFunc) + wgtSeedRestore, textRestoreSeed, pageSeedRestoreName := pageSeedRestore(assistant, assistFunc) // --- page_password - password, entryPassword, pagePasswordName := pagePassword(assistant, assistFunc) + wgtPassword, entryPassword, pagePasswordName := pagePassword(assistant, assistFunc) // --- page_num_validators - numValidators, lsNumValidators, comboNumValidators, pageNumValidatorsName := pageNumValidators(assistant, assistFunc) + wgtNumValidators, lsNumValidators, comboNumValidators, + pageNumValidatorsName := pageNumValidators(assistant, assistFunc) - // --- page_final - final, textViewNodeInfo, pageFinalName := pageFinal(assistant, assistFunc) + // -- page_node_type + wgtNodeType, gridImport, radioImport, pageNodeTypeName := pageNodeType(assistant, assistFunc) + + // --- page_summary + wgtSummary, txtNodeInfo, pageSummaryName := pageSummary(assistant, assistFunc) assistant.Connect("cancel", func() { assistant.Close() @@ -68,19 +76,20 @@ func startupAssistant(workingDir string, chain genesis.ChainType) bool { gtk.MainQuit() }) - assistant.SetPageType(mode, gtk.ASSISTANT_PAGE_INTRO) // page 0 - assistant.SetPageType(seedGenerate, gtk.ASSISTANT_PAGE_CONTENT) // page 1 - assistant.SetPageType(seedConfirm, gtk.ASSISTANT_PAGE_CONTENT) // page 2 - assistant.SetPageType(seedRestore, gtk.ASSISTANT_PAGE_CONTENT) // page 3 - assistant.SetPageType(password, gtk.ASSISTANT_PAGE_CONTENT) // page 4 - assistant.SetPageType(numValidators, gtk.ASSISTANT_PAGE_CONTENT) // page 5 - assistant.SetPageType(final, gtk.ASSISTANT_PAGE_SUMMARY) // page 6 + assistant.SetPageType(wgtWalletMode, gtk.ASSISTANT_PAGE_INTRO) // page 0 + assistant.SetPageType(wgtSeedGenerate, gtk.ASSISTANT_PAGE_CONTENT) // page 1 + assistant.SetPageType(wgtSeedConfirm, gtk.ASSISTANT_PAGE_CONTENT) // page 2 + assistant.SetPageType(wgtSeedRestore, gtk.ASSISTANT_PAGE_CONTENT) // page 3 + assistant.SetPageType(wgtPassword, gtk.ASSISTANT_PAGE_CONTENT) // page 4 + assistant.SetPageType(wgtNumValidators, gtk.ASSISTANT_PAGE_CONTENT) // page 5 + assistant.SetPageType(wgtNodeType, gtk.ASSISTANT_PAGE_CONTENT) // page 6 + assistant.SetPageType(wgtSummary, gtk.ASSISTANT_PAGE_SUMMARY) // page 7 mnemonic := "" prevPageIndex := -1 prevPageAdjust := 0 assistant.Connect("prepare", func(assistant *gtk.Assistant, page *gtk.Widget) { - isRestoreMode := restoreRadio.GetActive() + isRestoreMode := radioRestoreWallet.GetActive() curPageName, err := page.GetName() curPageIndex := assistant.GetCurrentPage() fatalErrorCheck(err) @@ -94,7 +103,7 @@ func startupAssistant(workingDir string, chain genesis.ChainType) bool { curPageName, isRestoreMode, prevPageIndex, curPageIndex) switch curPageName { case pageModeName: - assistantPageComplete(assistant, mode, true) + assistantPageComplete(assistant, wgtWalletMode, true) case pageSeedGenerateName: if isRestoreMode { @@ -109,11 +118,11 @@ func startupAssistant(workingDir string, chain genesis.ChainType) bool { assistant.PreviousPage() prevPageAdjust = -1 } - assistantPageComplete(assistant, seedGenerate, false) + assistantPageComplete(assistant, wgtSeedGenerate, false) } else { mnemonic, _ = wallet.GenerateMnemonic(128) - setTextViewContent(textViewSeed, mnemonic) - assistantPageComplete(assistant, seedGenerate, true) + setTextViewContent(txtSeed, mnemonic) + assistantPageComplete(assistant, wgtSeedGenerate, true) } case pageSeedConfirmName: if isRestoreMode { @@ -128,9 +137,9 @@ func startupAssistant(workingDir string, chain genesis.ChainType) bool { assistant.PreviousPage() prevPageAdjust = -1 } - assistantPageComplete(assistant, seedConfirm, false) + assistantPageComplete(assistant, wgtSeedConfirm, false) } else { - assistantPageComplete(assistant, seedConfirm, false) + assistantPageComplete(assistant, wgtSeedConfirm, false) } case pageSeedRestoreName: if !isRestoreMode { @@ -145,24 +154,161 @@ func startupAssistant(workingDir string, chain genesis.ChainType) bool { assistant.PreviousPage() prevPageAdjust = -1 } - assistantPageComplete(assistant, seedConfirm, false) + assistantPageComplete(assistant, wgtSeedConfirm, false) } else { - assistantPageComplete(assistant, seedRestore, true) + assistantPageComplete(assistant, wgtSeedRestore, true) } case pagePasswordName: if isRestoreMode { - mnemonic = getTextViewContent(textViewRestoreSeed) + mnemonic = getTextViewContent(textRestoreSeed) if err := wallet.CheckMnemonic(mnemonic); err != nil { showErrorDialog(assistant, "mnemonic is invalid") assistant.PreviousPage() } } - assistantPageComplete(assistant, password, true) + assistantPageComplete(assistant, wgtPassword, true) case pageNumValidatorsName: - assistantPageComplete(assistant, numValidators, true) + assistantPageComplete(assistant, wgtNumValidators, true) + + case pageNodeTypeName: + assistantPageComplete(assistant, wgtNodeType, true) + ssLabel, err := gtk.LabelNew("") + fatalErrorCheck(err) + setMargin(ssLabel, 5, 5, 1, 1) + ssLabel.SetHAlign(gtk.ALIGN_START) + + listBox, err := gtk.ListBoxNew() + fatalErrorCheck(err) + setMargin(listBox, 5, 5, 1, 1) + listBox.SetHAlign(gtk.ALIGN_CENTER) + listBox.SetSizeRequest(600, -1) + + ssDLBtn, err := gtk.ButtonNewWithLabel("⏬ Download") + fatalErrorCheck(err) + setMargin(ssDLBtn, 10, 5, 1, 1) + ssDLBtn.SetHAlign(gtk.ALIGN_CENTER) + ssDLBtn.SetSizeRequest(600, -1) - case pageFinalName: + ssPBLabel, err := gtk.LabelNew("") + fatalErrorCheck(err) + setMargin(ssPBLabel, 5, 10, 1, 1) + ssPBLabel.SetHAlign(gtk.ALIGN_START) + + gridImport.Attach(ssLabel, 0, 1, 1, 1) + gridImport.Attach(listBox, 0, 2, 1, 1) + gridImport.Attach(ssDLBtn, 0, 3, 1, 1) + gridImport.Attach(ssPBLabel, 0, 5, 1, 1) + ssLabel.SetVisible(false) + listBox.SetVisible(false) + ssDLBtn.SetVisible(false) + ssPBLabel.SetVisible(false) + + snapshotIndex := 0 + + radioImport.Connect("toggled", func() { + if radioImport.GetActive() { + assistantPageComplete(assistant, wgtNodeType, false) + + ssLabel.SetVisible(true) + ssLabel.SetText(" ♻️ Please wait, loading snapshot list...") + + go func() { + time.Sleep(1 * time.Second) + + glib.IdleAdd(func() { + snapshotURL := cmd.DefaultSnapshotURL // TODO: make me optional... + + storeDir := filepath.Join(workingDir, "data") + importer, err := cmd.NewImporter( + chainType, + snapshotURL, + storeDir, + ) + fatalErrorCheck(err) + + ctx := context.Background() + mdCh := getMetadata(ctx, importer, listBox) + + if md := <-mdCh; md == nil { + ssLabel.SetText(" ❌ Failed to get snapshot list, please try again later.") + } else { + ssLabel.SetText(" πŸ”½ Please select a snapshot to download:") + listBox.SetVisible(true) + + listBox.Connect("row-selected", func(_ *gtk.ListBox, row *gtk.ListBoxRow) { + if row != nil { + snapshotIndex = row.GetIndex() + ssDLBtn.SetVisible(true) + } + }) + + ssDLBtn.Connect("clicked", func() { + radioGroup, _ := radioImport.GetParent() + radioImport.SetSensitive(false) + radioGroup.ToWidget().SetSensitive(false) + ssLabel.SetSensitive(false) + listBox.SetSensitive(false) + ssDLBtn.SetSensitive(false) + + ssDLBtn.SetVisible(false) + ssPBLabel.SetVisible(true) + listBox.SetSelectionMode(gtk.SELECTION_NONE) + + go func() { + log.Printf("start downloading...\n") + + importer.Download( + ctx, + &md[snapshotIndex], + func(fileName string, totalSize, downloaded int64, + percentage float64, + ) { + percent := int(percentage) + glib.IdleAdd(func() { + dlMessage := fmt.Sprintf("🌐 Downloading %s | %d%% (%s / %s)", + fileName, + percent, + util.FormatBytesToHumanReadable(uint64(downloaded)), + util.FormatBytesToHumanReadable(uint64(totalSize)), + ) + ssPBLabel.SetText(" " + dlMessage) + }) + }, + ) + + glib.IdleAdd(func() { + log.Printf("extracting data...\n") + ssPBLabel.SetText(" " + "πŸ“‚ Extracting downloaded files...") + err := importer.ExtractAndStoreFiles() + fatalErrorCheck(err) + + log.Printf("moving data...\n") + ssPBLabel.SetText(" " + "πŸ“‘ Moving data...") + err = importer.MoveStore() + fatalErrorCheck(err) + + log.Printf("cleanup...\n") + err = importer.Cleanup() + fatalErrorCheck(err) + + ssPBLabel.SetText(" " + "βœ… Import completed.") + assistantPageComplete(assistant, wgtNodeType, true) + }) + }() + }) + } + }) + }() + } else { + assistantPageComplete(assistant, wgtNodeType, true) + ssLabel.SetVisible(false) + listBox.SetVisible(false) + ssDLBtn.SetVisible(false) + ssPBLabel.SetVisible(false) + } + }) + case pageSummaryName: iter, err := comboNumValidators.GetActiveIter() fatalErrorCheck(err) @@ -176,13 +322,13 @@ func startupAssistant(workingDir string, chain genesis.ChainType) bool { walletPassword, err := entryPassword.GetText() fatalErrorCheck(err) - validatorAddrs, rewardAddrs, err := cmd.CreateNode(numValidators, chain, workingDir, mnemonic, walletPassword) + validatorAddrs, rewardAddrs, err := cmd.CreateNode(numValidators, chainType, workingDir, mnemonic, walletPassword) fatalErrorCheck(err) // Done! showing the node information successful = true nodeInfo := fmt.Sprintf("Working directory: %s\n", workingDir) - nodeInfo += fmt.Sprintf("Network: %s\n", chain.String()) + nodeInfo += fmt.Sprintf("Network: %s\n", chainType.String()) nodeInfo += "\nValidator addresses:\n" for i, addr := range validatorAddrs { nodeInfo += fmt.Sprintf("%v- %s\n", i+1, addr) @@ -193,7 +339,7 @@ func startupAssistant(workingDir string, chain genesis.ChainType) bool { nodeInfo += fmt.Sprintf("%v- %s\n", i+1, addr) } - setTextViewContent(textViewNodeInfo, nodeInfo) + setTextViewContent(txtNodeInfo, nodeInfo) } prevPageIndex = curPageIndex + prevPageAdjust }) @@ -243,7 +389,7 @@ func pageAssistant() assistantFunc { } } -func pageMode(assistant *gtk.Assistant, assistFunc assistantFunc) (*gtk.Widget, *gtk.RadioButton, string) { +func pageWalletMode(assistant *gtk.Assistant, assistFunc assistantFunc) (*gtk.Widget, *gtk.RadioButton, string) { var mode *gtk.Widget newWalletRadio, err := gtk.RadioButtonNewWithLabel(nil, "Create a new wallet from the scratch") fatalErrorCheck(err) @@ -260,8 +406,8 @@ func pageMode(assistant *gtk.Assistant, assistFunc assistantFunc) (*gtk.Widget, radioBox.Add(restoreWalletRadio) setMargin(restoreWalletRadio, 6, 6, 6, 6) - pageModeName := "page_mode" - pageModeTitle := "Initialize mode" + pageModeName := "page_wallet_mode" + pageModeTitle := "Wallet Mode" pageModeSubject := "How to create your wallet?" pageModeDesc := "If you are running the node for the first time, choose the first option." mode = assistFunc( @@ -287,7 +433,7 @@ func pageSeedGenerate(assistant *gtk.Assistant, assistFunc assistantFunc) (*gtk. textViewSeed.SetSizeRequest(0, 80) pageSeedName := "page_seed_generate" - pageSeedTitle := "Wallet seed" + pageSeedTitle := "Wallet Seed" pageSeedSubject := "Your wallet generation seed is:" pageSeedDesc := `Please write these 12 words on paper. This seed will allow you to recover your wallet in case of computer failure. @@ -319,7 +465,7 @@ func pageSeedRestore(assistant *gtk.Assistant, assistFunc assistantFunc) (*gtk.W textViewRestoreSeed.SetSizeRequest(0, 80) pageSeedName := "page_seed_restore" - pageSeedTitle := "Wallet seed restore" + pageSeedTitle := "Wallet Seed Restore" pageSeedSubject := "Enter your wallet seed:" pageSeedDesc := "Please enter your 12 words mnemonics backup to restore your wallet." @@ -369,7 +515,7 @@ func pageSeedConfirm(assistant *gtk.Assistant, assistFunc assistantFunc, }) pageSeedConfirmName := "page_seed_confirm" - pageSeedConfirmTitle := "Confirm seed" + pageSeedConfirmTitle := "Confirm Seed" pageSeedConfirmSubject := "What was your seed?" pageSeedConfirmDesc := `Your seed is important! To make sure that you have properly saved your seed, please retype it here.` @@ -385,6 +531,59 @@ To make sure that you have properly saved your seed, please retype it here.` return pageWidget, pageSeedConfirmName } +func pageNodeType(assistant *gtk.Assistant, assistFunc assistantFunc) ( + *gtk.Widget, + *gtk.Grid, + *gtk.RadioButton, + string, +) { + var pageWidget *gtk.Widget + + vbox, err := gtk.BoxNew(gtk.ORIENTATION_VERTICAL, 0) + fatalErrorCheck(err) + + grid, err := gtk.GridNew() + fatalErrorCheck(err) + + btnFullNode, err := gtk.RadioButtonNewWithLabel(nil, "Full node") + fatalErrorCheck(err) + btnFullNode.SetActive(true) + + btnPruneNode, err := gtk.RadioButtonNewWithLabelFromWidget(btnFullNode, "Pruned node") + fatalErrorCheck(err) + + radioBox, err := gtk.BoxNew(gtk.ORIENTATION_VERTICAL, 0) + fatalErrorCheck(err) + + radioBox.Add(btnFullNode) + setMargin(btnFullNode, 6, 6, 6, 6) + radioBox.Add(btnPruneNode) + setMargin(btnPruneNode, 6, 10, 6, 6) + + grid.Attach(radioBox, 0, 0, 1, 1) + + vbox.PackStart(grid, true, true, 0) + + pageName := "page_node_type" + pageTitle := "Node Type" + pageSubject := "How do you want to start your node?" + pageDesc := `A pruned node doesn’t keep all the historical data. +Instead, it only retains the most recent part of the blockchain, deleting older data to save disk space. +Offline data is available at: https://snapshot.pactus.org/.` + + // Create and return the page widget using assistFunc + pageWidget = assistFunc( + assistant, + vbox, + pageName, + pageTitle, + pageSubject, + pageDesc, + ) + + return pageWidget, grid, btnPruneNode, pageName +} + func pagePassword(assistant *gtk.Assistant, assistFunc assistantFunc) (*gtk.Widget, *gtk.Entry, string) { pageWidget := new(gtk.Widget) entryPassword, err := gtk.EntryNew() @@ -439,7 +638,7 @@ func pagePassword(assistant *gtk.Assistant, assistFunc assistantFunc) (*gtk.Widg }) pagePasswordName := "page_password" - pagePasswordTitle := "Wallet password" + pagePasswordTitle := "Wallet Password" pagePasswordSubject := "Enter password for your wallet:" pagePsswrdDesc := "Please choose a strong password for your wallet." @@ -492,7 +691,7 @@ func pageNumValidators(assistant *gtk.Assistant, grid.Attach(comboNumValidators, 1, 0, 1, 1) pageNumValidatorsName := "page_num_validators" - pageNumValidatorsTitle := "Number of validators" + pageNumValidatorsTitle := "Number of Validators" pageNumValidatorsSubject := "How many validators do you want to create?" pageNumValidatorsDesc := `Each node can run up to 32 validators, and each validator can hold up to 1000 staked coins. You can define validators based on the amount of coins you want to stake. @@ -509,7 +708,7 @@ For more information, look 0 { + child := children.Data().(*gtk.Widget) + listBox.Remove(child) + children = children.Next() + } + + metadata, err := dm.GetMetadata(ctx) + if err != nil { + mdCh <- nil + + return + } + + for _, md := range metadata { + listBoxRow, err := gtk.ListBoxRowNew() + fatalErrorCheck(err) + + label, err := gtk.LabelNew(fmt.Sprintf("snapshot %s (%s)", + md.CreatedAtTime().Format("2006-01-02"), + util.FormatBytesToHumanReadable(md.Data.Size), + )) + fatalErrorCheck(err) + + listBoxRow.Add(label) + listBox.Add(listBoxRow) + } + listBox.ShowAll() + mdCh <- metadata + }() + + return mdCh +} diff --git a/cmd/gtk/widget_node.go b/cmd/gtk/widget_node.go index e340d3cbd..66b1599ea 100644 --- a/cmd/gtk/widget_node.go +++ b/cmd/gtk/widget_node.go @@ -143,7 +143,7 @@ func (wn *widgetNode) timeout10() bool { fatalErrorCheck(err) wn.labelClockOffset.SetTooltipText("Difference between time of your machine and " + - "network time( (NTP) for synchronization.") + "network time (NTP) for synchronization.") if offsetErr != nil { styleContext.AddClass("warning") diff --git a/cmd/importer.go b/cmd/importer.go new file mode 100644 index 000000000..ce8405709 --- /dev/null +++ b/cmd/importer.go @@ -0,0 +1,239 @@ +package cmd + +import ( + "archive/zip" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "sort" + "time" + + "github.com/pactus-project/pactus/genesis" + "github.com/pactus-project/pactus/util" + "github.com/pactus-project/pactus/util/downloader" +) + +const DefaultSnapshotURL = "https://snapshot.pactus.org" + +const maxDecompressedSize = 10 << 20 // 10 MB + +type DMStateFunc func( + fileName string, + totalSize, downloaded int64, + percentage float64, +) + +type Metadata struct { + Name string `json:"name"` + CreatedAt string `json:"created_at"` + Compress string `json:"compress"` + Data SnapshotData `json:"data"` +} + +type SnapshotData struct { + Name string `json:"name"` + Path string `json:"path"` + Sha string `json:"sha"` + Size uint64 `json:"size"` +} + +func (md *Metadata) CreatedAtTime() time.Time { + const layout = "2006-01-02T15:04:05.000000" + + parsedTime, err := time.Parse(layout, md.CreatedAt) + if err != nil { + return time.Time{} + } + + return parsedTime +} + +// Importer downloads and imports the pruned data from a centralized server. +type Importer struct { + snapshotURL string + tempDir string + storeDir string + dataFileName string +} + +func NewImporter(chainType genesis.ChainType, snapshotURL, storeDir string) (*Importer, error) { + if util.PathExists(storeDir) { + return nil, fmt.Errorf("data directory is not empty: %s", storeDir) + } + + switch chainType { + case genesis.Mainnet: + snapshotURL += "/mainnet/" + case genesis.Testnet: + snapshotURL += "/testnet/" + case genesis.Localnet: + return nil, fmt.Errorf("unsupported chain type: %s", chainType) + } + + tempDir := util.TempDirPath() + + return &Importer{ + snapshotURL: snapshotURL, + tempDir: tempDir, + storeDir: storeDir, + }, nil +} + +func (i *Importer) GetMetadata(ctx context.Context) ([]Metadata, error) { + cli := http.DefaultClient + metaURL, err := url.JoinPath(i.snapshotURL, "metadata.json") + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, metaURL, http.NoBody) + if err != nil { + return nil, err + } + + resp, err := cli.Do(req) + if err != nil { + return nil, err + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + return nil, errors.New(resp.Status) + } + + metadata := make([]Metadata, 0) + + dec := json.NewDecoder(resp.Body) + + if err := dec.Decode(&metadata); err != nil { + return nil, err + } + + sort.SliceStable(metadata, func(i, j int) bool { + return metadata[i].CreatedAtTime().After(metadata[j].CreatedAtTime()) + }) + + return metadata, nil +} + +func (i *Importer) Download( + ctx context.Context, + metadata *Metadata, + stateFunc DMStateFunc, +) { + done := make(chan struct{}) + dlLink, err := url.JoinPath(i.snapshotURL, metadata.Data.Path) + FatalErrorCheck(err) + + fileName := filepath.Base(dlLink) + + i.dataFileName = fileName + + filePath := fmt.Sprintf("%s/%s", i.tempDir, fileName) + + d := downloader.New( + dlLink, + filePath, + metadata.Data.Sha, + ) + + d.Start(ctx) + + go func() { + err := <-d.Errors() + FatalErrorCheck(err) + }() + + go func() { + for state := range d.Stats() { + stateFunc(fileName, state.TotalSize, state.Downloaded, state.Percent) + if state.Completed { + done <- struct{}{} + close(done) + + return + } + } + }() + + <-done +} + +func (i *Importer) Cleanup() error { + return os.RemoveAll(i.tempDir) +} + +func (i *Importer) ExtractAndStoreFiles() error { + zipPath := filepath.Join(i.tempDir, i.dataFileName) + r, err := zip.OpenReader(zipPath) + if err != nil { + return fmt.Errorf("failed to open zip file: %w", err) + } + defer func() { + _ = r.Close() + }() + + for _, f := range r.File { + if err := i.extractAndWriteFile(f); err != nil { + return err + } + } + + return nil +} + +func (i *Importer) extractAndWriteFile(f *zip.File) error { + rc, err := f.Open() + if err != nil { + return fmt.Errorf("failed to open file in zip archive: %w", err) + } + defer func() { + _ = rc.Close() + }() + + fPath, err := util.SanitizeArchivePath(i.tempDir, f.Name) + if err != nil { + return fmt.Errorf("failed to make archive path: %w", err) + } + + if f.FileInfo().IsDir() { + return util.Mkdir(fPath) + } + + if err := util.Mkdir(filepath.Dir(fPath)); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + outFile, err := os.Create(fPath) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + defer func() { + _ = outFile.Close() + }() + + // Use a limited reader to prevent DoS attacks via decompression bomb + lr := &io.LimitedReader{R: rc, N: maxDecompressedSize} + written, err := io.Copy(outFile, lr) + if err != nil { + return fmt.Errorf("failed to copy file contents: %w", err) + } + + if written >= maxDecompressedSize { + return fmt.Errorf("file exceeds maximum decompressed size limit: %s", fPath) + } + + return nil +} + +func (i *Importer) MoveStore() error { + return util.MoveDirectory(filepath.Join(i.tempDir, "data"), i.storeDir) +} diff --git a/scripts/snapshot.py b/scripts/snapshot.py index cfcb70cd0..e84de29be 100644 --- a/scripts/snapshot.py +++ b/scripts/snapshot.py @@ -12,8 +12,8 @@ # - `--data_path`: This argument specifies the path to the Pactus data folder to create snapshots. # - Windows: `C:\Users\{user}\pactus\data` # - Linux or Mac: `/home/{user}/pactus/data` -# - `--compress`: This argument specifies the compression method based on your choice ['zip', 'tar'], -# default is zip. +# - `--compress`: This argument specifies the compression method based on your choice ['none', 'zip', 'tar'], +# with 'none' being without compression. # - `--retention`: This argument sets the number of snapshots to keep. # - `--snapshot_path`: This argument sets a custom path for snapshots, with the default being the current # working directory of the script. @@ -25,6 +25,7 @@ # sudo python3 snapshot.py --service_path /etc/systemd/system/pactus.service --data_path /home/{user}/pactus/data # --compress zip --retention 3 + import argparse import os import shutil @@ -72,7 +73,14 @@ def update_metadata_file(snapshot_path, snapshot_metadata): logging.info(f"Creating new metadata file '{metadata_file}'") metadata = [] - metadata.append(snapshot_metadata) + formatted_metadata = { + "name": snapshot_metadata["name"], + "created_at": snapshot_metadata["created_at"], + "compress": snapshot_metadata["compress"], + "data": snapshot_metadata["data"] + } + + metadata.append(formatted_metadata) with open(metadata_file, 'w') as f: json.dump(metadata, f, indent=4) @@ -92,6 +100,35 @@ def update_metadata_after_removal(snapshots_dir, removed_snapshots): with open(metadata_file, 'w') as f: json.dump(updated_metadata, f, indent=4) + @staticmethod + def create_snapshot_json(data_dir, snapshot_subdir): + files = [] + for root, _, filenames in os.walk(data_dir): + for filename in filenames: + file_path = os.path.join(root, filename) + rel_path = os.path.relpath(file_path, data_dir) + snapshot_rel_path = os.path.join(snapshot_subdir, rel_path).replace('\\', '/') + file_info = { + "name": filename, + "path": snapshot_rel_path, + "sha": Metadata.sha256(file_path) + } + files.append(file_info) + + return {"data": files} + + @staticmethod + def create_compressed_snapshot_json(compressed_file, rel_path): + compressed_file_size = os.path.getsize(compressed_file) + file_info = { + "name": os.path.basename(compressed_file), + "path": rel_path, + "sha": Metadata.sha256(compressed_file), + "size": compressed_file_size, + } + + return {"data": file_info} + def run_command(command): logging.info(f"Running command: {' '.join(command)}") @@ -160,39 +197,34 @@ def create_snapshot(self): logging.info(f"Creating snapshot directory '{snapshot_dir}'") os.makedirs(snapshot_dir, exist_ok=True) - data_dir = self.args.data_path - snapshot_metadata = {"name": timestamp_str, "created_at": get_current_time_iso(), - "compress": self.args.compress, "total_size": 0, "data": []} - - for root, _, files in os.walk(data_dir): - for file in files: - file_path = os.path.join(root, file) - file_name, file_ext = os.path.splitext(file) - compressed_file_name = f"{file_name}{file_ext}.{self.args.compress}" - compressed_file_path = os.path.join(snapshot_dir, compressed_file_name) - rel_path = os.path.relpath(compressed_file_path, self.args.snapshot_path) - - if rel_path.startswith('snapshots' + os.path.sep): - rel_path = rel_path[len('snapshots' + os.path.sep):] - - if self.args.compress == 'zip': - logging.info(f"Creating ZIP archive '{compressed_file_path}'") - with zipfile.ZipFile(compressed_file_path, 'w', zipfile.ZIP_DEFLATED) as zipf: - zipf.write(file_path, file) - elif self.args.compress == 'tar': - logging.info(f"Creating TAR archive '{compressed_file_path}'") - subprocess.run(['tar', '-cvf', compressed_file_path, '-C', os.path.dirname(file_path), file]) - - compressed_file_size = os.path.getsize(compressed_file_path) - snapshot_metadata["total_size"] += compressed_file_size - - file_info = { - "name": file_name, - "path": rel_path, - "sha": Metadata.sha256(compressed_file_path), - "size": compressed_file_size - } - snapshot_metadata["data"].append(file_info) + data_dir = os.path.join(snapshot_dir, 'data') + if self.args.compress == 'none': + logging.info(f"Copying data from '{self.args.data_path}' to '{data_dir}'") + shutil.copytree(self.args.data_path, data_dir) + snapshot_metadata = Metadata.create_snapshot_json(data_dir, timestamp_str) + elif self.args.compress == 'zip': + zip_file = os.path.join(snapshot_dir, 'data.zip') + rel = os.path.relpath(zip_file, snapshot_dir) + meta_path = os.path.join(timestamp_str, rel) + logging.info(f"Creating ZIP archive '{zip_file}'") + with zipfile.ZipFile(zip_file, 'w', zipfile.ZIP_DEFLATED) as zipf: + for root, _, files in os.walk(self.args.data_path): + for file in files: + full_path = os.path.join(root, file) + rel_path = os.path.relpath(full_path, self.args.data_path) + zipf.write(full_path, os.path.join('data', rel_path)) + snapshot_metadata = Metadata.create_compressed_snapshot_json(zip_file, meta_path) + elif self.args.compress == 'tar': + tar_file = os.path.join(snapshot_dir, 'data.tar.gz') + rel = os.path.relpath(tar_file, snapshot_dir) + meta_path = os.path.join(timestamp_str, rel) + logging.info(f"Creating TAR.GZ archive '{tar_file}'") + subprocess.run(['tar', '-czvf', tar_file, '-C', self.args.data_path, '.']) + snapshot_metadata = Metadata.create_compressed_snapshot_json(tar_file, meta_path) + + snapshot_metadata["name"] = timestamp_str + snapshot_metadata["created_at"] = get_current_time_iso() + snapshot_metadata["compress"] = self.args.compress Metadata.update_metadata_file(self.args.snapshot_path, snapshot_metadata) @@ -268,7 +300,7 @@ def parse_args(): parser = argparse.ArgumentParser(description='Pactus Blockchain Snapshot Tool') parser.add_argument('--service_path', required=True, help='Path to pactus systemctl service') parser.add_argument('--data_path', default=default_data_path, help='Path to data directory') - parser.add_argument('--compress', choices=['zip', 'tar'], default='zip', help='Compression type') + parser.add_argument('--compress', choices=['none', 'zip', 'tar'], default='none', help='Compression type') parser.add_argument('--retention', type=int, default=3, help='Number of snapshots to retain') parser.add_argument('--snapshot_path', default=os.getcwd(), help='Path to store snapshots') diff --git a/util/io.go b/util/io.go index c32ce1975..9b257d997 100644 --- a/util/io.go +++ b/util/io.go @@ -7,6 +7,7 @@ import ( "io" "os" "path/filepath" + "strings" ) func IsAbsPath(p string) bool { @@ -72,6 +73,7 @@ func TempFilePath() string { return filepath.Join(TempDirPath(), "file") } +// IsDirEmpty checks if a directory is empty. func IsDirEmpty(name string) bool { f, err := os.Open(name) if err != nil { @@ -88,6 +90,8 @@ func IsDirEmpty(name string) bool { return errors.Is(err, io.EOF) } +// IsDirNotExistsOrEmpty returns true if a directory does not exist or is empty. +// It checks if the path exists and, if so, whether the directory is empty. func IsDirNotExistsOrEmpty(name string) bool { if !PathExists(name) { return true @@ -196,3 +200,29 @@ func NewFixedReader(max int, buf []byte) *FixedReader { return &fr } + +// MoveDirectory moves a directory from srcDir to dstDir, including all its contents. +// If dstDir already exists and is not empty, it returns an error. +func MoveDirectory(srcDir, dstDir string) error { + if !IsDirNotExistsOrEmpty(dstDir) { + return fmt.Errorf("destination directory %s already exists", dstDir) + } + + if err := os.Rename(srcDir, dstDir); err != nil { + return fmt.Errorf("failed to move directory from %s to %s: %w", srcDir, dstDir, err) + } + + return nil +} + +// SanitizeArchivePath mitigates the "Zip Slip" vulnerability by sanitizing archive file paths. +// It ensures that the file path is contained within the specified base directory to prevent directory +// traversal attacks. For more details on the vulnerability, see https://snyk.io/research/zip-slip-vulnerability. +func SanitizeArchivePath(baseDir, archivePath string) (fullPath string, err error) { + fullPath = filepath.Join(baseDir, archivePath) + if strings.HasPrefix(fullPath, filepath.Clean(baseDir)) { + return fullPath, nil + } + + return "", fmt.Errorf("%s: %s", "content filepath is tainted", archivePath) +} diff --git a/util/io_test.go b/util/io_test.go index fec6ea052..aec40cd69 100644 --- a/util/io_test.go +++ b/util/io_test.go @@ -2,7 +2,9 @@ package util import ( "fmt" + "os" "os/exec" + "path/filepath" "runtime" "strconv" "testing" @@ -92,3 +94,83 @@ func TestIsValidPath(t *testing.T) { assert.True(t, IsValidDirPath("/tmp")) assert.True(t, IsValidDirPath("/tmp/pactus")) } + +func TestMoveDirectory(t *testing.T) { + // Create temporary directories + srcDir := TempDirPath() + dstDir := TempDirPath() + defer func() { _ = os.RemoveAll(srcDir) }() + defer func() { _ = os.RemoveAll(dstDir) }() + + // Create a subdirectory in the source directory + subDir := filepath.Join(srcDir, "subdir") + err := Mkdir(subDir) + assert.NoError(t, err) + + // Create multiple files in the subdirectory + files := []struct { + name string + content string + }{ + {"file1.txt", "content 1"}, + {"file2.txt", "content 2"}, + } + + for _, file := range files { + filePath := filepath.Join(subDir, file.name) + err = WriteFile(filePath, []byte(file.content)) + assert.NoError(t, err) + } + + // Move the directory + dstDirPath := filepath.Join(dstDir, "movedir") + err = MoveDirectory(srcDir, dstDirPath) + assert.NoError(t, err) + + // Assert the source directory no longer exists + assert.False(t, PathExists(srcDir)) + + // Assert the destination directory exists + assert.True(t, PathExists(dstDirPath)) + + // Verify that all files have been moved and their contents are correct + for _, file := range files { + movedFilePath := filepath.Join(dstDirPath, "subdir", file.name) + data, err := ReadFile(movedFilePath) + assert.NoError(t, err) + assert.Equal(t, file.content, string(data)) + } +} + +func TestSanitizeArchivePath(t *testing.T) { + if runtime.GOOS == "windows" { + return + } + + baseDir := "/safe/directory" + + tests := []struct { + name string + inputPath string + expected string + expectErr bool + }{ + {"Valid path", "file.txt", "/safe/directory/file.txt", false}, + {"Valid path in subdirectory", "subdir/file.txt", "/safe/directory/subdir/file.txt", false}, + {"Path with parent directory traversal", "../outside/file.txt", "", true}, + {"Absolute path outside base directory", "/etc/passwd", "/safe/directory/etc/passwd", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := SanitizeArchivePath(baseDir, tt.inputPath) + if tt.expectErr { + assert.Error(t, err, "Expected error but got none") + assert.Empty(t, result, "Expected empty result due to error") + } else { + assert.NoError(t, err, "Unexpected error occurred") + assert.Equal(t, tt.expected, result, "Sanitized path did not match expected") + } + }) + } +} diff --git a/util/utils.go b/util/utils.go index c84c5bc88..94ec20bd1 100644 --- a/util/utils.go +++ b/util/utils.go @@ -109,8 +109,8 @@ func IsFlagSet[T constraints.Integer](flags, mask T) bool { // OS2IP converts an octet string to a nonnegative integer. // OS2IP: https://datatracker.ietf.org/doc/html/rfc8017#section-4.2 -func OS2IP(os []byte) *big.Int { - return new(big.Int).SetBytes(os) +func OS2IP(x []byte) *big.Int { + return new(big.Int).SetBytes(x) } // I2OSP converts a nonnegative integer to an octet string of a specified length.