diff --git a/command.go b/command.go index 34d1bf36..c729600b 100644 --- a/command.go +++ b/command.go @@ -143,6 +143,9 @@ type Command struct { //FParseErrWhitelist flag parse errors to be ignored FParseErrWhitelist FParseErrWhitelist + // ValidateRequiredFlagsFunc is called to validate if all required flags are set + ValidateRequiredFlagsFunc func(*Command) error + // commands is the list of commands supported by this program. commands []*Command // parent is a parent command for this command. @@ -880,7 +883,8 @@ func (c *Command) ValidateArgs(args []string) error { return c.Args(c, args) } -func (c *Command) validateRequiredFlags() error { +// GetMissingRequiredFlags returns the names of the missing required flags +func (c *Command) GetMissingRequiredFlags() []string { flags := c.Flags() missingFlagNames := []string{} flags.VisitAll(func(pflag *flag.Flag) { @@ -892,7 +896,18 @@ func (c *Command) validateRequiredFlags() error { missingFlagNames = append(missingFlagNames, pflag.Name) } }) + return missingFlagNames +} +func (c *Command) validateRequiredFlags() error { + if c.ValidateRequiredFlagsFunc != nil { + return c.ValidateRequiredFlagsFunc(c) + } + return c.defaultValidateRequiredFlags() +} + +func (c *Command) defaultValidateRequiredFlags() error { + missingFlagNames := c.GetMissingRequiredFlags() if len(missingFlagNames) > 0 { return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`)) } diff --git a/command_test.go b/command_test.go index 6e483a3e..b0701015 100644 --- a/command_test.go +++ b/command_test.go @@ -691,6 +691,27 @@ func TestRequiredFlags(t *testing.T) { } } +func TestRequiredFlagsFunc(t *testing.T) { + c := &Command{Use: "c", Run: emptyRun, ValidateRequiredFlagsFunc: func(c *Command) error { + missing := c.GetMissingRequiredFlags() + return fmt.Errorf(strings.Join(missing, "|")) + }} + c.Flags().String("foo1", "", "") + c.MarkFlagRequired("foo1") + c.Flags().String("foo2", "", "") + c.MarkFlagRequired("foo2") + c.Flags().String("bar", "", "") + + expected := "foo1|foo2" + + _, err := executeCommand(c) + got := err.Error() + + if got != expected { + t.Errorf("Expected error: %q, got: %q", expected, got) + } +} + func TestPersistentRequiredFlags(t *testing.T) { parent := &Command{Use: "parent", Run: emptyRun} parent.PersistentFlags().String("foo1", "", "") @@ -716,6 +737,34 @@ func TestPersistentRequiredFlags(t *testing.T) { } } +func TestPersistentRequiredFlagsFunc(t *testing.T) { + parent := &Command{Use: "parent", Run: emptyRun} + parent.PersistentFlags().String("foo1", "", "") + parent.MarkPersistentFlagRequired("foo1") + parent.PersistentFlags().String("foo2", "", "") + parent.MarkPersistentFlagRequired("foo2") + parent.Flags().String("foo3", "", "") + + child := &Command{Use: "child", Run: emptyRun, ValidateRequiredFlagsFunc: func(c *Command) error { + missing := c.GetMissingRequiredFlags() + return fmt.Errorf(strings.Join(missing, "|")) + }} + child.Flags().String("bar1", "", "") + child.MarkFlagRequired("bar1") + child.Flags().String("bar2", "", "") + child.MarkFlagRequired("bar2") + child.Flags().String("bar3", "", "") + + parent.AddCommand(child) + + expected := "bar1|bar2|foo1|foo2" + + _, err := executeCommand(parent, "child") + if err.Error() != expected { + t.Errorf("Expected %q, got %q", expected, err.Error()) + } +} + func TestInitHelpFlagMergesFlags(t *testing.T) { usage := "custom flag" rootCmd := &Command{Use: "root"}