Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework grub setup #246

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
4 changes: 4 additions & 0 deletions cmd/ubuntu-image/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/jessevdk/go-flags"

"github.com/canonical/ubuntu-image/internal/arch"
"github.com/canonical/ubuntu-image/internal/commands"
"github.com/canonical/ubuntu-image/internal/helper"
"github.com/canonical/ubuntu-image/internal/statemachine"
Expand Down Expand Up @@ -38,6 +39,9 @@ func initStateMachine(imageType string, commonOpts *commands.CommonOpts, stateMa
Args: ubuntuImageCommand.Classic.ClassicArgsPassed,
}
case "pack":
if ubuntuImageCommand.Pack.PackOptsPassed.Architecture == "" {
ubuntuImageCommand.Pack.PackOptsPassed.Architecture = arch.GetHostArch()
}
stateMachine = &statemachine.PackStateMachine{
Opts: ubuntuImageCommand.Pack.PackOptsPassed,
}
Expand Down
40 changes: 40 additions & 0 deletions cmd/ubuntu-image/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ func (mockSM *mockedStateMachine) SetSeries() error {
return nil
}

func (mockSM *mockedStateMachine) Architecture() (string, error) {
return "", nil
}

// TestValidCommands tests that certain valid commands are parsed correctly
func TestValidCommands(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -412,6 +416,42 @@ func Test_initStateMachine(t *testing.T) {
Args: commands.ClassicArgs{},
},
},
{
name: "init a pack state machine, no arch",
args: args{
imageType: "pack",
commonOpts: &commands.CommonOpts{},
stateMachineOpts: &commands.StateMachineOpts{},
ubuntuImageCommand: &commands.UbuntuImageCommand{
Pack: commands.PackCommand{
PackOptsPassed: commands.PackOpts{},
},
},
},
want: &statemachine.PackStateMachine{
StateMachine: statemachine.StateMachine{},
Opts: commands.PackOpts{Architecture: "amd64"},
},
},
{
name: "init a pack state machine, with arch",
args: args{
imageType: "pack",
commonOpts: &commands.CommonOpts{},
stateMachineOpts: &commands.StateMachineOpts{},
ubuntuImageCommand: &commands.UbuntuImageCommand{
Pack: commands.PackCommand{
PackOptsPassed: commands.PackOpts{
Architecture: "arm64",
},
},
},
},
want: &statemachine.PackStateMachine{
StateMachine: statemachine.StateMachine{},
Opts: commands.PackOpts{Architecture: "arm64"},
},
},
{
name: "fail to init an unknown statemachine",
args: args{
Expand Down
23 changes: 23 additions & 0 deletions internal/arch/arch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package arch

import (
"os/exec"
"strings"
)

const (
AMD64 = "amd64"
ARM64 = "arm64"
ARMHF = "armhf"
I386 = "i386"
PPC64EL = "ppc64el"
S390X = "s390x"
RISCV64 = "riscv64"
)

// GetHostArch uses dpkg to return the host architecture of the current system
func GetHostArch() string {
cmd := exec.Command("dpkg", "--print-architecture")
outputBytes, _ := cmd.Output() // nolint: errcheck
return strings.TrimSpace(string(outputBytes))
}
34 changes: 34 additions & 0 deletions internal/arch/arch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package arch

import (
"runtime"
"testing"
)

// TestGetHostArch unit tests the getHostArch function
func TestGetHostArch(t *testing.T) {
t.Parallel()

var expected string
switch runtime.GOARCH {
case "amd64":
expected = "amd64"
case "arm":
expected = "armhf"
case "arm64":
expected = "arm64"
case "ppc64le":
expected = "ppc64el"
case "s390x":
expected = "s390x"
case "riscv64":
expected = "riscv64"
default:
t.Skipf("Test not supported on architecture %s", runtime.GOARCH)
}

hostArch := GetHostArch()
if hostArch != expected {
t.Errorf("Wrong value of getHostArch. Expected %s, got %s", expected, hostArch)
}
}
1 change: 1 addition & 0 deletions internal/commands/pack.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ type PackOpts struct {
ArtifactType string `long:"artifact-type" description:"Type of the resulting disk image file." required:"true" default:"raw" choice:"raw"`
GadgetDir string `long:"gadget-dir" description:"Directory containing the gadget tree. The gadget.yaml file is expected to be in a meta subdirectory." required:"true"`
RootfsDir string `long:"rootfs-dir" description:"Directory containing the rootfs" required:"true"`
Architecture string `long:"architecture" description:"CPU architecture of the image. Default to the host architecture executing the tool." required:"false"`
}

type PackCommand struct {
Expand Down
4 changes: 3 additions & 1 deletion internal/imagedefinition/image_definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"strings"

"github.com/xeipuuv/gojsonschema"

"github.com/canonical/ubuntu-image/internal/arch"
)

// ImageDefinition is the parent struct for the data
Expand Down Expand Up @@ -309,7 +311,7 @@ type DependentKeyError struct {
}

func (i ImageDefinition) securityMirror() string {
if i.Architecture == "amd64" || i.Architecture == "i386" {
if i.Architecture == arch.AMD64 || i.Architecture == arch.I386 {
return "http://security.ubuntu.com/ubuntu/"
}
return i.Rootfs.Mirror
Expand Down
3 changes: 2 additions & 1 deletion internal/imagedefinition/image_definition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/xeipuuv/gojsonschema"

"github.com/canonical/ubuntu-image/internal/arch"
"github.com/canonical/ubuntu-image/internal/helper"
)

Expand Down Expand Up @@ -307,7 +308,7 @@ func TestImageDefinition_securityMirror(t *testing.T) {
{
name: "amd64",
fields: fields{
Architecture: "amd64",
Architecture: arch.AMD64,
Rootfs: &Rootfs{
Mirror: "http://archive.ubuntu.com/ubuntu/",
},
Expand Down
22 changes: 16 additions & 6 deletions internal/partition/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,30 @@ func newPartitionTable(volume *gadget.Volume, sectorSize uint64, imgSize uint64)

// GeneratePartitionTable prepares the partition table for structures in a volume and
// returns it with the partition number of the root partition.
func GeneratePartitionTable(volume *gadget.Volume, sectorSize uint64, imgSize uint64, isSeeded bool) (partition.Table, int, error) {
partitionNumber, rootfsPartitionNumber := 1, -1
func GeneratePartitionTable(volume *gadget.Volume, sectorSize uint64, imgSize uint64, isSeeded bool) (partition.Table, int, int, error) {
partitionNumber, rootfsPartitionNumber, bootPartitionNumber := 1, -1, -1
partitionTable := newPartitionTable(volume, sectorSize, imgSize)
onDisk := gadget.OnDiskStructsFromGadget(volume)

for i := range volume.Structure {
structure := &volume.Structure[i]
if !structure.IsPartition() || helper.ShouldSkipStructure(structure, isSeeded) {
if !structure.IsPartition() {
continue
}

// Record the actual partition number of the boot partition, as it
// might be useful for certain operations (like updating the bootloader)
if helper.IsSystemBootStructure(structure) {
bootPartitionNumber = partitionNumber
}

if helper.ShouldSkipStructure(structure, isSeeded) {
continue
}

// Record the actual partition number of the root partition, as it
// might be useful for certain operations (like updating the bootloader)
if helper.IsRootfsStructure(structure) { //nolint:gosec,G301
if helper.IsRootfsStructure(structure) {
rootfsPartitionNumber = partitionNumber
}

Expand All @@ -85,13 +95,13 @@ func GeneratePartitionTable(volume *gadget.Volume, sectorSize uint64, imgSize ui
structureType := getStructureType(structure, volume.Schema)
err := partitionTable.AddPartition(structurePair, structureType)
if err != nil {
return nil, rootfsPartitionNumber, err
return nil, rootfsPartitionNumber, bootPartitionNumber, err
}

partitionNumber++
}

return partitionTable.GetConcreteTable(), rootfsPartitionNumber, nil
return partitionTable.GetConcreteTable(), rootfsPartitionNumber, bootPartitionNumber, nil
}

// getStructureType extracts the structure type from the structure.Type considering
Expand Down
9 changes: 8 additions & 1 deletion internal/partition/partition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ func TestGeneratePartitionTable(t *testing.T) {
args args
wantPartitionTable partition.Table
wantRootfsPartNumber int
wantBootPartNumber int
expectedError string
}{
{
Expand All @@ -178,6 +179,7 @@ func TestGeneratePartitionTable(t *testing.T) {
imgSize: uint64(4 * quantity.SizeKiB),
},
wantRootfsPartNumber: 2,
wantBootPartNumber: -1,
wantPartitionTable: &gpt.Table{
LogicalSectorSize: int(sectorSize512),
PhysicalSectorSize: int(sectorSize512),
Expand Down Expand Up @@ -206,6 +208,7 @@ func TestGeneratePartitionTable(t *testing.T) {
imgSize: uint64(4 * quantity.SizeKiB),
},
wantRootfsPartNumber: 2,
wantBootPartNumber: -1,
wantPartitionTable: &gpt.Table{
LogicalSectorSize: int(sectorSize4k),
PhysicalSectorSize: int(sectorSize4k),
Expand Down Expand Up @@ -234,6 +237,7 @@ func TestGeneratePartitionTable(t *testing.T) {
imgSize: uint64(4 * quantity.SizeKiB),
},
wantRootfsPartNumber: 2,
wantBootPartNumber: -1,
wantPartitionTable: &gpt.Table{
LogicalSectorSize: int(sectorSize512),
PhysicalSectorSize: int(sectorSize512),
Expand Down Expand Up @@ -262,6 +266,7 @@ func TestGeneratePartitionTable(t *testing.T) {
imgSize: uint64(4 * quantity.SizeKiB),
},
wantRootfsPartNumber: 2,
wantBootPartNumber: -1,
expectedError: `The structure "writable" overlaps GPT header or GPT partition table`,
},
{
Expand All @@ -272,6 +277,7 @@ func TestGeneratePartitionTable(t *testing.T) {
imgSize: uint64(4 * quantity.SizeKiB),
},
wantRootfsPartNumber: -1,
wantBootPartNumber: -1,
wantPartitionTable: &mbr.Table{
Partitions: []*mbr.Partition{
{
Expand All @@ -288,11 +294,12 @@ func TestGeneratePartitionTable(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
asserter := &helper.Asserter{T: t}
gotPartitionTable, gotRootfsPartNumber, gotErr := GeneratePartitionTable(tt.args.volume, tt.args.sectorSize, tt.args.imgSize, tt.args.isSeeded)
gotPartitionTable, gotRootfsPartNumber, gotBootPartNumber, gotErr := GeneratePartitionTable(tt.args.volume, tt.args.sectorSize, tt.args.imgSize, tt.args.isSeeded)

if len(tt.expectedError) == 0 {
asserter.AssertErrNil(gotErr, true)
asserter.AssertEqual(tt.wantRootfsPartNumber, gotRootfsPartNumber)
asserter.AssertEqual(tt.wantBootPartNumber, gotBootPartNumber)
asserter.AssertEqual(tt.wantPartitionTable, gotPartitionTable)
} else {
asserter.AssertErrContains(gotErr, tt.expectedError)
Expand Down
Loading