Skip to content

Commit

Permalink
create and use CSRSignerContext in the server and middleware (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
jessepeterson authored Aug 24, 2023
1 parent 699e8df commit 9f27f76
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 26 deletions.
7 changes: 4 additions & 3 deletions challenge/challenge.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package challenge

import (
"context"
"crypto/x509"
"errors"

Expand All @@ -16,8 +17,8 @@ type Store interface {
}

// Middleware wraps next in a CSRSigner that verifies and invalidates the challenge
func Middleware(store Store, next scepserver.CSRSigner) scepserver.CSRSignerFunc {
return func(m *scep.CSRReqMessage) (*x509.Certificate, error) {
func Middleware(store Store, next scepserver.CSRSignerContext) scepserver.CSRSignerContextFunc {
return func(ctx context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) {
// TODO: compare challenge only for PKCSReq?
valid, err := store.HasChallenge(m.ChallengePassword)
if err != nil {
Expand All @@ -26,6 +27,6 @@ func Middleware(store Store, next scepserver.CSRSigner) scepserver.CSRSignerFunc
if !valid {
return nil, errors.New("invalid challenge")
}
return next.SignCSR(m)
return next.SignCSRContext(ctx, m)
}
}
7 changes: 5 additions & 2 deletions challenge/challenge_bolt_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package challenge

import (
"context"
"io/ioutil"
"os"
"testing"
Expand Down Expand Up @@ -69,12 +70,14 @@ func TestDynamicChallenge(t *testing.T) {
ChallengePassword: challengePassword,
}

_, err = signer.SignCSR(csrReq)
ctx := context.Background()

_, err = signer.SignCSRContext(ctx, csrReq)
if err != nil {
t.Error(err)
}

_, err = signer.SignCSR(csrReq)
_, err = signer.SignCSRContext(ctx, csrReq)
if err == nil {
t.Error("challenge should not be valid twice")
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/scepserver/scepserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ func main() {
if *flSignServerAttrs {
signerOpts = append(signerOpts, scepdepot.WithSeverAttrs())
}
var signer scepserver.CSRSigner = scepdepot.NewSigner(depot, signerOpts...)
var signer scepserver.CSRSignerContext = scepserver.SignCSRAdapter(scepdepot.NewSigner(depot, signerOpts...))
if *flChallengePassword != "" {
signer = scepserver.ChallengeMiddleware(*flChallengePassword, signer)
signer = scepserver.StaticChallengeMiddleware(*flChallengePassword, signer)
}
if csrVerifier != nil {
signer = csrverifier.Middleware(csrVerifier, signer)
Expand Down
7 changes: 6 additions & 1 deletion cryptoutil/cryptoutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,23 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"math/big"
"testing"
)

func TestGenerateSubjectKeyID(t *testing.T) {
ecKey, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
if err != nil {
t.Fatal(err)
}
for _, test := range []struct {
testName string
pub crypto.PublicKey
}{
{"RSA", &rsa.PublicKey{N: big.NewInt(123), E: 65537}},
{"ECDSA", &ecdsa.PublicKey{X: big.NewInt(123), Y: big.NewInt(123), Curve: elliptic.P224()}},
{"ECDSA", ecKey.Public()},
} {
test := test
t.Run(test.testName, func(t *testing.T) {
Expand Down
7 changes: 4 additions & 3 deletions csrverifier/csrverifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package csrverifier

import (
"context"
"crypto/x509"
"errors"

Expand All @@ -15,15 +16,15 @@ type CSRVerifier interface {
}

// Middleware wraps next in a CSRSigner that runs verifier
func Middleware(verifier CSRVerifier, next scepserver.CSRSigner) scepserver.CSRSignerFunc {
return func(m *scep.CSRReqMessage) (*x509.Certificate, error) {
func Middleware(verifier CSRVerifier, next scepserver.CSRSignerContext) scepserver.CSRSignerContextFunc {
return func(ctx context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) {
ok, err := verifier.Verify(m.RawDecrypted)
if err != nil {
return nil, err
}
if !ok {
return nil, errors.New("CSR verify failed")
}
return next.SignCSR(m)
return next.SignCSRContext(ctx, m)
}
}
40 changes: 32 additions & 8 deletions server/csrsigner.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
package scepserver

import (
"context"
"crypto/subtle"
"crypto/x509"
"errors"

"github.com/micromdm/scep/v2/scep"
)

// CSRSignerContext is a handler for signing CSRs by a CA/RA.
//
// SignCSRContext should take the CSR in the CSRReqMessage and return a
// Certificate signed by the CA.
type CSRSignerContext interface {
SignCSRContext(context.Context, *scep.CSRReqMessage) (*x509.Certificate, error)
}

// CSRSignerContextFunc is an adapter for CSR signing by the CA/RA.
type CSRSignerContextFunc func(context.Context, *scep.CSRReqMessage) (*x509.Certificate, error)

// SignCSR calls f(ctx, m).
func (f CSRSignerContextFunc) SignCSRContext(ctx context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) {
return f(ctx, m)
}

// CSRSigner is a handler for CSR signing by the CA/RA
//
// SignCSR should take the CSR in the CSRReqMessage and return a
Expand All @@ -16,29 +33,36 @@ type CSRSigner interface {
SignCSR(*scep.CSRReqMessage) (*x509.Certificate, error)
}

// CSRSignerFunc is an adapter for CSR signing by the CA/RA
// CSRSignerFunc is an adapter for CSR signing by the CA/RA.
type CSRSignerFunc func(*scep.CSRReqMessage) (*x509.Certificate, error)

// SignCSR calls f(m)
// SignCSR calls f(m).
func (f CSRSignerFunc) SignCSR(m *scep.CSRReqMessage) (*x509.Certificate, error) {
return f(m)
}

// NopCSRSigner does nothing
func NopCSRSigner() CSRSignerFunc {
return func(m *scep.CSRReqMessage) (*x509.Certificate, error) {
// NopCSRSigner does nothing.
func NopCSRSigner() CSRSignerContextFunc {
return func(_ context.Context, _ *scep.CSRReqMessage) (*x509.Certificate, error) {
return nil, nil
}
}

// ChallengeMiddleware wraps next in a CSRSigner that validates the challenge from the CSR
func ChallengeMiddleware(challenge string, next CSRSigner) CSRSignerFunc {
// StaticChallengeMiddleware wraps next and validates the challenge from the CSR.
func StaticChallengeMiddleware(challenge string, next CSRSignerContext) CSRSignerContextFunc {
challengeBytes := []byte(challenge)
return func(m *scep.CSRReqMessage) (*x509.Certificate, error) {
return func(ctx context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) {
// TODO: compare challenge only for PKCSReq?
if subtle.ConstantTimeCompare(challengeBytes, []byte(m.ChallengePassword)) != 1 {
return nil, errors.New("invalid challenge")
}
return next.SignCSRContext(ctx, m)
}
}

// SignCSRAdapter adapts a next (i.e. no context) to a context signer.
func SignCSRAdapter(next CSRSigner) CSRSignerContextFunc {
return func(_ context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) {
return next.SignCSR(m)
}
}
9 changes: 6 additions & 3 deletions server/csrsigner_test.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
package scepserver

import (
"context"
"testing"

"github.com/micromdm/scep/v2/scep"
)

func TestChallengeMiddleware(t *testing.T) {
testPW := "RIGHT"
signer := ChallengeMiddleware(testPW, NopCSRSigner())
signer := StaticChallengeMiddleware(testPW, NopCSRSigner())

csrReq := &scep.CSRReqMessage{ChallengePassword: testPW}

_, err := signer.SignCSR(csrReq)
ctx := context.Background()

_, err := signer.SignCSRContext(ctx, csrReq)
if err != nil {
t.Error(err)
}

csrReq.ChallengePassword = "WRONG"

_, err = signer.SignCSR(csrReq)
_, err = signer.SignCSRContext(ctx, csrReq)
if err == nil {
t.Error("invalid challenge should generate an error")
}
Expand Down
6 changes: 3 additions & 3 deletions server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type service struct {
// The (chainable) CSR signing function. Intended to handle all
// SCEP request functionality such as CSR & challenge checking, CA
// issuance, RA proxying, etc.
signer CSRSigner
signer CSRSignerContext

/// info logging is implemented in the service middleware layer.
debugLogger log.Logger
Expand Down Expand Up @@ -80,7 +80,7 @@ func (svc *service) PKIOperation(ctx context.Context, data []byte) ([]byte, erro
return nil, err
}

crt, err := svc.signer.SignCSR(msg.CSRReqMessage)
crt, err := svc.signer.SignCSRContext(ctx, msg.CSRReqMessage)
if err == nil && crt == nil {
err = errors.New("no signed certificate")
}
Expand Down Expand Up @@ -119,7 +119,7 @@ func WithAddlCA(ca *x509.Certificate) ServiceOption {
}

// NewService creates a new scep service
func NewService(crt *x509.Certificate, key *rsa.PrivateKey, signer CSRSigner, opts ...ServiceOption) (Service, error) {
func NewService(crt *x509.Certificate, key *rsa.PrivateKey, signer CSRSignerContext, opts ...ServiceOption) (Service, error) {
s := &service{
crt: crt,
key: key,
Expand Down
2 changes: 1 addition & 1 deletion server/service_bolt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestCaCert(t *testing.T) {
caCert := certs[0]

// SCEP service
svc, err := scepserver.NewService(caCert, key, scepdepot.NewSigner(depot))
svc, err := scepserver.NewService(caCert, key, scepserver.SignCSRAdapter(scepdepot.NewSigner(depot)))
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit 9f27f76

Please sign in to comment.