mirror of
https://github.com/spf13/cobra
synced 2025-05-06 13:27:26 +00:00
Add a new ValidateRequiredFlagsFunc to Command to customize the required flag validation behavior. This is useful to customize the validation and messaging when you have to set, say, either flag1 or flag2, but not both.
This commit is contained in:
parent
fe5e611709
commit
28c4487f58
2 changed files with 65 additions and 1 deletions
17
command.go
17
command.go
|
@ -143,6 +143,9 @@ type Command struct {
|
||||||
//FParseErrWhitelist flag parse errors to be ignored
|
//FParseErrWhitelist flag parse errors to be ignored
|
||||||
FParseErrWhitelist FParseErrWhitelist
|
FParseErrWhitelist FParseErrWhitelist
|
||||||
|
|
||||||
|
// ValidateRequiredFlagsFunc is called to validate if all required flags are set
|
||||||
|
ValidateRequiredFlagsFunc func(*Command) error
|
||||||
|
|
||||||
// commands is the list of commands supported by this program.
|
// commands is the list of commands supported by this program.
|
||||||
commands []*Command
|
commands []*Command
|
||||||
// parent is a parent command for this command.
|
// parent is a parent command for this command.
|
||||||
|
@ -880,7 +883,8 @@ func (c *Command) ValidateArgs(args []string) error {
|
||||||
return c.Args(c, args)
|
return c.Args(c, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Command) validateRequiredFlags() error {
|
// GetMissingRequiredFlags returns the names of the missing required flags
|
||||||
|
func (c *Command) GetMissingRequiredFlags() []string {
|
||||||
flags := c.Flags()
|
flags := c.Flags()
|
||||||
missingFlagNames := []string{}
|
missingFlagNames := []string{}
|
||||||
flags.VisitAll(func(pflag *flag.Flag) {
|
flags.VisitAll(func(pflag *flag.Flag) {
|
||||||
|
@ -892,7 +896,18 @@ func (c *Command) validateRequiredFlags() error {
|
||||||
missingFlagNames = append(missingFlagNames, pflag.Name)
|
missingFlagNames = append(missingFlagNames, pflag.Name)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
return missingFlagNames
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Command) validateRequiredFlags() error {
|
||||||
|
if c.ValidateRequiredFlagsFunc != nil {
|
||||||
|
return c.ValidateRequiredFlagsFunc(c)
|
||||||
|
}
|
||||||
|
return c.defaultValidateRequiredFlags()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Command) defaultValidateRequiredFlags() error {
|
||||||
|
missingFlagNames := c.GetMissingRequiredFlags()
|
||||||
if len(missingFlagNames) > 0 {
|
if len(missingFlagNames) > 0 {
|
||||||
return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`))
|
return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`))
|
||||||
}
|
}
|
||||||
|
|
|
@ -691,6 +691,27 @@ func TestRequiredFlags(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRequiredFlagsFunc(t *testing.T) {
|
||||||
|
c := &Command{Use: "c", Run: emptyRun, ValidateRequiredFlagsFunc: func(c *Command) error {
|
||||||
|
missing := c.GetMissingRequiredFlags()
|
||||||
|
return fmt.Errorf(strings.Join(missing, "|"))
|
||||||
|
}}
|
||||||
|
c.Flags().String("foo1", "", "")
|
||||||
|
c.MarkFlagRequired("foo1")
|
||||||
|
c.Flags().String("foo2", "", "")
|
||||||
|
c.MarkFlagRequired("foo2")
|
||||||
|
c.Flags().String("bar", "", "")
|
||||||
|
|
||||||
|
expected := "foo1|foo2"
|
||||||
|
|
||||||
|
_, err := executeCommand(c)
|
||||||
|
got := err.Error()
|
||||||
|
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("Expected error: %q, got: %q", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPersistentRequiredFlags(t *testing.T) {
|
func TestPersistentRequiredFlags(t *testing.T) {
|
||||||
parent := &Command{Use: "parent", Run: emptyRun}
|
parent := &Command{Use: "parent", Run: emptyRun}
|
||||||
parent.PersistentFlags().String("foo1", "", "")
|
parent.PersistentFlags().String("foo1", "", "")
|
||||||
|
@ -716,6 +737,34 @@ func TestPersistentRequiredFlags(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPersistentRequiredFlagsFunc(t *testing.T) {
|
||||||
|
parent := &Command{Use: "parent", Run: emptyRun}
|
||||||
|
parent.PersistentFlags().String("foo1", "", "")
|
||||||
|
parent.MarkPersistentFlagRequired("foo1")
|
||||||
|
parent.PersistentFlags().String("foo2", "", "")
|
||||||
|
parent.MarkPersistentFlagRequired("foo2")
|
||||||
|
parent.Flags().String("foo3", "", "")
|
||||||
|
|
||||||
|
child := &Command{Use: "child", Run: emptyRun, ValidateRequiredFlagsFunc: func(c *Command) error {
|
||||||
|
missing := c.GetMissingRequiredFlags()
|
||||||
|
return fmt.Errorf(strings.Join(missing, "|"))
|
||||||
|
}}
|
||||||
|
child.Flags().String("bar1", "", "")
|
||||||
|
child.MarkFlagRequired("bar1")
|
||||||
|
child.Flags().String("bar2", "", "")
|
||||||
|
child.MarkFlagRequired("bar2")
|
||||||
|
child.Flags().String("bar3", "", "")
|
||||||
|
|
||||||
|
parent.AddCommand(child)
|
||||||
|
|
||||||
|
expected := "bar1|bar2|foo1|foo2"
|
||||||
|
|
||||||
|
_, err := executeCommand(parent, "child")
|
||||||
|
if err.Error() != expected {
|
||||||
|
t.Errorf("Expected %q, got %q", expected, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestInitHelpFlagMergesFlags(t *testing.T) {
|
func TestInitHelpFlagMergesFlags(t *testing.T) {
|
||||||
usage := "custom flag"
|
usage := "custom flag"
|
||||||
rootCmd := &Command{Use: "root"}
|
rootCmd := &Command{Use: "root"}
|
||||||
|
|
Loading…
Add table
Reference in a new issue