Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve signal handling - Make the state machine context aware #231

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cmd/ubuntu-image/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"io"
"os"
"os/signal"
"syscall"

"github.com/jessevdk/go-flags"

Expand Down Expand Up @@ -191,6 +193,16 @@ func main() { //nolint: gocyclo
imageType = parser.Command.Active.Name
}

// Properly handle signals to execute defered functions and make sure
// mounted dir are unmounted
ch := make(chan os.Signal, 2)
signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)

go func() {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you already have plans for the signal handler? Like, if the cleanup process blocks or takes too long for some reason (like trying to call umount on a problematic mounting point (I don't know, imagine a bad disk or NFS mount that got stuck)) it should be possible to still interrupt the tool (you could kill -9 it anyway).

If you ctrl+c again while the signal handler is running nothing will happen because the signal is being caught. You could handle subsequent ctrl+c's and inform the user that the tool is shutting down but if they want to stop it anyway they can hit ctrl+c again.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! That is the plan in the following PR #235 (still very much WIP)

<-ch
osExit(1)
}()

// init the state machine
sm, err := initStateMachine(imageType, commonOpts, stateMachineOpts, ubuntuImageCommand)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions internal/statemachine/classic.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (classicStateMachine *ClassicStateMachine) Setup() error {
// set the parent pointer of the embedded struct
classicStateMachine.parent = classicStateMachine

classicStateMachine.states = make([]stateFunc, 0)
classicStateMachine.stateFuncs = make([]stateFunc, 0)

if err := classicStateMachine.setConfDefDir(classicStateMachine.parent.(*ClassicStateMachine).Args.ImageDefinition); err != nil {
return err
Expand Down Expand Up @@ -379,7 +379,7 @@ func (s *StateMachine) calculateStates() error {
s.addArtifactsStates(c, &rootfsCreationStates)

// Append the newly calculated states to the slice of funcs in the parent struct
s.states = append(s.states, rootfsCreationStates...)
s.stateFuncs = append(s.stateFuncs, rootfsCreationStates...)

return nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/statemachine/classic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ func TestClassicStateMachine_calculateStates(t *testing.T) {
asserter.AssertErrNil(err, true)

stateNames := make([]string, 0)
for _, f := range stateMachine.states {
for _, f := range stateMachine.stateFuncs {
stateNames = append(stateNames, f.name)
}

Expand Down
2 changes: 1 addition & 1 deletion internal/statemachine/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (stateMachine *StateMachine) validateUntilThru() error {
}

if searchState != "" {
for _, state := range stateMachine.states {
for _, state := range stateMachine.stateFuncs {
if state.name == searchState {
stateFound = true
break
Expand Down
2 changes: 1 addition & 1 deletion internal/statemachine/pack.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (packStateMachine *PackStateMachine) Setup() error {
packStateMachine.parent = packStateMachine

// set the beginning states that will be used by all pack image builds
packStateMachine.states = packStates
packStateMachine.stateFuncs = packStates

// do the validation common to all image types
if err := packStateMachine.validateInput(); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/statemachine/snap.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (snapStateMachine *SnapStateMachine) Setup() error {
snapStateMachine.parent = snapStateMachine

// set the states that will be used for this image type
snapStateMachine.states = snapStates
snapStateMachine.stateFuncs = snapStates

if err := snapStateMachine.setConfDefDir(snapStateMachine.parent.(*SnapStateMachine).Args.ModelAssertion); err != nil {
return err
Expand Down
14 changes: 7 additions & 7 deletions internal/statemachine/state_machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ type StateMachine struct {
commonFlags *commands.CommonOpts
stateMachineFlags *commands.StateMachineOpts

states []stateFunc // the state functions
stateFuncs []stateFunc // the state functions

// used to access image type specific variables from state functions
parent SmInterface
Expand Down Expand Up @@ -442,12 +442,12 @@ func (stateMachine *StateMachine) readMetadata(metadataFile string) error {
func (stateMachine *StateMachine) loadState(partialStateMachine *StateMachine) error {
stateMachine.StepsTaken = partialStateMachine.StepsTaken

if stateMachine.StepsTaken > len(stateMachine.states) {
return fmt.Errorf("invalid steps taken count (%d). The state machine only have %d steps", stateMachine.StepsTaken, len(stateMachine.states))
if stateMachine.StepsTaken > len(stateMachine.stateFuncs) {
return fmt.Errorf("invalid steps taken count (%d). The state machine only have %d steps", stateMachine.StepsTaken, len(stateMachine.stateFuncs))
}

// delete all of the stateFuncs that have already run
stateMachine.states = stateMachine.states[stateMachine.StepsTaken:]
stateMachine.stateFuncs = stateMachine.stateFuncs[stateMachine.StepsTaken:]

stateMachine.CurrentStep = partialStateMachine.CurrentStep
stateMachine.YamlFilePath = partialStateMachine.YamlFilePath
Expand Down Expand Up @@ -509,7 +509,7 @@ func (s *StateMachine) displayStates() {
}
fmt.Printf("\nFollowing states %s be executed:\n", verb)

for i, state := range s.states {
for i, state := range s.stateFuncs {
if state.name == s.stateMachineFlags.Until {
break
}
Expand Down Expand Up @@ -609,8 +609,8 @@ func (stateMachine *StateMachine) Run() error {
return nil
}
// iterate through the states
for i := 0; i < len(stateMachine.states); i++ {
stateFunc := stateMachine.states[i]
for i := 0; i < len(stateMachine.stateFuncs); i++ {
stateFunc := stateMachine.stateFuncs[i]
stateMachine.CurrentStep = stateFunc.name
if stateFunc.name == stateMachine.stateMachineFlags.Until {
break
Expand Down
30 changes: 15 additions & 15 deletions internal/statemachine/state_machine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ type testStateMachine struct {
// testStateMachine needs its own setup
func (TestStateMachine *testStateMachine) Setup() error {
// set the states that will be used for this image type
TestStateMachine.states = allTestStates
TestStateMachine.stateFuncs = allTestStates

// do the validation common to all image types
if err := TestStateMachine.validateInput(); err != nil {
Expand Down Expand Up @@ -340,7 +340,7 @@ func TestDebug(t *testing.T) {
asserter.AssertErrNil(err, true)

// just use the one state
stateMachine.states = testStates
stateMachine.stateFuncs = testStates
stdout, restoreStdout, err := helper.CaptureStd(&os.Stdout)
asserter.AssertErrNil(err, true)

Expand All @@ -352,8 +352,8 @@ func TestDebug(t *testing.T) {
readStdout, err := io.ReadAll(stdout)
asserter.AssertErrNil(err, true)

if !strings.Contains(string(readStdout), stateMachine.states[0].name) {
t.Errorf("Expected state name \"%s\" to appear in output \"%s\"\n", stateMachine.states[0].name, string(readStdout))
if !strings.Contains(string(readStdout), stateMachine.stateFuncs[0].name) {
t.Errorf("Expected state name \"%s\" to appear in output \"%s\"\n", stateMachine.stateFuncs[0].name, string(readStdout))
}
}

Expand All @@ -376,7 +376,7 @@ func TestDryRun(t *testing.T) {
asserter.AssertErrNil(err, true)

// just use the one state
stateMachine.states = testStates
stateMachine.stateFuncs = testStates
stdout, restoreStdout, err := helper.CaptureStd(&os.Stdout)
asserter.AssertErrNil(err, true)

Expand All @@ -387,8 +387,8 @@ func TestDryRun(t *testing.T) {
readStdout, err := io.ReadAll(stdout)
asserter.AssertErrNil(err, true)

if strings.Contains(string(readStdout), stateMachine.states[0].name) {
t.Errorf("Expected state name \"%s\" to not appear in output \"%s\"\n", stateMachine.states[0].name, string(readStdout))
if strings.Contains(string(readStdout), stateMachine.stateFuncs[0].name) {
t.Errorf("Expected state name \"%s\" to not appear in output \"%s\"\n", stateMachine.stateFuncs[0].name, string(readStdout))
}
}

Expand Down Expand Up @@ -421,10 +421,10 @@ func TestFunctionErrors(t *testing.T) {
asserter.AssertErrNil(err, true)

// override the function, but save the old one
oldStateFunc := stateMachine.states[tc.overrideState]
stateMachine.states[tc.overrideState] = tc.newStateFunc
oldStateFunc := stateMachine.stateFuncs[tc.overrideState]
stateMachine.stateFuncs[tc.overrideState] = tc.newStateFunc
defer func() {
stateMachine.states[tc.overrideState] = oldStateFunc
stateMachine.stateFuncs[tc.overrideState] = oldStateFunc
}()
if err := stateMachine.Run(); err == nil {
if err := stateMachine.Teardown(); err == nil {
Expand Down Expand Up @@ -1159,7 +1159,7 @@ func TestStateMachine_readMetadata(t *testing.T) {
IsSeeded: true,
SectorSize: quantity.Size(512),
RootfsSize: quantity.Size(775915520),
states: allTestStates[2:],
stateFuncs: allTestStates[2:],
GadgetInfo: &gadget.Info{
Volumes: map[string]*gadget.Volume{
"pc": {
Expand Down Expand Up @@ -1246,7 +1246,7 @@ func TestStateMachine_readMetadata(t *testing.T) {
Resume: false,
WorkDir: filepath.Join(testDataDir, "metadata"),
},
states: allTestStates,
stateFuncs: allTestStates,
},
shouldPass: true,
expectedError: "error reading metadata file",
Expand All @@ -1270,7 +1270,7 @@ func TestStateMachine_readMetadata(t *testing.T) {
Resume: tc.args.resume,
WorkDir: filepath.Join(testDataDir, "metadata"),
},
states: allTestStates,
stateFuncs: allTestStates,
}

err := gotStateMachine.readMetadata(tc.args.metadataFile)
Expand Down Expand Up @@ -1304,7 +1304,7 @@ func TestStateMachine_writeMetadata(t *testing.T) {
IsSeeded: true,
SectorSize: quantity.Size(512),
RootfsSize: quantity.Size(775915520),
states: allTestStates[2:],
stateFuncs: allTestStates[2:],
GadgetInfo: &gadget.Info{
Volumes: map[string]*gadget.Volume{
"pc": {
Expand Down Expand Up @@ -1597,7 +1597,7 @@ Continuing
s := &StateMachine{
commonFlags: tt.fields.commonFlags,
stateMachineFlags: tt.fields.stateMachineFlags,
states: tt.fields.states,
stateFuncs: tt.fields.states,
}
s.displayStates()

Expand Down