diff --git a/command.go b/command.go index 42e500de..a6195716 100644 --- a/command.go +++ b/command.go @@ -17,6 +17,7 @@ package cobra import ( "bytes" + "context" "errors" "fmt" "io" @@ -143,9 +144,11 @@ type Command struct { // TraverseChildren parses flags on all parents before executing child command. TraverseChildren bool - //FParseErrWhitelist flag parse errors to be ignored + // FParseErrWhitelist flag parse errors to be ignored FParseErrWhitelist FParseErrWhitelist + ctx context.Context + // commands is the list of commands supported by this program. commands []*Command // parent is a parent command for this command. @@ -205,6 +208,12 @@ type Command struct { errWriter io.Writer } +// Context returns underlying command context. If command wasn't +// executed with ExecuteContext the returned context will be nil. +func (c *Command) Context() context.Context { + return c.ctx +} + // SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden // particularly useful when testing. func (c *Command) SetArgs(a []string) { @@ -860,6 +869,13 @@ func (c *Command) preRun() { } } +// ExecuteContext is the same as Execute(), but sets the ctx on the command. +// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle functions. +func (c *Command) ExecuteContext(ctx context.Context) error { + c.ctx = ctx + return c.Execute() +} + // Execute uses the args (os.Args[1:] by default) // and run through the command tree finding appropriate matches // for commands and then corresponding flags. @@ -914,6 +930,12 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { cmd.commandCalledAs.name = cmd.Name() } + // We have to pass global context to children command + // if context is present on the parent command. + if cmd.ctx == nil { + cmd.ctx = c.ctx + } + err = cmd.execute(flags) if err != nil { // Always show help if requested, even if SilenceErrors is in @@ -1558,7 +1580,7 @@ func (c *Command) ParseFlags(args []string) error { beforeErrorBufLen := c.flagErrorBuf.Len() c.mergePersistentFlags() - //do it here after merging all flags and just before parse + // do it here after merging all flags and just before parse c.Flags().ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist) err := c.Flags().Parse(args) diff --git a/command_test.go b/command_test.go index b26bd4ab..2d9b5847 100644 --- a/command_test.go +++ b/command_test.go @@ -2,6 +2,7 @@ package cobra import ( "bytes" + "context" "fmt" "os" "reflect" @@ -18,6 +19,16 @@ func executeCommand(root *Command, args ...string) (output string, err error) { return output, err } +func executeCommandWithContext(ctx context.Context, root *Command, args ...string) (output string, err error) { + buf := new(bytes.Buffer) + root.SetOutput(buf) + root.SetArgs(args) + + err = root.ExecuteContext(ctx) + + return buf.String(), err +} + func executeCommandC(root *Command, args ...string) (c *Command, output string, err error) { buf := new(bytes.Buffer) root.SetOutput(buf) @@ -135,6 +146,35 @@ func TestSubcommandExecuteC(t *testing.T) { } } +func TestExecuteContext(t *testing.T) { + ctx := context.Background() + + ctxRun := func(cmd *Command, args []string) { + if cmd.Context() != ctx { + t.Errorf("Command %q must have context when called with ExecuteContext", cmd.Use) + } + } + + rootCmd := &Command{Use: "root", Run: ctxRun, PreRun: ctxRun} + childCmd := &Command{Use: "child", Run: ctxRun, PreRun: ctxRun} + granchildCmd := &Command{Use: "grandchild", Run: ctxRun, PreRun: ctxRun} + + childCmd.AddCommand(granchildCmd) + rootCmd.AddCommand(childCmd) + + if _, err := executeCommandWithContext(ctx, rootCmd, ""); err != nil { + t.Errorf("Root command must not fail: %+v", err) + } + + if _, err := executeCommandWithContext(ctx, rootCmd, "child"); err != nil { + t.Errorf("Subcommand must not fail: %+v", err) + } + + if _, err := executeCommandWithContext(ctx, rootCmd, "child", "grandchild"); err != nil { + t.Errorf("Command child must not fail: %+v", err) + } +} + func TestRootUnknownCommandSilenced(t *testing.T) { rootCmd := &Command{Use: "root", Run: emptyRun} rootCmd.SilenceErrors = true