Add a new ValidateRequiredFlagsFunc to Command to customize the required flag validation behavior. This is useful to customize the validation and messaging when you have to set, say, either flag1 or flag2, but not both.

This commit is contained in:
Burak Serdar 2018-11-02 09:12:17 -06:00
parent fe5e611709
commit 28c4487f58
2 changed files with 65 additions and 1 deletions

View file

@ -143,6 +143,9 @@ type Command struct {
//FParseErrWhitelist flag parse errors to be ignored //FParseErrWhitelist flag parse errors to be ignored
FParseErrWhitelist FParseErrWhitelist FParseErrWhitelist FParseErrWhitelist
// ValidateRequiredFlagsFunc is called to validate if all required flags are set
ValidateRequiredFlagsFunc func(*Command) error
// commands is the list of commands supported by this program. // commands is the list of commands supported by this program.
commands []*Command commands []*Command
// parent is a parent command for this command. // parent is a parent command for this command.
@ -880,7 +883,8 @@ func (c *Command) ValidateArgs(args []string) error {
return c.Args(c, args) return c.Args(c, args)
} }
func (c *Command) validateRequiredFlags() error { // GetMissingRequiredFlags returns the names of the missing required flags
func (c *Command) GetMissingRequiredFlags() []string {
flags := c.Flags() flags := c.Flags()
missingFlagNames := []string{} missingFlagNames := []string{}
flags.VisitAll(func(pflag *flag.Flag) { flags.VisitAll(func(pflag *flag.Flag) {
@ -892,7 +896,18 @@ func (c *Command) validateRequiredFlags() error {
missingFlagNames = append(missingFlagNames, pflag.Name) missingFlagNames = append(missingFlagNames, pflag.Name)
} }
}) })
return missingFlagNames
}
func (c *Command) validateRequiredFlags() error {
if c.ValidateRequiredFlagsFunc != nil {
return c.ValidateRequiredFlagsFunc(c)
}
return c.defaultValidateRequiredFlags()
}
func (c *Command) defaultValidateRequiredFlags() error {
missingFlagNames := c.GetMissingRequiredFlags()
if len(missingFlagNames) > 0 { if len(missingFlagNames) > 0 {
return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`)) return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`))
} }

View file

@ -691,6 +691,27 @@ func TestRequiredFlags(t *testing.T) {
} }
} }
func TestRequiredFlagsFunc(t *testing.T) {
c := &Command{Use: "c", Run: emptyRun, ValidateRequiredFlagsFunc: func(c *Command) error {
missing := c.GetMissingRequiredFlags()
return fmt.Errorf(strings.Join(missing, "|"))
}}
c.Flags().String("foo1", "", "")
c.MarkFlagRequired("foo1")
c.Flags().String("foo2", "", "")
c.MarkFlagRequired("foo2")
c.Flags().String("bar", "", "")
expected := "foo1|foo2"
_, err := executeCommand(c)
got := err.Error()
if got != expected {
t.Errorf("Expected error: %q, got: %q", expected, got)
}
}
func TestPersistentRequiredFlags(t *testing.T) { func TestPersistentRequiredFlags(t *testing.T) {
parent := &Command{Use: "parent", Run: emptyRun} parent := &Command{Use: "parent", Run: emptyRun}
parent.PersistentFlags().String("foo1", "", "") parent.PersistentFlags().String("foo1", "", "")
@ -716,6 +737,34 @@ func TestPersistentRequiredFlags(t *testing.T) {
} }
} }
func TestPersistentRequiredFlagsFunc(t *testing.T) {
parent := &Command{Use: "parent", Run: emptyRun}
parent.PersistentFlags().String("foo1", "", "")
parent.MarkPersistentFlagRequired("foo1")
parent.PersistentFlags().String("foo2", "", "")
parent.MarkPersistentFlagRequired("foo2")
parent.Flags().String("foo3", "", "")
child := &Command{Use: "child", Run: emptyRun, ValidateRequiredFlagsFunc: func(c *Command) error {
missing := c.GetMissingRequiredFlags()
return fmt.Errorf(strings.Join(missing, "|"))
}}
child.Flags().String("bar1", "", "")
child.MarkFlagRequired("bar1")
child.Flags().String("bar2", "", "")
child.MarkFlagRequired("bar2")
child.Flags().String("bar3", "", "")
parent.AddCommand(child)
expected := "bar1|bar2|foo1|foo2"
_, err := executeCommand(parent, "child")
if err.Error() != expected {
t.Errorf("Expected %q, got %q", expected, err.Error())
}
}
func TestInitHelpFlagMergesFlags(t *testing.T) { func TestInitHelpFlagMergesFlags(t *testing.T) {
usage := "custom flag" usage := "custom flag"
rootCmd := &Command{Use: "root"} rootCmd := &Command{Use: "root"}