Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
feat: add support for decryption
Browse files Browse the repository at this point in the history
  • Loading branch information
lcforges authored and ericchiang committed Sep 7, 2023
1 parent 3bb1db4 commit c6f7932
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 28 deletions.
139 changes: 111 additions & 28 deletions pkcs11/pkcs11.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,26 @@ CK_RV ck_sign(
) {
return (*fl->C_Sign)(hSession, pData, ulDataLen, pSignature, pulSignatureLen);
}
CK_RV ck_decrypt_init(
CK_FUNCTION_LIST_PTR fl,
CK_SESSION_HANDLE hSession,
CK_MECHANISM_PTR pMechanism,
CK_OBJECT_HANDLE hKey
) {
return (*fl->C_DecryptInit)(hSession, pMechanism, hKey);
}
CK_RV ck_decrypt(
CK_FUNCTION_LIST_PTR fl,
CK_SESSION_HANDLE hSession,
CK_BYTE_PTR pEncryptedData,
CK_ULONG ulEncryptedDataLen,
CK_BYTE_PTR pData,
CK_ULONG_PTR pulDataLen
) {
return (*fl->C_Decrypt)(hSession, pEncryptedData, ulEncryptedDataLen, pData, pulDataLen);
}
*/
// #cgo linux LDFLAGS: -ldl
import "C"
Expand Down Expand Up @@ -1226,20 +1246,15 @@ func (r *rsaPrivateKey) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts)
// http://docs.oasis-open.org/pkcs11/pkcs11-curr/v2.40/cs01/pkcs11-curr-v2.40-cs01.html#_Toc399398842
size := opts.HashFunc().Size()
if size != len(digest) {
return nil, fmt.Errorf("input mush be hashed")
return nil, fmt.Errorf("input must be hashed")
}
prefix, ok := hashPrefixes[opts.HashFunc()]
if !ok {
return nil, fmt.Errorf("unsupported hash function: %s", opts.HashFunc())
}

cBytes := make([]C.CK_BYTE, len(prefix)+len(digest))
for i, b := range prefix {
cBytes[i] = C.CK_BYTE(b)
}
for i, b := range digest {
cBytes[len(prefix)+i] = C.CK_BYTE(b)
}
preAndDigest := append(prefix, digest...)
cBytes := toCBytes(preAndDigest)

cSig := make([]C.CK_BYTE, r.pub.Size())
cSigLen := C.CK_ULONG(len(cSig))
Expand All @@ -1257,10 +1272,7 @@ func (r *rsaPrivateKey) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts)
if int(cSigLen) != len(cSig) {
return nil, fmt.Errorf("expected signature of length %d, got %d", len(cSig), cSigLen)
}
sig := make([]byte, len(cSig))
for i, b := range cSig {
sig[i] = byte(b)
}
sig := toBytes(cSig)
return sig, nil
}

Expand Down Expand Up @@ -1295,10 +1307,7 @@ func (r *rsaPrivateKey) signPSS(digest []byte, opts *rsa.PSSOptions) ([]byte, er
cParam.sLen = C.CK_ULONG(opts.SaltLength)
}

cBytes := make([]C.CK_BYTE, len(digest))
for i, b := range digest {
cBytes[i] = C.CK_BYTE(b)
}
cBytes := toCBytes(digest)

cSig := make([]C.CK_BYTE, r.pub.Size())
cSigLen := C.CK_ULONG(len(cSig))
Expand All @@ -1321,10 +1330,7 @@ func (r *rsaPrivateKey) signPSS(digest []byte, opts *rsa.PSSOptions) ([]byte, er
if int(cSigLen) != len(cSig) {
return nil, fmt.Errorf("expected signature of length %d, got %d", len(cSig), cSigLen)
}
sig := make([]byte, len(cSig))
for i, b := range cSig {
sig[i] = byte(b)
}
sig := toBytes(cSig)
return sig, nil
}

Expand Down Expand Up @@ -1353,10 +1359,7 @@ func (e *ecdsaPrivateKey) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpt
cSig := make([]C.CK_BYTE, byteLen*2)
cSigLen := C.CK_ULONG(len(cSig))

cBytes := make([]C.CK_BYTE, len(digest))
for i, b := range digest {
cBytes[i] = C.CK_BYTE(b)
}
cBytes := toCBytes(digest)

rv = C.ck_sign(e.o.fl, e.o.h, &cBytes[0], C.CK_ULONG(len(digest)), &cSig[0], &cSigLen)
if err := isOk("C_Sign", rv); err != nil {
Expand All @@ -1366,10 +1369,7 @@ func (e *ecdsaPrivateKey) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpt
if int(cSigLen) != len(cSig) {
return nil, fmt.Errorf("expected signature of length %d, got %d", len(cSig), cSigLen)
}
sig := make([]byte, len(cSig))
for i, b := range cSig {
sig[i] = byte(b)
}
sig := toBytes(cSig)

var (
r = big.NewInt(0)
Expand Down Expand Up @@ -1687,3 +1687,86 @@ func (s *Slot) generateECDSA(o keyOptions) (crypto.PrivateKey, error) {
}
return priv, nil
}

func (r *rsaPrivateKey) Decrypt(_ io.Reader, encryptedData []byte, opts crypto.DecrypterOpts) ([]byte, error) {
var m C.CK_MECHANISM

if o, ok := opts.(*rsa.OAEPOptions); ok {
cParam := (C.CK_RSA_PKCS_OAEP_PARAMS_PTR)(C.malloc(C.sizeof_CK_RSA_PKCS_OAEP_PARAMS))
defer C.free(unsafe.Pointer(cParam))

switch o.Hash {
case crypto.SHA256:
cParam.hashAlg = C.CKM_SHA256
cParam.mgf = C.CKG_MGF1_SHA256
case crypto.SHA384:
cParam.hashAlg = C.CKM_SHA384
cParam.mgf = C.CKG_MGF1_SHA384
case crypto.SHA512:
cParam.hashAlg = C.CKM_SHA512
cParam.mgf = C.CKG_MGF1_SHA512
case crypto.SHA1:
cParam.hashAlg = C.CKM_SHA_1
cParam.mgf = C.CKG_MGF1_SHA1
default:
return nil, fmt.Errorf("decryptOAEP error, unsupported hash algorithm: %s", o.Hash)
}

cParam.source = C.CKZ_DATA_SPECIFIED
cParam.pSourceData = nil
cParam.ulSourceDataLen = 0

m = C.CK_MECHANISM{
mechanism: C.CKM_RSA_PKCS_OAEP,
pParameter: C.CK_VOID_PTR(cParam),
ulParameterLen: C.CK_ULONG(C.sizeof_CK_RSA_PKCS_OAEP_PARAMS),
}
} else {
m = C.CK_MECHANISM{C.CKM_RSA_PKCS, nil, 0}
}

cEncDataBytes := toCBytes(encryptedData)

rv := C.ck_decrypt_init(r.o.fl, r.o.h, &m, r.o.o)
if err := isOk("C_DecryptInit", rv); err != nil {
return nil, err
}

var cDecryptedLen C.CK_ULONG

// First call is used to determine length necessary to hold decrypted data (PKCS #11 5.2)
rv = C.ck_decrypt(r.o.fl, r.o.h, &cEncDataBytes[0], C.CK_ULONG(len(cEncDataBytes)), nil, &cDecryptedLen)
if err := isOk("C_Decrypt", rv); err != nil {
return nil, err
}

cDecrypted := make([]C.CK_BYTE, cDecryptedLen)

rv = C.ck_decrypt(r.o.fl, r.o.h, &cEncDataBytes[0], C.CK_ULONG(len(cEncDataBytes)), &cDecrypted[0], &cDecryptedLen)
if err := isOk("C_Decrypt", rv); err != nil {
return nil, err
}

decrypted := toBytes(cDecrypted)

// Removes null padding (PKCS#11 5.2): http://docs.oasis-open.org/pkcs11/pkcs11-base/v2.40/os/pkcs11-base-v2.40-os.html#_Toc416959738
decrypted = bytes.Trim(decrypted, "\x00")

return decrypted, nil
}

func toBytes(data []C.CK_BYTE) []byte {
goBytes := make([]byte, len(data))
for i, b := range data {
goBytes[i] = byte(b)
}
return goBytes
}

func toCBytes(data []byte) []C.CK_BYTE {
cBytes := make([]C.CK_BYTE, len(data))
for i, b := range data {
cBytes[i] = C.CK_BYTE(b)
}
return cBytes
}
75 changes: 75 additions & 0 deletions pkcs11/pkcs11_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"crypto/x509"
"encoding/pem"
Expand Down Expand Up @@ -594,3 +595,77 @@ func TestCreateCertificate(t *testing.T) {
t.Errorf("Returned certificate did not match loaded certificate")
}
}

func TestDecryptOAEP(t *testing.T) {
msg := "Plain text to encrypt"
b := []byte(msg)
tests := []struct {
name string
bits int
}{
{"2048", 2048},
{"4096", 4096},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := newTestSlot(t)
o := keyOptions{RSABits: test.bits}
priv, err := s.generate(o)
if err != nil {
t.Fatalf("generate(%#v) failed: %v", o, err)
}
rsaPub := priv.(*rsaPrivateKey).pub
// SHA1 is the only hash function supported by softhsm
cipher, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, rsaPub, b, nil)
if err != nil {
t.Fatalf("EncryptOAEP Error: %v", err)
}
opts := &rsa.OAEPOptions{Hash: crypto.SHA1}
rsaDecrypter := priv.(crypto.Decrypter)
decrypted, err := rsaDecrypter.Decrypt(nil, cipher, opts)
if err != nil {
t.Fatalf("Decrypt Error: %v", err)
}
if string(decrypted) != msg {
t.Errorf("Decrypt Error: expected %q, got %q", msg, string(decrypted))
}
})
}
}

func TestDecryptPKCS(t *testing.T) {
msg := "Plain text to encrypt"
b := []byte(msg)
tests := []struct {
name string
bits int
}{
{"2048", 2048},
{"4096", 4096},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := newTestSlot(t)
o := keyOptions{RSABits: test.bits}
priv, err := s.generate(o)
if err != nil {
t.Fatalf("generate(%#v) failed: %v", o, err)
}
rsaPub := priv.(*rsaPrivateKey).pub
cipher, err := rsa.EncryptPKCS1v15(rand.Reader, rsaPub, b)
if err != nil {
t.Fatalf("EncryptPKCS1v15 Error: %v", err)
}
rsaDecrypter := priv.(crypto.Decrypter)

// nil opts for decrypting using PKCS #1 v 1.5
decrypted, err := rsaDecrypter.Decrypt(nil, cipher, nil)
if err != nil {
t.Fatalf("Decrypt Error: %v", err)
}
if string(decrypted) != msg {
t.Errorf("Decrypt Error: expected %q, got %q", msg, string(decrypted))
}
})
}
}

0 comments on commit c6f7932

Please sign in to comment.