This commit is contained in:
Brandon Roehl 2018-06-27 17:48:42 +00:00 committed by GitHub
commit 7fc0d096d2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 46 deletions

View file

@ -6,6 +6,8 @@ import (
"io" "io"
"os" "os"
"strings" "strings"
flag "github.com/spf13/pflag"
) )
// GenZshCompletionFile generates zsh completion file. // GenZshCompletionFile generates zsh completion file.
@ -19,14 +21,23 @@ func (c *Command) GenZshCompletionFile(filename string) error {
return c.GenZshCompletion(outFile) return c.GenZshCompletion(outFile)
} }
func argName(cmd *Command) string {
for cmd.HasParent() {
cmd = cmd.Parent()
}
name := fmt.Sprintf("%s_cmd_args", cmd.Name())
return strings.Replace(name, "-", "_",-1)
}
// GenZshCompletion generates a zsh completion file and writes to the passed writer. // GenZshCompletion generates a zsh completion file and writes to the passed writer.
func (c *Command) GenZshCompletion(w io.Writer) error { func (c *Command) GenZshCompletion(w io.Writer) error {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
writeHeader(buf, c) writeHeader(buf, c)
maxDepth := maxDepth(c) maxDepth := maxDepth(c)
writeLevelMapping(buf, maxDepth) fmt.Fprintf(buf, "_%s() {\n", c.Name())
writeLevelMapping(buf, maxDepth, c)
writeLevelCases(buf, maxDepth, c) writeLevelCases(buf, maxDepth, c)
fmt.Fprintf(buf, "}\n_%s \"$@\"\n", c.Name())
_, err := buf.WriteTo(w) _, err := buf.WriteTo(w)
return err return err
@ -50,77 +61,102 @@ func maxDepth(c *Command) int {
return 1 + maxDepthSub return 1 + maxDepthSub
} }
func writeLevelMapping(w io.Writer, numLevels int) { func writeLevelMapping(w io.Writer, numLevels int, root *Command) {
fmt.Fprintln(w, `_arguments \`) fmt.Fprintln(w, `local context curcontext="$curcontext" state line`)
fmt.Fprintln(w, `typeset -A opt_args`)
fmt.Fprintln(w, `_arguments -C \`)
for i := 1; i <= numLevels; i++ { for i := 1; i <= numLevels; i++ {
fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i) fmt.Fprintf(w, " '%d: :->level%d' \\\n", i, i)
fmt.Fprintln(w)
} }
fmt.Fprintf(w, ` '%d: :%s'`, numLevels+1, "_files") fmt.Fprintf(w, " $%s \\\n", argName(root))
fmt.Fprintln(w) fmt.Fprintln(w, ` '*: :_files'`)
} }
func writeLevelCases(w io.Writer, maxDepth int, root *Command) { func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
fmt.Fprintln(w, "case $state in") fmt.Fprintln(w, "case $state in")
defer fmt.Fprintln(w, "esac")
for i := 1; i <= maxDepth; i++ { for i := 1; i <= maxDepth; i++ {
fmt.Fprintf(w, " level%d)\n", i)
writeLevel(w, root, i) writeLevel(w, root, i)
fmt.Fprintln(w, " ;;")
} }
fmt.Fprintln(w, " *)") fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'") fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;") fmt.Fprintln(w, " ;;")
fmt.Fprintln(w, "esac")
} }
func writeLevel(w io.Writer, root *Command, i int) { func writeLevel(w io.Writer, root *Command, l int) {
fmt.Fprintf(w, " case $words[%d] in\n", i) fmt.Fprintf(w, " level%d)\n", l)
defer fmt.Fprintln(w, " esac") fmt.Fprintf(w, " case $words[%d] in\n", l)
for _, c := range filterByLevel(root, l) {
commands := filterByLevel(root, i) writeCommandArgsBlock(w, c)
byParent := groupByParent(commands)
for p, c := range byParent {
names := names(c)
fmt.Fprintf(w, " %s)\n", p)
fmt.Fprintf(w, " _arguments '%d: :(%s)'\n", i, strings.Join(names, " "))
fmt.Fprintln(w, " ;;")
} }
fmt.Fprintln(w, " *)") fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'") fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;") fmt.Fprintln(w, " ;;")
fmt.Fprintln(w, " esac")
fmt.Fprintln(w, " ;;")
}
func writeCommandArgsBlock(w io.Writer, c *Command) {
names := commandNames(c)
flags := commandFlags(c)
if len(names) > 0 || len(flags) > 0 {
fmt.Fprintf(w, " %s)\n", c.Name())
defer fmt.Fprintln(w, " ;;")
}
if len(flags) > 0 {
fmt.Fprintf(w, " %s=(\n", argName(c))
for _, flag := range flags {
fmt.Fprintf(w, " %s\n", flag)
}
fmt.Fprintln(w, " )")
}
if len(names) > 0 {
fmt.Fprintf(w, " _values 'command' '%s'\n", strings.Join(names, "' '"))
}
} }
func filterByLevel(c *Command, l int) []*Command { func filterByLevel(c *Command, l int) []*Command {
cs := make([]*Command, 0) commands := []*Command{c}
if l == 0 { for i := 1; i < l; i++ {
cs = append(cs, c) var nextLevel []*Command
return cs
}
for _, s := range c.Commands() {
cs = append(cs, filterByLevel(s, l-1)...)
}
return cs
}
func groupByParent(commands []*Command) map[string][]*Command {
m := make(map[string][]*Command)
for _, c := range commands { for _, c := range commands {
parent := c.Parent() if c.HasSubCommands() {
if parent == nil { nextLevel = append(nextLevel, c.Commands()...)
continue
} }
m[parent.Name()] = append(m[parent.Name()], c)
} }
return m commands = nextLevel
} }
func names(commands []*Command) []string { return commands
}
func commandNames(command *Command) []string {
commands := command.Commands()
ns := make([]string, len(commands)) ns := make([]string, len(commands))
for i, c := range commands { for i, c := range commands {
ns[i] = c.Name() commandMsg := c.Name()
if len(c.Short) > 0 {
commandMsg += fmt.Sprintf("[%s]", c.Short)
}
ns[i] = commandMsg
} }
return ns return ns
} }
func commandFlags(command *Command) []string {
flags := command.Flags()
ns := make([]string, 0)
flags.VisitAll(func(flag *flag.Flag) {
var flagMsg string
if len(flag.Shorthand) > 0 {
flagMsg = fmt.Sprintf("{-%s,--%s}", flag.Shorthand, flag.Name)
} else {
flagMsg = fmt.Sprintf("--%s", flag.Name)
}
if len(flag.Usage) > 0 {
flagMsg += fmt.Sprintf("'[%s]'", flag.Usage)
}
ns = append(ns, flagMsg)
})
return ns
}

View file

@ -42,7 +42,7 @@ func TestZshCompletion(t *testing.T) {
r.AddCommand(&Command{Use: "c2"}) r.AddCommand(&Command{Use: "c2"})
return r return r
}(), }(),
expectedExpressions: []string{"(c1 c2)"}, expectedExpressions: []string{"'c1' 'c2'"},
}, },
{ {
name: "tree", name: "tree",
@ -69,7 +69,7 @@ func TestZshCompletion(t *testing.T) {
return r return r
}(), }(),
expectedExpressions: []string{"(sub11 sub12)", "(sub21 sub22)"}, expectedExpressions: []string{"'sub11' 'sub12'", "'sub21' 'sub22'"},
}, },
} }