mirror of
https://github.com/spf13/cobra
synced 2025-05-05 04:47:22 +00:00
Merge aa02e412ac
into a281c8b47b
This commit is contained in:
commit
7de41f90ff
2 changed files with 54 additions and 2 deletions
|
@ -863,7 +863,7 @@ func (c *Command) execute(a []string) (err error) {
|
|||
}
|
||||
|
||||
if err := c.validateRequiredFlags(); err != nil {
|
||||
return err
|
||||
return c.FlagErrorFunc()(c, err)
|
||||
}
|
||||
if err := c.validateFlagGroups(); err != nil {
|
||||
return err
|
||||
|
|
|
@ -17,6 +17,7 @@ package cobra
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
@ -792,7 +793,6 @@ func TestRequiredFlags(t *testing.T) {
|
|||
c.Flags().String("foo2", "", "")
|
||||
assertNoErr(t, c.MarkFlagRequired("foo2"))
|
||||
c.Flags().String("bar", "", "")
|
||||
|
||||
expected := fmt.Sprintf("required flag(s) %q, %q not set", "foo1", "foo2")
|
||||
|
||||
_, err := executeCommand(c)
|
||||
|
@ -803,6 +803,58 @@ func TestRequiredFlags(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRequiredFlagsWithCustomFlagErrorFunc(t *testing.T) {
|
||||
usageFunc := func(c *Command) error {
|
||||
c.Println("usage string")
|
||||
return nil
|
||||
}
|
||||
c := &Command{Use: "c", Run: emptyRun, SilenceUsage: true, usageFunc: usageFunc}
|
||||
c.Flags().String("foo1", "", "")
|
||||
assertNoErr(t, c.MarkFlagRequired("foo1"))
|
||||
silentError := "failed flag parsing"
|
||||
c.SetFlagErrorFunc(func(c *Command, err error) error {
|
||||
c.Println(err)
|
||||
c.Println(c.UsageString())
|
||||
return errors.New(silentError)
|
||||
})
|
||||
requiredFlagErrorMessage := fmt.Sprintf("required flag(s) %q not set", "foo1")
|
||||
|
||||
output, err := executeCommand(c)
|
||||
|
||||
got := err.Error()
|
||||
checkStringContains(t, output, requiredFlagErrorMessage)
|
||||
checkStringContains(t, output, c.UsageString())
|
||||
if got != silentError {
|
||||
t.Errorf("Expected error %s but got %s", silentError, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnexistingFlagsWithCustomFlagErrorFunc(t *testing.T) {
|
||||
usageFunc := func(c *Command) error {
|
||||
c.Println("usage string")
|
||||
return nil
|
||||
}
|
||||
c := &Command{Use: "c", Run: emptyRun, SilenceUsage: true, usageFunc: usageFunc}
|
||||
c.Flags().String("foo1", "", "")
|
||||
assertNoErr(t, c.MarkFlagRequired("foo1"))
|
||||
silentError := "failed flag parsing"
|
||||
c.SetFlagErrorFunc(func(c *Command, err error) error {
|
||||
c.Println(err)
|
||||
c.Println(c.UsageString())
|
||||
return errors.New(silentError)
|
||||
})
|
||||
unknownFlagErrorMessage := fmt.Sprintf("unknown flag: %s", "--unknownflag")
|
||||
|
||||
output, err := executeCommand(c, "--unknownflag")
|
||||
|
||||
got := err.Error()
|
||||
checkStringContains(t, output, unknownFlagErrorMessage)
|
||||
checkStringContains(t, output, c.UsageString())
|
||||
if got != silentError {
|
||||
t.Errorf("Expected error %s but got %s", silentError, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistentRequiredFlags(t *testing.T) {
|
||||
parent := &Command{Use: "parent", Run: emptyRun}
|
||||
parent.PersistentFlags().String("foo1", "", "")
|
||||
|
|
Loading…
Add table
Reference in a new issue