This commit is contained in:
nestoroprysk 2022-08-29 10:12:56 +09:00 committed by GitHub
commit d0d5d68eef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 3 deletions

View file

@ -1036,7 +1036,10 @@ func (c *Command) validateRequiredFlags() error {
}) })
if len(missingFlagNames) > 0 { if len(missingFlagNames) > 0 {
return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`)) return fmt.Errorf(`required %s "%s" not set`,
pluralize("flag", len(missingFlagNames)),
strings.Join(missingFlagNames, `", "`),
)
} }
return nil return nil
} }
@ -1695,3 +1698,11 @@ func (c *Command) updateParentsPflags() {
c.parentsPflags.AddFlagSet(parent.PersistentFlags()) c.parentsPflags.AddFlagSet(parent.PersistentFlags())
}) })
} }
func pluralize(name string, length int) string {
if length == 1 {
return name
}
return name + "s"
}

View file

@ -771,6 +771,21 @@ func TestPersistentFlagsOnChild(t *testing.T) {
} }
} }
func TestRequiredFlag(t *testing.T) {
c := &Command{Use: "c", Run: emptyRun}
c.Flags().String("foo1", "", "")
c.MarkFlagRequired("foo1")
expected := fmt.Sprintf("required flag %q not set", "foo1")
_, err := executeCommand(c)
got := err.Error()
if got != expected {
t.Errorf("Expected error: %q, got: %q", expected, got)
}
}
func TestRequiredFlags(t *testing.T) { func TestRequiredFlags(t *testing.T) {
c := &Command{Use: "c", Run: emptyRun} c := &Command{Use: "c", Run: emptyRun}
c.Flags().String("foo1", "", "") c.Flags().String("foo1", "", "")
@ -779,7 +794,7 @@ func TestRequiredFlags(t *testing.T) {
assertNoErr(t, c.MarkFlagRequired("foo2")) assertNoErr(t, c.MarkFlagRequired("foo2"))
c.Flags().String("bar", "", "") c.Flags().String("bar", "", "")
expected := fmt.Sprintf("required flag(s) %q, %q not set", "foo1", "foo2") expected := fmt.Sprintf("required flags %q, %q not set", "foo1", "foo2")
_, err := executeCommand(c) _, err := executeCommand(c)
got := err.Error() got := err.Error()
@ -806,7 +821,7 @@ func TestPersistentRequiredFlags(t *testing.T) {
parent.AddCommand(child) parent.AddCommand(child)
expected := fmt.Sprintf("required flag(s) %q, %q, %q, %q not set", "bar1", "bar2", "foo1", "foo2") expected := fmt.Sprintf("required flags %q, %q, %q, %q not set", "bar1", "bar2", "foo1", "foo2")
_, err := executeCommand(parent, "child") _, err := executeCommand(parent, "child")
if err.Error() != expected { if err.Error() != expected {