diff --git a/command.go b/command.go index a7d90886..c1928825 100644 --- a/command.go +++ b/command.go @@ -51,6 +51,16 @@ type Command struct { // Run runs the command. // The args are the arguments after the command name. Run func(cmd *Command, args []string) + // PreRun runs the command after the flags are parsed and before run. + // The args are the arguments after the command name. + PreRun func(cmd *Command, args []string) + // PostRun runs the command after run. + // The args are the arguments after the command name. + PostRun func(cmd *Command, args []string) + // PreRun which children of this command will inherit. + PersistentPreRun func(cmd *Command, args []string) + // PostRun which children of this command will inherit. + PersistentPostRun func(cmd *Command, args []string) // Commands is the list of commands supported by this program. commands []*Command // Parent Command for this command @@ -441,7 +451,23 @@ func (c *Command) execute(a []string) (err error) { c.preRun() argWoFlags := c.Flags().Args() + + if c.PersistentPreRun != nil { + c.PersistentPreRun(c, argWoFlags) + } + if c.PreRun != nil { + c.PreRun(c, argWoFlags) + } + c.Run(c, argWoFlags) + + if c.PostRun != nil { + c.PostRun(c, argWoFlags) + } + if c.PersistentPostRun != nil { + c.PersistentPostRun(c, argWoFlags) + } + return nil } } @@ -528,10 +554,25 @@ func (c *Command) Execute() (err error) { // print the usage c.Usage() } else { - // Only flags left... Call root.Run c.preRun() + + if c.PersistentPreRun != nil { + c.PersistentPreRun(c, argWoFlags) + } + if c.PreRun != nil { + c.PreRun(c, argWoFlags) + } + + // Only flags left... Call root.Run c.Run(c, argWoFlags) err = nil + + if c.PostRun != nil { + c.PostRun(c, argWoFlags) + } + if c.PersistentPostRun != nil { + c.PersistentPostRun(c, argWoFlags) + } } } } @@ -560,7 +601,9 @@ func (c *Command) initHelp() { Short: "Help about any command", Long: `Help provides help for any command in the application. Simply type ` + c.Name() + ` help [path to command] for full details.`, - Run: c.HelpFunc(), + Run: c.HelpFunc(), + PersistentPreRun: func(cmd *Command, args []string) {}, + PersistentPostRun: func(cmd *Command, args []string) {}, } } c.AddCommand(c.helpCommand) @@ -600,6 +643,14 @@ func (c *Command) AddCommand(cmds ...*Command) { c.commandsMaxNameLen = nameLen } c.commands = append(c.commands, x) + + // Pass on peristent pre/post functions to children + if x.PersistentPreRun == nil { + x.PersistentPreRun = c.PersistentPreRun + } + if x.PersistentPostRun == nil { + x.PersistentPostRun = c.PersistentPostRun + } } }