Extend Persistent*Run behavior to allow multiple hooks throughout the execution chain

This commit is contained in:
Bart de Boer 2020-06-22 17:13:49 +02:00
parent 04318720db
commit 245d68f8ea
3 changed files with 254 additions and 49 deletions

View file

@ -39,6 +39,9 @@ var templateFuncs = template.FuncMap{
var initializers []func()
// EnablePersistentRunOverride ensures Persistent*Run* functions in childs override their parents
var EnablePersistentRunOverride = true
// EnablePrefixMatching allows to set automatic prefix matching. Automatic prefix matching can be a dangerous thing
// to automatically enable in CLI tools.
// Set this to true to enable it.

View file

@ -118,6 +118,17 @@ type Command struct {
// PersistentPostRunE: PersistentPostRun but returns an error.
PersistentPostRunE func(cmd *Command, args []string) error
// persistentPreRunHooks are executed before the command or one of its children are executed
persistentPreRunHooks []func(cmd *Command, args []string) error
// preRunHooks are executed before the command is executed
preRunHooks []func(cmd *Command, args []string) error
// runHooks are executed when the command is executed
runHooks []func(cmd *Command, args []string) error
// postRunHooks are executed after the command has executed
postRunHooks []func(cmd *Command, args []string) error
// persistentPostRunHooks are executed after the command or one of its children have executed
persistentPostRunHooks []func(cmd *Command, args []string) error
// SilenceErrors is an option to quiet errors down stream.
SilenceErrors bool
@ -816,52 +827,104 @@ func (c *Command) execute(a []string) (err error) {
return err
}
for p := c; p != nil; p = p.Parent() {
if p.PersistentPreRunE != nil {
if err := p.PersistentPreRunE(c, argWoFlags); err != nil {
return err
}
break
} else if p.PersistentPreRun != nil {
p.PersistentPreRun(c, argWoFlags)
break
}
}
var persistentPreRunHooks []func(cmd *Command, args []string) error
preRunHooks := c.preRunHooks
runHooks := c.runHooks
postRunHooks := c.postRunHooks
var persistentPostRunHooks []func(cmd *Command, args []string) error
// Merge the PreRun functions into the preRunHooks slice
if c.PreRunE != nil {
if err := c.PreRunE(c, argWoFlags); err != nil {
return err
}
preRunHooks = append(preRunHooks, c.PreRunE)
} else if c.PreRun != nil {
c.PreRun(c, argWoFlags)
preRunHook := c.PreRun
preRunHooks = append(preRunHooks, func(cmd *Command, args []string) error {
preRunHook(cmd, args)
return nil
})
}
// Merge the Run functions into the runHooks slice
if c.RunE != nil {
runHooks = append(runHooks, c.RunE)
} else if c.Run != nil {
runHook := c.Run
runHooks = append(runHooks, func(cmd *Command, args []string) error {
runHook(cmd, args)
return nil
})
}
// Merge the PostRun functions into the runHooks slice
if c.PostRunE != nil {
postRunHooks = append(postRunHooks, c.PostRunE)
} else if c.PostRun != nil {
postRunHook := c.PostRun
postRunHooks = append(postRunHooks, func(cmd *Command, args []string) error {
postRunHook(cmd, args)
return nil
})
}
// Find and merge the Persistent*Run functions into the persistent*Run slices.
// If EnablePersistentRunOverride is set Persistent*Run from childs will override their parents.
// Any hooks registered through OnPersistent*Run will always be executed and cannot be overriden.
hasLegacyPersistentPreRun := false
hasLegacyPersistentPostRun := false
for p := c; p != nil; p = p.Parent() {
if !hasLegacyPersistentPreRun || !EnablePersistentRunOverride {
if p.PersistentPreRunE != nil {
persistentPreRunHooks = append([]func(cmd *Command, args []string) error{
p.PersistentPreRunE,
}, persistentPreRunHooks...)
hasLegacyPersistentPreRun = true
} else if p.PersistentPreRun != nil {
persistentPreRunHook := p.PersistentPreRun
persistentPreRunHooks = append([]func(cmd *Command, args []string) error{
func(cmd *Command, args []string) error {
persistentPreRunHook(cmd, args)
return nil
},
}, persistentPreRunHooks...)
hasLegacyPersistentPreRun = true
}
}
if !hasLegacyPersistentPostRun || !EnablePersistentRunOverride {
if p.PersistentPostRunE != nil {
persistentPostRunHooks = append(persistentPostRunHooks, p.PersistentPostRunE)
hasLegacyPersistentPostRun = true
} else if p.PersistentPostRun != nil {
persistentPostRunHook := p.PersistentPostRun
persistentPostRunHooks = append(persistentPostRunHooks, func(cmd *Command, args []string) error {
persistentPostRunHook(cmd, args)
return nil
})
hasLegacyPersistentPostRun = true
}
}
persistentPreRunHooks = append(p.persistentPreRunHooks, persistentPreRunHooks...)
persistentPostRunHooks = append(persistentPostRunHooks, p.persistentPostRunHooks...)
}
// Execute the hooks:
if err := c.executeHooks(&persistentPreRunHooks, argWoFlags); err != nil {
return err
}
if err := c.executeHooks(&preRunHooks, argWoFlags); err != nil {
return err
}
if err := c.validateRequiredFlags(); err != nil {
return err
}
if c.RunE != nil {
if err := c.RunE(c, argWoFlags); err != nil {
return err
}
} else {
c.Run(c, argWoFlags)
if err := c.executeHooks(&runHooks, argWoFlags); err != nil {
return err
}
if c.PostRunE != nil {
if err := c.PostRunE(c, argWoFlags); err != nil {
return err
}
} else if c.PostRun != nil {
c.PostRun(c, argWoFlags)
if err := c.executeHooks(&postRunHooks, argWoFlags); err != nil {
return err
}
for p := c; p != nil; p = p.Parent() {
if p.PersistentPostRunE != nil {
if err := p.PersistentPostRunE(c, argWoFlags); err != nil {
return err
}
break
} else if p.PersistentPostRun != nil {
p.PersistentPostRun(c, argWoFlags)
break
}
if err := c.executeHooks(&persistentPostRunHooks, argWoFlags); err != nil {
return err
}
return nil
@ -873,6 +936,43 @@ func (c *Command) preRun() {
}
}
// executeHooks executes a slice of hooks
func (c *Command) executeHooks(hooks *[]func(cmd *Command, args []string) error, args []string) error {
for _, x := range *hooks {
if err := x(c, args); err != nil {
return err
}
}
return nil
}
// OnPersistentPreRun registers one or more hooks on the command to be executed
// before the command or one of its children are executed
func (c *Command) OnPersistentPreRun(f ...func(cmd *Command, args []string) error) {
c.persistentPreRunHooks = append(c.persistentPreRunHooks, f...)
}
// OnPreRun registers one or more hooks on the command to be executed before the command is executed
func (c *Command) OnPreRun(f ...func(cmd *Command, args []string) error) {
c.preRunHooks = append(c.preRunHooks, f...)
}
// OnRun registers one or more hooks on the command to be executed when the command is executed
func (c *Command) OnRun(f ...func(cmd *Command, args []string) error) {
c.runHooks = append(c.runHooks, f...)
}
// OnPostRun registers one or more hooks on the command to be executed after the command has executed
func (c *Command) OnPostRun(f ...func(cmd *Command, args []string) error) {
c.postRunHooks = append(c.postRunHooks, f...)
}
// OnPersistentPostRun register one or more hooks on the command to be executed
// after the command or one of its children have executed
func (c *Command) OnPersistentPostRun(f ...func(cmd *Command, args []string) error) {
c.persistentPostRunHooks = append(c.persistentPostRunHooks, f...)
}
// ExecuteContext is the same as Execute(), but sets the ctx on the command.
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle functions.
func (c *Command) ExecuteContext(ctx context.Context) error {

View file

@ -1332,6 +1332,23 @@ func TestPersistentHooks(t *testing.T) {
childPersPostArgs string
)
var (
persParentPersPreArgs string
persParentPreArgs string
persParentRunArgs string
persParentPostArgs string
persParentPersPostArgs string
)
var (
persChildPersPreArgs string
persChildPreArgs string
persChildPreArgs2 string
persChildRunArgs string
persChildPostArgs string
persChildPersPostArgs string
)
parentCmd := &Command{
Use: "parent",
PersistentPreRun: func(_ *Command, args []string) {
@ -1371,6 +1388,52 @@ func TestPersistentHooks(t *testing.T) {
}
parentCmd.AddCommand(childCmd)
parentCmd.OnPersistentPreRun(func(_ *Command, args []string) error {
persParentPersPreArgs = strings.Join(args, " ")
return nil
})
parentCmd.OnPreRun(func(_ *Command, args []string) error {
persParentPreArgs = strings.Join(args, " ")
return nil
})
parentCmd.OnRun(func(_ *Command, args []string) error {
persParentRunArgs = strings.Join(args, " ")
return nil
})
parentCmd.OnPostRun(func(_ *Command, args []string) error {
persParentPostArgs = strings.Join(args, " ")
return nil
})
parentCmd.OnPersistentPostRun(func(_ *Command, args []string) error {
persParentPersPostArgs = strings.Join(args, " ")
return nil
})
childCmd.OnPersistentPreRun(func(_ *Command, args []string) error {
persChildPersPreArgs = strings.Join(args, " ")
return nil
})
childCmd.OnPreRun(func(_ *Command, args []string) error {
persChildPreArgs = strings.Join(args, " ")
return nil
})
childCmd.OnPreRun(func(_ *Command, args []string) error {
persChildPreArgs2 = strings.Join(args, " ") + " three"
return nil
})
childCmd.OnRun(func(_ *Command, args []string) error {
persChildRunArgs = strings.Join(args, " ")
return nil
})
childCmd.OnPostRun(func(_ *Command, args []string) error {
persChildPostArgs = strings.Join(args, " ")
return nil
})
childCmd.OnPersistentPostRun(func(_ *Command, args []string) error {
persChildPersPostArgs = strings.Join(args, " ")
return nil
})
output, err := executeCommand(parentCmd, "child", "one", "two")
if output != "" {
t.Errorf("Unexpected output: %v", output)
@ -1378,14 +1441,12 @@ func TestPersistentHooks(t *testing.T) {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// TODO: currently PersistenPreRun* defined in parent does not
// run if the matchin child subcommand has PersistenPreRun.
// If the behavior changes (https://github.com/spf13/cobra/issues/252)
// this test must be fixed.
if parentPersPreArgs != "" {
if EnablePersistentRunOverride && parentPersPreArgs != "" {
t.Errorf("Expected blank parentPersPreArgs, got %q", parentPersPreArgs)
}
if !EnablePersistentRunOverride && parentPersPreArgs != "one two" {
t.Errorf("Expected parentPersPreArgs %q, got %q", "one two", parentPersPreArgs)
}
if parentPreArgs != "" {
t.Errorf("Expected blank parentPreArgs, got %q", parentPreArgs)
}
@ -1395,14 +1456,12 @@ func TestPersistentHooks(t *testing.T) {
if parentPostArgs != "" {
t.Errorf("Expected blank parentPostArgs, got %q", parentPostArgs)
}
// TODO: currently PersistenPostRun* defined in parent does not
// run if the matchin child subcommand has PersistenPostRun.
// If the behavior changes (https://github.com/spf13/cobra/issues/252)
// this test must be fixed.
if parentPersPostArgs != "" {
if EnablePersistentRunOverride && parentPersPostArgs != "" {
t.Errorf("Expected blank parentPersPostArgs, got %q", parentPersPostArgs)
}
if !EnablePersistentRunOverride && parentPersPostArgs != "one two" {
t.Errorf("Expected parentPersPostArgs %q, got %q", "one two", parentPersPostArgs)
}
if childPersPreArgs != "one two" {
t.Errorf("Expected childPersPreArgs %q, got %q", "one two", childPersPreArgs)
}
@ -1418,6 +1477,49 @@ func TestPersistentHooks(t *testing.T) {
if childPersPostArgs != "one two" {
t.Errorf("Expected childPersPostArgs %q, got %q", "one two", childPersPostArgs)
}
// Test On*Run hooks
if persParentPersPreArgs != "one two" {
t.Errorf("Expected persParentPersPreArgs %q, got %q", "one two", persParentPersPreArgs)
}
if persParentPreArgs != "" {
t.Errorf("Expected persParentPreArgs %q, got %q", "one two", persParentPreArgs)
}
if persParentRunArgs != "" {
t.Errorf("Expected persParentRunArgs %q, got %q", "one two", persParentRunArgs)
}
if persParentPostArgs != "" {
t.Errorf("Expected persParentPostArgs %q, got %q", "one two", persParentPostArgs)
}
if persParentPersPostArgs != "one two" {
t.Errorf("Expected persParentPersPostArgs %q, got %q", "one two", persParentPersPostArgs)
}
if persChildPersPreArgs != "one two" {
t.Errorf("Expected persChildPersPreArgs %q, got %q", "one two", persChildPersPreArgs)
}
if persChildPreArgs != "one two" {
t.Errorf("Expected persChildPreArgs %q, got %q", "one two", persChildPreArgs)
}
if persChildPreArgs2 != "one two three" {
t.Errorf("Expected persChildPreArgs %q, got %q", "one two three", persChildPreArgs2)
}
if persChildRunArgs != "one two" {
t.Errorf("Expected persChildRunArgs %q, got %q", "one two", persChildRunArgs)
}
if persChildPostArgs != "one two" {
t.Errorf("Expected persChildPostArgs %q, got %q", "one two", persChildPostArgs)
}
if persChildPersPostArgs != "one two" {
t.Errorf("Expected persChildPersPostArgs %q, got %q", "one two", persChildPersPostArgs)
}
}
func TestPersistentHooksWoOverride(t *testing.T) {
EnablePersistentRunOverride = false
TestPersistentHooks(t)
EnablePersistentRunOverride = true
}
// Related to https://github.com/spf13/cobra/issues/521.