From f1caa7513d99fc13ffd67c236248ca0c49f82bb8 Mon Sep 17 00:00:00 2001 From: sivchari Date: Tue, 29 Apr 2025 14:53:56 +0900 Subject: [PATCH] Add new type to represent validating required flag error Signed-off-by: sivchari --- command.go | 28 +++++++++++++++++++++++++++- command_test.go | 12 ++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/command.go b/command.go index 4794e5eb..ab46580f 100644 --- a/command.go +++ b/command.go @@ -1176,6 +1176,30 @@ func (c *Command) ValidateArgs(args []string) error { return c.Args(c, args) } +// RequiredFlagError represents a failure to validate required flags. +type RequiredFlagError struct { + Err error +} + +// Error satisfies the error interface. +func (r *RequiredFlagError) Error() string { + return r.Err.Error() +} + +// Is satisfies the Is error interface. +func (r *RequiredFlagError) Is(target error) bool { + err, ok := target.(*RequiredFlagError) + if !ok { + return false + } + return r.Err == err +} + +// Unwrap satisfies Unwrap error interface. +func (r *RequiredFlagError) Unwrap() error { + return r.Err +} + // ValidateRequiredFlags validates all required flags are present and returns an error otherwise func (c *Command) ValidateRequiredFlags() error { if c.DisableFlagParsing { @@ -1195,7 +1219,9 @@ func (c *Command) ValidateRequiredFlags() error { }) if len(missingFlagNames) > 0 { - return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`)) + return &RequiredFlagError{ + Err: fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`)), + } } return nil } diff --git a/command_test.go b/command_test.go index 156df9eb..97fb6557 100644 --- a/command_test.go +++ b/command_test.go @@ -17,6 +17,7 @@ package cobra import ( "bytes" "context" + "errors" "fmt" "io" "os" @@ -2952,3 +2953,14 @@ func TestHelpFuncExecuted(t *testing.T) { checkStringContains(t, output, helpText) } + +func TestValidateRequiredFlags(t *testing.T) { + c := &Command{Use: "c", Run: emptyRun} + c.Flags().BoolP("boola", "a", false, "a boolean flag") + c.MarkFlagRequired("boola") + if err := c.ValidateRequiredFlags(); !errors.Is(err, &RequiredFlagError{ + Err: errors.New("required flag(s) \"boola\" not set"), + }) { + t.Fatalf("Expected error: %v, got: %v", "boola", err) + } +}