diff --git a/command.go b/command.go index eca65809..f74237d4 100644 --- a/command.go +++ b/command.go @@ -149,6 +149,12 @@ type Command struct { // flagErrorFunc is func defined by user and it's called when the parsing of // flags returns an error. flagErrorFunc func(*Command, error) error + // requiredFlagErrorFunc is func defined by user and it's called when the parsing of + // required flags returns an error. + requiredFlagErrorFunc func(*Command, error) error + // flagGroupsErrorFunc is func defined by user and it's called when the parsing of + // required flags returns an error. + flagGroupsErrorFunc func(*Command, error) error // helpTemplate is help template defined by user. helpTemplate string // helpFunc is help func defined by user. @@ -283,12 +289,6 @@ func (c *Command) SetUsageTemplate(s string) { c.usageTemplate = s } -// SetFlagErrorFunc sets a function to generate an error when flag parsing -// fails. -func (c *Command) SetFlagErrorFunc(f func(*Command, error) error) { - c.flagErrorFunc = f -} - // SetHelpFunc sets help function. Can be defined by Application. func (c *Command) SetHelpFunc(f func(*Command, []string)) { c.helpFunc = f @@ -444,6 +444,12 @@ func (c *Command) UsageString() string { return bb.String() } +// SetFlagErrorFunc sets a function to generate an error when flag parsing +// fails. +func (c *Command) SetFlagErrorFunc(f func(*Command, error) error) { + c.flagErrorFunc = f +} + // FlagErrorFunc returns either the function set by SetFlagErrorFunc for this // command or a parent, or it returns a function which returns the original // error. @@ -451,7 +457,6 @@ func (c *Command) FlagErrorFunc() (f func(*Command, error) error) { if c.flagErrorFunc != nil { return c.flagErrorFunc } - if c.HasParent() { return c.parent.FlagErrorFunc() } @@ -460,6 +465,48 @@ func (c *Command) FlagErrorFunc() (f func(*Command, error) error) { } } +// SetRequiredFlagsErrorFunc sets a function to generate an error when +// validating of required flags fails. +func (c *Command) SetRequiredFlagsErrorFunc(f func(*Command, error) error) { + c.requiredFlagErrorFunc = f +} + +// RequiredFlagsErrorFunc returns either the function set by +// SetRequiredFlagsErrorFunc for this command or a parent, or it returns a +// function which returns the original error. +func (c *Command) RequiredFlagsErrorFunc() (f func(*Command, error) error) { + if c.requiredFlagErrorFunc != nil { + return c.requiredFlagErrorFunc + } + if c.HasParent() { + return c.parent.RequiredFlagsErrorFunc() + } + return func(c *Command, err error) error { + return err + } +} + +// SetFlagGroupsErrorFunc sets a function to generate an error when validating +// of flag groups fails. +func (c *Command) SetFlagGroupsErrorFunc(f func(*Command, error) error) { + c.flagGroupsErrorFunc = f +} + +// FlagGroupsErrorFunc returns either the function set by +// SetFlagGroupsErrorFunc for this command or a parent, or it returns a +// function which returns the original error. +func (c *Command) FlagGroupsErrorFunc() (f func(*Command, error) error) { + if c.flagGroupsErrorFunc != nil { + return c.flagGroupsErrorFunc + } + if c.HasParent() { + return c.parent.FlagGroupsErrorFunc() + } + return func(c *Command, err error) error { + return err + } +} + var minUsagePadding = 25 // UsagePadding return padding for the usage. @@ -861,10 +908,10 @@ func (c *Command) execute(a []string) (err error) { } if err := c.validateRequiredFlags(); err != nil { - return c.FlagErrorFunc()(c, err) + return c.RequiredFlagsErrorFunc()(c, err) } if err := c.validateFlagGroups(); err != nil { - return c.FlagErrorFunc()(c, err) + return c.FlagGroupsErrorFunc()(c, err) } if c.RunE != nil { diff --git a/command_test.go b/command_test.go index d48fef1a..b6d346c7 100644 --- a/command_test.go +++ b/command_test.go @@ -1723,6 +1723,74 @@ func TestFlagErrorFunc(t *testing.T) { } } +func TestRequiredFlagsErrorFunc(t *testing.T) { + c := &Command{Use: "c", Run: emptyRun} + ff := c.Flags() + ff.Bool("required-flag-1", false, "required flag #1") + ff.Bool("required-flag-2", false, "required flag #1") + if err := c.MarkFlagRequired("required-flag-1"); err != nil { + t.Errorf("unexpected error %v", err) + } + if err := c.MarkFlagRequired("required-flag-2"); err != nil { + t.Errorf("unexpected error %v", err) + } + + expectedFmt := "This is expected: %v" + c.SetRequiredFlagsErrorFunc(func(_ *Command, err error) error { + return fmt.Errorf(expectedFmt, err) + }) + + _, err := executeCommand(c) + + got := err.Error() + expected := fmt.Sprintf(expectedFmt, `required flag(s) "required-flag-1", "required-flag-2" not set`) + if got != expected { + t.Errorf("Expected %v, got %v", expected, got) + } +} + +func TestFlagGroupsErrorFunc_RequiredTogether(t *testing.T) { + c := &Command{Use: "c", Run: emptyRun} + ff := c.Flags() + ff.Bool("required-flag-1", false, "required flag #1") + ff.Bool("required-flag-2", false, "required flag #1") + c.MarkFlagsRequiredTogether("required-flag-1", "required-flag-2") + + expectedFmt := "This is expected: %v" + c.SetFlagGroupsErrorFunc(func(_ *Command, err error) error { + return fmt.Errorf(expectedFmt, err) + }) + + _, err := executeCommand(c, "--required-flag-1") + + got := err.Error() + expected := fmt.Sprintf(expectedFmt, `if any flags in the group [required-flag-1 required-flag-2] are set they must all be set; missing [required-flag-2]`) + if got != expected { + t.Errorf("Expected %v, got %v", expected, got) + } +} + +func TestFlagGroupsErrorFunc_MutuallyExclusive(t *testing.T) { + c := &Command{Use: "c", Run: emptyRun} + ff := c.Flags() + ff.Bool("required-flag-1", false, "required flag #1") + ff.Bool("required-flag-2", false, "required flag #1") + c.MarkFlagsMutuallyExclusive("required-flag-1", "required-flag-2") + + expectedFmt := "This is expected: %v" + c.SetFlagGroupsErrorFunc(func(_ *Command, err error) error { + return fmt.Errorf(expectedFmt, err) + }) + + _, err := executeCommand(c, "--required-flag-1", "--required-flag-2") + + got := err.Error() + expected := fmt.Sprintf(expectedFmt, `if any flags in the group [required-flag-1 required-flag-2] are set none of the others can be; [required-flag-1 required-flag-2] were all set`) + if got != expected { + t.Errorf("Expected %v, got %v", expected, got) + } +} + // TestSortedFlags checks, // if cmd.LocalFlags() is unsorted when cmd.Flags().SortFlags set to false. // Related to https://github.com/spf13/cobra/issues/404.