Skip to content

Commit

Permalink
internal/argon2: return errors
Browse files Browse the repository at this point in the history
We're currently wrapping x/crypto/argon2 which panics on some invalid
arguments. Instead, move the checking of these arguments in order to
avoid some panics.
  • Loading branch information
chrisccoulson committed Sep 10, 2024
1 parent ceee2eb commit 96d8049
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 30 deletions.
25 changes: 7 additions & 18 deletions argon2.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ type Argon2CostParams struct {
}

func (p *Argon2CostParams) internalParams() *argon2.CostParams {
if p == nil {
return nil
}
return &argon2.CostParams{
Time: p.Time,
MemoryKiB: p.MemoryKiB,
Expand All @@ -228,33 +231,19 @@ type Argon2KDF interface {
type inProcessArgon2KDFImpl struct{}

func (_ inProcessArgon2KDFImpl) Derive(passphrase string, salt []byte, mode Argon2Mode, params *Argon2CostParams, keyLen uint32) ([]byte, error) {
switch {
case mode != Argon2i && mode != Argon2id:
if mode != Argon2i && mode != Argon2id {
return nil, errors.New("invalid mode")
case params == nil:
return nil, errors.New("nil params")
case params.Time == 0:
return nil, errors.New("invalid time cost")
case params.Threads == 0:
return nil, errors.New("invalid number of threads")
}

return argon2.Key(passphrase, salt, argon2.Mode(mode), params.internalParams(), keyLen), nil
return argon2.Key(passphrase, salt, argon2.Mode(mode), params.internalParams(), keyLen)
}

func (_ inProcessArgon2KDFImpl) Time(mode Argon2Mode, params *Argon2CostParams) (time.Duration, error) {
switch {
case mode != Argon2i && mode != Argon2id:
if mode != Argon2i && mode != Argon2id {
return 0, errors.New("invalid mode")
case params == nil:
return 0, errors.New("nil params")
case params.Time == 0:
return 0, errors.New("invalid time cost")
case params.Threads == 0:
return 0, errors.New("invalid number of threads")
}

return argon2.KeyDuration(argon2.Mode(mode), params.internalParams()), nil
return argon2.KeyDuration(argon2.Mode(mode), params.internalParams())
}

// InProcessArgon2KDF is the in-process implementation of the Argon2 KDF. This
Expand Down
8 changes: 7 additions & 1 deletion argon2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ func (s *argon2Suite) TestInProcessKDFTimeInvalidThreads(c *C) {
c.Check(err, ErrorMatches, `invalid number of threads`)
}

func (s *argon2Suite) TestModeConstants(c *C) {
c.Check(Argon2i, Equals, Argon2Mode(argon2.ModeI))
c.Check(Argon2id, Equals, Argon2Mode(argon2.ModeID))
}

type argon2SuiteExpensive struct{}

func (s *argon2SuiteExpensive) SetUpSuite(c *C) {
Expand All @@ -300,10 +305,11 @@ func (s *argon2SuiteExpensive) testInProcessKDFDerive(c *C, data *testInProcessA
c.Check(err, IsNil)
runtime.GC()

expected := argon2.Key(data.passphrase, data.salt, argon2.Mode(data.mode), &argon2.CostParams{
expected, err := argon2.Key(data.passphrase, data.salt, argon2.Mode(data.mode), &argon2.CostParams{
Time: data.params.Time,
MemoryKiB: data.params.MemoryKiB,
Threads: data.params.Threads}, data.keyLen)
c.Check(err, IsNil)
runtime.GC()

c.Check(key, DeepEquals, expected)
Expand Down
22 changes: 16 additions & 6 deletions internal/argon2/argon2.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,10 @@ func (c *benchmarkContext) run(params *BenchmarkParams, keyFn KeyDurationFunc, s
//
// By design, this function consumes a lot of memory depending on the supplied
// parameters. It may be desirable to execute it in a short-lived utility process.
func KeyDuration(mode Mode, params *CostParams) time.Duration {
func KeyDuration(mode Mode, params *CostParams) (time.Duration, error) {
start := time.Now()
Key(benchmarkPassword, benchmarkSalt, mode, params, benchmarkKeyLen)
return time.Now().Sub(start)
_, err := Key(benchmarkPassword, benchmarkSalt, mode, params, benchmarkKeyLen)
return time.Now().Sub(start), err
}

// KeyDurationFunc provides a mechanism to delegate key derivation measurements
Expand Down Expand Up @@ -349,7 +349,17 @@ func Benchmark(params *BenchmarkParams, keyFn KeyDurationFunc) (*CostParams, err
// By design, this function consumes a lot of memory depending on the supplied parameters.
// It may be desirable to execute it in a short-lived utility process.
//
// This will panic if the time or threads cost parameter are zero.
func Key(passphrase string, salt []byte, mode Mode, params *CostParams, keyLen uint32) []byte {
return mode.keyFn()([]byte(passphrase), salt, params.Time, params.MemoryKiB, params.Threads, keyLen)
// This will return an error if the time or threads cost parameter are zero. If the memory
// cost is less than the minimum (8KiB per thread) for the specified number of threads, it
// will be rounded up to the minimum accordingly.
func Key(passphrase string, salt []byte, mode Mode, params *CostParams, keyLen uint32) ([]byte, error) {
switch {
case params == nil:
return nil, errors.New("nil params")
case params.Time == 0:
return nil, errors.New("invalid time cost")
case params.Threads == 0:
return nil, errors.New("invalid number of threads")
}
return mode.keyFn()([]byte(passphrase), salt, params.Time, params.MemoryKiB, params.Threads, keyLen), nil
}
57 changes: 52 additions & 5 deletions internal/argon2/argon2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,36 @@ func (s *argon2Suite) TestBenchmarkNoProgress(c *C) {
c.Check(err, ErrorMatches, "not making sufficient progress")
}

func (s *argon2Suite) TestKeyDurationNilParams(c *C) {
_, err := KeyDuration(ModeID, nil)
c.Check(err, ErrorMatches, `nil params`)
}

func (s *argon2Suite) TestKeyDurationInvalidTime(c *C) {
_, err := KeyDuration(ModeID, &CostParams{Threads: 1})
c.Check(err, ErrorMatches, `invalid time cost`)
}

func (s *argon2Suite) TestKeyDurationInvalidThreads(c *C) {
_, err := KeyDuration(ModeID, &CostParams{Time: 1})
c.Check(err, ErrorMatches, `invalid number of threads`)
}

func (s *argon2Suite) TestKeyNilParams(c *C) {
_, err := Key("foo", nil, ModeID, nil, 32)
c.Check(err, ErrorMatches, `nil params`)
}

func (s *argon2Suite) TestKeyInvalidTime(c *C) {
_, err := Key("foo", nil, ModeID, &CostParams{Threads: 1}, 32)
c.Check(err, ErrorMatches, `invalid time cost`)
}

func (s *argon2Suite) TestKeyInvalidThreads(c *C) {
_, err := Key("foo", nil, ModeID, &CostParams{Time: 1}, 32)
c.Check(err, ErrorMatches, `invalid number of threads`)
}

type argon2SuiteExpensive struct{}

var _ = Suite(&argon2SuiteExpensive{})
Expand Down Expand Up @@ -360,7 +390,8 @@ func (s *argon2SuiteExpensive) testKey(c *C, data *testKeyData) {
data.params.Threads = maxThreads
}

key := Key(data.passphrase, salt, data.mode, data.params, data.keyLen)
key, err := Key(data.passphrase, salt, data.mode, data.params, data.keyLen)
c.Check(err, IsNil)

var expectedKey []byte
switch data.mode {
Expand Down Expand Up @@ -456,23 +487,39 @@ func (s *argon2SuiteExpensive) TestKey7(c *C) {
keyLen: 32})
}

func (s *argon2SuiteExpensive) TestKey8(c *C) {
s.testKey(c, &testKeyData{
mode: ModeI,
passphrase: "ubuntu",
saltLen: 16,
params: &CostParams{
Time: 4,
MemoryKiB: 0,
Threads: 4},
keyLen: 32})
}

func (s *argon2SuiteExpensive) TestKeyDuration(c *C) {
time1 := KeyDuration(ModeID, &CostParams{Time: 4, MemoryKiB: 32 * 1024, Threads: 4})
time1, err := KeyDuration(ModeID, &CostParams{Time: 4, MemoryKiB: 32 * 1024, Threads: 4})
c.Check(err, IsNil)
runtime.GC()

time2 := KeyDuration(ModeID, &CostParams{Time: 16, MemoryKiB: 32 * 1024, Threads: 4})
time2, err := KeyDuration(ModeID, &CostParams{Time: 16, MemoryKiB: 32 * 1024, Threads: 4})
c.Check(err, IsNil)
runtime.GC()
// XXX: this needs a checker like go-tpm2/testutil's IntGreater, which copes with
// types of int64 kind
c.Check(time2 > time1, testutil.IsTrue)

time2 = KeyDuration(ModeID, &CostParams{Time: 4, MemoryKiB: 128 * 1024, Threads: 4})
time2, err = KeyDuration(ModeID, &CostParams{Time: 4, MemoryKiB: 128 * 1024, Threads: 4})
c.Check(err, IsNil)
runtime.GC()
// XXX: this needs a checker like go-tpm2/testutil's IntGreater, which copes with
// types of int64 kind
c.Check(time2 > time1, testutil.IsTrue)

time2 = KeyDuration(ModeID, &CostParams{Time: 4, MemoryKiB: 32 * 1024, Threads: 1})
time2, err = KeyDuration(ModeID, &CostParams{Time: 4, MemoryKiB: 32 * 1024, Threads: 1})
c.Check(err, IsNil)
runtime.GC()
// XXX: this needs a checker like go-tpm2/testutil's IntGreater, which copes with
// types of int64 kind
Expand Down

0 comments on commit 96d8049

Please sign in to comment.