From 750785d1cca58a8226c2bc578d000355b1dfdcd4 Mon Sep 17 00:00:00 2001 From: "Ethan P." Date: Mon, 14 Apr 2025 21:51:45 -0700 Subject: [PATCH] feat: Use error structs for errors returned in arg validation --- args.go | 47 ++++++++++++++++++++++++++++++------- errors.go | 69 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 8 deletions(-) create mode 100644 errors.go diff --git a/args.go b/args.go index ed1e70ce..78cf02e7 100644 --- a/args.go +++ b/args.go @@ -15,7 +15,6 @@ package cobra import ( - "fmt" "strings" ) @@ -33,7 +32,11 @@ func legacyArgs(cmd *Command, args []string) error { // root command with subcommands, do subcommand checking. if !cmd.HasParent() && len(args) > 0 { - return fmt.Errorf("unknown command %q for %q%s", args[0], cmd.CommandPath(), cmd.findSuggestions(args[0])) + return &UnknownSubcommandError{ + cmd: cmd, + subcmd: args[0], + suggestions: cmd.findSuggestions(args[0]), + } } return nil } @@ -41,7 +44,11 @@ func legacyArgs(cmd *Command, args []string) error { // NoArgs returns an error if any args are included. func NoArgs(cmd *Command, args []string) error { if len(args) > 0 { - return fmt.Errorf("unknown command %q for %q", args[0], cmd.CommandPath()) + return &UnknownSubcommandError{ + cmd: cmd, + subcmd: args[0], + suggestions: "", + } } return nil } @@ -58,7 +65,11 @@ func OnlyValidArgs(cmd *Command, args []string) error { } for _, v := range args { if !stringInSlice(v, validArgs) { - return fmt.Errorf("invalid argument %q for %q%s", v, cmd.CommandPath(), cmd.findSuggestions(args[0])) + return &InvalidArgValueError{ + cmd: cmd, + arg: v, + suggestions: cmd.findSuggestions(args[0]), + } } } } @@ -74,7 +85,12 @@ func ArbitraryArgs(cmd *Command, args []string) error { func MinimumNArgs(n int) PositionalArgs { return func(cmd *Command, args []string) error { if len(args) < n { - return fmt.Errorf("requires at least %d arg(s), only received %d", n, len(args)) + return &InvalidArgCountError{ + cmd: cmd, + args: args, + atLeast: n, + atMost: -1, + } } return nil } @@ -84,7 +100,12 @@ func MinimumNArgs(n int) PositionalArgs { func MaximumNArgs(n int) PositionalArgs { return func(cmd *Command, args []string) error { if len(args) > n { - return fmt.Errorf("accepts at most %d arg(s), received %d", n, len(args)) + return &InvalidArgCountError{ + cmd: cmd, + args: args, + atLeast: -1, + atMost: n, + } } return nil } @@ -94,7 +115,12 @@ func MaximumNArgs(n int) PositionalArgs { func ExactArgs(n int) PositionalArgs { return func(cmd *Command, args []string) error { if len(args) != n { - return fmt.Errorf("accepts %d arg(s), received %d", n, len(args)) + return &InvalidArgCountError{ + cmd: cmd, + args: args, + atLeast: n, + atMost: n, + } } return nil } @@ -104,7 +130,12 @@ func ExactArgs(n int) PositionalArgs { func RangeArgs(min int, max int) PositionalArgs { return func(cmd *Command, args []string) error { if len(args) < min || len(args) > max { - return fmt.Errorf("accepts between %d and %d arg(s), received %d", min, max, len(args)) + return &InvalidArgCountError{ + cmd: cmd, + args: args, + atLeast: min, + atMost: max, + } } return nil } diff --git a/errors.go b/errors.go new file mode 100644 index 00000000..7d8fb167 --- /dev/null +++ b/errors.go @@ -0,0 +1,69 @@ +// Copyright 2013-2023 The Cobra Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cobra + +import "fmt" + +// InvalidArgCountError is the error returned when the wrong number of arguments +// are supplied to a command. +type InvalidArgCountError struct { + cmd *Command + args []string + atLeast int + atMost int +} + +// Error implements error. +func (e *InvalidArgCountError) Error() string { + if e.atMost == -1 && e.atLeast >= 0 { // MinimumNArgs + return fmt.Sprintf("requires at least %d arg(s), only received %d", e.atLeast, len(e.args)) + } + + if e.atLeast == -1 && e.atMost >= 0 { // MaximumNArgs + return fmt.Sprintf("accepts at most %d arg(s), received %d", e.atMost, len(e.args)) + } + + if e.atLeast == e.atMost && e.atLeast != -1 { // ExactArgs + return fmt.Sprintf("accepts %d arg(s), received %d", e.atLeast, len(e.args)) + } + + // RangeArgs + return fmt.Sprintf("accepts between %d and %d arg(s), received %d", e.atLeast, e.atMost, len(e.args)) +} + +// InvalidArgCountError is the error returned an invalid argument is present. +type InvalidArgValueError struct { + cmd *Command + arg string + suggestions string +} + +// Error implements error. +func (e *InvalidArgValueError) Error() string { + return fmt.Sprintf("invalid argument %q for %q%s", e.arg, e.cmd.CommandPath(), e.suggestions) +} + +// UnknownSubcommandError is the error returned when a subcommand can not be +// found. +type UnknownSubcommandError struct { + cmd *Command + subcmd string + suggestions string +} + +// Error implements error. +func (e *UnknownSubcommandError) Error() string { + return fmt.Sprintf("unknown command %q for %q%s", e.subcmd, e.cmd.CommandPath(), e.suggestions) +}