diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d36436b..db9f113 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -89,3 +89,15 @@ jobs: tags: ghcr.io/${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} platforms: linux/amd64,linux/arm64 + + test-windows: + runs-on: windows-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.21 + - name: Run Tests + run: go test ./... diff --git a/.gitignore b/.gitignore index 2630087..fc4c0b4 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,5 @@ ca/ es-gencert-cli certs.yml .DS_Store +*.crt +*.key diff --git a/README.md b/README.md index fac3410..7202be1 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,8 @@ certificates: dns-names: "localhost,eventstore-node2.localhost.com" ``` +If you want to specify the name of the certificates from the config file, you can add the name field to the certificate definition. You can see an example of this in the [example configuration](references/named_certs.yml). + ## Development Building or working on `es-gencert-cli` requires a Go environment, version 1.14 or higher. diff --git a/certificates/boring_linux.go b/certificates/boring_linux.go index 3f17c1b..ab8cd55 100644 --- a/certificates/boring_linux.go +++ b/certificates/boring_linux.go @@ -8,4 +8,4 @@ import ( func isBoringEnabled() bool { return boring.Enabled() -} \ No newline at end of file +} diff --git a/certificates/certificates.go b/certificates/certificates.go deleted file mode 100644 index c91262e..0000000 --- a/certificates/certificates.go +++ /dev/null @@ -1,60 +0,0 @@ -package certificates - -import ( - "log" - "os" - "strings" - - "github.com/mitchellh/cli" -) - -type Certificates struct { - Ui cli.Ui -} - -func (command *Certificates) Run(args []string) int { - ui := &cli.BasicUi{ - Reader: os.Stdin, - Writer: os.Stdout, - ErrorWriter: os.Stderr, - } - c := cli.NewCLI("Event Store CLI certificates", "") - c.Args = args - c.Commands = map[string]cli.CommandFactory{ - "create-ca": func() (cli.Command, error) { - return &CreateCA{ - Ui: &cli.ColoredUi{ - Ui: ui, - OutputColor: cli.UiColorBlue, - }, - }, nil - }, - "create-node": func() (cli.Command, error) { - return &CreateNode{ - Ui: &cli.ColoredUi{ - Ui: ui, - OutputColor: cli.UiColorBlue, - }, - }, nil - }, - } - exitStatus, err := c.Run() - if err != nil { - log.Println(err) - } - return exitStatus -} - -func (c *Certificates) Help() string { - helpText := ` -usage: certificates [--help] [] - -Available commands: -` - helpText += c.Synopsis() - return strings.TrimSpace(helpText) -} - -func (c *Certificates) Synopsis() string { - return "certificates (create_ca, create_node)" -} diff --git a/certificates/common.go b/certificates/common.go index 1c5d498..ab5e054 100644 --- a/certificates/common.go +++ b/certificates/common.go @@ -11,16 +11,18 @@ import ( "math/big" "os" "path" - "text/tabwriter" + "path/filepath" ) -const defaultKeySize = 2048 - -const forceOption = "Force overwrite of existing files without prompting" - const ( - ErrFileExists = "Error: Existing files would be overwritten. Use -force to proceed" + ForceFlagUsage = "Force overwrite of existing files without prompting" + NameFlagUsage = "The name of the CA certificate and key file" + OutDirFlagUsage = "The output directory" + DayFlagUsage = "the validity period of the certificate in days" + CaKeyFlagUsage = "the path to the CA key file" + CaCertFlagUsage = "the path to the CA certificate file" ) +const defaultKeySize = 2048 func generateSerialNumber(bits uint) (*big.Int, error) { maxValue := new(big.Int).Lsh(big.NewInt(1), bits) @@ -48,13 +50,9 @@ func writeFileWithDir(filePath string, data []byte, perm os.FileMode) error { return os.WriteFile(filePath, data, perm) } -func writeHelpOption(w *tabwriter.Writer, title string, description string) { - fmt.Fprintf(w, "\t-%s\t%s\n", title, description) -} - func writeCertAndKey(outputDir string, fileName string, certPem, privateKeyPem *bytes.Buffer, force bool) error { - certFile := path.Join(outputDir, fileName+".crt") - keyFile := path.Join(outputDir, fileName+".key") + certFile := filepath.ToSlash(fmt.Sprintf("%s/%s.crt", outputDir, fileName)) + keyFile := filepath.ToSlash(fmt.Sprintf("%s/%s.key", outputDir, fileName)) if force { if _, err := os.Stat(certFile); err == nil { @@ -76,19 +74,12 @@ func writeCertAndKey(outputDir string, fileName string, certPem, privateKeyPem * err = writeFileWithDir(keyFile, privateKeyPem.Bytes(), 0400) if err != nil { - return fmt.Errorf("error writing certificate private key to %s: %s", keyFile, err.Error()) + return fmt.Errorf("error writing private key to %s: %s", keyFile, err.Error()) } return nil } -func fileExists(path string, force bool) bool { - if _, err := os.Stat(path); !os.IsNotExist(err) && !force { - return true - } - return false -} - func readCertificateFromFile(path string) (*x509.Certificate, error) { pemBytes, err := os.ReadFile(path) if err != nil { @@ -124,3 +115,19 @@ func readRSAKeyFromFile(path string) (*rsa.PrivateKey, error) { } return key, nil } + +func checkCertificatesLocationWithForce(dir, certificateName string, force bool) error { + // Throw an error if the path for the CA and key certificates already + // exists and the 'force' flag is not set. + + checkFile := func(ext string) bool { + _, err := os.Stat(filepath.Join(dir, certificateName+ext)) + return !os.IsNotExist(err) + } + + if !force && (checkFile(".key") || checkFile(".crt")) { + return fmt.Errorf("existing files would be overwritten. Use -force to proceed") + } + + return nil +} diff --git a/certificates/common_test.go b/certificates/common_test.go index 1c4d9f9..c70d87a 100644 --- a/certificates/common_test.go +++ b/certificates/common_test.go @@ -3,42 +3,57 @@ package certificates import ( "crypto/rsa" "crypto/x509" + "fmt" "github.com/stretchr/testify/assert" - "os" - "path" + "path/filepath" + "regexp" + "strings" "testing" ) -func assertFilesExist(t *testing.T, files ...string) { - for _, file := range files { - _, err := os.Stat(file) - assert.False(t, os.IsNotExist(err)) - } -} +func extractErrors(errorMessage string) []string { + // Sometimes errors are shown in a multi-line format (multierror.Append), so we need to extract them and return them + // as a list. However, this method can be used with single line errors as well and will return a list with a single + // element. Also perform some basic cleanup of the error message (TrimSpace). -func generateAndAssertCACert(t *testing.T, years int, days int, outputDirCa string, force bool) (*x509.Certificate, *rsa.PrivateKey) { - certificateError := generateCACertificate(years, days, outputDirCa, nil, nil, force) - assert.NoError(t, certificateError) + var errors []string - certFilePath := path.Join(outputDirCa, "ca.crt") - keyFilePath := path.Join(outputDirCa, "ca.key") - assertFilesExist(t, certFilePath, keyFilePath) + // Pattern for multi-line errors + multiLinePattern := regexp.MustCompile(`\* (.+)`) + multiLineMatches := multiLinePattern.FindAllStringSubmatch(errorMessage, -1) - caCertificate, err := readCertificateFromFile(certFilePath) - assert.NoError(t, err) - caPrivateKey, err := readRSAKeyFromFile(keyFilePath) - assert.NoError(t, err) + ansiCodePattern := regexp.MustCompile(`\x1b\[[0-9;]*m`) + errorMessage = ansiCodePattern.ReplaceAllString(errorMessage, "") - return caCertificate, caPrivateKey -} - -func cleanupDirsForTest(t *testing.T, dirs ...string) { - cleanupDirs := func() { - for _, dir := range dirs { - os.RemoveAll(dir) + if len(multiLineMatches) > 0 { + for _, match := range multiLineMatches { + if len(match) > 1 { + errors = append(errors, strings.TrimSpace(match[1])) + } } + } else { + // Single line error + cleanedError := strings.TrimSpace(errorMessage) + errors = append(errors, cleanedError) } - cleanupDirs() - t.Cleanup(cleanupDirs) + return errors +} + +func readAndDecodeCertificateAndKey(t *testing.T, dir, name string) (*x509.Certificate, *rsa.PrivateKey) { + // In the test suite, we often need to verify that a certificate and key pair exist in a given directory. + // This is usually carried out after a call to the create_ca or create_node commands. This method reads the certificate + // and key from the given directory and returns them. It will throw an error if the certificate or key cannot be + // read from the given directory. + + certPath := filepath.Join(dir, fmt.Sprintf("%s.crt", name)) + keyPath := filepath.Join(dir, fmt.Sprintf("%s.key", name)) + + ca, caErr := readCertificateFromFile(certPath) + assert.NoError(t, caErr) + + key, keyErr := readRSAKeyFromFile(keyPath) + assert.NoError(t, keyErr) + + return ca, key } diff --git a/certificates/create_ca.go b/certificates/create_ca.go index a7d3337..52287cc 100644 --- a/certificates/create_ca.go +++ b/certificates/create_ca.go @@ -10,9 +10,6 @@ import ( "errors" "flag" "fmt" - "path" - "strings" - "text/tabwriter" "time" "github.com/hashicorp/go-multierror" @@ -20,7 +17,9 @@ import ( ) type CreateCA struct { - Ui cli.Ui + Ui cli.Ui + Config CreateCAArguments + Flags *flag.FlagSet } type CreateCAArguments struct { @@ -28,33 +27,38 @@ type CreateCAArguments struct { OutputDir string `yaml:"out"` CACertificatePath string `yaml:"ca-certificate"` CAKeyPath string `yaml:"ca-key"` + Name string `yaml:"name"` Force bool `yaml:"force"` } -func (c *CreateCA) Run(args []string) int { - var config CreateCAArguments - - flags := flag.NewFlagSet("create_ca", flag.ContinueOnError) - flags.Usage = func() { c.Ui.Info(c.Help()) } - flags.IntVar(&config.Days, "days", 0, "the validity period of the certificate in days") - flags.StringVar(&config.OutputDir, "out", "./ca", "The output directory") - flags.StringVar(&config.CACertificatePath, "ca-certificate", "", "the path to a CA certificate file") - flags.StringVar(&config.CAKeyPath, "ca-key", "", "the path to a CA key file") - flags.BoolVar(&config.Force, "force", false, forceOption) +func NewCreateCA(ui cli.Ui) *CreateCA { + c := &CreateCA{Ui: ui} + + c.Flags = flag.NewFlagSet("create_ca", flag.ContinueOnError) + c.Flags.IntVar(&c.Config.Days, "days", 0, DayFlagUsage) + c.Flags.StringVar(&c.Config.OutputDir, "out", "./ca", OutDirFlagUsage) + c.Flags.StringVar(&c.Config.CACertificatePath, "ca-certificate", "", CaCertFlagUsage) + c.Flags.StringVar(&c.Config.CAKeyPath, "ca-key", "", CaKeyFlagUsage) + c.Flags.StringVar(&c.Config.Name, "name", "ca", NameFlagUsage) + c.Flags.BoolVar(&c.Config.Force, "force", false, ForceFlagUsage) + return c +} - if err := flags.Parse(args); err != nil { +func (c *CreateCA) Run(args []string) int { + if err := c.Flags.Parse(args); err != nil { + c.Ui.Error(err.Error()) return 1 } validationErrors := new(multierror.Error) - if config.Days < 0 { - multierror.Append(validationErrors, errors.New("days must be positive")) + if c.Config.Days < 0 { + _ = multierror.Append(validationErrors, errors.New("days must be positive")) } - caCertPathLen := len(config.CACertificatePath) - caKeyPathLen := len(config.CAKeyPath) + caCertPathLen := len(c.Config.CACertificatePath) + caKeyPathLen := len(c.Config.CAKeyPath) if (caCertPathLen > 0 && caKeyPathLen == 0) || (caKeyPathLen > 0 && caCertPathLen == 0) { - multierror.Append(validationErrors, errors.New("both -ca-certificate and -ca-key options are required")) + _ = multierror.Append(validationErrors, errors.New("both -ca-certificate and -ca-key options are required")) } if validationErrors.ErrorOrNil() != nil { @@ -62,14 +66,9 @@ func (c *CreateCA) Run(args []string) int { return 1 } - // check if certificates already exist - if fileExists(path.Join(config.OutputDir, "ca.key"), config.Force) { - c.Ui.Error(ErrFileExists) - return 1 - } - - if fileExists(path.Join(config.OutputDir, "ca.crt"), config.Force) { - c.Ui.Error(ErrFileExists) + certErr := checkCertificatesLocationWithForce(c.Config.OutputDir, c.Config.Name, c.Config.Force) + if certErr != nil { + c.Ui.Error(certErr.Error()) return 1 } @@ -77,8 +76,8 @@ func (c *CreateCA) Run(args []string) int { years := 5 days := 0 - if config.Days != 0 { - days = config.Days + if c.Config.Days != 0 { + days = c.Config.Days years = 0 } @@ -86,13 +85,13 @@ func (c *CreateCA) Run(args []string) int { var caKey *rsa.PrivateKey var err error if caCertPathLen > 0 { - caCert, err = readCertificateFromFile(config.CACertificatePath) + caCert, err = readCertificateFromFile(c.Config.CACertificatePath) if err != nil { c.Ui.Error(err.Error()) return 1 } - caKey, err = readRSAKeyFromFile(config.CAKeyPath) + caKey, err = readRSAKeyFromFile(c.Config.CAKeyPath) if err != nil { err := fmt.Errorf("error: %s. please note that only RSA keys are currently supported", err.Error()) c.Ui.Error(err.Error()) @@ -100,10 +99,11 @@ func (c *CreateCA) Run(args []string) int { } } - outputDir := config.OutputDir - err = generateCACertificate(years, days, outputDir, caCert, caKey, config.Force) + outputDir := c.Config.OutputDir + err = generateCACertificate(years, days, outputDir, c.Config.Name, caCert, caKey, c.Config.Force) if err != nil { c.Ui.Error(err.Error()) + return 1 } else { if isBoringEnabled() { c.Ui.Output(fmt.Sprintf("A CA certificate & key file have been generated in the '%s' directory (FIPS mode enabled).", outputDir)) @@ -111,10 +111,11 @@ func (c *CreateCA) Run(args []string) int { c.Ui.Output(fmt.Sprintf("A CA certificate & key file have been generated in the '%s' directory.", outputDir)) } } + return 0 } -func generateCACertificate(years int, days int, outputDir string, caCert *x509.Certificate, caPrivateKey *rsa.PrivateKey, force bool) error { +func generateCACertificate(years int, days int, outputDir string, name string, caCert *x509.Certificate, caPrivateKey *rsa.PrivateKey, force bool) error { serialNumber, err := generateSerialNumber(128) if err != nil { return fmt.Errorf("could not generate 128-bit serial number: %s", err.Error()) @@ -184,29 +185,16 @@ func generateCACertificate(years int, days int, outputDir string, caCert *x509.C return fmt.Errorf("could not encode certificate to PEM format: %s", err.Error()) } - err = writeCertAndKey(outputDir, "ca", certPem, privateKeyPem, force) + err = writeCertAndKey(outputDir, name, certPem, privateKeyPem, force) return err } func (c *CreateCA) Help() string { - var buffer bytes.Buffer - - w := tabwriter.NewWriter(&buffer, 0, 0, 2, ' ', 0) - - fmt.Fprintln(w, "Usage: create_ca [options]") - fmt.Fprintln(w, c.Synopsis()) - fmt.Fprintln(w, "Options:") - - writeHelpOption(w, "days", "The validity period of the certificate in days (default: 5 years).") - writeHelpOption(w, "out", "The output directory (default: ./ca).") - writeHelpOption(w, "ca-certificate", "The path to a CA certificate file for creating an intermediate CA certificate.") - writeHelpOption(w, "ca-key", "The path to a CA key file for creating an intermediate CA certificate.") - writeHelpOption(w, "force", forceOption) - - w.Flush() - - return strings.TrimSpace(buffer.String()) + var helpText bytes.Buffer + c.Flags.SetOutput(&helpText) + c.Flags.PrintDefaults() + return helpText.String() } func (c *CreateCA) Synopsis() string { diff --git a/certificates/create_ca_test.go b/certificates/create_ca_test.go index 36ab422..35ccb23 100644 --- a/certificates/create_ca_test.go +++ b/certificates/create_ca_test.go @@ -1,77 +1,169 @@ package certificates import ( - "crypto/rsa" - "crypto/x509" - "path" - "testing" - + "bytes" + "fmt" + "github.com/mitchellh/cli" "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "testing" + "time" ) -func setupTestEnvironmentForCaTests(t *testing.T) (years int, days int, outputDir string, caCert *x509.Certificate, caKey *rsa.PrivateKey) { - years = 1 - days = 0 - outputDir = "./ca" - caCert = nil - caKey = nil +func TestCreateCACertificate(t *testing.T) { + t.Run("TestCreateCACertificate_NominalCase_ShouldSucceed", TestCreateCACertificate_NominalCase_ShouldSucceed) + t.Run("TestCreateCACertificate_DifferentOut_ShouldSucceed", TestCreateCACertificate_DifferentOut_ShouldSucceed) + t.Run("TestCreateCACertificate_WithNameFlag_ShouldCreateNamedCertificates", TestCreateCACertificate_WithNameFlag_ShouldCreateNamedCertificates) + t.Run("TestCreateCACertificate_WithForceFlag_ShouldRegenerate", TestCreateCACertificate_WithForceFlag_ShouldRegenerate) + t.Run("TestCreateIntermediateCertificate_WithoutRootCertificate_ShouldFail", TestCreateIntermediateCertificate_WithoutRootCertificate_ShouldFail) +} + +func TestCreateCACertificate_NominalCase_ShouldSucceed(t *testing.T) { + // Create CA certificate and key without any additional parameters. + + t.Parallel() + + cleanup, tempDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestCreateCAParams{ + OutputDir: "./ca", + }) + defer cleanup() + + var args []string + + result := createCa.Run(args) + assert.Equal(t, 0, result, "create-ca should pass without any additional parameters") + + assert.FileExists(t, filepath.Join("./ca", "ca.crt"), "CA certificate should exist") + assert.FileExists(t, filepath.Join("./ca", "ca.key"), "CA key should exist") - cleanupDirsForTest(t, outputDir) - return + cert, err := readCertificateFromFile(filepath.Join(tempDir, "ca.crt")) + assert.NoError(t, err, "Failed to read and parse certificate file") + + // The certificate should be valid for 5 year + now := time.Now().Truncate(time.Second) + expectedNotAfter := now.AddDate(5, 0, 0) + assert.WithinDuration(t, expectedNotAfter, cert.NotAfter, time.Second, "Certificate validity period does not match expected default") } -func testGenerateCACertificate(t *testing.T, years int, days int, outputDir string, caCert *x509.Certificate, caKey *rsa.PrivateKey, force bool) { - err := generateCACertificate(years, days, outputDir, caCert, caKey, force) - assert.NoError(t, err, "Expected no error in nominal case") +func TestCreateCACertificate_DifferentOut_ShouldSucceed(t *testing.T) { + // Create certificate with a different output directory. - certFilePath := path.Join(outputDir, "ca.crt") - keyFilePath := path.Join(outputDir, "ca.key") + t.Parallel() - certFile, err := readCertificateFromFile(certFilePath) - assert.NoError(t, err) - keyFile, err := readRSAKeyFromFile(keyFilePath) - assert.NoError(t, err) + cleanup, tempCaDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestCreateCAParams{}) + defer cleanup() - err = generateCACertificate(years, days, outputDir, caCert, caKey, force) - if !force { - assert.Error(t, err, "Expected an error when directory exists and override is false") - } else { - assert.NoError(t, err, "Expected no error when directory exists and override is true") - } + args := []string{"-out", filepath.Join(tempCaDir, "my-custom-dir")} + + result := createCa.Run(args) + assert.Equal(t, 0, result, "creat-ca should pass with a different output") + + assert.FileExists(t, filepath.Join(tempCaDir, "my-custom-dir", "ca.crt"), "CA certificate should exist") + assert.FileExists(t, filepath.Join(tempCaDir, "my-custom-dir", "ca.key"), "CA key should exist") +} + +func TestCreateCACertificate_WithNameFlag_ShouldCreateNamedCertificates(t *testing.T) { + // Create CA certificate and key with the name parameter. + // 1. It creates a certificate with the name parameter + // 2. The CA certificate and key should be named with the name parameter + + t.Parallel() + + cleanup, tempCaDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestCreateCAParams{}) + defer cleanup() + + args := []string{"-out", tempCaDir, "-name", "my-custom-name"} + + result := createCa.Run(args) + assert.Equal(t, 0, result, "creat-ca should create a certificate with a different name") - certFileAfter, err := readCertificateFromFile(certFilePath) - assert.NoError(t, err) - keyFileAfter, err := readRSAKeyFromFile(keyFilePath) - assert.NoError(t, err) - - if !force { - assert.Equal(t, certFile, certFileAfter, "Expected CA certificate to be the same") - assert.Equal(t, keyFile, keyFileAfter, "Expected CA key to be the same") - } else { - assert.NotEqual(t, certFile, certFileAfter, "Expected CA certificate to be different") - assert.NotEqual(t, keyFile, keyFileAfter, "Expected CA key to be different") + assert.FileExists(t, filepath.Join(tempCaDir, "my-custom-name.crt"), "CA certificate should exist") + assert.FileExists(t, filepath.Join(tempCaDir, "my-custom-name.key"), "CA certificate should exist") +} + +func TestCreateCACertificate_WithForceFlag_ShouldRegenerate(t *testing.T) { + // Creation of a CA certificate with the force flag. + // 1. It first creates a certificate + // 2. Attempt to recreate the certificate with the force flag + // 3. Check that the content of the files are different + + t.Parallel() + + cleanup, tempCaDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestCreateCAParams{}) + defer cleanup() + + // Create a CA certificate + result := createCa.Run([]string{"-out", tempCaDir}) + assert.Equal(t, 0, result, fmt.Sprintf("creat-ca should pass and create a certificate at %s", tempCaDir)) + + // Read the content of the key and crt files + originalCaCert, originalKeyCert := readAndDecodeCertificateAndKey(t, tempCaDir, "ca") + + // Try to create a CA certificate again with the force flag and override the existing one + args := []string{"-out", tempCaDir, "-force"} + result = createCa.Run(args) + assert.Equal(t, 0, result, fmt.Sprintf("creat-ca should pass and override certificate at %s", tempCaDir)) + + // Read the content of the key and crt files generated from the config file + newCaCert, newKeyCert := readAndDecodeCertificateAndKey(t, tempCaDir, "ca") + + // Check that the content of the files are different + assert.NotEqual(t, originalCaCert, newCaCert, "The content of the CA certificate should be different") + assert.NotEqual(t, originalKeyCert, newKeyCert, "The content of the CA key should be different") +} + +func TestCreateIntermediateCertificate_WithoutRootCertificate_ShouldFail(t *testing.T) { + // Create intermediate certificate without root certificate. + // 1. It creates an intermediate certificate without root certificate + // 2. It should return an error + + t.Parallel() + + cleanup, tempCaDir, _, errorBuffer, createCa := setupCreateCaTestEnvironment(t, &TestCreateCAParams{}) + defer cleanup() + + args := []string{ + "-out", tempCaDir, + "-ca-certificate", "unknown", + "-ca-key", "unknown", } + result := createCa.Run(args) + assert.Equal(t, 1, result, "create-ca should fail without a root certificate") + + errors := extractErrors(errorBuffer.String()) + assert.Equal(t, 1, len(errors), "Expected 1 error") + assert.Contains(t, errors[0], "error reading file") +} + +type TestCreateCAParams struct { + OutputDir string } -func TestGenerateCACertificate(t *testing.T) { - t.Run("nominal-case", func(t *testing.T) { - years, days, outputDir, caCert, caKey := setupTestEnvironmentForCaTests(t) +func setupCreateCaTestEnvironment(t *testing.T, params *TestCreateCAParams) (cleanupFunc func(), tempDir string, outputBuffer *bytes.Buffer, errorBuffer *bytes.Buffer, createCa *CreateCA) { + tempDir = params.OutputDir - err := generateCACertificate(years, days, outputDir, caCert, caKey, false) + if tempDir == "" { + var err error + tempDir, err = os.MkdirTemp(os.TempDir(), "ca-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %s", err) + } + } - assert.NoError(t, err, "Expected no error in nominal case") + outputBuffer = new(bytes.Buffer) + errorBuffer = new(bytes.Buffer) - assert.FileExists(t, path.Join(outputDir, "ca.crt"), "CA certificate should exist") - assert.FileExists(t, path.Join(outputDir, "ca.key"), "CA key should exist") + createCa = NewCreateCA(&cli.BasicUi{ + Writer: outputBuffer, + ErrorWriter: errorBuffer, }) - t.Run("directory-exists", func(t *testing.T) { - years, days, outputDir, caCert, caKey := setupTestEnvironmentForCaTests(t) - testGenerateCACertificate(t, years, days, outputDir, caCert, caKey, false) - }) + cleanupFunc = func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Logf("Failed to remove temp directory (%s): %s", tempDir, err) + } + } - t.Run("directory-exists-force", func(t *testing.T) { - years, days, outputDir, caCert, caKey := setupTestEnvironmentForCaTests(t) - testGenerateCACertificate(t, years, days, outputDir, caCert, caKey, true) - }) + return cleanupFunc, tempDir, outputBuffer, errorBuffer, createCa } diff --git a/certificates/create_certs.go b/certificates/create_certs.go index 4594eeb..ccc6eea 100644 --- a/certificates/create_certs.go +++ b/certificates/create_certs.go @@ -4,19 +4,18 @@ import ( "bytes" "flag" "fmt" + "github.com/mitchellh/cli" + "gopkg.in/yaml.v3" "os" - "path" "reflect" "strings" "sync" - "text/tabwriter" - - "github.com/mitchellh/cli" - "gopkg.in/yaml.v3" ) type CreateCertificates struct { - Ui cli.Ui + Ui cli.Ui + Config CreateCertificateArguments + Flags *flag.FlagSet } type CreateCertificateArguments struct { @@ -28,20 +27,26 @@ type Config struct { Certificates struct { CaCerts []CreateCAArguments `yaml:"ca-certs"` Nodes []CreateNodeArguments `yaml:"node-certs"` + Users []CreateUserArguments `yaml:"user-certs"` } `yaml:"certificates"` } -func (c *CreateCertificates) Run(args []string) int { - var arguments CreateCertificateArguments - flags := flag.NewFlagSet("create_certs", flag.ContinueOnError) - flags.Usage = func() { c.Ui.Info(c.Help()) } - flags.StringVar(&arguments.ConfigPath, "config-file", "./certs.yml", "The config yml file") - flags.BoolVar(&arguments.Force, "force", false, forceOption) +func NewCreateCerts(ui cli.Ui) *CreateCertificates { + c := &CreateCertificates{Ui: ui} + + c.Flags = flag.NewFlagSet("create_certs", flag.ContinueOnError) + c.Flags.StringVar(&c.Config.ConfigPath, "config-file", "./certs.yml", "The config yml file") + c.Flags.BoolVar(&c.Config.Force, "force", false, ForceFlagUsage) + return c +} - if err := flags.Parse(args); err != nil { +func (c *CreateCertificates) Run(args []string) int { + if err := c.Flags.Parse(args); err != nil { + c.Ui.Error(err.Error()) return 1 } - configData, err := os.ReadFile(arguments.ConfigPath) + + configData, err := os.ReadFile(c.Config.ConfigPath) if err != nil { c.Ui.Error(err.Error()) return 1 @@ -53,65 +58,70 @@ func (c *CreateCertificates) Run(args []string) int { return 1 } - if err := c.checkPaths(config, arguments.Force); err { - c.Ui.Error(ErrFileExists) + certErr := c.checkPaths(config, c.Config.Force) + if certErr != nil { + c.Ui.Error(certErr.Error()) return 1 } - if c.generateCaCerts(config, arguments.Force) != 0 || c.generateNodes(config, arguments.Force) != 0 { + if c.generateCaCerts(config, c.Config.Force) != 0 || c.generateNodes(config, c.Config.Force) != 0 || c.generateUsers(config, c.Config.Force) != 0 { return 1 } return 0 } -func (c *CreateCertificates) checkPaths(config Config, force bool) bool { - // If any certs file exists and the force flag isn't provided, it returns an - // error. Otherwise, it returns false, indicating that certificate generation - // can proceed safely. - - var errorMutex sync.Mutex - var error bool +func (c *CreateCertificates) checkPaths(config Config, force bool) error { + var once sync.Once + var certError error var wg sync.WaitGroup - checkFile := func(filePath string) { + checkCertFiles := func(certificateName, dir string) { defer wg.Done() - if fileExists(filePath, force) { - errorMutex.Lock() - error = true - errorMutex.Unlock() + if err := checkCertificatesLocationWithForce(dir, certificateName, force); err != nil { + once.Do(func() { + certError = err + }) } } // Check CA certificate and key paths for _, caCert := range config.Certificates.CaCerts { - wg.Add(2) - go checkFile(caCert.CACertificatePath) - go checkFile(caCert.CAKeyPath) + wg.Add(1) + go checkCertFiles(caCert.Name, caCert.OutputDir) } // Check Node certificate and key paths for _, node := range config.Certificates.Nodes { - wg.Add(4) - go checkFile(node.CACertificatePath) - go checkFile(node.CAKeyPath) - go checkFile(path.Join(node.OutputDir, "node.crt")) - go checkFile(path.Join(node.OutputDir, "node.key")) + wg.Add(1) + go checkCertFiles(node.Name, node.OutputDir) } wg.Wait() - return error + return certError +} + +func (c *CreateCertificates) generateUsers(config Config, force bool) int { + for _, user := range config.Certificates.Users { + user.Force = force + createUser := NewCreateUser(&cli.ColoredUi{ + Ui: c.Ui, + OutputColor: cli.UiColorBlue, + }) + if createUser.Run(toArguments(user)) != 0 { + return 1 + } + } + return 0 } func (c *CreateCertificates) generateNodes(config Config, force bool) int { for _, node := range config.Certificates.Nodes { node.Force = force - createNode := CreateNode{ - Ui: &cli.ColoredUi{ - Ui: c.Ui, - OutputColor: cli.UiColorBlue, - }, - } + createNode := NewCreateNode(&cli.ColoredUi{ + Ui: c.Ui, + OutputColor: cli.UiColorBlue, + }) if createNode.Run(toArguments(node)) != 0 { return 1 } @@ -120,19 +130,16 @@ func (c *CreateCertificates) generateNodes(config Config, force bool) int { } func (c *CreateCertificates) generateCaCerts(config Config, force bool) int { - coloredUI := &cli.ColoredUi{ - Ui: c.Ui, - OutputColor: cli.UiColorBlue, - } - for _, caCert := range config.Certificates.CaCerts { caCert.Force = force - caCreator := CreateCA{Ui: coloredUI} + caCreator := NewCreateCA(&cli.ColoredUi{ + Ui: c.Ui, + OutputColor: cli.UiColorBlue, + }) if caCreator.Run(toArguments(caCert)) != 0 { return 1 } } - return 0 } @@ -157,20 +164,10 @@ func toArguments(config interface{}) []string { } func (c *CreateCertificates) Help() string { - var buffer bytes.Buffer - - w := tabwriter.NewWriter(&buffer, 0, 0, 2, ' ', 0) - - fmt.Fprintln(w, "Usage: create_certs [options]") - fmt.Fprintln(w, c.Synopsis()) - fmt.Fprintln(w, "Options:") - - writeHelpOption(w, "config-file", "The path to the yml config file.") - writeHelpOption(w, "force", forceOption) - - w.Flush() - - return strings.TrimSpace(buffer.String()) + var helpText bytes.Buffer + c.Flags.SetOutput(&helpText) + c.Flags.PrintDefaults() + return helpText.String() } func (c *CreateCertificates) Synopsis() string { diff --git a/certificates/create_certs_test.go b/certificates/create_certs_test.go new file mode 100644 index 0000000..c34f8b7 --- /dev/null +++ b/certificates/create_certs_test.go @@ -0,0 +1,379 @@ +package certificates + +import ( + "bytes" + "fmt" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestCreateCertificates(t *testing.T) { + t.Run("TestCreateCertificates_ValidConfigFile_ShouldSucceed", TestCreateCertificates_ValidConfigFile_ShouldSucceed) + t.Run("TestCreateCertificates_ExistingCertificatesWithoutForceFlag_ShouldFail", TestCreateCertificates_ExistingCertificatesWithoutForceFlag_ShouldFail) + t.Run("TestCreateCertificates_ForceFlagWithExistingCertificates_ShouldRegenerate", TestCreateCertificates_ForceFlagWithExistingCertificates_ShouldRegenerate) + t.Run("TestCreateCertificates_ValidConfigWithCustomNames_ShouldCreateNamedCertificates", TestCreateCertificates_ValidConfigWithCustomNames_ShouldCreateNamedCertificates) + t.Run("TestCreateCertificates_InvalidPathInConfig_ShouldFailWithError", TestCreateCertificates_InvalidPathInConfig_ShouldFailWithError) +} + +func TestCreateCertificates_ValidConfigFile_ShouldSucceed(t *testing.T) { + // Create certificates from a certs.yml file + + t.Parallel() + + cleanup, tempCertsDir, _, _, createCerts := setupCertificateTestEnvironment(t) + defer cleanup() + + certsFileWithName := "certs.yml" + + // Create a certs.yml file + createConfigFile(t, tempCertsDir, certsFileWithName, validCertificatesYaml, tempCertsDir) + + args := []string{ + "-config-file", filepath.Join(tempCertsDir, certsFileWithName), + } + + result := createCerts.Run(args) + assert.Equal(t, 0, result, "The create-certs command should succeed") + + assert.FileExists(t, filepath.Join(tempCertsDir, "root_ca", "ca.crt"), "Root CA certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "root_ca", "ca.key"), "Root CA key should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "intermediate_ca", "ca.crt"), "Intermediate certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "intermediate_ca", "ca.key"), "Intermediate certificate key should exist") + + nodes := []string{"node1", "node2", "node3"} + for _, node := range nodes { + assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.crt"), fmt.Sprintf("%s certificate should exist", node)) + assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.key"), fmt.Sprintf("%s certificate key should exist", node)) + } + + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "user-admin.crt"), "User admin certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "user-admin.key"), "User admin private key should exist") +} + +func TestCreateCertificates_ExistingCertificatesWithoutForceFlag_ShouldFail(t *testing.T) { + // Create certificates from config file and should fail because the certificates already exist + // 1. Successfully create certificates from config file + // 2. Run create-certs again without the force flag + // 3. Expect an error suggesting that the certificates already exist and that the force flag should be used + + t.Parallel() + + cleanup, tempCertsDir, _, errorBuffer, createCerts := setupCertificateTestEnvironment(t) + defer cleanup() + + createConfigFile(t, tempCertsDir, "certs.yml", validCertificatesYaml, tempCertsDir) + + args := []string{ + "-config-file", tempCertsDir + "/certs.yml", + } + + result := createCerts.Run(args) + assert.Equal(t, 0, result, "The create-certs command should succeed the first time it is run since the certificates do not exist") + + assert.FileExists(t, filepath.Join(tempCertsDir, "root_ca", "ca.crt"), "Root CA certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "root_ca", "ca.key"), "Root CA key should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "intermediate_ca", "ca.crt"), "Intermediate certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "intermediate_ca", "ca.key"), "Intermediate certificate key should exist") + + nodes := []string{"node1", "node2", "node3"} + for _, node := range nodes { + assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.crt"), fmt.Sprintf("%s certificate should exist", node)) + assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.key"), fmt.Sprintf("%s certificate key should exist", node)) + } + + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "user-admin.crt"), "User admin certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "user-admin.key"), "User admin private key should exist") + + // Try to generate the certificates again and expect and error + result = createCerts.Run(args) + assert.Equal(t, 1, result, "The create-certs command should fail the second time it is run since the certificates already exist") + errors := extractErrors(errorBuffer.String()) + + assert.Equal(t, 1, len(errors), "Expected 1 error") + assert.Equal(t, "existing files would be overwritten. Use -force to proceed", errors[0]) +} + +func TestCreateCertificates_ForceFlagWithExistingCertificates_ShouldRegenerate(t *testing.T) { + // Create certificates from a certs.yml file with the force flag + // Expect all certificates to be regenerated and different from the original ones + + t.Parallel() + + cleanup, tempCertsDir, _, _, createCerts := setupCertificateTestEnvironment(t) + defer cleanup() + + certsFileWithName := "certs.yml" + + // Create a certs.yml file + createConfigFile(t, tempCertsDir, certsFileWithName, validCertificatesYaml, tempCertsDir) + + args := []string{ + "-config-file", filepath.Join(tempCertsDir, certsFileWithName), + } + + result := createCerts.Run(args) + assert.Equal(t, 0, result, "The create-certs command should succeed the first time it is run since the certificates do not exist") + + assert.FileExists(t, filepath.Join(tempCertsDir, "root_ca", "ca.crt"), "Root CA certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "root_ca", "ca.key"), "Root CA key should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "intermediate_ca", "ca.crt"), "Intermediate certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "intermediate_ca", "ca.key"), "Intermediate certificate key should exist") + + nodes := []string{"node1", "node2", "node3"} + for _, node := range nodes { + assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.crt"), fmt.Sprintf("%s certificate should exist", node)) + assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.key"), fmt.Sprintf("%s certificate key should exist", node)) + } + + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "user-admin.crt"), "User admin certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "user-admin.key"), "User admin private key should exist") + + // Read the content of the key and crt files generated from the config file + originalCaCert, originalKeyCert := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "root_ca"), "ca") + originalIntermediateCaCert, originalIntermediateKeyCert := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "intermediate_ca"), "ca") + originalUserCert, originalUserCertKey := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "user-admin"), "user-admin") + + originalCerts := make(map[string][2]interface{}) + + for _, node := range nodes { + originalCaCert, originalKeyCert := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, node), "node") + originalCerts[node] = [2]interface{}{originalCaCert, originalKeyCert} + } + + args = []string{ + "-config-file", filepath.Join(tempCertsDir, certsFileWithName), + "-force", + } + + result = createCerts.Run(args) + assert.Equal(t, 0, result, "The create-certs command should succeed with the force flag and "+ + "override the existing certificates defined in the config file") + + newRootCaCert, newRootCaKey := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "root_ca"), "ca") + newIntermediateCaCert, newIntermediateKeyCert := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "intermediate_ca"), "ca") + newUserCert, newUserCertKey := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "user-admin"), "user-admin") + + assert.NotEqual(t, originalCaCert, newRootCaCert, "Root CA certificate should be regenerated") + assert.NotEqual(t, originalKeyCert, newRootCaKey, "Root CA key should be regenerated") + + assert.NotEqual(t, originalIntermediateCaCert, newIntermediateCaCert, "Intermediate CA certificate should be regenerated") + assert.NotEqual(t, originalIntermediateKeyCert, newIntermediateKeyCert, "Intermediate CA key should be regenerated") + + assert.NotEqual(t, originalUserCert, newUserCert, "User certificate should be regenerated") + assert.NotEqual(t, originalUserCertKey, newUserCertKey, "User certificate key should be regenerated") + + for _, node := range nodes { + newCAHash, newKeyHash := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, node), "node") + assert.NotEqual(t, originalCerts[node][0], newCAHash, fmt.Sprintf("%s certificate should be regenerated", node)) + assert.NotEqual(t, originalCerts[node][1], newKeyHash, fmt.Sprintf("%s certificate key should be regenerated", node)) + } +} + +func TestCreateCertificates_ValidConfigWithCustomNames_ShouldCreateNamedCertificates(t *testing.T) { + // Create certificates from a certs.yml file with the name parameter + // Expect all certificates to be named with the name parameter + + t.Parallel() + + cleanup, tempCertsDir, _, _, createCerts := setupCertificateTestEnvironment(t) + defer cleanup() + + certsFileName := "certs-with-name.yml" + + createConfigFile(t, tempCertsDir, certsFileName, certificatesYamlWithOverrideName, tempCertsDir) + + args := []string{ + "-config-file", filepath.Join(tempCertsDir, certsFileName), + } + + result := createCerts.Run(args) + assert.Equal(t, 0, result, "The create-certs command should create certificates with custom names") + + assert.FileExists(t, filepath.Join(tempCertsDir, "custom_root", "custom_root.crt"), "Root CA certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "custom_root", "custom_root.key"), "Root CA key should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "custom_intermediate", "custom_intermediate.crt"), "Intermediate certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "custom_intermediate", "custom_intermediate.key"), "Intermediate certificate key should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "renamed.crt"), "User admin certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "renamed.key"), "Intermediate certificate key should exist") + + nodes := []string{"custom_node1", "custom_node2", "custom_node3"} + for _, node := range nodes { + assert.FileExists(t, filepath.Join(tempCertsDir, node, fmt.Sprintf("%s.crt", node)), fmt.Sprintf("%s certificate should exist", node)) + assert.FileExists(t, filepath.Join(tempCertsDir, node, fmt.Sprintf("%s.key", node)), fmt.Sprintf("%s certificate key should exist", node)) + } +} + +func TestCreateCertificates_InvalidPathInConfig_ShouldFailWithError(t *testing.T) { + // An invalid path is defined at ca-certificate in the config. + // The intermediate certificate uses an invalid path for the root CA certificate. + // This should result in an error suggesting that ca.crt is not found. + + t.Parallel() + + cleanup, tempCertsDir, _, errorBuffer, createCerts := setupCertificateTestEnvironment(t) + defer cleanup() + + certsFileName := "certs.yml" + + createConfigFile(t, tempCertsDir, certsFileName, certificatesYamlWithInvalidPath, tempCertsDir) + + args := []string{ + "-config-file", filepath.Join(tempCertsDir, certsFileName), + } + + result := createCerts.Run(args) + assert.Equal(t, 1, result, "The create-certs command should fail with code 1 when an invalid path is defined in the config") + + errors := extractErrors(errorBuffer.String()) + + assert.Equal(t, 1, len(errors), "Expected 1 error") + + assert.Contains(t, errors[0], "error reading file") + assert.Contains(t, errors[0], filepath.ToSlash(fmt.Sprintf("%s/invalid_root_ca/ca.crt", tempCertsDir))) + + // The root CA will be created + assert.DirExists(t, filepath.Join(tempCertsDir, "root_ca")) + + // Intermediate and node1 will not be created + assert.NoDirExists(t, filepath.Join(tempCertsDir, "intermediate_ca"), "Intermediate certificate should not exist") + assert.NoDirExists(t, filepath.Join(tempCertsDir, "node1"), "Intermediate certificate key should not exist") +} + +// Valid certificate file +var validCertificatesYaml = `certificates: + ca-certs: + - out: "./root_ca" + - out: "./intermediate_ca" + ca-certificate: "./root_ca/ca.crt" + ca-key: "./root_ca/ca.key" + days: 5 + node-certs: + - out: "./node1" + ca-certificate: "./intermediate_ca/ca.crt" + ca-key: "./intermediate_ca/ca.key" + ip-addresses: "127.0.0.1,172.20.240.1" + dns-names: "localhost,eventstore-node1.localhost.com" + - out: "./node2" + ca-certificate: "./intermediate_ca/ca.crt" + ca-key: "./intermediate_ca/ca.key" + ip-addresses: "127.0.0.2,172.20.240.2" + dns-names: "localhost,eventstore-node2.localhost.com" + - out: "./node3" + ca-certificate: "./intermediate_ca/ca.crt" + ca-key: "./intermediate_ca/ca.key" + ip-addresses: "127.0.0.3,172.20.240.3" + dns-names: "localhost,eventstore-node2.localhost.com" + user-certs: + - out: "./user-admin" + username: "admin" + ca-certificate: "./root_ca/ca.crt" + ca-key: "./root_ca/ca.key"` + +// Invalid path defined at ca-certificate in the config +var certificatesYamlWithInvalidPath = `certificates: + ca-certs: + - out: "./root_ca" + - out: "./intermediate_ca" + ca-certificate: "./invalid_root_ca/ca.crt" + ca-key: "./root_ca/ca.key" + days: 5 + node-certs: + - out: "./node1" + ca-certificate: "./intermediate_ca/ca.crt" + ca-key: "./intermediate_ca/ca.key" + ip-addresses: "127.0.0.1,172.20.240.1" + dns-names: "localhost,eventstore-node1.localhost.com"` + +// Each certificate have a name parameter +var certificatesYamlWithOverrideName = `certificates: + ca-certs: + - out: "./custom_root" + name: "custom_root" + - out: "./custom_intermediate" + name: "custom_intermediate" + ca-certificate: "./custom_root/custom_root.crt" + ca-key: "./custom_root/custom_root.key" + days: 5 + node-certs: + - out: "./custom_node1" + name: "custom_node1" + ca-certificate: "./custom_intermediate/custom_intermediate.crt" + ca-key: "./custom_intermediate/custom_intermediate.key" + ip-addresses: "127.0.0.1,172.20.240.1" + dns-names: "localhost,eventstore-node1.localhost.com" + - out: "./custom_node2" + name: "custom_node2" + ca-certificate: "./custom_intermediate/custom_intermediate.crt" + ca-key: "./custom_intermediate/custom_intermediate.key" + ip-addresses: "127.0.0.2,172.20.240.2" + dns-names: "localhost,eventstore-node2.localhost.com" + - out: "./custom_node3" + name: "custom_node3" + ca-certificate: "./custom_intermediate/custom_intermediate.crt" + ca-key: "./custom_intermediate/custom_intermediate.key" + ip-addresses: "127.0.0.3,172.20.240.3" + dns-names: "localhost,eventstore-node2.localhost.com" + user-certs: + - out: "./user-admin" + username: "admin" + name: "renamed" + ca-certificate: "./custom_root/custom_root.crt" + ca-key: "./custom_root/custom_root.key"` + +func setupCertificateTestEnvironment(t *testing.T) (cleanupFunc func(), tempCertsDir string, outputBuffer *bytes.Buffer, errorBuffer *bytes.Buffer, createCerts *CreateCertificates) { + tempCertsDir, err := os.MkdirTemp(os.TempDir(), "certs-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %s", err) + } + + outputBuffer = new(bytes.Buffer) + errorBuffer = new(bytes.Buffer) + + createCerts = NewCreateCerts(&cli.BasicUi{ + Writer: outputBuffer, + ErrorWriter: errorBuffer, + }) + + cleanupFunc = func() { + if err := os.RemoveAll(tempCertsDir); err != nil { + t.Logf("Failed to remove temp directory (%s): %s", tempCertsDir, err) + } + } + + return cleanupFunc, tempCertsDir, outputBuffer, errorBuffer, createCerts +} + +func createConfigFile(t *testing.T, dirPath string, fileName string, content string, newParentDir string) { + updatedContent := strings.ReplaceAll(content, "./", fmt.Sprintf("%s/", filepath.ToSlash(newParentDir))) + + filePath := filepath.Join(dirPath, fileName) + + // Create the directory if it does not exist + if _, err := os.Stat(dirPath); os.IsNotExist(err) { + err := os.MkdirAll(dirPath, 0755) + if err != nil { + t.Errorf("Error creating directory: %s", err) + } + } + + f, err := os.Create(filePath) + if err != nil { + panic(err) + } + + defer func(f *os.File) { + err := f.Close() + if err != nil { + t.Error(err) + } + }(f) + + _, err = f.WriteString(updatedContent) + if err != nil { + panic(err) + } +} diff --git a/certificates/create_node.go b/certificates/create_node.go index bf05ea6..103a7e6 100644 --- a/certificates/create_node.go +++ b/certificates/create_node.go @@ -12,10 +12,8 @@ import ( "fmt" "net" "os" - "path" "strconv" "strings" - "text/tabwriter" "time" multierror "github.com/hashicorp/go-multierror" @@ -23,7 +21,9 @@ import ( ) type CreateNode struct { - Ui cli.Ui + Ui cli.Ui + Flags *flag.FlagSet + Config CreateNodeArguments } type CreateNodeArguments struct { @@ -34,9 +34,26 @@ type CreateNodeArguments struct { Days int `yaml:"days"` OutputDir string `yaml:"out"` CommonName string `yaml:"common-name"` + Name string `yaml:"name"` Force bool `yaml:"force"` } +func NewCreateNode(ui cli.Ui) *CreateNode { + c := &CreateNode{Ui: ui} + + c.Flags = flag.NewFlagSet("create_node", flag.ContinueOnError) + c.Flags.StringVar(&c.Config.CACertificatePath, "ca-certificate", "./ca/ca.crt", CaCertFlagUsage) + c.Flags.StringVar(&c.Config.CommonName, "common-name", "eventstoredb-node", "the certificate subject common name") + c.Flags.StringVar(&c.Config.CAKeyPath, "ca-key", "./ca/ca.key", CaKeyFlagUsage) + c.Flags.StringVar(&c.Config.IPAddresses, "ip-addresses", "", "comma-separated list of IP addresses of the node") + c.Flags.StringVar(&c.Config.DNSNames, "dns-names", "", "comma-separated list of DNS names of the node") + c.Flags.IntVar(&c.Config.Days, "days", 0, DayFlagUsage) + c.Flags.StringVar(&c.Config.OutputDir, "out", "", OutDirFlagUsage) + c.Flags.StringVar(&c.Config.Name, "name", "node", NameFlagUsage) + c.Flags.BoolVar(&c.Config.Force, "force", false, ForceFlagUsage) + return c +} + func parseIPAddresses(ipAddresses string) ([]net.IP, error) { if len(ipAddresses) == 0 { return []net.IP{}, nil @@ -62,7 +79,7 @@ func parseDNSNames(dnsNames string) ([]string, error) { return dns, nil } -func getNodeOutputDirectory() (string, error) { +func getOutputDirectory() (string, error) { for i := 1; i <= 100; i++ { dir := "node" + strconv.Itoa(i) if _, err := os.Stat(dir); os.IsNotExist(err) { @@ -73,38 +90,26 @@ func getNodeOutputDirectory() (string, error) { } func (c *CreateNode) Run(args []string) int { - var config CreateNodeArguments - - flags := flag.NewFlagSet("create_node", flag.ContinueOnError) - flags.Usage = func() { c.Ui.Info(c.Help()) } - flags.StringVar(&config.CACertificatePath, "ca-certificate", "./ca/ca.crt", "the path to the CA certificate file") - flags.StringVar(&config.CommonName, "common-name", "eventstoredb-node", "the certificate subject common name") - flags.StringVar(&config.CAKeyPath, "ca-key", "./ca/ca.key", "the path to the CA key file") - flags.StringVar(&config.IPAddresses, "ip-addresses", "", "comma-separated list of IP addresses of the node") - flags.StringVar(&config.DNSNames, "dns-names", "", "comma-separated list of DNS names of the node") - flags.IntVar(&config.Days, "days", 0, "the validity period of the certificate in days") - flags.StringVar(&config.OutputDir, "out", "", "The output directory") - flags.BoolVar(&config.Force, "force", false, forceOption) - - if err := flags.Parse(args); err != nil { + if err := c.Flags.Parse(args); err != nil { + c.Ui.Error(err.Error()) return 1 } validationErrors := new(multierror.Error) - if len(config.CACertificatePath) == 0 { - multierror.Append(validationErrors, errors.New("ca-certificate is a required field")) + if len(c.Config.CACertificatePath) == 0 { + _ = multierror.Append(validationErrors, errors.New("ca-certificate is a required field")) } - if len(config.CAKeyPath) == 0 { - multierror.Append(validationErrors, errors.New("ca-key is a required field")) + if len(c.Config.CAKeyPath) == 0 { + _ = multierror.Append(validationErrors, errors.New("ca-key is a required field")) } - if len(config.IPAddresses) == 0 && len(config.DNSNames) == 0 { - multierror.Append(validationErrors, errors.New("at least one IP address or DNS name needs to be specified with --ip-addresses or --dns-names")) + if len(c.Config.IPAddresses) == 0 && len(c.Config.DNSNames) == 0 { + _ = multierror.Append(validationErrors, errors.New("at least one IP address or DNS name needs to be specified with --ip-addresses or --dns-names")) } - if config.Days < 0 { - multierror.Append(validationErrors, errors.New("days must be positive")) + if c.Config.Days < 0 { + _ = multierror.Append(validationErrors, errors.New("days must be positive")) } if validationErrors.ErrorOrNil() != nil { @@ -112,54 +117,49 @@ func (c *CreateNode) Run(args []string) int { return 1 } - caCert, err := readCertificateFromFile(config.CACertificatePath) + caCert, err := readCertificateFromFile(c.Config.CACertificatePath) if err != nil { c.Ui.Error(err.Error()) return 1 } - caKey, err := readRSAKeyFromFile(config.CAKeyPath) + caKey, err := readRSAKeyFromFile(c.Config.CAKeyPath) if err != nil { err := fmt.Errorf("error: %s. please note that only RSA keys are currently supported", err.Error()) c.Ui.Error(err.Error()) return 1 } - ips, err := parseIPAddresses(config.IPAddresses) + ips, err := parseIPAddresses(c.Config.IPAddresses) if err != nil { c.Ui.Error(err.Error()) return 1 } - dnsNames, err := parseDNSNames(config.DNSNames) + dnsNames, err := parseDNSNames(c.Config.DNSNames) if err != nil { c.Ui.Error(err.Error()) return 1 } - outputDir := config.OutputDir - outputBaseFileName := "node" + outputDir := c.Config.OutputDir + outputBaseFileName := c.Config.Name if len(outputDir) == 0 { - outputDir, err = getNodeOutputDirectory() + outputDir, err = getOutputDirectory() if err != nil { c.Ui.Error(err.Error()) return 1 } - outputBaseFileName = outputDir } - // check if certificates already exist - keyPath := path.Join(config.OutputDir, fmt.Sprintf("%s.key", outputBaseFileName)) - crtPath := path.Join(config.OutputDir, fmt.Sprintf("%s.crt", outputBaseFileName)) - - if fileExists(keyPath, config.Force) { - c.Ui.Error(ErrFileExists) - return 1 + if len(outputBaseFileName) == 0 { + outputBaseFileName = outputDir } - if fileExists(crtPath, config.Force) { - c.Ui.Error(ErrFileExists) + certErr := checkCertificatesLocationWithForce(outputDir, outputBaseFileName, c.Config.Force) + if certErr != nil { + c.Ui.Error(certErr.Error()) return 1 } @@ -167,12 +167,12 @@ func (c *CreateNode) Run(args []string) int { years := 1 days := 0 - if config.Days != 0 { - days = config.Days + if c.Config.Days != 0 { + days = c.Config.Days years = 0 } - err = generateNodeCertificate(caCert, caKey, ips, dnsNames, years, days, outputDir, outputBaseFileName, config.CommonName, config.Force) + err = generateNodeCertificate(caCert, caKey, ips, dnsNames, years, days, outputDir, outputBaseFileName, c.Config.CommonName, c.Config.Force) if err != nil { c.Ui.Error(err.Error()) return 1 @@ -187,7 +187,18 @@ func (c *CreateNode) Run(args []string) int { return 0 } -func generateNodeCertificate(caCert *x509.Certificate, caPrivateKey *rsa.PrivateKey, ips []net.IP, dnsNames []string, years int, days int, outputDir string, outputBaseFileName string, commonName string, force bool) error { +func generateNodeCertificate( + caCert *x509.Certificate, + caPrivateKey *rsa.PrivateKey, + ips []net.IP, + dnsNames []string, + years int, + days int, + outputDir string, + outputBaseFileName string, + commonName string, + force bool, +) error { serialNumber, err := generateSerialNumber(128) if err != nil { return fmt.Errorf("could not generate 128-bit serial number: %s", err.Error()) @@ -219,7 +230,7 @@ func generateNodeCertificate(caCert *x509.Certificate, caPrivateKey *rsa.Private } privateKeyPem := new(bytes.Buffer) - pem.Encode(privateKeyPem, &pem.Block{ + err = pem.Encode(privateKeyPem, &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey), }) @@ -247,28 +258,10 @@ func generateNodeCertificate(caCert *x509.Certificate, caPrivateKey *rsa.Private } func (c *CreateNode) Help() string { - var buffer bytes.Buffer - - w := tabwriter.NewWriter(&buffer, 0, 0, 2, ' ', 0) // 2 spaces minimum gap between columns - - fmt.Fprintln(w, "Usage: create_node [options]") - fmt.Fprintln(w, c.Synopsis()) - fmt.Fprintln(w, "Options:") - - writeHelpOption(w, "ca-certificate", "The path to the CA certificate file (default: ./ca/ca.crt).") - writeHelpOption(w, "ca-key", "The path to the CA key file (default: ./ca/ca.key).") - writeHelpOption(w, "days", "The validity period of the certificates in days (default: 1 year).") - writeHelpOption(w, "out", "The output directory (default: ./nodeX where X is an auto-generated number).") - writeHelpOption(w, "ip-addresses", "Comma-separated list of IP addresses of the node.") - writeHelpOption(w, "dns-names", "Comma-separated list of DNS names of the node.") - writeHelpOption(w, "common-name", "The certificate subject common name.") - writeHelpOption(w, "force", forceOption) - - fmt.Fprintln(w, "\nAt least one IP address or DNS name needs to be specified.") - - w.Flush() - - return strings.TrimSpace(buffer.String()) + var helpText bytes.Buffer + c.Flags.SetOutput(&helpText) + c.Flags.PrintDefaults() + return helpText.String() } func (c *CreateNode) Synopsis() string { diff --git a/certificates/create_node_test.go b/certificates/create_node_test.go index 278cc01..b8f2970 100644 --- a/certificates/create_node_test.go +++ b/certificates/create_node_test.go @@ -1,101 +1,200 @@ package certificates import ( + "bytes" "crypto/x509" - "path" - "testing" - + "fmt" + "github.com/mitchellh/cli" "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "testing" + "time" ) -func setupTestEnvironmentForNodeTests(t *testing.T) (years int, days int, outputDirCa string, outputDirNode string, nodeCertFileName string, ipAddresses string, commonName string, dnsNames []string) { - years = 1 - days = 0 - outputDirCa = "./ca" - outputDirNode = "./node" - nodeCertFileName = "node" - ipAddresses = "127.0.0.1" - commonName = "EventStoreDB" - dnsNames = []string{"localhost"} - - cleanupDirsForTest(t, outputDirCa, outputDirNode) - return +func TestCreateNodeCertificate(t *testing.T) { + t.Run("TestCreateNodeCertificate_WithoutParams_ShouldFail", TestCreateNodeCertificate_WithoutParams_ShouldFail) + t.Run("TestCreateNodeCertificate_WithAllRequiredParams_ShouldSucceed", TestCreateNodeCertificate_WithAllRequiredParams_ShouldSucceed) + t.Run("TestCreateNodeCertificate_WithNameFlagAndOutput_ShouldCreateNamedCertificate", TestCreateNodeCertificate_WithNameFlagAndOutput_ShouldCreateNamedCertificate) + t.Run("TestCreateNodeCertificate_WithNameFlagWithoutOutput_ShouldCreateNamedCertificate", TestCreateNodeCertificate_WithNameFlagWithoutOutput_ShouldCreateNamedCertificate) + t.Run("TestCreateNodeCertificate_WithForceFlag_ShouldRegenerate", TestCreateNodeCertificate_WithForceFlag_ShouldRegenerate) } -func TestGenerateNodeCertificate(t *testing.T) { +func TestCreateNodeCertificate_WithoutParams_ShouldFail(t *testing.T) { + t.Parallel() + + cleanup, _, _, _, errorBuffer, createNode := setupCreateNodeTestEnvironment(t) + defer cleanup() - t.Run("nominal-case", func(t *testing.T) { - years, days, outputDirCa, outputDirNode, nodeCertFileName, ipAddresses, commonName, dnsNames := setupTestEnvironmentForNodeTests(t) + var args []string + result := createNode.Run(args) + assert.Equal(t, 1, result, "The 'create-node' operation should fail due to the absence of required parameters.") - caCertificate, caPrivateKey := generateAndAssertCACert(t, years, days, outputDirCa, false) - ips, err := parseIPAddresses(ipAddresses) - assert.NoError(t, err) + errors := extractErrors(errorBuffer.String()) + assert.Equal(t, 1, len(errors)) + assert.Equal(t, "at least one IP address or DNS name needs to be specified with --ip-addresses or --dns-names", errors[0]) +} + +func TestCreateNodeCertificate_WithAllRequiredParams_ShouldSucceed(t *testing.T) { + t.Parallel() - certificateError := generateNodeCertificate(caCertificate, caPrivateKey, ips, dnsNames, years, days, outputDirNode, nodeCertFileName, commonName, false) - assert.NoError(t, certificateError) + cleanup, tempNodeDir, tempCaDir, _, _, createNode := setupCreateNodeTestEnvironment(t) + defer cleanup() - nodeCertPath := path.Join(outputDirNode, nodeCertFileName+".crt") - nodeKeyPath := path.Join(outputDirNode, nodeCertFileName+".key") - assertFilesExist(t, nodeCertPath, nodeKeyPath) + args := []string{ + "-ca-certificate", filepath.Join(tempCaDir, "ca.crt"), + "-ca-key", filepath.Join(tempCaDir, "ca.key"), + "-out", tempNodeDir, + "-ip-addresses", "127.0.0.1", + "-dns-names", "localhost", + } + if result := createNode.Run(args); result != 0 { + t.Fatalf("Expected 0, got %d", result) + } - nodeCertificate, err := readCertificateFromFile(nodeCertPath) - assert.NoError(t, err) + assert.FileExists(t, filepath.Join(tempNodeDir, "node.crt"), "Node certificate should exist") + assert.FileExists(t, filepath.Join(tempNodeDir, "node.key"), "Node key should exist") - // verify the subject - assert.Equal(t, "CN=EventStoreDB", nodeCertificate.Subject.String()) + cert, err := readCertificateFromFile(filepath.Join(tempNodeDir, "node.crt")) + assert.NoError(t, err, "Failed to read and parse certificate file") - // verify the issuer - assert.Equal(t, caCertificate.Issuer.String(), nodeCertificate.Issuer.String()) + expectedNotAfter := time.Now().AddDate(1, 0, 0) + assert.WithinDuration(t, expectedNotAfter, cert.NotAfter, time.Second, "Certificate validity period does not match expected default") - // verify the EKUs - assert.Equal(t, 2, len(nodeCertificate.ExtKeyUsage)) - assert.Equal(t, x509.ExtKeyUsageClientAuth, nodeCertificate.ExtKeyUsage[0]) - assert.Equal(t, x509.ExtKeyUsageServerAuth, nodeCertificate.ExtKeyUsage[1]) - assert.Equal(t, 0, len(nodeCertificate.UnknownExtKeyUsage)) + caCert, err := readCertificateFromFile(filepath.Join(tempCaDir, "ca.crt")) + assert.NoError(t, err, "Failed to read and parse CA certificate file") - // verify the IP SANs - assert.Equal(t, 1, len(nodeCertificate.IPAddresses)) - assert.Equal(t, "127.0.0.1", nodeCertificate.IPAddresses[0].String()) + roots := x509.NewCertPool() + roots.AddCert(caCert) + + _, err = cert.Verify(x509.VerifyOptions{Roots: roots}) + assert.NoError(t, err, "Node certificate should be signed by the provided root CA") +} - // verify the DNS SANs - assert.Equal(t, 1, len(nodeCertificate.DNSNames)) - assert.Equal(t, "localhost", nodeCertificate.DNSNames[0]) +func TestCreateNodeCertificate_WithNameFlagAndOutput_ShouldCreateNamedCertificate(t *testing.T) { + t.Parallel() + + cleanup, tempNodeDir, tempCaDir, _, _, createNode := setupCreateNodeTestEnvironment(t) + defer cleanup() + + args := []string{ + "-ca-certificate", fmt.Sprintf("%s/ca.crt", tempCaDir), + "-ca-key", fmt.Sprintf("%s/ca.key", tempCaDir), + "-out", tempNodeDir, + "-ip-addresses", "127.0.0.1", + "-dns-names", "localhost", + "-name", "renamed", + } + + result := createNode.Run(args) + assert.Equal(t, 0, result, "The 'create-node' operation should succeed with the --name flag") + + assert.FileExists(t, filepath.Join(tempNodeDir, "renamed.crt"), "Renamed certificate should exist") + assert.FileExists(t, filepath.Join(tempNodeDir, "renamed.key"), "Renamed key should exist") +} + +func TestCreateNodeCertificate_WithNameFlagWithoutOutput_ShouldCreateNamedCertificate(t *testing.T) { + t.Parallel() + + cleanup, tempNodeDir, tempCaDir, _, _, createNode := setupCreateNodeTestEnvironment(t) + defer cleanup() + + originalDir, err := os.Getwd() + if err != nil { + t.Fatalf("Failed to get current directory: %s", err) + } + defer func(dir string) { + err := os.Chdir(dir) + if err != nil { + t.Fatalf("Failed to change to orignal directory: %s", err) + } + }(originalDir) + + if err := os.Chdir(tempNodeDir); err != nil { + t.Fatalf("Failed to change current directory: %s", err) + } + + args := []string{ + "-ca-certificate", fmt.Sprintf("%s/ca.crt", tempCaDir), + "-ca-key", fmt.Sprintf("%s/ca.key", tempCaDir), + "-ip-addresses", "127.0.0.1", + "-name", "renamed_without_output", + } + + result := createNode.Run(args) + assert.Equal(t, 0, result, "The 'create-node' operation should succeed with the --name flag") + + assert.FileExists(t, filepath.Join(tempNodeDir, "node1", "renamed_without_output.crt"), "Renamed certificate should exist") + assert.FileExists(t, filepath.Join(tempNodeDir, "node1", "renamed_without_output.key"), "Renamed key should exist") +} + +func TestCreateNodeCertificate_WithForceFlag_ShouldRegenerate(t *testing.T) { + t.Parallel() + + cleanup, tempNodeDir, tempCaDir, _, _, createNode := setupCreateNodeTestEnvironment(t) + defer cleanup() + + args := []string{ + "-ca-certificate", fmt.Sprintf("%s/ca.crt", tempCaDir), + "-ca-key", fmt.Sprintf("%s/ca.key", tempCaDir), + "-out", tempNodeDir, + "-ip-addresses", "127.0.0.1", + "-dns-names", "localhost", + } + + result := createNode.Run(args) + originalNodeCert, originalNodeKey := readAndDecodeCertificateAndKey(t, tempNodeDir, "node") + + updatedArgs := append(args, "-force") + result = createNode.Run(updatedArgs) + assert.Equal(t, 0, result, "The 'create-node' should override the existing certificate with the --force flag") + + newNodeCert, newNodeKey := readAndDecodeCertificateAndKey(t, tempNodeDir, "node") + + assert.NotEqual(t, originalNodeCert, newNodeCert, "The Node certificate should be different") + assert.NotEqual(t, originalNodeKey, newNodeKey, "The Node key should be different") +} + +func setupCreateNodeTestEnvironment(t *testing.T) (cleanupFunc func(), tempNodeDir, tempCaDir string, outputBuffer *bytes.Buffer, errorBuffer *bytes.Buffer, createNode *CreateNode) { + var err error + + tempNodeDir, err = os.MkdirTemp(os.TempDir(), "node-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %s", err) + } + + tempCaDir, err = os.MkdirTemp(os.TempDir(), "ca-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %s", err) + } + + outputBuffer = new(bytes.Buffer) + errorBuffer = new(bytes.Buffer) + + createNode = NewCreateNode(&cli.BasicUi{ + Writer: outputBuffer, + ErrorWriter: errorBuffer, }) - t.Run("force-flag", func(t *testing.T) { - years, days, outputDirCa, outputDirNode, nodeCertFileName, ipAddresses, commonName, dnsNames := setupTestEnvironmentForNodeTests(t) - - caCertificate, caPrivateKey := generateAndAssertCACert(t, years, days, outputDirCa, false) - ips, err := parseIPAddresses(ipAddresses) - assert.NoError(t, err) - - nodeCertFilePath := path.Join(outputDirNode, nodeCertFileName+".crt") - nodeKeyFilePath := path.Join(outputDirNode, nodeCertFileName+".key") - - generateNodeCertificate(caCertificate, caPrivateKey, ips, dnsNames, years, days, outputDirNode, nodeCertFileName, commonName, false) - nodeCertFile, err := readCertificateFromFile(nodeCertFilePath) - assert.NoError(t, err) - nodeKeyFile, err := readRSAKeyFromFile(nodeKeyFilePath) - assert.NoError(t, err) - - // try to generate again without force - err = generateNodeCertificate(caCertificate, caPrivateKey, ips, dnsNames, years, days, outputDirNode, nodeCertFileName, commonName, false) - assert.Error(t, err) - nodeCertFileAfter, err := readCertificateFromFile(nodeCertFilePath) - assert.NoError(t, err) - nodeKeyFileAfter, err := readRSAKeyFromFile(nodeKeyFilePath) - assert.NoError(t, err) - assert.Equal(t, nodeCertFile, nodeCertFileAfter, "Expected node certificate to be the same") - assert.Equal(t, nodeKeyFile, nodeKeyFileAfter, "Expected node key to be the same") - - // try to generate again with force - err = generateNodeCertificate(caCertificate, caPrivateKey, ips, dnsNames, years, days, outputDirNode, nodeCertFileName, commonName, true) - assert.NoError(t, err) - nodeCertFileAfterWithForce, err := readCertificateFromFile(nodeCertFilePath) - assert.NoError(t, err) - nodeKeyFileAfterWithForce, err := readRSAKeyFromFile(nodeKeyFilePath) - assert.NoError(t, err) - assert.NotEqual(t, nodeCertFileAfter, nodeCertFileAfterWithForce, "Expected node certificate to be different") - assert.NotEqual(t, nodeKeyFileAfter, nodeKeyFileAfterWithForce, "Expected node key to be different") + // We need to create a root CA file to be able to create a node certificate + createCa := NewCreateCA(&cli.BasicUi{ + Writer: new(bytes.Buffer), + ErrorWriter: new(bytes.Buffer), }) + + args := []string{"-out", tempCaDir} + if result := createCa.Run(args); result != 0 { + t.Fatalf("Expected 0, got %d", result) + } + + cleanupFunc = func() { + if err := os.RemoveAll(tempNodeDir); err != nil { + t.Logf("Failed to remove temp node directory (%s): %s", tempNodeDir, err) + } + if err := os.RemoveAll(tempCaDir); err != nil { + t.Logf("Failed to remove temp ca directory (%s): %s", tempCaDir, err) + } + } + + return cleanupFunc, tempNodeDir, tempCaDir, outputBuffer, errorBuffer, createNode } diff --git a/certificates/create_user.go b/certificates/create_user.go index 903e915..d4063bd 100644 --- a/certificates/create_user.go +++ b/certificates/create_user.go @@ -10,9 +10,7 @@ import ( "errors" "flag" "fmt" - "path" - "strings" - "text/tabwriter" + "path/filepath" "time" multierror "github.com/hashicorp/go-multierror" @@ -20,7 +18,9 @@ import ( ) type CreateUser struct { - Ui cli.Ui + Ui cli.Ui + Config CreateUserArguments + Flags *flag.FlagSet } type CreateUserArguments struct { @@ -29,41 +29,48 @@ type CreateUserArguments struct { CAKeyPath string `yaml:"ca-key"` Days int `yaml:"days"` OutputDir string `yaml:"out"` + Name string `yaml:"name"` Force bool `yaml:"force"` } +func NewCreateUser(ui cli.Ui) *CreateUser { + c := &CreateUser{Ui: ui} + + c.Flags = flag.NewFlagSet("create_user", flag.ContinueOnError) + c.Flags.Usage = func() { c.Ui.Info(c.Help()) } + c.Flags.StringVar(&c.Config.Username, "username", "", "the EventStoreDB user") + c.Flags.StringVar(&c.Config.CACertificatePath, "ca-certificate", "./ca/ca.crt", CaCertFlagUsage) + c.Flags.StringVar(&c.Config.CAKeyPath, "ca-key", "./ca/ca.key", CaKeyFlagUsage) + c.Flags.IntVar(&c.Config.Days, "days", 0, DayFlagUsage) + c.Flags.StringVar(&c.Config.OutputDir, "out", "", OutDirFlagUsage) + c.Flags.StringVar(&c.Config.Name, "name", "", NameFlagUsage) + c.Flags.BoolVar(&c.Config.Force, "force", false, ForceFlagUsage) + + return c +} + func (c *CreateUser) Run(args []string) int { - var config CreateUserArguments - - flags := flag.NewFlagSet("create_user", flag.ContinueOnError) - flags.Usage = func() { c.Ui.Info(c.Help()) } - flags.StringVar(&config.Username, "username", "", "the EventStoreDB user") - flags.StringVar(&config.CACertificatePath, "ca-certificate", "./ca/ca.crt", "the path to the CA certificate file") - flags.StringVar(&config.CAKeyPath, "ca-key", "./ca/ca.key", "the path to the CA key file") - flags.IntVar(&config.Days, "days", 0, "the validity period of the certificate in days") - flags.StringVar(&config.OutputDir, "out", "", "The output directory") - flags.BoolVar(&config.Force, "force", false, forceOption) - - if err := flags.Parse(args); err != nil { + if err := c.Flags.Parse(args); err != nil { + c.Ui.Error(err.Error()) return 1 } validationErrors := new(multierror.Error) - if len(config.Username) == 0 { - multierror.Append(validationErrors, errors.New("username is a required field")) + if len(c.Config.Username) == 0 { + _ = multierror.Append(validationErrors, errors.New("username is a required field")) } - if len(config.CACertificatePath) == 0 { - multierror.Append(validationErrors, errors.New("ca-certificate is a required field")) + if len(c.Config.CACertificatePath) == 0 { + _ = multierror.Append(validationErrors, errors.New("ca-certificate is a required field")) } - if len(config.CAKeyPath) == 0 { - multierror.Append(validationErrors, errors.New("ca-key is a required field")) + if len(c.Config.CAKeyPath) == 0 { + _ = multierror.Append(validationErrors, errors.New("ca-key is a required field")) } - if config.Days < 0 { - multierror.Append(validationErrors, errors.New("days must be positive")) + if c.Config.Days < 0 { + _ = multierror.Append(validationErrors, errors.New("days must be positive")) } if validationErrors.ErrorOrNil() != nil { @@ -71,34 +78,32 @@ func (c *CreateUser) Run(args []string) int { return 1 } - caCert, err := readCertificateFromFile(config.CACertificatePath) + caCert, err := readCertificateFromFile(c.Config.CACertificatePath) if err != nil { c.Ui.Error(err.Error()) return 1 } - caKey, err := readRSAKeyFromFile(config.CAKeyPath) + caKey, err := readRSAKeyFromFile(c.Config.CAKeyPath) if err != nil { err := fmt.Errorf("error: %s. please note that only RSA keys are currently supported", err.Error()) c.Ui.Error(err.Error()) return 1 } - outputDir := config.OutputDir - outputBaseFileName := "user-" + config.Username - - if len(outputDir) == 0 { - outputDir = outputBaseFileName + outputDir := c.Config.OutputDir + outputBaseFileName := c.Config.Name + if outputBaseFileName == "" { + outputBaseFileName = "user-" + c.Config.Username } - // check if user certificates already exist - if fileExists(path.Join(outputDir, outputBaseFileName+".crt"), config.Force) { - c.Ui.Error(ErrFileExists) - return 1 + if outputDir == "" { + outputDir = filepath.Dir(outputBaseFileName) } - if fileExists(path.Join(outputDir, outputBaseFileName+".key"), config.Force) { - c.Ui.Error(ErrFileExists) + certErr := checkCertificatesLocationWithForce(outputDir, outputBaseFileName, c.Config.Force) + if certErr != nil { + c.Ui.Error(certErr.Error()) return 1 } @@ -106,12 +111,12 @@ func (c *CreateUser) Run(args []string) int { years := 1 days := 0 - if config.Days != 0 { - days = config.Days + if c.Config.Days != 0 { + days = c.Config.Days years = 0 } - err = generateUserCertificate(config.Username, outputBaseFileName, caCert, caKey, years, days, outputDir, config.Force) + err = generateUserCertificate(c.Config.Username, outputBaseFileName, caCert, caKey, years, days, outputDir, c.Config.Force) if err != nil { c.Ui.Error(err.Error()) return 1 @@ -156,7 +161,7 @@ func generateUserCertificate(username string, outputBaseFileName string, caCert } privateKeyPem := new(bytes.Buffer) - pem.Encode(privateKeyPem, &pem.Block{ + err = pem.Encode(privateKeyPem, &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey), }) @@ -184,24 +189,10 @@ func generateUserCertificate(username string, outputBaseFileName string, caCert } func (c *CreateUser) Help() string { - var buffer bytes.Buffer - - w := tabwriter.NewWriter(&buffer, 0, 0, 2, ' ', 0) - - fmt.Fprintln(w, "Usage: create_user [options]") - fmt.Fprintln(w, c.Synopsis()) - fmt.Fprintln(w, "Options:") - - writeHelpOption(w, "username", "The name of the EventStoreDB user to generate a certificate for.") - writeHelpOption(w, "ca-certificate", "The path to the CA certificate file (default: ./ca/ca.crt).") - writeHelpOption(w, "ca-key", "The path to the CA key file (default: ./ca/ca.key).") - writeHelpOption(w, "days", "The validity period of the certificates in days (default: 1 year).") - writeHelpOption(w, "out", "The output directory (default: ./user-).") - writeHelpOption(w, "force", forceOption) - - w.Flush() - - return strings.TrimSpace(buffer.String()) + var helpText bytes.Buffer + c.Flags.SetOutput(&helpText) + c.Flags.PrintDefaults() + return helpText.String() } func (c *CreateUser) Synopsis() string { diff --git a/certificates/create_user_test.go b/certificates/create_user_test.go index f015fdc..0cf1160 100644 --- a/certificates/create_user_test.go +++ b/certificates/create_user_test.go @@ -1,86 +1,202 @@ package certificates import ( + "bytes" "crypto/x509" - "path" + "github.com/mitchellh/cli" + "os" + "path/filepath" "testing" + "time" "github.com/stretchr/testify/assert" ) -func setupTestEnvironmentForUserTests(t *testing.T) (years int, days int, username string, userCertFileName string, outputDirCa string, outputDirUser string) { - years = 1 - days = 0 - username = "bob" - userCertFileName = "user-" + username - outputDirCa = "./ca" - outputDirUser = "./" + userCertFileName +func TestCreateUserCertificate(t *testing.T) { + t.Run("TestCreateUserCertificate_WithoutParams_ShouldFail", TestCreateUserCertificate_WithoutParams_ShouldFail) + t.Run("TestCreateUserCertificate_WithAllRequiredParams_ShouldSucceed", TestCreateUserCertificate_WithAllRequiredParams_ShouldSucceed) + t.Run("TestCreateUserCertificate_WithNegativeDays_ShouldFail", TestCreateUserCertificate_WithNegativeDays_ShouldFail) + t.Run("TestCreateUserCertificate_WithForceFlag_ShouldRegenerate", TestCreateUserCertificate_WithForceFlag_ShouldRegenerate) + t.Run("TestCreateUserCertificate_WithNameFlag_ShouldSucceed", TestCreateUserCertificate_WithNameFlag_ShouldSucceed) +} + +func TestCreateUserCertificate_WithoutParams_ShouldFail(t *testing.T) { + t.Parallel() + + cleanup, _, _, _, errorBuffer, createUser := setupCreateUserTestEnvironment(t) + defer cleanup() + + var args []string + result := createUser.Run(args) + + assert.Equal(t, 1, result, "The 'create-user' operation should fail due to the absence of required parameters.") + + errors := extractErrors(errorBuffer.String()) + assert.Equal(t, 1, len(errors)) + assert.Equal(t, "username is a required field", errors[0]) +} + +func TestCreateUserCertificate_WithAllRequiredParams_ShouldSucceed(t *testing.T) { + t.Parallel() + + cleanup, tempUserDir, tempCaDir, _, _, createUser := setupCreateUserTestEnvironment(t) + defer cleanup() + + username := "ouro" + args := []string{ + "-username", username, + "-ca-certificate", filepath.Join(tempCaDir, "ca.crt"), + "-ca-key", filepath.Join(tempCaDir, "ca.key"), + "-out", tempUserDir, + } + + if result := createUser.Run(args); result != 0 { + t.Fatalf("Expected 0, got %d", result) + } + + userFmt := "user-" + username + userCertPath := filepath.Join(tempUserDir, userFmt+".crt") + userKeyPath := filepath.Join(tempUserDir, userFmt+".key") + + assert.FileExists(t, userCertPath, "User certificate should exist") + assert.FileExists(t, userKeyPath, "User key should exist") + + cert, err := readCertificateFromFile(userCertPath) + assert.NoError(t, err, "Failed to read and parse certificate file") + + expectedNotAfter := time.Now().AddDate(1, 0, 0) + assert.WithinDuration(t, expectedNotAfter, cert.NotAfter, time.Second, "Certificate validity period does not match expected default") + + caCert, err := readCertificateFromFile(filepath.Join(tempCaDir, "ca.crt")) + assert.NoError(t, err, "Failed to read and parse CA certificate file") + + roots := x509.NewCertPool() + roots.AddCert(caCert) + + _, err = cert.Verify(x509.VerifyOptions{Roots: roots, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}}) + assert.NoError(t, err, "User certificate should be signed by the provided root CA") + assert.Equal(t, username, cert.Subject.CommonName, "The common name of the certificate should be the same as the provided username") +} - cleanupDirsForTest(t, outputDirCa, outputDirUser) - return +func TestCreateUserCertificate_WithNegativeDays_ShouldFail(t *testing.T) { + t.Parallel() + + cleanup, _, tempCaDir, _, errorBuffer, createUser := setupCreateUserTestEnvironment(t) + defer cleanup() + + args := []string{ + "-username", "ouro", + "-ca-certificate", filepath.Join(tempCaDir, "ca.crt"), + "-ca-key", filepath.Join(tempCaDir, "ca.key"), + "-days", "-1", + } + result := createUser.Run(args) + + assert.Equal(t, 1, result, "The 'create-user' operation should fail when days is negative.") + + errors := extractErrors(errorBuffer.String()) + assert.Equal(t, 1, len(errors)) + assert.Equal(t, "days must be positive", errors[0]) } -func TestGenerateUserCertificate(t *testing.T) { +func TestCreateUserCertificate_WithForceFlag_ShouldRegenerate(t *testing.T) { + t.Parallel() + + cleanup, tempUserDir, tempCaDir, _, _, createUser := setupCreateUserTestEnvironment(t) + defer cleanup() - t.Run("nominal-case", func(t *testing.T) { - years, days, username, userCertFileName, outputDirCa, outputDirUser := setupTestEnvironmentForUserTests(t) + username := "ouro" + args := []string{ + "-username", username, + "-ca-certificate", filepath.Join(tempCaDir, "ca.crt"), + "-ca-key", filepath.Join(tempCaDir, "ca.key"), + "-out", tempUserDir, + } - caCertificate, caPrivateKey := generateAndAssertCACert(t, years, days, outputDirCa, false) + result := createUser.Run(args) + + userFmt := "user-" + username + originalUserCert, originalUserKey := readAndDecodeCertificateAndKey(t, tempUserDir, userFmt) + + updatedArgs := append(args, "-force") + result = createUser.Run(updatedArgs) + assert.Equal(t, 0, result, "The 'create-user' should override the existing certificate with the --force flag") + + newUserCert, newUserKey := readAndDecodeCertificateAndKey(t, tempUserDir, userFmt) + + assert.NotEqual(t, originalUserCert, newUserCert, "The User certificate should be different") + assert.NotEqual(t, originalUserKey, newUserKey, "The User key should be different") +} - err := generateUserCertificate(username, userCertFileName, caCertificate, caPrivateKey, years, days, outputDirUser, false) - assert.NoError(t, err) +func TestCreateUserCertificate_WithNameFlag_ShouldSucceed(t *testing.T) { + t.Parallel() - userCertPath := path.Join(outputDirUser, userCertFileName+".crt") - userKeyPath := path.Join(outputDirUser, userCertFileName+".key") - assertFilesExist(t, userCertPath, userKeyPath) + cleanup, tempUserDir, tempCaDir, _, _, createUser := setupCreateUserTestEnvironment(t) + defer cleanup() - userCertificate, _ := readCertificateFromFile(userCertPath) + username := "ouro" + name := "testing" + args := []string{ + "-username", username, + "-name", name, + "-ca-certificate", filepath.Join(tempCaDir, "ca.crt"), + "-ca-key", filepath.Join(tempCaDir, "ca.key"), + "-out", tempUserDir, + } - // verify the subject - assert.Equal(t, "CN="+username, userCertificate.Subject.String()) + result := createUser.Run(args) - // verify the issuer - assert.Equal(t, caCertificate.Issuer.String(), userCertificate.Issuer.String()) + assert.Equal(t, 0, result, "The 'create-user' create the certificates with the provided name") - // verify the EKUs - assert.Equal(t, 1, len(userCertificate.ExtKeyUsage)) - assert.Equal(t, x509.ExtKeyUsageClientAuth, userCertificate.ExtKeyUsage[0]) - assert.Equal(t, 0, len(userCertificate.UnknownExtKeyUsage)) + assert.FileExists(t, filepath.Join(tempUserDir, name+".crt"), "User certificate should exist") + assert.FileExists(t, filepath.Join(tempUserDir, name+".key"), "User key should exist") + + cert, err := readCertificateFromFile(filepath.Join(tempUserDir, name+".crt")) + assert.NoError(t, err, "Failed to read and parse certificate file") + + assert.Equal(t, username, cert.Subject.CommonName, "The common name of the certificate should be the same as the provided username") +} + +func setupCreateUserTestEnvironment(t *testing.T) (cleanupFunc func(), tempUserDir string, tempCaDir string, outputBuffer *bytes.Buffer, errorBuffer *bytes.Buffer, createUser *CreateUser) { + var err error + + tempUserDir, err = os.MkdirTemp(os.TempDir(), "user-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %s", err) + } + + tempCaDir, err = os.MkdirTemp(os.TempDir(), "ca-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %s", err) + } + + outputBuffer = new(bytes.Buffer) + errorBuffer = new(bytes.Buffer) + + createUser = NewCreateUser(&cli.BasicUi{ + Writer: outputBuffer, + ErrorWriter: errorBuffer, }) - t.Run("force-flag", func(t *testing.T) { - years, days, username, userCertFileName, outputDirCa, outputDirUser := setupTestEnvironmentForUserTests(t) - - caCertificate, caPrivateKey := generateAndAssertCACert(t, years, days, outputDirCa, false) - - err := generateUserCertificate(username, userCertFileName, caCertificate, caPrivateKey, years, days, outputDirUser, false) - assert.NoError(t, err) - - userCertPath := path.Join(outputDirUser, userCertFileName+".crt") - userKeyPath := path.Join(outputDirUser, userCertFileName+".key") - assertFilesExist(t, userCertPath, userKeyPath) - - userCertificate, _ := readCertificateFromFile(userCertPath) - userCertificateKey, _ := readRSAKeyFromFile(userKeyPath) - - // try to generate again without force - err = generateUserCertificate(username, userCertFileName, caCertificate, caPrivateKey, years, days, outputDirUser, false) - assert.Error(t, err) - userCertificateAfter, err := readCertificateFromFile(userCertPath) - assert.NoError(t, err) - userCertificateKeyAfter, err := readRSAKeyFromFile(userKeyPath) - assert.NoError(t, err) - assert.Equal(t, userCertificate, userCertificateAfter, "Expected user certificate to be the same") - assert.Equal(t, userCertificateKey, userCertificateKeyAfter, "Expected user key to be the same") - - // try to generate again with force - err = generateUserCertificate(username, userCertFileName, caCertificate, caPrivateKey, years, days, outputDirUser, true) - assert.NoError(t, err) - userCertificateAfterWithForce, err := readCertificateFromFile(userCertPath) - assert.NoError(t, err) - userCertificateKeyAfterWithForce, err := readRSAKeyFromFile(userKeyPath) - assert.NoError(t, err) - assert.NotEqual(t, userCertificate, userCertificateAfterWithForce, "Expected user certificate to be different") - assert.NotEqual(t, userCertificateKey, userCertificateKeyAfterWithForce, "Expected user key to be different") + // We need to create a root CA file to be able to create a user certificate + createCa := NewCreateCA(&cli.BasicUi{ + Writer: new(bytes.Buffer), + ErrorWriter: new(bytes.Buffer), }) + + args := []string{"-out", tempCaDir} + if result := createCa.Run(args); result != 0 { + t.Fatalf("Expected 0, got %d", result) + } + + cleanupFunc = func() { + if err := os.RemoveAll(tempUserDir); err != nil { + t.Logf("Failed to remove temp user directory (%s): %s", tempUserDir, err) + } + if err := os.RemoveAll(tempCaDir); err != nil { + t.Logf("Failed to remove temp ca directory (%s): %s", tempCaDir, err) + } + } + + return cleanupFunc, tempUserDir, tempCaDir, outputBuffer, errorBuffer, createUser } diff --git a/main.go b/main.go index 54b3e7a..b9c5d0c 100644 --- a/main.go +++ b/main.go @@ -29,7 +29,10 @@ func main() { flags := flag.NewFlagSet("config", flag.ContinueOnError) if !c.IsVersion() && !c.IsHelp() { - flags.Parse(os.Args[1:]) + err := flags.Parse(os.Args[1:]) + if err != nil { + ui.Error(err.Error()) + } args = flags.Args() } @@ -38,36 +41,32 @@ func main() { c.Commands = map[string]cli.CommandFactory{ "create-ca": func() (cli.Command, error) { - return &certificates.CreateCA{ - Ui: &cli.ColoredUi{ - Ui: ui, - OutputColor: cli.UiColorBlue, - }, - }, nil + return certificates.NewCreateCA(&cli.ColoredUi{ + Ui: ui, + OutputColor: cli.UiColorBlue, + }), nil }, "create-node": func() (cli.Command, error) { - return &certificates.CreateNode{ - Ui: &cli.ColoredUi{ + return certificates.NewCreateNode( + &cli.ColoredUi{ Ui: ui, OutputColor: cli.UiColorBlue, }, - }, nil + ), nil }, "create-certs": func() (cli.Command, error) { - return &certificates.CreateCertificates{ - Ui: &cli.ColoredUi{ - Ui: ui, - OutputColor: cli.UiColorBlue, - }, - }, nil + return certificates.NewCreateCerts(&cli.ColoredUi{ + Ui: ui, + OutputColor: cli.UiColorBlue, + }, + ), nil }, "create-user": func() (cli.Command, error) { - return &certificates.CreateUser{ - Ui: &cli.ColoredUi{ - Ui: ui, - OutputColor: cli.UiColorBlue, - }, - }, nil + return certificates.NewCreateUser(&cli.ColoredUi{ + Ui: ui, + OutputColor: cli.UiColorBlue, + }, + ), nil }, } c.HelpFunc = createGeneralHelpFunc(appName, flags) diff --git a/references/certs.yml b/references/certs.yml new file mode 100644 index 0000000..8b19725 --- /dev/null +++ b/references/certs.yml @@ -0,0 +1,28 @@ +certificates: + ca-certs: + - out: "./root_ca" + - out: "./intermediate_ca" + ca-certificate: "./root_ca/ca.crt" + ca-key: "./root_ca/ca.key" + days: 5 + node-certs: + - out: "./node1" + ca-certificate: "./intermediate_ca/ca.crt" + ca-key: "./intermediate_ca/ca.key" + ip-addresses: "127.0.0.1,172.20.240.1" + dns-names: "localhost,eventstore-node1.localhost.com" + - out: "./node2" + ca-certificate: "./intermediate_ca/ca.crt" + ca-key: "./intermediate_ca/ca.key" + ip-addresses: "127.0.0.2,172.20.240.2" + dns-names: "localhost,eventstore-node2.localhost.com" + - out: "./node3" + ca-certificate: "./intermediate_ca/ca.crt" + ca-key: "./intermediate_ca/ca.key" + ip-addresses: "127.0.0.3,172.20.240.3" + dns-names: "localhost,eventstore-node2.localhost.com" + user-certs: + - out: "./user-admin" + username: "admin" + ca-certificate: "./root_ca/ca.crt" + ca-key: "./root_ca/ca.key" diff --git a/references/named_certs.yml b/references/named_certs.yml new file mode 100644 index 0000000..af06aa8 --- /dev/null +++ b/references/named_certs.yml @@ -0,0 +1,34 @@ +certificates: + ca-certs: + - out: "./root_ca" + name: "root" + - out: "./intermediate_ca" + name: "intermediate" + ca-certificate: "./root_ca/root.crt" + ca-key: "./root_ca/root.key" + days: 5 + node-certs: + - out: "./node1" + name: "node1" + ca-certificate: "./intermediate_ca/intermediate.crt" + ca-key: "./intermediate_ca/intermediate.key" + ip-addresses: "127.0.0.1,172.20.240.1" + dns-names: "localhost,eventstore-node1.localhost.com" + - out: "./node2" + name: "node2" + ca-certificate: "./intermediate_ca/intermediate.crt" + ca-key: "./intermediate_ca/intermediate.key" + ip-addresses: "127.0.0.2,172.20.240.2" + dns-names: "localhost,eventstore-node2.localhost.com" + - out: "./node3" + name: "node3" + ca-certificate: "./intermediate_ca/intermediate.crt" + ca-key: "./intermediate_ca/intermediate.key" + ip-addresses: "127.0.0.3,172.20.240.3" + dns-names: "localhost,eventstore-node2.localhost.com" + user-certs: + - out: "./user-admin" + username: "admin" + name: "admin" + ca-certificate: "./root_ca/root.crt" + ca-key: "./root_ca/root.key" \ No newline at end of file