mirror of
https://github.com/spf13/cobra
synced 2025-05-05 04:47:22 +00:00
add RequiredFlagsErrorFunc and FlagGroupsErrorFunc
This commit is contained in:
parent
5c62c6a057
commit
823517c128
2 changed files with 124 additions and 9 deletions
65
command.go
65
command.go
|
@ -149,6 +149,12 @@ type Command struct {
|
|||
// flagErrorFunc is func defined by user and it's called when the parsing of
|
||||
// flags returns an error.
|
||||
flagErrorFunc func(*Command, error) error
|
||||
// requiredFlagErrorFunc is func defined by user and it's called when the parsing of
|
||||
// required flags returns an error.
|
||||
requiredFlagErrorFunc func(*Command, error) error
|
||||
// flagGroupsErrorFunc is func defined by user and it's called when the parsing of
|
||||
// required flags returns an error.
|
||||
flagGroupsErrorFunc func(*Command, error) error
|
||||
// helpTemplate is help template defined by user.
|
||||
helpTemplate string
|
||||
// helpFunc is help func defined by user.
|
||||
|
@ -283,12 +289,6 @@ func (c *Command) SetUsageTemplate(s string) {
|
|||
c.usageTemplate = s
|
||||
}
|
||||
|
||||
// SetFlagErrorFunc sets a function to generate an error when flag parsing
|
||||
// fails.
|
||||
func (c *Command) SetFlagErrorFunc(f func(*Command, error) error) {
|
||||
c.flagErrorFunc = f
|
||||
}
|
||||
|
||||
// SetHelpFunc sets help function. Can be defined by Application.
|
||||
func (c *Command) SetHelpFunc(f func(*Command, []string)) {
|
||||
c.helpFunc = f
|
||||
|
@ -444,6 +444,12 @@ func (c *Command) UsageString() string {
|
|||
return bb.String()
|
||||
}
|
||||
|
||||
// SetFlagErrorFunc sets a function to generate an error when flag parsing
|
||||
// fails.
|
||||
func (c *Command) SetFlagErrorFunc(f func(*Command, error) error) {
|
||||
c.flagErrorFunc = f
|
||||
}
|
||||
|
||||
// FlagErrorFunc returns either the function set by SetFlagErrorFunc for this
|
||||
// command or a parent, or it returns a function which returns the original
|
||||
// error.
|
||||
|
@ -451,7 +457,6 @@ func (c *Command) FlagErrorFunc() (f func(*Command, error) error) {
|
|||
if c.flagErrorFunc != nil {
|
||||
return c.flagErrorFunc
|
||||
}
|
||||
|
||||
if c.HasParent() {
|
||||
return c.parent.FlagErrorFunc()
|
||||
}
|
||||
|
@ -460,6 +465,48 @@ func (c *Command) FlagErrorFunc() (f func(*Command, error) error) {
|
|||
}
|
||||
}
|
||||
|
||||
// SetRequiredFlagsErrorFunc sets a function to generate an error when
|
||||
// validating of required flags fails.
|
||||
func (c *Command) SetRequiredFlagsErrorFunc(f func(*Command, error) error) {
|
||||
c.requiredFlagErrorFunc = f
|
||||
}
|
||||
|
||||
// RequiredFlagsErrorFunc returns either the function set by
|
||||
// SetRequiredFlagsErrorFunc for this command or a parent, or it returns a
|
||||
// function which returns the original error.
|
||||
func (c *Command) RequiredFlagsErrorFunc() (f func(*Command, error) error) {
|
||||
if c.requiredFlagErrorFunc != nil {
|
||||
return c.requiredFlagErrorFunc
|
||||
}
|
||||
if c.HasParent() {
|
||||
return c.parent.RequiredFlagsErrorFunc()
|
||||
}
|
||||
return func(c *Command, err error) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// SetFlagGroupsErrorFunc sets a function to generate an error when validating
|
||||
// of flag groups fails.
|
||||
func (c *Command) SetFlagGroupsErrorFunc(f func(*Command, error) error) {
|
||||
c.flagGroupsErrorFunc = f
|
||||
}
|
||||
|
||||
// FlagGroupsErrorFunc returns either the function set by
|
||||
// SetFlagGroupsErrorFunc for this command or a parent, or it returns a
|
||||
// function which returns the original error.
|
||||
func (c *Command) FlagGroupsErrorFunc() (f func(*Command, error) error) {
|
||||
if c.flagGroupsErrorFunc != nil {
|
||||
return c.flagGroupsErrorFunc
|
||||
}
|
||||
if c.HasParent() {
|
||||
return c.parent.FlagGroupsErrorFunc()
|
||||
}
|
||||
return func(c *Command, err error) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var minUsagePadding = 25
|
||||
|
||||
// UsagePadding return padding for the usage.
|
||||
|
@ -861,10 +908,10 @@ func (c *Command) execute(a []string) (err error) {
|
|||
}
|
||||
|
||||
if err := c.validateRequiredFlags(); err != nil {
|
||||
return c.FlagErrorFunc()(c, err)
|
||||
return c.RequiredFlagsErrorFunc()(c, err)
|
||||
}
|
||||
if err := c.validateFlagGroups(); err != nil {
|
||||
return c.FlagErrorFunc()(c, err)
|
||||
return c.FlagGroupsErrorFunc()(c, err)
|
||||
}
|
||||
|
||||
if c.RunE != nil {
|
||||
|
|
|
@ -1723,6 +1723,74 @@ func TestFlagErrorFunc(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRequiredFlagsErrorFunc(t *testing.T) {
|
||||
c := &Command{Use: "c", Run: emptyRun}
|
||||
ff := c.Flags()
|
||||
ff.Bool("required-flag-1", false, "required flag #1")
|
||||
ff.Bool("required-flag-2", false, "required flag #1")
|
||||
if err := c.MarkFlagRequired("required-flag-1"); err != nil {
|
||||
t.Errorf("unexpected error %v", err)
|
||||
}
|
||||
if err := c.MarkFlagRequired("required-flag-2"); err != nil {
|
||||
t.Errorf("unexpected error %v", err)
|
||||
}
|
||||
|
||||
expectedFmt := "This is expected: %v"
|
||||
c.SetRequiredFlagsErrorFunc(func(_ *Command, err error) error {
|
||||
return fmt.Errorf(expectedFmt, err)
|
||||
})
|
||||
|
||||
_, err := executeCommand(c)
|
||||
|
||||
got := err.Error()
|
||||
expected := fmt.Sprintf(expectedFmt, `required flag(s) "required-flag-1", "required-flag-2" not set`)
|
||||
if got != expected {
|
||||
t.Errorf("Expected %v, got %v", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlagGroupsErrorFunc_RequiredTogether(t *testing.T) {
|
||||
c := &Command{Use: "c", Run: emptyRun}
|
||||
ff := c.Flags()
|
||||
ff.Bool("required-flag-1", false, "required flag #1")
|
||||
ff.Bool("required-flag-2", false, "required flag #1")
|
||||
c.MarkFlagsRequiredTogether("required-flag-1", "required-flag-2")
|
||||
|
||||
expectedFmt := "This is expected: %v"
|
||||
c.SetFlagGroupsErrorFunc(func(_ *Command, err error) error {
|
||||
return fmt.Errorf(expectedFmt, err)
|
||||
})
|
||||
|
||||
_, err := executeCommand(c, "--required-flag-1")
|
||||
|
||||
got := err.Error()
|
||||
expected := fmt.Sprintf(expectedFmt, `if any flags in the group [required-flag-1 required-flag-2] are set they must all be set; missing [required-flag-2]`)
|
||||
if got != expected {
|
||||
t.Errorf("Expected %v, got %v", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlagGroupsErrorFunc_MutuallyExclusive(t *testing.T) {
|
||||
c := &Command{Use: "c", Run: emptyRun}
|
||||
ff := c.Flags()
|
||||
ff.Bool("required-flag-1", false, "required flag #1")
|
||||
ff.Bool("required-flag-2", false, "required flag #1")
|
||||
c.MarkFlagsMutuallyExclusive("required-flag-1", "required-flag-2")
|
||||
|
||||
expectedFmt := "This is expected: %v"
|
||||
c.SetFlagGroupsErrorFunc(func(_ *Command, err error) error {
|
||||
return fmt.Errorf(expectedFmt, err)
|
||||
})
|
||||
|
||||
_, err := executeCommand(c, "--required-flag-1", "--required-flag-2")
|
||||
|
||||
got := err.Error()
|
||||
expected := fmt.Sprintf(expectedFmt, `if any flags in the group [required-flag-1 required-flag-2] are set none of the others can be; [required-flag-1 required-flag-2] were all set`)
|
||||
if got != expected {
|
||||
t.Errorf("Expected %v, got %v", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSortedFlags checks,
|
||||
// if cmd.LocalFlags() is unsorted when cmd.Flags().SortFlags set to false.
|
||||
// Related to https://github.com/spf13/cobra/issues/404.
|
||||
|
|
Loading…
Add table
Reference in a new issue