From c6f79328ddf90f25424ecb6e874c109718883daa Mon Sep 17 00:00:00 2001 From: lcforges <90981702+lcforges@users.noreply.github.com> Date: Tue, 1 Aug 2023 23:17:18 +0000 Subject: [PATCH] feat: add support for decryption --- pkcs11/pkcs11.go | 139 +++++++++++++++++++++++++++++++++--------- pkcs11/pkcs11_test.go | 75 +++++++++++++++++++++++ 2 files changed, 186 insertions(+), 28 deletions(-) diff --git a/pkcs11/pkcs11.go b/pkcs11/pkcs11.go index 6fd4ceb..d3aebf6 100644 --- a/pkcs11/pkcs11.go +++ b/pkcs11/pkcs11.go @@ -232,6 +232,26 @@ CK_RV ck_sign( ) { return (*fl->C_Sign)(hSession, pData, ulDataLen, pSignature, pulSignatureLen); } + +CK_RV ck_decrypt_init( + CK_FUNCTION_LIST_PTR fl, + CK_SESSION_HANDLE hSession, + CK_MECHANISM_PTR pMechanism, + CK_OBJECT_HANDLE hKey +) { + return (*fl->C_DecryptInit)(hSession, pMechanism, hKey); +} + +CK_RV ck_decrypt( + CK_FUNCTION_LIST_PTR fl, + CK_SESSION_HANDLE hSession, + CK_BYTE_PTR pEncryptedData, + CK_ULONG ulEncryptedDataLen, + CK_BYTE_PTR pData, + CK_ULONG_PTR pulDataLen +) { + return (*fl->C_Decrypt)(hSession, pEncryptedData, ulEncryptedDataLen, pData, pulDataLen); +} */ // #cgo linux LDFLAGS: -ldl import "C" @@ -1226,20 +1246,15 @@ func (r *rsaPrivateKey) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) // http://docs.oasis-open.org/pkcs11/pkcs11-curr/v2.40/cs01/pkcs11-curr-v2.40-cs01.html#_Toc399398842 size := opts.HashFunc().Size() if size != len(digest) { - return nil, fmt.Errorf("input mush be hashed") + return nil, fmt.Errorf("input must be hashed") } prefix, ok := hashPrefixes[opts.HashFunc()] if !ok { return nil, fmt.Errorf("unsupported hash function: %s", opts.HashFunc()) } - cBytes := make([]C.CK_BYTE, len(prefix)+len(digest)) - for i, b := range prefix { - cBytes[i] = C.CK_BYTE(b) - } - for i, b := range digest { - cBytes[len(prefix)+i] = C.CK_BYTE(b) - } + preAndDigest := append(prefix, digest...) + cBytes := toCBytes(preAndDigest) cSig := make([]C.CK_BYTE, r.pub.Size()) cSigLen := C.CK_ULONG(len(cSig)) @@ -1257,10 +1272,7 @@ func (r *rsaPrivateKey) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) if int(cSigLen) != len(cSig) { return nil, fmt.Errorf("expected signature of length %d, got %d", len(cSig), cSigLen) } - sig := make([]byte, len(cSig)) - for i, b := range cSig { - sig[i] = byte(b) - } + sig := toBytes(cSig) return sig, nil } @@ -1295,10 +1307,7 @@ func (r *rsaPrivateKey) signPSS(digest []byte, opts *rsa.PSSOptions) ([]byte, er cParam.sLen = C.CK_ULONG(opts.SaltLength) } - cBytes := make([]C.CK_BYTE, len(digest)) - for i, b := range digest { - cBytes[i] = C.CK_BYTE(b) - } + cBytes := toCBytes(digest) cSig := make([]C.CK_BYTE, r.pub.Size()) cSigLen := C.CK_ULONG(len(cSig)) @@ -1321,10 +1330,7 @@ func (r *rsaPrivateKey) signPSS(digest []byte, opts *rsa.PSSOptions) ([]byte, er if int(cSigLen) != len(cSig) { return nil, fmt.Errorf("expected signature of length %d, got %d", len(cSig), cSigLen) } - sig := make([]byte, len(cSig)) - for i, b := range cSig { - sig[i] = byte(b) - } + sig := toBytes(cSig) return sig, nil } @@ -1353,10 +1359,7 @@ func (e *ecdsaPrivateKey) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpt cSig := make([]C.CK_BYTE, byteLen*2) cSigLen := C.CK_ULONG(len(cSig)) - cBytes := make([]C.CK_BYTE, len(digest)) - for i, b := range digest { - cBytes[i] = C.CK_BYTE(b) - } + cBytes := toCBytes(digest) rv = C.ck_sign(e.o.fl, e.o.h, &cBytes[0], C.CK_ULONG(len(digest)), &cSig[0], &cSigLen) if err := isOk("C_Sign", rv); err != nil { @@ -1366,10 +1369,7 @@ func (e *ecdsaPrivateKey) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpt if int(cSigLen) != len(cSig) { return nil, fmt.Errorf("expected signature of length %d, got %d", len(cSig), cSigLen) } - sig := make([]byte, len(cSig)) - for i, b := range cSig { - sig[i] = byte(b) - } + sig := toBytes(cSig) var ( r = big.NewInt(0) @@ -1687,3 +1687,86 @@ func (s *Slot) generateECDSA(o keyOptions) (crypto.PrivateKey, error) { } return priv, nil } + +func (r *rsaPrivateKey) Decrypt(_ io.Reader, encryptedData []byte, opts crypto.DecrypterOpts) ([]byte, error) { + var m C.CK_MECHANISM + + if o, ok := opts.(*rsa.OAEPOptions); ok { + cParam := (C.CK_RSA_PKCS_OAEP_PARAMS_PTR)(C.malloc(C.sizeof_CK_RSA_PKCS_OAEP_PARAMS)) + defer C.free(unsafe.Pointer(cParam)) + + switch o.Hash { + case crypto.SHA256: + cParam.hashAlg = C.CKM_SHA256 + cParam.mgf = C.CKG_MGF1_SHA256 + case crypto.SHA384: + cParam.hashAlg = C.CKM_SHA384 + cParam.mgf = C.CKG_MGF1_SHA384 + case crypto.SHA512: + cParam.hashAlg = C.CKM_SHA512 + cParam.mgf = C.CKG_MGF1_SHA512 + case crypto.SHA1: + cParam.hashAlg = C.CKM_SHA_1 + cParam.mgf = C.CKG_MGF1_SHA1 + default: + return nil, fmt.Errorf("decryptOAEP error, unsupported hash algorithm: %s", o.Hash) + } + + cParam.source = C.CKZ_DATA_SPECIFIED + cParam.pSourceData = nil + cParam.ulSourceDataLen = 0 + + m = C.CK_MECHANISM{ + mechanism: C.CKM_RSA_PKCS_OAEP, + pParameter: C.CK_VOID_PTR(cParam), + ulParameterLen: C.CK_ULONG(C.sizeof_CK_RSA_PKCS_OAEP_PARAMS), + } + } else { + m = C.CK_MECHANISM{C.CKM_RSA_PKCS, nil, 0} + } + + cEncDataBytes := toCBytes(encryptedData) + + rv := C.ck_decrypt_init(r.o.fl, r.o.h, &m, r.o.o) + if err := isOk("C_DecryptInit", rv); err != nil { + return nil, err + } + + var cDecryptedLen C.CK_ULONG + + // First call is used to determine length necessary to hold decrypted data (PKCS #11 5.2) + rv = C.ck_decrypt(r.o.fl, r.o.h, &cEncDataBytes[0], C.CK_ULONG(len(cEncDataBytes)), nil, &cDecryptedLen) + if err := isOk("C_Decrypt", rv); err != nil { + return nil, err + } + + cDecrypted := make([]C.CK_BYTE, cDecryptedLen) + + rv = C.ck_decrypt(r.o.fl, r.o.h, &cEncDataBytes[0], C.CK_ULONG(len(cEncDataBytes)), &cDecrypted[0], &cDecryptedLen) + if err := isOk("C_Decrypt", rv); err != nil { + return nil, err + } + + decrypted := toBytes(cDecrypted) + + // Removes null padding (PKCS#11 5.2): http://docs.oasis-open.org/pkcs11/pkcs11-base/v2.40/os/pkcs11-base-v2.40-os.html#_Toc416959738 + decrypted = bytes.Trim(decrypted, "\x00") + + return decrypted, nil +} + +func toBytes(data []C.CK_BYTE) []byte { + goBytes := make([]byte, len(data)) + for i, b := range data { + goBytes[i] = byte(b) + } + return goBytes +} + +func toCBytes(data []byte) []C.CK_BYTE { + cBytes := make([]C.CK_BYTE, len(data)) + for i, b := range data { + cBytes[i] = C.CK_BYTE(b) + } + return cBytes +} diff --git a/pkcs11/pkcs11_test.go b/pkcs11/pkcs11_test.go index 8f23406..995dc7f 100644 --- a/pkcs11/pkcs11_test.go +++ b/pkcs11/pkcs11_test.go @@ -21,6 +21,7 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/rsa" + "crypto/sha1" "crypto/sha256" "crypto/x509" "encoding/pem" @@ -594,3 +595,77 @@ func TestCreateCertificate(t *testing.T) { t.Errorf("Returned certificate did not match loaded certificate") } } + +func TestDecryptOAEP(t *testing.T) { + msg := "Plain text to encrypt" + b := []byte(msg) + tests := []struct { + name string + bits int + }{ + {"2048", 2048}, + {"4096", 4096}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := newTestSlot(t) + o := keyOptions{RSABits: test.bits} + priv, err := s.generate(o) + if err != nil { + t.Fatalf("generate(%#v) failed: %v", o, err) + } + rsaPub := priv.(*rsaPrivateKey).pub + // SHA1 is the only hash function supported by softhsm + cipher, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, rsaPub, b, nil) + if err != nil { + t.Fatalf("EncryptOAEP Error: %v", err) + } + opts := &rsa.OAEPOptions{Hash: crypto.SHA1} + rsaDecrypter := priv.(crypto.Decrypter) + decrypted, err := rsaDecrypter.Decrypt(nil, cipher, opts) + if err != nil { + t.Fatalf("Decrypt Error: %v", err) + } + if string(decrypted) != msg { + t.Errorf("Decrypt Error: expected %q, got %q", msg, string(decrypted)) + } + }) + } +} + +func TestDecryptPKCS(t *testing.T) { + msg := "Plain text to encrypt" + b := []byte(msg) + tests := []struct { + name string + bits int + }{ + {"2048", 2048}, + {"4096", 4096}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := newTestSlot(t) + o := keyOptions{RSABits: test.bits} + priv, err := s.generate(o) + if err != nil { + t.Fatalf("generate(%#v) failed: %v", o, err) + } + rsaPub := priv.(*rsaPrivateKey).pub + cipher, err := rsa.EncryptPKCS1v15(rand.Reader, rsaPub, b) + if err != nil { + t.Fatalf("EncryptPKCS1v15 Error: %v", err) + } + rsaDecrypter := priv.(crypto.Decrypter) + + // nil opts for decrypting using PKCS #1 v 1.5 + decrypted, err := rsaDecrypter.Decrypt(nil, cipher, nil) + if err != nil { + t.Fatalf("Decrypt Error: %v", err) + } + if string(decrypted) != msg { + t.Errorf("Decrypt Error: expected %q, got %q", msg, string(decrypted)) + } + }) + } +}