From 750785d1cca58a8226c2bc578d000355b1dfdcd4 Mon Sep 17 00:00:00 2001 From: "Ethan P." Date: Mon, 14 Apr 2025 21:51:45 -0700 Subject: [PATCH 1/5] feat: Use error structs for errors returned in arg validation --- args.go | 47 ++++++++++++++++++++++++++++++------- errors.go | 69 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 8 deletions(-) create mode 100644 errors.go diff --git a/args.go b/args.go index ed1e70ce..78cf02e7 100644 --- a/args.go +++ b/args.go @@ -15,7 +15,6 @@ package cobra import ( - "fmt" "strings" ) @@ -33,7 +32,11 @@ func legacyArgs(cmd *Command, args []string) error { // root command with subcommands, do subcommand checking. if !cmd.HasParent() && len(args) > 0 { - return fmt.Errorf("unknown command %q for %q%s", args[0], cmd.CommandPath(), cmd.findSuggestions(args[0])) + return &UnknownSubcommandError{ + cmd: cmd, + subcmd: args[0], + suggestions: cmd.findSuggestions(args[0]), + } } return nil } @@ -41,7 +44,11 @@ func legacyArgs(cmd *Command, args []string) error { // NoArgs returns an error if any args are included. func NoArgs(cmd *Command, args []string) error { if len(args) > 0 { - return fmt.Errorf("unknown command %q for %q", args[0], cmd.CommandPath()) + return &UnknownSubcommandError{ + cmd: cmd, + subcmd: args[0], + suggestions: "", + } } return nil } @@ -58,7 +65,11 @@ func OnlyValidArgs(cmd *Command, args []string) error { } for _, v := range args { if !stringInSlice(v, validArgs) { - return fmt.Errorf("invalid argument %q for %q%s", v, cmd.CommandPath(), cmd.findSuggestions(args[0])) + return &InvalidArgValueError{ + cmd: cmd, + arg: v, + suggestions: cmd.findSuggestions(args[0]), + } } } } @@ -74,7 +85,12 @@ func ArbitraryArgs(cmd *Command, args []string) error { func MinimumNArgs(n int) PositionalArgs { return func(cmd *Command, args []string) error { if len(args) < n { - return fmt.Errorf("requires at least %d arg(s), only received %d", n, len(args)) + return &InvalidArgCountError{ + cmd: cmd, + args: args, + atLeast: n, + atMost: -1, + } } return nil } @@ -84,7 +100,12 @@ func MinimumNArgs(n int) PositionalArgs { func MaximumNArgs(n int) PositionalArgs { return func(cmd *Command, args []string) error { if len(args) > n { - return fmt.Errorf("accepts at most %d arg(s), received %d", n, len(args)) + return &InvalidArgCountError{ + cmd: cmd, + args: args, + atLeast: -1, + atMost: n, + } } return nil } @@ -94,7 +115,12 @@ func MaximumNArgs(n int) PositionalArgs { func ExactArgs(n int) PositionalArgs { return func(cmd *Command, args []string) error { if len(args) != n { - return fmt.Errorf("accepts %d arg(s), received %d", n, len(args)) + return &InvalidArgCountError{ + cmd: cmd, + args: args, + atLeast: n, + atMost: n, + } } return nil } @@ -104,7 +130,12 @@ func ExactArgs(n int) PositionalArgs { func RangeArgs(min int, max int) PositionalArgs { return func(cmd *Command, args []string) error { if len(args) < min || len(args) > max { - return fmt.Errorf("accepts between %d and %d arg(s), received %d", min, max, len(args)) + return &InvalidArgCountError{ + cmd: cmd, + args: args, + atLeast: min, + atMost: max, + } } return nil } diff --git a/errors.go b/errors.go new file mode 100644 index 00000000..7d8fb167 --- /dev/null +++ b/errors.go @@ -0,0 +1,69 @@ +// Copyright 2013-2023 The Cobra Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cobra + +import "fmt" + +// InvalidArgCountError is the error returned when the wrong number of arguments +// are supplied to a command. +type InvalidArgCountError struct { + cmd *Command + args []string + atLeast int + atMost int +} + +// Error implements error. +func (e *InvalidArgCountError) Error() string { + if e.atMost == -1 && e.atLeast >= 0 { // MinimumNArgs + return fmt.Sprintf("requires at least %d arg(s), only received %d", e.atLeast, len(e.args)) + } + + if e.atLeast == -1 && e.atMost >= 0 { // MaximumNArgs + return fmt.Sprintf("accepts at most %d arg(s), received %d", e.atMost, len(e.args)) + } + + if e.atLeast == e.atMost && e.atLeast != -1 { // ExactArgs + return fmt.Sprintf("accepts %d arg(s), received %d", e.atLeast, len(e.args)) + } + + // RangeArgs + return fmt.Sprintf("accepts between %d and %d arg(s), received %d", e.atLeast, e.atMost, len(e.args)) +} + +// InvalidArgCountError is the error returned an invalid argument is present. +type InvalidArgValueError struct { + cmd *Command + arg string + suggestions string +} + +// Error implements error. +func (e *InvalidArgValueError) Error() string { + return fmt.Sprintf("invalid argument %q for %q%s", e.arg, e.cmd.CommandPath(), e.suggestions) +} + +// UnknownSubcommandError is the error returned when a subcommand can not be +// found. +type UnknownSubcommandError struct { + cmd *Command + subcmd string + suggestions string +} + +// Error implements error. +func (e *UnknownSubcommandError) Error() string { + return fmt.Sprintf("unknown command %q for %q%s", e.subcmd, e.cmd.CommandPath(), e.suggestions) +} From 6d0ee6b0714ce20782ecc5ce57817862067fd1f8 Mon Sep 17 00:00:00 2001 From: "Ethan P." Date: Mon, 14 Apr 2025 22:24:17 -0700 Subject: [PATCH 2/5] feat: Use struct error for errors returned in flag validation --- command.go | 4 +++- errors.go | 46 +++++++++++++++++++++++++++++++++++++++++++++- flag_groups.go | 17 ++++++++++++++--- 3 files changed, 62 insertions(+), 5 deletions(-) diff --git a/command.go b/command.go index 4794e5eb..3ffe6cc2 100644 --- a/command.go +++ b/command.go @@ -1195,7 +1195,9 @@ func (c *Command) ValidateRequiredFlags() error { }) if len(missingFlagNames) > 0 { - return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`)) + return &RequiredFlagError{ + missingFlagNames: missingFlagNames, + } } return nil } diff --git a/errors.go b/errors.go index 7d8fb167..8c470199 100644 --- a/errors.go +++ b/errors.go @@ -14,7 +14,10 @@ package cobra -import "fmt" +import ( + "fmt" + "strings" +) // InvalidArgCountError is the error returned when the wrong number of arguments // are supplied to a command. @@ -67,3 +70,44 @@ type UnknownSubcommandError struct { func (e *UnknownSubcommandError) Error() string { return fmt.Sprintf("unknown command %q for %q%s", e.subcmd, e.cmd.CommandPath(), e.suggestions) } + +// RequiredFlagError is the error returned when a required flag is not set. +type RequiredFlagError struct { + missingFlagNames []string +} + +// Error implements error. +func (e *RequiredFlagError) Error() string { + return fmt.Sprintf(`required flag(s) "%s" not set`, strings.Join(e.missingFlagNames, `", "`)) +} + +// FlagGroupError is the error returned when mutually-required or +// mutually-exclusive flags are not properly specified. +type FlagGroupError struct { + flagList string + flagGroupType flagGroupType + problemFlags []string +} + +// flagGroupType identifies which failed validation caused a FlagGroupError. +type flagGroupType string + +const ( + flagGroupIsExclusive flagGroupType = "if any is set, none of the others can be" + flagGroupIsRequired flagGroupType = "if any is set, they must all be set" + flagGroupIsOneRequired flagGroupType = "at least one of the flags is required" +) + +// Error implements error. +func (e *FlagGroupError) Error() string { + switch e.flagGroupType { + case flagGroupIsRequired: + return fmt.Sprintf("if any flags in the group [%v] are set they must all be set; missing %v", e.flagList, e.problemFlags) + case flagGroupIsOneRequired: + return fmt.Sprintf("at least one of the flags in the group [%v] is required", e.flagList) + case flagGroupIsExclusive: + return fmt.Sprintf("if any flags in the group [%v] are set none of the others can be; %v were all set", e.flagList, e.problemFlags) + } + + panic("invalid flagGroupType") +} diff --git a/flag_groups.go b/flag_groups.go index 560612fd..401a618a 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -158,7 +158,11 @@ func validateRequiredFlagGroups(data map[string]map[string]bool) error { // Sort values, so they can be tested/scripted against consistently. sort.Strings(unset) - return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset) + return &FlagGroupError{ + flagList: flagList, + flagGroupType: flagGroupIsRequired, + problemFlags: unset, + } } return nil @@ -180,7 +184,10 @@ func validateOneRequiredFlagGroups(data map[string]map[string]bool) error { // Sort values, so they can be tested/scripted against consistently. sort.Strings(set) - return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList) + return &FlagGroupError{ + flagList: flagList, + flagGroupType: flagGroupIsOneRequired, + } } return nil } @@ -201,7 +208,11 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error { // Sort values, so they can be tested/scripted against consistently. sort.Strings(set) - return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set) + return &FlagGroupError{ + flagList: flagList, + flagGroupType: flagGroupIsExclusive, + problemFlags: set, + } } return nil } From 0fd2fcdab53dbcffc4579c9df50c8b62d8aa9449 Mon Sep 17 00:00:00 2001 From: "Ethan P." Date: Mon, 14 Apr 2025 22:44:55 -0700 Subject: [PATCH 3/5] feat: Add getters to error structs --- command.go | 1 + errors.go | 112 +++++++++++++++++++++++++++++++++++++++++++++---- flag_groups.go | 15 ++++--- 3 files changed, 113 insertions(+), 15 deletions(-) diff --git a/command.go b/command.go index 3ffe6cc2..aae5094d 100644 --- a/command.go +++ b/command.go @@ -1196,6 +1196,7 @@ func (c *Command) ValidateRequiredFlags() error { if len(missingFlagNames) > 0 { return &RequiredFlagError{ + cmd: c, missingFlagNames: missingFlagNames, } } diff --git a/errors.go b/errors.go index 8c470199..d5a33358 100644 --- a/errors.go +++ b/errors.go @@ -46,6 +46,28 @@ func (e *InvalidArgCountError) Error() string { return fmt.Sprintf("accepts between %d and %d arg(s), received %d", e.atLeast, e.atMost, len(e.args)) } +// GetCommand returns the Command that the error occurred in. +func (e *InvalidArgCountError) GetCommand() *Command { + return e.cmd +} + +// GetArguments returns the invalid arguments. +func (e *InvalidArgCountError) GetArguments() []string { + return e.args +} + +// GetMinArgumentCount returns the minimum (inclusive) number of arguments +// that the command requires. If there is no minimum, a value of -1 is returned. +func (e *InvalidArgCountError) GetMinArgumentCount() int { + return e.atLeast +} + +// GetMaxArgumentCount returns the maximum (inclusive) number of arguments +// that the command requires. If there is no maximum, a value of -1 is returned. +func (e *InvalidArgCountError) GetMaxArgumentCount() int { + return e.atMost +} + // InvalidArgCountError is the error returned an invalid argument is present. type InvalidArgValueError struct { cmd *Command @@ -58,6 +80,21 @@ func (e *InvalidArgValueError) Error() string { return fmt.Sprintf("invalid argument %q for %q%s", e.arg, e.cmd.CommandPath(), e.suggestions) } +// GetCommand returns the Command that the error occurred in. +func (e *InvalidArgValueError) GetCommand() *Command { + return e.cmd +} + +// GetArgument returns the invalid argument. +func (e *InvalidArgValueError) GetArgument() string { + return e.arg +} + +// GetSuggestions returns suggestions, if there are any. +func (e *InvalidArgValueError) GetSuggestions() string { + return e.suggestions +} + // UnknownSubcommandError is the error returned when a subcommand can not be // found. type UnknownSubcommandError struct { @@ -66,6 +103,21 @@ type UnknownSubcommandError struct { suggestions string } +// GetCommand returns the Command that the error occurred in. +func (e *UnknownSubcommandError) GetCommand() *Command { + return e.cmd +} + +// GetSubcommand returns the unknown subcommand name. +func (e *UnknownSubcommandError) GetSubcommand() string { + return e.subcmd +} + +// GetSuggestions returns suggestions, if there are any. +func (e *UnknownSubcommandError) GetSuggestions() string { + return e.suggestions +} + // Error implements error. func (e *UnknownSubcommandError) Error() string { return fmt.Sprintf("unknown command %q for %q%s", e.subcmd, e.cmd.CommandPath(), e.suggestions) @@ -73,6 +125,7 @@ func (e *UnknownSubcommandError) Error() string { // RequiredFlagError is the error returned when a required flag is not set. type RequiredFlagError struct { + cmd *Command missingFlagNames []string } @@ -81,31 +134,72 @@ func (e *RequiredFlagError) Error() string { return fmt.Sprintf(`required flag(s) "%s" not set`, strings.Join(e.missingFlagNames, `", "`)) } +// GetCommand returns the Command that the error occurred in. +func (e *RequiredFlagError) GetCommand() *Command { + return e.cmd +} + +// GetFlags returns the names of the missing flags. +func (e *RequiredFlagError) GetFlags() []string { + return e.missingFlagNames +} + // FlagGroupError is the error returned when mutually-required or // mutually-exclusive flags are not properly specified. type FlagGroupError struct { + cmd *Command flagList string - flagGroupType flagGroupType + flagGroupType FlagGroupType problemFlags []string } -// flagGroupType identifies which failed validation caused a FlagGroupError. -type flagGroupType string +// GetCommand returns the Command that the error occurred in. +func (e *FlagGroupError) GetCommand() *Command { + return e.cmd +} + +// GetFlagList returns the flags in the group. +func (e *FlagGroupError) GetFlags() []string { + return strings.Split(e.flagList, " ") +} + +// GetFlagGroupType returns the type of flag group causing the error. +// +// Valid types are: +// - FlagsAreMutuallyExclusive for mutually-exclusive flags. +// - FlagsAreRequiredTogether for mutually-required flags. +// - FlagsAreOneRequired for flags where at least one must be present. +func (e *FlagGroupError) GetFlagGroupType() FlagGroupType { + return e.flagGroupType +} + +// GetProblemFlags returns the flags causing the error. +// +// For flag groups where: +// - FlagsAreMutuallyExclusive, these are all the flags set. +// - FlagsAreRequiredTogether, these are the missing flags. +// - FlagsAreOneRequired, this is empty. +func (e *FlagGroupError) GetProblemFlags() []string { + return e.problemFlags +} + +// FlagGroupType identifies which failed validation caused a FlagGroupError. +type FlagGroupType string const ( - flagGroupIsExclusive flagGroupType = "if any is set, none of the others can be" - flagGroupIsRequired flagGroupType = "if any is set, they must all be set" - flagGroupIsOneRequired flagGroupType = "at least one of the flags is required" + FlagsAreMutuallyExclusive FlagGroupType = "if any is set, none of the others can be" + FlagsAreRequiredTogether FlagGroupType = "if any is set, they must all be set" + FlagsAreOneRequired FlagGroupType = "at least one of the flags is required" ) // Error implements error. func (e *FlagGroupError) Error() string { switch e.flagGroupType { - case flagGroupIsRequired: + case FlagsAreRequiredTogether: return fmt.Sprintf("if any flags in the group [%v] are set they must all be set; missing %v", e.flagList, e.problemFlags) - case flagGroupIsOneRequired: + case FlagsAreOneRequired: return fmt.Sprintf("at least one of the flags in the group [%v] is required", e.flagList) - case flagGroupIsExclusive: + case FlagsAreMutuallyExclusive: return fmt.Sprintf("if any flags in the group [%v] are set none of the others can be; %v were all set", e.flagList, e.problemFlags) } diff --git a/flag_groups.go b/flag_groups.go index 401a618a..9bf4ee9a 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -97,12 +97,15 @@ func (c *Command) ValidateFlagGroups() error { }) if err := validateRequiredFlagGroups(groupStatus); err != nil { + err.cmd = c return err } if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil { + err.cmd = c return err } if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { + err.cmd = c return err } return nil @@ -141,7 +144,7 @@ func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annota } } -func validateRequiredFlagGroups(data map[string]map[string]bool) error { +func validateRequiredFlagGroups(data map[string]map[string]bool) *FlagGroupError { keys := sortedKeys(data) for _, flagList := range keys { flagnameAndStatus := data[flagList] @@ -160,7 +163,7 @@ func validateRequiredFlagGroups(data map[string]map[string]bool) error { sort.Strings(unset) return &FlagGroupError{ flagList: flagList, - flagGroupType: flagGroupIsRequired, + flagGroupType: FlagsAreRequiredTogether, problemFlags: unset, } } @@ -168,7 +171,7 @@ func validateRequiredFlagGroups(data map[string]map[string]bool) error { return nil } -func validateOneRequiredFlagGroups(data map[string]map[string]bool) error { +func validateOneRequiredFlagGroups(data map[string]map[string]bool) *FlagGroupError { keys := sortedKeys(data) for _, flagList := range keys { flagnameAndStatus := data[flagList] @@ -186,13 +189,13 @@ func validateOneRequiredFlagGroups(data map[string]map[string]bool) error { sort.Strings(set) return &FlagGroupError{ flagList: flagList, - flagGroupType: flagGroupIsOneRequired, + flagGroupType: FlagsAreOneRequired, } } return nil } -func validateExclusiveFlagGroups(data map[string]map[string]bool) error { +func validateExclusiveFlagGroups(data map[string]map[string]bool) *FlagGroupError { keys := sortedKeys(data) for _, flagList := range keys { flagnameAndStatus := data[flagList] @@ -210,7 +213,7 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error { sort.Strings(set) return &FlagGroupError{ flagList: flagList, - flagGroupType: flagGroupIsExclusive, + flagGroupType: FlagsAreMutuallyExclusive, problemFlags: set, } } From e37797b5edcf55842a2b1426c171678453c411cb Mon Sep 17 00:00:00 2001 From: "Ethan P." Date: Mon, 14 Apr 2025 23:39:06 -0700 Subject: [PATCH 4/5] test: Add tests for error structs --- args_test.go | 77 +++++++++++++++++ command_test.go | 16 ++++ errors_test.go | 195 ++++++++++++++++++++++++++++++++++++++++++++ flag_groups_test.go | 44 +++++++++- 4 files changed, 331 insertions(+), 1 deletion(-) create mode 100644 errors_test.go diff --git a/args_test.go b/args_test.go index 90d174cc..b1eb798a 100644 --- a/args_test.go +++ b/args_test.go @@ -15,7 +15,9 @@ package cobra import ( + "errors" "fmt" + "reflect" "strings" "testing" ) @@ -32,6 +34,14 @@ func getCommand(args PositionalArgs, withValid bool) *Command { return c } +func getCommandName(c *Command) string { + if c == nil { + return "" + } else { + return c.Name() + } +} + func expectSuccess(output string, err error, t *testing.T) { if output != "" { t.Errorf("Unexpected output: %v", output) @@ -41,6 +51,31 @@ func expectSuccess(output string, err error, t *testing.T) { } } +func expectErrorAs(err error, target error, t *testing.T) { + if err == nil { + t.Fatalf("Expected error, got nil") + } + + targetType := reflect.TypeOf(target) + targetPtr := reflect.New(targetType).Interface() // *SomeError + if !errors.As(err, targetPtr) { + t.Fatalf("Expected error to be %T, got %T", target, err) + } +} + +func expectErrorHasCommand(err error, cmd *Command, t *testing.T) { + getCommand, ok := err.(interface{ GetCommand() *Command }) + if !ok { + t.Fatalf("Expected error to have GetCommand method, but did not") + } + + got := getCommand.GetCommand() + if cmd != got { + t.Errorf("Expected err.GetCommand to return %v, got %v", + getCommandName(cmd), getCommandName(got)) + } +} + func validOnlyWithInvalidArgs(err error, t *testing.T) { if err == nil { t.Fatal("Expected an error") @@ -139,6 +174,13 @@ func TestNoArgs_WithValidOnly_WithInvalidArgs(t *testing.T) { validOnlyWithInvalidArgs(err, t) } +func TestNoArgs_ReturnsUnknownSubcommandError(t *testing.T) { + c := getCommand(NoArgs, false) + _, err := executeCommand(c, "a") + expectErrorAs(err, &UnknownSubcommandError{}, t) + expectErrorHasCommand(err, c, t) +} + // OnlyValidArgs func TestOnlyValidArgs(t *testing.T) { @@ -153,6 +195,13 @@ func TestOnlyValidArgs_WithInvalidArgs(t *testing.T) { validOnlyWithInvalidArgs(err, t) } +func TestOnlyValidArgs_ReturnsInvalidArgValueError(t *testing.T) { + c := getCommand(OnlyValidArgs, true) + _, err := executeCommand(c, "a") + expectErrorAs(err, &InvalidArgValueError{}, t) + expectErrorHasCommand(err, c, t) +} + // ArbitraryArgs func TestArbitraryArgs(t *testing.T) { @@ -229,6 +278,13 @@ func TestMinimumNArgs_WithLessArgs_WithValidOnly_WithInvalidArgs(t *testing.T) { validOnlyWithInvalidArgs(err, t) } +func TestMinimumNArgs_ReturnsInvalidArgCountError(t *testing.T) { + c := getCommand(MinimumNArgs(2), true) + _, err := executeCommand(c, "a") + expectErrorAs(err, &InvalidArgCountError{}, t) + expectErrorHasCommand(err, c, t) +} + // MaximumNArgs func TestMaximumNArgs(t *testing.T) { @@ -279,6 +335,13 @@ func TestMaximumNArgs_WithMoreArgs_WithValidOnly_WithInvalidArgs(t *testing.T) { validOnlyWithInvalidArgs(err, t) } +func TestMaximumNArgs_ReturnsInvalidArgCountError(t *testing.T) { + c := getCommand(MaximumNArgs(2), true) + _, err := executeCommand(c, "a", "b", "c") + expectErrorAs(err, &InvalidArgCountError{}, t) + expectErrorHasCommand(err, c, t) +} + // ExactArgs func TestExactArgs(t *testing.T) { @@ -329,6 +392,13 @@ func TestExactArgs_WithInvalidCount_WithValidOnly_WithInvalidArgs(t *testing.T) validOnlyWithInvalidArgs(err, t) } +func TestExactArgs_ReturnsInvalidArgCountError(t *testing.T) { + c := getCommand(ExactArgs(2), true) + _, err := executeCommand(c, "a") + expectErrorAs(err, &InvalidArgCountError{}, t) + expectErrorHasCommand(err, c, t) +} + // RangeArgs func TestRangeArgs(t *testing.T) { @@ -379,6 +449,13 @@ func TestRangeArgs_WithInvalidCount_WithValidOnly_WithInvalidArgs(t *testing.T) validOnlyWithInvalidArgs(err, t) } +func TestRangeArgs_ReturnsInvalidArgCountError(t *testing.T) { + c := getCommand(RangeArgs(2, 4), true) + _, err := executeCommand(c, "a") + expectErrorAs(err, &InvalidArgCountError{}, t) + expectErrorHasCommand(err, c, t) +} + // Takes(No)Args func TestRootTakesNoArgs(t *testing.T) { diff --git a/command_test.go b/command_test.go index 156df9eb..e7bd2f26 100644 --- a/command_test.go +++ b/command_test.go @@ -17,6 +17,7 @@ package cobra import ( "bytes" "context" + "errors" "fmt" "io" "os" @@ -866,6 +867,21 @@ func TestRequiredFlags(t *testing.T) { if got != expected { t.Errorf("Expected error: %q, got: %q", expected, got) } + + // Test it returns valid RequiredFlagError. + var requiredFlagErr *RequiredFlagError + if !errors.As(err, &requiredFlagErr) { + t.Fatalf("Expected error to be RequiredFlagError, got %T", err) + } + + expectedMissingFlagNames := "foo1 foo2" + gotMissingFlagNames := strings.Join(requiredFlagErr.missingFlagNames, " ") + if expectedMissingFlagNames != gotMissingFlagNames { + t.Errorf("Expected error missingFlagNames to be %q, got %q", + expectedMissingFlagNames, gotMissingFlagNames) + } + + expectErrorHasCommand(err, c, t) } func TestPersistentRequiredFlags(t *testing.T) { diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 00000000..eb92b015 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,195 @@ +// Copyright 2013-2023 The Cobra Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cobra + +import ( + "strings" + "testing" +) + +// InvalidArgCountError + +func TestInvalidArgCountError_GetCommand(t *testing.T) { + expected := &Command{} + err := &InvalidArgCountError{cmd: expected} + + got := err.GetCommand() + if got != expected { + t.Errorf("expected %v, got %v", + getCommandName(expected), getCommandName(got)) + } +} + +func TestInvalidArgCountError_GetArgs(t *testing.T) { + expected := []string{"a", "b", "c"} + err := &InvalidArgCountError{args: expected} + + got := err.GetArguments() + if strings.Join(expected, " ") != strings.Join(got, " ") { + t.Fatalf("expected %v, got %v", expected, got) + } +} + +func TestInvalidArgCountError_GetMinArgumentCount(t *testing.T) { + expected := 1 + err := &InvalidArgCountError{atLeast: expected} + + got := err.GetMinArgumentCount() + if got != expected { + t.Fatalf("expected %v, got %v", expected, got) + } +} + +func TestInvalidArgCountError_GetMaxArgumentCount(t *testing.T) { + expected := 1 + err := &InvalidArgCountError{atMost: expected} + + got := err.GetMaxArgumentCount() + if got != expected { + t.Fatalf("expected %v, got %v", expected, got) + } +} + +// InvalidArgValueError + +func TestInvalidArgValueError_GetCommand(t *testing.T) { + expected := &Command{} + err := &InvalidArgValueError{cmd: expected} + + got := err.GetCommand() + if got != expected { + t.Errorf("expected %v, got %v", + getCommandName(expected), getCommandName(got)) + } +} + +func TestInvalidArgValueError_GetArgument(t *testing.T) { + expected := "a" + err := &InvalidArgValueError{arg: expected} + + got := err.GetArgument() + if got != expected { + t.Fatalf("expected %v, got %v", expected, got) + } +} + +func TestInvalidArgValueError_GetSuggestions(t *testing.T) { + expected := "a" + err := &InvalidArgValueError{suggestions: expected} + + got := err.GetSuggestions() + if got != expected { + t.Fatalf("expected %v, got %v", expected, got) + } +} + +// UnknownSubcommandError + +func TestUnknownSubcommandError_GetCommand(t *testing.T) { + expected := &Command{} + err := &UnknownSubcommandError{cmd: expected} + + got := err.GetCommand() + if got != expected { + t.Errorf("expected %v, got %v", + getCommandName(expected), getCommandName(got)) + } +} + +func TestUnknownSubcommandError_GetSubcommand(t *testing.T) { + expected := "a" + err := &UnknownSubcommandError{subcmd: expected} + + got := err.GetSubcommand() + if got != expected { + t.Fatalf("expected %v, got %v", expected, got) + } +} + +func TestUnknownSubcommandError_GetSuggestions(t *testing.T) { + expected := "a" + err := &UnknownSubcommandError{suggestions: expected} + + got := err.GetSuggestions() + if got != expected { + t.Fatalf("expected %v, got %v", expected, got) + } +} + +// RequiredFlagError + +func TestRequiredFlagError_GetCommand(t *testing.T) { + expected := &Command{} + err := &UnknownSubcommandError{cmd: expected} + + got := err.GetCommand() + if got != expected { + t.Errorf("expected %v, got %v", + getCommandName(expected), getCommandName(got)) + } +} + +func TestRequiredFlagError_GetFlags(t *testing.T) { + expected := []string{"a", "b", "c"} + err := &RequiredFlagError{missingFlagNames: expected} + + got := err.GetFlags() + if strings.Join(expected, " ") != strings.Join(got, " ") { + t.Fatalf("expected %v, got %v", expected, got) + } +} + +// FlagGroupError + +func TestFlagGroupError_GetCommand(t *testing.T) { + expected := &Command{} + err := &FlagGroupError{cmd: expected} + + got := err.GetCommand() + if got != expected { + t.Errorf("expected %v, got %v", + getCommandName(expected), getCommandName(got)) + } +} + +func TestFlagGroupError_GetFlags(t *testing.T) { + expected := []string{"a", "b", "c"} + err := &FlagGroupError{flagList: "a b c"} + + got := err.GetFlags() + if strings.Join(expected, " ") != strings.Join(got, " ") { + t.Fatalf("expected %v, got %v", expected, got) + } +} + +func TestFlagGroupError_GetProblemFlags(t *testing.T) { + expected := []string{"a", "b", "c"} + err := &FlagGroupError{problemFlags: expected} + + got := err.GetProblemFlags() + if strings.Join(expected, " ") != strings.Join(got, " ") { + t.Fatalf("expected %v, got %v", expected, got) + } +} + +func TestFlagGroupError_GetFlagGroupType(t *testing.T) { + expected := FlagsAreMutuallyExclusive + err := &FlagGroupError{flagGroupType: expected} + + got := err.GetFlagGroupType() + if got != expected { + t.Fatalf("expected %v, got %v", expected, got) + } +} diff --git a/flag_groups_test.go b/flag_groups_test.go index cffa8552..acc62e3a 100644 --- a/flag_groups_test.go +++ b/flag_groups_test.go @@ -15,6 +15,7 @@ package cobra import ( + "errors" "strings" "testing" ) @@ -52,6 +53,7 @@ func TestValidateFlagGroups(t *testing.T) { subCmdFlagGroupsExclusive []string args []string expectErr string + expectErrGroupType FlagGroupType }{ { desc: "No flags no problem", @@ -64,64 +66,76 @@ func TestValidateFlagGroups(t *testing.T) { flagGroupsRequired: []string{"a b c"}, args: []string{"--a=foo"}, expectErr: "if any flags in the group [a b c] are set they must all be set; missing [b c]", + expectErrGroupType: FlagsAreRequiredTogether, }, { desc: "One-required flag group not satisfied", flagGroupsOneRequired: []string{"a b"}, args: []string{"--c=foo"}, expectErr: "at least one of the flags in the group [a b] is required", + expectErrGroupType: FlagsAreOneRequired, }, { desc: "Exclusive flag group not satisfied", flagGroupsExclusive: []string{"a b c"}, args: []string{"--a=foo", "--b=foo"}, expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set", + expectErrGroupType: FlagsAreMutuallyExclusive, }, { desc: "Multiple required flag group not satisfied returns first error", flagGroupsRequired: []string{"a b c", "a d"}, args: []string{"--c=foo", "--d=foo"}, expectErr: `if any flags in the group [a b c] are set they must all be set; missing [a b]`, + expectErrGroupType: FlagsAreRequiredTogether, }, { desc: "Multiple one-required flag group not satisfied returns first error", flagGroupsOneRequired: []string{"a b", "d e"}, args: []string{"--c=foo", "--f=foo"}, expectErr: `at least one of the flags in the group [a b] is required`, + expectErrGroupType: FlagsAreOneRequired, }, { desc: "Multiple exclusive flag group not satisfied returns first error", flagGroupsExclusive: []string{"a b c", "a d"}, args: []string{"--a=foo", "--c=foo", "--d=foo"}, expectErr: `if any flags in the group [a b c] are set none of the others can be; [a c] were all set`, + expectErrGroupType: FlagsAreMutuallyExclusive, }, { desc: "Validation of required groups occurs on groups in sorted order", flagGroupsRequired: []string{"a d", "a b", "a c"}, args: []string{"--a=foo"}, expectErr: `if any flags in the group [a b] are set they must all be set; missing [b]`, + expectErrGroupType: FlagsAreRequiredTogether, }, { desc: "Validation of one-required groups occurs on groups in sorted order", flagGroupsOneRequired: []string{"d e", "a b", "f g"}, args: []string{"--c=foo"}, expectErr: `at least one of the flags in the group [a b] is required`, + expectErrGroupType: FlagsAreOneRequired, }, { desc: "Validation of exclusive groups occurs on groups in sorted order", flagGroupsExclusive: []string{"a d", "a b", "a c"}, args: []string{"--a=foo", "--b=foo", "--c=foo"}, expectErr: `if any flags in the group [a b] are set none of the others can be; [a b] were all set`, + expectErrGroupType: FlagsAreMutuallyExclusive, }, { desc: "Persistent flags utilize required and exclusive groups and can fail required groups", flagGroupsRequired: []string{"a e", "e f"}, flagGroupsExclusive: []string{"f g"}, args: []string{"--a=foo", "--f=foo", "--g=foo"}, expectErr: `if any flags in the group [a e] are set they must all be set; missing [e]`, + expectErrGroupType: FlagsAreRequiredTogether, }, { desc: "Persistent flags utilize one-required and exclusive groups and can fail one-required groups", flagGroupsOneRequired: []string{"a b", "e f"}, flagGroupsExclusive: []string{"e f"}, args: []string{"--e=foo"}, expectErr: `at least one of the flags in the group [a b] is required`, + expectErrGroupType: FlagsAreOneRequired, }, { desc: "Persistent flags utilize required and exclusive groups and can fail mutually exclusive groups", flagGroupsRequired: []string{"a e", "e f"}, flagGroupsExclusive: []string{"f g"}, args: []string{"--a=foo", "--e=foo", "--f=foo", "--g=foo"}, expectErr: `if any flags in the group [f g] are set none of the others can be; [f g] were all set`, + expectErrGroupType: FlagsAreMutuallyExclusive, }, { desc: "Persistent flags utilize required and exclusive groups and can pass", flagGroupsRequired: []string{"a e", "e f"}, @@ -145,11 +159,13 @@ func TestValidateFlagGroups(t *testing.T) { subCmdFlagGroupsOneRequired: []string{"e subonly"}, args: []string{"subcmd"}, expectErr: "at least one of the flags in the group [e subonly] is required", + expectErrGroupType: FlagsAreOneRequired, }, { desc: "Subcmds can use exclusive groups using inherited flags", subCmdFlagGroupsExclusive: []string{"e subonly"}, args: []string{"subcmd", "--e=foo", "--subonly=foo"}, expectErr: "if any flags in the group [e subonly] are set none of the others can be; [e subonly] were all set", + expectErrGroupType: FlagsAreMutuallyExclusive, }, { desc: "Subcmds can use exclusive groups using inherited flags and pass", subCmdFlagGroupsExclusive: []string{"e subonly"}, @@ -183,13 +199,39 @@ func TestValidateFlagGroups(t *testing.T) { sub.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) } c.SetArgs(tc.args) - err := c.Execute() + executedCmd, err := c.ExecuteC() switch { case err == nil && len(tc.expectErr) > 0: t.Errorf("Expected error %q but got nil", tc.expectErr) case err != nil && err.Error() != tc.expectErr: t.Errorf("Expected error %q but got %q", tc.expectErr, err) } + + if len(tc.expectErr) > 0 { + var flagGroupErr *FlagGroupError + if !errors.As(err, &flagGroupErr) { + t.Fatalf("Expected error to be FlagGroupError, got %T", err) + } + + gotGroupType := flagGroupErr.flagGroupType + if gotGroupType != tc.expectErrGroupType { + t.Errorf("Expected FlagGroupError flag group type to be %q, got %q", + tc.expectErrGroupType, gotGroupType) + } + + if flagGroupErr.cmd != executedCmd { + t.Errorf("Expected FlagGroupError to have command %v, got %v", + getCommandName(executedCmd), getCommandName(flagGroupErr.cmd)) + } + + if flagGroupErr.flagList == "" { + t.Errorf("Expected FlagGroupError to have flagList, but was empty") + } + + if gotGroupType != FlagsAreOneRequired && flagGroupErr.problemFlags == nil { + t.Errorf("Expected FlagGroupError to have problemFlags, but was nil") + } + } }) } } From a980b4ae5d3b87160b5807c7ddbada52d13385d1 Mon Sep 17 00:00:00 2001 From: "Ethan P." Date: Tue, 15 Apr 2025 00:11:09 -0700 Subject: [PATCH 5/5] doc: Add user guide section on error structs --- site/content/user_guide.md | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/site/content/user_guide.md b/site/content/user_guide.md index 3f618825..109235c5 100644 --- a/site/content/user_guide.md +++ b/site/content/user_guide.md @@ -817,3 +817,41 @@ Flags: Use "kubectl myplugin [command] --help" for more information about a command. ``` + +## Error Structs + +Cobra uses structs for errors related to command-line argument validation. +If you need fine-grained details on why a command failed to execute, you can +use [errors.As](https://pkg.go.dev/errors#As) to unwrap the `error` as an +error struct. + +```go +package main + +import ( + "errors" + "fmt" + + "github.com/spf13/cobra" +) + +func main() { + var rootCmd = &cobra.Command{ + Use: "echo [value]", + Short: "Echo the first argument back", + Args: cobra.MinimumNArgs(1), + Run: func(cmd *cobra.Command, args []string) { + fmt.Println(args[0]) + }, + } + + err := rootCmd.Execute() + + var invalidArgCountErr *cobra.InvalidArgCountError + if errors.As(err, &invalidArgCountErr) { + fmt.Printf("At least %d arg(s) were needed for %q\n", + invalidArgCountErr.GetMinArgumentCount(), + invalidArgCountErr.GetCommand().Name()) + } +} +```