diff --git a/command.go b/command.go index 4794e5eb..3ffe6cc2 100644 --- a/command.go +++ b/command.go @@ -1195,7 +1195,9 @@ func (c *Command) ValidateRequiredFlags() error { }) if len(missingFlagNames) > 0 { - return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`)) + return &RequiredFlagError{ + missingFlagNames: missingFlagNames, + } } return nil } diff --git a/errors.go b/errors.go index 7d8fb167..8c470199 100644 --- a/errors.go +++ b/errors.go @@ -14,7 +14,10 @@ package cobra -import "fmt" +import ( + "fmt" + "strings" +) // InvalidArgCountError is the error returned when the wrong number of arguments // are supplied to a command. @@ -67,3 +70,44 @@ type UnknownSubcommandError struct { func (e *UnknownSubcommandError) Error() string { return fmt.Sprintf("unknown command %q for %q%s", e.subcmd, e.cmd.CommandPath(), e.suggestions) } + +// RequiredFlagError is the error returned when a required flag is not set. +type RequiredFlagError struct { + missingFlagNames []string +} + +// Error implements error. +func (e *RequiredFlagError) Error() string { + return fmt.Sprintf(`required flag(s) "%s" not set`, strings.Join(e.missingFlagNames, `", "`)) +} + +// FlagGroupError is the error returned when mutually-required or +// mutually-exclusive flags are not properly specified. +type FlagGroupError struct { + flagList string + flagGroupType flagGroupType + problemFlags []string +} + +// flagGroupType identifies which failed validation caused a FlagGroupError. +type flagGroupType string + +const ( + flagGroupIsExclusive flagGroupType = "if any is set, none of the others can be" + flagGroupIsRequired flagGroupType = "if any is set, they must all be set" + flagGroupIsOneRequired flagGroupType = "at least one of the flags is required" +) + +// Error implements error. +func (e *FlagGroupError) Error() string { + switch e.flagGroupType { + case flagGroupIsRequired: + return fmt.Sprintf("if any flags in the group [%v] are set they must all be set; missing %v", e.flagList, e.problemFlags) + case flagGroupIsOneRequired: + return fmt.Sprintf("at least one of the flags in the group [%v] is required", e.flagList) + case flagGroupIsExclusive: + return fmt.Sprintf("if any flags in the group [%v] are set none of the others can be; %v were all set", e.flagList, e.problemFlags) + } + + panic("invalid flagGroupType") +} diff --git a/flag_groups.go b/flag_groups.go index 560612fd..401a618a 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -158,7 +158,11 @@ func validateRequiredFlagGroups(data map[string]map[string]bool) error { // Sort values, so they can be tested/scripted against consistently. sort.Strings(unset) - return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset) + return &FlagGroupError{ + flagList: flagList, + flagGroupType: flagGroupIsRequired, + problemFlags: unset, + } } return nil @@ -180,7 +184,10 @@ func validateOneRequiredFlagGroups(data map[string]map[string]bool) error { // Sort values, so they can be tested/scripted against consistently. sort.Strings(set) - return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList) + return &FlagGroupError{ + flagList: flagList, + flagGroupType: flagGroupIsOneRequired, + } } return nil } @@ -201,7 +208,11 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error { // Sort values, so they can be tested/scripted against consistently. sort.Strings(set) - return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set) + return &FlagGroupError{ + flagList: flagList, + flagGroupType: flagGroupIsExclusive, + problemFlags: set, + } } return nil }