From 0fd2fcdab53dbcffc4579c9df50c8b62d8aa9449 Mon Sep 17 00:00:00 2001 From: "Ethan P." Date: Mon, 14 Apr 2025 22:44:55 -0700 Subject: [PATCH] feat: Add getters to error structs --- command.go | 1 + errors.go | 112 +++++++++++++++++++++++++++++++++++++++++++++---- flag_groups.go | 15 ++++--- 3 files changed, 113 insertions(+), 15 deletions(-) diff --git a/command.go b/command.go index 3ffe6cc2..aae5094d 100644 --- a/command.go +++ b/command.go @@ -1196,6 +1196,7 @@ func (c *Command) ValidateRequiredFlags() error { if len(missingFlagNames) > 0 { return &RequiredFlagError{ + cmd: c, missingFlagNames: missingFlagNames, } } diff --git a/errors.go b/errors.go index 8c470199..d5a33358 100644 --- a/errors.go +++ b/errors.go @@ -46,6 +46,28 @@ func (e *InvalidArgCountError) Error() string { return fmt.Sprintf("accepts between %d and %d arg(s), received %d", e.atLeast, e.atMost, len(e.args)) } +// GetCommand returns the Command that the error occurred in. +func (e *InvalidArgCountError) GetCommand() *Command { + return e.cmd +} + +// GetArguments returns the invalid arguments. +func (e *InvalidArgCountError) GetArguments() []string { + return e.args +} + +// GetMinArgumentCount returns the minimum (inclusive) number of arguments +// that the command requires. If there is no minimum, a value of -1 is returned. +func (e *InvalidArgCountError) GetMinArgumentCount() int { + return e.atLeast +} + +// GetMaxArgumentCount returns the maximum (inclusive) number of arguments +// that the command requires. If there is no maximum, a value of -1 is returned. +func (e *InvalidArgCountError) GetMaxArgumentCount() int { + return e.atMost +} + // InvalidArgCountError is the error returned an invalid argument is present. type InvalidArgValueError struct { cmd *Command @@ -58,6 +80,21 @@ func (e *InvalidArgValueError) Error() string { return fmt.Sprintf("invalid argument %q for %q%s", e.arg, e.cmd.CommandPath(), e.suggestions) } +// GetCommand returns the Command that the error occurred in. +func (e *InvalidArgValueError) GetCommand() *Command { + return e.cmd +} + +// GetArgument returns the invalid argument. +func (e *InvalidArgValueError) GetArgument() string { + return e.arg +} + +// GetSuggestions returns suggestions, if there are any. +func (e *InvalidArgValueError) GetSuggestions() string { + return e.suggestions +} + // UnknownSubcommandError is the error returned when a subcommand can not be // found. type UnknownSubcommandError struct { @@ -66,6 +103,21 @@ type UnknownSubcommandError struct { suggestions string } +// GetCommand returns the Command that the error occurred in. +func (e *UnknownSubcommandError) GetCommand() *Command { + return e.cmd +} + +// GetSubcommand returns the unknown subcommand name. +func (e *UnknownSubcommandError) GetSubcommand() string { + return e.subcmd +} + +// GetSuggestions returns suggestions, if there are any. +func (e *UnknownSubcommandError) GetSuggestions() string { + return e.suggestions +} + // Error implements error. func (e *UnknownSubcommandError) Error() string { return fmt.Sprintf("unknown command %q for %q%s", e.subcmd, e.cmd.CommandPath(), e.suggestions) @@ -73,6 +125,7 @@ func (e *UnknownSubcommandError) Error() string { // RequiredFlagError is the error returned when a required flag is not set. type RequiredFlagError struct { + cmd *Command missingFlagNames []string } @@ -81,31 +134,72 @@ func (e *RequiredFlagError) Error() string { return fmt.Sprintf(`required flag(s) "%s" not set`, strings.Join(e.missingFlagNames, `", "`)) } +// GetCommand returns the Command that the error occurred in. +func (e *RequiredFlagError) GetCommand() *Command { + return e.cmd +} + +// GetFlags returns the names of the missing flags. +func (e *RequiredFlagError) GetFlags() []string { + return e.missingFlagNames +} + // FlagGroupError is the error returned when mutually-required or // mutually-exclusive flags are not properly specified. type FlagGroupError struct { + cmd *Command flagList string - flagGroupType flagGroupType + flagGroupType FlagGroupType problemFlags []string } -// flagGroupType identifies which failed validation caused a FlagGroupError. -type flagGroupType string +// GetCommand returns the Command that the error occurred in. +func (e *FlagGroupError) GetCommand() *Command { + return e.cmd +} + +// GetFlagList returns the flags in the group. +func (e *FlagGroupError) GetFlags() []string { + return strings.Split(e.flagList, " ") +} + +// GetFlagGroupType returns the type of flag group causing the error. +// +// Valid types are: +// - FlagsAreMutuallyExclusive for mutually-exclusive flags. +// - FlagsAreRequiredTogether for mutually-required flags. +// - FlagsAreOneRequired for flags where at least one must be present. +func (e *FlagGroupError) GetFlagGroupType() FlagGroupType { + return e.flagGroupType +} + +// GetProblemFlags returns the flags causing the error. +// +// For flag groups where: +// - FlagsAreMutuallyExclusive, these are all the flags set. +// - FlagsAreRequiredTogether, these are the missing flags. +// - FlagsAreOneRequired, this is empty. +func (e *FlagGroupError) GetProblemFlags() []string { + return e.problemFlags +} + +// FlagGroupType identifies which failed validation caused a FlagGroupError. +type FlagGroupType string const ( - flagGroupIsExclusive flagGroupType = "if any is set, none of the others can be" - flagGroupIsRequired flagGroupType = "if any is set, they must all be set" - flagGroupIsOneRequired flagGroupType = "at least one of the flags is required" + FlagsAreMutuallyExclusive FlagGroupType = "if any is set, none of the others can be" + FlagsAreRequiredTogether FlagGroupType = "if any is set, they must all be set" + FlagsAreOneRequired FlagGroupType = "at least one of the flags is required" ) // Error implements error. func (e *FlagGroupError) Error() string { switch e.flagGroupType { - case flagGroupIsRequired: + case FlagsAreRequiredTogether: return fmt.Sprintf("if any flags in the group [%v] are set they must all be set; missing %v", e.flagList, e.problemFlags) - case flagGroupIsOneRequired: + case FlagsAreOneRequired: return fmt.Sprintf("at least one of the flags in the group [%v] is required", e.flagList) - case flagGroupIsExclusive: + case FlagsAreMutuallyExclusive: return fmt.Sprintf("if any flags in the group [%v] are set none of the others can be; %v were all set", e.flagList, e.problemFlags) } diff --git a/flag_groups.go b/flag_groups.go index 401a618a..9bf4ee9a 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -97,12 +97,15 @@ func (c *Command) ValidateFlagGroups() error { }) if err := validateRequiredFlagGroups(groupStatus); err != nil { + err.cmd = c return err } if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil { + err.cmd = c return err } if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { + err.cmd = c return err } return nil @@ -141,7 +144,7 @@ func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annota } } -func validateRequiredFlagGroups(data map[string]map[string]bool) error { +func validateRequiredFlagGroups(data map[string]map[string]bool) *FlagGroupError { keys := sortedKeys(data) for _, flagList := range keys { flagnameAndStatus := data[flagList] @@ -160,7 +163,7 @@ func validateRequiredFlagGroups(data map[string]map[string]bool) error { sort.Strings(unset) return &FlagGroupError{ flagList: flagList, - flagGroupType: flagGroupIsRequired, + flagGroupType: FlagsAreRequiredTogether, problemFlags: unset, } } @@ -168,7 +171,7 @@ func validateRequiredFlagGroups(data map[string]map[string]bool) error { return nil } -func validateOneRequiredFlagGroups(data map[string]map[string]bool) error { +func validateOneRequiredFlagGroups(data map[string]map[string]bool) *FlagGroupError { keys := sortedKeys(data) for _, flagList := range keys { flagnameAndStatus := data[flagList] @@ -186,13 +189,13 @@ func validateOneRequiredFlagGroups(data map[string]map[string]bool) error { sort.Strings(set) return &FlagGroupError{ flagList: flagList, - flagGroupType: flagGroupIsOneRequired, + flagGroupType: FlagsAreOneRequired, } } return nil } -func validateExclusiveFlagGroups(data map[string]map[string]bool) error { +func validateExclusiveFlagGroups(data map[string]map[string]bool) *FlagGroupError { keys := sortedKeys(data) for _, flagList := range keys { flagnameAndStatus := data[flagList] @@ -210,7 +213,7 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error { sort.Strings(set) return &FlagGroupError{ flagList: flagList, - flagGroupType: flagGroupIsExclusive, + flagGroupType: FlagsAreMutuallyExclusive, problemFlags: set, } }