mirror of
https://github.com/spf13/cobra
synced 2025-04-27 00:57:23 +00:00
Merge 9a4aaf6d9c
into 236f3c0418
This commit is contained in:
commit
08e49f5d41
2 changed files with 268 additions and 10 deletions
138
command.go
138
command.go
|
@ -155,6 +155,8 @@ type Command struct {
|
|||
pflags *flag.FlagSet
|
||||
// lflags contains local flags.
|
||||
lflags *flag.FlagSet
|
||||
// lnpflags contains local non persistent flags
|
||||
lnpflags *flag.FlagSet
|
||||
// iflags contains inherited flags.
|
||||
iflags *flag.FlagSet
|
||||
// parentsPflags is all persistent flags of cmd's parents.
|
||||
|
@ -1074,7 +1076,6 @@ func (c *Command) ExecuteC() (cmd *Command, err error) {
|
|||
c.checkCommandGroups()
|
||||
|
||||
args := c.args
|
||||
|
||||
// Workaround FAIL with "go test -v" or "cobra.test -test.v", see #155
|
||||
if c.args == nil && filepath.Base(os.Args[0]) != "cobra.test" {
|
||||
args = os.Args[1:]
|
||||
|
@ -1654,15 +1655,19 @@ func (c *Command) Flags() *flag.FlagSet {
|
|||
|
||||
// LocalNonPersistentFlags are flags specific to this command which will NOT persist to subcommands.
|
||||
func (c *Command) LocalNonPersistentFlags() *flag.FlagSet {
|
||||
persistentFlags := c.PersistentFlags()
|
||||
if c.lnpflags == nil {
|
||||
persistentFlags := c.PersistentFlags()
|
||||
|
||||
out := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
|
||||
c.LocalFlags().VisitAll(func(f *flag.Flag) {
|
||||
if persistentFlags.Lookup(f.Name) == nil {
|
||||
out.AddFlag(f)
|
||||
}
|
||||
})
|
||||
return out
|
||||
c.lnpflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
|
||||
c.LocalFlags().VisitAll(func(f *flag.Flag) {
|
||||
if persistentFlags.Lookup(f.Name) == nil {
|
||||
f.Changed = false
|
||||
c.lnpflags.AddFlag(f)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return c.lnpflags
|
||||
}
|
||||
|
||||
// LocalFlags returns the local FlagSet specifically set in the current command.
|
||||
|
@ -1684,6 +1689,7 @@ func (c *Command) LocalFlags() *flag.FlagSet {
|
|||
addToLocal := func(f *flag.Flag) {
|
||||
// Add the flag if it is not a parent PFlag, or it shadows a parent PFlag
|
||||
if c.lflags.Lookup(f.Name) == nil && f != c.parentsPflags.Lookup(f.Name) {
|
||||
f.Changed = false
|
||||
c.lflags.AddFlag(f)
|
||||
}
|
||||
}
|
||||
|
@ -1815,12 +1821,97 @@ func (c *Command) persistentFlag(name string) (flag *flag.Flag) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *Command) parseLongArgs(s string, args []string, flags *flag.FlagSet) (passedArgs, restArgs []string) {
|
||||
restArgs = args
|
||||
name := s[2:]
|
||||
if len(name) == 0 {
|
||||
passedArgs = append(passedArgs, s)
|
||||
return
|
||||
}
|
||||
|
||||
split := strings.SplitN(s[2:], "=", 2)
|
||||
name = split[0]
|
||||
searchedFlag := flags.Lookup(name)
|
||||
if searchedFlag == nil {
|
||||
// ignore the flag that is not registered in passed flags but is registered in c.parentsPflags
|
||||
c.parentsPflags.VisitAll(func(f *flag.Flag) {
|
||||
if name == f.Name {
|
||||
if len(split) == 1 && f.NoOptDefVal == "" && len(args) > 0 {
|
||||
// '--flag arg'
|
||||
restArgs = args[1:]
|
||||
}
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
passedArgs = append(passedArgs, fmt.Sprintf("--%s", s[2:]))
|
||||
if len(split) == 1 && searchedFlag.NoOptDefVal == "" && len(args) > 0 {
|
||||
passedArgs = append(passedArgs, args[0])
|
||||
restArgs = args[1:]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Command) parseShortArgs(s string, args []string, flags *flag.FlagSet) (passedArgs []string, restArgs []string) {
|
||||
restArgs = args
|
||||
|
||||
shorthands := s[1:]
|
||||
shorthand := string(s[1])
|
||||
|
||||
searchedFlag := flags.ShorthandLookup(shorthand)
|
||||
if searchedFlag == nil {
|
||||
// ignore the flag that is not registered in passed flags but is registered in c.parentsPflags
|
||||
c.parentsPflags.VisitAll(func(f *flag.Flag) {
|
||||
if shorthand == f.Shorthand {
|
||||
if len(shorthands) == 1 && f.NoOptDefVal == "" && len(args) > 0 {
|
||||
// '-f arg'
|
||||
restArgs = args[1:]
|
||||
}
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
passedArgs = append(passedArgs, s)
|
||||
if len(shorthands) == 1 && searchedFlag.NoOptDefVal == "" && len(args) > 0 {
|
||||
// '-f arg'
|
||||
passedArgs = append(passedArgs, args[0])
|
||||
restArgs = args[1:]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Command) removeParentPersistentArgs(args []string, flags *flag.FlagSet) (newArgs []string) {
|
||||
for len(args) > 0 {
|
||||
s := args[0]
|
||||
args = args[1:]
|
||||
if len(s) == 0 || s[0] != '-' {
|
||||
newArgs = append(newArgs, s)
|
||||
continue
|
||||
}
|
||||
|
||||
var passedArgs, restArgs []string
|
||||
if s[1] == '-' {
|
||||
passedArgs, restArgs = c.parseLongArgs(s, args, flags)
|
||||
} else {
|
||||
passedArgs, restArgs = c.parseShortArgs(s, args, flags)
|
||||
}
|
||||
if len(passedArgs) > 0 {
|
||||
newArgs = append(newArgs, passedArgs...)
|
||||
}
|
||||
args = restArgs
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ParseFlags parses persistent flag tree and local flags.
|
||||
func (c *Command) ParseFlags(args []string) error {
|
||||
if c.DisableFlagParsing {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.flagErrorBuf == nil {
|
||||
c.flagErrorBuf = new(bytes.Buffer)
|
||||
}
|
||||
|
@ -1830,11 +1921,38 @@ func (c *Command) ParseFlags(args []string) error {
|
|||
// do it here after merging all flags and just before parse
|
||||
c.Flags().ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist)
|
||||
|
||||
// parse Flags
|
||||
err := c.Flags().Parse(args)
|
||||
// Print warnings if they occurred (e.g. deprecated flag messages).
|
||||
if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil {
|
||||
c.Print(c.flagErrorBuf.String())
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// parse Local Flags
|
||||
c.LocalFlags() // need to execute LocalFlags() to set the value in c.lflags before executing removeParentPersistentArgs
|
||||
c.lflags.ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist)
|
||||
localArgs := c.removeParentPersistentArgs(args, c.lflags) // get only arguments related to c.lflags
|
||||
err = c.lflags.Parse(localArgs)
|
||||
// Print warnings if they occurred (e.g. deprecated flag messages).
|
||||
if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil {
|
||||
c.Print(c.flagErrorBuf.String())
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// parse local non persistent flags
|
||||
c.LocalNonPersistentFlags() // need to execute LocalNonPersistentFlags() to set the value in c.lnpflags before executing removeParentPersistentArgs
|
||||
c.lnpflags.ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist)
|
||||
localNonPersistentArgs := c.removeParentPersistentArgs(args, c.lnpflags)
|
||||
err = c.lnpflags.Parse(localNonPersistentArgs)
|
||||
// Print warnings if they occurred (e.g. deprecated flag messages).
|
||||
if c.flagErrorBuf.Len()-beforeErrorBufLen > 0 && err == nil {
|
||||
c.Print(c.flagErrorBuf.String())
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
140
command_test.go
140
command_test.go
|
@ -2786,3 +2786,143 @@ func TestUnknownFlagShouldReturnSameErrorRegardlessOfArgPosition(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNFlagForFlags(t *testing.T) {
|
||||
var rootNFlag, childNFlag int
|
||||
rootCmd := &Command{
|
||||
Use: "root",
|
||||
Run: func(cmd *Command, _ []string) {
|
||||
rootNFlag = cmd.Flags().NFlag()
|
||||
},
|
||||
}
|
||||
childCmd := &Command{
|
||||
Use: "child",
|
||||
Run: func(cmd *Command, args []string) {
|
||||
childNFlag = cmd.Flags().NFlag()
|
||||
},
|
||||
}
|
||||
rootCmd.AddCommand(childCmd)
|
||||
|
||||
rootCmd.PersistentFlags().Bool("rp", false, "")
|
||||
rpFlag := rootCmd.PersistentFlags().Lookup("rp")
|
||||
childCmd.PersistentFlags().Bool("cp", false, "")
|
||||
childCmd.Flags().Int("int", 0, "")
|
||||
|
||||
output, err := executeCommand(rootCmd, "--rp")
|
||||
if output != "" {
|
||||
t.Errorf("Unexpected output: %v", output)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if rootNFlag != 1 {
|
||||
t.Errorf("Expected NFlag: %v, got %v", 1, rootNFlag)
|
||||
}
|
||||
// set Changed false for the next test
|
||||
rpFlag.Changed = false
|
||||
|
||||
output, err = executeCommand(rootCmd, "child", "--rp", "--cp", "--int", "10")
|
||||
if output != "" {
|
||||
t.Errorf("Unexpected output: %v", output)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if childNFlag != 3 {
|
||||
t.Errorf("Expected NFlag: %v, got %v", 3, childNFlag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNFlagForLocalFlags(t *testing.T) {
|
||||
var localNFlag int
|
||||
rootCmd := &Command{
|
||||
Use: "root",
|
||||
Run: emptyRun,
|
||||
}
|
||||
childCmd := &Command{
|
||||
Use: "child",
|
||||
Run: func(cmd *Command, args []string) {
|
||||
localNFlag = cmd.LocalFlags().NFlag()
|
||||
},
|
||||
}
|
||||
rootCmd.AddCommand(childCmd)
|
||||
|
||||
rootCmd.PersistentFlags().Bool("rp", false, "")
|
||||
childCmd.PersistentFlags().Bool("cp", false, "")
|
||||
childCmd.Flags().Int("int", 0, "")
|
||||
|
||||
output, err := executeCommand(rootCmd, "child", "--rp", "--cp", "--int", "10")
|
||||
if output != "" {
|
||||
t.Errorf("Unexpected output: %v", output)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if localNFlag != 2 { // LocalFlags().NFlag() ignores '--rp'
|
||||
t.Errorf("Expected NFlag: %v, got %v", 2, localNFlag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNFlagForLocalNonPersistentFlags(t *testing.T) {
|
||||
var localNonPNFlag int
|
||||
rootCmd := &Command{
|
||||
Use: "root",
|
||||
Run: emptyRun,
|
||||
}
|
||||
childCmd := &Command{
|
||||
Use: "child",
|
||||
Run: func(cmd *Command, args []string) {
|
||||
localNonPNFlag = cmd.LocalNonPersistentFlags().NFlag()
|
||||
},
|
||||
}
|
||||
rootCmd.AddCommand(childCmd)
|
||||
|
||||
rootCmd.PersistentFlags().Bool("rp", false, "")
|
||||
childCmd.PersistentFlags().Bool("cp", false, "")
|
||||
childCmd.Flags().Int("int", 0, "")
|
||||
|
||||
output, err := executeCommand(rootCmd, "child", "--rp", "--cp", "--int", "10")
|
||||
if output != "" {
|
||||
t.Errorf("Unexpected output: %v", output)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if localNonPNFlag != 1 { // LocalNonPersistentFlags().NFlag() ignores '--rp' and '--cp'
|
||||
t.Errorf("Expected NFlag: %v, got %v", 1, localNonPNFlag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveParentPersistentArgs(t *testing.T) {
|
||||
rootCmd := &Command{Use: "root", Run: emptyRun}
|
||||
childCmd := &Command{Use: "child", Run: emptyRun}
|
||||
rootCmd.AddCommand(childCmd)
|
||||
|
||||
rootCmd.PersistentFlags().BoolP("rp", "r", false, "")
|
||||
rootCmd.PersistentFlags().Int("ri", 0, "")
|
||||
childCmd.PersistentFlags().Bool("cp", false, "")
|
||||
childCmd.Flags().Int("int", 0, "")
|
||||
|
||||
output, err := executeCommand(rootCmd, "child", "-r", "--ri", "10", "--cp", "--int", "10")
|
||||
if output != "" {
|
||||
t.Errorf("Unexpected output: %v", output)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
args := rootCmd.args
|
||||
_, args, _ = rootCmd.Find(args)
|
||||
|
||||
gotLocalArgs := childCmd.removeParentPersistentArgs(args, childCmd.lflags)
|
||||
expectedLocalArgs := []string{"--cp", "--int", "10"}
|
||||
if !reflect.DeepEqual(gotLocalArgs, expectedLocalArgs) {
|
||||
t.Errorf("Expected localArgs: %v, got %v", expectedLocalArgs, gotLocalArgs)
|
||||
}
|
||||
|
||||
gotLocalNonPersistentArgs := childCmd.removeParentPersistentArgs(args, childCmd.lnpflags)
|
||||
expectedLocalNonPersistentArgs := []string{"--int", "10"}
|
||||
if !reflect.DeepEqual(gotLocalNonPersistentArgs, expectedLocalNonPersistentArgs) {
|
||||
t.Errorf("Expected localArgs: %v, got %v", expectedLocalNonPersistentArgs, gotLocalNonPersistentArgs)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue