This commit is contained in:
Niels Claeys 2022-09-21 11:53:06 +02:00 committed by GitHub
commit 7de41f90ff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 2 deletions

View file

@ -863,7 +863,7 @@ func (c *Command) execute(a []string) (err error) {
} }
if err := c.validateRequiredFlags(); err != nil { if err := c.validateRequiredFlags(); err != nil {
return err return c.FlagErrorFunc()(c, err)
} }
if err := c.validateFlagGroups(); err != nil { if err := c.validateFlagGroups(); err != nil {
return err return err

View file

@ -17,6 +17,7 @@ package cobra
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@ -792,7 +793,6 @@ func TestRequiredFlags(t *testing.T) {
c.Flags().String("foo2", "", "") c.Flags().String("foo2", "", "")
assertNoErr(t, c.MarkFlagRequired("foo2")) assertNoErr(t, c.MarkFlagRequired("foo2"))
c.Flags().String("bar", "", "") c.Flags().String("bar", "", "")
expected := fmt.Sprintf("required flag(s) %q, %q not set", "foo1", "foo2") expected := fmt.Sprintf("required flag(s) %q, %q not set", "foo1", "foo2")
_, err := executeCommand(c) _, err := executeCommand(c)
@ -803,6 +803,58 @@ func TestRequiredFlags(t *testing.T) {
} }
} }
func TestRequiredFlagsWithCustomFlagErrorFunc(t *testing.T) {
usageFunc := func(c *Command) error {
c.Println("usage string")
return nil
}
c := &Command{Use: "c", Run: emptyRun, SilenceUsage: true, usageFunc: usageFunc}
c.Flags().String("foo1", "", "")
assertNoErr(t, c.MarkFlagRequired("foo1"))
silentError := "failed flag parsing"
c.SetFlagErrorFunc(func(c *Command, err error) error {
c.Println(err)
c.Println(c.UsageString())
return errors.New(silentError)
})
requiredFlagErrorMessage := fmt.Sprintf("required flag(s) %q not set", "foo1")
output, err := executeCommand(c)
got := err.Error()
checkStringContains(t, output, requiredFlagErrorMessage)
checkStringContains(t, output, c.UsageString())
if got != silentError {
t.Errorf("Expected error %s but got %s", silentError, got)
}
}
func TestUnexistingFlagsWithCustomFlagErrorFunc(t *testing.T) {
usageFunc := func(c *Command) error {
c.Println("usage string")
return nil
}
c := &Command{Use: "c", Run: emptyRun, SilenceUsage: true, usageFunc: usageFunc}
c.Flags().String("foo1", "", "")
assertNoErr(t, c.MarkFlagRequired("foo1"))
silentError := "failed flag parsing"
c.SetFlagErrorFunc(func(c *Command, err error) error {
c.Println(err)
c.Println(c.UsageString())
return errors.New(silentError)
})
unknownFlagErrorMessage := fmt.Sprintf("unknown flag: %s", "--unknownflag")
output, err := executeCommand(c, "--unknownflag")
got := err.Error()
checkStringContains(t, output, unknownFlagErrorMessage)
checkStringContains(t, output, c.UsageString())
if got != silentError {
t.Errorf("Expected error %s but got %s", silentError, got)
}
}
func TestPersistentRequiredFlags(t *testing.T) { func TestPersistentRequiredFlags(t *testing.T) {
parent := &Command{Use: "parent", Run: emptyRun} parent := &Command{Use: "parent", Run: emptyRun}
parent.PersistentFlags().String("foo1", "", "") parent.PersistentFlags().String("foo1", "", "")