From 28c4487f58a8a6420c1ba1587290559735507878 Mon Sep 17 00:00:00 2001 From: Burak Serdar Date: Fri, 2 Nov 2018 09:12:17 -0600 Subject: [PATCH] Add a new ValidateRequiredFlagsFunc to Command to customize the required flag validation behavior. This is useful to customize the validation and messaging when you have to set, say, either flag1 or flag2, but not both. --- command.go | 17 ++++++++++++++++- command_test.go | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/command.go b/command.go index 34d1bf36..c729600b 100644 --- a/command.go +++ b/command.go @@ -143,6 +143,9 @@ type Command struct { //FParseErrWhitelist flag parse errors to be ignored FParseErrWhitelist FParseErrWhitelist + // ValidateRequiredFlagsFunc is called to validate if all required flags are set + ValidateRequiredFlagsFunc func(*Command) error + // commands is the list of commands supported by this program. commands []*Command // parent is a parent command for this command. @@ -880,7 +883,8 @@ func (c *Command) ValidateArgs(args []string) error { return c.Args(c, args) } -func (c *Command) validateRequiredFlags() error { +// GetMissingRequiredFlags returns the names of the missing required flags +func (c *Command) GetMissingRequiredFlags() []string { flags := c.Flags() missingFlagNames := []string{} flags.VisitAll(func(pflag *flag.Flag) { @@ -892,7 +896,18 @@ func (c *Command) validateRequiredFlags() error { missingFlagNames = append(missingFlagNames, pflag.Name) } }) + return missingFlagNames +} +func (c *Command) validateRequiredFlags() error { + if c.ValidateRequiredFlagsFunc != nil { + return c.ValidateRequiredFlagsFunc(c) + } + return c.defaultValidateRequiredFlags() +} + +func (c *Command) defaultValidateRequiredFlags() error { + missingFlagNames := c.GetMissingRequiredFlags() if len(missingFlagNames) > 0 { return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`)) } diff --git a/command_test.go b/command_test.go index 6e483a3e..b0701015 100644 --- a/command_test.go +++ b/command_test.go @@ -691,6 +691,27 @@ func TestRequiredFlags(t *testing.T) { } } +func TestRequiredFlagsFunc(t *testing.T) { + c := &Command{Use: "c", Run: emptyRun, ValidateRequiredFlagsFunc: func(c *Command) error { + missing := c.GetMissingRequiredFlags() + return fmt.Errorf(strings.Join(missing, "|")) + }} + c.Flags().String("foo1", "", "") + c.MarkFlagRequired("foo1") + c.Flags().String("foo2", "", "") + c.MarkFlagRequired("foo2") + c.Flags().String("bar", "", "") + + expected := "foo1|foo2" + + _, err := executeCommand(c) + got := err.Error() + + if got != expected { + t.Errorf("Expected error: %q, got: %q", expected, got) + } +} + func TestPersistentRequiredFlags(t *testing.T) { parent := &Command{Use: "parent", Run: emptyRun} parent.PersistentFlags().String("foo1", "", "") @@ -716,6 +737,34 @@ func TestPersistentRequiredFlags(t *testing.T) { } } +func TestPersistentRequiredFlagsFunc(t *testing.T) { + parent := &Command{Use: "parent", Run: emptyRun} + parent.PersistentFlags().String("foo1", "", "") + parent.MarkPersistentFlagRequired("foo1") + parent.PersistentFlags().String("foo2", "", "") + parent.MarkPersistentFlagRequired("foo2") + parent.Flags().String("foo3", "", "") + + child := &Command{Use: "child", Run: emptyRun, ValidateRequiredFlagsFunc: func(c *Command) error { + missing := c.GetMissingRequiredFlags() + return fmt.Errorf(strings.Join(missing, "|")) + }} + child.Flags().String("bar1", "", "") + child.MarkFlagRequired("bar1") + child.Flags().String("bar2", "", "") + child.MarkFlagRequired("bar2") + child.Flags().String("bar3", "", "") + + parent.AddCommand(child) + + expected := "bar1|bar2|foo1|foo2" + + _, err := executeCommand(parent, "child") + if err.Error() != expected { + t.Errorf("Expected %q, got %q", expected, err.Error()) + } +} + func TestInitHelpFlagMergesFlags(t *testing.T) { usage := "custom flag" rootCmd := &Command{Use: "root"}