mirror of
https://github.com/spf13/cobra
synced 2025-05-05 04:47:22 +00:00
add RequiredFlagsErrorFunc and FlagGroupsErrorFunc
This commit is contained in:
parent
5c62c6a057
commit
823517c128
2 changed files with 124 additions and 9 deletions
65
command.go
65
command.go
|
@ -149,6 +149,12 @@ type Command struct {
|
||||||
// flagErrorFunc is func defined by user and it's called when the parsing of
|
// flagErrorFunc is func defined by user and it's called when the parsing of
|
||||||
// flags returns an error.
|
// flags returns an error.
|
||||||
flagErrorFunc func(*Command, error) error
|
flagErrorFunc func(*Command, error) error
|
||||||
|
// requiredFlagErrorFunc is func defined by user and it's called when the parsing of
|
||||||
|
// required flags returns an error.
|
||||||
|
requiredFlagErrorFunc func(*Command, error) error
|
||||||
|
// flagGroupsErrorFunc is func defined by user and it's called when the parsing of
|
||||||
|
// required flags returns an error.
|
||||||
|
flagGroupsErrorFunc func(*Command, error) error
|
||||||
// helpTemplate is help template defined by user.
|
// helpTemplate is help template defined by user.
|
||||||
helpTemplate string
|
helpTemplate string
|
||||||
// helpFunc is help func defined by user.
|
// helpFunc is help func defined by user.
|
||||||
|
@ -283,12 +289,6 @@ func (c *Command) SetUsageTemplate(s string) {
|
||||||
c.usageTemplate = s
|
c.usageTemplate = s
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetFlagErrorFunc sets a function to generate an error when flag parsing
|
|
||||||
// fails.
|
|
||||||
func (c *Command) SetFlagErrorFunc(f func(*Command, error) error) {
|
|
||||||
c.flagErrorFunc = f
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetHelpFunc sets help function. Can be defined by Application.
|
// SetHelpFunc sets help function. Can be defined by Application.
|
||||||
func (c *Command) SetHelpFunc(f func(*Command, []string)) {
|
func (c *Command) SetHelpFunc(f func(*Command, []string)) {
|
||||||
c.helpFunc = f
|
c.helpFunc = f
|
||||||
|
@ -444,6 +444,12 @@ func (c *Command) UsageString() string {
|
||||||
return bb.String()
|
return bb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFlagErrorFunc sets a function to generate an error when flag parsing
|
||||||
|
// fails.
|
||||||
|
func (c *Command) SetFlagErrorFunc(f func(*Command, error) error) {
|
||||||
|
c.flagErrorFunc = f
|
||||||
|
}
|
||||||
|
|
||||||
// FlagErrorFunc returns either the function set by SetFlagErrorFunc for this
|
// FlagErrorFunc returns either the function set by SetFlagErrorFunc for this
|
||||||
// command or a parent, or it returns a function which returns the original
|
// command or a parent, or it returns a function which returns the original
|
||||||
// error.
|
// error.
|
||||||
|
@ -451,7 +457,6 @@ func (c *Command) FlagErrorFunc() (f func(*Command, error) error) {
|
||||||
if c.flagErrorFunc != nil {
|
if c.flagErrorFunc != nil {
|
||||||
return c.flagErrorFunc
|
return c.flagErrorFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.HasParent() {
|
if c.HasParent() {
|
||||||
return c.parent.FlagErrorFunc()
|
return c.parent.FlagErrorFunc()
|
||||||
}
|
}
|
||||||
|
@ -460,6 +465,48 @@ func (c *Command) FlagErrorFunc() (f func(*Command, error) error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRequiredFlagsErrorFunc sets a function to generate an error when
|
||||||
|
// validating of required flags fails.
|
||||||
|
func (c *Command) SetRequiredFlagsErrorFunc(f func(*Command, error) error) {
|
||||||
|
c.requiredFlagErrorFunc = f
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequiredFlagsErrorFunc returns either the function set by
|
||||||
|
// SetRequiredFlagsErrorFunc for this command or a parent, or it returns a
|
||||||
|
// function which returns the original error.
|
||||||
|
func (c *Command) RequiredFlagsErrorFunc() (f func(*Command, error) error) {
|
||||||
|
if c.requiredFlagErrorFunc != nil {
|
||||||
|
return c.requiredFlagErrorFunc
|
||||||
|
}
|
||||||
|
if c.HasParent() {
|
||||||
|
return c.parent.RequiredFlagsErrorFunc()
|
||||||
|
}
|
||||||
|
return func(c *Command, err error) error {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFlagGroupsErrorFunc sets a function to generate an error when validating
|
||||||
|
// of flag groups fails.
|
||||||
|
func (c *Command) SetFlagGroupsErrorFunc(f func(*Command, error) error) {
|
||||||
|
c.flagGroupsErrorFunc = f
|
||||||
|
}
|
||||||
|
|
||||||
|
// FlagGroupsErrorFunc returns either the function set by
|
||||||
|
// SetFlagGroupsErrorFunc for this command or a parent, or it returns a
|
||||||
|
// function which returns the original error.
|
||||||
|
func (c *Command) FlagGroupsErrorFunc() (f func(*Command, error) error) {
|
||||||
|
if c.flagGroupsErrorFunc != nil {
|
||||||
|
return c.flagGroupsErrorFunc
|
||||||
|
}
|
||||||
|
if c.HasParent() {
|
||||||
|
return c.parent.FlagGroupsErrorFunc()
|
||||||
|
}
|
||||||
|
return func(c *Command, err error) error {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var minUsagePadding = 25
|
var minUsagePadding = 25
|
||||||
|
|
||||||
// UsagePadding return padding for the usage.
|
// UsagePadding return padding for the usage.
|
||||||
|
@ -861,10 +908,10 @@ func (c *Command) execute(a []string) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.validateRequiredFlags(); err != nil {
|
if err := c.validateRequiredFlags(); err != nil {
|
||||||
return c.FlagErrorFunc()(c, err)
|
return c.RequiredFlagsErrorFunc()(c, err)
|
||||||
}
|
}
|
||||||
if err := c.validateFlagGroups(); err != nil {
|
if err := c.validateFlagGroups(); err != nil {
|
||||||
return c.FlagErrorFunc()(c, err)
|
return c.FlagGroupsErrorFunc()(c, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.RunE != nil {
|
if c.RunE != nil {
|
||||||
|
|
|
@ -1723,6 +1723,74 @@ func TestFlagErrorFunc(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRequiredFlagsErrorFunc(t *testing.T) {
|
||||||
|
c := &Command{Use: "c", Run: emptyRun}
|
||||||
|
ff := c.Flags()
|
||||||
|
ff.Bool("required-flag-1", false, "required flag #1")
|
||||||
|
ff.Bool("required-flag-2", false, "required flag #1")
|
||||||
|
if err := c.MarkFlagRequired("required-flag-1"); err != nil {
|
||||||
|
t.Errorf("unexpected error %v", err)
|
||||||
|
}
|
||||||
|
if err := c.MarkFlagRequired("required-flag-2"); err != nil {
|
||||||
|
t.Errorf("unexpected error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedFmt := "This is expected: %v"
|
||||||
|
c.SetRequiredFlagsErrorFunc(func(_ *Command, err error) error {
|
||||||
|
return fmt.Errorf(expectedFmt, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := executeCommand(c)
|
||||||
|
|
||||||
|
got := err.Error()
|
||||||
|
expected := fmt.Sprintf(expectedFmt, `required flag(s) "required-flag-1", "required-flag-2" not set`)
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("Expected %v, got %v", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlagGroupsErrorFunc_RequiredTogether(t *testing.T) {
|
||||||
|
c := &Command{Use: "c", Run: emptyRun}
|
||||||
|
ff := c.Flags()
|
||||||
|
ff.Bool("required-flag-1", false, "required flag #1")
|
||||||
|
ff.Bool("required-flag-2", false, "required flag #1")
|
||||||
|
c.MarkFlagsRequiredTogether("required-flag-1", "required-flag-2")
|
||||||
|
|
||||||
|
expectedFmt := "This is expected: %v"
|
||||||
|
c.SetFlagGroupsErrorFunc(func(_ *Command, err error) error {
|
||||||
|
return fmt.Errorf(expectedFmt, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := executeCommand(c, "--required-flag-1")
|
||||||
|
|
||||||
|
got := err.Error()
|
||||||
|
expected := fmt.Sprintf(expectedFmt, `if any flags in the group [required-flag-1 required-flag-2] are set they must all be set; missing [required-flag-2]`)
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("Expected %v, got %v", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlagGroupsErrorFunc_MutuallyExclusive(t *testing.T) {
|
||||||
|
c := &Command{Use: "c", Run: emptyRun}
|
||||||
|
ff := c.Flags()
|
||||||
|
ff.Bool("required-flag-1", false, "required flag #1")
|
||||||
|
ff.Bool("required-flag-2", false, "required flag #1")
|
||||||
|
c.MarkFlagsMutuallyExclusive("required-flag-1", "required-flag-2")
|
||||||
|
|
||||||
|
expectedFmt := "This is expected: %v"
|
||||||
|
c.SetFlagGroupsErrorFunc(func(_ *Command, err error) error {
|
||||||
|
return fmt.Errorf(expectedFmt, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := executeCommand(c, "--required-flag-1", "--required-flag-2")
|
||||||
|
|
||||||
|
got := err.Error()
|
||||||
|
expected := fmt.Sprintf(expectedFmt, `if any flags in the group [required-flag-1 required-flag-2] are set none of the others can be; [required-flag-1 required-flag-2] were all set`)
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("Expected %v, got %v", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestSortedFlags checks,
|
// TestSortedFlags checks,
|
||||||
// if cmd.LocalFlags() is unsorted when cmd.Flags().SortFlags set to false.
|
// if cmd.LocalFlags() is unsorted when cmd.Flags().SortFlags set to false.
|
||||||
// Related to https://github.com/spf13/cobra/issues/404.
|
// Related to https://github.com/spf13/cobra/issues/404.
|
||||||
|
|
Loading…
Add table
Reference in a new issue