From 85c9f3065c93a6f733fd4b5423837a514349a20b Mon Sep 17 00:00:00 2001 From: Paul Chesnais Date: Tue, 11 Jun 2019 18:46:51 -0700 Subject: [PATCH] Add framework for dynamic tab completions By setting the COBRA_FLAG_COMPLETION environment variable, the normal execution path of the command is short circuited, and instead the function registered by `MarkCustomFlagCompletion` is executed. All flags other than the one being completed get parsed according to whatever type they are defined as, but the flag being completed is parsed as a raw string and passed into the custom compeltion. --- command.go | 104 +++++++++++++++++++++++++++++++++++++++++++ shell_completions.go | 30 +++++++++++++ zsh_completions.go | 47 ++++++++++++++++--- 3 files changed, 176 insertions(+), 5 deletions(-) diff --git a/command.go b/command.go index c7e89830..ba27f896 100644 --- a/command.go +++ b/command.go @@ -143,6 +143,10 @@ type Command struct { //FParseErrWhitelist flag parse errors to be ignored FParseErrWhitelist FParseErrWhitelist + // RunPreRunsDuringCompletion defines if the (Persistent)PreRun functions should be run before calling the + // completion functions + RunPreRunsDuringCompletion bool + // commands is the list of commands supported by this program. commands []*Command // parent is a parent command for this command. @@ -200,6 +204,10 @@ type Command struct { outWriter io.Writer // errWriter is a writer defined by the user that replaces stderr errWriter io.Writer + + // flagCompletions is a map of flag to a function that returns a list of values to suggest during tab completion for + // this flag + flagCompletions map[*flag.Flag]DynamicFlagCompletion } // SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden @@ -736,6 +744,8 @@ func (c *Command) ArgsLenAtDash() int { return c.Flags().ArgsLenAtDash() } +const FlagCompletionEnvVar = "COBRA_FLAG_COMPLETION" + func (c *Command) execute(a []string) (err error) { if c == nil { return fmt.Errorf("Called Execute() on a nil Command") @@ -851,6 +861,90 @@ func (c *Command) execute(a []string) (err error) { return nil } +func (c *Command) complete(flagName string, a []string) (err error) { + if c == nil { + return fmt.Errorf("Called Execute() on a nil Command") + } + + // initialize help and version flag at the last point possible to allow for user + // overriding + c.InitDefaultHelpFlag() + c.InitDefaultVersionFlag() + + var flagToComplete *flag.Flag + var currentCompletionValue string + + oldFlags := c.Flags() + c.flags = nil + oldFlags.VisitAll(func(f *flag.Flag) { + if f.Name == flagName { + flagToComplete = f + } else { + c.Flags().AddFlag(f) + } + }) + if flagToComplete.Shorthand != "" { + c.Flags().StringVarP(¤tCompletionValue, flagName, flagToComplete.Shorthand, "", "") + } else { + c.Flags().StringVar(¤tCompletionValue, flagName, "", "") + } + c.Flag(flagName).NoOptDefVal = "_hack_" + + err = c.ParseFlags(a) + if err != nil { + return c.FlagErrorFunc()(c, err) + } + + c.preRun() + + currentCommand := c + completionFunc := currentCommand.flagCompletions[flagToComplete] + for completionFunc == nil && currentCommand.HasParent() { + currentCommand = currentCommand.Parent() + completionFunc = currentCommand.flagCompletions[flagToComplete] + } + if completionFunc == nil { + return fmt.Errorf("%s does not have completions enabled", flagName) + } + + if c.RunPreRunsDuringCompletion { + argWoFlags := c.Flags().Args() + if c.DisableFlagParsing { + argWoFlags = a + } + + for p := c; p != nil; p = p.Parent() { + if p.PersistentPreRunE != nil { + if err := p.PersistentPreRunE(c, argWoFlags); err != nil { + return err + } + break + } else if p.PersistentPreRun != nil { + p.PersistentPreRun(c, argWoFlags) + break + } + } + if c.PreRunE != nil { + if err := c.PreRunE(c, argWoFlags); err != nil { + return err + } + } else if c.PreRun != nil { + c.PreRun(c, argWoFlags) + } + } + + values, err := completionFunc(currentCompletionValue) + if err != nil { + return err + } + + for _, v := range values { + c.Print(v + "\x00") + } + + return nil +} + func (c *Command) preRun() { for _, x := range initializers { x() @@ -911,6 +1005,16 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { cmd.commandCalledAs.name = cmd.Name() } + flagName, flagCompletionEnabled := os.LookupEnv(FlagCompletionEnvVar) + if flagCompletionEnabled { + err = cmd.complete(flagName, flags) + if err != nil { + c.Println("Error:", err.Error()) + } + + return cmd, err + } + err = cmd.execute(flags) if err != nil { // Always show help if requested, even if SilenceErrors is in diff --git a/shell_completions.go b/shell_completions.go index ba0af9cb..918ef3a4 100644 --- a/shell_completions.go +++ b/shell_completions.go @@ -1,6 +1,8 @@ package cobra import ( + "fmt" + "github.com/spf13/pflag" ) @@ -83,3 +85,31 @@ func MarkFlagDirname(flags *pflag.FlagSet, name string) error { zshPattern := "-(/)" return flags.SetAnnotation(name, zshCompDirname, []string{zshPattern}) } + +type DynamicFlagCompletion func(currentValue string) (suggestedValues []string, err error) + +// MarkDynamicFlagCompletion provides cobra a function to dynamically suggest values to the user during tab completion +// for this flag. All (Persistent)PreRun(E) functions will be run accordingly before the provided function is called if +// RunPreRunsDuringCompletion is set to true. All flags other than the flag currently being completed will be parsed +// according to their type. The flag being completed is parsed as a raw string with no format requirements +// +// Shell Completion compatibility matrix: zsh +func (c *Command) MarkDynamicFlagCompletion(name string, completion DynamicFlagCompletion) error { + flag := c.Flag(name) + if flag == nil { + return fmt.Errorf("no such flag %s", name) + } + if flag.NoOptDefVal != "" { + return fmt.Errorf("%s takes no parameters", name) + } + + if c.flagCompletions == nil { + c.flagCompletions = make(map[*pflag.Flag]DynamicFlagCompletion) + } + c.flagCompletions[flag] = completion + if flag.Annotations == nil { + flag.Annotations = map[string][]string{} + } + flag.Annotations[zshCompDynamicCompletion] = []string{zshCompGenFlagCompletionFuncName(c)} + return nil +} diff --git a/zsh_completions.go b/zsh_completions.go index 12755482..2255de6c 100644 --- a/zsh_completions.go +++ b/zsh_completions.go @@ -17,14 +17,18 @@ const ( zshCompArgumentFilenameComp = "cobra_annotations_zsh_completion_argument_file_completion" zshCompArgumentWordComp = "cobra_annotations_zsh_completion_argument_word_completion" zshCompDirname = "cobra_annotations_zsh_dirname" + zshCompDynamicCompletion = "cobra_annotations_zsh_completion_dynamic_completion" ) var ( zshCompFuncMap = template.FuncMap{ - "genZshFuncName": zshCompGenFuncName, - "extractFlags": zshCompExtractFlag, - "genFlagEntryForZshArguments": zshCompGenFlagEntryForArguments, - "extractArgsCompletions": zshCompExtractArgumentCompletionHintsForRendering, + "genZshFuncName": zshCompGenFuncName, + "extractFlags": zshCompExtractFlag, + "genFlagEntryForZshArguments": zshCompGenFlagEntryForArguments, + "extractArgsCompletions": zshCompExtractArgumentCompletionHintsForRendering, + "genZshFlagDynamicCompletionFuncName": zshCompGenFlagCompletionFuncName, + "hasDynamicCompletions": zshCompHasDynamicCompletions, + "flagCompletionsEnvVar": func() string { return FlagCompletionEnvVar }, } zshCompletionText = ` {{/* should accept Command (that contains subcommands) as parameter */}} @@ -79,6 +83,21 @@ function {{genZshFuncName .}} { {{define "Main" -}} #compdef _{{.Name}} {{.Name}} +{{if hasDynamicCompletions . -}} +function {{genZshFlagDynamicCompletionFuncName .}} { + export COBRA_FLAG_COMPLETION="$1" + if suggestions="$("$words[@]" 2>&1)" ; then + local -a args + while read -d $'\0' line ; do + args+="$line" + done <<< "$suggestions" + _values "$1" "$args[@]" + else + _message "Exception occurred during completion: $suggestions" + fi + unset COBRA_FLAG_COMPLETION +}{{- end}} + {{template "selectCmdTemplate" .}} {{end}} ` @@ -250,6 +269,10 @@ func zshCompGenFuncName(c *Command) string { return "_" + c.Name() } +func zshCompGenFlagCompletionFuncName(c *Command) string { + return "_" + c.Root().Name() + "-flag-completion" +} + func zshCompExtractFlag(c *Command) []*pflag.Flag { var flags []*pflag.Flag c.LocalFlags().VisitAll(func(f *pflag.Flag) { @@ -310,7 +333,7 @@ func zshCompGenFlagEntryExtras(f *pflag.Flag) string { return "" } - extras := ":" // allow options for flag (even without assistance) + extras := ":" + f.Name // allow options for flag (even without assistance) for key, values := range f.Annotations { switch key { case zshCompDirname: @@ -320,6 +343,8 @@ func zshCompGenFlagEntryExtras(f *pflag.Flag) string { for _, pattern := range values { extras = extras + fmt.Sprintf(` -g "%s"`, pattern) } + case zshCompDynamicCompletion: + extras += fmt.Sprintf(":{%s %s}", values[0], f.Name) } } @@ -334,3 +359,15 @@ func zshCompFlagCouldBeSpecifiedMoreThenOnce(f *pflag.Flag) bool { func zshCompQuoteFlagDescription(s string) string { return strings.Replace(s, "'", `'\''`, -1) } + +func zshCompHasDynamicCompletions(c *Command) bool { + if len(c.flagCompletions) > 0 { + return true + } + for _, subcommand := range c.Commands() { + if zshCompHasDynamicCompletions(subcommand) { + return true + } + } + return false +}