Merge branch 'master' into master

This commit is contained in:
Alex 2019-06-07 17:11:37 +02:00 committed by GitHub
commit 6eb960b3cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 1689 additions and 902 deletions

2
.gitignore vendored
View file

@ -34,3 +34,5 @@ tags
*.exe *.exe
cobra.test cobra.test
.idea/*

View file

@ -23,6 +23,7 @@ Many of the most widely used Go projects are built using Cobra, such as:
[Istio](https://istio.io), [Istio](https://istio.io),
[Prototool](https://github.com/uber/prototool), [Prototool](https://github.com/uber/prototool),
[mattermost-server](https://github.com/mattermost/mattermost-server), [mattermost-server](https://github.com/mattermost/mattermost-server),
[Gardener](https://github.com/gardener/gardenctl),
etc. etc.
[![Build Status](https://travis-ci.org/spf13/cobra.svg "Travis CI status")](https://travis-ci.org/spf13/cobra) [![Build Status](https://travis-ci.org/spf13/cobra.svg "Travis CI status")](https://travis-ci.org/spf13/cobra)
@ -48,6 +49,7 @@ etc.
* [Suggestions when "unknown command" happens](#suggestions-when-unknown-command-happens) * [Suggestions when "unknown command" happens](#suggestions-when-unknown-command-happens)
* [Generating documentation for your command](#generating-documentation-for-your-command) * [Generating documentation for your command](#generating-documentation-for-your-command)
* [Generating bash completions](#generating-bash-completions) * [Generating bash completions](#generating-bash-completions)
* [Generating zsh completions](#generating-zsh-completions)
- [Contributing](#contributing) - [Contributing](#contributing)
- [License](#license) - [License](#license)
@ -336,7 +338,7 @@ rootCmd.PersistentFlags().BoolVarP(&Verbose, "verbose", "v", false, "verbose out
A flag can also be assigned locally which will only apply to that specific command. A flag can also be assigned locally which will only apply to that specific command.
```go ```go
rootCmd.Flags().StringVarP(&Source, "source", "s", "", "Source directory to read from") localCmd.Flags().StringVarP(&Source, "source", "s", "", "Source directory to read from")
``` ```
### Local Flag on Parent Commands ### Local Flag on Parent Commands
@ -719,6 +721,11 @@ Cobra can generate documentation based on subcommands, flags, etc. in the follow
Cobra can generate a bash-completion file. If you add more information to your command, these completions can be amazingly powerful and flexible. Read more about it in [Bash Completions](bash_completions.md). Cobra can generate a bash-completion file. If you add more information to your command, these completions can be amazingly powerful and flexible. Read more about it in [Bash Completions](bash_completions.md).
## Generating zsh completions
Cobra can generate zsh-completion file. Read more about it in
[Zsh Completions](zsh_completions.md).
# Contributing # Contributing
1. Fork it 1. Fork it

View file

@ -545,51 +545,3 @@ func (c *Command) GenBashCompletionFile(filename string) error {
return c.GenBashCompletion(outFile) return c.GenBashCompletion(outFile)
} }
// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag if it exists,
// and causes your command to report an error if invoked without the flag.
func (c *Command) MarkFlagRequired(name string) error {
return MarkFlagRequired(c.Flags(), name)
}
// MarkPersistentFlagRequired adds the BashCompOneRequiredFlag annotation to the named persistent flag if it exists,
// and causes your command to report an error if invoked without the flag.
func (c *Command) MarkPersistentFlagRequired(name string) error {
return MarkFlagRequired(c.PersistentFlags(), name)
}
// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag if it exists,
// and causes your command to report an error if invoked without the flag.
func MarkFlagRequired(flags *pflag.FlagSet, name string) error {
return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"})
}
// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag, if it exists.
// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
func (c *Command) MarkFlagFilename(name string, extensions ...string) error {
return MarkFlagFilename(c.Flags(), name, extensions...)
}
// MarkFlagCustom adds the BashCompCustom annotation to the named flag, if it exists.
// Generated bash autocompletion will call the bash function f for the flag.
func (c *Command) MarkFlagCustom(name string, f string) error {
return MarkFlagCustom(c.Flags(), name, f)
}
// MarkPersistentFlagFilename adds the BashCompFilenameExt annotation to the named persistent flag, if it exists.
// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
func (c *Command) MarkPersistentFlagFilename(name string, extensions ...string) error {
return MarkFlagFilename(c.PersistentFlags(), name, extensions...)
}
// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag in the flag set, if it exists.
// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
func MarkFlagFilename(flags *pflag.FlagSet, name string, extensions ...string) error {
return flags.SetAnnotation(name, BashCompFilenameExt, extensions)
}
// MarkFlagCustom adds the BashCompCustom annotation to the named flag in the flag set, if it exists.
// Generated bash autocompletion will call the bash function f for the flag.
func MarkFlagCustom(flags *pflag.FlagSet, name string, f string) error {
return flags.SetAnnotation(name, BashCompCustom, []string{f})
}

View file

@ -16,24 +16,20 @@ package cmd
import ( import (
"fmt" "fmt"
"os" "os"
"path/filepath"
"unicode" "unicode"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
func init() { var (
addCmd.Flags().StringVarP(&packageName, "package", "t", "", "target package name (e.g. github.com/spf13/hugo)") packageName string
addCmd.Flags().StringVarP(&parentName, "parent", "p", "rootCmd", "variable name of parent command for this command (e.g. xyCmd)") parentName string
}
var packageName, parentName string addCmd = &cobra.Command{
Use: "add [command name]",
var addCmd = &cobra.Command{ Aliases: []string{"command"},
Use: "add [command name]", Short: "Add a command to a Cobra Application",
Aliases: []string{"command"}, Long: `Add (cobra add) will create a new command, with a license and
Short: "Add a command to a Cobra Application",
Long: `Add (cobra add) will create a new command, with a license and
the appropriate structure for a Cobra-based CLI application, the appropriate structure for a Cobra-based CLI application,
and register it to its parent (default rootCmd). and register it to its parent (default rootCmd).
@ -42,28 +38,41 @@ with an initial uppercase letter.
Example: cobra add server -> resulting in a new cmd/server.go`, Example: cobra add server -> resulting in a new cmd/server.go`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
if len(args) < 1 { if len(args) < 1 {
er("add needs a name for the command") er("add needs a name for the command")
} }
var project *Project
if packageName != "" {
project = NewProject(packageName)
} else {
wd, err := os.Getwd() wd, err := os.Getwd()
if err != nil { if err != nil {
er(err) er(err)
} }
project = NewProjectFromPath(wd)
}
cmdName := validateCmdName(args[0]) commandName := validateCmdName(args[0])
cmdPath := filepath.Join(project.CmdPath(), cmdName+".go") command := &Command{
createCmdFile(project.License(), cmdPath, cmdName) CmdName: commandName,
CmdParent: parentName,
Project: &Project{
AbsolutePath: wd,
Legal: getLicense(),
Copyright: copyrightLine(),
},
}
fmt.Fprintln(cmd.OutOrStdout(), cmdName, "created at", cmdPath) err = command.Create()
}, if err != nil {
er(err)
}
fmt.Printf("%s created at %s", command.CmdName, command.AbsolutePath)
},
}
)
func init() {
addCmd.Flags().StringVarP(&packageName, "package", "t", "", "target package name (e.g. github.com/spf13/hugo)")
addCmd.Flags().StringVarP(&parentName, "parent", "p", "rootCmd", "variable name of parent command for this command (e.g. xyCmd)")
addCmd.Flags().MarkDeprecated("package", "this operation has been removed.")
} }
// validateCmdName returns source without any dashes and underscore. // validateCmdName returns source without any dashes and underscore.
@ -118,62 +127,3 @@ func validateCmdName(source string) string {
} }
return output return output
} }
func createCmdFile(license License, path, cmdName string) {
template := `{{comment .copyright}}
{{if .license}}{{comment .license}}{{end}}
package {{.cmdPackage}}
import (
"fmt"
"github.com/spf13/cobra"
)
// {{.cmdName}}Cmd represents the {{.cmdName}} command
var {{.cmdName}}Cmd = &cobra.Command{
Use: "{{.cmdName}}",
Short: "A brief description of your command",
Long: ` + "`" + `A longer description that spans multiple lines and likely contains examples
and usage of using your command. For example:
Cobra is a CLI library for Go that empowers applications.
This application is a tool to generate the needed files
to quickly create a Cobra application.` + "`" + `,
Run: func(cmd *cobra.Command, args []string) {
fmt.Println("{{.cmdName}} called")
},
}
func init() {
{{.parentName}}.AddCommand({{.cmdName}}Cmd)
// Here you will define your flags and configuration settings.
// Cobra supports Persistent Flags which will work for this command
// and all subcommands, e.g.:
// {{.cmdName}}Cmd.PersistentFlags().String("foo", "", "A help for foo")
// Cobra supports local flags which will only run when this command
// is called directly, e.g.:
// {{.cmdName}}Cmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}
`
data := make(map[string]interface{})
data["copyright"] = copyrightLine()
data["license"] = license.Header
data["cmdPackage"] = filepath.Base(filepath.Dir(path)) // last dir of path
data["parentName"] = parentName
data["cmdName"] = cmdName
cmdScript, err := executeTemplate(template, data)
if err != nil {
er(err)
}
err = writeStringToFile(path, cmdScript)
if err != nil {
er(err)
}
}

View file

@ -1,85 +1,45 @@
package cmd package cmd
import ( import (
"errors" "fmt"
"io/ioutil"
"os" "os"
"path/filepath"
"testing" "testing"
"github.com/spf13/viper"
) )
// TestGoldenAddCmd initializes the project "github.com/spf13/testproject"
// in GOPATH, adds "test" command
// and compares the content of all files in cmd directory of testproject
// with appropriate golden files.
// Use -update to update existing golden files.
func TestGoldenAddCmd(t *testing.T) { func TestGoldenAddCmd(t *testing.T) {
projectName := "github.com/spf13/testproject"
project := NewProject(projectName)
defer os.RemoveAll(project.AbsPath())
viper.Set("author", "NAME HERE <EMAIL ADDRESS>") wd, _ := os.Getwd()
viper.Set("license", "apache") command := &Command{
viper.Set("year", 2017) CmdName: "test",
defer viper.Set("author", nil) CmdParent: parentName,
defer viper.Set("license", nil) Project: &Project{
defer viper.Set("year", nil) AbsolutePath: fmt.Sprintf("%s/testproject", wd),
Legal: getLicense(),
Copyright: copyrightLine(),
// Initialize the project first. // required to init
initializeProject(project) AppName: "testproject",
PkgName: "github.com/spf13/testproject",
Viper: true,
},
}
// Then add the "test" command. // init project first
cmdName := "test" command.Project.Create()
cmdPath := filepath.Join(project.CmdPath(), cmdName+".go") defer func() {
createCmdFile(project.License(), cmdPath, cmdName) if _, err := os.Stat(command.AbsolutePath); err == nil {
os.RemoveAll(command.AbsolutePath)
expectedFiles := []string{".", "root.go", "test.go"}
gotFiles := []string{}
// Check project file hierarchy and compare the content of every single file
// with appropriate golden file.
err := filepath.Walk(project.CmdPath(), func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
} }
}()
// Make path relative to project.CmdPath(). if err := command.Create(); err != nil {
// E.g. path = "/home/user/go/src/github.com/spf13/testproject/cmd/root.go"
// then it returns just "root.go".
relPath, err := filepath.Rel(project.CmdPath(), path)
if err != nil {
return err
}
relPath = filepath.ToSlash(relPath)
gotFiles = append(gotFiles, relPath)
goldenPath := filepath.Join("testdata", filepath.Base(path)+".golden")
switch relPath {
// Known directories.
case ".":
return nil
// Known files.
case "root.go", "test.go":
if *update {
got, err := ioutil.ReadFile(path)
if err != nil {
return err
}
ioutil.WriteFile(goldenPath, got, 0644)
}
return compareFiles(path, goldenPath)
}
// Unknown file.
return errors.New("unknown file: " + path)
})
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Check if some files lack. generatedFile := fmt.Sprintf("%s/cmd/%s.go", command.AbsolutePath, command.CmdName)
if err := checkLackFiles(expectedFiles, gotFiles); err != nil { goldenFile := fmt.Sprintf("testdata/%s.go.golden", command.CmdName)
err := compareFiles(generatedFile, goldenFile)
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }

View file

@ -15,19 +15,20 @@ package cmd
import ( import (
"fmt" "fmt"
"os"
"path"
"path/filepath"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"os"
"path"
) )
var initCmd = &cobra.Command{ var (
Use: "init [name]", pkgName string
Aliases: []string{"initialize", "initialise", "create"},
Short: "Initialize a Cobra Application", initCmd = &cobra.Command{
Long: `Initialize (cobra init) will create a new application, with a license Use: "init [name]",
Aliases: []string{"initialize", "initialise", "create"},
Short: "Initialize a Cobra Application",
Long: `Initialize (cobra init) will create a new application, with a license
and the appropriate structure for a Cobra-based CLI application. and the appropriate structure for a Cobra-based CLI application.
* If a name is provided, it will be created in the current directory; * If a name is provided, it will be created in the current directory;
@ -39,196 +40,38 @@ and the appropriate structure for a Cobra-based CLI application.
Init will not use an existing directory with contents.`, Init will not use an existing directory with contents.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
wd, err := os.Getwd()
if err != nil {
er(err)
}
var project *Project wd, err := os.Getwd()
if len(args) == 0 { if err != nil {
project = NewProjectFromPath(wd) er(err)
} else if len(args) == 1 {
arg := args[0]
if arg[0] == '.' {
arg = filepath.Join(wd, arg)
} }
if filepath.IsAbs(arg) {
project = NewProjectFromPath(arg) if len(args) > 0 {
} else { if args[0] != "." {
project = NewProject(arg) wd = fmt.Sprintf("%s/%s", wd, args[0])
}
} }
} else {
er("please provide only one argument")
}
initializeProject(project) project := &Project{
AbsolutePath: wd,
PkgName: pkgName,
Legal: getLicense(),
Copyright: copyrightLine(),
Viper: viper.GetBool("useViper"),
AppName: path.Base(pkgName),
}
fmt.Fprintln(cmd.OutOrStdout(), `Your Cobra application is ready at if err := project.Create(); err != nil {
`+project.AbsPath()+` er(err)
}
Give it a try by going there and running `+"`go run main.go`."+`
Add commands to it by running `+"`cobra add [cmdname]`.")
},
}
func initializeProject(project *Project) {
if !exists(project.AbsPath()) { // If path doesn't yet exist, create it
err := os.MkdirAll(project.AbsPath(), os.ModePerm)
if err != nil {
er(err)
}
} else if !isEmpty(project.AbsPath()) { // If path exists and is not empty don't use it
er("Cobra will not create a new project in a non empty directory: " + project.AbsPath())
}
// We have a directory and it's empty. Time to initialize it.
createLicenseFile(project.License(), project.AbsPath())
createMainFile(project)
createRootCmdFile(project)
}
func createLicenseFile(license License, path string) {
data := make(map[string]interface{})
data["copyright"] = copyrightLine()
// Generate license template from text and data.
text, err := executeTemplate(license.Text, data)
if err != nil {
er(err)
}
// Write license text to LICENSE file.
err = writeStringToFile(filepath.Join(path, "LICENSE"), text)
if err != nil {
er(err)
}
}
func createMainFile(project *Project) {
mainTemplate := `{{ comment .copyright }}
{{if .license}}{{ comment .license }}{{end}}
package main
import "{{ .importpath }}"
func main() {
cmd.Execute()
}
`
data := make(map[string]interface{})
data["copyright"] = copyrightLine()
data["license"] = project.License().Header
data["importpath"] = path.Join(project.Name(), filepath.Base(project.CmdPath()))
mainScript, err := executeTemplate(mainTemplate, data)
if err != nil {
er(err)
}
err = writeStringToFile(filepath.Join(project.AbsPath(), "main.go"), mainScript)
if err != nil {
er(err)
}
}
func createRootCmdFile(project *Project) {
template := `{{comment .copyright}}
{{if .license}}{{comment .license}}{{end}}
package cmd
import (
"fmt"
"os"
{{if .viper}}
homedir "github.com/mitchellh/go-homedir"{{end}}
"github.com/spf13/cobra"{{if .viper}}
"github.com/spf13/viper"{{end}}
){{if .viper}}
var cfgFile string{{end}}
// rootCmd represents the base command when called without any subcommands
var rootCmd = &cobra.Command{
Use: "{{.appName}}",
Short: "A brief description of your application",
Long: ` + "`" + `A longer description that spans multiple lines and likely contains
examples and usage of using your application. For example:
Cobra is a CLI library for Go that empowers applications.
This application is a tool to generate the needed files
to quickly create a Cobra application.` + "`" + `,
// Uncomment the following line if your bare application
// has an action associated with it:
// Run: func(cmd *cobra.Command, args []string) { },
}
// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
if err := rootCmd.Execute(); err != nil {
fmt.Println(err)
os.Exit(1)
}
}
func init() { {{- if .viper}}
cobra.OnInitialize(initConfig)
{{end}}
// Here you will define your flags and configuration settings.
// Cobra supports persistent flags, which, if defined here,
// will be global for your application.{{ if .viper }}
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.{{ .appName }}.yaml)"){{ else }}
// rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.{{ .appName }}.yaml)"){{ end }}
// Cobra also supports local flags, which will only run
// when this action is called directly.
rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}{{ if .viper }}
// initConfig reads in config file and ENV variables if set.
func initConfig() {
if cfgFile != "" {
// Use config file from the flag.
viper.SetConfigFile(cfgFile)
} else {
// Find home directory.
home, err := homedir.Dir()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
// Search config in home directory with name ".{{ .appName }}" (without extension).
viper.AddConfigPath(home)
viper.SetConfigName(".{{ .appName }}")
}
viper.AutomaticEnv() // read in environment variables that match
// If a config file is found, read it in.
if err := viper.ReadInConfig(); err == nil {
fmt.Println("Using config file:", viper.ConfigFileUsed())
}
}{{ end }}
`
data := make(map[string]interface{})
data["copyright"] = copyrightLine()
data["viper"] = viper.GetBool("useViper")
data["license"] = project.License().Header
data["appName"] = path.Base(project.Name())
rootCmdScript, err := executeTemplate(template, data)
if err != nil {
er(err)
}
err = writeStringToFile(filepath.Join(project.CmdPath(), "root.go"), rootCmdScript)
if err != nil {
er(err)
}
fmt.Printf("Your Cobra applicaton is ready at\n%s\n", project.AbsolutePath)
},
}
)
func init() {
initCmd.Flags().StringVar(&pkgName, "pkg-name", "", "fully qualified pkg name")
initCmd.MarkFlagRequired("pkg-name")
} }

View file

@ -1,83 +1,42 @@
package cmd package cmd
import ( import (
"errors" "fmt"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/spf13/viper"
) )
// TestGoldenInitCmd initializes the project "github.com/spf13/testproject"
// in GOPATH and compares the content of files in initialized project with
// appropriate golden files ("testdata/*.golden").
// Use -update to update existing golden files.
func TestGoldenInitCmd(t *testing.T) { func TestGoldenInitCmd(t *testing.T) {
projectName := "github.com/spf13/testproject"
project := NewProject(projectName)
defer os.RemoveAll(project.AbsPath())
viper.Set("author", "NAME HERE <EMAIL ADDRESS>") wd, _ := os.Getwd()
viper.Set("license", "apache") project := &Project{
viper.Set("year", 2017) AbsolutePath: fmt.Sprintf("%s/testproject", wd),
defer viper.Set("author", nil) PkgName: "github.com/spf13/testproject",
defer viper.Set("license", nil) Legal: getLicense(),
defer viper.Set("year", nil) Copyright: copyrightLine(),
Viper: true,
os.Args = []string{"cobra", "init", projectName} AppName: "testproject",
if err := rootCmd.Execute(); err != nil {
t.Fatal("Error by execution:", err)
} }
expectedFiles := []string{".", "cmd", "LICENSE", "main.go", "cmd/root.go"} err := project.Create()
gotFiles := []string{}
// Check project file hierarchy and compare the content of every single file
// with appropriate golden file.
err := filepath.Walk(project.AbsPath(), func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Make path relative to project.AbsPath().
// E.g. path = "/home/user/go/src/github.com/spf13/testproject/cmd/root.go"
// then it returns just "cmd/root.go".
relPath, err := filepath.Rel(project.AbsPath(), path)
if err != nil {
return err
}
relPath = filepath.ToSlash(relPath)
gotFiles = append(gotFiles, relPath)
goldenPath := filepath.Join("testdata", filepath.Base(path)+".golden")
switch relPath {
// Known directories.
case ".", "cmd":
return nil
// Known files.
case "LICENSE", "main.go", "cmd/root.go":
if *update {
got, err := ioutil.ReadFile(path)
if err != nil {
return err
}
if err := ioutil.WriteFile(goldenPath, got, 0644); err != nil {
t.Fatal("Error while updating file:", err)
}
}
return compareFiles(path, goldenPath)
}
// Unknown file.
return errors.New("unknown file: " + path)
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Check if some files lack. defer func() {
if err := checkLackFiles(expectedFiles, gotFiles); err != nil { if _, err := os.Stat(project.AbsolutePath); err == nil {
t.Fatal(err) os.RemoveAll(project.AbsolutePath)
}
}()
expectedFiles := []string{"LICENSE", "main.go", "cmd/root.go"}
for _, f := range expectedFiles {
generatedFile := fmt.Sprintf("%s/%s", project.AbsolutePath, f)
goldenFile := fmt.Sprintf("testdata/%s.golden", filepath.Base(f))
err := compareFiles(generatedFile, goldenFile)
if err != nil {
t.Fatal(err)
}
} }
} }

View file

@ -1,200 +1,96 @@
package cmd package cmd
import ( import (
"fmt"
"github.com/spf13/cobra/cobra/tpl"
"os" "os"
"path/filepath" "text/template"
"runtime"
"strings"
) )
// Project contains name, license and paths to projects. // Project contains name, license and paths to projects.
type Project struct { type Project struct {
absPath string // v2
cmdPath string PkgName string
srcPath string Copyright string
license License AbsolutePath string
name string Legal License
Viper bool
AppName string
} }
// NewProject returns Project with specified project name. type Command struct {
func NewProject(projectName string) *Project { CmdName string
if projectName == "" { CmdParent string
er("can't create project with blank name") *Project
} }
p := new(Project) func (p *Project) Create() error {
p.name = projectName
// 1. Find already created protect. // check if AbsolutePath exists
p.absPath = findPackage(projectName) if _, err := os.Stat(p.AbsolutePath); os.IsNotExist(err) {
// create directory
// 2. If there are no created project with this path, and user is in GOPATH, if err := os.Mkdir(p.AbsolutePath, 0754); err != nil {
// then use GOPATH/src/projectName. return err
if p.absPath == "" {
wd, err := os.Getwd()
if err != nil {
er(err)
}
for _, srcPath := range srcPaths {
goPath := filepath.Dir(srcPath)
if filepathHasPrefix(wd, goPath) {
p.absPath = filepath.Join(srcPath, projectName)
break
}
} }
} }
// 3. If user is not in GOPATH, then use (first GOPATH)/src/projectName. // create main.go
if p.absPath == "" { mainFile, err := os.Create(fmt.Sprintf("%s/main.go", p.AbsolutePath))
p.absPath = filepath.Join(srcPaths[0], projectName)
}
return p
}
// findPackage returns full path to existing go package in GOPATHs.
func findPackage(packageName string) string {
if packageName == "" {
return ""
}
for _, srcPath := range srcPaths {
packagePath := filepath.Join(srcPath, packageName)
if exists(packagePath) {
return packagePath
}
}
return ""
}
// NewProjectFromPath returns Project with specified absolute path to
// package.
func NewProjectFromPath(absPath string) *Project {
if absPath == "" {
er("can't create project: absPath can't be blank")
}
if !filepath.IsAbs(absPath) {
er("can't create project: absPath is not absolute")
}
// If absPath is symlink, use its destination.
fi, err := os.Lstat(absPath)
if err != nil { if err != nil {
er("can't read path info: " + err.Error()) return err
}
if fi.Mode()&os.ModeSymlink != 0 {
path, err := os.Readlink(absPath)
if err != nil {
er("can't read the destination of symlink: " + err.Error())
}
absPath = path
} }
defer mainFile.Close()
p := new(Project) mainTemplate := template.Must(template.New("main").Parse(string(tpl.MainTemplate())))
p.absPath = strings.TrimSuffix(absPath, findCmdDir(absPath)) err = mainTemplate.Execute(mainFile, p)
p.name = filepath.ToSlash(trimSrcPath(p.absPath, p.SrcPath()))
return p
}
// trimSrcPath trims at the beginning of absPath the srcPath.
func trimSrcPath(absPath, srcPath string) string {
relPath, err := filepath.Rel(srcPath, absPath)
if err != nil { if err != nil {
er(err) return err
} }
return relPath
// create cmd/root.go
if _, err = os.Stat(fmt.Sprintf("%s/cmd", p.AbsolutePath)); os.IsNotExist(err) {
os.Mkdir(fmt.Sprintf("%s/cmd", p.AbsolutePath), 0751)
}
rootFile, err := os.Create(fmt.Sprintf("%s/cmd/root.go", p.AbsolutePath))
if err != nil {
return err
}
defer rootFile.Close()
rootTemplate := template.Must(template.New("root").Parse(string(tpl.RootTemplate())))
err = rootTemplate.Execute(rootFile, p)
if err != nil {
return err
}
// create license
return p.createLicenseFile()
} }
// License returns the License object of project. func (p *Project) createLicenseFile() error {
func (p *Project) License() License { data := map[string]interface{}{
if p.license.Text == "" && p.license.Name != "None" { "copyright": copyrightLine(),
p.license = getLicense()
} }
return p.license licenseFile, err := os.Create(fmt.Sprintf("%s/LICENSE", p.AbsolutePath))
if err != nil {
return err
}
licenseTemplate := template.Must(template.New("license").Parse(p.Legal.Text))
return licenseTemplate.Execute(licenseFile, data)
} }
// Name returns the name of project, e.g. "github.com/spf13/cobra" func (c *Command) Create() error {
func (p Project) Name() string { cmdFile, err := os.Create(fmt.Sprintf("%s/cmd/%s.go", c.AbsolutePath, c.CmdName))
return p.name if err != nil {
} return err
}
// CmdPath returns absolute path to directory, where all commands are located. defer cmdFile.Close()
func (p *Project) CmdPath() string {
if p.absPath == "" { commandTemplate := template.Must(template.New("sub").Parse(string(tpl.AddCommandTemplate())))
return "" err = commandTemplate.Execute(cmdFile, c)
} if err != nil {
if p.cmdPath == "" { return err
p.cmdPath = filepath.Join(p.absPath, findCmdDir(p.absPath)) }
} return nil
return p.cmdPath
}
// findCmdDir checks if base of absPath is cmd dir and returns it or
// looks for existing cmd dir in absPath.
func findCmdDir(absPath string) string {
if !exists(absPath) || isEmpty(absPath) {
return "cmd"
}
if isCmdDir(absPath) {
return filepath.Base(absPath)
}
files, _ := filepath.Glob(filepath.Join(absPath, "c*"))
for _, file := range files {
if isCmdDir(file) {
return filepath.Base(file)
}
}
return "cmd"
}
// isCmdDir checks if base of name is one of cmdDir.
func isCmdDir(name string) bool {
name = filepath.Base(name)
for _, cmdDir := range []string{"cmd", "cmds", "command", "commands"} {
if name == cmdDir {
return true
}
}
return false
}
// AbsPath returns absolute path of project.
func (p Project) AbsPath() string {
return p.absPath
}
// SrcPath returns absolute path to $GOPATH/src where project is located.
func (p *Project) SrcPath() string {
if p.srcPath != "" {
return p.srcPath
}
if p.absPath == "" {
p.srcPath = srcPaths[0]
return p.srcPath
}
for _, srcPath := range srcPaths {
if filepathHasPrefix(p.absPath, srcPath) {
p.srcPath = srcPath
break
}
}
return p.srcPath
}
func filepathHasPrefix(path string, prefix string) bool {
if len(path) <= len(prefix) {
return false
}
if runtime.GOOS == "windows" {
// Paths in windows are case-insensitive.
return strings.EqualFold(path[0:len(prefix)], prefix)
}
return path[0:len(prefix)] == prefix
} }

View file

@ -1,24 +1,3 @@
package cmd package cmd
import ( /* todo: write tests */
"testing"
)
func TestFindExistingPackage(t *testing.T) {
path := findPackage("github.com/spf13/cobra")
if path == "" {
t.Fatal("findPackage didn't find the existing package")
}
if !hasGoPathPrefix(path) {
t.Fatalf("%q is not in GOPATH, but must be", path)
}
}
func hasGoPathPrefix(path string) bool {
for _, srcPath := range srcPaths {
if filepathHasPrefix(path, srcPath) {
return true
}
}
return false
}

View file

@ -23,7 +23,8 @@ import (
var ( var (
// Used for flags. // Used for flags.
cfgFile, userLicense string cfgFile string
userLicense string
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "cobra", Use: "cobra",

View file

@ -1,21 +1,22 @@
// Copyright © 2017 NAME HERE <EMAIL ADDRESS> /*
// Copyright © 2019 NAME HERE <EMAIL ADDRESS>
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main package main
import "github.com/spf13/testproject/cmd" import "github.com/spf13/testproject/cmd"
func main() { func main() {
cmd.Execute() cmd.Execute()
} }

View file

@ -1,89 +1,97 @@
// Copyright © 2017 NAME HERE <EMAIL ADDRESS> /*
// Copyright © 2019 NAME HERE <EMAIL ADDRESS>
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmd package cmd
import ( import (
"fmt" "fmt"
"os" "os"
"github.com/spf13/cobra"
homedir "github.com/mitchellh/go-homedir"
"github.com/spf13/viper"
homedir "github.com/mitchellh/go-homedir"
"github.com/spf13/cobra"
"github.com/spf13/viper"
) )
var cfgFile string var cfgFile string
// rootCmd represents the base command when called without any subcommands // rootCmd represents the base command when called without any subcommands
var rootCmd = &cobra.Command{ var rootCmd = &cobra.Command{
Use: "testproject", Use: "testproject",
Short: "A brief description of your application", Short: "A brief description of your application",
Long: `A longer description that spans multiple lines and likely contains Long: `A longer description that spans multiple lines and likely contains
examples and usage of using your application. For example: examples and usage of using your application. For example:
Cobra is a CLI library for Go that empowers applications. Cobra is a CLI library for Go that empowers applications.
This application is a tool to generate the needed files This application is a tool to generate the needed files
to quickly create a Cobra application.`, to quickly create a Cobra application.`,
// Uncomment the following line if your bare application // Uncomment the following line if your bare application
// has an action associated with it: // has an action associated with it:
// Run: func(cmd *cobra.Command, args []string) { }, // Run: func(cmd *cobra.Command, args []string) { },
} }
// Execute adds all child commands to the root command and sets flags appropriately. // Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd. // This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() { func Execute() {
if err := rootCmd.Execute(); err != nil { if err := rootCmd.Execute(); err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
} }
func init() { func init() {
cobra.OnInitialize(initConfig) cobra.OnInitialize(initConfig)
// Here you will define your flags and configuration settings. // Here you will define your flags and configuration settings.
// Cobra supports persistent flags, which, if defined here, // Cobra supports persistent flags, which, if defined here,
// will be global for your application. // will be global for your application.
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.testproject.yaml)")
// Cobra also supports local flags, which will only run rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.testproject.yaml)")
// when this action is called directly.
rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
// Cobra also supports local flags, which will only run
// when this action is called directly.
rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
} }
// initConfig reads in config file and ENV variables if set. // initConfig reads in config file and ENV variables if set.
func initConfig() { func initConfig() {
if cfgFile != "" { if cfgFile != "" {
// Use config file from the flag. // Use config file from the flag.
viper.SetConfigFile(cfgFile) viper.SetConfigFile(cfgFile)
} else { } else {
// Find home directory. // Find home directory.
home, err := homedir.Dir() home, err := homedir.Dir()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
// Search config in home directory with name ".testproject" (without extension). // Search config in home directory with name ".testproject" (without extension).
viper.AddConfigPath(home) viper.AddConfigPath(home)
viper.SetConfigName(".testproject") viper.SetConfigName(".testproject")
} }
viper.AutomaticEnv() // read in environment variables that match viper.AutomaticEnv() // read in environment variables that match
// If a config file is found, read it in. // If a config file is found, read it in.
if err := viper.ReadInConfig(); err == nil { if err := viper.ReadInConfig(); err == nil {
fmt.Println("Using config file:", viper.ConfigFileUsed()) fmt.Println("Using config file:", viper.ConfigFileUsed())
} }
} }

View file

@ -1,17 +1,18 @@
// Copyright © 2017 NAME HERE <EMAIL ADDRESS> /*
// Copyright © 2019 NAME HERE <EMAIL ADDRESS>
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmd package cmd
import ( import (

153
cobra/tpl/main.go Normal file
View file

@ -0,0 +1,153 @@
package tpl
func MainTemplate() []byte {
return []byte(`/*
{{ .Copyright }}
{{ if .Legal.Header }}{{ .Legal.Header }}{{ end }}
*/
package main
import "{{ .PkgName }}/cmd"
func main() {
cmd.Execute()
}
`)
}
func RootTemplate() []byte {
return []byte(`/*
{{ .Copyright }}
{{ if .Legal.Header }}{{ .Legal.Header }}{{ end }}
*/
package cmd
import (
"fmt"
"os"
"github.com/spf13/cobra"
{{ if .Viper }}
homedir "github.com/mitchellh/go-homedir"
"github.com/spf13/viper"
{{ end }}
)
{{ if .Viper }}
var cfgFile string
{{ end }}
// rootCmd represents the base command when called without any subcommands
var rootCmd = &cobra.Command{
Use: "{{ .AppName }}",
Short: "A brief description of your application",
Long: ` + "`" + `A longer description that spans multiple lines and likely contains
examples and usage of using your application. For example:
Cobra is a CLI library for Go that empowers applications.
This application is a tool to generate the needed files
to quickly create a Cobra application.` + "`" + `,
// Uncomment the following line if your bare application
// has an action associated with it:
// Run: func(cmd *cobra.Command, args []string) { },
}
// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
if err := rootCmd.Execute(); err != nil {
fmt.Println(err)
os.Exit(1)
}
}
func init() {
{{- if .Viper }}
cobra.OnInitialize(initConfig)
{{ end }}
// Here you will define your flags and configuration settings.
// Cobra supports persistent flags, which, if defined here,
// will be global for your application.
{{ if .Viper }}
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.{{ .AppName }}.yaml)")
{{ else }}
// rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.{{ .AppName }}.yaml)")
{{ end }}
// Cobra also supports local flags, which will only run
// when this action is called directly.
rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}
{{ if .Viper }}
// initConfig reads in config file and ENV variables if set.
func initConfig() {
if cfgFile != "" {
// Use config file from the flag.
viper.SetConfigFile(cfgFile)
} else {
// Find home directory.
home, err := homedir.Dir()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
// Search config in home directory with name ".{{ .AppName }}" (without extension).
viper.AddConfigPath(home)
viper.SetConfigName(".{{ .AppName }}")
}
viper.AutomaticEnv() // read in environment variables that match
// If a config file is found, read it in.
if err := viper.ReadInConfig(); err == nil {
fmt.Println("Using config file:", viper.ConfigFileUsed())
}
}
{{ end }}
`)
}
func AddCommandTemplate() []byte {
return []byte(`/*
{{ .Project.Copyright }}
{{ if .Legal.Header }}{{ .Legal.Header }}{{ end }}
*/
package cmd
import (
"fmt"
"github.com/spf13/cobra"
)
// {{ .CmdName }}Cmd represents the {{ .CmdName }} command
var {{ .CmdName }}Cmd = &cobra.Command{
Use: "{{ .CmdName }}",
Short: "A brief description of your command",
Long: ` + "`" + `A longer description that spans multiple lines and likely contains examples
and usage of using your command. For example:
Cobra is a CLI library for Go that empowers applications.
This application is a tool to generate the needed files
to quickly create a Cobra application.` + "`" + `,
Run: func(cmd *cobra.Command, args []string) {
fmt.Println("{{ .CmdName }} called")
},
}
func init() {
{{ .CmdParent }}.AddCommand({{ .CmdName }}Cmd)
// Here you will define your flags and configuration settings.
// Cobra supports Persistent Flags which will work for this command
// and all subcommands, e.g.:
// {{ .CmdName }}Cmd.PersistentFlags().String("foo", "", "A help for foo")
// Cobra supports local flags which will only run when this command
// is called directly, e.g.:
// {{ .CmdName }}Cmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
}
`)
}

View file

@ -177,8 +177,6 @@ type Command struct {
// that we can use on every pflag set and children commands // that we can use on every pflag set and children commands
globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName
// output is an output writer defined by user.
output io.Writer
// usageFunc is usage func defined by user. // usageFunc is usage func defined by user.
usageFunc func(*Command) error usageFunc func(*Command) error
// usageTemplate is usage template defined by user. // usageTemplate is usage template defined by user.
@ -195,6 +193,13 @@ type Command struct {
helpCommand *Command helpCommand *Command
// versionTemplate is the version template defined by user. // versionTemplate is the version template defined by user.
versionTemplate string versionTemplate string
// inReader is a reader defined by the user that replaces stdin
inReader io.Reader
// outWriter is a writer defined by the user that replaces stdout
outWriter io.Writer
// errWriter is a writer defined by the user that replaces stderr
errWriter io.Writer
} }
// SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden // SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden
@ -205,8 +210,28 @@ func (c *Command) SetArgs(a []string) {
// SetOutput sets the destination for usage and error messages. // SetOutput sets the destination for usage and error messages.
// If output is nil, os.Stderr is used. // If output is nil, os.Stderr is used.
// Deprecated: Use SetOut and/or SetErr instead
func (c *Command) SetOutput(output io.Writer) { func (c *Command) SetOutput(output io.Writer) {
c.output = output c.outWriter = output
c.errWriter = output
}
// SetOut sets the destination for usage messages.
// If newOut is nil, os.Stdout is used.
func (c *Command) SetOut(newOut io.Writer) {
c.outWriter = newOut
}
// SetErr sets the destination for error messages.
// If newErr is nil, os.Stderr is used.
func (c *Command) SetErr(newErr io.Writer) {
c.errWriter = newErr
}
// SetOut sets the source for input data
// If newIn is nil, os.Stdin is used.
func (c *Command) SetIn(newIn io.Reader) {
c.inReader = newIn
} }
// SetUsageFunc sets usage function. Usage can be defined by application. // SetUsageFunc sets usage function. Usage can be defined by application.
@ -267,9 +292,19 @@ func (c *Command) OutOrStderr() io.Writer {
return c.getOut(os.Stderr) return c.getOut(os.Stderr)
} }
// ErrOrStderr returns output to stderr
func (c *Command) ErrOrStderr() io.Writer {
return c.getErr(os.Stderr)
}
// ErrOrStderr returns output to stderr
func (c *Command) InOrStdin() io.Reader {
return c.getIn(os.Stdin)
}
func (c *Command) getOut(def io.Writer) io.Writer { func (c *Command) getOut(def io.Writer) io.Writer {
if c.output != nil { if c.outWriter != nil {
return c.output return c.outWriter
} }
if c.HasParent() { if c.HasParent() {
return c.parent.getOut(def) return c.parent.getOut(def)
@ -277,6 +312,26 @@ func (c *Command) getOut(def io.Writer) io.Writer {
return def return def
} }
func (c *Command) getErr(def io.Writer) io.Writer {
if c.errWriter != nil {
return c.errWriter
}
if c.HasParent() {
return c.parent.getErr(def)
}
return def
}
func (c *Command) getIn(def io.Reader) io.Reader {
if c.inReader != nil {
return c.inReader
}
if c.HasParent() {
return c.parent.getIn(def)
}
return def
}
// UsageFunc returns either the function set by SetUsageFunc for this command // UsageFunc returns either the function set by SetUsageFunc for this command
// or a parent, or it returns a default usage function. // or a parent, or it returns a default usage function.
func (c *Command) UsageFunc() (f func(*Command) error) { func (c *Command) UsageFunc() (f func(*Command) error) {
@ -329,13 +384,22 @@ func (c *Command) Help() error {
return nil return nil
} }
// UsageString return usage string. // UsageString returns usage string.
func (c *Command) UsageString() string { func (c *Command) UsageString() string {
tmpOutput := c.output // Storing normal writers
tmpOutput := c.outWriter
tmpErr := c.errWriter
bb := new(bytes.Buffer) bb := new(bytes.Buffer)
c.SetOutput(bb) c.outWriter = bb
c.errWriter = bb
c.Usage() c.Usage()
c.output = tmpOutput
// Setting things back to normal
c.outWriter = tmpOutput
c.errWriter = tmpErr
return bb.String() return bb.String()
} }
@ -1068,6 +1132,21 @@ func (c *Command) Printf(format string, i ...interface{}) {
c.Print(fmt.Sprintf(format, i...)) c.Print(fmt.Sprintf(format, i...))
} }
// PrintErr is a convenience method to Print to the defined Err output, fallback to Stderr if not set.
func (c *Command) PrintErr(i ...interface{}) {
fmt.Fprint(c.ErrOrStderr(), i...)
}
// PrintErrln is a convenience method to Println to the defined Err output, fallback to Stderr if not set.
func (c *Command) PrintErrln(i ...interface{}) {
c.Print(fmt.Sprintln(i...))
}
// PrintErrf is a convenience method to Printf to the defined Err output, fallback to Stderr if not set.
func (c *Command) PrintErrf(format string, i ...interface{}) {
c.Print(fmt.Sprintf(format, i...))
}
// CommandPath returns the full path to this command. // CommandPath returns the full path to this command.
func (c *Command) CommandPath() string { func (c *Command) CommandPath() string {
if c.HasParent() { if c.HasParent() {

View file

@ -1381,6 +1381,46 @@ func TestSetOutput(t *testing.T) {
} }
} }
func TestSetOut(t *testing.T) {
c := &Command{}
c.SetOut(nil)
if out := c.OutOrStdout(); out != os.Stdout {
t.Errorf("Expected setting output to nil to revert back to stdout")
}
}
func TestSetErr(t *testing.T) {
c := &Command{}
c.SetErr(nil)
if out := c.ErrOrStderr(); out != os.Stderr {
t.Errorf("Expected setting error to nil to revert back to stderr")
}
}
func TestSetIn(t *testing.T) {
c := &Command{}
c.SetIn(nil)
if out := c.InOrStdin(); out != os.Stdin {
t.Errorf("Expected setting input to nil to revert back to stdin")
}
}
func TestUsageStringRedirected(t *testing.T) {
c := &Command{}
c.usageFunc = func(cmd *Command) error {
cmd.Print("[stdout1]")
cmd.PrintErr("[stderr2]")
cmd.Print("[stdout3]")
return nil
}
expected := "[stdout1][stderr2][stdout3]"
if got := c.UsageString(); got != expected {
t.Errorf("Expected usage string to consider both stdout and stderr")
}
}
func TestFlagErrorFunc(t *testing.T) { func TestFlagErrorFunc(t *testing.T) {
c := &Command{Use: "c", Run: emptyRun} c := &Command{Use: "c", Run: emptyRun}

100
powershell_completions.go Normal file
View file

@ -0,0 +1,100 @@
// PowerShell completions are based on the amazing work from clap:
// https://github.com/clap-rs/clap/blob/3294d18efe5f264d12c9035f404c7d189d4824e1/src/completions/powershell.rs
//
// The generated scripts require PowerShell v5.0+ (which comes Windows 10, but
// can be downloaded separately for windows 7 or 8.1).
package cobra
import (
"bytes"
"fmt"
"io"
"os"
"strings"
"github.com/spf13/pflag"
)
var powerShellCompletionTemplate = `using namespace System.Management.Automation
using namespace System.Management.Automation.Language
Register-ArgumentCompleter -Native -CommandName '%s' -ScriptBlock {
param($wordToComplete, $commandAst, $cursorPosition)
$commandElements = $commandAst.CommandElements
$command = @(
'%s'
for ($i = 1; $i -lt $commandElements.Count; $i++) {
$element = $commandElements[$i]
if ($element -isnot [StringConstantExpressionAst] -or
$element.StringConstantType -ne [StringConstantType]::BareWord -or
$element.Value.StartsWith('-')) {
break
}
$element.Value
}
) -join ';'
$completions = @(switch ($command) {%s
})
$completions.Where{ $_.CompletionText -like "$wordToComplete*" } |
Sort-Object -Property ListItemText
}`
func generatePowerShellSubcommandCases(out io.Writer, cmd *Command, previousCommandName string) {
var cmdName string
if previousCommandName == "" {
cmdName = cmd.Name()
} else {
cmdName = fmt.Sprintf("%s;%s", previousCommandName, cmd.Name())
}
fmt.Fprintf(out, "\n '%s' {", cmdName)
cmd.Flags().VisitAll(func(flag *pflag.Flag) {
if nonCompletableFlag(flag) {
return
}
usage := escapeStringForPowerShell(flag.Usage)
if len(flag.Shorthand) > 0 {
fmt.Fprintf(out, "\n [CompletionResult]::new('-%s', '%s', [CompletionResultType]::ParameterName, '%s')", flag.Shorthand, flag.Shorthand, usage)
}
fmt.Fprintf(out, "\n [CompletionResult]::new('--%s', '%s', [CompletionResultType]::ParameterName, '%s')", flag.Name, flag.Name, usage)
})
for _, subCmd := range cmd.Commands() {
usage := escapeStringForPowerShell(subCmd.Short)
fmt.Fprintf(out, "\n [CompletionResult]::new('%s', '%s', [CompletionResultType]::ParameterValue, '%s')", subCmd.Name(), subCmd.Name(), usage)
}
fmt.Fprint(out, "\n break\n }")
for _, subCmd := range cmd.Commands() {
generatePowerShellSubcommandCases(out, subCmd, cmdName)
}
}
func escapeStringForPowerShell(s string) string {
return strings.Replace(s, "'", "''", -1)
}
// GenPowerShellCompletion generates PowerShell completion file and writes to the passed writer.
func (c *Command) GenPowerShellCompletion(w io.Writer) error {
buf := new(bytes.Buffer)
var subCommandCases bytes.Buffer
generatePowerShellSubcommandCases(&subCommandCases, c, "")
fmt.Fprintf(buf, powerShellCompletionTemplate, c.Name(), c.Name(), subCommandCases.String())
_, err := buf.WriteTo(w)
return err
}
// GenPowerShellCompletionFile generates PowerShell completion file.
func (c *Command) GenPowerShellCompletionFile(filename string) error {
outFile, err := os.Create(filename)
if err != nil {
return err
}
defer outFile.Close()
return c.GenPowerShellCompletion(outFile)
}

14
powershell_completions.md Normal file
View file

@ -0,0 +1,14 @@
# Generating PowerShell Completions For Your Own cobra.Command
Cobra can generate PowerShell completion scripts. Users need PowerShell version 5.0 or above, which comes with Windows 10 and can be downloaded separately for Windows 7 or 8.1. They can then write the completions to a file and source this file from their PowerShell profile, which is referenced by the `$Profile` environment variable. See `Get-Help about_Profiles` for more info about PowerShell profiles.
# What's supported
- Completion for subcommands using their `.Short` description
- Completion for non-hidden flags using their `.Name` and `.Shorthand`
# What's not yet supported
- Command aliases
- Required, filename or custom flags (they will work like normal flags)
- Custom completion scripts

View file

@ -0,0 +1,122 @@
package cobra
import (
"bytes"
"strings"
"testing"
)
func TestPowerShellCompletion(t *testing.T) {
tcs := []struct {
name string
root *Command
expectedExpressions []string
}{
{
name: "trivial",
root: &Command{Use: "trivialapp"},
expectedExpressions: []string{
"Register-ArgumentCompleter -Native -CommandName 'trivialapp' -ScriptBlock",
"$command = @(\n 'trivialapp'\n",
},
},
{
name: "tree",
root: func() *Command {
r := &Command{Use: "tree"}
sub1 := &Command{Use: "sub1"}
r.AddCommand(sub1)
sub11 := &Command{Use: "sub11"}
sub12 := &Command{Use: "sub12"}
sub1.AddCommand(sub11)
sub1.AddCommand(sub12)
sub2 := &Command{Use: "sub2"}
r.AddCommand(sub2)
sub21 := &Command{Use: "sub21"}
sub22 := &Command{Use: "sub22"}
sub2.AddCommand(sub21)
sub2.AddCommand(sub22)
return r
}(),
expectedExpressions: []string{
"'tree'",
"[CompletionResult]::new('sub1', 'sub1', [CompletionResultType]::ParameterValue, '')",
"[CompletionResult]::new('sub2', 'sub2', [CompletionResultType]::ParameterValue, '')",
"'tree;sub1'",
"[CompletionResult]::new('sub11', 'sub11', [CompletionResultType]::ParameterValue, '')",
"[CompletionResult]::new('sub12', 'sub12', [CompletionResultType]::ParameterValue, '')",
"'tree;sub1;sub11'",
"'tree;sub1;sub12'",
"'tree;sub2'",
"[CompletionResult]::new('sub21', 'sub21', [CompletionResultType]::ParameterValue, '')",
"[CompletionResult]::new('sub22', 'sub22', [CompletionResultType]::ParameterValue, '')",
"'tree;sub2;sub21'",
"'tree;sub2;sub22'",
},
},
{
name: "flags",
root: func() *Command {
r := &Command{Use: "flags"}
r.Flags().StringP("flag1", "a", "", "")
r.Flags().String("flag2", "", "")
sub1 := &Command{Use: "sub1"}
sub1.Flags().StringP("flag3", "c", "", "")
r.AddCommand(sub1)
return r
}(),
expectedExpressions: []string{
"'flags'",
"[CompletionResult]::new('-a', 'a', [CompletionResultType]::ParameterName, '')",
"[CompletionResult]::new('--flag1', 'flag1', [CompletionResultType]::ParameterName, '')",
"[CompletionResult]::new('--flag2', 'flag2', [CompletionResultType]::ParameterName, '')",
"[CompletionResult]::new('sub1', 'sub1', [CompletionResultType]::ParameterValue, '')",
"'flags;sub1'",
"[CompletionResult]::new('-c', 'c', [CompletionResultType]::ParameterName, '')",
"[CompletionResult]::new('--flag3', 'flag3', [CompletionResultType]::ParameterName, '')",
},
},
{
name: "usage",
root: func() *Command {
r := &Command{Use: "usage"}
r.Flags().String("flag", "", "this describes the usage of the 'flag' flag")
sub1 := &Command{
Use: "sub1",
Short: "short describes 'sub1'",
}
r.AddCommand(sub1)
return r
}(),
expectedExpressions: []string{
"[CompletionResult]::new('--flag', 'flag', [CompletionResultType]::ParameterName, 'this describes the usage of the ''flag'' flag')",
"[CompletionResult]::new('sub1', 'sub1', [CompletionResultType]::ParameterValue, 'short describes ''sub1''')",
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
buf := new(bytes.Buffer)
tc.root.GenPowerShellCompletion(buf)
output := buf.String()
for _, expectedExpression := range tc.expectedExpressions {
if !strings.Contains(output, expectedExpression) {
t.Errorf("Expected completion to contain %q somewhere; got %q", expectedExpression, output)
}
}
})
}
}

85
shell_completions.go Normal file
View file

@ -0,0 +1,85 @@
package cobra
import (
"github.com/spf13/pflag"
)
// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag if it exists,
// and causes your command to report an error if invoked without the flag.
func (c *Command) MarkFlagRequired(name string) error {
return MarkFlagRequired(c.Flags(), name)
}
// MarkPersistentFlagRequired adds the BashCompOneRequiredFlag annotation to the named persistent flag if it exists,
// and causes your command to report an error if invoked without the flag.
func (c *Command) MarkPersistentFlagRequired(name string) error {
return MarkFlagRequired(c.PersistentFlags(), name)
}
// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag if it exists,
// and causes your command to report an error if invoked without the flag.
func MarkFlagRequired(flags *pflag.FlagSet, name string) error {
return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"})
}
// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag, if it exists.
// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
func (c *Command) MarkFlagFilename(name string, extensions ...string) error {
return MarkFlagFilename(c.Flags(), name, extensions...)
}
// MarkFlagCustom adds the BashCompCustom annotation to the named flag, if it exists.
// Generated bash autocompletion will call the bash function f for the flag.
func (c *Command) MarkFlagCustom(name string, f string) error {
return MarkFlagCustom(c.Flags(), name, f)
}
// MarkPersistentFlagFilename instructs the various shell completion
// implementations to limit completions for this persistent flag to the
// specified extensions (patterns).
//
// Shell Completion compatibility matrix: bash, zsh
func (c *Command) MarkPersistentFlagFilename(name string, extensions ...string) error {
return MarkFlagFilename(c.PersistentFlags(), name, extensions...)
}
// MarkFlagFilename instructs the various shell completion implementations to
// limit completions for this flag to the specified extensions (patterns).
//
// Shell Completion compatibility matrix: bash, zsh
func MarkFlagFilename(flags *pflag.FlagSet, name string, extensions ...string) error {
return flags.SetAnnotation(name, BashCompFilenameExt, extensions)
}
// MarkFlagCustom instructs the various shell completion implementations to
// limit completions for this flag to the specified extensions (patterns).
//
// Shell Completion compatibility matrix: bash, zsh
func MarkFlagCustom(flags *pflag.FlagSet, name string, f string) error {
return flags.SetAnnotation(name, BashCompCustom, []string{f})
}
// MarkFlagDirname instructs the various shell completion implementations to
// complete only directories with this named flag.
//
// Shell Completion compatibility matrix: zsh
func (c *Command) MarkFlagDirname(name string) error {
return MarkFlagDirname(c.Flags(), name)
}
// MarkPersistentFlagDirname instructs the various shell completion
// implementations to complete only directories with this persistent named flag.
//
// Shell Completion compatibility matrix: zsh
func (c *Command) MarkPersistentFlagDirname(name string) error {
return MarkFlagDirname(c.PersistentFlags(), name)
}
// MarkFlagDirname instructs the various shell completion implementations to
// complete only directories with this specified flag.
//
// Shell Completion compatibility matrix: zsh
func MarkFlagDirname(flags *pflag.FlagSet, name string) error {
zshPattern := "-(/)"
return flags.SetAnnotation(name, zshCompDirname, []string{zshPattern})
}

View file

@ -1,13 +1,102 @@
package cobra package cobra
import ( import (
"bytes" "encoding/json"
"fmt" "fmt"
"io" "io"
"os" "os"
"sort"
"strings" "strings"
"text/template"
"github.com/spf13/pflag"
) )
const (
zshCompArgumentAnnotation = "cobra_annotations_zsh_completion_argument_annotation"
zshCompArgumentFilenameComp = "cobra_annotations_zsh_completion_argument_file_completion"
zshCompArgumentWordComp = "cobra_annotations_zsh_completion_argument_word_completion"
zshCompDirname = "cobra_annotations_zsh_dirname"
)
var (
zshCompFuncMap = template.FuncMap{
"genZshFuncName": zshCompGenFuncName,
"extractFlags": zshCompExtractFlag,
"genFlagEntryForZshArguments": zshCompGenFlagEntryForArguments,
"extractArgsCompletions": zshCompExtractArgumentCompletionHintsForRendering,
}
zshCompletionText = `
{{/* should accept Command (that contains subcommands) as parameter */}}
{{define "argumentsC" -}}
{{ $cmdPath := genZshFuncName .}}
function {{$cmdPath}} {
local -a commands
_arguments -C \{{- range extractFlags .}}
{{genFlagEntryForZshArguments .}} \{{- end}}
"1: :->cmnds" \
"*::arg:->args"
case $state in
cmnds)
commands=({{range .Commands}}{{if not .Hidden}}
"{{.Name}}:{{.Short}}"{{end}}{{end}}
)
_describe "command" commands
;;
esac
case "$words[1]" in {{- range .Commands}}{{if not .Hidden}}
{{.Name}})
{{$cmdPath}}_{{.Name}}
;;{{end}}{{end}}
esac
}
{{range .Commands}}{{if not .Hidden}}
{{template "selectCmdTemplate" .}}
{{- end}}{{end}}
{{- end}}
{{/* should accept Command without subcommands as parameter */}}
{{define "arguments" -}}
function {{genZshFuncName .}} {
{{" _arguments"}}{{range extractFlags .}} \
{{genFlagEntryForZshArguments . -}}
{{end}}{{range extractArgsCompletions .}} \
{{.}}{{end}}
}
{{end}}
{{/* dispatcher for commands with or without subcommands */}}
{{define "selectCmdTemplate" -}}
{{if .Hidden}}{{/* ignore hidden*/}}{{else -}}
{{if .Commands}}{{template "argumentsC" .}}{{else}}{{template "arguments" .}}{{end}}
{{- end}}
{{- end}}
{{/* template entry point */}}
{{define "Main" -}}
#compdef _{{.Name}} {{.Name}}
{{template "selectCmdTemplate" .}}
{{end}}
`
)
// zshCompArgsAnnotation is used to encode/decode zsh completion for
// arguments to/from Command.Annotations.
type zshCompArgsAnnotation map[int]zshCompArgHint
type zshCompArgHint struct {
// Indicates the type of the completion to use. One of:
// zshCompArgumentFilenameComp or zshCompArgumentWordComp
Tipe string `json:"type"`
// A value for the type above (globs for file completion or words)
Options []string `json:"options"`
}
// GenZshCompletionFile generates zsh completion file. // GenZshCompletionFile generates zsh completion file.
func (c *Command) GenZshCompletionFile(filename string) error { func (c *Command) GenZshCompletionFile(filename string) error {
outFile, err := os.Create(filename) outFile, err := os.Create(filename)
@ -19,108 +108,229 @@ func (c *Command) GenZshCompletionFile(filename string) error {
return c.GenZshCompletion(outFile) return c.GenZshCompletion(outFile)
} }
// GenZshCompletion generates a zsh completion file and writes to the passed writer. // GenZshCompletion generates a zsh completion file and writes to the passed
// writer. The completion always run on the root command regardless of the
// command it was called from.
func (c *Command) GenZshCompletion(w io.Writer) error { func (c *Command) GenZshCompletion(w io.Writer) error {
buf := new(bytes.Buffer) tmpl, err := template.New("Main").Funcs(zshCompFuncMap).Parse(zshCompletionText)
if err != nil {
writeHeader(buf, c) return fmt.Errorf("error creating zsh completion template: %v", err)
maxDepth := maxDepth(c)
writeLevelMapping(buf, maxDepth)
writeLevelCases(buf, maxDepth, c)
_, err := buf.WriteTo(w)
return err
}
func writeHeader(w io.Writer, cmd *Command) {
fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name())
}
func maxDepth(c *Command) int {
if len(c.Commands()) == 0 {
return 0
} }
maxDepthSub := 0 return tmpl.Execute(w, c.Root())
for _, s := range c.Commands() { }
subDepth := maxDepth(s)
if subDepth > maxDepthSub { // MarkZshCompPositionalArgumentFile marks the specified argument (first
maxDepthSub = subDepth // argument is 1) as completed by file selection. patterns (e.g. "*.txt") are
// optional - if not provided the completion will search for all files.
func (c *Command) MarkZshCompPositionalArgumentFile(argPosition int, patterns ...string) error {
if argPosition < 1 {
return fmt.Errorf("Invalid argument position (%d)", argPosition)
}
annotation, err := c.zshCompGetArgsAnnotations()
if err != nil {
return err
}
if c.zshcompArgsAnnotationnIsDuplicatePosition(annotation, argPosition) {
return fmt.Errorf("Duplicate annotation for positional argument at index %d", argPosition)
}
annotation[argPosition] = zshCompArgHint{
Tipe: zshCompArgumentFilenameComp,
Options: patterns,
}
return c.zshCompSetArgsAnnotations(annotation)
}
// MarkZshCompPositionalArgumentWords marks the specified positional argument
// (first argument is 1) as completed by the provided words. At east one word
// must be provided, spaces within words will be offered completion with
// "word\ word".
func (c *Command) MarkZshCompPositionalArgumentWords(argPosition int, words ...string) error {
if argPosition < 1 {
return fmt.Errorf("Invalid argument position (%d)", argPosition)
}
if len(words) == 0 {
return fmt.Errorf("Trying to set empty word list for positional argument %d", argPosition)
}
annotation, err := c.zshCompGetArgsAnnotations()
if err != nil {
return err
}
if c.zshcompArgsAnnotationnIsDuplicatePosition(annotation, argPosition) {
return fmt.Errorf("Duplicate annotation for positional argument at index %d", argPosition)
}
annotation[argPosition] = zshCompArgHint{
Tipe: zshCompArgumentWordComp,
Options: words,
}
return c.zshCompSetArgsAnnotations(annotation)
}
func zshCompExtractArgumentCompletionHintsForRendering(c *Command) ([]string, error) {
var result []string
annotation, err := c.zshCompGetArgsAnnotations()
if err != nil {
return nil, err
}
for k, v := range annotation {
s, err := zshCompRenderZshCompArgHint(k, v)
if err != nil {
return nil, err
}
result = append(result, s)
}
if len(c.ValidArgs) > 0 {
if _, positionOneExists := annotation[1]; !positionOneExists {
s, err := zshCompRenderZshCompArgHint(1, zshCompArgHint{
Tipe: zshCompArgumentWordComp,
Options: c.ValidArgs,
})
if err != nil {
return nil, err
}
result = append(result, s)
} }
} }
return 1 + maxDepthSub sort.Strings(result)
return result, nil
} }
func writeLevelMapping(w io.Writer, numLevels int) { func zshCompRenderZshCompArgHint(i int, z zshCompArgHint) (string, error) {
fmt.Fprintln(w, `_arguments \`) switch t := z.Tipe; t {
for i := 1; i <= numLevels; i++ { case zshCompArgumentFilenameComp:
fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i) var globs []string
fmt.Fprintln(w) for _, g := range z.Options {
} globs = append(globs, fmt.Sprintf(`-g "%s"`, g))
fmt.Fprintf(w, ` '%d: :%s'`, numLevels+1, "_files")
fmt.Fprintln(w)
}
func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
fmt.Fprintln(w, "case $state in")
defer fmt.Fprintln(w, "esac")
for i := 1; i <= maxDepth; i++ {
fmt.Fprintf(w, " level%d)\n", i)
writeLevel(w, root, i)
fmt.Fprintln(w, " ;;")
}
fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;")
}
func writeLevel(w io.Writer, root *Command, i int) {
fmt.Fprintf(w, " case $words[%d] in\n", i)
defer fmt.Fprintln(w, " esac")
commands := filterByLevel(root, i)
byParent := groupByParent(commands)
for p, c := range byParent {
names := names(c)
fmt.Fprintf(w, " %s)\n", p)
fmt.Fprintf(w, " _arguments '%d: :(%s)'\n", i, strings.Join(names, " "))
fmt.Fprintln(w, " ;;")
}
fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;")
}
func filterByLevel(c *Command, l int) []*Command {
cs := make([]*Command, 0)
if l == 0 {
cs = append(cs, c)
return cs
}
for _, s := range c.Commands() {
cs = append(cs, filterByLevel(s, l-1)...)
}
return cs
}
func groupByParent(commands []*Command) map[string][]*Command {
m := make(map[string][]*Command)
for _, c := range commands {
parent := c.Parent()
if parent == nil {
continue
} }
m[parent.Name()] = append(m[parent.Name()], c) return fmt.Sprintf(`'%d: :_files %s'`, i, strings.Join(globs, " ")), nil
case zshCompArgumentWordComp:
var words []string
for _, w := range z.Options {
words = append(words, fmt.Sprintf("%q", w))
}
return fmt.Sprintf(`'%d: :(%s)'`, i, strings.Join(words, " ")), nil
default:
return "", fmt.Errorf("Invalid zsh argument completion annotation: %s", t)
} }
return m
} }
func names(commands []*Command) []string { func (c *Command) zshcompArgsAnnotationnIsDuplicatePosition(annotation zshCompArgsAnnotation, position int) bool {
ns := make([]string, len(commands)) _, dup := annotation[position]
for i, c := range commands { return dup
ns[i] = c.Name() }
}
return ns func (c *Command) zshCompGetArgsAnnotations() (zshCompArgsAnnotation, error) {
annotation := make(zshCompArgsAnnotation)
annotationString, ok := c.Annotations[zshCompArgumentAnnotation]
if !ok {
return annotation, nil
}
err := json.Unmarshal([]byte(annotationString), &annotation)
if err != nil {
return annotation, fmt.Errorf("Error unmarshaling zsh argument annotation: %v", err)
}
return annotation, nil
}
func (c *Command) zshCompSetArgsAnnotations(annotation zshCompArgsAnnotation) error {
jsn, err := json.Marshal(annotation)
if err != nil {
return fmt.Errorf("Error marshaling zsh argument annotation: %v", err)
}
if c.Annotations == nil {
c.Annotations = make(map[string]string)
}
c.Annotations[zshCompArgumentAnnotation] = string(jsn)
return nil
}
func zshCompGenFuncName(c *Command) string {
if c.HasParent() {
return zshCompGenFuncName(c.Parent()) + "_" + c.Name()
}
return "_" + c.Name()
}
func zshCompExtractFlag(c *Command) []*pflag.Flag {
var flags []*pflag.Flag
c.LocalFlags().VisitAll(func(f *pflag.Flag) {
if !f.Hidden {
flags = append(flags, f)
}
})
c.InheritedFlags().VisitAll(func(f *pflag.Flag) {
if !f.Hidden {
flags = append(flags, f)
}
})
return flags
}
// zshCompGenFlagEntryForArguments returns an entry that matches _arguments
// zsh-completion parameters. It's too complicated to generate in a template.
func zshCompGenFlagEntryForArguments(f *pflag.Flag) string {
if f.Name == "" || f.Shorthand == "" {
return zshCompGenFlagEntryForSingleOptionFlag(f)
}
return zshCompGenFlagEntryForMultiOptionFlag(f)
}
func zshCompGenFlagEntryForSingleOptionFlag(f *pflag.Flag) string {
var option, multiMark, extras string
if zshCompFlagCouldBeSpecifiedMoreThenOnce(f) {
multiMark = "*"
}
option = "--" + f.Name
if option == "--" {
option = "-" + f.Shorthand
}
extras = zshCompGenFlagEntryExtras(f)
return fmt.Sprintf(`'%s%s[%s]%s'`, multiMark, option, zshCompQuoteFlagDescription(f.Usage), extras)
}
func zshCompGenFlagEntryForMultiOptionFlag(f *pflag.Flag) string {
var options, parenMultiMark, curlyMultiMark, extras string
if zshCompFlagCouldBeSpecifiedMoreThenOnce(f) {
parenMultiMark = "*"
curlyMultiMark = "\\*"
}
options = fmt.Sprintf(`'(%s-%s %s--%s)'{%s-%s,%s--%s}`,
parenMultiMark, f.Shorthand, parenMultiMark, f.Name, curlyMultiMark, f.Shorthand, curlyMultiMark, f.Name)
extras = zshCompGenFlagEntryExtras(f)
return fmt.Sprintf(`%s'[%s]%s'`, options, zshCompQuoteFlagDescription(f.Usage), extras)
}
func zshCompGenFlagEntryExtras(f *pflag.Flag) string {
if f.NoOptDefVal != "" {
return ""
}
extras := ":" // allow options for flag (even without assistance)
for key, values := range f.Annotations {
switch key {
case zshCompDirname:
extras = fmt.Sprintf(":filename:_files -g %q", values[0])
case BashCompFilenameExt:
extras = ":filename:_files"
for _, pattern := range values {
extras = extras + fmt.Sprintf(` -g "%s"`, pattern)
}
}
}
return extras
}
func zshCompFlagCouldBeSpecifiedMoreThenOnce(f *pflag.Flag) bool {
return strings.Contains(f.Value.Type(), "Slice") ||
strings.Contains(f.Value.Type(), "Array")
}
func zshCompQuoteFlagDescription(s string) string {
return strings.Replace(s, "'", `'\''`, -1)
} }

39
zsh_completions.md Normal file
View file

@ -0,0 +1,39 @@
## Generating Zsh Completion for your cobra.Command
Cobra supports native Zsh completion generated from the root `cobra.Command`.
The generated completion script should be put somewhere in your `$fpath` named
`_<YOUR COMMAND>`.
### What's Supported
* Completion for all non-hidden subcommands using their `.Short` description.
* Completion for all non-hidden flags using the following rules:
* Filename completion works by marking the flag with `cmd.MarkFlagFilename...`
family of commands.
* The requirement for argument to the flag is decided by the `.NoOptDefVal`
flag value - if it's empty then completion will expect an argument.
* Flags of one of the various `*Array` and `*Slice` types supports multiple
specifications (with or without argument depending on the specific type).
* Completion of positional arguments using the following rules:
* Argument position for all options below starts at `1`. If argument position
`0` is requested it will raise an error.
* Use `command.MarkZshCompPositionalArgumentFile` to complete filenames. Glob
patterns (e.g. `"*.log"`) are optional - if not specified it will offer to
complete all file types.
* Use `command.MarkZshCompPositionalArgumentWords` to offer specific words for
completion. At least one word is required.
* It's possible to specify completion for some arguments and leave some
unspecified (e.g. offer words for second argument but nothing for first
argument). This will cause no completion for first argument but words
completion for second argument.
* If no argument completion was specified for 1st argument (but optionally was
specified for 2nd) and the command has `ValidArgs` it will be used as
completion options for 1st argument.
* Argument completions only offered for commands with no subcommands.
### What's not yet Supported
* Custom completion scripts are not supported yet (We should probably create zsh
specific one, doesn't make sense to re-use the bash one as the functions will
be different).
* Whatever other feature you're looking for and doesn't exist :)

View file

@ -2,88 +2,474 @@ package cobra
import ( import (
"bytes" "bytes"
"regexp"
"strings" "strings"
"testing" "testing"
) )
func TestZshCompletion(t *testing.T) { func TestGenZshCompletion(t *testing.T) {
var debug bool
var option string
tcs := []struct {
name string
root *Command
expectedExpressions []string
invocationArgs []string
skip string
}{
{
name: "simple command",
root: func() *Command {
r := &Command{
Use: "mycommand",
Long: "My Command long description",
Run: emptyRun,
}
r.Flags().BoolVar(&debug, "debug", debug, "description")
return r
}(),
expectedExpressions: []string{
`(?s)function _mycommand {\s+_arguments \\\s+'--debug\[description\]'.*--help.*}`,
"#compdef _mycommand mycommand",
},
},
{
name: "flags with both long and short flags",
root: func() *Command {
r := &Command{
Use: "testcmd",
Long: "long description",
Run: emptyRun,
}
r.Flags().BoolVarP(&debug, "debug", "d", debug, "debug description")
return r
}(),
expectedExpressions: []string{
`'\(-d --debug\)'{-d,--debug}'\[debug description\]'`,
},
},
{
name: "command with subcommands and flags with values",
root: func() *Command {
r := &Command{
Use: "rootcmd",
Long: "Long rootcmd description",
}
d := &Command{
Use: "subcmd1",
Short: "Subcmd1 short description",
Run: emptyRun,
}
e := &Command{
Use: "subcmd2",
Long: "Subcmd2 short description",
Run: emptyRun,
}
r.PersistentFlags().BoolVar(&debug, "debug", debug, "description")
d.Flags().StringVarP(&option, "option", "o", option, "option description")
r.AddCommand(d, e)
return r
}(),
expectedExpressions: []string{
`commands=\(\n\s+"help:.*\n\s+"subcmd1:.*\n\s+"subcmd2:.*\n\s+\)`,
`_arguments \\\n.*'--debug\[description]'`,
`_arguments -C \\\n.*'--debug\[description]'`,
`function _rootcmd_subcmd1 {`,
`function _rootcmd_subcmd1 {`,
`_arguments \\\n.*'\(-o --option\)'{-o,--option}'\[option description]:' \\\n`,
},
},
{
name: "filename completion with and without globs",
root: func() *Command {
var file string
r := &Command{
Use: "mycmd",
Short: "my command short description",
Run: emptyRun,
}
r.Flags().StringVarP(&file, "config", "c", file, "config file")
r.MarkFlagFilename("config")
r.Flags().String("output", "", "output file")
r.MarkFlagFilename("output", "*.log", "*.txt")
return r
}(),
expectedExpressions: []string{
`\n +'\(-c --config\)'{-c,--config}'\[config file]:filename:_files'`,
`:_files -g "\*.log" -g "\*.txt"`,
},
},
{
name: "repeated variables both with and without value",
root: func() *Command {
r := genTestCommand("mycmd", true)
_ = r.Flags().BoolSliceP("debug", "d", []bool{}, "debug usage")
_ = r.Flags().StringArray("option", []string{}, "options")
return r
}(),
expectedExpressions: []string{
`'\*--option\[options]`,
`'\(\*-d \*--debug\)'{\\\*-d,\\\*--debug}`,
},
},
{
name: "generated flags --help and --version should be created even when not executing root cmd",
root: func() *Command {
r := &Command{
Use: "mycmd",
Short: "mycmd short description",
Version: "myversion",
}
s := genTestCommand("sub1", true)
r.AddCommand(s)
return s
}(),
expectedExpressions: []string{
"--version",
"--help",
},
invocationArgs: []string{
"sub1",
},
skip: "--version and --help are currently not generated when not running on root command",
},
{
name: "zsh generation should run on root command",
root: func() *Command {
r := genTestCommand("root", false)
s := genTestCommand("sub1", true)
r.AddCommand(s)
return s
}(),
expectedExpressions: []string{
"function _root {",
},
},
{
name: "flag description with single quote (') shouldn't break quotes in completion file",
root: func() *Command {
r := genTestCommand("root", true)
r.Flags().Bool("private", false, "Don't show public info")
return r
}(),
expectedExpressions: []string{
`--private\[Don'\\''t show public info]`,
},
},
{
name: "argument completion for file with and without patterns",
root: func() *Command {
r := genTestCommand("root", true)
r.MarkZshCompPositionalArgumentFile(1, "*.log")
r.MarkZshCompPositionalArgumentFile(2)
return r
}(),
expectedExpressions: []string{
`'1: :_files -g "\*.log"' \\\n\s+'2: :_files`,
},
},
{
name: "argument zsh completion for words",
root: func() *Command {
r := genTestCommand("root", true)
r.MarkZshCompPositionalArgumentWords(1, "word1", "word2")
return r
}(),
expectedExpressions: []string{
`'1: :\("word1" "word2"\)`,
},
},
{
name: "argument completion for words with spaces",
root: func() *Command {
r := genTestCommand("root", true)
r.MarkZshCompPositionalArgumentWords(1, "single", "multiple words")
return r
}(),
expectedExpressions: []string{
`'1: :\("single" "multiple words"\)'`,
},
},
{
name: "argument completion when command has ValidArgs and no annotation for argument completion",
root: func() *Command {
r := genTestCommand("root", true)
r.ValidArgs = []string{"word1", "word2"}
return r
}(),
expectedExpressions: []string{
`'1: :\("word1" "word2"\)'`,
},
},
{
name: "argument completion when command has ValidArgs and no annotation for argument at argPosition 1",
root: func() *Command {
r := genTestCommand("root", true)
r.ValidArgs = []string{"word1", "word2"}
r.MarkZshCompPositionalArgumentFile(2)
return r
}(),
expectedExpressions: []string{
`'1: :\("word1" "word2"\)' \\`,
},
},
{
name: "directory completion for flag",
root: func() *Command {
r := genTestCommand("root", true)
r.Flags().String("test", "", "test")
r.PersistentFlags().String("ptest", "", "ptest")
r.MarkFlagDirname("test")
r.MarkPersistentFlagDirname("ptest")
return r
}(),
expectedExpressions: []string{
`--test\[test]:filename:_files -g "-\(/\)"`,
`--ptest\[ptest]:filename:_files -g "-\(/\)"`,
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
if tc.skip != "" {
t.Skip(tc.skip)
}
tc.root.Root().SetArgs(tc.invocationArgs)
tc.root.Execute()
buf := new(bytes.Buffer)
if err := tc.root.GenZshCompletion(buf); err != nil {
t.Error(err)
}
output := buf.Bytes()
for _, expr := range tc.expectedExpressions {
rgx, err := regexp.Compile(expr)
if err != nil {
t.Errorf("error compiling expression (%s): %v", expr, err)
}
if !rgx.Match(output) {
t.Errorf("expected completion (%s) to match '%s'", buf.String(), expr)
}
}
})
}
}
func TestGenZshCompletionHidden(t *testing.T) {
tcs := []struct { tcs := []struct {
name string name string
root *Command root *Command
expectedExpressions []string expectedExpressions []string
}{ }{
{ {
name: "trivial", name: "hidden commands",
root: &Command{Use: "trivialapp"},
expectedExpressions: []string{"#compdef trivial"},
},
{
name: "linear",
root: func() *Command { root: func() *Command {
r := &Command{Use: "linear"} r := &Command{
Use: "main",
sub1 := &Command{Use: "sub1"} Short: "main short description",
r.AddCommand(sub1) }
s1 := &Command{
sub2 := &Command{Use: "sub2"} Use: "sub1",
sub1.AddCommand(sub2) Hidden: true,
Run: emptyRun,
sub3 := &Command{Use: "sub3"} }
sub2.AddCommand(sub3) s2 := &Command{
return r Use: "sub2",
}(), Short: "short sub2 description",
expectedExpressions: []string{"sub1", "sub2", "sub3"}, Run: emptyRun,
}, }
{ r.AddCommand(s1, s2)
name: "flat",
root: func() *Command {
r := &Command{Use: "flat"}
r.AddCommand(&Command{Use: "c1"})
r.AddCommand(&Command{Use: "c2"})
return r
}(),
expectedExpressions: []string{"(c1 c2)"},
},
{
name: "tree",
root: func() *Command {
r := &Command{Use: "tree"}
sub1 := &Command{Use: "sub1"}
r.AddCommand(sub1)
sub11 := &Command{Use: "sub11"}
sub12 := &Command{Use: "sub12"}
sub1.AddCommand(sub11)
sub1.AddCommand(sub12)
sub2 := &Command{Use: "sub2"}
r.AddCommand(sub2)
sub21 := &Command{Use: "sub21"}
sub22 := &Command{Use: "sub22"}
sub2.AddCommand(sub21)
sub2.AddCommand(sub22)
return r return r
}(), }(),
expectedExpressions: []string{"(sub11 sub12)", "(sub21 sub22)"}, expectedExpressions: []string{
"sub1",
},
},
{
name: "hidden flags",
root: func() *Command {
var hidden string
r := &Command{
Use: "root",
Short: "root short description",
Run: emptyRun,
}
r.Flags().StringVarP(&hidden, "hidden", "H", hidden, "hidden usage")
if err := r.Flags().MarkHidden("hidden"); err != nil {
t.Errorf("Error setting flag hidden: %v\n", err)
}
return r
}(),
expectedExpressions: []string{
"--hidden",
},
}, },
} }
for _, tc := range tcs { for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
tc.root.Execute()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
tc.root.GenZshCompletion(buf) if err := tc.root.GenZshCompletion(buf); err != nil {
t.Error(err)
}
output := buf.String() output := buf.String()
for _, expectedExpression := range tc.expectedExpressions { for _, expr := range tc.expectedExpressions {
if !strings.Contains(output, expectedExpression) { if strings.Contains(output, expr) {
t.Errorf("Expected completion to contain %q somewhere; got %q", expectedExpression, output) t.Errorf("Expected completion (%s) not to contain '%s' but it does", output, expr)
} }
} }
}) })
} }
} }
func TestMarkZshCompPositionalArgumentFile(t *testing.T) {
t.Run("Doesn't allow overwriting existing positional argument", func(t *testing.T) {
c := &Command{}
if err := c.MarkZshCompPositionalArgumentFile(1, "*.log"); err != nil {
t.Errorf("Received error when we shouldn't have: %v\n", err)
}
if err := c.MarkZshCompPositionalArgumentFile(1); err == nil {
t.Error("Didn't receive an error when trying to overwrite argument position")
}
})
t.Run("Refuses to accept argPosition less then 1", func(t *testing.T) {
c := &Command{}
err := c.MarkZshCompPositionalArgumentFile(0, "*")
if err == nil {
t.Fatal("Error was not thrown when indicating argument position 0")
}
if !strings.Contains(err.Error(), "position") {
t.Errorf("expected error message '%s' to contain 'position'", err.Error())
}
})
}
func TestMarkZshCompPositionalArgumentWords(t *testing.T) {
t.Run("Doesn't allow overwriting existing positional argument", func(t *testing.T) {
c := &Command{}
if err := c.MarkZshCompPositionalArgumentFile(1, "*.log"); err != nil {
t.Errorf("Received error when we shouldn't have: %v\n", err)
}
if err := c.MarkZshCompPositionalArgumentWords(1, "hello"); err == nil {
t.Error("Didn't receive an error when trying to overwrite argument position")
}
})
t.Run("Doesn't allow calling without words", func(t *testing.T) {
c := &Command{}
if err := c.MarkZshCompPositionalArgumentWords(0); err == nil {
t.Error("Should not allow saving empty word list for annotation")
}
})
t.Run("Refuses to accept argPosition less then 1", func(t *testing.T) {
c := &Command{}
err := c.MarkZshCompPositionalArgumentWords(0, "word")
if err == nil {
t.Fatal("Should not allow setting argument position less then 1")
}
if !strings.Contains(err.Error(), "position") {
t.Errorf("Expected error '%s' to contain 'position' but didn't", err.Error())
}
})
}
func BenchmarkMediumSizeConstruct(b *testing.B) {
root := constructLargeCommandHierarchy()
// if err := root.GenZshCompletionFile("_mycmd"); err != nil {
// b.Error(err)
// }
for i := 0; i < b.N; i++ {
buf := new(bytes.Buffer)
err := root.GenZshCompletion(buf)
if err != nil {
b.Error(err)
}
}
}
func TestExtractFlags(t *testing.T) {
var debug, cmdc, cmdd bool
c := &Command{
Use: "cmdC",
Long: "Command C",
}
c.PersistentFlags().BoolVarP(&debug, "debug", "d", debug, "debug mode")
c.Flags().BoolVar(&cmdc, "cmd-c", cmdc, "Command C")
d := &Command{
Use: "CmdD",
Long: "Command D",
}
d.Flags().BoolVar(&cmdd, "cmd-d", cmdd, "Command D")
c.AddCommand(d)
resC := zshCompExtractFlag(c)
resD := zshCompExtractFlag(d)
if len(resC) != 2 {
t.Errorf("expected Command C to return 2 flags, got %d", len(resC))
}
if len(resD) != 2 {
t.Errorf("expected Command D to return 2 flags, got %d", len(resD))
}
}
func constructLargeCommandHierarchy() *Command {
var config, st1, st2 string
var long, debug bool
var in1, in2 int
var verbose []bool
r := genTestCommand("mycmd", false)
r.PersistentFlags().StringVarP(&config, "config", "c", config, "config usage")
if err := r.MarkPersistentFlagFilename("config", "*"); err != nil {
panic(err)
}
s1 := genTestCommand("sub1", true)
s1.Flags().BoolVar(&long, "long", long, "long description")
s1.Flags().BoolSliceVar(&verbose, "verbose", verbose, "verbose description")
s1.Flags().StringArray("option", []string{}, "various options")
s2 := genTestCommand("sub2", true)
s2.PersistentFlags().BoolVar(&debug, "debug", debug, "debug description")
s3 := genTestCommand("sub3", true)
s3.Hidden = true
s1_1 := genTestCommand("sub1sub1", true)
s1_1.Flags().StringVar(&st1, "st1", st1, "st1 description")
s1_1.Flags().StringVar(&st2, "st2", st2, "st2 description")
s1_2 := genTestCommand("sub1sub2", true)
s1_3 := genTestCommand("sub1sub3", true)
s1_3.Flags().IntVar(&in1, "int1", in1, "int1 description")
s1_3.Flags().IntVar(&in2, "int2", in2, "int2 description")
s1_3.Flags().StringArrayP("option", "O", []string{}, "more options")
s2_1 := genTestCommand("sub2sub1", true)
s2_2 := genTestCommand("sub2sub2", true)
s2_3 := genTestCommand("sub2sub3", true)
s2_4 := genTestCommand("sub2sub4", true)
s2_5 := genTestCommand("sub2sub5", true)
s1.AddCommand(s1_1, s1_2, s1_3)
s2.AddCommand(s2_1, s2_2, s2_3, s2_4, s2_5)
r.AddCommand(s1, s2, s3)
r.Execute()
return r
}
func genTestCommand(name string, withRun bool) *Command {
r := &Command{
Use: name,
Short: name + " short description",
Long: "Long description for " + name,
}
if withRun {
r.Run = emptyRun
}
return r
}