Merge branch 'master' into master

This commit is contained in:
Alex 2019-06-07 17:11:37 +02:00 committed by GitHub
commit 6eb960b3cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 1689 additions and 902 deletions

2
.gitignore vendored
View file

@ -34,3 +34,5 @@ tags
*.exe
cobra.test
.idea/*

View file

@ -23,6 +23,7 @@ Many of the most widely used Go projects are built using Cobra, such as:
[Istio](https://istio.io),
[Prototool](https://github.com/uber/prototool),
[mattermost-server](https://github.com/mattermost/mattermost-server),
[Gardener](https://github.com/gardener/gardenctl),
etc.
[![Build Status](https://travis-ci.org/spf13/cobra.svg "Travis CI status")](https://travis-ci.org/spf13/cobra)
@ -48,6 +49,7 @@ etc.
* [Suggestions when "unknown command" happens](#suggestions-when-unknown-command-happens)
* [Generating documentation for your command](#generating-documentation-for-your-command)
* [Generating bash completions](#generating-bash-completions)
* [Generating zsh completions](#generating-zsh-completions)
- [Contributing](#contributing)
- [License](#license)
@ -336,7 +338,7 @@ rootCmd.PersistentFlags().BoolVarP(&Verbose, "verbose", "v", false, "verbose out
A flag can also be assigned locally which will only apply to that specific command.
```go
rootCmd.Flags().StringVarP(&Source, "source", "s", "", "Source directory to read from")
localCmd.Flags().StringVarP(&Source, "source", "s", "", "Source directory to read from")
```
### Local Flag on Parent Commands
@ -719,6 +721,11 @@ Cobra can generate documentation based on subcommands, flags, etc. in the follow
Cobra can generate a bash-completion file. If you add more information to your command, these completions can be amazingly powerful and flexible. Read more about it in [Bash Completions](bash_completions.md).
## Generating zsh completions
Cobra can generate zsh-completion file. Read more about it in
[Zsh Completions](zsh_completions.md).
# Contributing
1. Fork it

View file

@ -545,51 +545,3 @@ func (c *Command) GenBashCompletionFile(filename string) error {
return c.GenBashCompletion(outFile)
}
// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag if it exists,
// and causes your command to report an error if invoked without the flag.
func (c *Command) MarkFlagRequired(name string) error {
return MarkFlagRequired(c.Flags(), name)
}
// MarkPersistentFlagRequired adds the BashCompOneRequiredFlag annotation to the named persistent flag if it exists,
// and causes your command to report an error if invoked without the flag.
func (c *Command) MarkPersistentFlagRequired(name string) error {
return MarkFlagRequired(c.PersistentFlags(), name)
}
// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag if it exists,
// and causes your command to report an error if invoked without the flag.
func MarkFlagRequired(flags *pflag.FlagSet, name string) error {
return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"})
}
// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag, if it exists.
// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
func (c *Command) MarkFlagFilename(name string, extensions ...string) error {
return MarkFlagFilename(c.Flags(), name, extensions...)
}
// MarkFlagCustom adds the BashCompCustom annotation to the named flag, if it exists.
// Generated bash autocompletion will call the bash function f for the flag.
func (c *Command) MarkFlagCustom(name string, f string) error {
return MarkFlagCustom(c.Flags(), name, f)
}
// MarkPersistentFlagFilename adds the BashCompFilenameExt annotation to the named persistent flag, if it exists.
// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
func (c *Command) MarkPersistentFlagFilename(name string, extensions ...string) error {
return MarkFlagFilename(c.PersistentFlags(), name, extensions...)
}
// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag in the flag set, if it exists.
// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
func MarkFlagFilename(flags *pflag.FlagSet, name string, extensions ...string) error {
return flags.SetAnnotation(name, BashCompFilenameExt, extensions)
}
// MarkFlagCustom adds the BashCompCustom annotation to the named flag in the flag set, if it exists.
// Generated bash autocompletion will call the bash function f for the flag.
func MarkFlagCustom(flags *pflag.FlagSet, name string, f string) error {
return flags.SetAnnotation(name, BashCompCustom, []string{f})
}

View file

@ -16,20 +16,16 @@ package cmd
import (
"fmt"
"os"
"path/filepath"
"unicode"
"github.com/spf13/cobra"
)
func init() {
addCmd.Flags().StringVarP(&packageName, "package", "t", "", "target package name (e.g. github.com/spf13/hugo)")
addCmd.Flags().StringVarP(&parentName, "parent", "p", "rootCmd", "variable name of parent command for this command (e.g. xyCmd)")
}
var (
packageName string
parentName string
var packageName, parentName string
var addCmd = &cobra.Command{
addCmd = &cobra.Command{
Use: "add [command name]",
Aliases: []string{"command"},
Short: "Add a command to a Cobra Application",
@ -47,24 +43,37 @@ Example: cobra add server -> resulting in a new cmd/server.go`,
er("add needs a name for the command")
}
var project *Project
if packageName != "" {
project = NewProject(packageName)
} else {
wd, err := os.Getwd()
if err != nil {
er(err)
}
project = NewProjectFromPath(wd)
commandName := validateCmdName(args[0])
command := &Command{
CmdName: commandName,
CmdParent: parentName,
Project: &Project{
AbsolutePath: wd,
Legal: getLicense(),
Copyright: copyrightLine(),
},
}
cmdName := validateCmdName(args[0])
cmdPath := filepath.Join(project.CmdPath(), cmdName+".go")
createCmdFile(project.License(), cmdPath, cmdName)
err = command.Create()
if err != nil {
er(err)
}
fmt.Fprintln(cmd.OutOrStdout(), cmdName, "created at", cmdPath)
fmt.Printf("%s created at %s", command.CmdName, command.AbsolutePath)
},
}
)
func init() {
addCmd.Flags().StringVarP(&packageName, "package", "t", "", "target package name (e.g. github.com/spf13/hugo)")
addCmd.Flags().StringVarP(&parentName, "parent", "p", "rootCmd", "variable name of parent command for this command (e.g. xyCmd)")
addCmd.Flags().MarkDeprecated("package", "this operation has been removed.")
}
// validateCmdName returns source without any dashes and underscore.
// If there will be dash or underscore, next letter will be uppered.
@ -118,62 +127,3 @@ func validateCmdName(source string) string {
}
return output
}
func createCmdFile(license License, path, cmdName string) {
template := `{{comment .copyright}}
{{if .license}}{{comment .license}}{{end}}
package {{.cmdPackage}}
import (
"fmt"
"github.com/spf13/cobra"
)
// {{.cmdName}}Cmd represents the {{.cmdName}} command
var {{.cmdName}}Cmd = &cobra.Command{
Use: "{{.cmdName}}",
Short: "A brief description of your command",
Long: ` + "`" + `A longer description that spans multiple lines and likely contains examples
and usage of using your command. For example:
Cobra is a CLI library for Go that empowers applications.
This application is a tool to generate the needed files
to quickly create a Cobra application.` + "`" + `,
Run: func(cmd *cobra.Command, args []string) {
fmt.Println("{{.cmdName}} called")
},
}
func init() {
{{.parentName}}.AddCommand({{.cmdName}}Cmd)
// Here you will define your flags and configuration settings.
// Cobra supports Persistent Flags which will work for this command
// and all subcommands, e.g.:
// {{.cmdName}}Cmd.PersistentFlags().String("foo", "", "A help for foo")
// Cobra supports local flags which will only run when this command
// is called directly, e.g.:
// {{.cmdName}}Cmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}
`
data := make(map[string]interface{})
data["copyright"] = copyrightLine()
data["license"] = license.Header
data["cmdPackage"] = filepath.Base(filepath.Dir(path)) // last dir of path
data["parentName"] = parentName
data["cmdName"] = cmdName
cmdScript, err := executeTemplate(template, data)
if err != nil {
er(err)
}
err = writeStringToFile(path, cmdScript)
if err != nil {
er(err)
}
}

View file

@ -1,85 +1,45 @@
package cmd
import (
"errors"
"io/ioutil"
"fmt"
"os"
"path/filepath"
"testing"
"github.com/spf13/viper"
)
// TestGoldenAddCmd initializes the project "github.com/spf13/testproject"
// in GOPATH, adds "test" command
// and compares the content of all files in cmd directory of testproject
// with appropriate golden files.
// Use -update to update existing golden files.
func TestGoldenAddCmd(t *testing.T) {
projectName := "github.com/spf13/testproject"
project := NewProject(projectName)
defer os.RemoveAll(project.AbsPath())
viper.Set("author", "NAME HERE <EMAIL ADDRESS>")
viper.Set("license", "apache")
viper.Set("year", 2017)
defer viper.Set("author", nil)
defer viper.Set("license", nil)
defer viper.Set("year", nil)
wd, _ := os.Getwd()
command := &Command{
CmdName: "test",
CmdParent: parentName,
Project: &Project{
AbsolutePath: fmt.Sprintf("%s/testproject", wd),
Legal: getLicense(),
Copyright: copyrightLine(),
// Initialize the project first.
initializeProject(project)
// Then add the "test" command.
cmdName := "test"
cmdPath := filepath.Join(project.CmdPath(), cmdName+".go")
createCmdFile(project.License(), cmdPath, cmdName)
expectedFiles := []string{".", "root.go", "test.go"}
gotFiles := []string{}
// Check project file hierarchy and compare the content of every single file
// with appropriate golden file.
err := filepath.Walk(project.CmdPath(), func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
// required to init
AppName: "testproject",
PkgName: "github.com/spf13/testproject",
Viper: true,
},
}
// Make path relative to project.CmdPath().
// E.g. path = "/home/user/go/src/github.com/spf13/testproject/cmd/root.go"
// then it returns just "root.go".
relPath, err := filepath.Rel(project.CmdPath(), path)
if err != nil {
return err
// init project first
command.Project.Create()
defer func() {
if _, err := os.Stat(command.AbsolutePath); err == nil {
os.RemoveAll(command.AbsolutePath)
}
relPath = filepath.ToSlash(relPath)
gotFiles = append(gotFiles, relPath)
goldenPath := filepath.Join("testdata", filepath.Base(path)+".golden")
}()
switch relPath {
// Known directories.
case ".":
return nil
// Known files.
case "root.go", "test.go":
if *update {
got, err := ioutil.ReadFile(path)
if err != nil {
return err
}
ioutil.WriteFile(goldenPath, got, 0644)
}
return compareFiles(path, goldenPath)
}
// Unknown file.
return errors.New("unknown file: " + path)
})
if err != nil {
if err := command.Create(); err != nil {
t.Fatal(err)
}
// Check if some files lack.
if err := checkLackFiles(expectedFiles, gotFiles); err != nil {
generatedFile := fmt.Sprintf("%s/cmd/%s.go", command.AbsolutePath, command.CmdName)
goldenFile := fmt.Sprintf("testdata/%s.go.golden", command.CmdName)
err := compareFiles(generatedFile, goldenFile)
if err != nil {
t.Fatal(err)
}
}

View file

@ -15,15 +15,16 @@ package cmd
import (
"fmt"
"os"
"path"
"path/filepath"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"os"
"path"
)
var initCmd = &cobra.Command{
var (
pkgName string
initCmd = &cobra.Command{
Use: "init [name]",
Aliases: []string{"initialize", "initialise", "create"},
Short: "Initialize a Cobra Application",
@ -40,195 +41,37 @@ and the appropriate structure for a Cobra-based CLI application.
Init will not use an existing directory with contents.`,
Run: func(cmd *cobra.Command, args []string) {
wd, err := os.Getwd()
if err != nil {
er(err)
}
var project *Project
if len(args) == 0 {
project = NewProjectFromPath(wd)
} else if len(args) == 1 {
arg := args[0]
if arg[0] == '.' {
arg = filepath.Join(wd, arg)
if len(args) > 0 {
if args[0] != "." {
wd = fmt.Sprintf("%s/%s", wd, args[0])
}
if filepath.IsAbs(arg) {
project = NewProjectFromPath(arg)
} else {
project = NewProject(arg)
}
} else {
er("please provide only one argument")
}
initializeProject(project)
project := &Project{
AbsolutePath: wd,
PkgName: pkgName,
Legal: getLicense(),
Copyright: copyrightLine(),
Viper: viper.GetBool("useViper"),
AppName: path.Base(pkgName),
}
fmt.Fprintln(cmd.OutOrStdout(), `Your Cobra application is ready at
`+project.AbsPath()+`
if err := project.Create(); err != nil {
er(err)
}
Give it a try by going there and running `+"`go run main.go`."+`
Add commands to it by running `+"`cobra add [cmdname]`.")
fmt.Printf("Your Cobra applicaton is ready at\n%s\n", project.AbsolutePath)
},
}
)
func initializeProject(project *Project) {
if !exists(project.AbsPath()) { // If path doesn't yet exist, create it
err := os.MkdirAll(project.AbsPath(), os.ModePerm)
if err != nil {
er(err)
}
} else if !isEmpty(project.AbsPath()) { // If path exists and is not empty don't use it
er("Cobra will not create a new project in a non empty directory: " + project.AbsPath())
}
// We have a directory and it's empty. Time to initialize it.
createLicenseFile(project.License(), project.AbsPath())
createMainFile(project)
createRootCmdFile(project)
}
func createLicenseFile(license License, path string) {
data := make(map[string]interface{})
data["copyright"] = copyrightLine()
// Generate license template from text and data.
text, err := executeTemplate(license.Text, data)
if err != nil {
er(err)
}
// Write license text to LICENSE file.
err = writeStringToFile(filepath.Join(path, "LICENSE"), text)
if err != nil {
er(err)
}
}
func createMainFile(project *Project) {
mainTemplate := `{{ comment .copyright }}
{{if .license}}{{ comment .license }}{{end}}
package main
import "{{ .importpath }}"
func main() {
cmd.Execute()
}
`
data := make(map[string]interface{})
data["copyright"] = copyrightLine()
data["license"] = project.License().Header
data["importpath"] = path.Join(project.Name(), filepath.Base(project.CmdPath()))
mainScript, err := executeTemplate(mainTemplate, data)
if err != nil {
er(err)
}
err = writeStringToFile(filepath.Join(project.AbsPath(), "main.go"), mainScript)
if err != nil {
er(err)
}
}
func createRootCmdFile(project *Project) {
template := `{{comment .copyright}}
{{if .license}}{{comment .license}}{{end}}
package cmd
import (
"fmt"
"os"
{{if .viper}}
homedir "github.com/mitchellh/go-homedir"{{end}}
"github.com/spf13/cobra"{{if .viper}}
"github.com/spf13/viper"{{end}}
){{if .viper}}
var cfgFile string{{end}}
// rootCmd represents the base command when called without any subcommands
var rootCmd = &cobra.Command{
Use: "{{.appName}}",
Short: "A brief description of your application",
Long: ` + "`" + `A longer description that spans multiple lines and likely contains
examples and usage of using your application. For example:
Cobra is a CLI library for Go that empowers applications.
This application is a tool to generate the needed files
to quickly create a Cobra application.` + "`" + `,
// Uncomment the following line if your bare application
// has an action associated with it:
// Run: func(cmd *cobra.Command, args []string) { },
}
// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
if err := rootCmd.Execute(); err != nil {
fmt.Println(err)
os.Exit(1)
}
}
func init() { {{- if .viper}}
cobra.OnInitialize(initConfig)
{{end}}
// Here you will define your flags and configuration settings.
// Cobra supports persistent flags, which, if defined here,
// will be global for your application.{{ if .viper }}
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.{{ .appName }}.yaml)"){{ else }}
// rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.{{ .appName }}.yaml)"){{ end }}
// Cobra also supports local flags, which will only run
// when this action is called directly.
rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}{{ if .viper }}
// initConfig reads in config file and ENV variables if set.
func initConfig() {
if cfgFile != "" {
// Use config file from the flag.
viper.SetConfigFile(cfgFile)
} else {
// Find home directory.
home, err := homedir.Dir()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
// Search config in home directory with name ".{{ .appName }}" (without extension).
viper.AddConfigPath(home)
viper.SetConfigName(".{{ .appName }}")
}
viper.AutomaticEnv() // read in environment variables that match
// If a config file is found, read it in.
if err := viper.ReadInConfig(); err == nil {
fmt.Println("Using config file:", viper.ConfigFileUsed())
}
}{{ end }}
`
data := make(map[string]interface{})
data["copyright"] = copyrightLine()
data["viper"] = viper.GetBool("useViper")
data["license"] = project.License().Header
data["appName"] = path.Base(project.Name())
rootCmdScript, err := executeTemplate(template, data)
if err != nil {
er(err)
}
err = writeStringToFile(filepath.Join(project.CmdPath(), "root.go"), rootCmdScript)
if err != nil {
er(err)
}
func init() {
initCmd.Flags().StringVar(&pkgName, "pkg-name", "", "fully qualified pkg name")
initCmd.MarkFlagRequired("pkg-name")
}

View file

@ -1,83 +1,42 @@
package cmd
import (
"errors"
"io/ioutil"
"fmt"
"os"
"path/filepath"
"testing"
"github.com/spf13/viper"
)
// TestGoldenInitCmd initializes the project "github.com/spf13/testproject"
// in GOPATH and compares the content of files in initialized project with
// appropriate golden files ("testdata/*.golden").
// Use -update to update existing golden files.
func TestGoldenInitCmd(t *testing.T) {
projectName := "github.com/spf13/testproject"
project := NewProject(projectName)
defer os.RemoveAll(project.AbsPath())
viper.Set("author", "NAME HERE <EMAIL ADDRESS>")
viper.Set("license", "apache")
viper.Set("year", 2017)
defer viper.Set("author", nil)
defer viper.Set("license", nil)
defer viper.Set("year", nil)
os.Args = []string{"cobra", "init", projectName}
if err := rootCmd.Execute(); err != nil {
t.Fatal("Error by execution:", err)
wd, _ := os.Getwd()
project := &Project{
AbsolutePath: fmt.Sprintf("%s/testproject", wd),
PkgName: "github.com/spf13/testproject",
Legal: getLicense(),
Copyright: copyrightLine(),
Viper: true,
AppName: "testproject",
}
expectedFiles := []string{".", "cmd", "LICENSE", "main.go", "cmd/root.go"}
gotFiles := []string{}
// Check project file hierarchy and compare the content of every single file
// with appropriate golden file.
err := filepath.Walk(project.AbsPath(), func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Make path relative to project.AbsPath().
// E.g. path = "/home/user/go/src/github.com/spf13/testproject/cmd/root.go"
// then it returns just "cmd/root.go".
relPath, err := filepath.Rel(project.AbsPath(), path)
if err != nil {
return err
}
relPath = filepath.ToSlash(relPath)
gotFiles = append(gotFiles, relPath)
goldenPath := filepath.Join("testdata", filepath.Base(path)+".golden")
switch relPath {
// Known directories.
case ".", "cmd":
return nil
// Known files.
case "LICENSE", "main.go", "cmd/root.go":
if *update {
got, err := ioutil.ReadFile(path)
if err != nil {
return err
}
if err := ioutil.WriteFile(goldenPath, got, 0644); err != nil {
t.Fatal("Error while updating file:", err)
}
}
return compareFiles(path, goldenPath)
}
// Unknown file.
return errors.New("unknown file: " + path)
})
err := project.Create()
if err != nil {
t.Fatal(err)
}
// Check if some files lack.
if err := checkLackFiles(expectedFiles, gotFiles); err != nil {
defer func() {
if _, err := os.Stat(project.AbsolutePath); err == nil {
os.RemoveAll(project.AbsolutePath)
}
}()
expectedFiles := []string{"LICENSE", "main.go", "cmd/root.go"}
for _, f := range expectedFiles {
generatedFile := fmt.Sprintf("%s/%s", project.AbsolutePath, f)
goldenFile := fmt.Sprintf("testdata/%s.golden", filepath.Base(f))
err := compareFiles(generatedFile, goldenFile)
if err != nil {
t.Fatal(err)
}
}
}

View file

@ -1,200 +1,96 @@
package cmd
import (
"fmt"
"github.com/spf13/cobra/cobra/tpl"
"os"
"path/filepath"
"runtime"
"strings"
"text/template"
)
// Project contains name, license and paths to projects.
type Project struct {
absPath string
cmdPath string
srcPath string
license License
name string
// v2
PkgName string
Copyright string
AbsolutePath string
Legal License
Viper bool
AppName string
}
// NewProject returns Project with specified project name.
func NewProject(projectName string) *Project {
if projectName == "" {
er("can't create project with blank name")
type Command struct {
CmdName string
CmdParent string
*Project
}
p := new(Project)
p.name = projectName
func (p *Project) Create() error {
// 1. Find already created protect.
p.absPath = findPackage(projectName)
// check if AbsolutePath exists
if _, err := os.Stat(p.AbsolutePath); os.IsNotExist(err) {
// create directory
if err := os.Mkdir(p.AbsolutePath, 0754); err != nil {
return err
}
}
// 2. If there are no created project with this path, and user is in GOPATH,
// then use GOPATH/src/projectName.
if p.absPath == "" {
wd, err := os.Getwd()
// create main.go
mainFile, err := os.Create(fmt.Sprintf("%s/main.go", p.AbsolutePath))
if err != nil {
er(err)
}
for _, srcPath := range srcPaths {
goPath := filepath.Dir(srcPath)
if filepathHasPrefix(wd, goPath) {
p.absPath = filepath.Join(srcPath, projectName)
break
}
}
return err
}
defer mainFile.Close()
// 3. If user is not in GOPATH, then use (first GOPATH)/src/projectName.
if p.absPath == "" {
p.absPath = filepath.Join(srcPaths[0], projectName)
}
return p
}
// findPackage returns full path to existing go package in GOPATHs.
func findPackage(packageName string) string {
if packageName == "" {
return ""
}
for _, srcPath := range srcPaths {
packagePath := filepath.Join(srcPath, packageName)
if exists(packagePath) {
return packagePath
}
}
return ""
}
// NewProjectFromPath returns Project with specified absolute path to
// package.
func NewProjectFromPath(absPath string) *Project {
if absPath == "" {
er("can't create project: absPath can't be blank")
}
if !filepath.IsAbs(absPath) {
er("can't create project: absPath is not absolute")
}
// If absPath is symlink, use its destination.
fi, err := os.Lstat(absPath)
mainTemplate := template.Must(template.New("main").Parse(string(tpl.MainTemplate())))
err = mainTemplate.Execute(mainFile, p)
if err != nil {
er("can't read path info: " + err.Error())
return err
}
if fi.Mode()&os.ModeSymlink != 0 {
path, err := os.Readlink(absPath)
// create cmd/root.go
if _, err = os.Stat(fmt.Sprintf("%s/cmd", p.AbsolutePath)); os.IsNotExist(err) {
os.Mkdir(fmt.Sprintf("%s/cmd", p.AbsolutePath), 0751)
}
rootFile, err := os.Create(fmt.Sprintf("%s/cmd/root.go", p.AbsolutePath))
if err != nil {
er("can't read the destination of symlink: " + err.Error())
}
absPath = path
return err
}
defer rootFile.Close()
p := new(Project)
p.absPath = strings.TrimSuffix(absPath, findCmdDir(absPath))
p.name = filepath.ToSlash(trimSrcPath(p.absPath, p.SrcPath()))
return p
}
// trimSrcPath trims at the beginning of absPath the srcPath.
func trimSrcPath(absPath, srcPath string) string {
relPath, err := filepath.Rel(srcPath, absPath)
rootTemplate := template.Must(template.New("root").Parse(string(tpl.RootTemplate())))
err = rootTemplate.Execute(rootFile, p)
if err != nil {
er(err)
}
return relPath
return err
}
// License returns the License object of project.
func (p *Project) License() License {
if p.license.Text == "" && p.license.Name != "None" {
p.license = getLicense()
}
return p.license
// create license
return p.createLicenseFile()
}
// Name returns the name of project, e.g. "github.com/spf13/cobra"
func (p Project) Name() string {
return p.name
func (p *Project) createLicenseFile() error {
data := map[string]interface{}{
"copyright": copyrightLine(),
}
licenseFile, err := os.Create(fmt.Sprintf("%s/LICENSE", p.AbsolutePath))
if err != nil {
return err
}
// CmdPath returns absolute path to directory, where all commands are located.
func (p *Project) CmdPath() string {
if p.absPath == "" {
return ""
}
if p.cmdPath == "" {
p.cmdPath = filepath.Join(p.absPath, findCmdDir(p.absPath))
}
return p.cmdPath
licenseTemplate := template.Must(template.New("license").Parse(p.Legal.Text))
return licenseTemplate.Execute(licenseFile, data)
}
// findCmdDir checks if base of absPath is cmd dir and returns it or
// looks for existing cmd dir in absPath.
func findCmdDir(absPath string) string {
if !exists(absPath) || isEmpty(absPath) {
return "cmd"
func (c *Command) Create() error {
cmdFile, err := os.Create(fmt.Sprintf("%s/cmd/%s.go", c.AbsolutePath, c.CmdName))
if err != nil {
return err
}
defer cmdFile.Close()
if isCmdDir(absPath) {
return filepath.Base(absPath)
commandTemplate := template.Must(template.New("sub").Parse(string(tpl.AddCommandTemplate())))
err = commandTemplate.Execute(cmdFile, c)
if err != nil {
return err
}
files, _ := filepath.Glob(filepath.Join(absPath, "c*"))
for _, file := range files {
if isCmdDir(file) {
return filepath.Base(file)
}
}
return "cmd"
}
// isCmdDir checks if base of name is one of cmdDir.
func isCmdDir(name string) bool {
name = filepath.Base(name)
for _, cmdDir := range []string{"cmd", "cmds", "command", "commands"} {
if name == cmdDir {
return true
}
}
return false
}
// AbsPath returns absolute path of project.
func (p Project) AbsPath() string {
return p.absPath
}
// SrcPath returns absolute path to $GOPATH/src where project is located.
func (p *Project) SrcPath() string {
if p.srcPath != "" {
return p.srcPath
}
if p.absPath == "" {
p.srcPath = srcPaths[0]
return p.srcPath
}
for _, srcPath := range srcPaths {
if filepathHasPrefix(p.absPath, srcPath) {
p.srcPath = srcPath
break
}
}
return p.srcPath
}
func filepathHasPrefix(path string, prefix string) bool {
if len(path) <= len(prefix) {
return false
}
if runtime.GOOS == "windows" {
// Paths in windows are case-insensitive.
return strings.EqualFold(path[0:len(prefix)], prefix)
}
return path[0:len(prefix)] == prefix
return nil
}

View file

@ -1,24 +1,3 @@
package cmd
import (
"testing"
)
func TestFindExistingPackage(t *testing.T) {
path := findPackage("github.com/spf13/cobra")
if path == "" {
t.Fatal("findPackage didn't find the existing package")
}
if !hasGoPathPrefix(path) {
t.Fatalf("%q is not in GOPATH, but must be", path)
}
}
func hasGoPathPrefix(path string) bool {
for _, srcPath := range srcPaths {
if filepathHasPrefix(path, srcPath) {
return true
}
}
return false
}
/* todo: write tests */

View file

@ -23,7 +23,8 @@ import (
var (
// Used for flags.
cfgFile, userLicense string
cfgFile string
userLicense string
rootCmd = &cobra.Command{
Use: "cobra",

View file

@ -1,17 +1,18 @@
// Copyright © 2017 NAME HERE <EMAIL ADDRESS>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
Copyright © 2019 NAME HERE <EMAIL ADDRESS>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import "github.com/spf13/testproject/cmd"

View file

@ -1,30 +1,34 @@
// Copyright © 2017 NAME HERE <EMAIL ADDRESS>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
Copyright © 2019 NAME HERE <EMAIL ADDRESS>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmd
import (
"fmt"
"os"
"github.com/spf13/cobra"
homedir "github.com/mitchellh/go-homedir"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
var cfgFile string
// rootCmd represents the base command when called without any subcommands
var rootCmd = &cobra.Command{
Use: "testproject",
@ -55,13 +59,16 @@ func init() {
// Here you will define your flags and configuration settings.
// Cobra supports persistent flags, which, if defined here,
// will be global for your application.
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.testproject.yaml)")
// Cobra also supports local flags, which will only run
// when this action is called directly.
rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}
// initConfig reads in config file and ENV variables if set.
func initConfig() {
if cfgFile != "" {
@ -87,3 +94,4 @@ func initConfig() {
fmt.Println("Using config file:", viper.ConfigFileUsed())
}
}

View file

@ -1,17 +1,18 @@
// Copyright © 2017 NAME HERE <EMAIL ADDRESS>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
Copyright © 2019 NAME HERE <EMAIL ADDRESS>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmd
import (

153
cobra/tpl/main.go Normal file
View file

@ -0,0 +1,153 @@
package tpl
func MainTemplate() []byte {
return []byte(`/*
{{ .Copyright }}
{{ if .Legal.Header }}{{ .Legal.Header }}{{ end }}
*/
package main
import "{{ .PkgName }}/cmd"
func main() {
cmd.Execute()
}
`)
}
func RootTemplate() []byte {
return []byte(`/*
{{ .Copyright }}
{{ if .Legal.Header }}{{ .Legal.Header }}{{ end }}
*/
package cmd
import (
"fmt"
"os"
"github.com/spf13/cobra"
{{ if .Viper }}
homedir "github.com/mitchellh/go-homedir"
"github.com/spf13/viper"
{{ end }}
)
{{ if .Viper }}
var cfgFile string
{{ end }}
// rootCmd represents the base command when called without any subcommands
var rootCmd = &cobra.Command{
Use: "{{ .AppName }}",
Short: "A brief description of your application",
Long: ` + "`" + `A longer description that spans multiple lines and likely contains
examples and usage of using your application. For example:
Cobra is a CLI library for Go that empowers applications.
This application is a tool to generate the needed files
to quickly create a Cobra application.` + "`" + `,
// Uncomment the following line if your bare application
// has an action associated with it:
// Run: func(cmd *cobra.Command, args []string) { },
}
// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
if err := rootCmd.Execute(); err != nil {
fmt.Println(err)
os.Exit(1)
}
}
func init() {
{{- if .Viper }}
cobra.OnInitialize(initConfig)
{{ end }}
// Here you will define your flags and configuration settings.
// Cobra supports persistent flags, which, if defined here,
// will be global for your application.
{{ if .Viper }}
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.{{ .AppName }}.yaml)")
{{ else }}
// rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.{{ .AppName }}.yaml)")
{{ end }}
// Cobra also supports local flags, which will only run
// when this action is called directly.
rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}
{{ if .Viper }}
// initConfig reads in config file and ENV variables if set.
func initConfig() {
if cfgFile != "" {
// Use config file from the flag.
viper.SetConfigFile(cfgFile)
} else {
// Find home directory.
home, err := homedir.Dir()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
// Search config in home directory with name ".{{ .AppName }}" (without extension).
viper.AddConfigPath(home)
viper.SetConfigName(".{{ .AppName }}")
}
viper.AutomaticEnv() // read in environment variables that match
// If a config file is found, read it in.
if err := viper.ReadInConfig(); err == nil {
fmt.Println("Using config file:", viper.ConfigFileUsed())
}
}
{{ end }}
`)
}
func AddCommandTemplate() []byte {
return []byte(`/*
{{ .Project.Copyright }}
{{ if .Legal.Header }}{{ .Legal.Header }}{{ end }}
*/
package cmd
import (
"fmt"
"github.com/spf13/cobra"
)
// {{ .CmdName }}Cmd represents the {{ .CmdName }} command
var {{ .CmdName }}Cmd = &cobra.Command{
Use: "{{ .CmdName }}",
Short: "A brief description of your command",
Long: ` + "`" + `A longer description that spans multiple lines and likely contains examples
and usage of using your command. For example:
Cobra is a CLI library for Go that empowers applications.
This application is a tool to generate the needed files
to quickly create a Cobra application.` + "`" + `,
Run: func(cmd *cobra.Command, args []string) {
fmt.Println("{{ .CmdName }} called")
},
}
func init() {
{{ .CmdParent }}.AddCommand({{ .CmdName }}Cmd)
// Here you will define your flags and configuration settings.
// Cobra supports Persistent Flags which will work for this command
// and all subcommands, e.g.:
// {{ .CmdName }}Cmd.PersistentFlags().String("foo", "", "A help for foo")
// Cobra supports local flags which will only run when this command
// is called directly, e.g.:
// {{ .CmdName }}Cmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}
`)
}

View file

@ -177,8 +177,6 @@ type Command struct {
// that we can use on every pflag set and children commands
globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName
// output is an output writer defined by user.
output io.Writer
// usageFunc is usage func defined by user.
usageFunc func(*Command) error
// usageTemplate is usage template defined by user.
@ -195,6 +193,13 @@ type Command struct {
helpCommand *Command
// versionTemplate is the version template defined by user.
versionTemplate string
// inReader is a reader defined by the user that replaces stdin
inReader io.Reader
// outWriter is a writer defined by the user that replaces stdout
outWriter io.Writer
// errWriter is a writer defined by the user that replaces stderr
errWriter io.Writer
}
// SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden
@ -205,8 +210,28 @@ func (c *Command) SetArgs(a []string) {
// SetOutput sets the destination for usage and error messages.
// If output is nil, os.Stderr is used.
// Deprecated: Use SetOut and/or SetErr instead
func (c *Command) SetOutput(output io.Writer) {
c.output = output
c.outWriter = output
c.errWriter = output
}
// SetOut sets the destination for usage messages.
// If newOut is nil, os.Stdout is used.
func (c *Command) SetOut(newOut io.Writer) {
c.outWriter = newOut
}
// SetErr sets the destination for error messages.
// If newErr is nil, os.Stderr is used.
func (c *Command) SetErr(newErr io.Writer) {
c.errWriter = newErr
}
// SetOut sets the source for input data
// If newIn is nil, os.Stdin is used.
func (c *Command) SetIn(newIn io.Reader) {
c.inReader = newIn
}
// SetUsageFunc sets usage function. Usage can be defined by application.
@ -267,9 +292,19 @@ func (c *Command) OutOrStderr() io.Writer {
return c.getOut(os.Stderr)
}
// ErrOrStderr returns output to stderr
func (c *Command) ErrOrStderr() io.Writer {
return c.getErr(os.Stderr)
}
// ErrOrStderr returns output to stderr
func (c *Command) InOrStdin() io.Reader {
return c.getIn(os.Stdin)
}
func (c *Command) getOut(def io.Writer) io.Writer {
if c.output != nil {
return c.output
if c.outWriter != nil {
return c.outWriter
}
if c.HasParent() {
return c.parent.getOut(def)
@ -277,6 +312,26 @@ func (c *Command) getOut(def io.Writer) io.Writer {
return def
}
func (c *Command) getErr(def io.Writer) io.Writer {
if c.errWriter != nil {
return c.errWriter
}
if c.HasParent() {
return c.parent.getErr(def)
}
return def
}
func (c *Command) getIn(def io.Reader) io.Reader {
if c.inReader != nil {
return c.inReader
}
if c.HasParent() {
return c.parent.getIn(def)
}
return def
}
// UsageFunc returns either the function set by SetUsageFunc for this command
// or a parent, or it returns a default usage function.
func (c *Command) UsageFunc() (f func(*Command) error) {
@ -329,13 +384,22 @@ func (c *Command) Help() error {
return nil
}
// UsageString return usage string.
// UsageString returns usage string.
func (c *Command) UsageString() string {
tmpOutput := c.output
// Storing normal writers
tmpOutput := c.outWriter
tmpErr := c.errWriter
bb := new(bytes.Buffer)
c.SetOutput(bb)
c.outWriter = bb
c.errWriter = bb
c.Usage()
c.output = tmpOutput
// Setting things back to normal
c.outWriter = tmpOutput
c.errWriter = tmpErr
return bb.String()
}
@ -1068,6 +1132,21 @@ func (c *Command) Printf(format string, i ...interface{}) {
c.Print(fmt.Sprintf(format, i...))
}
// PrintErr is a convenience method to Print to the defined Err output, fallback to Stderr if not set.
func (c *Command) PrintErr(i ...interface{}) {
fmt.Fprint(c.ErrOrStderr(), i...)
}
// PrintErrln is a convenience method to Println to the defined Err output, fallback to Stderr if not set.
func (c *Command) PrintErrln(i ...interface{}) {
c.Print(fmt.Sprintln(i...))
}
// PrintErrf is a convenience method to Printf to the defined Err output, fallback to Stderr if not set.
func (c *Command) PrintErrf(format string, i ...interface{}) {
c.Print(fmt.Sprintf(format, i...))
}
// CommandPath returns the full path to this command.
func (c *Command) CommandPath() string {
if c.HasParent() {

View file

@ -1381,6 +1381,46 @@ func TestSetOutput(t *testing.T) {
}
}
func TestSetOut(t *testing.T) {
c := &Command{}
c.SetOut(nil)
if out := c.OutOrStdout(); out != os.Stdout {
t.Errorf("Expected setting output to nil to revert back to stdout")
}
}
func TestSetErr(t *testing.T) {
c := &Command{}
c.SetErr(nil)
if out := c.ErrOrStderr(); out != os.Stderr {
t.Errorf("Expected setting error to nil to revert back to stderr")
}
}
func TestSetIn(t *testing.T) {
c := &Command{}
c.SetIn(nil)
if out := c.InOrStdin(); out != os.Stdin {
t.Errorf("Expected setting input to nil to revert back to stdin")
}
}
func TestUsageStringRedirected(t *testing.T) {
c := &Command{}
c.usageFunc = func(cmd *Command) error {
cmd.Print("[stdout1]")
cmd.PrintErr("[stderr2]")
cmd.Print("[stdout3]")
return nil
}
expected := "[stdout1][stderr2][stdout3]"
if got := c.UsageString(); got != expected {
t.Errorf("Expected usage string to consider both stdout and stderr")
}
}
func TestFlagErrorFunc(t *testing.T) {
c := &Command{Use: "c", Run: emptyRun}

100
powershell_completions.go Normal file
View file

@ -0,0 +1,100 @@
// PowerShell completions are based on the amazing work from clap:
// https://github.com/clap-rs/clap/blob/3294d18efe5f264d12c9035f404c7d189d4824e1/src/completions/powershell.rs
//
// The generated scripts require PowerShell v5.0+ (which comes Windows 10, but
// can be downloaded separately for windows 7 or 8.1).
package cobra
import (
"bytes"
"fmt"
"io"
"os"
"strings"
"github.com/spf13/pflag"
)
var powerShellCompletionTemplate = `using namespace System.Management.Automation
using namespace System.Management.Automation.Language
Register-ArgumentCompleter -Native -CommandName '%s' -ScriptBlock {
param($wordToComplete, $commandAst, $cursorPosition)
$commandElements = $commandAst.CommandElements
$command = @(
'%s'
for ($i = 1; $i -lt $commandElements.Count; $i++) {
$element = $commandElements[$i]
if ($element -isnot [StringConstantExpressionAst] -or
$element.StringConstantType -ne [StringConstantType]::BareWord -or
$element.Value.StartsWith('-')) {
break
}
$element.Value
}
) -join ';'
$completions = @(switch ($command) {%s
})
$completions.Where{ $_.CompletionText -like "$wordToComplete*" } |
Sort-Object -Property ListItemText
}`
func generatePowerShellSubcommandCases(out io.Writer, cmd *Command, previousCommandName string) {
var cmdName string
if previousCommandName == "" {
cmdName = cmd.Name()
} else {
cmdName = fmt.Sprintf("%s;%s", previousCommandName, cmd.Name())
}
fmt.Fprintf(out, "\n '%s' {", cmdName)
cmd.Flags().VisitAll(func(flag *pflag.Flag) {
if nonCompletableFlag(flag) {
return
}
usage := escapeStringForPowerShell(flag.Usage)
if len(flag.Shorthand) > 0 {
fmt.Fprintf(out, "\n [CompletionResult]::new('-%s', '%s', [CompletionResultType]::ParameterName, '%s')", flag.Shorthand, flag.Shorthand, usage)
}
fmt.Fprintf(out, "\n [CompletionResult]::new('--%s', '%s', [CompletionResultType]::ParameterName, '%s')", flag.Name, flag.Name, usage)
})
for _, subCmd := range cmd.Commands() {
usage := escapeStringForPowerShell(subCmd.Short)
fmt.Fprintf(out, "\n [CompletionResult]::new('%s', '%s', [CompletionResultType]::ParameterValue, '%s')", subCmd.Name(), subCmd.Name(), usage)
}
fmt.Fprint(out, "\n break\n }")
for _, subCmd := range cmd.Commands() {
generatePowerShellSubcommandCases(out, subCmd, cmdName)
}
}
func escapeStringForPowerShell(s string) string {
return strings.Replace(s, "'", "''", -1)
}
// GenPowerShellCompletion generates PowerShell completion file and writes to the passed writer.
func (c *Command) GenPowerShellCompletion(w io.Writer) error {
buf := new(bytes.Buffer)
var subCommandCases bytes.Buffer
generatePowerShellSubcommandCases(&subCommandCases, c, "")
fmt.Fprintf(buf, powerShellCompletionTemplate, c.Name(), c.Name(), subCommandCases.String())
_, err := buf.WriteTo(w)
return err
}
// GenPowerShellCompletionFile generates PowerShell completion file.
func (c *Command) GenPowerShellCompletionFile(filename string) error {
outFile, err := os.Create(filename)
if err != nil {
return err
}
defer outFile.Close()
return c.GenPowerShellCompletion(outFile)
}

14
powershell_completions.md Normal file
View file

@ -0,0 +1,14 @@
# Generating PowerShell Completions For Your Own cobra.Command
Cobra can generate PowerShell completion scripts. Users need PowerShell version 5.0 or above, which comes with Windows 10 and can be downloaded separately for Windows 7 or 8.1. They can then write the completions to a file and source this file from their PowerShell profile, which is referenced by the `$Profile` environment variable. See `Get-Help about_Profiles` for more info about PowerShell profiles.
# What's supported
- Completion for subcommands using their `.Short` description
- Completion for non-hidden flags using their `.Name` and `.Shorthand`
# What's not yet supported
- Command aliases
- Required, filename or custom flags (they will work like normal flags)
- Custom completion scripts

View file

@ -0,0 +1,122 @@
package cobra
import (
"bytes"
"strings"
"testing"
)
func TestPowerShellCompletion(t *testing.T) {
tcs := []struct {
name string
root *Command
expectedExpressions []string
}{
{
name: "trivial",
root: &Command{Use: "trivialapp"},
expectedExpressions: []string{
"Register-ArgumentCompleter -Native -CommandName 'trivialapp' -ScriptBlock",
"$command = @(\n 'trivialapp'\n",
},
},
{
name: "tree",
root: func() *Command {
r := &Command{Use: "tree"}
sub1 := &Command{Use: "sub1"}
r.AddCommand(sub1)
sub11 := &Command{Use: "sub11"}
sub12 := &Command{Use: "sub12"}
sub1.AddCommand(sub11)
sub1.AddCommand(sub12)
sub2 := &Command{Use: "sub2"}
r.AddCommand(sub2)
sub21 := &Command{Use: "sub21"}
sub22 := &Command{Use: "sub22"}
sub2.AddCommand(sub21)
sub2.AddCommand(sub22)
return r
}(),
expectedExpressions: []string{
"'tree'",
"[CompletionResult]::new('sub1', 'sub1', [CompletionResultType]::ParameterValue, '')",
"[CompletionResult]::new('sub2', 'sub2', [CompletionResultType]::ParameterValue, '')",
"'tree;sub1'",
"[CompletionResult]::new('sub11', 'sub11', [CompletionResultType]::ParameterValue, '')",
"[CompletionResult]::new('sub12', 'sub12', [CompletionResultType]::ParameterValue, '')",
"'tree;sub1;sub11'",
"'tree;sub1;sub12'",
"'tree;sub2'",
"[CompletionResult]::new('sub21', 'sub21', [CompletionResultType]::ParameterValue, '')",
"[CompletionResult]::new('sub22', 'sub22', [CompletionResultType]::ParameterValue, '')",
"'tree;sub2;sub21'",
"'tree;sub2;sub22'",
},
},
{
name: "flags",
root: func() *Command {
r := &Command{Use: "flags"}
r.Flags().StringP("flag1", "a", "", "")
r.Flags().String("flag2", "", "")
sub1 := &Command{Use: "sub1"}
sub1.Flags().StringP("flag3", "c", "", "")
r.AddCommand(sub1)
return r
}(),
expectedExpressions: []string{
"'flags'",
"[CompletionResult]::new('-a', 'a', [CompletionResultType]::ParameterName, '')",
"[CompletionResult]::new('--flag1', 'flag1', [CompletionResultType]::ParameterName, '')",
"[CompletionResult]::new('--flag2', 'flag2', [CompletionResultType]::ParameterName, '')",
"[CompletionResult]::new('sub1', 'sub1', [CompletionResultType]::ParameterValue, '')",
"'flags;sub1'",
"[CompletionResult]::new('-c', 'c', [CompletionResultType]::ParameterName, '')",
"[CompletionResult]::new('--flag3', 'flag3', [CompletionResultType]::ParameterName, '')",
},
},
{
name: "usage",
root: func() *Command {
r := &Command{Use: "usage"}
r.Flags().String("flag", "", "this describes the usage of the 'flag' flag")
sub1 := &Command{
Use: "sub1",
Short: "short describes 'sub1'",
}
r.AddCommand(sub1)
return r
}(),
expectedExpressions: []string{
"[CompletionResult]::new('--flag', 'flag', [CompletionResultType]::ParameterName, 'this describes the usage of the ''flag'' flag')",
"[CompletionResult]::new('sub1', 'sub1', [CompletionResultType]::ParameterValue, 'short describes ''sub1''')",
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
buf := new(bytes.Buffer)
tc.root.GenPowerShellCompletion(buf)
output := buf.String()
for _, expectedExpression := range tc.expectedExpressions {
if !strings.Contains(output, expectedExpression) {
t.Errorf("Expected completion to contain %q somewhere; got %q", expectedExpression, output)
}
}
})
}
}

85
shell_completions.go Normal file
View file

@ -0,0 +1,85 @@
package cobra
import (
"github.com/spf13/pflag"
)
// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag if it exists,
// and causes your command to report an error if invoked without the flag.
func (c *Command) MarkFlagRequired(name string) error {
return MarkFlagRequired(c.Flags(), name)
}
// MarkPersistentFlagRequired adds the BashCompOneRequiredFlag annotation to the named persistent flag if it exists,
// and causes your command to report an error if invoked without the flag.
func (c *Command) MarkPersistentFlagRequired(name string) error {
return MarkFlagRequired(c.PersistentFlags(), name)
}
// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag if it exists,
// and causes your command to report an error if invoked without the flag.
func MarkFlagRequired(flags *pflag.FlagSet, name string) error {
return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"})
}
// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag, if it exists.
// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
func (c *Command) MarkFlagFilename(name string, extensions ...string) error {
return MarkFlagFilename(c.Flags(), name, extensions...)
}
// MarkFlagCustom adds the BashCompCustom annotation to the named flag, if it exists.
// Generated bash autocompletion will call the bash function f for the flag.
func (c *Command) MarkFlagCustom(name string, f string) error {
return MarkFlagCustom(c.Flags(), name, f)
}
// MarkPersistentFlagFilename instructs the various shell completion
// implementations to limit completions for this persistent flag to the
// specified extensions (patterns).
//
// Shell Completion compatibility matrix: bash, zsh
func (c *Command) MarkPersistentFlagFilename(name string, extensions ...string) error {
return MarkFlagFilename(c.PersistentFlags(), name, extensions...)
}
// MarkFlagFilename instructs the various shell completion implementations to
// limit completions for this flag to the specified extensions (patterns).
//
// Shell Completion compatibility matrix: bash, zsh
func MarkFlagFilename(flags *pflag.FlagSet, name string, extensions ...string) error {
return flags.SetAnnotation(name, BashCompFilenameExt, extensions)
}
// MarkFlagCustom instructs the various shell completion implementations to
// limit completions for this flag to the specified extensions (patterns).
//
// Shell Completion compatibility matrix: bash, zsh
func MarkFlagCustom(flags *pflag.FlagSet, name string, f string) error {
return flags.SetAnnotation(name, BashCompCustom, []string{f})
}
// MarkFlagDirname instructs the various shell completion implementations to
// complete only directories with this named flag.
//
// Shell Completion compatibility matrix: zsh
func (c *Command) MarkFlagDirname(name string) error {
return MarkFlagDirname(c.Flags(), name)
}
// MarkPersistentFlagDirname instructs the various shell completion
// implementations to complete only directories with this persistent named flag.
//
// Shell Completion compatibility matrix: zsh
func (c *Command) MarkPersistentFlagDirname(name string) error {
return MarkFlagDirname(c.PersistentFlags(), name)
}
// MarkFlagDirname instructs the various shell completion implementations to
// complete only directories with this specified flag.
//
// Shell Completion compatibility matrix: zsh
func MarkFlagDirname(flags *pflag.FlagSet, name string) error {
zshPattern := "-(/)"
return flags.SetAnnotation(name, zshCompDirname, []string{zshPattern})
}

View file

@ -1,13 +1,102 @@
package cobra
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"sort"
"strings"
"text/template"
"github.com/spf13/pflag"
)
const (
zshCompArgumentAnnotation = "cobra_annotations_zsh_completion_argument_annotation"
zshCompArgumentFilenameComp = "cobra_annotations_zsh_completion_argument_file_completion"
zshCompArgumentWordComp = "cobra_annotations_zsh_completion_argument_word_completion"
zshCompDirname = "cobra_annotations_zsh_dirname"
)
var (
zshCompFuncMap = template.FuncMap{
"genZshFuncName": zshCompGenFuncName,
"extractFlags": zshCompExtractFlag,
"genFlagEntryForZshArguments": zshCompGenFlagEntryForArguments,
"extractArgsCompletions": zshCompExtractArgumentCompletionHintsForRendering,
}
zshCompletionText = `
{{/* should accept Command (that contains subcommands) as parameter */}}
{{define "argumentsC" -}}
{{ $cmdPath := genZshFuncName .}}
function {{$cmdPath}} {
local -a commands
_arguments -C \{{- range extractFlags .}}
{{genFlagEntryForZshArguments .}} \{{- end}}
"1: :->cmnds" \
"*::arg:->args"
case $state in
cmnds)
commands=({{range .Commands}}{{if not .Hidden}}
"{{.Name}}:{{.Short}}"{{end}}{{end}}
)
_describe "command" commands
;;
esac
case "$words[1]" in {{- range .Commands}}{{if not .Hidden}}
{{.Name}})
{{$cmdPath}}_{{.Name}}
;;{{end}}{{end}}
esac
}
{{range .Commands}}{{if not .Hidden}}
{{template "selectCmdTemplate" .}}
{{- end}}{{end}}
{{- end}}
{{/* should accept Command without subcommands as parameter */}}
{{define "arguments" -}}
function {{genZshFuncName .}} {
{{" _arguments"}}{{range extractFlags .}} \
{{genFlagEntryForZshArguments . -}}
{{end}}{{range extractArgsCompletions .}} \
{{.}}{{end}}
}
{{end}}
{{/* dispatcher for commands with or without subcommands */}}
{{define "selectCmdTemplate" -}}
{{if .Hidden}}{{/* ignore hidden*/}}{{else -}}
{{if .Commands}}{{template "argumentsC" .}}{{else}}{{template "arguments" .}}{{end}}
{{- end}}
{{- end}}
{{/* template entry point */}}
{{define "Main" -}}
#compdef _{{.Name}} {{.Name}}
{{template "selectCmdTemplate" .}}
{{end}}
`
)
// zshCompArgsAnnotation is used to encode/decode zsh completion for
// arguments to/from Command.Annotations.
type zshCompArgsAnnotation map[int]zshCompArgHint
type zshCompArgHint struct {
// Indicates the type of the completion to use. One of:
// zshCompArgumentFilenameComp or zshCompArgumentWordComp
Tipe string `json:"type"`
// A value for the type above (globs for file completion or words)
Options []string `json:"options"`
}
// GenZshCompletionFile generates zsh completion file.
func (c *Command) GenZshCompletionFile(filename string) error {
outFile, err := os.Create(filename)
@ -19,108 +108,229 @@ func (c *Command) GenZshCompletionFile(filename string) error {
return c.GenZshCompletion(outFile)
}
// GenZshCompletion generates a zsh completion file and writes to the passed writer.
// GenZshCompletion generates a zsh completion file and writes to the passed
// writer. The completion always run on the root command regardless of the
// command it was called from.
func (c *Command) GenZshCompletion(w io.Writer) error {
buf := new(bytes.Buffer)
tmpl, err := template.New("Main").Funcs(zshCompFuncMap).Parse(zshCompletionText)
if err != nil {
return fmt.Errorf("error creating zsh completion template: %v", err)
}
return tmpl.Execute(w, c.Root())
}
writeHeader(buf, c)
maxDepth := maxDepth(c)
writeLevelMapping(buf, maxDepth)
writeLevelCases(buf, maxDepth, c)
_, err := buf.WriteTo(w)
// MarkZshCompPositionalArgumentFile marks the specified argument (first
// argument is 1) as completed by file selection. patterns (e.g. "*.txt") are
// optional - if not provided the completion will search for all files.
func (c *Command) MarkZshCompPositionalArgumentFile(argPosition int, patterns ...string) error {
if argPosition < 1 {
return fmt.Errorf("Invalid argument position (%d)", argPosition)
}
annotation, err := c.zshCompGetArgsAnnotations()
if err != nil {
return err
}
func writeHeader(w io.Writer, cmd *Command) {
fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name())
if c.zshcompArgsAnnotationnIsDuplicatePosition(annotation, argPosition) {
return fmt.Errorf("Duplicate annotation for positional argument at index %d", argPosition)
}
annotation[argPosition] = zshCompArgHint{
Tipe: zshCompArgumentFilenameComp,
Options: patterns,
}
return c.zshCompSetArgsAnnotations(annotation)
}
func maxDepth(c *Command) int {
if len(c.Commands()) == 0 {
return 0
// MarkZshCompPositionalArgumentWords marks the specified positional argument
// (first argument is 1) as completed by the provided words. At east one word
// must be provided, spaces within words will be offered completion with
// "word\ word".
func (c *Command) MarkZshCompPositionalArgumentWords(argPosition int, words ...string) error {
if argPosition < 1 {
return fmt.Errorf("Invalid argument position (%d)", argPosition)
}
maxDepthSub := 0
for _, s := range c.Commands() {
subDepth := maxDepth(s)
if subDepth > maxDepthSub {
maxDepthSub = subDepth
if len(words) == 0 {
return fmt.Errorf("Trying to set empty word list for positional argument %d", argPosition)
}
annotation, err := c.zshCompGetArgsAnnotations()
if err != nil {
return err
}
return 1 + maxDepthSub
if c.zshcompArgsAnnotationnIsDuplicatePosition(annotation, argPosition) {
return fmt.Errorf("Duplicate annotation for positional argument at index %d", argPosition)
}
annotation[argPosition] = zshCompArgHint{
Tipe: zshCompArgumentWordComp,
Options: words,
}
return c.zshCompSetArgsAnnotations(annotation)
}
func writeLevelMapping(w io.Writer, numLevels int) {
fmt.Fprintln(w, `_arguments \`)
for i := 1; i <= numLevels; i++ {
fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i)
fmt.Fprintln(w)
func zshCompExtractArgumentCompletionHintsForRendering(c *Command) ([]string, error) {
var result []string
annotation, err := c.zshCompGetArgsAnnotations()
if err != nil {
return nil, err
}
fmt.Fprintf(w, ` '%d: :%s'`, numLevels+1, "_files")
fmt.Fprintln(w)
for k, v := range annotation {
s, err := zshCompRenderZshCompArgHint(k, v)
if err != nil {
return nil, err
}
result = append(result, s)
}
if len(c.ValidArgs) > 0 {
if _, positionOneExists := annotation[1]; !positionOneExists {
s, err := zshCompRenderZshCompArgHint(1, zshCompArgHint{
Tipe: zshCompArgumentWordComp,
Options: c.ValidArgs,
})
if err != nil {
return nil, err
}
result = append(result, s)
}
}
sort.Strings(result)
return result, nil
}
func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
fmt.Fprintln(w, "case $state in")
defer fmt.Fprintln(w, "esac")
for i := 1; i <= maxDepth; i++ {
fmt.Fprintf(w, " level%d)\n", i)
writeLevel(w, root, i)
fmt.Fprintln(w, " ;;")
func zshCompRenderZshCompArgHint(i int, z zshCompArgHint) (string, error) {
switch t := z.Tipe; t {
case zshCompArgumentFilenameComp:
var globs []string
for _, g := range z.Options {
globs = append(globs, fmt.Sprintf(`-g "%s"`, g))
}
return fmt.Sprintf(`'%d: :_files %s'`, i, strings.Join(globs, " ")), nil
case zshCompArgumentWordComp:
var words []string
for _, w := range z.Options {
words = append(words, fmt.Sprintf("%q", w))
}
return fmt.Sprintf(`'%d: :(%s)'`, i, strings.Join(words, " ")), nil
default:
return "", fmt.Errorf("Invalid zsh argument completion annotation: %s", t)
}
fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;")
}
func writeLevel(w io.Writer, root *Command, i int) {
fmt.Fprintf(w, " case $words[%d] in\n", i)
defer fmt.Fprintln(w, " esac")
commands := filterByLevel(root, i)
byParent := groupByParent(commands)
for p, c := range byParent {
names := names(c)
fmt.Fprintf(w, " %s)\n", p)
fmt.Fprintf(w, " _arguments '%d: :(%s)'\n", i, strings.Join(names, " "))
fmt.Fprintln(w, " ;;")
}
fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;")
func (c *Command) zshcompArgsAnnotationnIsDuplicatePosition(annotation zshCompArgsAnnotation, position int) bool {
_, dup := annotation[position]
return dup
}
func filterByLevel(c *Command, l int) []*Command {
cs := make([]*Command, 0)
if l == 0 {
cs = append(cs, c)
return cs
func (c *Command) zshCompGetArgsAnnotations() (zshCompArgsAnnotation, error) {
annotation := make(zshCompArgsAnnotation)
annotationString, ok := c.Annotations[zshCompArgumentAnnotation]
if !ok {
return annotation, nil
}
for _, s := range c.Commands() {
cs = append(cs, filterByLevel(s, l-1)...)
err := json.Unmarshal([]byte(annotationString), &annotation)
if err != nil {
return annotation, fmt.Errorf("Error unmarshaling zsh argument annotation: %v", err)
}
return cs
return annotation, nil
}
func groupByParent(commands []*Command) map[string][]*Command {
m := make(map[string][]*Command)
for _, c := range commands {
parent := c.Parent()
if parent == nil {
continue
func (c *Command) zshCompSetArgsAnnotations(annotation zshCompArgsAnnotation) error {
jsn, err := json.Marshal(annotation)
if err != nil {
return fmt.Errorf("Error marshaling zsh argument annotation: %v", err)
}
m[parent.Name()] = append(m[parent.Name()], c)
if c.Annotations == nil {
c.Annotations = make(map[string]string)
}
return m
c.Annotations[zshCompArgumentAnnotation] = string(jsn)
return nil
}
func names(commands []*Command) []string {
ns := make([]string, len(commands))
for i, c := range commands {
ns[i] = c.Name()
func zshCompGenFuncName(c *Command) string {
if c.HasParent() {
return zshCompGenFuncName(c.Parent()) + "_" + c.Name()
}
return ns
return "_" + c.Name()
}
func zshCompExtractFlag(c *Command) []*pflag.Flag {
var flags []*pflag.Flag
c.LocalFlags().VisitAll(func(f *pflag.Flag) {
if !f.Hidden {
flags = append(flags, f)
}
})
c.InheritedFlags().VisitAll(func(f *pflag.Flag) {
if !f.Hidden {
flags = append(flags, f)
}
})
return flags
}
// zshCompGenFlagEntryForArguments returns an entry that matches _arguments
// zsh-completion parameters. It's too complicated to generate in a template.
func zshCompGenFlagEntryForArguments(f *pflag.Flag) string {
if f.Name == "" || f.Shorthand == "" {
return zshCompGenFlagEntryForSingleOptionFlag(f)
}
return zshCompGenFlagEntryForMultiOptionFlag(f)
}
func zshCompGenFlagEntryForSingleOptionFlag(f *pflag.Flag) string {
var option, multiMark, extras string
if zshCompFlagCouldBeSpecifiedMoreThenOnce(f) {
multiMark = "*"
}
option = "--" + f.Name
if option == "--" {
option = "-" + f.Shorthand
}
extras = zshCompGenFlagEntryExtras(f)
return fmt.Sprintf(`'%s%s[%s]%s'`, multiMark, option, zshCompQuoteFlagDescription(f.Usage), extras)
}
func zshCompGenFlagEntryForMultiOptionFlag(f *pflag.Flag) string {
var options, parenMultiMark, curlyMultiMark, extras string
if zshCompFlagCouldBeSpecifiedMoreThenOnce(f) {
parenMultiMark = "*"
curlyMultiMark = "\\*"
}
options = fmt.Sprintf(`'(%s-%s %s--%s)'{%s-%s,%s--%s}`,
parenMultiMark, f.Shorthand, parenMultiMark, f.Name, curlyMultiMark, f.Shorthand, curlyMultiMark, f.Name)
extras = zshCompGenFlagEntryExtras(f)
return fmt.Sprintf(`%s'[%s]%s'`, options, zshCompQuoteFlagDescription(f.Usage), extras)
}
func zshCompGenFlagEntryExtras(f *pflag.Flag) string {
if f.NoOptDefVal != "" {
return ""
}
extras := ":" // allow options for flag (even without assistance)
for key, values := range f.Annotations {
switch key {
case zshCompDirname:
extras = fmt.Sprintf(":filename:_files -g %q", values[0])
case BashCompFilenameExt:
extras = ":filename:_files"
for _, pattern := range values {
extras = extras + fmt.Sprintf(` -g "%s"`, pattern)
}
}
}
return extras
}
func zshCompFlagCouldBeSpecifiedMoreThenOnce(f *pflag.Flag) bool {
return strings.Contains(f.Value.Type(), "Slice") ||
strings.Contains(f.Value.Type(), "Array")
}
func zshCompQuoteFlagDescription(s string) string {
return strings.Replace(s, "'", `'\''`, -1)
}

39
zsh_completions.md Normal file
View file

@ -0,0 +1,39 @@
## Generating Zsh Completion for your cobra.Command
Cobra supports native Zsh completion generated from the root `cobra.Command`.
The generated completion script should be put somewhere in your `$fpath` named
`_<YOUR COMMAND>`.
### What's Supported
* Completion for all non-hidden subcommands using their `.Short` description.
* Completion for all non-hidden flags using the following rules:
* Filename completion works by marking the flag with `cmd.MarkFlagFilename...`
family of commands.
* The requirement for argument to the flag is decided by the `.NoOptDefVal`
flag value - if it's empty then completion will expect an argument.
* Flags of one of the various `*Array` and `*Slice` types supports multiple
specifications (with or without argument depending on the specific type).
* Completion of positional arguments using the following rules:
* Argument position for all options below starts at `1`. If argument position
`0` is requested it will raise an error.
* Use `command.MarkZshCompPositionalArgumentFile` to complete filenames. Glob
patterns (e.g. `"*.log"`) are optional - if not specified it will offer to
complete all file types.
* Use `command.MarkZshCompPositionalArgumentWords` to offer specific words for
completion. At least one word is required.
* It's possible to specify completion for some arguments and leave some
unspecified (e.g. offer words for second argument but nothing for first
argument). This will cause no completion for first argument but words
completion for second argument.
* If no argument completion was specified for 1st argument (but optionally was
specified for 2nd) and the command has `ValidArgs` it will be used as
completion options for 1st argument.
* Argument completions only offered for commands with no subcommands.
### What's not yet Supported
* Custom completion scripts are not supported yet (We should probably create zsh
specific one, doesn't make sense to re-use the bash one as the functions will
be different).
* Whatever other feature you're looking for and doesn't exist :)

View file

@ -2,88 +2,474 @@ package cobra
import (
"bytes"
"regexp"
"strings"
"testing"
)
func TestZshCompletion(t *testing.T) {
func TestGenZshCompletion(t *testing.T) {
var debug bool
var option string
tcs := []struct {
name string
root *Command
expectedExpressions []string
invocationArgs []string
skip string
}{
{
name: "simple command",
root: func() *Command {
r := &Command{
Use: "mycommand",
Long: "My Command long description",
Run: emptyRun,
}
r.Flags().BoolVar(&debug, "debug", debug, "description")
return r
}(),
expectedExpressions: []string{
`(?s)function _mycommand {\s+_arguments \\\s+'--debug\[description\]'.*--help.*}`,
"#compdef _mycommand mycommand",
},
},
{
name: "flags with both long and short flags",
root: func() *Command {
r := &Command{
Use: "testcmd",
Long: "long description",
Run: emptyRun,
}
r.Flags().BoolVarP(&debug, "debug", "d", debug, "debug description")
return r
}(),
expectedExpressions: []string{
`'\(-d --debug\)'{-d,--debug}'\[debug description\]'`,
},
},
{
name: "command with subcommands and flags with values",
root: func() *Command {
r := &Command{
Use: "rootcmd",
Long: "Long rootcmd description",
}
d := &Command{
Use: "subcmd1",
Short: "Subcmd1 short description",
Run: emptyRun,
}
e := &Command{
Use: "subcmd2",
Long: "Subcmd2 short description",
Run: emptyRun,
}
r.PersistentFlags().BoolVar(&debug, "debug", debug, "description")
d.Flags().StringVarP(&option, "option", "o", option, "option description")
r.AddCommand(d, e)
return r
}(),
expectedExpressions: []string{
`commands=\(\n\s+"help:.*\n\s+"subcmd1:.*\n\s+"subcmd2:.*\n\s+\)`,
`_arguments \\\n.*'--debug\[description]'`,
`_arguments -C \\\n.*'--debug\[description]'`,
`function _rootcmd_subcmd1 {`,
`function _rootcmd_subcmd1 {`,
`_arguments \\\n.*'\(-o --option\)'{-o,--option}'\[option description]:' \\\n`,
},
},
{
name: "filename completion with and without globs",
root: func() *Command {
var file string
r := &Command{
Use: "mycmd",
Short: "my command short description",
Run: emptyRun,
}
r.Flags().StringVarP(&file, "config", "c", file, "config file")
r.MarkFlagFilename("config")
r.Flags().String("output", "", "output file")
r.MarkFlagFilename("output", "*.log", "*.txt")
return r
}(),
expectedExpressions: []string{
`\n +'\(-c --config\)'{-c,--config}'\[config file]:filename:_files'`,
`:_files -g "\*.log" -g "\*.txt"`,
},
},
{
name: "repeated variables both with and without value",
root: func() *Command {
r := genTestCommand("mycmd", true)
_ = r.Flags().BoolSliceP("debug", "d", []bool{}, "debug usage")
_ = r.Flags().StringArray("option", []string{}, "options")
return r
}(),
expectedExpressions: []string{
`'\*--option\[options]`,
`'\(\*-d \*--debug\)'{\\\*-d,\\\*--debug}`,
},
},
{
name: "generated flags --help and --version should be created even when not executing root cmd",
root: func() *Command {
r := &Command{
Use: "mycmd",
Short: "mycmd short description",
Version: "myversion",
}
s := genTestCommand("sub1", true)
r.AddCommand(s)
return s
}(),
expectedExpressions: []string{
"--version",
"--help",
},
invocationArgs: []string{
"sub1",
},
skip: "--version and --help are currently not generated when not running on root command",
},
{
name: "zsh generation should run on root command",
root: func() *Command {
r := genTestCommand("root", false)
s := genTestCommand("sub1", true)
r.AddCommand(s)
return s
}(),
expectedExpressions: []string{
"function _root {",
},
},
{
name: "flag description with single quote (') shouldn't break quotes in completion file",
root: func() *Command {
r := genTestCommand("root", true)
r.Flags().Bool("private", false, "Don't show public info")
return r
}(),
expectedExpressions: []string{
`--private\[Don'\\''t show public info]`,
},
},
{
name: "argument completion for file with and without patterns",
root: func() *Command {
r := genTestCommand("root", true)
r.MarkZshCompPositionalArgumentFile(1, "*.log")
r.MarkZshCompPositionalArgumentFile(2)
return r
}(),
expectedExpressions: []string{
`'1: :_files -g "\*.log"' \\\n\s+'2: :_files`,
},
},
{
name: "argument zsh completion for words",
root: func() *Command {
r := genTestCommand("root", true)
r.MarkZshCompPositionalArgumentWords(1, "word1", "word2")
return r
}(),
expectedExpressions: []string{
`'1: :\("word1" "word2"\)`,
},
},
{
name: "argument completion for words with spaces",
root: func() *Command {
r := genTestCommand("root", true)
r.MarkZshCompPositionalArgumentWords(1, "single", "multiple words")
return r
}(),
expectedExpressions: []string{
`'1: :\("single" "multiple words"\)'`,
},
},
{
name: "argument completion when command has ValidArgs and no annotation for argument completion",
root: func() *Command {
r := genTestCommand("root", true)
r.ValidArgs = []string{"word1", "word2"}
return r
}(),
expectedExpressions: []string{
`'1: :\("word1" "word2"\)'`,
},
},
{
name: "argument completion when command has ValidArgs and no annotation for argument at argPosition 1",
root: func() *Command {
r := genTestCommand("root", true)
r.ValidArgs = []string{"word1", "word2"}
r.MarkZshCompPositionalArgumentFile(2)
return r
}(),
expectedExpressions: []string{
`'1: :\("word1" "word2"\)' \\`,
},
},
{
name: "directory completion for flag",
root: func() *Command {
r := genTestCommand("root", true)
r.Flags().String("test", "", "test")
r.PersistentFlags().String("ptest", "", "ptest")
r.MarkFlagDirname("test")
r.MarkPersistentFlagDirname("ptest")
return r
}(),
expectedExpressions: []string{
`--test\[test]:filename:_files -g "-\(/\)"`,
`--ptest\[ptest]:filename:_files -g "-\(/\)"`,
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
if tc.skip != "" {
t.Skip(tc.skip)
}
tc.root.Root().SetArgs(tc.invocationArgs)
tc.root.Execute()
buf := new(bytes.Buffer)
if err := tc.root.GenZshCompletion(buf); err != nil {
t.Error(err)
}
output := buf.Bytes()
for _, expr := range tc.expectedExpressions {
rgx, err := regexp.Compile(expr)
if err != nil {
t.Errorf("error compiling expression (%s): %v", expr, err)
}
if !rgx.Match(output) {
t.Errorf("expected completion (%s) to match '%s'", buf.String(), expr)
}
}
})
}
}
func TestGenZshCompletionHidden(t *testing.T) {
tcs := []struct {
name string
root *Command
expectedExpressions []string
}{
{
name: "trivial",
root: &Command{Use: "trivialapp"},
expectedExpressions: []string{"#compdef trivial"},
},
{
name: "linear",
name: "hidden commands",
root: func() *Command {
r := &Command{Use: "linear"}
sub1 := &Command{Use: "sub1"}
r.AddCommand(sub1)
sub2 := &Command{Use: "sub2"}
sub1.AddCommand(sub2)
sub3 := &Command{Use: "sub3"}
sub2.AddCommand(sub3)
return r
}(),
expectedExpressions: []string{"sub1", "sub2", "sub3"},
},
{
name: "flat",
root: func() *Command {
r := &Command{Use: "flat"}
r.AddCommand(&Command{Use: "c1"})
r.AddCommand(&Command{Use: "c2"})
return r
}(),
expectedExpressions: []string{"(c1 c2)"},
},
{
name: "tree",
root: func() *Command {
r := &Command{Use: "tree"}
sub1 := &Command{Use: "sub1"}
r.AddCommand(sub1)
sub11 := &Command{Use: "sub11"}
sub12 := &Command{Use: "sub12"}
sub1.AddCommand(sub11)
sub1.AddCommand(sub12)
sub2 := &Command{Use: "sub2"}
r.AddCommand(sub2)
sub21 := &Command{Use: "sub21"}
sub22 := &Command{Use: "sub22"}
sub2.AddCommand(sub21)
sub2.AddCommand(sub22)
r := &Command{
Use: "main",
Short: "main short description",
}
s1 := &Command{
Use: "sub1",
Hidden: true,
Run: emptyRun,
}
s2 := &Command{
Use: "sub2",
Short: "short sub2 description",
Run: emptyRun,
}
r.AddCommand(s1, s2)
return r
}(),
expectedExpressions: []string{"(sub11 sub12)", "(sub21 sub22)"},
expectedExpressions: []string{
"sub1",
},
},
{
name: "hidden flags",
root: func() *Command {
var hidden string
r := &Command{
Use: "root",
Short: "root short description",
Run: emptyRun,
}
r.Flags().StringVarP(&hidden, "hidden", "H", hidden, "hidden usage")
if err := r.Flags().MarkHidden("hidden"); err != nil {
t.Errorf("Error setting flag hidden: %v\n", err)
}
return r
}(),
expectedExpressions: []string{
"--hidden",
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
tc.root.Execute()
buf := new(bytes.Buffer)
tc.root.GenZshCompletion(buf)
if err := tc.root.GenZshCompletion(buf); err != nil {
t.Error(err)
}
output := buf.String()
for _, expectedExpression := range tc.expectedExpressions {
if !strings.Contains(output, expectedExpression) {
t.Errorf("Expected completion to contain %q somewhere; got %q", expectedExpression, output)
for _, expr := range tc.expectedExpressions {
if strings.Contains(output, expr) {
t.Errorf("Expected completion (%s) not to contain '%s' but it does", output, expr)
}
}
})
}
}
func TestMarkZshCompPositionalArgumentFile(t *testing.T) {
t.Run("Doesn't allow overwriting existing positional argument", func(t *testing.T) {
c := &Command{}
if err := c.MarkZshCompPositionalArgumentFile(1, "*.log"); err != nil {
t.Errorf("Received error when we shouldn't have: %v\n", err)
}
if err := c.MarkZshCompPositionalArgumentFile(1); err == nil {
t.Error("Didn't receive an error when trying to overwrite argument position")
}
})
t.Run("Refuses to accept argPosition less then 1", func(t *testing.T) {
c := &Command{}
err := c.MarkZshCompPositionalArgumentFile(0, "*")
if err == nil {
t.Fatal("Error was not thrown when indicating argument position 0")
}
if !strings.Contains(err.Error(), "position") {
t.Errorf("expected error message '%s' to contain 'position'", err.Error())
}
})
}
func TestMarkZshCompPositionalArgumentWords(t *testing.T) {
t.Run("Doesn't allow overwriting existing positional argument", func(t *testing.T) {
c := &Command{}
if err := c.MarkZshCompPositionalArgumentFile(1, "*.log"); err != nil {
t.Errorf("Received error when we shouldn't have: %v\n", err)
}
if err := c.MarkZshCompPositionalArgumentWords(1, "hello"); err == nil {
t.Error("Didn't receive an error when trying to overwrite argument position")
}
})
t.Run("Doesn't allow calling without words", func(t *testing.T) {
c := &Command{}
if err := c.MarkZshCompPositionalArgumentWords(0); err == nil {
t.Error("Should not allow saving empty word list for annotation")
}
})
t.Run("Refuses to accept argPosition less then 1", func(t *testing.T) {
c := &Command{}
err := c.MarkZshCompPositionalArgumentWords(0, "word")
if err == nil {
t.Fatal("Should not allow setting argument position less then 1")
}
if !strings.Contains(err.Error(), "position") {
t.Errorf("Expected error '%s' to contain 'position' but didn't", err.Error())
}
})
}
func BenchmarkMediumSizeConstruct(b *testing.B) {
root := constructLargeCommandHierarchy()
// if err := root.GenZshCompletionFile("_mycmd"); err != nil {
// b.Error(err)
// }
for i := 0; i < b.N; i++ {
buf := new(bytes.Buffer)
err := root.GenZshCompletion(buf)
if err != nil {
b.Error(err)
}
}
}
func TestExtractFlags(t *testing.T) {
var debug, cmdc, cmdd bool
c := &Command{
Use: "cmdC",
Long: "Command C",
}
c.PersistentFlags().BoolVarP(&debug, "debug", "d", debug, "debug mode")
c.Flags().BoolVar(&cmdc, "cmd-c", cmdc, "Command C")
d := &Command{
Use: "CmdD",
Long: "Command D",
}
d.Flags().BoolVar(&cmdd, "cmd-d", cmdd, "Command D")
c.AddCommand(d)
resC := zshCompExtractFlag(c)
resD := zshCompExtractFlag(d)
if len(resC) != 2 {
t.Errorf("expected Command C to return 2 flags, got %d", len(resC))
}
if len(resD) != 2 {
t.Errorf("expected Command D to return 2 flags, got %d", len(resD))
}
}
func constructLargeCommandHierarchy() *Command {
var config, st1, st2 string
var long, debug bool
var in1, in2 int
var verbose []bool
r := genTestCommand("mycmd", false)
r.PersistentFlags().StringVarP(&config, "config", "c", config, "config usage")
if err := r.MarkPersistentFlagFilename("config", "*"); err != nil {
panic(err)
}
s1 := genTestCommand("sub1", true)
s1.Flags().BoolVar(&long, "long", long, "long description")
s1.Flags().BoolSliceVar(&verbose, "verbose", verbose, "verbose description")
s1.Flags().StringArray("option", []string{}, "various options")
s2 := genTestCommand("sub2", true)
s2.PersistentFlags().BoolVar(&debug, "debug", debug, "debug description")
s3 := genTestCommand("sub3", true)
s3.Hidden = true
s1_1 := genTestCommand("sub1sub1", true)
s1_1.Flags().StringVar(&st1, "st1", st1, "st1 description")
s1_1.Flags().StringVar(&st2, "st2", st2, "st2 description")
s1_2 := genTestCommand("sub1sub2", true)
s1_3 := genTestCommand("sub1sub3", true)
s1_3.Flags().IntVar(&in1, "int1", in1, "int1 description")
s1_3.Flags().IntVar(&in2, "int2", in2, "int2 description")
s1_3.Flags().StringArrayP("option", "O", []string{}, "more options")
s2_1 := genTestCommand("sub2sub1", true)
s2_2 := genTestCommand("sub2sub2", true)
s2_3 := genTestCommand("sub2sub3", true)
s2_4 := genTestCommand("sub2sub4", true)
s2_5 := genTestCommand("sub2sub5", true)
s1.AddCommand(s1_1, s1_2, s1_3)
s2.AddCommand(s2_1, s2_2, s2_3, s2_4, s2_5)
r.AddCommand(s1, s2, s3)
r.Execute()
return r
}
func genTestCommand(name string, withRun bool) *Command {
r := &Command{
Use: name,
Short: name + " short description",
Long: "Long description for " + name,
}
if withRun {
r.Run = emptyRun
}
return r
}