add RequiredFlagsErrorFunc and FlagGroupsErrorFunc

This commit is contained in:
Sergey Vilgelm 2022-06-09 15:03:02 -07:00
parent 5c62c6a057
commit 823517c128
No known key found for this signature in database
GPG key ID: 08D0E2FF778887E6
2 changed files with 124 additions and 9 deletions

View file

@ -149,6 +149,12 @@ type Command struct {
// flagErrorFunc is func defined by user and it's called when the parsing of // flagErrorFunc is func defined by user and it's called when the parsing of
// flags returns an error. // flags returns an error.
flagErrorFunc func(*Command, error) error flagErrorFunc func(*Command, error) error
// requiredFlagErrorFunc is func defined by user and it's called when the parsing of
// required flags returns an error.
requiredFlagErrorFunc func(*Command, error) error
// flagGroupsErrorFunc is func defined by user and it's called when the parsing of
// required flags returns an error.
flagGroupsErrorFunc func(*Command, error) error
// helpTemplate is help template defined by user. // helpTemplate is help template defined by user.
helpTemplate string helpTemplate string
// helpFunc is help func defined by user. // helpFunc is help func defined by user.
@ -283,12 +289,6 @@ func (c *Command) SetUsageTemplate(s string) {
c.usageTemplate = s c.usageTemplate = s
} }
// SetFlagErrorFunc sets a function to generate an error when flag parsing
// fails.
func (c *Command) SetFlagErrorFunc(f func(*Command, error) error) {
c.flagErrorFunc = f
}
// SetHelpFunc sets help function. Can be defined by Application. // SetHelpFunc sets help function. Can be defined by Application.
func (c *Command) SetHelpFunc(f func(*Command, []string)) { func (c *Command) SetHelpFunc(f func(*Command, []string)) {
c.helpFunc = f c.helpFunc = f
@ -444,6 +444,12 @@ func (c *Command) UsageString() string {
return bb.String() return bb.String()
} }
// SetFlagErrorFunc sets a function to generate an error when flag parsing
// fails.
func (c *Command) SetFlagErrorFunc(f func(*Command, error) error) {
c.flagErrorFunc = f
}
// FlagErrorFunc returns either the function set by SetFlagErrorFunc for this // FlagErrorFunc returns either the function set by SetFlagErrorFunc for this
// command or a parent, or it returns a function which returns the original // command or a parent, or it returns a function which returns the original
// error. // error.
@ -451,7 +457,6 @@ func (c *Command) FlagErrorFunc() (f func(*Command, error) error) {
if c.flagErrorFunc != nil { if c.flagErrorFunc != nil {
return c.flagErrorFunc return c.flagErrorFunc
} }
if c.HasParent() { if c.HasParent() {
return c.parent.FlagErrorFunc() return c.parent.FlagErrorFunc()
} }
@ -460,6 +465,48 @@ func (c *Command) FlagErrorFunc() (f func(*Command, error) error) {
} }
} }
// SetRequiredFlagsErrorFunc sets a function to generate an error when
// validating of required flags fails.
func (c *Command) SetRequiredFlagsErrorFunc(f func(*Command, error) error) {
c.requiredFlagErrorFunc = f
}
// RequiredFlagsErrorFunc returns either the function set by
// SetRequiredFlagsErrorFunc for this command or a parent, or it returns a
// function which returns the original error.
func (c *Command) RequiredFlagsErrorFunc() (f func(*Command, error) error) {
if c.requiredFlagErrorFunc != nil {
return c.requiredFlagErrorFunc
}
if c.HasParent() {
return c.parent.RequiredFlagsErrorFunc()
}
return func(c *Command, err error) error {
return err
}
}
// SetFlagGroupsErrorFunc sets a function to generate an error when validating
// of flag groups fails.
func (c *Command) SetFlagGroupsErrorFunc(f func(*Command, error) error) {
c.flagGroupsErrorFunc = f
}
// FlagGroupsErrorFunc returns either the function set by
// SetFlagGroupsErrorFunc for this command or a parent, or it returns a
// function which returns the original error.
func (c *Command) FlagGroupsErrorFunc() (f func(*Command, error) error) {
if c.flagGroupsErrorFunc != nil {
return c.flagGroupsErrorFunc
}
if c.HasParent() {
return c.parent.FlagGroupsErrorFunc()
}
return func(c *Command, err error) error {
return err
}
}
var minUsagePadding = 25 var minUsagePadding = 25
// UsagePadding return padding for the usage. // UsagePadding return padding for the usage.
@ -861,10 +908,10 @@ func (c *Command) execute(a []string) (err error) {
} }
if err := c.validateRequiredFlags(); err != nil { if err := c.validateRequiredFlags(); err != nil {
return c.FlagErrorFunc()(c, err) return c.RequiredFlagsErrorFunc()(c, err)
} }
if err := c.validateFlagGroups(); err != nil { if err := c.validateFlagGroups(); err != nil {
return c.FlagErrorFunc()(c, err) return c.FlagGroupsErrorFunc()(c, err)
} }
if c.RunE != nil { if c.RunE != nil {

View file

@ -1723,6 +1723,74 @@ func TestFlagErrorFunc(t *testing.T) {
} }
} }
func TestRequiredFlagsErrorFunc(t *testing.T) {
c := &Command{Use: "c", Run: emptyRun}
ff := c.Flags()
ff.Bool("required-flag-1", false, "required flag #1")
ff.Bool("required-flag-2", false, "required flag #1")
if err := c.MarkFlagRequired("required-flag-1"); err != nil {
t.Errorf("unexpected error %v", err)
}
if err := c.MarkFlagRequired("required-flag-2"); err != nil {
t.Errorf("unexpected error %v", err)
}
expectedFmt := "This is expected: %v"
c.SetRequiredFlagsErrorFunc(func(_ *Command, err error) error {
return fmt.Errorf(expectedFmt, err)
})
_, err := executeCommand(c)
got := err.Error()
expected := fmt.Sprintf(expectedFmt, `required flag(s) "required-flag-1", "required-flag-2" not set`)
if got != expected {
t.Errorf("Expected %v, got %v", expected, got)
}
}
func TestFlagGroupsErrorFunc_RequiredTogether(t *testing.T) {
c := &Command{Use: "c", Run: emptyRun}
ff := c.Flags()
ff.Bool("required-flag-1", false, "required flag #1")
ff.Bool("required-flag-2", false, "required flag #1")
c.MarkFlagsRequiredTogether("required-flag-1", "required-flag-2")
expectedFmt := "This is expected: %v"
c.SetFlagGroupsErrorFunc(func(_ *Command, err error) error {
return fmt.Errorf(expectedFmt, err)
})
_, err := executeCommand(c, "--required-flag-1")
got := err.Error()
expected := fmt.Sprintf(expectedFmt, `if any flags in the group [required-flag-1 required-flag-2] are set they must all be set; missing [required-flag-2]`)
if got != expected {
t.Errorf("Expected %v, got %v", expected, got)
}
}
func TestFlagGroupsErrorFunc_MutuallyExclusive(t *testing.T) {
c := &Command{Use: "c", Run: emptyRun}
ff := c.Flags()
ff.Bool("required-flag-1", false, "required flag #1")
ff.Bool("required-flag-2", false, "required flag #1")
c.MarkFlagsMutuallyExclusive("required-flag-1", "required-flag-2")
expectedFmt := "This is expected: %v"
c.SetFlagGroupsErrorFunc(func(_ *Command, err error) error {
return fmt.Errorf(expectedFmt, err)
})
_, err := executeCommand(c, "--required-flag-1", "--required-flag-2")
got := err.Error()
expected := fmt.Sprintf(expectedFmt, `if any flags in the group [required-flag-1 required-flag-2] are set none of the others can be; [required-flag-1 required-flag-2] were all set`)
if got != expected {
t.Errorf("Expected %v, got %v", expected, got)
}
}
// TestSortedFlags checks, // TestSortedFlags checks,
// if cmd.LocalFlags() is unsorted when cmd.Flags().SortFlags set to false. // if cmd.LocalFlags() is unsorted when cmd.Flags().SortFlags set to false.
// Related to https://github.com/spf13/cobra/issues/404. // Related to https://github.com/spf13/cobra/issues/404.