feat: Add getters to error structs

This commit is contained in:
Ethan P. 2025-04-14 22:44:55 -07:00
parent 6d0ee6b071
commit 0fd2fcdab5
No known key found for this signature in database
GPG key ID: B29B90B1B228FEBC
3 changed files with 113 additions and 15 deletions

View file

@ -1196,6 +1196,7 @@ func (c *Command) ValidateRequiredFlags() error {
if len(missingFlagNames) > 0 { if len(missingFlagNames) > 0 {
return &RequiredFlagError{ return &RequiredFlagError{
cmd: c,
missingFlagNames: missingFlagNames, missingFlagNames: missingFlagNames,
} }
} }

112
errors.go
View file

@ -46,6 +46,28 @@ func (e *InvalidArgCountError) Error() string {
return fmt.Sprintf("accepts between %d and %d arg(s), received %d", e.atLeast, e.atMost, len(e.args)) return fmt.Sprintf("accepts between %d and %d arg(s), received %d", e.atLeast, e.atMost, len(e.args))
} }
// GetCommand returns the Command that the error occurred in.
func (e *InvalidArgCountError) GetCommand() *Command {
return e.cmd
}
// GetArguments returns the invalid arguments.
func (e *InvalidArgCountError) GetArguments() []string {
return e.args
}
// GetMinArgumentCount returns the minimum (inclusive) number of arguments
// that the command requires. If there is no minimum, a value of -1 is returned.
func (e *InvalidArgCountError) GetMinArgumentCount() int {
return e.atLeast
}
// GetMaxArgumentCount returns the maximum (inclusive) number of arguments
// that the command requires. If there is no maximum, a value of -1 is returned.
func (e *InvalidArgCountError) GetMaxArgumentCount() int {
return e.atMost
}
// InvalidArgCountError is the error returned an invalid argument is present. // InvalidArgCountError is the error returned an invalid argument is present.
type InvalidArgValueError struct { type InvalidArgValueError struct {
cmd *Command cmd *Command
@ -58,6 +80,21 @@ func (e *InvalidArgValueError) Error() string {
return fmt.Sprintf("invalid argument %q for %q%s", e.arg, e.cmd.CommandPath(), e.suggestions) return fmt.Sprintf("invalid argument %q for %q%s", e.arg, e.cmd.CommandPath(), e.suggestions)
} }
// GetCommand returns the Command that the error occurred in.
func (e *InvalidArgValueError) GetCommand() *Command {
return e.cmd
}
// GetArgument returns the invalid argument.
func (e *InvalidArgValueError) GetArgument() string {
return e.arg
}
// GetSuggestions returns suggestions, if there are any.
func (e *InvalidArgValueError) GetSuggestions() string {
return e.suggestions
}
// UnknownSubcommandError is the error returned when a subcommand can not be // UnknownSubcommandError is the error returned when a subcommand can not be
// found. // found.
type UnknownSubcommandError struct { type UnknownSubcommandError struct {
@ -66,6 +103,21 @@ type UnknownSubcommandError struct {
suggestions string suggestions string
} }
// GetCommand returns the Command that the error occurred in.
func (e *UnknownSubcommandError) GetCommand() *Command {
return e.cmd
}
// GetSubcommand returns the unknown subcommand name.
func (e *UnknownSubcommandError) GetSubcommand() string {
return e.subcmd
}
// GetSuggestions returns suggestions, if there are any.
func (e *UnknownSubcommandError) GetSuggestions() string {
return e.suggestions
}
// Error implements error. // Error implements error.
func (e *UnknownSubcommandError) Error() string { func (e *UnknownSubcommandError) Error() string {
return fmt.Sprintf("unknown command %q for %q%s", e.subcmd, e.cmd.CommandPath(), e.suggestions) return fmt.Sprintf("unknown command %q for %q%s", e.subcmd, e.cmd.CommandPath(), e.suggestions)
@ -73,6 +125,7 @@ func (e *UnknownSubcommandError) Error() string {
// RequiredFlagError is the error returned when a required flag is not set. // RequiredFlagError is the error returned when a required flag is not set.
type RequiredFlagError struct { type RequiredFlagError struct {
cmd *Command
missingFlagNames []string missingFlagNames []string
} }
@ -81,31 +134,72 @@ func (e *RequiredFlagError) Error() string {
return fmt.Sprintf(`required flag(s) "%s" not set`, strings.Join(e.missingFlagNames, `", "`)) return fmt.Sprintf(`required flag(s) "%s" not set`, strings.Join(e.missingFlagNames, `", "`))
} }
// GetCommand returns the Command that the error occurred in.
func (e *RequiredFlagError) GetCommand() *Command {
return e.cmd
}
// GetFlags returns the names of the missing flags.
func (e *RequiredFlagError) GetFlags() []string {
return e.missingFlagNames
}
// FlagGroupError is the error returned when mutually-required or // FlagGroupError is the error returned when mutually-required or
// mutually-exclusive flags are not properly specified. // mutually-exclusive flags are not properly specified.
type FlagGroupError struct { type FlagGroupError struct {
cmd *Command
flagList string flagList string
flagGroupType flagGroupType flagGroupType FlagGroupType
problemFlags []string problemFlags []string
} }
// flagGroupType identifies which failed validation caused a FlagGroupError. // GetCommand returns the Command that the error occurred in.
type flagGroupType string func (e *FlagGroupError) GetCommand() *Command {
return e.cmd
}
// GetFlagList returns the flags in the group.
func (e *FlagGroupError) GetFlags() []string {
return strings.Split(e.flagList, " ")
}
// GetFlagGroupType returns the type of flag group causing the error.
//
// Valid types are:
// - FlagsAreMutuallyExclusive for mutually-exclusive flags.
// - FlagsAreRequiredTogether for mutually-required flags.
// - FlagsAreOneRequired for flags where at least one must be present.
func (e *FlagGroupError) GetFlagGroupType() FlagGroupType {
return e.flagGroupType
}
// GetProblemFlags returns the flags causing the error.
//
// For flag groups where:
// - FlagsAreMutuallyExclusive, these are all the flags set.
// - FlagsAreRequiredTogether, these are the missing flags.
// - FlagsAreOneRequired, this is empty.
func (e *FlagGroupError) GetProblemFlags() []string {
return e.problemFlags
}
// FlagGroupType identifies which failed validation caused a FlagGroupError.
type FlagGroupType string
const ( const (
flagGroupIsExclusive flagGroupType = "if any is set, none of the others can be" FlagsAreMutuallyExclusive FlagGroupType = "if any is set, none of the others can be"
flagGroupIsRequired flagGroupType = "if any is set, they must all be set" FlagsAreRequiredTogether FlagGroupType = "if any is set, they must all be set"
flagGroupIsOneRequired flagGroupType = "at least one of the flags is required" FlagsAreOneRequired FlagGroupType = "at least one of the flags is required"
) )
// Error implements error. // Error implements error.
func (e *FlagGroupError) Error() string { func (e *FlagGroupError) Error() string {
switch e.flagGroupType { switch e.flagGroupType {
case flagGroupIsRequired: case FlagsAreRequiredTogether:
return fmt.Sprintf("if any flags in the group [%v] are set they must all be set; missing %v", e.flagList, e.problemFlags) return fmt.Sprintf("if any flags in the group [%v] are set they must all be set; missing %v", e.flagList, e.problemFlags)
case flagGroupIsOneRequired: case FlagsAreOneRequired:
return fmt.Sprintf("at least one of the flags in the group [%v] is required", e.flagList) return fmt.Sprintf("at least one of the flags in the group [%v] is required", e.flagList)
case flagGroupIsExclusive: case FlagsAreMutuallyExclusive:
return fmt.Sprintf("if any flags in the group [%v] are set none of the others can be; %v were all set", e.flagList, e.problemFlags) return fmt.Sprintf("if any flags in the group [%v] are set none of the others can be; %v were all set", e.flagList, e.problemFlags)
} }

View file

@ -97,12 +97,15 @@ func (c *Command) ValidateFlagGroups() error {
}) })
if err := validateRequiredFlagGroups(groupStatus); err != nil { if err := validateRequiredFlagGroups(groupStatus); err != nil {
err.cmd = c
return err return err
} }
if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil { if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {
err.cmd = c
return err return err
} }
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
err.cmd = c
return err return err
} }
return nil return nil
@ -141,7 +144,7 @@ func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annota
} }
} }
func validateRequiredFlagGroups(data map[string]map[string]bool) error { func validateRequiredFlagGroups(data map[string]map[string]bool) *FlagGroupError {
keys := sortedKeys(data) keys := sortedKeys(data)
for _, flagList := range keys { for _, flagList := range keys {
flagnameAndStatus := data[flagList] flagnameAndStatus := data[flagList]
@ -160,7 +163,7 @@ func validateRequiredFlagGroups(data map[string]map[string]bool) error {
sort.Strings(unset) sort.Strings(unset)
return &FlagGroupError{ return &FlagGroupError{
flagList: flagList, flagList: flagList,
flagGroupType: flagGroupIsRequired, flagGroupType: FlagsAreRequiredTogether,
problemFlags: unset, problemFlags: unset,
} }
} }
@ -168,7 +171,7 @@ func validateRequiredFlagGroups(data map[string]map[string]bool) error {
return nil return nil
} }
func validateOneRequiredFlagGroups(data map[string]map[string]bool) error { func validateOneRequiredFlagGroups(data map[string]map[string]bool) *FlagGroupError {
keys := sortedKeys(data) keys := sortedKeys(data)
for _, flagList := range keys { for _, flagList := range keys {
flagnameAndStatus := data[flagList] flagnameAndStatus := data[flagList]
@ -186,13 +189,13 @@ func validateOneRequiredFlagGroups(data map[string]map[string]bool) error {
sort.Strings(set) sort.Strings(set)
return &FlagGroupError{ return &FlagGroupError{
flagList: flagList, flagList: flagList,
flagGroupType: flagGroupIsOneRequired, flagGroupType: FlagsAreOneRequired,
} }
} }
return nil return nil
} }
func validateExclusiveFlagGroups(data map[string]map[string]bool) error { func validateExclusiveFlagGroups(data map[string]map[string]bool) *FlagGroupError {
keys := sortedKeys(data) keys := sortedKeys(data)
for _, flagList := range keys { for _, flagList := range keys {
flagnameAndStatus := data[flagList] flagnameAndStatus := data[flagList]
@ -210,7 +213,7 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
sort.Strings(set) sort.Strings(set)
return &FlagGroupError{ return &FlagGroupError{
flagList: flagList, flagList: flagList,
flagGroupType: flagGroupIsExclusive, flagGroupType: FlagsAreMutuallyExclusive,
problemFlags: set, problemFlags: set,
} }
} }