diff --git a/efi/efivario/fscontext.go b/efi/efivario/fscontext.go index 04fa481..0b462e9 100644 --- a/efi/efivario/fscontext.go +++ b/efi/efivario/fscontext.go @@ -109,28 +109,30 @@ func (c FsContext) VariableNames() (VariableNameIterator, error) { func (c FsContext) readEfiVarFileName(name string, out []byte) (a Attributes, n int, err error) { f, err := c.fs.Open(name) if err != nil { - var pathErr *os.PathError - if errors.As(err, &pathErr) && pathErr.Err == syscall.ENOENT { + if errors.Is(err, fs.ErrNotExist) || errors.Is(err, syscall.ENOENT) { + // Overwrite error with a custom error type. err = ErrNotFound - } else { - err = fmt.Errorf("efivario/get: %w", err) } return } defer multierr.AppendInvoke(&err, multierr.Close(f)) if err = binary.Read(f, binary.LittleEndian, &a); err != nil { - err = fmt.Errorf("efivario/get: %w", err) return } - n, rerr := io.ReadFull(f, out) - switch rerr { + if n, err = f.Read(out); err != nil { + return + } + + // Ensure to return ErrInsufficientSpace if there is more data + // available to read than out can hold. + var tmp [1]byte + switch _, err = f.Read(tmp[:]); err { + case io.EOF: + err = nil case nil: err = ErrInsufficientSpace - case io.ErrUnexpectedEOF: - err = nil - default: } return } @@ -139,21 +141,21 @@ func (c FsContext) writeEfiVarFileName(name string, value []byte, attrs Attribut var buf bytes.Buffer if err := binary.Write(&buf, binary.LittleEndian, attrs); err != nil { - return fmt.Errorf("efivario/set: write attr: %w", err) + return fmt.Errorf("write attr: %w", err) } if _, err := buf.Write(value); err != nil { - return fmt.Errorf("efivario/set: write value: %w", err) + return fmt.Errorf("write value: %w", err) } guard, err := openSafeguard(c.fs, name) if err != nil { - return fmt.Errorf("efivario/set: guard open: %w", err) + return fmt.Errorf("guard open: %w", err) } if guard != nil { defer multierr.AppendInvoke(&err, multierr.Invoke(func() error { if err := guard.Close(); err != nil { - return fmt.Errorf("efivario/set: guard close: %w", err) + return fmt.Errorf("guard close: %w", err) } return nil })) @@ -161,12 +163,12 @@ func (c FsContext) writeEfiVarFileName(name string, value []byte, attrs Attribut wasProtected, err := guard.disable() if err != nil { - return fmt.Errorf("efivario/set: disable protection: %w", err) + return fmt.Errorf("disable protection: %w", err) } if wasProtected { defer multierr.AppendInvoke(&err, multierr.Invoke(func() error { if err := guard.enable(); err != nil { - return fmt.Errorf("efivario/set: enable protection: %w", err) + return fmt.Errorf("enable protection: %w", err) } return nil })) @@ -179,20 +181,19 @@ func (c FsContext) writeEfiVarFileName(name string, value []byte, attrs Attribut f, err := c.fs.OpenFile(name, flags, 0644) if err != nil { - return fmt.Errorf("efivario/set: %w", err) + return err } defer multierr.AppendInvoke(&err, multierr.Close(f)) if _, err := buf.WriteTo(f); err != nil { - return fmt.Errorf("efivario/set: %w", err) + return err } if err := f.Sync(); err != nil { - var errno syscall.Errno switch { case errors.Is(err, fs.ErrInvalid): fallthrough - case errors.As(err, &errno) && errno == syscall.EINVAL: + case errors.Is(err, syscall.EINVAL): // fsync is not implemented by efivarfs yet so // calling it here might sound silly, which it // actually is. Lets just ignore it for now. @@ -207,23 +208,27 @@ func (c FsContext) writeEfiVarFileName(name string, value []byte, attrs Attribut func (c FsContext) deleteEfiFile(name string) error { guard, err := openSafeguard(c.fs, name) if err != nil { - return fmt.Errorf("efivario/delete: guard open: %w", err) + return fmt.Errorf("guard open: %w", err) } if guard != nil { defer multierr.AppendInvoke(&err, multierr.Invoke(func() error { if err := guard.Close(); err != nil { - return fmt.Errorf("efivario/set: guard close: %w", err) + return fmt.Errorf("guard close: %w", err) } return nil })) } if _, err := guard.disable(); err != nil { - return fmt.Errorf("efivario/delete: guard disable: %w", err) + return fmt.Errorf("guard disable: %w", err) } if err := c.fs.Remove(name); err != nil { - return fmt.Errorf("efivario/delete: remove: %w", err) + if errors.Is(err, fs.ErrNotExist) || errors.Is(err, syscall.ENOENT) { + // Overwrite error with a custom error type. + err = ErrNotFound + } + return fmt.Errorf("remove: %w", err) } return nil } @@ -237,15 +242,25 @@ func (c FsContext) GetSizeHint(name string, guid efiguid.GUID) (int64, error) { } func (c FsContext) Get(name string, guid efiguid.GUID, out []byte) (a Attributes, n int, err error) { - return c.readEfiVarFileName(getFileName(name, guid), out) + a, n, err = c.readEfiVarFileName(getFileName(name, guid), out) + if err != nil { + err = fmt.Errorf("efivario/get: %w", err) + } + return } -func (c FsContext) Set(name string, guid efiguid.GUID, attributes Attributes, value []byte) error { - return c.writeEfiVarFileName(getFileName(name, guid), value, attributes) +func (c FsContext) Set(name string, guid efiguid.GUID, attributes Attributes, value []byte) (err error) { + if err = c.writeEfiVarFileName(getFileName(name, guid), value, attributes); err != nil { + err = fmt.Errorf("efivario/set: %w", err) + } + return } -func (c FsContext) Delete(name string, guid efiguid.GUID) error { - return c.deleteEfiFile(getFileName(name, guid)) +func (c FsContext) Delete(name string, guid efiguid.GUID) (err error) { + if err = c.deleteEfiFile(getFileName(name, guid)); err != nil { + err = fmt.Errorf("efivario/delete: %w", err) + } + return } func NewFileSystemContext(fs afero.Fs) *FsContext { diff --git a/efi/efivario/fscontext_test.go b/efi/efivario/fscontext_test.go index 4defdc2..f896108 100644 --- a/efi/efivario/fscontext_test.go +++ b/efi/efivario/fscontext_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/spf13/afero" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -58,6 +59,63 @@ func (s *FsContextTestSuite) TearDownTest() { require.NoError(s.T(), os.RemoveAll(s.tmpDir)) } +// TestGetNonExistent tests reading a non-existing variable. +func (s *FsContextTestSuite) TestGetNonExistent() { + // given that ... + _, err := s.context.fs.Stat("TestVar-3CD99F3F-4B2B-43EB-AC29-F0890A4772B7") + require.ErrorIs(s.T(), err, afero.ErrFileNotFound) + + buf := make([]byte, 4096) + + // when ... + _, _, err = s.context.Get("TestVar", testGuid, buf) + + // then ... + require.ErrorIs(s.T(), err, ErrNotFound) +} + +// TestGet tests reading an existing variable. +func (s *FsContextTestSuite) TestGet() { + // given that ... + f, err := s.context.fs.Create("TestVar-3CD99F3F-4B2B-43EB-AC29-F0890A4772B7") + require.NoError(s.T(), err) + defer func() { require.NoError(s.T(), f.Close()) }() + + _, err = f.Write([]byte{0x07, 0x00, 0x00, 0x00, 0x65, 0x6e, 0x2d, 0x55, 0x53, 0x00}) + require.NoError(s.T(), err) + require.NoError(s.T(), f.Sync()) + + buf := make([]byte, 6) + + // when ... + attrs, length, err := s.context.Get("TestVar", testGuid, buf) + + // then ... + require.NoError(s.T(), err) + assert.Equal(s.T(), RuntimeAccess|BootServiceAccess|NonVolatile, attrs) + assert.Equal(s.T(), 6, length) +} + +// TestGetTooShort tests reading with a buffer too short. +func (s *FsContextTestSuite) TestGetTooShort() { + // given that ... + f, err := s.context.fs.Create("TestVar-3CD99F3F-4B2B-43EB-AC29-F0890A4772B7") + require.NoError(s.T(), err) + defer func() { require.NoError(s.T(), f.Close()) }() + + _, err = f.Write([]byte{0x07, 0x00, 0x00, 0x00, 0x65, 0x6e, 0x2d, 0x55, 0x53, 0x00}) + require.NoError(s.T(), err) + require.NoError(s.T(), f.Sync()) + + buf := make([]byte, 5) + + // when ... + _, _, err = s.context.Get("TestVar", testGuid, buf) + + // then ... + require.ErrorIs(s.T(), err, ErrInsufficientSpace) +} + // TestSetNewVariable tests setting a new variable. func (s *FsContextTestSuite) TestSetNewVariable() { // given that ... @@ -105,6 +163,19 @@ func (s *FsContextTestSuite) TestDelete() { require.True(s.T(), errors.Is(err, afero.ErrFileNotFound)) } +// TestDeleteNonExistent tests deleting a non-existing variable. +func (s *FsContextTestSuite) TestDeleteNonExistent() { + // given that ... + _, err := s.context.fs.Stat("TestVar-3CD99F3F-4B2B-43EB-AC29-F0890A4772B7") + require.True(s.T(), errors.Is(err, afero.ErrFileNotFound)) + + // when ... + err = s.context.Delete("TestVar", testGuid) + + // then ... + require.True(s.T(), errors.Is(err, ErrNotFound)) +} + // In order for 'go test' to run this suite, we need to create // a normal test function and pass our suite to suite.Run func TestFsContextTestSuite(t *testing.T) {