mirror of
https://github.com/spf13/cobra
synced 2025-05-05 12:57:22 +00:00
Merge aa02e412ac
into a281c8b47b
This commit is contained in:
commit
7de41f90ff
2 changed files with 54 additions and 2 deletions
|
@ -863,7 +863,7 @@ func (c *Command) execute(a []string) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.validateRequiredFlags(); err != nil {
|
if err := c.validateRequiredFlags(); err != nil {
|
||||||
return err
|
return c.FlagErrorFunc()(c, err)
|
||||||
}
|
}
|
||||||
if err := c.validateFlagGroups(); err != nil {
|
if err := c.validateFlagGroups(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -17,6 +17,7 @@ package cobra
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
|
@ -792,7 +793,6 @@ func TestRequiredFlags(t *testing.T) {
|
||||||
c.Flags().String("foo2", "", "")
|
c.Flags().String("foo2", "", "")
|
||||||
assertNoErr(t, c.MarkFlagRequired("foo2"))
|
assertNoErr(t, c.MarkFlagRequired("foo2"))
|
||||||
c.Flags().String("bar", "", "")
|
c.Flags().String("bar", "", "")
|
||||||
|
|
||||||
expected := fmt.Sprintf("required flag(s) %q, %q not set", "foo1", "foo2")
|
expected := fmt.Sprintf("required flag(s) %q, %q not set", "foo1", "foo2")
|
||||||
|
|
||||||
_, err := executeCommand(c)
|
_, err := executeCommand(c)
|
||||||
|
@ -803,6 +803,58 @@ func TestRequiredFlags(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRequiredFlagsWithCustomFlagErrorFunc(t *testing.T) {
|
||||||
|
usageFunc := func(c *Command) error {
|
||||||
|
c.Println("usage string")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c := &Command{Use: "c", Run: emptyRun, SilenceUsage: true, usageFunc: usageFunc}
|
||||||
|
c.Flags().String("foo1", "", "")
|
||||||
|
assertNoErr(t, c.MarkFlagRequired("foo1"))
|
||||||
|
silentError := "failed flag parsing"
|
||||||
|
c.SetFlagErrorFunc(func(c *Command, err error) error {
|
||||||
|
c.Println(err)
|
||||||
|
c.Println(c.UsageString())
|
||||||
|
return errors.New(silentError)
|
||||||
|
})
|
||||||
|
requiredFlagErrorMessage := fmt.Sprintf("required flag(s) %q not set", "foo1")
|
||||||
|
|
||||||
|
output, err := executeCommand(c)
|
||||||
|
|
||||||
|
got := err.Error()
|
||||||
|
checkStringContains(t, output, requiredFlagErrorMessage)
|
||||||
|
checkStringContains(t, output, c.UsageString())
|
||||||
|
if got != silentError {
|
||||||
|
t.Errorf("Expected error %s but got %s", silentError, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnexistingFlagsWithCustomFlagErrorFunc(t *testing.T) {
|
||||||
|
usageFunc := func(c *Command) error {
|
||||||
|
c.Println("usage string")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c := &Command{Use: "c", Run: emptyRun, SilenceUsage: true, usageFunc: usageFunc}
|
||||||
|
c.Flags().String("foo1", "", "")
|
||||||
|
assertNoErr(t, c.MarkFlagRequired("foo1"))
|
||||||
|
silentError := "failed flag parsing"
|
||||||
|
c.SetFlagErrorFunc(func(c *Command, err error) error {
|
||||||
|
c.Println(err)
|
||||||
|
c.Println(c.UsageString())
|
||||||
|
return errors.New(silentError)
|
||||||
|
})
|
||||||
|
unknownFlagErrorMessage := fmt.Sprintf("unknown flag: %s", "--unknownflag")
|
||||||
|
|
||||||
|
output, err := executeCommand(c, "--unknownflag")
|
||||||
|
|
||||||
|
got := err.Error()
|
||||||
|
checkStringContains(t, output, unknownFlagErrorMessage)
|
||||||
|
checkStringContains(t, output, c.UsageString())
|
||||||
|
if got != silentError {
|
||||||
|
t.Errorf("Expected error %s but got %s", silentError, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPersistentRequiredFlags(t *testing.T) {
|
func TestPersistentRequiredFlags(t *testing.T) {
|
||||||
parent := &Command{Use: "parent", Run: emptyRun}
|
parent := &Command{Use: "parent", Run: emptyRun}
|
||||||
parent.PersistentFlags().String("foo1", "", "")
|
parent.PersistentFlags().String("foo1", "", "")
|
||||||
|
|
Loading…
Add table
Reference in a new issue