mirror of
https://github.com/spf13/cobra
synced 2025-05-02 11:27:21 +00:00
Merge f1caa7513d
into ceb39aba25
This commit is contained in:
commit
d6e35efe80
2 changed files with 39 additions and 1 deletions
28
command.go
28
command.go
|
@ -1176,6 +1176,30 @@ func (c *Command) ValidateArgs(args []string) error {
|
|||
return c.Args(c, args)
|
||||
}
|
||||
|
||||
// RequiredFlagError represents a failure to validate required flags.
|
||||
type RequiredFlagError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// Error satisfies the error interface.
|
||||
func (r *RequiredFlagError) Error() string {
|
||||
return r.Err.Error()
|
||||
}
|
||||
|
||||
// Is satisfies the Is error interface.
|
||||
func (r *RequiredFlagError) Is(target error) bool {
|
||||
err, ok := target.(*RequiredFlagError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return r.Err == err
|
||||
}
|
||||
|
||||
// Unwrap satisfies Unwrap error interface.
|
||||
func (r *RequiredFlagError) Unwrap() error {
|
||||
return r.Err
|
||||
}
|
||||
|
||||
// ValidateRequiredFlags validates all required flags are present and returns an error otherwise
|
||||
func (c *Command) ValidateRequiredFlags() error {
|
||||
if c.DisableFlagParsing {
|
||||
|
@ -1195,7 +1219,9 @@ func (c *Command) ValidateRequiredFlags() error {
|
|||
})
|
||||
|
||||
if len(missingFlagNames) > 0 {
|
||||
return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`))
|
||||
return &RequiredFlagError{
|
||||
Err: fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`)),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ package cobra
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
@ -2952,3 +2953,14 @@ func TestHelpFuncExecuted(t *testing.T) {
|
|||
|
||||
checkStringContains(t, output, helpText)
|
||||
}
|
||||
|
||||
func TestValidateRequiredFlags(t *testing.T) {
|
||||
c := &Command{Use: "c", Run: emptyRun}
|
||||
c.Flags().BoolP("boola", "a", false, "a boolean flag")
|
||||
c.MarkFlagRequired("boola")
|
||||
if err := c.ValidateRequiredFlags(); !errors.Is(err, &RequiredFlagError{
|
||||
Err: errors.New("required flag(s) \"boola\" not set"),
|
||||
}) {
|
||||
t.Fatalf("Expected error: %v, got: %v", "boola", err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue