diff --git a/snapmodel_policy.go b/snapmodel_policy.go index afceecf1..92b08622 100644 --- a/snapmodel_policy.go +++ b/snapmodel_policy.go @@ -30,6 +30,14 @@ import ( "golang.org/x/xerrors" ) +const zeroSnapSystemEpoch uint32 = 0 + +func computeSnapSystemEpochDigest(alg tpm2.HashAlgorithmId, epoch uint32) tpm2.Digest { + h := alg.NewHash() + binary.Write(h, binary.LittleEndian, epoch) + return h.Sum(nil) +} + func computeSnapModelDigest(alg tpm2.HashAlgorithmId, model *asserts.Model) (tpm2.Digest, error) { signKeyId, err := base64.RawURLEncoding.DecodeString(model.SignKeyID()) if err != nil { @@ -73,11 +81,16 @@ type SnapModelProfileParams struct { // a PCR policy that is bound to a specific set of device models. It is the responsibility of snap-bootstrap to verify the integrity // of the model that it has measured. // -// The profile consists of 2 measurements (where H is the digest algorithm supplied via params.PCRAlgorithm): -// H(uint32(0)) +// The profile consists of 2 measurements: +// digestEpoch // digestModel // -// digestModel is computed as follows: +// digestEpoch is currently hardcoded as (where H is the digest algorithm supplied via params.PCRAlgorithm): +// digestEpoch = H(uint32(0)) +// +// A future version of this package may allow another epoch to be supplied. +// +// digestModel is computed as follows (where H is the digest algorithm supplied via params.PCRAlgorithm): // digest1 = H(tpm2.HashAlgorithmSHA384 || sign-key-sha3-384 || brand-id) // digest2 = H(digest1 || model) // digestModel = H(digest2 || series || grade) @@ -98,9 +111,7 @@ func AddSnapModelProfile(profile *PCRProtectionProfile, params *SnapModelProfile return errors.New("no models provided") } - h := params.PCRAlgorithm.NewHash() - binary.Write(h, binary.LittleEndian, uint32(0)) - profile.ExtendPCR(params.PCRAlgorithm, params.PCRIndex, h.Sum(nil)) + profile.ExtendPCR(params.PCRAlgorithm, params.PCRIndex, computeSnapSystemEpochDigest(params.PCRAlgorithm, zeroSnapSystemEpoch)) var subProfiles []*PCRProtectionProfile for _, model := range params.Models { @@ -119,9 +130,7 @@ func AddSnapModelProfile(profile *PCRProtectionProfile, params *SnapModelProfile return nil } -// MeasureSnapModelToTPM measures a digest of the supplied model assertion to the specified PCR for all supported PCR banks. -// See the documentation for AddSnapModelProfile for details of how the digest of the model is computed. -func MeasureSnapModelToTPM(tpm *TPMConnection, pcrIndex int, model *asserts.Model) error { +func measureSnapPropertyToTPM(tpm *TPMConnection, pcrIndex int, computeDigest func(tpm2.HashAlgorithmId) (tpm2.Digest, error)) error { pcrSelection, err := tpm.GetCapabilityPCRs(tpm.HmacSession().IncludeAttrs(tpm2.AttrAudit)) if err != nil { return xerrors.Errorf("cannot determine supported PCR banks: %w", err) @@ -136,9 +145,9 @@ func MeasureSnapModelToTPM(tpm *TPMConnection, pcrIndex int, model *asserts.Mode continue } - digest, err := computeSnapModelDigest(s.Hash, model) + digest, err := computeDigest(s.Hash) if err != nil { - return xerrors.Errorf("cannot compute snap mode digest for algorithm %v: %w", s.Hash, err) + return xerrors.Errorf("cannot compute digest for algorithm %v: %w", s.Hash, err) } digests = append(digests, tpm2.TaggedHash{HashAlg: s.Hash, Digest: digest}) @@ -146,3 +155,19 @@ func MeasureSnapModelToTPM(tpm *TPMConnection, pcrIndex int, model *asserts.Mode return tpm.PCRExtend(tpm.PCRHandleContext(pcrIndex), digests, tpm.HmacSession()) } + +// MeasureSnapSystemEpochToTPM measures a digest of uint32(0) to the specified PCR for all supported PCR banks. See the documentation +// for AddSnapModelProfile for more details. +func MeasureSnapSystemEpochToTPM(tpm *TPMConnection, pcrIndex int) error { + return measureSnapPropertyToTPM(tpm, pcrIndex, func(alg tpm2.HashAlgorithmId) (tpm2.Digest, error) { + return computeSnapSystemEpochDigest(alg, zeroSnapSystemEpoch), nil + }) +} + +// MeasureSnapModelToTPM measures a digest of the supplied model assertion to the specified PCR for all supported PCR banks. +// See the documentation for AddSnapModelProfile for details of how the digest of the model is computed. +func MeasureSnapModelToTPM(tpm *TPMConnection, pcrIndex int, model *asserts.Model) error { + return measureSnapPropertyToTPM(tpm, pcrIndex, func(alg tpm2.HashAlgorithmId) (tpm2.Digest, error) { + return computeSnapModelDigest(alg, model) + }) +} diff --git a/snapmodel_policy_test.go b/snapmodel_policy_test.go index b903dceb..a5d91e08 100644 --- a/snapmodel_policy_test.go +++ b/snapmodel_policy_test.go @@ -20,6 +20,7 @@ package secboot_test import ( + "encoding/binary" "time" "github.com/canonical/go-tpm2" @@ -558,3 +559,52 @@ func (s *snapModelMeasureSuite) TestMeasureSnapModelToTPMTest7(c *C) { }, "Jv8_JiHiIzJVcO9M55pPdqSDWUvuhfDIBJUS-3VW7F_idjix7Ffn5qMxB21ZQuij"), }) } + +func (s *snapModelMeasureSuite) testMeasureSnapSystemEpochToTPM(c *C, pcrIndex int) { + pcrSelection, err := s.tpm.GetCapabilityPCRs() + c.Assert(err, IsNil) + + var pcrs []int + for i := 0; i < 24; i++ { + pcrs = append(pcrs, i) + } + var readPcrSelection tpm2.PCRSelectionList + for _, s := range pcrSelection { + readPcrSelection = append(readPcrSelection, tpm2.PCRSelection{Hash: s.Hash, Select: pcrs}) + } + + _, origPcrValues, err := s.tpm.PCRRead(readPcrSelection) + c.Assert(err, IsNil) + + c.Check(MeasureSnapSystemEpochToTPM(s.tpm, pcrIndex), IsNil) + + _, pcrValues, err := s.tpm.PCRRead(readPcrSelection) + c.Assert(err, IsNil) + + for _, s := range pcrSelection { + h := s.Hash.NewHash() + binary.Write(h, binary.LittleEndian, uint32(0)) + digest := h.Sum(nil) + + h = s.Hash.NewHash() + h.Write(origPcrValues[s.Hash][pcrIndex]) + h.Write(digest) + + c.Check(pcrValues[s.Hash][pcrIndex], DeepEquals, tpm2.Digest(h.Sum(nil))) + + for _, p := range pcrs { + if p == pcrIndex { + continue + } + c.Check(pcrValues[s.Hash][p], DeepEquals, origPcrValues[s.Hash][p]) + } + } +} + +func (s *snapModelMeasureSuite) TestMeasureSnapSystemEpochToTPM1(c *C) { + s.testMeasureSnapSystemEpochToTPM(c, 12) +} + +func (s *snapModelMeasureSuite) TestMeasureSnapSystemEpochToTPM2(c *C) { + s.testMeasureSnapSystemEpochToTPM(c, 14) +}