diff --git a/command.go b/command.go index 4794e5eb..e0151dab 100644 --- a/command.go +++ b/command.go @@ -671,14 +671,33 @@ func shortHasNoOptDefVal(name string, fs *flag.FlagSet) bool { return flag.NoOptDefVal != "" } +func addFlagSet(dest *flag.FlagSet, src *flag.FlagSet) { + if src == nil { + return + } + src.VisitAll(func(f *flag.Flag) { + if dest.Lookup(f.Name) == nil && dest.ShorthandLookup(f.Shorthand) == nil { + dest.AddFlag(f) + } + }) +} + +func mergeChildrenFlags(fs *flag.FlagSet, c *Command) { + addFlagSet(fs, c.flags) + addFlagSet(fs, c.pflags) + for _, subc := range c.commands { + mergeChildrenFlags(fs, subc) + } +} + func stripFlags(args []string, c *Command) []string { if len(args) == 0 { return args } - c.mergePersistentFlags() commands := []string{} - flags := c.Flags() + flags := flag.NewFlagSet("", flag.ContinueOnError) + mergeChildrenFlags(flags, c) Loop: for len(args) > 0 { @@ -688,18 +707,23 @@ Loop: case s == "--": // "--" terminates the flags break Loop - case strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && !hasNoOptDefVal(s[2:], flags): - // If '--flag arg' then - // delete arg from args. - fallthrough // (do the same as below) - case strings.HasPrefix(s, "-") && !strings.Contains(s, "=") && len(s) == 2 && !shortHasNoOptDefVal(s[1:], flags): - // If '-f arg' then - // delete 'arg' from args or break the loop if len(args) <= 1. - if len(args) <= 1 { - break Loop - } else { + case strings.HasPrefix(s, "-") && !strings.Contains(s, "=") && len(args) > 0: + // If "--flag" or "-f" then strip leading dashes. + s = s[1:] + hnodv := shortHasNoOptDefVal + if len(s) > 1 { + if !strings.HasPrefix(s, "-") { + // "-i1" + continue + } + // Long flag. + s = s[1:] + hnodv = hasNoOptDefVal + } + f := flags.Lookup(s) + if (f == nil || f.Value.Type() != "bool") && !hnodv(s, flags) { + // Delete the argument if '--flag arg'. args = args[1:] - continue } case s != "" && !strings.HasPrefix(s, "-"): commands = append(commands, s) diff --git a/command_test.go b/command_test.go index 156df9eb..b1898f8e 100644 --- a/command_test.go +++ b/command_test.go @@ -640,6 +640,34 @@ func TestFlagBeforeCommand(t *testing.T) { } } +func TestBooleanFlagBeforeCommand(t *testing.T) { + rootCommand := &Command{SilenceUsage: true, Args: MaximumNArgs(0)} + var flagValue bool + rootCommand.PersistentFlags().BoolVar(&flagValue, "rootflag", false, "root boolean flag") + + cmd := &Command{Use: "command", SilenceUsage: true, Args: MaximumNArgs(0)} + cmd.PersistentFlags().BoolVar(&flagValue, "cmdflag", false, "command boolean flag") + rootCommand.AddCommand(cmd) + + for name, args := range map[string][]string{ + "root flag before command": {"--rootflag", "command"}, + "command flag before command": {"--cmdflag", "command"}, + "command flag with =value before command": {"--cmdflag=true", "command"}, + // Not supported: + // "command flag with value before command": {"--cmdflag", "true", "command"}, + } { + t.Run(name, func(t *testing.T) { + flagValue = false + rootCommand.SetArgs(args) + if err := rootCommand.Execute(); err != nil { + t.Error(err) + } else if !flagValue { + t.Errorf("flag is false") + } + }) + } +} + func TestStripFlags(t *testing.T) { tests := []struct { input []string