From 245d68f8ea72e32600359f30bf525dea07817dac Mon Sep 17 00:00:00 2001 From: Bart de Boer Date: Mon, 22 Jun 2020 17:13:49 +0200 Subject: [PATCH] Extend Persistent*Run behavior to allow multiple hooks throughout the execution chain --- cobra.go | 3 + command.go | 174 ++++++++++++++++++++++++++++++++++++++---------- command_test.go | 126 +++++++++++++++++++++++++++++++---- 3 files changed, 254 insertions(+), 49 deletions(-) diff --git a/cobra.go b/cobra.go index d01becc8..45f679e1 100644 --- a/cobra.go +++ b/cobra.go @@ -39,6 +39,9 @@ var templateFuncs = template.FuncMap{ var initializers []func() +// EnablePersistentRunOverride ensures Persistent*Run* functions in childs override their parents +var EnablePersistentRunOverride = true + // EnablePrefixMatching allows to set automatic prefix matching. Automatic prefix matching can be a dangerous thing // to automatically enable in CLI tools. // Set this to true to enable it. diff --git a/command.go b/command.go index 5f1caccc..6f1653c5 100644 --- a/command.go +++ b/command.go @@ -118,6 +118,17 @@ type Command struct { // PersistentPostRunE: PersistentPostRun but returns an error. PersistentPostRunE func(cmd *Command, args []string) error + // persistentPreRunHooks are executed before the command or one of its children are executed + persistentPreRunHooks []func(cmd *Command, args []string) error + // preRunHooks are executed before the command is executed + preRunHooks []func(cmd *Command, args []string) error + // runHooks are executed when the command is executed + runHooks []func(cmd *Command, args []string) error + // postRunHooks are executed after the command has executed + postRunHooks []func(cmd *Command, args []string) error + // persistentPostRunHooks are executed after the command or one of its children have executed + persistentPostRunHooks []func(cmd *Command, args []string) error + // SilenceErrors is an option to quiet errors down stream. SilenceErrors bool @@ -816,52 +827,104 @@ func (c *Command) execute(a []string) (err error) { return err } - for p := c; p != nil; p = p.Parent() { - if p.PersistentPreRunE != nil { - if err := p.PersistentPreRunE(c, argWoFlags); err != nil { - return err - } - break - } else if p.PersistentPreRun != nil { - p.PersistentPreRun(c, argWoFlags) - break - } - } + var persistentPreRunHooks []func(cmd *Command, args []string) error + preRunHooks := c.preRunHooks + runHooks := c.runHooks + postRunHooks := c.postRunHooks + var persistentPostRunHooks []func(cmd *Command, args []string) error + + // Merge the PreRun functions into the preRunHooks slice if c.PreRunE != nil { - if err := c.PreRunE(c, argWoFlags); err != nil { - return err - } + preRunHooks = append(preRunHooks, c.PreRunE) } else if c.PreRun != nil { - c.PreRun(c, argWoFlags) + preRunHook := c.PreRun + preRunHooks = append(preRunHooks, func(cmd *Command, args []string) error { + preRunHook(cmd, args) + return nil + }) } + // Merge the Run functions into the runHooks slice + if c.RunE != nil { + runHooks = append(runHooks, c.RunE) + } else if c.Run != nil { + runHook := c.Run + runHooks = append(runHooks, func(cmd *Command, args []string) error { + runHook(cmd, args) + return nil + }) + } + + // Merge the PostRun functions into the runHooks slice + if c.PostRunE != nil { + postRunHooks = append(postRunHooks, c.PostRunE) + } else if c.PostRun != nil { + postRunHook := c.PostRun + postRunHooks = append(postRunHooks, func(cmd *Command, args []string) error { + postRunHook(cmd, args) + return nil + }) + } + + // Find and merge the Persistent*Run functions into the persistent*Run slices. + // If EnablePersistentRunOverride is set Persistent*Run from childs will override their parents. + // Any hooks registered through OnPersistent*Run will always be executed and cannot be overriden. + hasLegacyPersistentPreRun := false + hasLegacyPersistentPostRun := false + for p := c; p != nil; p = p.Parent() { + if !hasLegacyPersistentPreRun || !EnablePersistentRunOverride { + if p.PersistentPreRunE != nil { + persistentPreRunHooks = append([]func(cmd *Command, args []string) error{ + p.PersistentPreRunE, + }, persistentPreRunHooks...) + hasLegacyPersistentPreRun = true + } else if p.PersistentPreRun != nil { + persistentPreRunHook := p.PersistentPreRun + persistentPreRunHooks = append([]func(cmd *Command, args []string) error{ + func(cmd *Command, args []string) error { + persistentPreRunHook(cmd, args) + return nil + }, + }, persistentPreRunHooks...) + hasLegacyPersistentPreRun = true + } + } + if !hasLegacyPersistentPostRun || !EnablePersistentRunOverride { + if p.PersistentPostRunE != nil { + persistentPostRunHooks = append(persistentPostRunHooks, p.PersistentPostRunE) + hasLegacyPersistentPostRun = true + } else if p.PersistentPostRun != nil { + persistentPostRunHook := p.PersistentPostRun + persistentPostRunHooks = append(persistentPostRunHooks, func(cmd *Command, args []string) error { + persistentPostRunHook(cmd, args) + return nil + }) + hasLegacyPersistentPostRun = true + } + } + + persistentPreRunHooks = append(p.persistentPreRunHooks, persistentPreRunHooks...) + persistentPostRunHooks = append(persistentPostRunHooks, p.persistentPostRunHooks...) + } + + // Execute the hooks: + if err := c.executeHooks(&persistentPreRunHooks, argWoFlags); err != nil { + return err + } + if err := c.executeHooks(&preRunHooks, argWoFlags); err != nil { + return err + } if err := c.validateRequiredFlags(); err != nil { return err } - if c.RunE != nil { - if err := c.RunE(c, argWoFlags); err != nil { - return err - } - } else { - c.Run(c, argWoFlags) + if err := c.executeHooks(&runHooks, argWoFlags); err != nil { + return err } - if c.PostRunE != nil { - if err := c.PostRunE(c, argWoFlags); err != nil { - return err - } - } else if c.PostRun != nil { - c.PostRun(c, argWoFlags) + if err := c.executeHooks(&postRunHooks, argWoFlags); err != nil { + return err } - for p := c; p != nil; p = p.Parent() { - if p.PersistentPostRunE != nil { - if err := p.PersistentPostRunE(c, argWoFlags); err != nil { - return err - } - break - } else if p.PersistentPostRun != nil { - p.PersistentPostRun(c, argWoFlags) - break - } + if err := c.executeHooks(&persistentPostRunHooks, argWoFlags); err != nil { + return err } return nil @@ -873,6 +936,43 @@ func (c *Command) preRun() { } } +// executeHooks executes a slice of hooks +func (c *Command) executeHooks(hooks *[]func(cmd *Command, args []string) error, args []string) error { + for _, x := range *hooks { + if err := x(c, args); err != nil { + return err + } + } + return nil +} + +// OnPersistentPreRun registers one or more hooks on the command to be executed +// before the command or one of its children are executed +func (c *Command) OnPersistentPreRun(f ...func(cmd *Command, args []string) error) { + c.persistentPreRunHooks = append(c.persistentPreRunHooks, f...) +} + +// OnPreRun registers one or more hooks on the command to be executed before the command is executed +func (c *Command) OnPreRun(f ...func(cmd *Command, args []string) error) { + c.preRunHooks = append(c.preRunHooks, f...) +} + +// OnRun registers one or more hooks on the command to be executed when the command is executed +func (c *Command) OnRun(f ...func(cmd *Command, args []string) error) { + c.runHooks = append(c.runHooks, f...) +} + +// OnPostRun registers one or more hooks on the command to be executed after the command has executed +func (c *Command) OnPostRun(f ...func(cmd *Command, args []string) error) { + c.postRunHooks = append(c.postRunHooks, f...) +} + +// OnPersistentPostRun register one or more hooks on the command to be executed +// after the command or one of its children have executed +func (c *Command) OnPersistentPostRun(f ...func(cmd *Command, args []string) error) { + c.persistentPostRunHooks = append(c.persistentPostRunHooks, f...) +} + // ExecuteContext is the same as Execute(), but sets the ctx on the command. // Retrieve ctx by calling cmd.Context() inside your *Run lifecycle functions. func (c *Command) ExecuteContext(ctx context.Context) error { diff --git a/command_test.go b/command_test.go index 16cc41b4..71999b64 100644 --- a/command_test.go +++ b/command_test.go @@ -1332,6 +1332,23 @@ func TestPersistentHooks(t *testing.T) { childPersPostArgs string ) + var ( + persParentPersPreArgs string + persParentPreArgs string + persParentRunArgs string + persParentPostArgs string + persParentPersPostArgs string + ) + + var ( + persChildPersPreArgs string + persChildPreArgs string + persChildPreArgs2 string + persChildRunArgs string + persChildPostArgs string + persChildPersPostArgs string + ) + parentCmd := &Command{ Use: "parent", PersistentPreRun: func(_ *Command, args []string) { @@ -1371,6 +1388,52 @@ func TestPersistentHooks(t *testing.T) { } parentCmd.AddCommand(childCmd) + parentCmd.OnPersistentPreRun(func(_ *Command, args []string) error { + persParentPersPreArgs = strings.Join(args, " ") + return nil + }) + parentCmd.OnPreRun(func(_ *Command, args []string) error { + persParentPreArgs = strings.Join(args, " ") + return nil + }) + parentCmd.OnRun(func(_ *Command, args []string) error { + persParentRunArgs = strings.Join(args, " ") + return nil + }) + parentCmd.OnPostRun(func(_ *Command, args []string) error { + persParentPostArgs = strings.Join(args, " ") + return nil + }) + parentCmd.OnPersistentPostRun(func(_ *Command, args []string) error { + persParentPersPostArgs = strings.Join(args, " ") + return nil + }) + + childCmd.OnPersistentPreRun(func(_ *Command, args []string) error { + persChildPersPreArgs = strings.Join(args, " ") + return nil + }) + childCmd.OnPreRun(func(_ *Command, args []string) error { + persChildPreArgs = strings.Join(args, " ") + return nil + }) + childCmd.OnPreRun(func(_ *Command, args []string) error { + persChildPreArgs2 = strings.Join(args, " ") + " three" + return nil + }) + childCmd.OnRun(func(_ *Command, args []string) error { + persChildRunArgs = strings.Join(args, " ") + return nil + }) + childCmd.OnPostRun(func(_ *Command, args []string) error { + persChildPostArgs = strings.Join(args, " ") + return nil + }) + childCmd.OnPersistentPostRun(func(_ *Command, args []string) error { + persChildPersPostArgs = strings.Join(args, " ") + return nil + }) + output, err := executeCommand(parentCmd, "child", "one", "two") if output != "" { t.Errorf("Unexpected output: %v", output) @@ -1378,14 +1441,12 @@ func TestPersistentHooks(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) } - - // TODO: currently PersistenPreRun* defined in parent does not - // run if the matchin child subcommand has PersistenPreRun. - // If the behavior changes (https://github.com/spf13/cobra/issues/252) - // this test must be fixed. - if parentPersPreArgs != "" { + if EnablePersistentRunOverride && parentPersPreArgs != "" { t.Errorf("Expected blank parentPersPreArgs, got %q", parentPersPreArgs) } + if !EnablePersistentRunOverride && parentPersPreArgs != "one two" { + t.Errorf("Expected parentPersPreArgs %q, got %q", "one two", parentPersPreArgs) + } if parentPreArgs != "" { t.Errorf("Expected blank parentPreArgs, got %q", parentPreArgs) } @@ -1395,14 +1456,12 @@ func TestPersistentHooks(t *testing.T) { if parentPostArgs != "" { t.Errorf("Expected blank parentPostArgs, got %q", parentPostArgs) } - // TODO: currently PersistenPostRun* defined in parent does not - // run if the matchin child subcommand has PersistenPostRun. - // If the behavior changes (https://github.com/spf13/cobra/issues/252) - // this test must be fixed. - if parentPersPostArgs != "" { + if EnablePersistentRunOverride && parentPersPostArgs != "" { t.Errorf("Expected blank parentPersPostArgs, got %q", parentPersPostArgs) } - + if !EnablePersistentRunOverride && parentPersPostArgs != "one two" { + t.Errorf("Expected parentPersPostArgs %q, got %q", "one two", parentPersPostArgs) + } if childPersPreArgs != "one two" { t.Errorf("Expected childPersPreArgs %q, got %q", "one two", childPersPreArgs) } @@ -1418,6 +1477,49 @@ func TestPersistentHooks(t *testing.T) { if childPersPostArgs != "one two" { t.Errorf("Expected childPersPostArgs %q, got %q", "one two", childPersPostArgs) } + + // Test On*Run hooks + + if persParentPersPreArgs != "one two" { + t.Errorf("Expected persParentPersPreArgs %q, got %q", "one two", persParentPersPreArgs) + } + if persParentPreArgs != "" { + t.Errorf("Expected persParentPreArgs %q, got %q", "one two", persParentPreArgs) + } + if persParentRunArgs != "" { + t.Errorf("Expected persParentRunArgs %q, got %q", "one two", persParentRunArgs) + } + if persParentPostArgs != "" { + t.Errorf("Expected persParentPostArgs %q, got %q", "one two", persParentPostArgs) + } + if persParentPersPostArgs != "one two" { + t.Errorf("Expected persParentPersPostArgs %q, got %q", "one two", persParentPersPostArgs) + } + + if persChildPersPreArgs != "one two" { + t.Errorf("Expected persChildPersPreArgs %q, got %q", "one two", persChildPersPreArgs) + } + if persChildPreArgs != "one two" { + t.Errorf("Expected persChildPreArgs %q, got %q", "one two", persChildPreArgs) + } + if persChildPreArgs2 != "one two three" { + t.Errorf("Expected persChildPreArgs %q, got %q", "one two three", persChildPreArgs2) + } + if persChildRunArgs != "one two" { + t.Errorf("Expected persChildRunArgs %q, got %q", "one two", persChildRunArgs) + } + if persChildPostArgs != "one two" { + t.Errorf("Expected persChildPostArgs %q, got %q", "one two", persChildPostArgs) + } + if persChildPersPostArgs != "one two" { + t.Errorf("Expected persChildPersPostArgs %q, got %q", "one two", persChildPersPostArgs) + } +} + +func TestPersistentHooksWoOverride(t *testing.T) { + EnablePersistentRunOverride = false + TestPersistentHooks(t) + EnablePersistentRunOverride = true } // Related to https://github.com/spf13/cobra/issues/521.