From dd1f999a4be949602c4d4e57e51855fa4393fd29 Mon Sep 17 00:00:00 2001 From: "Sean E. Russell" Date: Wed, 10 Aug 2022 15:32:26 -0500 Subject: [PATCH 1/5] Minimally addresses #71: don't swallow errors. Also return the right exit error code when errors are encountered. --- cmd/root.go | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 528d2ec..8020746 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -3,6 +3,7 @@ package main import ( "errors" "fmt" + "os" "strings" "time" @@ -15,9 +16,10 @@ import ( var ( // Flags. - owner string - description string - confirm bool + owner string + description string + confirm bool + error_encountered bool // Commands. rootCmd = &cobra.Command{} @@ -39,8 +41,14 @@ var ( Run: func(cmd *cobra.Command, args []string) { config := cli.MustLoadConfigFile() server := cli.GetServer(config) - server.Up() - utils.ShellOut(config.PostUp, "PostUp") + if e := server.Up(); e != nil { + fmt.Printf("error bringing up the network: %s\n", e) + error_encountered = true + } + if e := utils.ShellOut(config.PostUp, "PostUp"); e != nil { + fmt.Printf("error bringing up the network: %s\n", e) + error_encountered = true + } }, } @@ -50,8 +58,14 @@ var ( Run: func(cmd *cobra.Command, args []string) { config := cli.MustLoadConfigFile() server := cli.GetServer(config) - server.DeleteLink() - utils.ShellOut(config.PostDown, "PostDown") + if e := server.DeleteLink(); e != nil { + fmt.Printf("error bringing up the network: %s\n", e) + error_encountered = true + } + if e := utils.ShellOut(config.PostDown, "PostDown"); e != nil { + fmt.Printf("error bringing up the network: %s\n", e) + error_encountered = true + } }, } @@ -170,4 +184,8 @@ func main() { if err := rootCmd.Execute(); err != nil { cli.ExitFail(err.Error()) } + if error_encountered { + os.Exit(1) + } + os.Exit(0) } From da76ddbbdabd2035846d73fb4d66fd8ad1263dc3 Mon Sep 17 00:00:00 2001 From: "Sean E. Russell" Date: Thu, 11 Aug 2022 04:44:16 -0500 Subject: [PATCH 2/5] Return errors immediately, rather than cache-and-return-later --- cmd/root.go | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 8020746..c172907 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -16,10 +16,9 @@ import ( var ( // Flags. - owner string - description string - confirm bool - error_encountered bool + owner string + description string + confirm bool // Commands. rootCmd = &cobra.Command{} @@ -38,34 +37,32 @@ var ( upCmd = &cobra.Command{ Use: "up", Short: "Create the interface, run pre/post up, sync", - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { config := cli.MustLoadConfigFile() server := cli.GetServer(config) if e := server.Up(); e != nil { - fmt.Printf("error bringing up the network: %s\n", e) - error_encountered = true + return e } if e := utils.ShellOut(config.PostUp, "PostUp"); e != nil { - fmt.Printf("error bringing up the network: %s\n", e) - error_encountered = true + return e } + return nil }, } downCmd = &cobra.Command{ Use: "down", Short: "Destroy the interface, run pre/post down", - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { config := cli.MustLoadConfigFile() server := cli.GetServer(config) if e := server.DeleteLink(); e != nil { - fmt.Printf("error bringing up the network: %s\n", e) - error_encountered = true + return e } if e := utils.ShellOut(config.PostDown, "PostDown"); e != nil { - fmt.Printf("error bringing up the network: %s\n", e) - error_encountered = true + return e } + return nil }, } From 9bf5693cd9a98fb50ad643b91221a1b84ae7cb9f Mon Sep 17 00:00:00 2001 From: "Sean E. Russell" Date: Thu, 11 Aug 2022 04:56:33 -0500 Subject: [PATCH 3/5] Cascade errors back up to main(), rather than exiting deep in the stack. --- cmd/cli/add.go | 19 ++++++++++++++----- cmd/cli/regenerate.go | 21 +++++++++++++++------ cmd/cli/remove.go | 11 ++++++++--- cmd/cli/report.go | 6 ++++-- cmd/cli/sync.go | 11 ++++++++--- cmd/cli/util.go | 4 ++++ cmd/root.go | 23 ++++++++++------------- 7 files changed, 63 insertions(+), 32 deletions(-) diff --git a/cmd/cli/add.go b/cmd/cli/add.go index d922844..9006bd3 100644 --- a/cmd/cli/add.go +++ b/cmd/cli/add.go @@ -9,10 +9,12 @@ import ( ) // Add prompts for the required information and creates a new peer -func Add(hostname, owner, description string, confirm bool) { +func Add(hostname, owner, description string, confirm bool) error { // TODO accept existing pubkey config, err := LoadConfigFile() - check(err, "failed to load configuration file") + if err != nil { + return wrapError(err, "failed to load configuration file") + } server := GetServer(config) if owner == "" { @@ -31,7 +33,9 @@ func Add(hostname, owner, description string, confirm bool) { fmt.Fprintln(os.Stderr) peer, err := lib.NewPeer(server, owner, hostname, description) - check(err, "failed to get new peer") + if err != nil { + return wrapError(err, "failed to get new peer") + } // TODO Some kind of recovery here would be nice, to avoid // leaving things in a potential broken state @@ -41,12 +45,17 @@ func Add(hostname, owner, description string, confirm bool) { peerType := viper.GetString("output") peerConfigBytes, err := lib.AsciiPeerConfig(peer, peerType, *server) - check(err, "failed to get peer configuration") + if err != nil { + return wrapError(err, "failed to get peer configuration") + } os.Stdout.Write(peerConfigBytes.Bytes()) config.MustSave() server = GetServer(config) err = server.ConfigureDevice() - check(err, "failed to configure device") + if err != nil { + return wrapError(err, "failed to configure device") + } + return nil } diff --git a/cmd/cli/regenerate.go b/cmd/cli/regenerate.go index b1c6676..27fb522 100644 --- a/cmd/cli/regenerate.go +++ b/cmd/cli/regenerate.go @@ -8,7 +8,7 @@ import ( "github.com/spf13/viper" ) -func Regenerate(hostname string, confirm bool) { +func Regenerate(hostname string, confirm bool) error { config := MustLoadConfigFile() server := GetServer(config) @@ -21,22 +21,30 @@ func Regenerate(hostname string, confirm bool) { for _, peer := range server.Peers { if peer.Hostname == hostname { privateKey, err := lib.GenerateJSONPrivateKey() - check(err, "failed to generate private key") + if err != nil { + return wrapError(err, "failed to generate private key") + } preshareKey, err := lib.GenerateJSONKey() - check(err, "failed to generate preshared key") + if err != nil { + return wrapError(err, "failed to generate preshared key") + } peer.PrivateKey = privateKey peer.PublicKey = privateKey.PublicKey() peer.PresharedKey = preshareKey err = config.RemovePeer(hostname) - check(err, "failed to regenerate peer") + if err != nil { + return wrapError(err, "failed to regenerate peer") + } peerType := viper.GetString("output") peerConfigBytes, err := lib.AsciiPeerConfig(peer, peerType, *server) - check(err, "failed to get peer configuration") + if err != nil { + return wrapError(err, "failed to get peer configuration") + } os.Stdout.Write(peerConfigBytes.Bytes()) found = true config.MustAddPeer(peer) @@ -46,11 +54,12 @@ func Regenerate(hostname string, confirm bool) { } if !found { - ExitFail(fmt.Sprintf("unknown hostname: %s", hostname)) + return fmt.Errorf("unknown hostname: %s", hostname) } // Get a new server configuration so we can update the wg interface with the new peer details server = GetServer(config) config.MustSave() server.ConfigureDevice() + return nil } diff --git a/cmd/cli/remove.go b/cmd/cli/remove.go index 8eba3e9..efdff0b 100644 --- a/cmd/cli/remove.go +++ b/cmd/cli/remove.go @@ -2,11 +2,13 @@ package cli import "fmt" -func Remove(hostname string, confirm bool) { +func Remove(hostname string, confirm bool) error { conf := MustLoadConfigFile() err := conf.RemovePeer(hostname) - check(err, "failed to update config") + if err != nil { + return wrapError(err, "failed to update config") + } if !confirm { ConfirmOrAbort("Do you really want to remove %s?", hostname) @@ -16,5 +18,8 @@ func Remove(hostname string, confirm bool) { server := GetServer(conf) err = server.ConfigureDevice() - check(err, fmt.Sprintf("failed to sync server config to wg interface: %s", server.InterfaceName)) + if err != nil { + return wrapError(err, fmt.Sprintf("failed to sync server config to wg interface: %s", server.InterfaceName)) + } + return nil } diff --git a/cmd/cli/report.go b/cmd/cli/report.go index f6cad60..6a185ea 100644 --- a/cmd/cli/report.go +++ b/cmd/cli/report.go @@ -2,6 +2,7 @@ package cli import ( "encoding/json" + "fmt" "io/ioutil" "net" "os" @@ -69,7 +70,7 @@ type PeerReport struct { TransmitBytesSI string } -func GenerateReport() { +func GenerateReport() error { conf := MustLoadConfigFile() wg, err := wgctrl.New() @@ -79,12 +80,13 @@ func GenerateReport() { dev, err := wg.Device(conf.InterfaceName) if err != nil { - ExitFail("Could not retrieve device '%s' (%v)", conf.InterfaceName, err) + return wrapError(err, fmt.Sprintf("Could not retrieve device '%s'", conf.InterfaceName)) } oldReport := MustLoadDsnetReport() report := GetReport(dev, conf, oldReport) report.MustSave() + return nil } func GetReport(dev *wgtypes.Device, conf *DsnetConfig, oldReport *DsnetReport) DsnetReport { diff --git a/cmd/cli/sync.go b/cmd/cli/sync.go index 50c5bc9..f523d87 100644 --- a/cmd/cli/sync.go +++ b/cmd/cli/sync.go @@ -1,10 +1,15 @@ package cli -func Sync() { +func Sync() error { // TODO check device settings first conf, err := LoadConfigFile() - check(err, "failed to load configuration file") + if err != nil { + return wrapError(err, "failed to load configuration file") + } server := GetServer(conf) err = server.ConfigureDevice() - check(err, "failed to sync device configuration") + if err != nil { + return wrapError(err, "failed to sync device configuration") + } + return nil } diff --git a/cmd/cli/util.go b/cmd/cli/util.go index 461f3bb..c9e1c93 100644 --- a/cmd/cli/util.go +++ b/cmd/cli/util.go @@ -42,6 +42,10 @@ func ExitFail(format string, a ...interface{}) { os.Exit(1) } +func wrapError(err error, s string) error { + return fmt.Errorf("\033[31m%s - %s\033[0m\n", err, s) +} + func MustPromptString(prompt string, required bool) string { reader := bufio.NewReader(os.Stdin) var text string diff --git a/cmd/root.go b/cmd/root.go index c172907..cde5137 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -76,8 +76,8 @@ var ( } return nil }, - Run: func(cmd *cobra.Command, args []string) { - cli.Add(args[0], owner, description, confirm) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.Add(args[0], owner, description, confirm) }, } @@ -90,24 +90,24 @@ var ( } return nil }, - Run: func(cmd *cobra.Command, args []string) { - cli.Regenerate(args[0], confirm) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.Regenerate(args[0], confirm) }, } syncCmd = &cobra.Command{ Use: "sync", Short: fmt.Sprintf("Update wireguard configuration from %s after validating", viper.GetString("config_file")), - Run: func(cmd *cobra.Command, args []string) { - cli.Sync() + RunE: func(cmd *cobra.Command, args []string) error { + return cli.Sync() }, } reportCmd = &cobra.Command{ Use: "report", Short: fmt.Sprintf("Generate a JSON status report to the location configured in %s.", viper.GetString("config_file")), - Run: func(cmd *cobra.Command, args []string) { - cli.GenerateReport() + RunE: func(cmd *cobra.Command, args []string) error { + return cli.GenerateReport() }, } @@ -122,8 +122,8 @@ var ( return nil }, - Run: func(cmd *cobra.Command, args []string) { - cli.Remove(args[0], confirm) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.Remove(args[0], confirm) }, } @@ -181,8 +181,5 @@ func main() { if err := rootCmd.Execute(); err != nil { cli.ExitFail(err.Error()) } - if error_encountered { - os.Exit(1) - } os.Exit(0) } From 94a36f68c919da632902ce975398d8368df742af Mon Sep 17 00:00:00 2001 From: "Sean E. Russell" Date: Thu, 11 Aug 2022 05:15:26 -0500 Subject: [PATCH 4/5] Eggbeater missed one. --- cmd/cli/init.go | 50 ++++++++++++++++++++++++++++++++++--------------- cmd/root.go | 4 ++-- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/cmd/cli/init.go b/cmd/cli/init.go index 14448aa..71ef6ef 100644 --- a/cmd/cli/init.go +++ b/cmd/cli/init.go @@ -15,7 +15,7 @@ import ( "github.com/spf13/viper" ) -func Init() { +func Init() error { reportFile := viper.GetString("report_file") listenPort := viper.GetInt("listen_port") configFile := viper.GetString("config_file") @@ -24,14 +24,23 @@ func Init() { _, err := os.Stat(configFile) if !os.IsNotExist(err) { - ExitFail("Refusing to overwrite existing %s", configFile) + return wrapError(err, fmt.Sprintf("Refusing to overwrite existing %s", configFile)) } privateKey, err := lib.GenerateJSONPrivateKey() - check(err, "failed to generate private key") + if err != nil { + return wrapError(err, "failed to generate private key") + } externalIPV4, err := getExternalIP() - check(err) + if err != nil { + return err + } + + externalIPV6, err := getExternalIP6() + if err != nil { + return err + } conf := &DsnetConfig{ PrivateKey: privateKey, @@ -42,7 +51,7 @@ func Init() { Domain: "dsnet", ReportFile: reportFile, ExternalIP: externalIPV4, - ExternalIP6: getExternalIP6(), + ExternalIP6: externalIPV6, InterfaceName: interfaceName, Networks: []lib.JSONIPNet{}, } @@ -50,21 +59,26 @@ func Init() { server := GetServer(conf) ipv4, err := server.AllocateIP() - check(err, "failed to allocate ipv4 address") + if err != nil { + return wrapError(err, "failed to allocate ipv4 address") + } ipv6, err := server.AllocateIP6() - check(err, "failed to allocate ipv6 address") + if err != nil { + return wrapError(err, "failed to allocate ipv6 address") + } conf.IP = ipv4 conf.IP6 = ipv6 if len(conf.ExternalIP) == 0 && len(conf.ExternalIP6) == 0 { - ExitFail("Could not determine any external IP, v4 or v6") + return fmt.Errorf("Could not determine any external IP, v4 or v6") } conf.MustSave() fmt.Printf("Config written to %s. Please check/edit.\n", configFile) + return nil } // get a random IPv4 /22 subnet on 10.0.0.0 (1023 hosts) (or /24?) @@ -120,12 +134,16 @@ func getExternalIP() (net.IP, error) { Timeout: 5 * time.Second, } resp, err := client.Get("https://ipv4.icanhazip.com/") - check(err) + if err != nil { + return nil, err + } defer resp.Body.Close() if resp.StatusCode == http.StatusOK { body, err := ioutil.ReadAll(resp.Body) - check(err) + if err != nil { + return nil, err + } IP = net.ParseIP(strings.TrimSpace(string(body))) return IP.To4(), nil } @@ -133,7 +151,7 @@ func getExternalIP() (net.IP, error) { return nil, errors.New("failed to determine external ip") } -func getExternalIP6() net.IP { +func getExternalIP6() (net.IP, error) { var IP net.IP conn, err := net.Dial("udp", "2001:4860:4860::8888:53") if err == nil { @@ -144,7 +162,7 @@ func getExternalIP6() net.IP { // check is not a ULA if IP[0] != 0xfd && IP[0] != 0xfc { - return IP + return IP, nil } } @@ -157,11 +175,13 @@ func getExternalIP6() net.IP { if resp.StatusCode == http.StatusOK { body, err := ioutil.ReadAll(resp.Body) - check(err) + if err != nil { + return nil, err + } IP = net.ParseIP(strings.TrimSpace(string(body))) - return IP + return IP, nil } } - return net.IP{} + return net.IP{}, nil } diff --git a/cmd/root.go b/cmd/root.go index cde5137..720700b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -29,8 +29,8 @@ var ( "Create %s containing default configuration + new keys without loading. Edit to taste.", viper.GetString("config_file"), ), - Run: func(cmd *cobra.Command, args []string) { - cli.Init() + RunE: func(cmd *cobra.Command, args []string) error { + return cli.Init() }, } From ae67bf4311d3c5ae870c5a2d703655555eafd758 Mon Sep 17 00:00:00 2001 From: "Sean E. Russell" Date: Thu, 11 Aug 2022 06:22:57 -0500 Subject: [PATCH 5/5] Expands error cascading. This eliminates deeply nested terminations, allowing for the opportunity of better state clean-up. --- cmd/cli/add.go | 29 ++++++++++++++++--------- cmd/cli/config.go | 24 --------------------- cmd/cli/init.go | 12 ++++++----- cmd/cli/regenerate.go | 21 ++++++++++++------ cmd/cli/remove.go | 19 +++++++++------- cmd/cli/report.go | 50 +++++++++++++++++++++++++++---------------- cmd/cli/sync.go | 6 ++++-- cmd/cli/types.go | 17 +++++++++------ cmd/cli/util.go | 32 +++++++++------------------ cmd/root.go | 17 +++++++++++---- 10 files changed, 119 insertions(+), 108 deletions(-) diff --git a/cmd/cli/add.go b/cmd/cli/add.go index 9006bd3..dabcd68 100644 --- a/cmd/cli/add.go +++ b/cmd/cli/add.go @@ -13,15 +13,21 @@ func Add(hostname, owner, description string, confirm bool) error { // TODO accept existing pubkey config, err := LoadConfigFile() if err != nil { - return wrapError(err, "failed to load configuration file") + return fmt.Errorf("%w - failed to load configuration file", err) } server := GetServer(config) if owner == "" { - owner = MustPromptString("owner", true) + owner, err = PromptString("owner", true) + if err != nil { + return fmt.Errorf("%w - invalid input for owner", err) + } } if description == "" { - description = MustPromptString("Description", true) + description, err = PromptString("Description", true) + if err != nil { + return fmt.Errorf("%w - invalid input for Description", err) + } } // publicKey := MustPromptString("PublicKey (optional)", false) @@ -34,28 +40,31 @@ func Add(hostname, owner, description string, confirm bool) error { peer, err := lib.NewPeer(server, owner, hostname, description) if err != nil { - return wrapError(err, "failed to get new peer") + return fmt.Errorf("%w - failed to get new peer", err) } // TODO Some kind of recovery here would be nice, to avoid // leaving things in a potential broken state - config.MustAddPeer(peer) + if err = config.AddPeer(peer); err != nil { + return fmt.Errorf("%w - failed to add new peer", err) + } peerType := viper.GetString("output") peerConfigBytes, err := lib.AsciiPeerConfig(peer, peerType, *server) if err != nil { - return wrapError(err, "failed to get peer configuration") + return fmt.Errorf("%w - failed to get peer configuration", err) } os.Stdout.Write(peerConfigBytes.Bytes()) - config.MustSave() + if err = config.Save(); err != nil { + return fmt.Errorf("%w - failed to save config file", err) + } server = GetServer(config) - err = server.ConfigureDevice() - if err != nil { - return wrapError(err, "failed to configure device") + if err = server.ConfigureDevice(); err != nil { + return fmt.Errorf("%w - failed to configure device", err) } return nil } diff --git a/cmd/cli/config.go b/cmd/cli/config.go index ddace1f..e9d9570 100644 --- a/cmd/cli/config.go +++ b/cmd/cli/config.go @@ -99,12 +99,6 @@ func LoadConfigFile() (*DsnetConfig, error) { return &conf, nil } -func MustLoadConfigFile() *DsnetConfig { - config, err := LoadConfigFile() - check(err, "failed to load configuration file") - return config -} - // Save writes the configuration to disk func (conf *DsnetConfig) Save() error { configFile := viper.GetString("config_file") @@ -117,12 +111,6 @@ func (conf *DsnetConfig) Save() error { return nil } -// MustSave is like Save except it exits on error -func (conf *DsnetConfig) MustSave() { - err := conf.Save() - check(err, "failed to save config file") -} - // AddPeer adds a provided peer to the Peers list in the conf func (conf *DsnetConfig) AddPeer(peer lib.Peer) error { // TODO validate all PeerConfig (keys etc) @@ -162,12 +150,6 @@ func (conf *DsnetConfig) AddPeer(peer lib.Peer) error { return nil } -// MustAddPeer is like AddPeer, except it exist on error -func (conf *DsnetConfig) MustAddPeer(peer lib.Peer) { - err := conf.AddPeer(peer) - check(err) -} - // RemovePeer removes a peer from the peer list based on hostname func (conf *DsnetConfig) RemovePeer(hostname string) error { peerIndex := -1 @@ -188,12 +170,6 @@ func (conf *DsnetConfig) RemovePeer(hostname string) error { return nil } -// MustRemovePeer is like RemovePeer, except it exits on error -func (conf *DsnetConfig) MustRemovePeer(hostname string) { - err := conf.RemovePeer(hostname) - check(err) -} - func (conf DsnetConfig) GetWgPeerConfigs() []wgtypes.PeerConfig { wgPeers := make([]wgtypes.PeerConfig, 0, len(conf.Peers)) diff --git a/cmd/cli/init.go b/cmd/cli/init.go index 71ef6ef..a26a3f3 100644 --- a/cmd/cli/init.go +++ b/cmd/cli/init.go @@ -24,12 +24,12 @@ func Init() error { _, err := os.Stat(configFile) if !os.IsNotExist(err) { - return wrapError(err, fmt.Sprintf("Refusing to overwrite existing %s", configFile)) + return fmt.Errorf("%w - Refusing to overwrite existing %s", err, configFile) } privateKey, err := lib.GenerateJSONPrivateKey() if err != nil { - return wrapError(err, "failed to generate private key") + return fmt.Errorf("%w - failed to generate private key", err) } externalIPV4, err := getExternalIP() @@ -60,12 +60,12 @@ func Init() error { ipv4, err := server.AllocateIP() if err != nil { - return wrapError(err, "failed to allocate ipv4 address") + return fmt.Errorf("%w - failed to allocate ipv4 address", err) } ipv6, err := server.AllocateIP6() if err != nil { - return wrapError(err, "failed to allocate ipv6 address") + return fmt.Errorf("%w - failed to allocate ipv6 address", err) } conf.IP = ipv4 @@ -75,7 +75,9 @@ func Init() error { return fmt.Errorf("Could not determine any external IP, v4 or v6") } - conf.MustSave() + if err := conf.Save(); err != nil { + return fmt.Errorf("%w - failed to save config file", err) + } fmt.Printf("Config written to %s. Please check/edit.\n", configFile) return nil diff --git a/cmd/cli/regenerate.go b/cmd/cli/regenerate.go index 27fb522..8da8e17 100644 --- a/cmd/cli/regenerate.go +++ b/cmd/cli/regenerate.go @@ -9,7 +9,10 @@ import ( ) func Regenerate(hostname string, confirm bool) error { - config := MustLoadConfigFile() + config, err := LoadConfigFile() + if err != nil { + return fmt.Errorf("%w - failure to load config file", err) + } server := GetServer(config) found := false @@ -22,12 +25,12 @@ func Regenerate(hostname string, confirm bool) error { if peer.Hostname == hostname { privateKey, err := lib.GenerateJSONPrivateKey() if err != nil { - return wrapError(err, "failed to generate private key") + return fmt.Errorf("%w - failed to generate private key", err) } preshareKey, err := lib.GenerateJSONKey() if err != nil { - return wrapError(err, "failed to generate preshared key") + return fmt.Errorf("%w - failed to generate preshared key", err) } peer.PrivateKey = privateKey @@ -36,18 +39,20 @@ func Regenerate(hostname string, confirm bool) error { err = config.RemovePeer(hostname) if err != nil { - return wrapError(err, "failed to regenerate peer") + return fmt.Errorf("%w - failed to regenerate peer", err) } peerType := viper.GetString("output") peerConfigBytes, err := lib.AsciiPeerConfig(peer, peerType, *server) if err != nil { - return wrapError(err, "failed to get peer configuration") + return fmt.Errorf("%w - failed to get peer configuration", err) } os.Stdout.Write(peerConfigBytes.Bytes()) found = true - config.MustAddPeer(peer) + if err = config.AddPeer(peer); err != nil { + return fmt.Errorf("%w - failure to add peer", err) + } break } @@ -59,7 +64,9 @@ func Regenerate(hostname string, confirm bool) error { // Get a new server configuration so we can update the wg interface with the new peer details server = GetServer(config) - config.MustSave() + if err = config.Save(); err != nil { + return fmt.Errorf("%w - failure saving config", err) + } server.ConfigureDevice() return nil } diff --git a/cmd/cli/remove.go b/cmd/cli/remove.go index efdff0b..d986b9f 100644 --- a/cmd/cli/remove.go +++ b/cmd/cli/remove.go @@ -3,23 +3,26 @@ package cli import "fmt" func Remove(hostname string, confirm bool) error { - conf := MustLoadConfigFile() - - err := conf.RemovePeer(hostname) + conf, err := LoadConfigFile() if err != nil { - return wrapError(err, "failed to update config") + return fmt.Errorf("%w - failed to load config", err) + } + + if err = conf.RemovePeer(hostname); err != nil { + return fmt.Errorf("%w - failed to update config", err) } if !confirm { ConfirmOrAbort("Do you really want to remove %s?", hostname) } - conf.MustSave() + if err = conf.Save(); err != nil { + return fmt.Errorf("%w - failure to save config", err) + } server := GetServer(conf) - err = server.ConfigureDevice() - if err != nil { - return wrapError(err, fmt.Sprintf("failed to sync server config to wg interface: %s", server.InterfaceName)) + if err = server.ConfigureDevice(); err != nil { + return fmt.Errorf("%w - failed to sync server config to wg interface: %s", err, server.InterfaceName) } return nil } diff --git a/cmd/cli/report.go b/cmd/cli/report.go index 6a185ea..1023f93 100644 --- a/cmd/cli/report.go +++ b/cmd/cli/report.go @@ -71,25 +71,35 @@ type PeerReport struct { } func GenerateReport() error { - conf := MustLoadConfigFile() + conf, err := LoadConfigFile() + if err != nil { + return fmt.Errorf("%w - failure to load config", err) + } wg, err := wgctrl.New() - check(err) + if err != nil { + return fmt.Errorf("%w - failure to create new client", err) + } defer wg.Close() dev, err := wg.Device(conf.InterfaceName) if err != nil { - return wrapError(err, fmt.Sprintf("Could not retrieve device '%s'", conf.InterfaceName)) + return fmt.Errorf("%w - Could not retrieve device '%s'", err, conf.InterfaceName) } - oldReport := MustLoadDsnetReport() - report := GetReport(dev, conf, oldReport) - report.MustSave() - return nil + oldReport, err := LoadDsnetReport() + if err != nil { + return err + } + report, err := GetReport(dev, conf, oldReport) + if err != nil { + return err + } + return report.Save() } -func GetReport(dev *wgtypes.Device, conf *DsnetConfig, oldReport *DsnetReport) DsnetReport { +func GetReport(dev *wgtypes.Device, conf *DsnetConfig, oldReport *DsnetReport) (DsnetReport, error) { peerTimeout := viper.GetDuration("peer_timeout") peerExpiry := viper.GetDuration("peer_expiry") wgPeerIndex := make(map[wgtypes.Key]wgtypes.Peer) @@ -98,7 +108,9 @@ func GetReport(dev *wgtypes.Device, conf *DsnetConfig, oldReport *DsnetReport) D peersOnline := 0 linkDev, err := netlink.LinkByName(conf.InterfaceName) - check(err) + if err != nil { + return DsnetReport{}, fmt.Errorf("%w - error getting link", err) + } stats := linkDev.Attrs().Statistics @@ -173,37 +185,37 @@ func GetReport(dev *wgtypes.Device, conf *DsnetConfig, oldReport *DsnetReport) D ReceiveBytesSI: BytesToSI(stats.RxBytes), TransmitBytesSI: BytesToSI(stats.TxBytes), Timestamp: time.Now(), - } + }, nil } -func (report *DsnetReport) MustSave() { +func (report *DsnetReport) Save() error { reportFilePath := viper.GetString("report_file") _json, _ := json.MarshalIndent(report, "", " ") _json = append(_json, '\n') err := ioutil.WriteFile(reportFilePath, _json, 0644) - check(err) + return err } -func MustLoadDsnetReport() *DsnetReport { +func LoadDsnetReport() (*DsnetReport, error) { reportFilePath := viper.GetString("report_file_path") raw, err := ioutil.ReadFile(reportFilePath) if os.IsNotExist(err) { - return nil + return nil, err } else if os.IsPermission(err) { - ExitFail("%s cannot be accessed. Check read permissions.", reportFilePath) + return nil, fmt.Errorf("%s cannot be accessed. Check read permissions.", reportFilePath) } else { - check(err) + return nil, err } report := DsnetReport{} err = json.Unmarshal(raw, &report) - check(err) + return nil, err err = validator.New().Struct(report) - check(err) + return nil, err - return &report + return &report, nil } diff --git a/cmd/cli/sync.go b/cmd/cli/sync.go index f523d87..ff637c2 100644 --- a/cmd/cli/sync.go +++ b/cmd/cli/sync.go @@ -1,15 +1,17 @@ package cli +import "fmt" + func Sync() error { // TODO check device settings first conf, err := LoadConfigFile() if err != nil { - return wrapError(err, "failed to load configuration file") + return fmt.Errorf("%w - failed to load configuration file", err) } server := GetServer(conf) err = server.ConfigureDevice() if err != nil { - return wrapError(err, "failed to sync device configuration") + return fmt.Errorf("%w - failed to sync device configuration", err) } return nil } diff --git a/cmd/cli/types.go b/cmd/cli/types.go index cb84166..dee25f4 100644 --- a/cmd/cli/types.go +++ b/cmd/cli/types.go @@ -27,22 +27,25 @@ func (k *JSONKey) UnmarshalJSON(b []byte) error { return err } -func GenerateJSONPrivateKey() JSONKey { +func GenerateJSONPrivateKey() (JSONKey, error) { privateKey, err := wgtypes.GeneratePrivateKey() - - check(err) + if err != nil { + return JSONKey{}, err + } return JSONKey{ Key: privateKey, - } + }, nil } -func GenerateJSONKey() JSONKey { +func GenerateJSONKey() (JSONKey, error) { privateKey, err := wgtypes.GenerateKey() - check(err) + if err != nil { + return JSONKey{}, err + } return JSONKey{ Key: privateKey, - } + }, err } diff --git a/cmd/cli/util.go b/cmd/cli/util.go index c9e1c93..4dea73c 100644 --- a/cmd/cli/util.go +++ b/cmd/cli/util.go @@ -1,5 +1,7 @@ package cli +// FIXME every function in this file has public scope, but only private references + import ( "bufio" "fmt" @@ -9,15 +11,6 @@ import ( "github.com/naggie/dsnet/lib" ) -func check(e error, optMsg ...string) { - if e != nil { - if len(optMsg) > 0 { - ExitFail("%s - %s", e, strings.Join(optMsg, " ")) - } - ExitFail("%s", e) - } -} - func jsonPeerToDsnetPeer(peers []PeerConfig) []lib.Peer { libPeers := make([]lib.Peer, 0, len(peers)) for _, p := range peers { @@ -37,16 +30,7 @@ func jsonPeerToDsnetPeer(peers []PeerConfig) []lib.Peer { return libPeers } -func ExitFail(format string, a ...interface{}) { - fmt.Fprintf(os.Stderr, "\033[31m"+format+"\033[0m\n", a...) - os.Exit(1) -} - -func wrapError(err error, s string) error { - return fmt.Errorf("\033[31m%s - %s\033[0m\n", err, s) -} - -func MustPromptString(prompt string, required bool) string { +func PromptString(prompt string, required bool) (string, error) { reader := bufio.NewReader(os.Stdin) var text string var err error @@ -54,12 +38,15 @@ func MustPromptString(prompt string, required bool) string { for text == "" { fmt.Fprintf(os.Stderr, "%s: ", prompt) text, err = reader.ReadString('\n') - check(err) + if err != nil { + return "", fmt.Errorf("%w - error getting input", err) + } text = strings.TrimSpace(text) } - return text + return text, nil } +// FIXME is it critical for this to panic, or can we cascade the errors? func ConfirmOrAbort(format string, a ...interface{}) { fmt.Fprintf(os.Stderr, format+" [y/n] ", a...) @@ -73,7 +60,8 @@ func ConfirmOrAbort(format string, a ...interface{}) { if input == "y\n" { return } else { - ExitFail("Aborted.") + fmt.Fprintf(os.Stderr, "\033[31mAborted.\033[0m\n") + os.Exit(1) } } diff --git a/cmd/root.go b/cmd/root.go index 720700b..42f874a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -38,7 +38,10 @@ var ( Use: "up", Short: "Create the interface, run pre/post up, sync", RunE: func(cmd *cobra.Command, args []string) error { - config := cli.MustLoadConfigFile() + config, err := cli.LoadConfigFile() + if err != nil { + return fmt.Errorf("%w - failure to load config file", err) + } server := cli.GetServer(config) if e := server.Up(); e != nil { return e @@ -54,7 +57,10 @@ var ( Use: "down", Short: "Destroy the interface, run pre/post down", RunE: func(cmd *cobra.Command, args []string) error { - config := cli.MustLoadConfigFile() + config, err := cli.LoadConfigFile() + if err != nil { + return fmt.Errorf("%w - failure to load config file", err) + } server := cli.GetServer(config) if e := server.DeleteLink(); e != nil { return e @@ -150,7 +156,8 @@ func init() { viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) if err := viper.BindPFlag("output", rootCmd.PersistentFlags().Lookup("output")); err != nil { - cli.ExitFail(err.Error()) + fmt.Fprintf(os.Stderr, "\033[31m%s\033[0m\n", err.Error()) + os.Exit(1) } viper.SetDefault("config_file", "/etc/dsnetconfig.json") @@ -179,7 +186,9 @@ func init() { func main() { if err := rootCmd.Execute(); err != nil { - cli.ExitFail(err.Error()) + // Because of side effects in viper, this gets printed twice + fmt.Fprintf(os.Stderr, "\033[31m%s\033[0m\n", err.Error()) + os.Exit(1) } os.Exit(0) }