diff --git a/README.md b/README.md index 853ea31..c2c4544 100644 --- a/README.md +++ b/README.md @@ -43,3 +43,4 @@ func main() { - Simple API - Reading individual Boot options - Setting next Boot option +- Managing Boot order diff --git a/cmd/example/example.go b/cmd/example/example.go index b61e3a8..c03442a 100644 --- a/cmd/example/example.go +++ b/cmd/example/example.go @@ -20,6 +20,29 @@ var ( bootRe = regexp.MustCompile(`^Boot([\da-fA-F]{4})$`) ) +func ListBootOrder(c efivario.Context) error { + _, value, err := efivars.BootOrder.Get(c) + if err != nil { + return err + } + + for i, index := range value { + _, lo, err := efivars.Boot(index).Get(c) + if err != nil { + return fmt.Errorf("entry %d (%04[1]X): %w", index, err) + } + + pp.Println(map[string]any{ + "Order": i, + "Index": index, + "Description": efireader.UTF16NullBytesToString(lo.Description), + "Path": lo.FilePathList.AllText(), + }) + } + + return nil +} + func ListAllVariables(c efivario.Context) error { iter, err := c.VariableNames() if err != nil { @@ -76,7 +99,7 @@ func ReadBootEntries(c efivario.Context) error { fmt.Printf("\nEntry Boot%04X(%[1]d):\n", value) - attrs, lo, err := efivars.Boot(int(value)).Get(c) + attrs, lo, err := efivars.Boot(uint16(int(value))).Get(c) if err != nil { return err } @@ -105,6 +128,9 @@ func Run(args []string) error { var listBootEntries bool fset.BoolVar(&listBootEntries, "list-boot", false, "list boot entries") + var listBootOrder bool + fset.BoolVar(&listBootOrder, "list-boot-order", false, "list boot order") + var setNextBoot bool fset.BoolVar(&setNextBoot, "set-next", false, "set next boot option") @@ -119,6 +145,8 @@ func Run(args []string) error { var err error switch { + case listBootOrder: + err = ListBootOrder(c) case listBootEntries: err = ReadBootEntries(c) case listAllVariables: diff --git a/efi/efivario/readall.go b/efi/efivario/readall.go index 9738576..765da7e 100644 --- a/efi/efivario/readall.go +++ b/efi/efivario/readall.go @@ -25,8 +25,8 @@ func ReadAll(c Context, name string, guid efiguid.GUID) ( out []byte, err error, ) { - var hint int64 - if hint, err = c.GetSizeHint(name, guid); err != nil { + hint, err := c.GetSizeHint(name, guid) + if err != nil || hint < 0 { hint = 8 } diff --git a/efi/efivars/bootvariables.go b/efi/efivars/bootvariables.go index 444a08b..7facb94 100644 --- a/efi/efivars/bootvariables.go +++ b/efi/efivars/bootvariables.go @@ -18,12 +18,12 @@ import ( "fmt" "github.com/0x5a17ed/uefi/efi/efitypes" - "github.com/0x5a17ed/uefi/efi/efivario" ) const ( BootNextName = "BootNext" BootCurrentName = "BootCurrent" + BootOrderName = "BootOrder" ) var ( @@ -33,7 +33,9 @@ var ( BootNext = Variable[uint16]{ name: BootNextName, guid: GlobalVariable, - defaultAttrs: efivario.NonVolatile | efivario.BootServiceAccess | efivario.RuntimeAccess, + defaultAttrs: defaultAttrs, + marshal: primitiveMarshaller[uint16], + unmarshal: primitiveUnmarshaller[uint16], } // BootCurrent defines the Boot#### option that was selected @@ -43,7 +45,25 @@ var ( BootCurrent = Variable[uint16]{ name: BootCurrentName, guid: GlobalVariable, - defaultAttrs: efivario.NonVolatile | efivario.BootServiceAccess | efivario.RuntimeAccess, + defaultAttrs: defaultAttrs, + marshal: primitiveMarshaller[uint16], + unmarshal: primitiveUnmarshaller[uint16], + } + + // BootOrder is an ordered list of the Boot#### options. + // + // The first element in the array is the value for the first + // logical boot option, the second element is the value for + // the second logical boot option, etc. The BootOrder order + // list is used by the firmware’s boot manager as the default + // boot order. + // + // + BootOrder = Variable[[]uint16]{ + name: BootOrderName, + guid: GlobalVariable, + defaultAttrs: defaultAttrs, + unmarshal: sliceUnmarshaller[uint16], } ) @@ -51,9 +71,10 @@ var ( // for the given index. // // -func Boot(i int) Variable[efitypes.LoadOption] { - return Variable[efitypes.LoadOption]{ - name: fmt.Sprintf("Boot%04X", i), - guid: GlobalVariable, +func Boot(i uint16) Variable[*efitypes.LoadOption] { + return Variable[*efitypes.LoadOption]{ + name: fmt.Sprintf("Boot%04X", i), + guid: GlobalVariable, + unmarshal: structUnmarshaller[efitypes.LoadOption], } } diff --git a/efi/efivars/marshalling.go b/efi/efivars/marshalling.go new file mode 100644 index 0000000..f25b0a6 --- /dev/null +++ b/efi/efivars/marshalling.go @@ -0,0 +1,81 @@ +// Copyright (c) 2022 Arthur Skowronek <0x5a17ed@tuta.io> and contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package efivars + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" +) + +func primitiveUnmarshaller[T any](r io.Reader) (out T, err error) { + err = binary.Read(r, binary.LittleEndian, &out) + return +} + +func primitiveMarshaller[T any](w io.Writer, inp T) error { + return binary.Write(w, binary.LittleEndian, inp) +} + +func sliceUnmarshaller[T any](r io.Reader) (out []T, err error) { + for i := 0; ; i += 1 { + var item T + err = binary.Read(r, binary.LittleEndian, &item) + if err != nil { + if errors.Is(err, io.EOF) { + err = nil + } else { + err = fmt.Errorf("item #%d: %w", i, err) + } + return + } + out = append(out, item) + } +} + +func sliceMarshaller[T any](w io.Writer, inp []T) (err error) { + var buf bytes.Buffer + + for i, item := range inp { + err := binary.Write(&buf, binary.LittleEndian, item) + if err != nil { + return fmt.Errorf("item #%d: %w", i, err) + } + } + + _, err = buf.WriteTo(w) + return +} + +type readerFrom[T any] interface { + io.ReaderFrom + *T +} + +func structUnmarshaller[T any, PT readerFrom[T]](r io.Reader) (out *T, err error) { + var value T + _, err = PT(&value).ReadFrom(r) + if err == nil { + out = &value + } + return +} + +func structMarshaller[T io.WriterTo](w io.Writer, inp T) (err error) { + _, err = inp.WriteTo(w) + return +} diff --git a/efi/efivars/variable.go b/efi/efivars/variable.go index fe774c1..51627fa 100644 --- a/efi/efivars/variable.go +++ b/efi/efivars/variable.go @@ -16,7 +16,6 @@ package efivars import ( "bytes" - "encoding/binary" "fmt" "io" @@ -24,48 +23,52 @@ import ( "github.com/0x5a17ed/uefi/efi/efivario" ) +const ( + globalAccess = efivario.BootServiceAccess | efivario.RuntimeAccess + + defaultAttrs = efivario.NonVolatile | globalAccess +) + +type MarshalFn[T any] func(w io.Writer, inp T) error +type UnmarshalFn[T any] func(r io.Reader) (T, error) + type Variable[T any] struct { name string guid efiguid.GUID defaultAttrs efivario.Attributes + + marshal MarshalFn[T] + unmarshal UnmarshalFn[T] } func (e Variable[T]) Get(c efivario.Context) (attrs efivario.Attributes, value T, err error) { + if e.unmarshal == nil { + err = fmt.Errorf("efivars/get(%s): unsupported", e.name) + return + } + attrs, data, err := efivario.ReadAll(c, e.name, e.guid) if err != nil { err = fmt.Errorf("efivars/get(%s): load: %w", e.name, err) return } - buf := bytes.NewReader(data) - - var valueInterface any = &value - if reader, ok := (valueInterface).(io.ReaderFrom); ok { - _, err = reader.ReadFrom(buf) - } else { - err = binary.Read(buf, binary.LittleEndian, &value) - } + value, err = e.unmarshal(bytes.NewReader(data)) if err != nil { err = fmt.Errorf("efivars/get(%s): parse: %w", e.name, err) - return } - return } -func (e Variable[T]) SetWithAttributes(c efivario.Context, attrs efivario.Attributes, value T) (err error) { - var buf bytes.Buffer - - var valueInterface any = &value - if writer, ok := (valueInterface).(io.WriterTo); ok { - _, err = writer.WriteTo(&buf) - } else { - err = binary.Write(&buf, binary.LittleEndian, value) - } - if err != nil { - return fmt.Errorf("efivars/set(%s): %w", e.name, err) +func (e Variable[T]) SetWithAttributes(c efivario.Context, attrs efivario.Attributes, value T) error { + if e.marshal == nil { + return fmt.Errorf("efivars/set(%s): unsupported", e.name) } + var buf bytes.Buffer + if err := e.marshal(&buf, value); err != nil { + return fmt.Errorf("efivars/set(%s): write: %w", e.name, err) + } return c.Set(e.name, e.guid, attrs, buf.Bytes()) } diff --git a/efi/efivars/variable_test.go b/efi/efivars/variable_test.go index 2e54e15..84ccee4 100644 --- a/efi/efivars/variable_test.go +++ b/efi/efivars/variable_test.go @@ -15,31 +15,255 @@ package efivars import ( + "bytes" + "encoding/binary" "encoding/hex" + "fmt" + "io" + "strings" "testing" "github.com/spf13/afero" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "go.uber.org/multierr" + "github.com/0x5a17ed/uefi/efi/efiguid" "github.com/0x5a17ed/uefi/efi/efivario" ) -func TestBoot(t *testing.T) { - // objective of the test is to ensure that Boot(n) correctly - // converts n to hexadecimal. +var testGuid = efiguid.MustFromString("3cd99f3f-4b2b-43eb-ac29-f0890a4772b7") + +func readFile(fs afero.Fs, name string) (string, error) { + f, err := fs.Open(name) + if err != nil { + return "", fmt.Errorf("readFile/open: %w", err) + } + defer f.Close() + + var buf bytes.Buffer + if _, err := io.Copy(hex.NewEncoder(&buf), f); err != nil { + return "", fmt.Errorf("readFile/read: %w", err) + } + + return buf.String(), nil +} + +func writeFile(fs afero.Fs, name, data string) (err error) { + f, _ := fs.Create(name) + defer multierr.AppendInvoke(&err, multierr.Close(f)) + _, err = io.Copy(f, hex.NewDecoder(strings.NewReader(data))) + return +} + +func wantedError(target error) assert.ErrorAssertionFunc { + return func(t assert.TestingT, err error, msgAndArgs ...interface{}) bool { + return assert.ErrorIs(t, err, target, msgAndArgs...) + } +} + +type testRow[T any] struct { + variable Variable[T] + name string + data string + wanted T + wantErr assert.ErrorAssertionFunc +} + +func (r *testRow[T]) fileName() string { + return fmt.Sprintf("%s-%s", r.variable.name, r.variable.guid) +} + +type testEnv[T any] struct { + t *testing.T + fs afero.Fs + ctx *efivario.FsContext +} + +func (te *testEnv[T]) testGet(row *testRow[T]) (ok bool) { + gotAttrs, gotValue, err := row.variable.Get(te.ctx) + if !row.wantErr(te.t, err, "efivars/get") { + return + } + + if !assert.Equal(te.t, row.variable.defaultAttrs, gotAttrs) { + return + } + if !assert.Equal(te.t, row.wanted, gotValue) { + return + } + return true +} + +func (te testEnv[T]) testSet(row *testRow[T]) (ok bool) { + err := row.variable.Set(te.ctx, row.wanted) + if !row.wantErr(te.t, err, "efivars/set") { + return + } + + content, err := readFile(te.fs, row.fileName()) + if !assert.NoError(te.t, err) { + return + } + + if !assert.Equal(te.t, row.data, content[8:]) { + return + } + + return true +} + +func newTestEnv[T any](t *testing.T) *testEnv[T] { fs := afero.NewMemMapFs() + return &testEnv[T]{ + t: t, + fs: fs, + ctx: efivario.NewFileSystemContext(fs), + } +} - const s = "000000000000000000000000" +func (te testEnv[T]) setupVarFile(row *testRow[T]) error { + return writeFile(te.fs, row.fileName(), "07000000"+row.data) +} - decoded, err := hex.DecodeString(s) - assert.NoError(t, err) +type testRunner[T any] func(*testing.T, *testRow[T], *testEnv[T]) - f, _ := fs.Create("Boot000F-8BE4DF61-93CA-11D2-AA0D-00E098032B8C") - f.Write(decoded) - f.Close() +func runTests[T any](t *testing.T, tests []*testRow[T], fn testRunner[T]) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := newTestEnv[T](t) + fn(t, tt, e) + }) + } +} + +type VariableTestSuite struct{ suite.Suite } + +func (s *VariableTestSuite) TestPrimitive() { + s.Run("uint16", func() { + v := Variable[uint16]{ + name: "TestVar", + guid: testGuid, + defaultAttrs: defaultAttrs, + unmarshal: primitiveUnmarshaller[uint16], + marshal: primitiveMarshaller[uint16], + } + + row := &testRow[uint16]{ + v, "", "1234", uint16(0x3412), assert.NoError, + } + + tfn := newTestEnv[uint16](s.T()) + if !tfn.testSet(row) { + return + } + tfn.testGet(row) + }) + + s.Run("SetGet/bool", func() { + v := Variable[bool]{ + name: "TestVar", + guid: testGuid, + defaultAttrs: defaultAttrs, + unmarshal: primitiveUnmarshaller[bool], + marshal: primitiveMarshaller[bool], + } + + row := &testRow[bool]{ + v, "", "01", true, assert.NoError, + } + + tfn := newTestEnv[bool](s.T()) + if !tfn.testSet(row) { + return + } + tfn.testGet(row) + }) +} + +func (s *VariableTestSuite) TestSlice() { + v := Variable[[]uint16]{ + name: "TestVar", + guid: testGuid, + defaultAttrs: defaultAttrs, + marshal: sliceMarshaller[uint16], + unmarshal: sliceUnmarshaller[uint16], + } + + s.Run("SetGet", func() { + var tests = []*testRow[[]uint16]{ + {v, "Empty", "", []uint16(nil), assert.NoError}, + {v, "One", "1234", []uint16{0x3412}, assert.NoError}, + {v, "Many", "12345678", []uint16{0x3412, 0x7856}, assert.NoError}, + } + + runTests(s.T(), tests, func(t *testing.T, row *testRow[[]uint16], env *testEnv[[]uint16]) { + if !env.testSet(row) { + return + } + env.testGet(row) + }) + }) - c := efivario.NewFileSystemContext(fs) + s.Run("Get/ErrShort", func() { + rows := []*testRow[[]uint16]{ + {v, "", "1234", []uint16{0x3412}, assert.NoError}, + {v, "", "123456", []uint16{0x3412}, wantedError(io.ErrUnexpectedEOF)}, + } + + runTests(s.T(), rows, func(t *testing.T, row *testRow[[]uint16], env *testEnv[[]uint16]) { + if !assert.NoError(s.T(), env.setupVarFile(row)) { + return + } + env.testGet(row) + }) + }) +} + +type data struct { + A uint16 + B uint32 +} + +func (d *data) WriteTo(w io.Writer) (n int64, err error) { + return 0, binary.Write(w, binary.LittleEndian, d) +} + +func (d *data) ReadFrom(r io.Reader) (n int64, err error) { + return 0, binary.Read(r, binary.LittleEndian, d) +} + +func (s *VariableTestSuite) TestGetStruct() { + v := Variable[*data]{ + name: "TestVar", + guid: testGuid, + defaultAttrs: defaultAttrs, + marshal: structMarshaller[*data], + unmarshal: structUnmarshaller[data], + } + + s.Run("SetGet", func() { + var tests = []*testRow[*data]{ + {v, "Empty", "", nil, wantedError(io.EOF)}, + {v, "One", "012345678901", &data{0x2301, 0x01896745}, assert.NoError}, + {v, "Many", "12345678", nil, wantedError(io.ErrUnexpectedEOF)}, + } + + runTests(s.T(), tests, func(t *testing.T, row *testRow[*data], env *testEnv[*data]) { + if row.wanted == nil { + if !assert.NoError(t, env.setupVarFile(row)) { + return + } + } else { + if !env.testSet(row) { + return + } + } + env.testGet(row) + }) + }) +} - _, _, err = Boot(15).Get(c) - assert.NoError(t, err) +func TestVariablesTestSuite(t *testing.T) { + suite.Run(t, &VariableTestSuite{}) }