This commit is contained in:
Jun Nishimura 2023-12-16 21:21:02 +09:00 committed by GitHub
commit 08e49f5d41
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 268 additions and 10 deletions

View file

@ -155,6 +155,8 @@ type Command struct {
pflags *flag.FlagSet pflags *flag.FlagSet
// lflags contains local flags. // lflags contains local flags.
lflags *flag.FlagSet lflags *flag.FlagSet
// lnpflags contains local non persistent flags
lnpflags *flag.FlagSet
// iflags contains inherited flags. // iflags contains inherited flags.
iflags *flag.FlagSet iflags *flag.FlagSet
// parentsPflags is all persistent flags of cmd's parents. // parentsPflags is all persistent flags of cmd's parents.
@ -1074,7 +1076,6 @@ func (c *Command) ExecuteC() (cmd *Command, err error) {
c.checkCommandGroups() c.checkCommandGroups()
args := c.args args := c.args
// Workaround FAIL with "go test -v" or "cobra.test -test.v", see #155 // Workaround FAIL with "go test -v" or "cobra.test -test.v", see #155
if c.args == nil && filepath.Base(os.Args[0]) != "cobra.test" { if c.args == nil && filepath.Base(os.Args[0]) != "cobra.test" {
args = os.Args[1:] args = os.Args[1:]
@ -1654,15 +1655,19 @@ func (c *Command) Flags() *flag.FlagSet {
// LocalNonPersistentFlags are flags specific to this command which will NOT persist to subcommands. // LocalNonPersistentFlags are flags specific to this command which will NOT persist to subcommands.
func (c *Command) LocalNonPersistentFlags() *flag.FlagSet { func (c *Command) LocalNonPersistentFlags() *flag.FlagSet {
persistentFlags := c.PersistentFlags() if c.lnpflags == nil {
persistentFlags := c.PersistentFlags()
out := flag.NewFlagSet(c.Name(), flag.ContinueOnError) c.lnpflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
c.LocalFlags().VisitAll(func(f *flag.Flag) { c.LocalFlags().VisitAll(func(f *flag.Flag) {
if persistentFlags.Lookup(f.Name) == nil { if persistentFlags.Lookup(f.Name) == nil {
out.AddFlag(f) f.Changed = false
} c.lnpflags.AddFlag(f)
}) }
return out })
}
return c.lnpflags
} }
// LocalFlags returns the local FlagSet specifically set in the current command. // LocalFlags returns the local FlagSet specifically set in the current command.
@ -1684,6 +1689,7 @@ func (c *Command) LocalFlags() *flag.FlagSet {
addToLocal := func(f *flag.Flag) { addToLocal := func(f *flag.Flag) {
// Add the flag if it is not a parent PFlag, or it shadows a parent PFlag // Add the flag if it is not a parent PFlag, or it shadows a parent PFlag
if c.lflags.Lookup(f.Name) == nil && f != c.parentsPflags.Lookup(f.Name) { if c.lflags.Lookup(f.Name) == nil && f != c.parentsPflags.Lookup(f.Name) {
f.Changed = false
c.lflags.AddFlag(f) c.lflags.AddFlag(f)
} }
} }
@ -1815,12 +1821,97 @@ func (c *Command) persistentFlag(name string) (flag *flag.Flag) {
return return
} }
func (c *Command) parseLongArgs(s string, args []string, flags *flag.FlagSet) (passedArgs, restArgs []string) {
restArgs = args
name := s[2:]
if len(name) == 0 {
passedArgs = append(passedArgs, s)
return
}
split := strings.SplitN(s[2:], "=", 2)
name = split[0]
searchedFlag := flags.Lookup(name)
if searchedFlag == nil {
// ignore the flag that is not registered in passed flags but is registered in c.parentsPflags
c.parentsPflags.VisitAll(func(f *flag.Flag) {
if name == f.Name {
if len(split) == 1 && f.NoOptDefVal == "" && len(args) > 0 {
// '--flag arg'
restArgs = args[1:]
}
}
})
return
}
passedArgs = append(passedArgs, fmt.Sprintf("--%s", s[2:]))
if len(split) == 1 && searchedFlag.NoOptDefVal == "" && len(args) > 0 {
passedArgs = append(passedArgs, args[0])
restArgs = args[1:]
}
return
}
func (c *Command) parseShortArgs(s string, args []string, flags *flag.FlagSet) (passedArgs []string, restArgs []string) {
restArgs = args
shorthands := s[1:]
shorthand := string(s[1])
searchedFlag := flags.ShorthandLookup(shorthand)
if searchedFlag == nil {
// ignore the flag that is not registered in passed flags but is registered in c.parentsPflags
c.parentsPflags.VisitAll(func(f *flag.Flag) {
if shorthand == f.Shorthand {
if len(shorthands) == 1 && f.NoOptDefVal == "" && len(args) > 0 {
// '-f arg'
restArgs = args[1:]
}
}
})
return
}
passedArgs = append(passedArgs, s)
if len(shorthands) == 1 && searchedFlag.NoOptDefVal == "" && len(args) > 0 {
// '-f arg'
passedArgs = append(passedArgs, args[0])
restArgs = args[1:]
}
return
}
func (c *Command) removeParentPersistentArgs(args []string, flags *flag.FlagSet) (newArgs []string) {
for len(args) > 0 {
s := args[0]
args = args[1:]
if len(s) == 0 || s[0] != '-' {
newArgs = append(newArgs, s)
continue
}
var passedArgs, restArgs []string
if s[1] == '-' {
passedArgs, restArgs = c.parseLongArgs(s, args, flags)
} else {
passedArgs, restArgs = c.parseShortArgs(s, args, flags)
}
if len(passedArgs) > 0 {
newArgs = append(newArgs, passedArgs...)
}
args = restArgs
}
return
}
// ParseFlags parses persistent flag tree and local flags. // ParseFlags parses persistent flag tree and local flags.
func (c *Command) ParseFlags(args []string) error { func (c *Command) ParseFlags(args []string) error {
if c.DisableFlagParsing { if c.DisableFlagParsing {
return nil return nil
} }
if c.flagErrorBuf == nil { if c.flagErrorBuf == nil {
c.flagErrorBuf = new(bytes.Buffer) c.flagErrorBuf = new(bytes.Buffer)
} }
@ -1830,11 +1921,38 @@ func (c *Command) ParseFlags(args []string) error {
// do it here after merging all flags and just before parse // do it here after merging all flags and just before parse
c.Flags().ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist) c.Flags().ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist)
// parse Flags
err := c.Flags().Parse(args) err := c.Flags().Parse(args)
// Print warnings if they occurred (e.g. deprecated flag messages). // Print warnings if they occurred (e.g. deprecated flag messages).
if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil { if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil {
c.Print(c.flagErrorBuf.String()) c.Print(c.flagErrorBuf.String())
} }
if err != nil {
return err
}
// parse Local Flags
c.LocalFlags() // need to execute LocalFlags() to set the value in c.lflags before executing removeParentPersistentArgs
c.lflags.ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist)
localArgs := c.removeParentPersistentArgs(args, c.lflags) // get only arguments related to c.lflags
err = c.lflags.Parse(localArgs)
// Print warnings if they occurred (e.g. deprecated flag messages).
if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil {
c.Print(c.flagErrorBuf.String())
}
if err != nil {
return err
}
// parse local non persistent flags
c.LocalNonPersistentFlags() // need to execute LocalNonPersistentFlags() to set the value in c.lnpflags before executing removeParentPersistentArgs
c.lnpflags.ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist)
localNonPersistentArgs := c.removeParentPersistentArgs(args, c.lnpflags)
err = c.lnpflags.Parse(localNonPersistentArgs)
// Print warnings if they occurred (e.g. deprecated flag messages).
if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil {
c.Print(c.flagErrorBuf.String())
}
return err return err
} }

View file

@ -2786,3 +2786,143 @@ func TestUnknownFlagShouldReturnSameErrorRegardlessOfArgPosition(t *testing.T) {
}) })
} }
} }
func TestNFlagForFlags(t *testing.T) {
var rootNFlag, childNFlag int
rootCmd := &Command{
Use: "root",
Run: func(cmd *Command, _ []string) {
rootNFlag = cmd.Flags().NFlag()
},
}
childCmd := &Command{
Use: "child",
Run: func(cmd *Command, args []string) {
childNFlag = cmd.Flags().NFlag()
},
}
rootCmd.AddCommand(childCmd)
rootCmd.PersistentFlags().Bool("rp", false, "")
rpFlag := rootCmd.PersistentFlags().Lookup("rp")
childCmd.PersistentFlags().Bool("cp", false, "")
childCmd.Flags().Int("int", 0, "")
output, err := executeCommand(rootCmd, "--rp")
if output != "" {
t.Errorf("Unexpected output: %v", output)
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if rootNFlag != 1 {
t.Errorf("Expected NFlag: %v, got %v", 1, rootNFlag)
}
// set Changed false for the next test
rpFlag.Changed = false
output, err = executeCommand(rootCmd, "child", "--rp", "--cp", "--int", "10")
if output != "" {
t.Errorf("Unexpected output: %v", output)
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if childNFlag != 3 {
t.Errorf("Expected NFlag: %v, got %v", 3, childNFlag)
}
}
func TestNFlagForLocalFlags(t *testing.T) {
var localNFlag int
rootCmd := &Command{
Use: "root",
Run: emptyRun,
}
childCmd := &Command{
Use: "child",
Run: func(cmd *Command, args []string) {
localNFlag = cmd.LocalFlags().NFlag()
},
}
rootCmd.AddCommand(childCmd)
rootCmd.PersistentFlags().Bool("rp", false, "")
childCmd.PersistentFlags().Bool("cp", false, "")
childCmd.Flags().Int("int", 0, "")
output, err := executeCommand(rootCmd, "child", "--rp", "--cp", "--int", "10")
if output != "" {
t.Errorf("Unexpected output: %v", output)
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if localNFlag != 2 { // LocalFlags().NFlag() ignores '--rp'
t.Errorf("Expected NFlag: %v, got %v", 2, localNFlag)
}
}
func TestNFlagForLocalNonPersistentFlags(t *testing.T) {
var localNonPNFlag int
rootCmd := &Command{
Use: "root",
Run: emptyRun,
}
childCmd := &Command{
Use: "child",
Run: func(cmd *Command, args []string) {
localNonPNFlag = cmd.LocalNonPersistentFlags().NFlag()
},
}
rootCmd.AddCommand(childCmd)
rootCmd.PersistentFlags().Bool("rp", false, "")
childCmd.PersistentFlags().Bool("cp", false, "")
childCmd.Flags().Int("int", 0, "")
output, err := executeCommand(rootCmd, "child", "--rp", "--cp", "--int", "10")
if output != "" {
t.Errorf("Unexpected output: %v", output)
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if localNonPNFlag != 1 { // LocalNonPersistentFlags().NFlag() ignores '--rp' and '--cp'
t.Errorf("Expected NFlag: %v, got %v", 1, localNonPNFlag)
}
}
func TestRemoveParentPersistentArgs(t *testing.T) {
rootCmd := &Command{Use: "root", Run: emptyRun}
childCmd := &Command{Use: "child", Run: emptyRun}
rootCmd.AddCommand(childCmd)
rootCmd.PersistentFlags().BoolP("rp", "r", false, "")
rootCmd.PersistentFlags().Int("ri", 0, "")
childCmd.PersistentFlags().Bool("cp", false, "")
childCmd.Flags().Int("int", 0, "")
output, err := executeCommand(rootCmd, "child", "-r", "--ri", "10", "--cp", "--int", "10")
if output != "" {
t.Errorf("Unexpected output: %v", output)
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
args := rootCmd.args
_, args, _ = rootCmd.Find(args)
gotLocalArgs := childCmd.removeParentPersistentArgs(args, childCmd.lflags)
expectedLocalArgs := []string{"--cp", "--int", "10"}
if !reflect.DeepEqual(gotLocalArgs, expectedLocalArgs) {
t.Errorf("Expected localArgs: %v, got %v", expectedLocalArgs, gotLocalArgs)
}
gotLocalNonPersistentArgs := childCmd.removeParentPersistentArgs(args, childCmd.lnpflags)
expectedLocalNonPersistentArgs := []string{"--int", "10"}
if !reflect.DeepEqual(gotLocalNonPersistentArgs, expectedLocalNonPersistentArgs) {
t.Errorf("Expected localArgs: %v, got %v", expectedLocalNonPersistentArgs, gotLocalNonPersistentArgs)
}
}