diff --git a/command.go b/command.go index 64f1d5f4..bc249c27 100644 --- a/command.go +++ b/command.go @@ -35,6 +35,18 @@ const FlagSetByCobraAnnotation = "cobra_annotation_flag_set_by_cobra" // FParseErrWhitelist configures Flag parse errors to be ignored type FParseErrWhitelist flag.ParseErrorsWhitelist +// FErrorHandling defines how to handle flag parsing errors +type FErrorHandling flag.ErrorHandling + +const ( + // ContinueOnError will return an err from Parse() if an error is found + ContinueOnError FErrorHandling = iota + // ExitOnError will call os.Exit(2) if an error is found when parsing + ExitOnError + // PanicOnError will panic() if an error is found when parsing flags + PanicOnError +) + // Command is just that, a command for your application. // E.g. 'go run ...' - 'run' is the command. Cobra requires // you to define the usage and description as part of your command @@ -226,6 +238,9 @@ type Command struct { // SuggestionsMinimumDistance defines minimum levenshtein distance to display suggestions. // Must be > 0. SuggestionsMinimumDistance int + + // FlagErrorHandling defines how to handle flag parsing errors. Defaults to flag.ContinueOnError. + FlagErrorHandling FErrorHandling } // Context returns underlying command context. If command was executed @@ -1480,7 +1495,7 @@ func (c *Command) GlobalNormalizationFunc() func(f *flag.FlagSet, name string) f // to this command (local and persistent declared here and by all parents). func (c *Command) Flags() *flag.FlagSet { if c.flags == nil { - c.flags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.flags = flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) if c.flagErrorBuf == nil { c.flagErrorBuf = new(bytes.Buffer) } @@ -1494,7 +1509,7 @@ func (c *Command) Flags() *flag.FlagSet { func (c *Command) LocalNonPersistentFlags() *flag.FlagSet { persistentFlags := c.PersistentFlags() - out := flag.NewFlagSet(c.Name(), flag.ContinueOnError) + out := flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) c.LocalFlags().VisitAll(func(f *flag.Flag) { if persistentFlags.Lookup(f.Name) == nil { out.AddFlag(f) @@ -1508,7 +1523,7 @@ func (c *Command) LocalFlags() *flag.FlagSet { c.mergePersistentFlags() if c.lflags == nil { - c.lflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.lflags = flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) if c.flagErrorBuf == nil { c.flagErrorBuf = new(bytes.Buffer) } @@ -1535,7 +1550,7 @@ func (c *Command) InheritedFlags() *flag.FlagSet { c.mergePersistentFlags() if c.iflags == nil { - c.iflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.iflags = flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) if c.flagErrorBuf == nil { c.flagErrorBuf = new(bytes.Buffer) } @@ -1563,7 +1578,7 @@ func (c *Command) NonInheritedFlags() *flag.FlagSet { // PersistentFlags returns the persistent FlagSet specifically set in the current command. func (c *Command) PersistentFlags() *flag.FlagSet { if c.pflags == nil { - c.pflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.pflags = flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) if c.flagErrorBuf == nil { c.flagErrorBuf = new(bytes.Buffer) } @@ -1576,9 +1591,9 @@ func (c *Command) PersistentFlags() *flag.FlagSet { func (c *Command) ResetFlags() { c.flagErrorBuf = new(bytes.Buffer) c.flagErrorBuf.Reset() - c.flags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.flags = flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) c.flags.SetOutput(c.flagErrorBuf) - c.pflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.pflags = flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) c.pflags.SetOutput(c.flagErrorBuf) c.lflags = nil @@ -1695,7 +1710,7 @@ func (c *Command) mergePersistentFlags() { // If c.parentsPflags == nil, it makes new. func (c *Command) updateParentsPflags() { if c.parentsPflags == nil { - c.parentsPflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError) + c.parentsPflags = flag.NewFlagSet(c.Name(), (flag.ErrorHandling)(c.FlagErrorHandling)) c.parentsPflags.SetOutput(c.flagErrorBuf) c.parentsPflags.SortFlags = false }