diff --git a/interp/builtin.go b/interp/builtin.go index fcfcea94..0605f8f7 100644 --- a/interp/builtin.go +++ b/interp/builtin.go @@ -15,7 +15,6 @@ import ( "path/filepath" "strconv" "strings" - "sync" "syscall" "github.com/muesli/cancelreader" @@ -954,25 +953,21 @@ func (r *Runner) readLine(ctx context.Context, raw bool) ([]byte, error) { // but still fail on other errors, which may be unexpected or hide bugs. // See the upstream issue: https://github.com/muesli/cancelreader/issues/23 if cr, err := cancelreader.NewReader(r.stdin); err == nil { - done := make(chan struct{}) - var wg sync.WaitGroup - wg.Add(1) - go func() { - select { - case <-ctx.Done(): - cr.Cancel() - case <-done: - } - wg.Done() - }() + stopc := make(chan struct{}) + stop := context.AfterFunc(ctx, func() { + cr.Cancel() + close(stopc) + }) defer func() { - close(done) - wg.Wait() - // Could put the Close in the above goroutine, but if "read" is - // immediately called again, the Close might overlap with creating a - // new cancelreader. Want this cancelreader to be completely closed - // by the time readLine returns. - cr.Close() + if !stop() { + // The AfterFunc was started; wait for it to complete and close the cancel reader. + // Could put the Close in the above goroutine, but if "read" is + // immediately called again, the Close might overlap with creating a + // new cancelreader. Want this cancelreader to be completely closed + // by the time readLine returns. + <-stopc + cr.Close() + } }() stdin = cr } diff --git a/interp/handler.go b/interp/handler.go index c617328e..edac099a 100644 --- a/interp/handler.go +++ b/interp/handler.go @@ -109,25 +109,18 @@ func DefaultExecHandler(killTimeout time.Duration) ExecHandlerFunc { err = cmd.Start() if err == nil { - if done := ctx.Done(); done != nil { - go func() { - <-done - - if killTimeout <= 0 || runtime.GOOS == "windows" { - _ = cmd.Process.Signal(os.Kill) - return - } - - // TODO: don't temporarily leak this goroutine - // if the program stops itself with the - // interrupt. - go func() { - time.Sleep(killTimeout) - _ = cmd.Process.Signal(os.Kill) - }() - _ = cmd.Process.Signal(os.Interrupt) - }() - } + stopf := context.AfterFunc(ctx, func() { + if killTimeout <= 0 || runtime.GOOS == "windows" { + _ = cmd.Process.Signal(os.Kill) + return + } + _ = cmd.Process.Signal(os.Interrupt) + // TODO: don't sleep in this goroutine if the program + // stops itself with the interrupt above. + time.Sleep(killTimeout) + _ = cmd.Process.Signal(os.Kill) + }) + defer stopf() err = cmd.Wait() }