diff --git a/cobra.go b/cobra.go index d6cbfd71..651d3843 100644 --- a/cobra.go +++ b/cobra.go @@ -40,6 +40,8 @@ var templateFuncs = template.FuncMap{ var initializers []func() +var initializersE []func() error + // EnablePrefixMatching allows to set automatic prefix matching. Automatic prefix matching can be a dangerous thing // to automatically enable in CLI tools. // Set this to true to enable it. @@ -84,6 +86,12 @@ func OnInitialize(y ...func()) { initializers = append(initializers, y...) } +// OnInitializeE sets the passed functions to be run when each command's +// Execute method is called. +func OnInitializeE(y ...func() error) { + initializersE = append(initializersE, y...) +} + // FIXME Gt is unused by cobra and should be removed in a version 2. It exists only for compatibility with users of cobra. // Gt takes two types and checks whether the first type is greater than the second. In case of types Arrays, Chans, diff --git a/cobra_test.go b/cobra_test.go index 1219cc07..6419fd22 100644 --- a/cobra_test.go +++ b/cobra_test.go @@ -1,6 +1,7 @@ package cobra import ( + "errors" "testing" "text/template" ) @@ -26,3 +27,30 @@ func TestAddTemplateFunctions(t *testing.T) { t.Errorf("Expected UsageString: %v\nGot: %v", expected, got) } } + +func Test_OnInitialize(t *testing.T) { + call := false + c := &Command{Use: "c", Run: emptyRun} + OnInitialize(func() { + call = true + }) + _, err := executeCommand(c) + if err != nil { + t.Error(err) + } + if !call { + t.Error("expected OnInitialize func to be called") + } +} + +func Test_OnInitializeE(t *testing.T) { + c := &Command{Use: "c", Run: emptyRun} + e := errors.New("test error") + OnInitializeE(func() error { + return e + }) + _, err := executeCommand(c) + if err != e { + t.Error("expected error: %w", e) + } +} diff --git a/command.go b/command.go index 0f4511f3..3238afce 100644 --- a/command.go +++ b/command.go @@ -830,6 +830,10 @@ func (c *Command) execute(a []string) (err error) { return flag.ErrHelp } + if err := c.preRunE(); err != nil { + return err + } + c.preRun() argWoFlags := c.Flags().Args() @@ -902,6 +906,17 @@ func (c *Command) preRun() { } } +func (c *Command) preRunE() error { + for _, x := range initializersE { + err := x() + if err != nil { + return err + } + } + + return nil +} + // ExecuteContext is the same as Execute(), but sets the ctx on the command. // Retrieve ctx by calling cmd.Context() inside your *Run lifecycle or ValidArgs // functions.