diff --git a/.gitignore b/.gitignore index 1b8c7c26..3b053c59 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ tags *.exe cobra.test + +.idea/* diff --git a/README.md b/README.md index ff16e3f6..60c5a425 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ Many of the most widely used Go projects are built using Cobra, such as: [Istio](https://istio.io), [Prototool](https://github.com/uber/prototool), [mattermost-server](https://github.com/mattermost/mattermost-server), +[Gardener](https://github.com/gardener/gardenctl), etc. [![Build Status](https://travis-ci.org/spf13/cobra.svg "Travis CI status")](https://travis-ci.org/spf13/cobra) @@ -48,6 +49,7 @@ etc. * [Suggestions when "unknown command" happens](#suggestions-when-unknown-command-happens) * [Generating documentation for your command](#generating-documentation-for-your-command) * [Generating bash completions](#generating-bash-completions) + * [Generating zsh completions](#generating-zsh-completions) - [Contributing](#contributing) - [License](#license) @@ -336,7 +338,7 @@ rootCmd.PersistentFlags().BoolVarP(&Verbose, "verbose", "v", false, "verbose out A flag can also be assigned locally which will only apply to that specific command. ```go -rootCmd.Flags().StringVarP(&Source, "source", "s", "", "Source directory to read from") +localCmd.Flags().StringVarP(&Source, "source", "s", "", "Source directory to read from") ``` ### Local Flag on Parent Commands @@ -719,6 +721,11 @@ Cobra can generate documentation based on subcommands, flags, etc. in the follow Cobra can generate a bash-completion file. If you add more information to your command, these completions can be amazingly powerful and flexible. Read more about it in [Bash Completions](bash_completions.md). +## Generating zsh completions + +Cobra can generate zsh-completion file. Read more about it in +[Zsh Completions](zsh_completions.md). + # Contributing 1. Fork it diff --git a/bash_completions.go b/bash_completions.go index c3c1e501..57bb8e1b 100644 --- a/bash_completions.go +++ b/bash_completions.go @@ -545,51 +545,3 @@ func (c *Command) GenBashCompletionFile(filename string) error { return c.GenBashCompletion(outFile) } - -// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag if it exists, -// and causes your command to report an error if invoked without the flag. -func (c *Command) MarkFlagRequired(name string) error { - return MarkFlagRequired(c.Flags(), name) -} - -// MarkPersistentFlagRequired adds the BashCompOneRequiredFlag annotation to the named persistent flag if it exists, -// and causes your command to report an error if invoked without the flag. -func (c *Command) MarkPersistentFlagRequired(name string) error { - return MarkFlagRequired(c.PersistentFlags(), name) -} - -// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag if it exists, -// and causes your command to report an error if invoked without the flag. -func MarkFlagRequired(flags *pflag.FlagSet, name string) error { - return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"}) -} - -// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag, if it exists. -// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided. -func (c *Command) MarkFlagFilename(name string, extensions ...string) error { - return MarkFlagFilename(c.Flags(), name, extensions...) -} - -// MarkFlagCustom adds the BashCompCustom annotation to the named flag, if it exists. -// Generated bash autocompletion will call the bash function f for the flag. -func (c *Command) MarkFlagCustom(name string, f string) error { - return MarkFlagCustom(c.Flags(), name, f) -} - -// MarkPersistentFlagFilename adds the BashCompFilenameExt annotation to the named persistent flag, if it exists. -// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided. -func (c *Command) MarkPersistentFlagFilename(name string, extensions ...string) error { - return MarkFlagFilename(c.PersistentFlags(), name, extensions...) -} - -// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag in the flag set, if it exists. -// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided. -func MarkFlagFilename(flags *pflag.FlagSet, name string, extensions ...string) error { - return flags.SetAnnotation(name, BashCompFilenameExt, extensions) -} - -// MarkFlagCustom adds the BashCompCustom annotation to the named flag in the flag set, if it exists. -// Generated bash autocompletion will call the bash function f for the flag. -func MarkFlagCustom(flags *pflag.FlagSet, name string, f string) error { - return flags.SetAnnotation(name, BashCompCustom, []string{f}) -} diff --git a/cobra/cmd/add.go b/cobra/cmd/add.go index 54650d29..e3330e33 100644 --- a/cobra/cmd/add.go +++ b/cobra/cmd/add.go @@ -16,24 +16,20 @@ package cmd import ( "fmt" "os" - "path/filepath" "unicode" "github.com/spf13/cobra" ) -func init() { - addCmd.Flags().StringVarP(&packageName, "package", "t", "", "target package name (e.g. github.com/spf13/hugo)") - addCmd.Flags().StringVarP(&parentName, "parent", "p", "rootCmd", "variable name of parent command for this command (e.g. xyCmd)") -} +var ( + packageName string + parentName string -var packageName, parentName string - -var addCmd = &cobra.Command{ - Use: "add [command name]", - Aliases: []string{"command"}, - Short: "Add a command to a Cobra Application", - Long: `Add (cobra add) will create a new command, with a license and + addCmd = &cobra.Command{ + Use: "add [command name]", + Aliases: []string{"command"}, + Short: "Add a command to a Cobra Application", + Long: `Add (cobra add) will create a new command, with a license and the appropriate structure for a Cobra-based CLI application, and register it to its parent (default rootCmd). @@ -42,28 +38,41 @@ with an initial uppercase letter. Example: cobra add server -> resulting in a new cmd/server.go`, - Run: func(cmd *cobra.Command, args []string) { - if len(args) < 1 { - er("add needs a name for the command") - } + Run: func(cmd *cobra.Command, args []string) { + if len(args) < 1 { + er("add needs a name for the command") + } - var project *Project - if packageName != "" { - project = NewProject(packageName) - } else { wd, err := os.Getwd() if err != nil { er(err) } - project = NewProjectFromPath(wd) - } - cmdName := validateCmdName(args[0]) - cmdPath := filepath.Join(project.CmdPath(), cmdName+".go") - createCmdFile(project.License(), cmdPath, cmdName) + commandName := validateCmdName(args[0]) + command := &Command{ + CmdName: commandName, + CmdParent: parentName, + Project: &Project{ + AbsolutePath: wd, + Legal: getLicense(), + Copyright: copyrightLine(), + }, + } - fmt.Fprintln(cmd.OutOrStdout(), cmdName, "created at", cmdPath) - }, + err = command.Create() + if err != nil { + er(err) + } + + fmt.Printf("%s created at %s", command.CmdName, command.AbsolutePath) + }, + } +) + +func init() { + addCmd.Flags().StringVarP(&packageName, "package", "t", "", "target package name (e.g. github.com/spf13/hugo)") + addCmd.Flags().StringVarP(&parentName, "parent", "p", "rootCmd", "variable name of parent command for this command (e.g. xyCmd)") + addCmd.Flags().MarkDeprecated("package", "this operation has been removed.") } // validateCmdName returns source without any dashes and underscore. @@ -118,62 +127,3 @@ func validateCmdName(source string) string { } return output } - -func createCmdFile(license License, path, cmdName string) { - template := `{{comment .copyright}} -{{if .license}}{{comment .license}}{{end}} - -package {{.cmdPackage}} - -import ( - "fmt" - - "github.com/spf13/cobra" -) - -// {{.cmdName}}Cmd represents the {{.cmdName}} command -var {{.cmdName}}Cmd = &cobra.Command{ - Use: "{{.cmdName}}", - Short: "A brief description of your command", - Long: ` + "`" + `A longer description that spans multiple lines and likely contains examples -and usage of using your command. For example: - -Cobra is a CLI library for Go that empowers applications. -This application is a tool to generate the needed files -to quickly create a Cobra application.` + "`" + `, - Run: func(cmd *cobra.Command, args []string) { - fmt.Println("{{.cmdName}} called") - }, -} - -func init() { - {{.parentName}}.AddCommand({{.cmdName}}Cmd) - - // Here you will define your flags and configuration settings. - - // Cobra supports Persistent Flags which will work for this command - // and all subcommands, e.g.: - // {{.cmdName}}Cmd.PersistentFlags().String("foo", "", "A help for foo") - - // Cobra supports local flags which will only run when this command - // is called directly, e.g.: - // {{.cmdName}}Cmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") -} -` - - data := make(map[string]interface{}) - data["copyright"] = copyrightLine() - data["license"] = license.Header - data["cmdPackage"] = filepath.Base(filepath.Dir(path)) // last dir of path - data["parentName"] = parentName - data["cmdName"] = cmdName - - cmdScript, err := executeTemplate(template, data) - if err != nil { - er(err) - } - err = writeStringToFile(path, cmdScript) - if err != nil { - er(err) - } -} diff --git a/cobra/cmd/add_test.go b/cobra/cmd/add_test.go index b920e2b9..0de1d221 100644 --- a/cobra/cmd/add_test.go +++ b/cobra/cmd/add_test.go @@ -1,85 +1,45 @@ package cmd import ( - "errors" - "io/ioutil" + "fmt" "os" - "path/filepath" "testing" - - "github.com/spf13/viper" ) -// TestGoldenAddCmd initializes the project "github.com/spf13/testproject" -// in GOPATH, adds "test" command -// and compares the content of all files in cmd directory of testproject -// with appropriate golden files. -// Use -update to update existing golden files. func TestGoldenAddCmd(t *testing.T) { - projectName := "github.com/spf13/testproject" - project := NewProject(projectName) - defer os.RemoveAll(project.AbsPath()) - viper.Set("author", "NAME HERE ") - viper.Set("license", "apache") - viper.Set("year", 2017) - defer viper.Set("author", nil) - defer viper.Set("license", nil) - defer viper.Set("year", nil) + wd, _ := os.Getwd() + command := &Command{ + CmdName: "test", + CmdParent: parentName, + Project: &Project{ + AbsolutePath: fmt.Sprintf("%s/testproject", wd), + Legal: getLicense(), + Copyright: copyrightLine(), - // Initialize the project first. - initializeProject(project) + // required to init + AppName: "testproject", + PkgName: "github.com/spf13/testproject", + Viper: true, + }, + } - // Then add the "test" command. - cmdName := "test" - cmdPath := filepath.Join(project.CmdPath(), cmdName+".go") - createCmdFile(project.License(), cmdPath, cmdName) - - expectedFiles := []string{".", "root.go", "test.go"} - gotFiles := []string{} - - // Check project file hierarchy and compare the content of every single file - // with appropriate golden file. - err := filepath.Walk(project.CmdPath(), func(path string, info os.FileInfo, err error) error { - if err != nil { - return err + // init project first + command.Project.Create() + defer func() { + if _, err := os.Stat(command.AbsolutePath); err == nil { + os.RemoveAll(command.AbsolutePath) } + }() - // Make path relative to project.CmdPath(). - // E.g. path = "/home/user/go/src/github.com/spf13/testproject/cmd/root.go" - // then it returns just "root.go". - relPath, err := filepath.Rel(project.CmdPath(), path) - if err != nil { - return err - } - relPath = filepath.ToSlash(relPath) - gotFiles = append(gotFiles, relPath) - goldenPath := filepath.Join("testdata", filepath.Base(path)+".golden") - - switch relPath { - // Known directories. - case ".": - return nil - // Known files. - case "root.go", "test.go": - if *update { - got, err := ioutil.ReadFile(path) - if err != nil { - return err - } - ioutil.WriteFile(goldenPath, got, 0644) - } - return compareFiles(path, goldenPath) - } - // Unknown file. - return errors.New("unknown file: " + path) - }) - if err != nil { + if err := command.Create(); err != nil { t.Fatal(err) } - // Check if some files lack. - if err := checkLackFiles(expectedFiles, gotFiles); err != nil { + generatedFile := fmt.Sprintf("%s/cmd/%s.go", command.AbsolutePath, command.CmdName) + goldenFile := fmt.Sprintf("testdata/%s.go.golden", command.CmdName) + err := compareFiles(generatedFile, goldenFile) + if err != nil { t.Fatal(err) } } diff --git a/cobra/cmd/init.go b/cobra/cmd/init.go index d65e6c8c..63397d11 100644 --- a/cobra/cmd/init.go +++ b/cobra/cmd/init.go @@ -15,19 +15,20 @@ package cmd import ( "fmt" - "os" - "path" - "path/filepath" - "github.com/spf13/cobra" "github.com/spf13/viper" + "os" + "path" ) -var initCmd = &cobra.Command{ - Use: "init [name]", - Aliases: []string{"initialize", "initialise", "create"}, - Short: "Initialize a Cobra Application", - Long: `Initialize (cobra init) will create a new application, with a license +var ( + pkgName string + + initCmd = &cobra.Command{ + Use: "init [name]", + Aliases: []string{"initialize", "initialise", "create"}, + Short: "Initialize a Cobra Application", + Long: `Initialize (cobra init) will create a new application, with a license and the appropriate structure for a Cobra-based CLI application. * If a name is provided, it will be created in the current directory; @@ -39,196 +40,38 @@ and the appropriate structure for a Cobra-based CLI application. Init will not use an existing directory with contents.`, - Run: func(cmd *cobra.Command, args []string) { - wd, err := os.Getwd() - if err != nil { - er(err) - } + Run: func(cmd *cobra.Command, args []string) { - var project *Project - if len(args) == 0 { - project = NewProjectFromPath(wd) - } else if len(args) == 1 { - arg := args[0] - if arg[0] == '.' { - arg = filepath.Join(wd, arg) + wd, err := os.Getwd() + if err != nil { + er(err) } - if filepath.IsAbs(arg) { - project = NewProjectFromPath(arg) - } else { - project = NewProject(arg) + + if len(args) > 0 { + if args[0] != "." { + wd = fmt.Sprintf("%s/%s", wd, args[0]) + } } - } else { - er("please provide only one argument") - } - initializeProject(project) + project := &Project{ + AbsolutePath: wd, + PkgName: pkgName, + Legal: getLicense(), + Copyright: copyrightLine(), + Viper: viper.GetBool("useViper"), + AppName: path.Base(pkgName), + } - fmt.Fprintln(cmd.OutOrStdout(), `Your Cobra application is ready at -`+project.AbsPath()+` - -Give it a try by going there and running `+"`go run main.go`."+` -Add commands to it by running `+"`cobra add [cmdname]`.") - }, -} - -func initializeProject(project *Project) { - if !exists(project.AbsPath()) { // If path doesn't yet exist, create it - err := os.MkdirAll(project.AbsPath(), os.ModePerm) - if err != nil { - er(err) - } - } else if !isEmpty(project.AbsPath()) { // If path exists and is not empty don't use it - er("Cobra will not create a new project in a non empty directory: " + project.AbsPath()) - } - - // We have a directory and it's empty. Time to initialize it. - createLicenseFile(project.License(), project.AbsPath()) - createMainFile(project) - createRootCmdFile(project) -} - -func createLicenseFile(license License, path string) { - data := make(map[string]interface{}) - data["copyright"] = copyrightLine() - - // Generate license template from text and data. - text, err := executeTemplate(license.Text, data) - if err != nil { - er(err) - } - - // Write license text to LICENSE file. - err = writeStringToFile(filepath.Join(path, "LICENSE"), text) - if err != nil { - er(err) - } -} - -func createMainFile(project *Project) { - mainTemplate := `{{ comment .copyright }} -{{if .license}}{{ comment .license }}{{end}} - -package main - -import "{{ .importpath }}" - -func main() { - cmd.Execute() -} -` - data := make(map[string]interface{}) - data["copyright"] = copyrightLine() - data["license"] = project.License().Header - data["importpath"] = path.Join(project.Name(), filepath.Base(project.CmdPath())) - - mainScript, err := executeTemplate(mainTemplate, data) - if err != nil { - er(err) - } - - err = writeStringToFile(filepath.Join(project.AbsPath(), "main.go"), mainScript) - if err != nil { - er(err) - } -} - -func createRootCmdFile(project *Project) { - template := `{{comment .copyright}} -{{if .license}}{{comment .license}}{{end}} - -package cmd - -import ( - "fmt" - "os" -{{if .viper}} - homedir "github.com/mitchellh/go-homedir"{{end}} - "github.com/spf13/cobra"{{if .viper}} - "github.com/spf13/viper"{{end}} -){{if .viper}} - -var cfgFile string{{end}} - -// rootCmd represents the base command when called without any subcommands -var rootCmd = &cobra.Command{ - Use: "{{.appName}}", - Short: "A brief description of your application", - Long: ` + "`" + `A longer description that spans multiple lines and likely contains -examples and usage of using your application. For example: - -Cobra is a CLI library for Go that empowers applications. -This application is a tool to generate the needed files -to quickly create a Cobra application.` + "`" + `, - // Uncomment the following line if your bare application - // has an action associated with it: - // Run: func(cmd *cobra.Command, args []string) { }, -} - -// Execute adds all child commands to the root command and sets flags appropriately. -// This is called by main.main(). It only needs to happen once to the rootCmd. -func Execute() { - if err := rootCmd.Execute(); err != nil { - fmt.Println(err) - os.Exit(1) - } -} - -func init() { {{- if .viper}} - cobra.OnInitialize(initConfig) -{{end}} - // Here you will define your flags and configuration settings. - // Cobra supports persistent flags, which, if defined here, - // will be global for your application.{{ if .viper }} - rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.{{ .appName }}.yaml)"){{ else }} - // rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.{{ .appName }}.yaml)"){{ end }} - - // Cobra also supports local flags, which will only run - // when this action is called directly. - rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") -}{{ if .viper }} - -// initConfig reads in config file and ENV variables if set. -func initConfig() { - if cfgFile != "" { - // Use config file from the flag. - viper.SetConfigFile(cfgFile) - } else { - // Find home directory. - home, err := homedir.Dir() - if err != nil { - fmt.Println(err) - os.Exit(1) - } - - // Search config in home directory with name ".{{ .appName }}" (without extension). - viper.AddConfigPath(home) - viper.SetConfigName(".{{ .appName }}") - } - - viper.AutomaticEnv() // read in environment variables that match - - // If a config file is found, read it in. - if err := viper.ReadInConfig(); err == nil { - fmt.Println("Using config file:", viper.ConfigFileUsed()) - } -}{{ end }} -` - - data := make(map[string]interface{}) - data["copyright"] = copyrightLine() - data["viper"] = viper.GetBool("useViper") - data["license"] = project.License().Header - data["appName"] = path.Base(project.Name()) - - rootCmdScript, err := executeTemplate(template, data) - if err != nil { - er(err) - } - - err = writeStringToFile(filepath.Join(project.CmdPath(), "root.go"), rootCmdScript) - if err != nil { - er(err) - } + if err := project.Create(); err != nil { + er(err) + } + fmt.Printf("Your Cobra applicaton is ready at\n%s\n", project.AbsolutePath) + }, + } +) + +func init() { + initCmd.Flags().StringVar(&pkgName, "pkg-name", "", "fully qualified pkg name") + initCmd.MarkFlagRequired("pkg-name") } diff --git a/cobra/cmd/init_test.go b/cobra/cmd/init_test.go index 40eb4038..9540b2d3 100644 --- a/cobra/cmd/init_test.go +++ b/cobra/cmd/init_test.go @@ -1,83 +1,42 @@ package cmd import ( - "errors" - "io/ioutil" + "fmt" "os" "path/filepath" "testing" - - "github.com/spf13/viper" ) -// TestGoldenInitCmd initializes the project "github.com/spf13/testproject" -// in GOPATH and compares the content of files in initialized project with -// appropriate golden files ("testdata/*.golden"). -// Use -update to update existing golden files. func TestGoldenInitCmd(t *testing.T) { - projectName := "github.com/spf13/testproject" - project := NewProject(projectName) - defer os.RemoveAll(project.AbsPath()) - viper.Set("author", "NAME HERE ") - viper.Set("license", "apache") - viper.Set("year", 2017) - defer viper.Set("author", nil) - defer viper.Set("license", nil) - defer viper.Set("year", nil) - - os.Args = []string{"cobra", "init", projectName} - if err := rootCmd.Execute(); err != nil { - t.Fatal("Error by execution:", err) + wd, _ := os.Getwd() + project := &Project{ + AbsolutePath: fmt.Sprintf("%s/testproject", wd), + PkgName: "github.com/spf13/testproject", + Legal: getLicense(), + Copyright: copyrightLine(), + Viper: true, + AppName: "testproject", } - expectedFiles := []string{".", "cmd", "LICENSE", "main.go", "cmd/root.go"} - gotFiles := []string{} - - // Check project file hierarchy and compare the content of every single file - // with appropriate golden file. - err := filepath.Walk(project.AbsPath(), func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - // Make path relative to project.AbsPath(). - // E.g. path = "/home/user/go/src/github.com/spf13/testproject/cmd/root.go" - // then it returns just "cmd/root.go". - relPath, err := filepath.Rel(project.AbsPath(), path) - if err != nil { - return err - } - relPath = filepath.ToSlash(relPath) - gotFiles = append(gotFiles, relPath) - goldenPath := filepath.Join("testdata", filepath.Base(path)+".golden") - - switch relPath { - // Known directories. - case ".", "cmd": - return nil - // Known files. - case "LICENSE", "main.go", "cmd/root.go": - if *update { - got, err := ioutil.ReadFile(path) - if err != nil { - return err - } - if err := ioutil.WriteFile(goldenPath, got, 0644); err != nil { - t.Fatal("Error while updating file:", err) - } - } - return compareFiles(path, goldenPath) - } - // Unknown file. - return errors.New("unknown file: " + path) - }) + err := project.Create() if err != nil { t.Fatal(err) } - // Check if some files lack. - if err := checkLackFiles(expectedFiles, gotFiles); err != nil { - t.Fatal(err) + defer func() { + if _, err := os.Stat(project.AbsolutePath); err == nil { + os.RemoveAll(project.AbsolutePath) + } + }() + + expectedFiles := []string{"LICENSE", "main.go", "cmd/root.go"} + for _, f := range expectedFiles { + generatedFile := fmt.Sprintf("%s/%s", project.AbsolutePath, f) + goldenFile := fmt.Sprintf("testdata/%s.golden", filepath.Base(f)) + err := compareFiles(generatedFile, goldenFile) + if err != nil { + t.Fatal(err) + } } } diff --git a/cobra/cmd/project.go b/cobra/cmd/project.go index 7ddb8258..dd2f7ea2 100644 --- a/cobra/cmd/project.go +++ b/cobra/cmd/project.go @@ -1,200 +1,96 @@ package cmd import ( + "fmt" + "github.com/spf13/cobra/cobra/tpl" "os" - "path/filepath" - "runtime" - "strings" + "text/template" ) // Project contains name, license and paths to projects. type Project struct { - absPath string - cmdPath string - srcPath string - license License - name string + // v2 + PkgName string + Copyright string + AbsolutePath string + Legal License + Viper bool + AppName string } -// NewProject returns Project with specified project name. -func NewProject(projectName string) *Project { - if projectName == "" { - er("can't create project with blank name") - } +type Command struct { + CmdName string + CmdParent string + *Project +} - p := new(Project) - p.name = projectName +func (p *Project) Create() error { - // 1. Find already created protect. - p.absPath = findPackage(projectName) - - // 2. If there are no created project with this path, and user is in GOPATH, - // then use GOPATH/src/projectName. - if p.absPath == "" { - wd, err := os.Getwd() - if err != nil { - er(err) - } - for _, srcPath := range srcPaths { - goPath := filepath.Dir(srcPath) - if filepathHasPrefix(wd, goPath) { - p.absPath = filepath.Join(srcPath, projectName) - break - } + // check if AbsolutePath exists + if _, err := os.Stat(p.AbsolutePath); os.IsNotExist(err) { + // create directory + if err := os.Mkdir(p.AbsolutePath, 0754); err != nil { + return err } } - // 3. If user is not in GOPATH, then use (first GOPATH)/src/projectName. - if p.absPath == "" { - p.absPath = filepath.Join(srcPaths[0], projectName) - } - - return p -} - -// findPackage returns full path to existing go package in GOPATHs. -func findPackage(packageName string) string { - if packageName == "" { - return "" - } - - for _, srcPath := range srcPaths { - packagePath := filepath.Join(srcPath, packageName) - if exists(packagePath) { - return packagePath - } - } - - return "" -} - -// NewProjectFromPath returns Project with specified absolute path to -// package. -func NewProjectFromPath(absPath string) *Project { - if absPath == "" { - er("can't create project: absPath can't be blank") - } - if !filepath.IsAbs(absPath) { - er("can't create project: absPath is not absolute") - } - - // If absPath is symlink, use its destination. - fi, err := os.Lstat(absPath) + // create main.go + mainFile, err := os.Create(fmt.Sprintf("%s/main.go", p.AbsolutePath)) if err != nil { - er("can't read path info: " + err.Error()) - } - if fi.Mode()&os.ModeSymlink != 0 { - path, err := os.Readlink(absPath) - if err != nil { - er("can't read the destination of symlink: " + err.Error()) - } - absPath = path + return err } + defer mainFile.Close() - p := new(Project) - p.absPath = strings.TrimSuffix(absPath, findCmdDir(absPath)) - p.name = filepath.ToSlash(trimSrcPath(p.absPath, p.SrcPath())) - return p -} - -// trimSrcPath trims at the beginning of absPath the srcPath. -func trimSrcPath(absPath, srcPath string) string { - relPath, err := filepath.Rel(srcPath, absPath) + mainTemplate := template.Must(template.New("main").Parse(string(tpl.MainTemplate()))) + err = mainTemplate.Execute(mainFile, p) if err != nil { - er(err) + return err } - return relPath + + // create cmd/root.go + if _, err = os.Stat(fmt.Sprintf("%s/cmd", p.AbsolutePath)); os.IsNotExist(err) { + os.Mkdir(fmt.Sprintf("%s/cmd", p.AbsolutePath), 0751) + } + rootFile, err := os.Create(fmt.Sprintf("%s/cmd/root.go", p.AbsolutePath)) + if err != nil { + return err + } + defer rootFile.Close() + + rootTemplate := template.Must(template.New("root").Parse(string(tpl.RootTemplate()))) + err = rootTemplate.Execute(rootFile, p) + if err != nil { + return err + } + + // create license + return p.createLicenseFile() } -// License returns the License object of project. -func (p *Project) License() License { - if p.license.Text == "" && p.license.Name != "None" { - p.license = getLicense() +func (p *Project) createLicenseFile() error { + data := map[string]interface{}{ + "copyright": copyrightLine(), } - return p.license + licenseFile, err := os.Create(fmt.Sprintf("%s/LICENSE", p.AbsolutePath)) + if err != nil { + return err + } + + licenseTemplate := template.Must(template.New("license").Parse(p.Legal.Text)) + return licenseTemplate.Execute(licenseFile, data) } -// Name returns the name of project, e.g. "github.com/spf13/cobra" -func (p Project) Name() string { - return p.name -} - -// CmdPath returns absolute path to directory, where all commands are located. -func (p *Project) CmdPath() string { - if p.absPath == "" { - return "" - } - if p.cmdPath == "" { - p.cmdPath = filepath.Join(p.absPath, findCmdDir(p.absPath)) - } - return p.cmdPath -} - -// findCmdDir checks if base of absPath is cmd dir and returns it or -// looks for existing cmd dir in absPath. -func findCmdDir(absPath string) string { - if !exists(absPath) || isEmpty(absPath) { - return "cmd" - } - - if isCmdDir(absPath) { - return filepath.Base(absPath) - } - - files, _ := filepath.Glob(filepath.Join(absPath, "c*")) - for _, file := range files { - if isCmdDir(file) { - return filepath.Base(file) - } - } - - return "cmd" -} - -// isCmdDir checks if base of name is one of cmdDir. -func isCmdDir(name string) bool { - name = filepath.Base(name) - for _, cmdDir := range []string{"cmd", "cmds", "command", "commands"} { - if name == cmdDir { - return true - } - } - return false -} - -// AbsPath returns absolute path of project. -func (p Project) AbsPath() string { - return p.absPath -} - -// SrcPath returns absolute path to $GOPATH/src where project is located. -func (p *Project) SrcPath() string { - if p.srcPath != "" { - return p.srcPath - } - if p.absPath == "" { - p.srcPath = srcPaths[0] - return p.srcPath - } - - for _, srcPath := range srcPaths { - if filepathHasPrefix(p.absPath, srcPath) { - p.srcPath = srcPath - break - } - } - - return p.srcPath -} - -func filepathHasPrefix(path string, prefix string) bool { - if len(path) <= len(prefix) { - return false - } - if runtime.GOOS == "windows" { - // Paths in windows are case-insensitive. - return strings.EqualFold(path[0:len(prefix)], prefix) - } - return path[0:len(prefix)] == prefix - +func (c *Command) Create() error { + cmdFile, err := os.Create(fmt.Sprintf("%s/cmd/%s.go", c.AbsolutePath, c.CmdName)) + if err != nil { + return err + } + defer cmdFile.Close() + + commandTemplate := template.Must(template.New("sub").Parse(string(tpl.AddCommandTemplate()))) + err = commandTemplate.Execute(cmdFile, c) + if err != nil { + return err + } + return nil } diff --git a/cobra/cmd/project_test.go b/cobra/cmd/project_test.go index 037f7c55..ed5b054a 100644 --- a/cobra/cmd/project_test.go +++ b/cobra/cmd/project_test.go @@ -1,24 +1,3 @@ package cmd -import ( - "testing" -) - -func TestFindExistingPackage(t *testing.T) { - path := findPackage("github.com/spf13/cobra") - if path == "" { - t.Fatal("findPackage didn't find the existing package") - } - if !hasGoPathPrefix(path) { - t.Fatalf("%q is not in GOPATH, but must be", path) - } -} - -func hasGoPathPrefix(path string) bool { - for _, srcPath := range srcPaths { - if filepathHasPrefix(path, srcPath) { - return true - } - } - return false -} +/* todo: write tests */ diff --git a/cobra/cmd/root.go b/cobra/cmd/root.go index 19568f98..624c717c 100644 --- a/cobra/cmd/root.go +++ b/cobra/cmd/root.go @@ -23,7 +23,8 @@ import ( var ( // Used for flags. - cfgFile, userLicense string + cfgFile string + userLicense string rootCmd = &cobra.Command{ Use: "cobra", diff --git a/cobra/cmd/testdata/main.go.golden b/cobra/cmd/testdata/main.go.golden index cdbe38d7..4ad570c5 100644 --- a/cobra/cmd/testdata/main.go.golden +++ b/cobra/cmd/testdata/main.go.golden @@ -1,21 +1,22 @@ -// Copyright © 2017 NAME HERE -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* +Copyright © 2019 NAME HERE +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ package main import "github.com/spf13/testproject/cmd" func main() { - cmd.Execute() + cmd.Execute() } diff --git a/cobra/cmd/testdata/root.go.golden b/cobra/cmd/testdata/root.go.golden index d74f4cd4..d3b889ba 100644 --- a/cobra/cmd/testdata/root.go.golden +++ b/cobra/cmd/testdata/root.go.golden @@ -1,89 +1,97 @@ -// Copyright © 2017 NAME HERE -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* +Copyright © 2019 NAME HERE +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ package cmd import ( - "fmt" - "os" + "fmt" + "os" + "github.com/spf13/cobra" + + homedir "github.com/mitchellh/go-homedir" + "github.com/spf13/viper" - homedir "github.com/mitchellh/go-homedir" - "github.com/spf13/cobra" - "github.com/spf13/viper" ) + var cfgFile string + // rootCmd represents the base command when called without any subcommands var rootCmd = &cobra.Command{ - Use: "testproject", - Short: "A brief description of your application", - Long: `A longer description that spans multiple lines and likely contains + Use: "testproject", + Short: "A brief description of your application", + Long: `A longer description that spans multiple lines and likely contains examples and usage of using your application. For example: Cobra is a CLI library for Go that empowers applications. This application is a tool to generate the needed files to quickly create a Cobra application.`, - // Uncomment the following line if your bare application - // has an action associated with it: - // Run: func(cmd *cobra.Command, args []string) { }, + // Uncomment the following line if your bare application + // has an action associated with it: + // Run: func(cmd *cobra.Command, args []string) { }, } // Execute adds all child commands to the root command and sets flags appropriately. // This is called by main.main(). It only needs to happen once to the rootCmd. func Execute() { - if err := rootCmd.Execute(); err != nil { - fmt.Println(err) - os.Exit(1) - } + if err := rootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } } func init() { - cobra.OnInitialize(initConfig) + cobra.OnInitialize(initConfig) - // Here you will define your flags and configuration settings. - // Cobra supports persistent flags, which, if defined here, - // will be global for your application. - rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.testproject.yaml)") + // Here you will define your flags and configuration settings. + // Cobra supports persistent flags, which, if defined here, + // will be global for your application. - // Cobra also supports local flags, which will only run - // when this action is called directly. - rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") + rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.testproject.yaml)") + + + // Cobra also supports local flags, which will only run + // when this action is called directly. + rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") } + // initConfig reads in config file and ENV variables if set. func initConfig() { - if cfgFile != "" { - // Use config file from the flag. - viper.SetConfigFile(cfgFile) - } else { - // Find home directory. - home, err := homedir.Dir() - if err != nil { - fmt.Println(err) - os.Exit(1) - } + if cfgFile != "" { + // Use config file from the flag. + viper.SetConfigFile(cfgFile) + } else { + // Find home directory. + home, err := homedir.Dir() + if err != nil { + fmt.Println(err) + os.Exit(1) + } - // Search config in home directory with name ".testproject" (without extension). - viper.AddConfigPath(home) - viper.SetConfigName(".testproject") - } + // Search config in home directory with name ".testproject" (without extension). + viper.AddConfigPath(home) + viper.SetConfigName(".testproject") + } - viper.AutomaticEnv() // read in environment variables that match + viper.AutomaticEnv() // read in environment variables that match - // If a config file is found, read it in. - if err := viper.ReadInConfig(); err == nil { - fmt.Println("Using config file:", viper.ConfigFileUsed()) - } + // If a config file is found, read it in. + if err := viper.ReadInConfig(); err == nil { + fmt.Println("Using config file:", viper.ConfigFileUsed()) + } } + diff --git a/cobra/cmd/testdata/test.go.golden b/cobra/cmd/testdata/test.go.golden index ed644275..fb8e0fa9 100644 --- a/cobra/cmd/testdata/test.go.golden +++ b/cobra/cmd/testdata/test.go.golden @@ -1,17 +1,18 @@ -// Copyright © 2017 NAME HERE -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* +Copyright © 2019 NAME HERE +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ package cmd import ( diff --git a/cobra/tpl/main.go b/cobra/tpl/main.go new file mode 100644 index 00000000..5e5a0fae --- /dev/null +++ b/cobra/tpl/main.go @@ -0,0 +1,153 @@ +package tpl + +func MainTemplate() []byte { + return []byte(`/* +{{ .Copyright }} +{{ if .Legal.Header }}{{ .Legal.Header }}{{ end }} +*/ +package main + +import "{{ .PkgName }}/cmd" + +func main() { + cmd.Execute() +} +`) +} + +func RootTemplate() []byte { + return []byte(`/* +{{ .Copyright }} +{{ if .Legal.Header }}{{ .Legal.Header }}{{ end }} +*/ +package cmd + +import ( + "fmt" + "os" + "github.com/spf13/cobra" +{{ if .Viper }} + homedir "github.com/mitchellh/go-homedir" + "github.com/spf13/viper" +{{ end }} +) + +{{ if .Viper }} +var cfgFile string +{{ end }} + +// rootCmd represents the base command when called without any subcommands +var rootCmd = &cobra.Command{ + Use: "{{ .AppName }}", + Short: "A brief description of your application", + Long: ` + "`" + `A longer description that spans multiple lines and likely contains +examples and usage of using your application. For example: + +Cobra is a CLI library for Go that empowers applications. +This application is a tool to generate the needed files +to quickly create a Cobra application.` + "`" + `, + // Uncomment the following line if your bare application + // has an action associated with it: + // Run: func(cmd *cobra.Command, args []string) { }, +} + +// Execute adds all child commands to the root command and sets flags appropriately. +// This is called by main.main(). It only needs to happen once to the rootCmd. +func Execute() { + if err := rootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +func init() { +{{- if .Viper }} + cobra.OnInitialize(initConfig) +{{ end }} + // Here you will define your flags and configuration settings. + // Cobra supports persistent flags, which, if defined here, + // will be global for your application. +{{ if .Viper }} + rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.{{ .AppName }}.yaml)") +{{ else }} + // rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.{{ .AppName }}.yaml)") +{{ end }} + + // Cobra also supports local flags, which will only run + // when this action is called directly. + rootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") +} + +{{ if .Viper }} +// initConfig reads in config file and ENV variables if set. +func initConfig() { + if cfgFile != "" { + // Use config file from the flag. + viper.SetConfigFile(cfgFile) + } else { + // Find home directory. + home, err := homedir.Dir() + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + // Search config in home directory with name ".{{ .AppName }}" (without extension). + viper.AddConfigPath(home) + viper.SetConfigName(".{{ .AppName }}") + } + + viper.AutomaticEnv() // read in environment variables that match + + // If a config file is found, read it in. + if err := viper.ReadInConfig(); err == nil { + fmt.Println("Using config file:", viper.ConfigFileUsed()) + } +} +{{ end }} +`) +} + +func AddCommandTemplate() []byte { + return []byte(`/* +{{ .Project.Copyright }} +{{ if .Legal.Header }}{{ .Legal.Header }}{{ end }} +*/ +package cmd + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +// {{ .CmdName }}Cmd represents the {{ .CmdName }} command +var {{ .CmdName }}Cmd = &cobra.Command{ + Use: "{{ .CmdName }}", + Short: "A brief description of your command", + Long: ` + "`" + `A longer description that spans multiple lines and likely contains examples +and usage of using your command. For example: + +Cobra is a CLI library for Go that empowers applications. +This application is a tool to generate the needed files +to quickly create a Cobra application.` + "`" + `, + Run: func(cmd *cobra.Command, args []string) { + fmt.Println("{{ .CmdName }} called") + }, +} + +func init() { + {{ .CmdParent }}.AddCommand({{ .CmdName }}Cmd) + + // Here you will define your flags and configuration settings. + + // Cobra supports Persistent Flags which will work for this command + // and all subcommands, e.g.: + // {{ .CmdName }}Cmd.PersistentFlags().String("foo", "", "A help for foo") + + // Cobra supports local flags which will only run when this command + // is called directly, e.g.: + // {{ .CmdName }}Cmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") +} +`) +} diff --git a/command.go b/command.go index b257f91b..c7e89830 100644 --- a/command.go +++ b/command.go @@ -177,8 +177,6 @@ type Command struct { // that we can use on every pflag set and children commands globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName - // output is an output writer defined by user. - output io.Writer // usageFunc is usage func defined by user. usageFunc func(*Command) error // usageTemplate is usage template defined by user. @@ -195,6 +193,13 @@ type Command struct { helpCommand *Command // versionTemplate is the version template defined by user. versionTemplate string + + // inReader is a reader defined by the user that replaces stdin + inReader io.Reader + // outWriter is a writer defined by the user that replaces stdout + outWriter io.Writer + // errWriter is a writer defined by the user that replaces stderr + errWriter io.Writer } // SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden @@ -205,8 +210,28 @@ func (c *Command) SetArgs(a []string) { // SetOutput sets the destination for usage and error messages. // If output is nil, os.Stderr is used. +// Deprecated: Use SetOut and/or SetErr instead func (c *Command) SetOutput(output io.Writer) { - c.output = output + c.outWriter = output + c.errWriter = output +} + +// SetOut sets the destination for usage messages. +// If newOut is nil, os.Stdout is used. +func (c *Command) SetOut(newOut io.Writer) { + c.outWriter = newOut +} + +// SetErr sets the destination for error messages. +// If newErr is nil, os.Stderr is used. +func (c *Command) SetErr(newErr io.Writer) { + c.errWriter = newErr +} + +// SetOut sets the source for input data +// If newIn is nil, os.Stdin is used. +func (c *Command) SetIn(newIn io.Reader) { + c.inReader = newIn } // SetUsageFunc sets usage function. Usage can be defined by application. @@ -267,9 +292,19 @@ func (c *Command) OutOrStderr() io.Writer { return c.getOut(os.Stderr) } +// ErrOrStderr returns output to stderr +func (c *Command) ErrOrStderr() io.Writer { + return c.getErr(os.Stderr) +} + +// ErrOrStderr returns output to stderr +func (c *Command) InOrStdin() io.Reader { + return c.getIn(os.Stdin) +} + func (c *Command) getOut(def io.Writer) io.Writer { - if c.output != nil { - return c.output + if c.outWriter != nil { + return c.outWriter } if c.HasParent() { return c.parent.getOut(def) @@ -277,6 +312,26 @@ func (c *Command) getOut(def io.Writer) io.Writer { return def } +func (c *Command) getErr(def io.Writer) io.Writer { + if c.errWriter != nil { + return c.errWriter + } + if c.HasParent() { + return c.parent.getErr(def) + } + return def +} + +func (c *Command) getIn(def io.Reader) io.Reader { + if c.inReader != nil { + return c.inReader + } + if c.HasParent() { + return c.parent.getIn(def) + } + return def +} + // UsageFunc returns either the function set by SetUsageFunc for this command // or a parent, or it returns a default usage function. func (c *Command) UsageFunc() (f func(*Command) error) { @@ -329,13 +384,22 @@ func (c *Command) Help() error { return nil } -// UsageString return usage string. +// UsageString returns usage string. func (c *Command) UsageString() string { - tmpOutput := c.output + // Storing normal writers + tmpOutput := c.outWriter + tmpErr := c.errWriter + bb := new(bytes.Buffer) - c.SetOutput(bb) + c.outWriter = bb + c.errWriter = bb + c.Usage() - c.output = tmpOutput + + // Setting things back to normal + c.outWriter = tmpOutput + c.errWriter = tmpErr + return bb.String() } @@ -1068,6 +1132,21 @@ func (c *Command) Printf(format string, i ...interface{}) { c.Print(fmt.Sprintf(format, i...)) } +// PrintErr is a convenience method to Print to the defined Err output, fallback to Stderr if not set. +func (c *Command) PrintErr(i ...interface{}) { + fmt.Fprint(c.ErrOrStderr(), i...) +} + +// PrintErrln is a convenience method to Println to the defined Err output, fallback to Stderr if not set. +func (c *Command) PrintErrln(i ...interface{}) { + c.Print(fmt.Sprintln(i...)) +} + +// PrintErrf is a convenience method to Printf to the defined Err output, fallback to Stderr if not set. +func (c *Command) PrintErrf(format string, i ...interface{}) { + c.Print(fmt.Sprintf(format, i...)) +} + // CommandPath returns the full path to this command. func (c *Command) CommandPath() string { if c.HasParent() { diff --git a/command_test.go b/command_test.go index 6e483a3e..2fa2003c 100644 --- a/command_test.go +++ b/command_test.go @@ -1381,6 +1381,46 @@ func TestSetOutput(t *testing.T) { } } +func TestSetOut(t *testing.T) { + c := &Command{} + c.SetOut(nil) + if out := c.OutOrStdout(); out != os.Stdout { + t.Errorf("Expected setting output to nil to revert back to stdout") + } +} + +func TestSetErr(t *testing.T) { + c := &Command{} + c.SetErr(nil) + if out := c.ErrOrStderr(); out != os.Stderr { + t.Errorf("Expected setting error to nil to revert back to stderr") + } +} + +func TestSetIn(t *testing.T) { + c := &Command{} + c.SetIn(nil) + if out := c.InOrStdin(); out != os.Stdin { + t.Errorf("Expected setting input to nil to revert back to stdin") + } +} + +func TestUsageStringRedirected(t *testing.T) { + c := &Command{} + + c.usageFunc = func(cmd *Command) error { + cmd.Print("[stdout1]") + cmd.PrintErr("[stderr2]") + cmd.Print("[stdout3]") + return nil + } + + expected := "[stdout1][stderr2][stdout3]" + if got := c.UsageString(); got != expected { + t.Errorf("Expected usage string to consider both stdout and stderr") + } +} + func TestFlagErrorFunc(t *testing.T) { c := &Command{Use: "c", Run: emptyRun} diff --git a/powershell_completions.go b/powershell_completions.go new file mode 100644 index 00000000..756c61b9 --- /dev/null +++ b/powershell_completions.go @@ -0,0 +1,100 @@ +// PowerShell completions are based on the amazing work from clap: +// https://github.com/clap-rs/clap/blob/3294d18efe5f264d12c9035f404c7d189d4824e1/src/completions/powershell.rs +// +// The generated scripts require PowerShell v5.0+ (which comes Windows 10, but +// can be downloaded separately for windows 7 or 8.1). + +package cobra + +import ( + "bytes" + "fmt" + "io" + "os" + "strings" + + "github.com/spf13/pflag" +) + +var powerShellCompletionTemplate = `using namespace System.Management.Automation +using namespace System.Management.Automation.Language +Register-ArgumentCompleter -Native -CommandName '%s' -ScriptBlock { + param($wordToComplete, $commandAst, $cursorPosition) + $commandElements = $commandAst.CommandElements + $command = @( + '%s' + for ($i = 1; $i -lt $commandElements.Count; $i++) { + $element = $commandElements[$i] + if ($element -isnot [StringConstantExpressionAst] -or + $element.StringConstantType -ne [StringConstantType]::BareWord -or + $element.Value.StartsWith('-')) { + break + } + $element.Value + } + ) -join ';' + $completions = @(switch ($command) {%s + }) + $completions.Where{ $_.CompletionText -like "$wordToComplete*" } | + Sort-Object -Property ListItemText +}` + +func generatePowerShellSubcommandCases(out io.Writer, cmd *Command, previousCommandName string) { + var cmdName string + if previousCommandName == "" { + cmdName = cmd.Name() + } else { + cmdName = fmt.Sprintf("%s;%s", previousCommandName, cmd.Name()) + } + + fmt.Fprintf(out, "\n '%s' {", cmdName) + + cmd.Flags().VisitAll(func(flag *pflag.Flag) { + if nonCompletableFlag(flag) { + return + } + usage := escapeStringForPowerShell(flag.Usage) + if len(flag.Shorthand) > 0 { + fmt.Fprintf(out, "\n [CompletionResult]::new('-%s', '%s', [CompletionResultType]::ParameterName, '%s')", flag.Shorthand, flag.Shorthand, usage) + } + fmt.Fprintf(out, "\n [CompletionResult]::new('--%s', '%s', [CompletionResultType]::ParameterName, '%s')", flag.Name, flag.Name, usage) + }) + + for _, subCmd := range cmd.Commands() { + usage := escapeStringForPowerShell(subCmd.Short) + fmt.Fprintf(out, "\n [CompletionResult]::new('%s', '%s', [CompletionResultType]::ParameterValue, '%s')", subCmd.Name(), subCmd.Name(), usage) + } + + fmt.Fprint(out, "\n break\n }") + + for _, subCmd := range cmd.Commands() { + generatePowerShellSubcommandCases(out, subCmd, cmdName) + } +} + +func escapeStringForPowerShell(s string) string { + return strings.Replace(s, "'", "''", -1) +} + +// GenPowerShellCompletion generates PowerShell completion file and writes to the passed writer. +func (c *Command) GenPowerShellCompletion(w io.Writer) error { + buf := new(bytes.Buffer) + + var subCommandCases bytes.Buffer + generatePowerShellSubcommandCases(&subCommandCases, c, "") + fmt.Fprintf(buf, powerShellCompletionTemplate, c.Name(), c.Name(), subCommandCases.String()) + + _, err := buf.WriteTo(w) + return err +} + +// GenPowerShellCompletionFile generates PowerShell completion file. +func (c *Command) GenPowerShellCompletionFile(filename string) error { + outFile, err := os.Create(filename) + if err != nil { + return err + } + defer outFile.Close() + + return c.GenPowerShellCompletion(outFile) +} diff --git a/powershell_completions.md b/powershell_completions.md new file mode 100644 index 00000000..afed8024 --- /dev/null +++ b/powershell_completions.md @@ -0,0 +1,14 @@ +# Generating PowerShell Completions For Your Own cobra.Command + +Cobra can generate PowerShell completion scripts. Users need PowerShell version 5.0 or above, which comes with Windows 10 and can be downloaded separately for Windows 7 or 8.1. They can then write the completions to a file and source this file from their PowerShell profile, which is referenced by the `$Profile` environment variable. See `Get-Help about_Profiles` for more info about PowerShell profiles. + +# What's supported + +- Completion for subcommands using their `.Short` description +- Completion for non-hidden flags using their `.Name` and `.Shorthand` + +# What's not yet supported + +- Command aliases +- Required, filename or custom flags (they will work like normal flags) +- Custom completion scripts diff --git a/powershell_completions_test.go b/powershell_completions_test.go new file mode 100644 index 00000000..29b609de --- /dev/null +++ b/powershell_completions_test.go @@ -0,0 +1,122 @@ +package cobra + +import ( + "bytes" + "strings" + "testing" +) + +func TestPowerShellCompletion(t *testing.T) { + tcs := []struct { + name string + root *Command + expectedExpressions []string + }{ + { + name: "trivial", + root: &Command{Use: "trivialapp"}, + expectedExpressions: []string{ + "Register-ArgumentCompleter -Native -CommandName 'trivialapp' -ScriptBlock", + "$command = @(\n 'trivialapp'\n", + }, + }, + { + name: "tree", + root: func() *Command { + r := &Command{Use: "tree"} + + sub1 := &Command{Use: "sub1"} + r.AddCommand(sub1) + + sub11 := &Command{Use: "sub11"} + sub12 := &Command{Use: "sub12"} + + sub1.AddCommand(sub11) + sub1.AddCommand(sub12) + + sub2 := &Command{Use: "sub2"} + r.AddCommand(sub2) + + sub21 := &Command{Use: "sub21"} + sub22 := &Command{Use: "sub22"} + + sub2.AddCommand(sub21) + sub2.AddCommand(sub22) + + return r + }(), + expectedExpressions: []string{ + "'tree'", + "[CompletionResult]::new('sub1', 'sub1', [CompletionResultType]::ParameterValue, '')", + "[CompletionResult]::new('sub2', 'sub2', [CompletionResultType]::ParameterValue, '')", + "'tree;sub1'", + "[CompletionResult]::new('sub11', 'sub11', [CompletionResultType]::ParameterValue, '')", + "[CompletionResult]::new('sub12', 'sub12', [CompletionResultType]::ParameterValue, '')", + "'tree;sub1;sub11'", + "'tree;sub1;sub12'", + "'tree;sub2'", + "[CompletionResult]::new('sub21', 'sub21', [CompletionResultType]::ParameterValue, '')", + "[CompletionResult]::new('sub22', 'sub22', [CompletionResultType]::ParameterValue, '')", + "'tree;sub2;sub21'", + "'tree;sub2;sub22'", + }, + }, + { + name: "flags", + root: func() *Command { + r := &Command{Use: "flags"} + r.Flags().StringP("flag1", "a", "", "") + r.Flags().String("flag2", "", "") + + sub1 := &Command{Use: "sub1"} + sub1.Flags().StringP("flag3", "c", "", "") + r.AddCommand(sub1) + + return r + }(), + expectedExpressions: []string{ + "'flags'", + "[CompletionResult]::new('-a', 'a', [CompletionResultType]::ParameterName, '')", + "[CompletionResult]::new('--flag1', 'flag1', [CompletionResultType]::ParameterName, '')", + "[CompletionResult]::new('--flag2', 'flag2', [CompletionResultType]::ParameterName, '')", + "[CompletionResult]::new('sub1', 'sub1', [CompletionResultType]::ParameterValue, '')", + "'flags;sub1'", + "[CompletionResult]::new('-c', 'c', [CompletionResultType]::ParameterName, '')", + "[CompletionResult]::new('--flag3', 'flag3', [CompletionResultType]::ParameterName, '')", + }, + }, + { + name: "usage", + root: func() *Command { + r := &Command{Use: "usage"} + r.Flags().String("flag", "", "this describes the usage of the 'flag' flag") + + sub1 := &Command{ + Use: "sub1", + Short: "short describes 'sub1'", + } + r.AddCommand(sub1) + + return r + }(), + expectedExpressions: []string{ + "[CompletionResult]::new('--flag', 'flag', [CompletionResultType]::ParameterName, 'this describes the usage of the ''flag'' flag')", + "[CompletionResult]::new('sub1', 'sub1', [CompletionResultType]::ParameterValue, 'short describes ''sub1''')", + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) + tc.root.GenPowerShellCompletion(buf) + output := buf.String() + + for _, expectedExpression := range tc.expectedExpressions { + if !strings.Contains(output, expectedExpression) { + t.Errorf("Expected completion to contain %q somewhere; got %q", expectedExpression, output) + } + } + }) + } +} diff --git a/shell_completions.go b/shell_completions.go new file mode 100644 index 00000000..ba0af9cb --- /dev/null +++ b/shell_completions.go @@ -0,0 +1,85 @@ +package cobra + +import ( + "github.com/spf13/pflag" +) + +// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag if it exists, +// and causes your command to report an error if invoked without the flag. +func (c *Command) MarkFlagRequired(name string) error { + return MarkFlagRequired(c.Flags(), name) +} + +// MarkPersistentFlagRequired adds the BashCompOneRequiredFlag annotation to the named persistent flag if it exists, +// and causes your command to report an error if invoked without the flag. +func (c *Command) MarkPersistentFlagRequired(name string) error { + return MarkFlagRequired(c.PersistentFlags(), name) +} + +// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag if it exists, +// and causes your command to report an error if invoked without the flag. +func MarkFlagRequired(flags *pflag.FlagSet, name string) error { + return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"}) +} + +// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag, if it exists. +// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided. +func (c *Command) MarkFlagFilename(name string, extensions ...string) error { + return MarkFlagFilename(c.Flags(), name, extensions...) +} + +// MarkFlagCustom adds the BashCompCustom annotation to the named flag, if it exists. +// Generated bash autocompletion will call the bash function f for the flag. +func (c *Command) MarkFlagCustom(name string, f string) error { + return MarkFlagCustom(c.Flags(), name, f) +} + +// MarkPersistentFlagFilename instructs the various shell completion +// implementations to limit completions for this persistent flag to the +// specified extensions (patterns). +// +// Shell Completion compatibility matrix: bash, zsh +func (c *Command) MarkPersistentFlagFilename(name string, extensions ...string) error { + return MarkFlagFilename(c.PersistentFlags(), name, extensions...) +} + +// MarkFlagFilename instructs the various shell completion implementations to +// limit completions for this flag to the specified extensions (patterns). +// +// Shell Completion compatibility matrix: bash, zsh +func MarkFlagFilename(flags *pflag.FlagSet, name string, extensions ...string) error { + return flags.SetAnnotation(name, BashCompFilenameExt, extensions) +} + +// MarkFlagCustom instructs the various shell completion implementations to +// limit completions for this flag to the specified extensions (patterns). +// +// Shell Completion compatibility matrix: bash, zsh +func MarkFlagCustom(flags *pflag.FlagSet, name string, f string) error { + return flags.SetAnnotation(name, BashCompCustom, []string{f}) +} + +// MarkFlagDirname instructs the various shell completion implementations to +// complete only directories with this named flag. +// +// Shell Completion compatibility matrix: zsh +func (c *Command) MarkFlagDirname(name string) error { + return MarkFlagDirname(c.Flags(), name) +} + +// MarkPersistentFlagDirname instructs the various shell completion +// implementations to complete only directories with this persistent named flag. +// +// Shell Completion compatibility matrix: zsh +func (c *Command) MarkPersistentFlagDirname(name string) error { + return MarkFlagDirname(c.PersistentFlags(), name) +} + +// MarkFlagDirname instructs the various shell completion implementations to +// complete only directories with this specified flag. +// +// Shell Completion compatibility matrix: zsh +func MarkFlagDirname(flags *pflag.FlagSet, name string) error { + zshPattern := "-(/)" + return flags.SetAnnotation(name, zshCompDirname, []string{zshPattern}) +} diff --git a/zsh_completions.go b/zsh_completions.go index 889c22e2..12755482 100644 --- a/zsh_completions.go +++ b/zsh_completions.go @@ -1,13 +1,102 @@ package cobra import ( - "bytes" + "encoding/json" "fmt" "io" "os" + "sort" "strings" + "text/template" + + "github.com/spf13/pflag" ) +const ( + zshCompArgumentAnnotation = "cobra_annotations_zsh_completion_argument_annotation" + zshCompArgumentFilenameComp = "cobra_annotations_zsh_completion_argument_file_completion" + zshCompArgumentWordComp = "cobra_annotations_zsh_completion_argument_word_completion" + zshCompDirname = "cobra_annotations_zsh_dirname" +) + +var ( + zshCompFuncMap = template.FuncMap{ + "genZshFuncName": zshCompGenFuncName, + "extractFlags": zshCompExtractFlag, + "genFlagEntryForZshArguments": zshCompGenFlagEntryForArguments, + "extractArgsCompletions": zshCompExtractArgumentCompletionHintsForRendering, + } + zshCompletionText = ` +{{/* should accept Command (that contains subcommands) as parameter */}} +{{define "argumentsC" -}} +{{ $cmdPath := genZshFuncName .}} +function {{$cmdPath}} { + local -a commands + + _arguments -C \{{- range extractFlags .}} + {{genFlagEntryForZshArguments .}} \{{- end}} + "1: :->cmnds" \ + "*::arg:->args" + + case $state in + cmnds) + commands=({{range .Commands}}{{if not .Hidden}} + "{{.Name}}:{{.Short}}"{{end}}{{end}} + ) + _describe "command" commands + ;; + esac + + case "$words[1]" in {{- range .Commands}}{{if not .Hidden}} + {{.Name}}) + {{$cmdPath}}_{{.Name}} + ;;{{end}}{{end}} + esac +} +{{range .Commands}}{{if not .Hidden}} +{{template "selectCmdTemplate" .}} +{{- end}}{{end}} +{{- end}} + +{{/* should accept Command without subcommands as parameter */}} +{{define "arguments" -}} +function {{genZshFuncName .}} { +{{" _arguments"}}{{range extractFlags .}} \ + {{genFlagEntryForZshArguments . -}} +{{end}}{{range extractArgsCompletions .}} \ + {{.}}{{end}} +} +{{end}} + +{{/* dispatcher for commands with or without subcommands */}} +{{define "selectCmdTemplate" -}} +{{if .Hidden}}{{/* ignore hidden*/}}{{else -}} +{{if .Commands}}{{template "argumentsC" .}}{{else}}{{template "arguments" .}}{{end}} +{{- end}} +{{- end}} + +{{/* template entry point */}} +{{define "Main" -}} +#compdef _{{.Name}} {{.Name}} + +{{template "selectCmdTemplate" .}} +{{end}} +` +) + +// zshCompArgsAnnotation is used to encode/decode zsh completion for +// arguments to/from Command.Annotations. +type zshCompArgsAnnotation map[int]zshCompArgHint + +type zshCompArgHint struct { + // Indicates the type of the completion to use. One of: + // zshCompArgumentFilenameComp or zshCompArgumentWordComp + Tipe string `json:"type"` + + // A value for the type above (globs for file completion or words) + Options []string `json:"options"` +} + // GenZshCompletionFile generates zsh completion file. func (c *Command) GenZshCompletionFile(filename string) error { outFile, err := os.Create(filename) @@ -19,108 +108,229 @@ func (c *Command) GenZshCompletionFile(filename string) error { return c.GenZshCompletion(outFile) } -// GenZshCompletion generates a zsh completion file and writes to the passed writer. +// GenZshCompletion generates a zsh completion file and writes to the passed +// writer. The completion always run on the root command regardless of the +// command it was called from. func (c *Command) GenZshCompletion(w io.Writer) error { - buf := new(bytes.Buffer) - - writeHeader(buf, c) - maxDepth := maxDepth(c) - writeLevelMapping(buf, maxDepth) - writeLevelCases(buf, maxDepth, c) - - _, err := buf.WriteTo(w) - return err -} - -func writeHeader(w io.Writer, cmd *Command) { - fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name()) -} - -func maxDepth(c *Command) int { - if len(c.Commands()) == 0 { - return 0 + tmpl, err := template.New("Main").Funcs(zshCompFuncMap).Parse(zshCompletionText) + if err != nil { + return fmt.Errorf("error creating zsh completion template: %v", err) } - maxDepthSub := 0 - for _, s := range c.Commands() { - subDepth := maxDepth(s) - if subDepth > maxDepthSub { - maxDepthSub = subDepth + return tmpl.Execute(w, c.Root()) +} + +// MarkZshCompPositionalArgumentFile marks the specified argument (first +// argument is 1) as completed by file selection. patterns (e.g. "*.txt") are +// optional - if not provided the completion will search for all files. +func (c *Command) MarkZshCompPositionalArgumentFile(argPosition int, patterns ...string) error { + if argPosition < 1 { + return fmt.Errorf("Invalid argument position (%d)", argPosition) + } + annotation, err := c.zshCompGetArgsAnnotations() + if err != nil { + return err + } + if c.zshcompArgsAnnotationnIsDuplicatePosition(annotation, argPosition) { + return fmt.Errorf("Duplicate annotation for positional argument at index %d", argPosition) + } + annotation[argPosition] = zshCompArgHint{ + Tipe: zshCompArgumentFilenameComp, + Options: patterns, + } + return c.zshCompSetArgsAnnotations(annotation) +} + +// MarkZshCompPositionalArgumentWords marks the specified positional argument +// (first argument is 1) as completed by the provided words. At east one word +// must be provided, spaces within words will be offered completion with +// "word\ word". +func (c *Command) MarkZshCompPositionalArgumentWords(argPosition int, words ...string) error { + if argPosition < 1 { + return fmt.Errorf("Invalid argument position (%d)", argPosition) + } + if len(words) == 0 { + return fmt.Errorf("Trying to set empty word list for positional argument %d", argPosition) + } + annotation, err := c.zshCompGetArgsAnnotations() + if err != nil { + return err + } + if c.zshcompArgsAnnotationnIsDuplicatePosition(annotation, argPosition) { + return fmt.Errorf("Duplicate annotation for positional argument at index %d", argPosition) + } + annotation[argPosition] = zshCompArgHint{ + Tipe: zshCompArgumentWordComp, + Options: words, + } + return c.zshCompSetArgsAnnotations(annotation) +} + +func zshCompExtractArgumentCompletionHintsForRendering(c *Command) ([]string, error) { + var result []string + annotation, err := c.zshCompGetArgsAnnotations() + if err != nil { + return nil, err + } + for k, v := range annotation { + s, err := zshCompRenderZshCompArgHint(k, v) + if err != nil { + return nil, err + } + result = append(result, s) + } + if len(c.ValidArgs) > 0 { + if _, positionOneExists := annotation[1]; !positionOneExists { + s, err := zshCompRenderZshCompArgHint(1, zshCompArgHint{ + Tipe: zshCompArgumentWordComp, + Options: c.ValidArgs, + }) + if err != nil { + return nil, err + } + result = append(result, s) } } - return 1 + maxDepthSub + sort.Strings(result) + return result, nil } -func writeLevelMapping(w io.Writer, numLevels int) { - fmt.Fprintln(w, `_arguments \`) - for i := 1; i <= numLevels; i++ { - fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i) - fmt.Fprintln(w) - } - fmt.Fprintf(w, ` '%d: :%s'`, numLevels+1, "_files") - fmt.Fprintln(w) -} - -func writeLevelCases(w io.Writer, maxDepth int, root *Command) { - fmt.Fprintln(w, "case $state in") - defer fmt.Fprintln(w, "esac") - - for i := 1; i <= maxDepth; i++ { - fmt.Fprintf(w, " level%d)\n", i) - writeLevel(w, root, i) - fmt.Fprintln(w, " ;;") - } - fmt.Fprintln(w, " *)") - fmt.Fprintln(w, " _arguments '*: :_files'") - fmt.Fprintln(w, " ;;") -} - -func writeLevel(w io.Writer, root *Command, i int) { - fmt.Fprintf(w, " case $words[%d] in\n", i) - defer fmt.Fprintln(w, " esac") - - commands := filterByLevel(root, i) - byParent := groupByParent(commands) - - for p, c := range byParent { - names := names(c) - fmt.Fprintf(w, " %s)\n", p) - fmt.Fprintf(w, " _arguments '%d: :(%s)'\n", i, strings.Join(names, " ")) - fmt.Fprintln(w, " ;;") - } - fmt.Fprintln(w, " *)") - fmt.Fprintln(w, " _arguments '*: :_files'") - fmt.Fprintln(w, " ;;") - -} - -func filterByLevel(c *Command, l int) []*Command { - cs := make([]*Command, 0) - if l == 0 { - cs = append(cs, c) - return cs - } - for _, s := range c.Commands() { - cs = append(cs, filterByLevel(s, l-1)...) - } - return cs -} - -func groupByParent(commands []*Command) map[string][]*Command { - m := make(map[string][]*Command) - for _, c := range commands { - parent := c.Parent() - if parent == nil { - continue +func zshCompRenderZshCompArgHint(i int, z zshCompArgHint) (string, error) { + switch t := z.Tipe; t { + case zshCompArgumentFilenameComp: + var globs []string + for _, g := range z.Options { + globs = append(globs, fmt.Sprintf(`-g "%s"`, g)) } - m[parent.Name()] = append(m[parent.Name()], c) + return fmt.Sprintf(`'%d: :_files %s'`, i, strings.Join(globs, " ")), nil + case zshCompArgumentWordComp: + var words []string + for _, w := range z.Options { + words = append(words, fmt.Sprintf("%q", w)) + } + return fmt.Sprintf(`'%d: :(%s)'`, i, strings.Join(words, " ")), nil + default: + return "", fmt.Errorf("Invalid zsh argument completion annotation: %s", t) } - return m } -func names(commands []*Command) []string { - ns := make([]string, len(commands)) - for i, c := range commands { - ns[i] = c.Name() - } - return ns +func (c *Command) zshcompArgsAnnotationnIsDuplicatePosition(annotation zshCompArgsAnnotation, position int) bool { + _, dup := annotation[position] + return dup +} + +func (c *Command) zshCompGetArgsAnnotations() (zshCompArgsAnnotation, error) { + annotation := make(zshCompArgsAnnotation) + annotationString, ok := c.Annotations[zshCompArgumentAnnotation] + if !ok { + return annotation, nil + } + err := json.Unmarshal([]byte(annotationString), &annotation) + if err != nil { + return annotation, fmt.Errorf("Error unmarshaling zsh argument annotation: %v", err) + } + return annotation, nil +} + +func (c *Command) zshCompSetArgsAnnotations(annotation zshCompArgsAnnotation) error { + jsn, err := json.Marshal(annotation) + if err != nil { + return fmt.Errorf("Error marshaling zsh argument annotation: %v", err) + } + if c.Annotations == nil { + c.Annotations = make(map[string]string) + } + c.Annotations[zshCompArgumentAnnotation] = string(jsn) + return nil +} + +func zshCompGenFuncName(c *Command) string { + if c.HasParent() { + return zshCompGenFuncName(c.Parent()) + "_" + c.Name() + } + return "_" + c.Name() +} + +func zshCompExtractFlag(c *Command) []*pflag.Flag { + var flags []*pflag.Flag + c.LocalFlags().VisitAll(func(f *pflag.Flag) { + if !f.Hidden { + flags = append(flags, f) + } + }) + c.InheritedFlags().VisitAll(func(f *pflag.Flag) { + if !f.Hidden { + flags = append(flags, f) + } + }) + return flags +} + +// zshCompGenFlagEntryForArguments returns an entry that matches _arguments +// zsh-completion parameters. It's too complicated to generate in a template. +func zshCompGenFlagEntryForArguments(f *pflag.Flag) string { + if f.Name == "" || f.Shorthand == "" { + return zshCompGenFlagEntryForSingleOptionFlag(f) + } + return zshCompGenFlagEntryForMultiOptionFlag(f) +} + +func zshCompGenFlagEntryForSingleOptionFlag(f *pflag.Flag) string { + var option, multiMark, extras string + + if zshCompFlagCouldBeSpecifiedMoreThenOnce(f) { + multiMark = "*" + } + + option = "--" + f.Name + if option == "--" { + option = "-" + f.Shorthand + } + extras = zshCompGenFlagEntryExtras(f) + + return fmt.Sprintf(`'%s%s[%s]%s'`, multiMark, option, zshCompQuoteFlagDescription(f.Usage), extras) +} + +func zshCompGenFlagEntryForMultiOptionFlag(f *pflag.Flag) string { + var options, parenMultiMark, curlyMultiMark, extras string + + if zshCompFlagCouldBeSpecifiedMoreThenOnce(f) { + parenMultiMark = "*" + curlyMultiMark = "\\*" + } + + options = fmt.Sprintf(`'(%s-%s %s--%s)'{%s-%s,%s--%s}`, + parenMultiMark, f.Shorthand, parenMultiMark, f.Name, curlyMultiMark, f.Shorthand, curlyMultiMark, f.Name) + extras = zshCompGenFlagEntryExtras(f) + + return fmt.Sprintf(`%s'[%s]%s'`, options, zshCompQuoteFlagDescription(f.Usage), extras) +} + +func zshCompGenFlagEntryExtras(f *pflag.Flag) string { + if f.NoOptDefVal != "" { + return "" + } + + extras := ":" // allow options for flag (even without assistance) + for key, values := range f.Annotations { + switch key { + case zshCompDirname: + extras = fmt.Sprintf(":filename:_files -g %q", values[0]) + case BashCompFilenameExt: + extras = ":filename:_files" + for _, pattern := range values { + extras = extras + fmt.Sprintf(` -g "%s"`, pattern) + } + } + } + + return extras +} + +func zshCompFlagCouldBeSpecifiedMoreThenOnce(f *pflag.Flag) bool { + return strings.Contains(f.Value.Type(), "Slice") || + strings.Contains(f.Value.Type(), "Array") +} + +func zshCompQuoteFlagDescription(s string) string { + return strings.Replace(s, "'", `'\''`, -1) } diff --git a/zsh_completions.md b/zsh_completions.md new file mode 100644 index 00000000..df9c2eac --- /dev/null +++ b/zsh_completions.md @@ -0,0 +1,39 @@ +## Generating Zsh Completion for your cobra.Command + +Cobra supports native Zsh completion generated from the root `cobra.Command`. +The generated completion script should be put somewhere in your `$fpath` named +`_`. + +### What's Supported + +* Completion for all non-hidden subcommands using their `.Short` description. +* Completion for all non-hidden flags using the following rules: + * Filename completion works by marking the flag with `cmd.MarkFlagFilename...` + family of commands. + * The requirement for argument to the flag is decided by the `.NoOptDefVal` + flag value - if it's empty then completion will expect an argument. + * Flags of one of the various `*Array` and `*Slice` types supports multiple + specifications (with or without argument depending on the specific type). +* Completion of positional arguments using the following rules: + * Argument position for all options below starts at `1`. If argument position + `0` is requested it will raise an error. + * Use `command.MarkZshCompPositionalArgumentFile` to complete filenames. Glob + patterns (e.g. `"*.log"`) are optional - if not specified it will offer to + complete all file types. + * Use `command.MarkZshCompPositionalArgumentWords` to offer specific words for + completion. At least one word is required. + * It's possible to specify completion for some arguments and leave some + unspecified (e.g. offer words for second argument but nothing for first + argument). This will cause no completion for first argument but words + completion for second argument. + * If no argument completion was specified for 1st argument (but optionally was + specified for 2nd) and the command has `ValidArgs` it will be used as + completion options for 1st argument. + * Argument completions only offered for commands with no subcommands. + +### What's not yet Supported + +* Custom completion scripts are not supported yet (We should probably create zsh + specific one, doesn't make sense to re-use the bash one as the functions will + be different). +* Whatever other feature you're looking for and doesn't exist :) diff --git a/zsh_completions_test.go b/zsh_completions_test.go index 34e69496..e53fa886 100644 --- a/zsh_completions_test.go +++ b/zsh_completions_test.go @@ -2,88 +2,474 @@ package cobra import ( "bytes" + "regexp" "strings" "testing" ) -func TestZshCompletion(t *testing.T) { +func TestGenZshCompletion(t *testing.T) { + var debug bool + var option string + + tcs := []struct { + name string + root *Command + expectedExpressions []string + invocationArgs []string + skip string + }{ + { + name: "simple command", + root: func() *Command { + r := &Command{ + Use: "mycommand", + Long: "My Command long description", + Run: emptyRun, + } + r.Flags().BoolVar(&debug, "debug", debug, "description") + return r + }(), + expectedExpressions: []string{ + `(?s)function _mycommand {\s+_arguments \\\s+'--debug\[description\]'.*--help.*}`, + "#compdef _mycommand mycommand", + }, + }, + { + name: "flags with both long and short flags", + root: func() *Command { + r := &Command{ + Use: "testcmd", + Long: "long description", + Run: emptyRun, + } + r.Flags().BoolVarP(&debug, "debug", "d", debug, "debug description") + return r + }(), + expectedExpressions: []string{ + `'\(-d --debug\)'{-d,--debug}'\[debug description\]'`, + }, + }, + { + name: "command with subcommands and flags with values", + root: func() *Command { + r := &Command{ + Use: "rootcmd", + Long: "Long rootcmd description", + } + d := &Command{ + Use: "subcmd1", + Short: "Subcmd1 short description", + Run: emptyRun, + } + e := &Command{ + Use: "subcmd2", + Long: "Subcmd2 short description", + Run: emptyRun, + } + r.PersistentFlags().BoolVar(&debug, "debug", debug, "description") + d.Flags().StringVarP(&option, "option", "o", option, "option description") + r.AddCommand(d, e) + return r + }(), + expectedExpressions: []string{ + `commands=\(\n\s+"help:.*\n\s+"subcmd1:.*\n\s+"subcmd2:.*\n\s+\)`, + `_arguments \\\n.*'--debug\[description]'`, + `_arguments -C \\\n.*'--debug\[description]'`, + `function _rootcmd_subcmd1 {`, + `function _rootcmd_subcmd1 {`, + `_arguments \\\n.*'\(-o --option\)'{-o,--option}'\[option description]:' \\\n`, + }, + }, + { + name: "filename completion with and without globs", + root: func() *Command { + var file string + r := &Command{ + Use: "mycmd", + Short: "my command short description", + Run: emptyRun, + } + r.Flags().StringVarP(&file, "config", "c", file, "config file") + r.MarkFlagFilename("config") + r.Flags().String("output", "", "output file") + r.MarkFlagFilename("output", "*.log", "*.txt") + return r + }(), + expectedExpressions: []string{ + `\n +'\(-c --config\)'{-c,--config}'\[config file]:filename:_files'`, + `:_files -g "\*.log" -g "\*.txt"`, + }, + }, + { + name: "repeated variables both with and without value", + root: func() *Command { + r := genTestCommand("mycmd", true) + _ = r.Flags().BoolSliceP("debug", "d", []bool{}, "debug usage") + _ = r.Flags().StringArray("option", []string{}, "options") + return r + }(), + expectedExpressions: []string{ + `'\*--option\[options]`, + `'\(\*-d \*--debug\)'{\\\*-d,\\\*--debug}`, + }, + }, + { + name: "generated flags --help and --version should be created even when not executing root cmd", + root: func() *Command { + r := &Command{ + Use: "mycmd", + Short: "mycmd short description", + Version: "myversion", + } + s := genTestCommand("sub1", true) + r.AddCommand(s) + return s + }(), + expectedExpressions: []string{ + "--version", + "--help", + }, + invocationArgs: []string{ + "sub1", + }, + skip: "--version and --help are currently not generated when not running on root command", + }, + { + name: "zsh generation should run on root command", + root: func() *Command { + r := genTestCommand("root", false) + s := genTestCommand("sub1", true) + r.AddCommand(s) + return s + }(), + expectedExpressions: []string{ + "function _root {", + }, + }, + { + name: "flag description with single quote (') shouldn't break quotes in completion file", + root: func() *Command { + r := genTestCommand("root", true) + r.Flags().Bool("private", false, "Don't show public info") + return r + }(), + expectedExpressions: []string{ + `--private\[Don'\\''t show public info]`, + }, + }, + { + name: "argument completion for file with and without patterns", + root: func() *Command { + r := genTestCommand("root", true) + r.MarkZshCompPositionalArgumentFile(1, "*.log") + r.MarkZshCompPositionalArgumentFile(2) + return r + }(), + expectedExpressions: []string{ + `'1: :_files -g "\*.log"' \\\n\s+'2: :_files`, + }, + }, + { + name: "argument zsh completion for words", + root: func() *Command { + r := genTestCommand("root", true) + r.MarkZshCompPositionalArgumentWords(1, "word1", "word2") + return r + }(), + expectedExpressions: []string{ + `'1: :\("word1" "word2"\)`, + }, + }, + { + name: "argument completion for words with spaces", + root: func() *Command { + r := genTestCommand("root", true) + r.MarkZshCompPositionalArgumentWords(1, "single", "multiple words") + return r + }(), + expectedExpressions: []string{ + `'1: :\("single" "multiple words"\)'`, + }, + }, + { + name: "argument completion when command has ValidArgs and no annotation for argument completion", + root: func() *Command { + r := genTestCommand("root", true) + r.ValidArgs = []string{"word1", "word2"} + return r + }(), + expectedExpressions: []string{ + `'1: :\("word1" "word2"\)'`, + }, + }, + { + name: "argument completion when command has ValidArgs and no annotation for argument at argPosition 1", + root: func() *Command { + r := genTestCommand("root", true) + r.ValidArgs = []string{"word1", "word2"} + r.MarkZshCompPositionalArgumentFile(2) + return r + }(), + expectedExpressions: []string{ + `'1: :\("word1" "word2"\)' \\`, + }, + }, + { + name: "directory completion for flag", + root: func() *Command { + r := genTestCommand("root", true) + r.Flags().String("test", "", "test") + r.PersistentFlags().String("ptest", "", "ptest") + r.MarkFlagDirname("test") + r.MarkPersistentFlagDirname("ptest") + return r + }(), + expectedExpressions: []string{ + `--test\[test]:filename:_files -g "-\(/\)"`, + `--ptest\[ptest]:filename:_files -g "-\(/\)"`, + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + if tc.skip != "" { + t.Skip(tc.skip) + } + tc.root.Root().SetArgs(tc.invocationArgs) + tc.root.Execute() + buf := new(bytes.Buffer) + if err := tc.root.GenZshCompletion(buf); err != nil { + t.Error(err) + } + output := buf.Bytes() + + for _, expr := range tc.expectedExpressions { + rgx, err := regexp.Compile(expr) + if err != nil { + t.Errorf("error compiling expression (%s): %v", expr, err) + } + if !rgx.Match(output) { + t.Errorf("expected completion (%s) to match '%s'", buf.String(), expr) + } + } + }) + } +} + +func TestGenZshCompletionHidden(t *testing.T) { tcs := []struct { name string root *Command expectedExpressions []string }{ { - name: "trivial", - root: &Command{Use: "trivialapp"}, - expectedExpressions: []string{"#compdef trivial"}, - }, - { - name: "linear", + name: "hidden commands", root: func() *Command { - r := &Command{Use: "linear"} - - sub1 := &Command{Use: "sub1"} - r.AddCommand(sub1) - - sub2 := &Command{Use: "sub2"} - sub1.AddCommand(sub2) - - sub3 := &Command{Use: "sub3"} - sub2.AddCommand(sub3) - return r - }(), - expectedExpressions: []string{"sub1", "sub2", "sub3"}, - }, - { - name: "flat", - root: func() *Command { - r := &Command{Use: "flat"} - r.AddCommand(&Command{Use: "c1"}) - r.AddCommand(&Command{Use: "c2"}) - return r - }(), - expectedExpressions: []string{"(c1 c2)"}, - }, - { - name: "tree", - root: func() *Command { - r := &Command{Use: "tree"} - - sub1 := &Command{Use: "sub1"} - r.AddCommand(sub1) - - sub11 := &Command{Use: "sub11"} - sub12 := &Command{Use: "sub12"} - - sub1.AddCommand(sub11) - sub1.AddCommand(sub12) - - sub2 := &Command{Use: "sub2"} - r.AddCommand(sub2) - - sub21 := &Command{Use: "sub21"} - sub22 := &Command{Use: "sub22"} - - sub2.AddCommand(sub21) - sub2.AddCommand(sub22) + r := &Command{ + Use: "main", + Short: "main short description", + } + s1 := &Command{ + Use: "sub1", + Hidden: true, + Run: emptyRun, + } + s2 := &Command{ + Use: "sub2", + Short: "short sub2 description", + Run: emptyRun, + } + r.AddCommand(s1, s2) return r }(), - expectedExpressions: []string{"(sub11 sub12)", "(sub21 sub22)"}, + expectedExpressions: []string{ + "sub1", + }, + }, + { + name: "hidden flags", + root: func() *Command { + var hidden string + r := &Command{ + Use: "root", + Short: "root short description", + Run: emptyRun, + } + r.Flags().StringVarP(&hidden, "hidden", "H", hidden, "hidden usage") + if err := r.Flags().MarkHidden("hidden"); err != nil { + t.Errorf("Error setting flag hidden: %v\n", err) + } + return r + }(), + expectedExpressions: []string{ + "--hidden", + }, }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { + tc.root.Execute() buf := new(bytes.Buffer) - tc.root.GenZshCompletion(buf) + if err := tc.root.GenZshCompletion(buf); err != nil { + t.Error(err) + } output := buf.String() - for _, expectedExpression := range tc.expectedExpressions { - if !strings.Contains(output, expectedExpression) { - t.Errorf("Expected completion to contain %q somewhere; got %q", expectedExpression, output) + for _, expr := range tc.expectedExpressions { + if strings.Contains(output, expr) { + t.Errorf("Expected completion (%s) not to contain '%s' but it does", output, expr) } } }) } } + +func TestMarkZshCompPositionalArgumentFile(t *testing.T) { + t.Run("Doesn't allow overwriting existing positional argument", func(t *testing.T) { + c := &Command{} + if err := c.MarkZshCompPositionalArgumentFile(1, "*.log"); err != nil { + t.Errorf("Received error when we shouldn't have: %v\n", err) + } + if err := c.MarkZshCompPositionalArgumentFile(1); err == nil { + t.Error("Didn't receive an error when trying to overwrite argument position") + } + }) + + t.Run("Refuses to accept argPosition less then 1", func(t *testing.T) { + c := &Command{} + err := c.MarkZshCompPositionalArgumentFile(0, "*") + if err == nil { + t.Fatal("Error was not thrown when indicating argument position 0") + } + if !strings.Contains(err.Error(), "position") { + t.Errorf("expected error message '%s' to contain 'position'", err.Error()) + } + }) +} + +func TestMarkZshCompPositionalArgumentWords(t *testing.T) { + t.Run("Doesn't allow overwriting existing positional argument", func(t *testing.T) { + c := &Command{} + if err := c.MarkZshCompPositionalArgumentFile(1, "*.log"); err != nil { + t.Errorf("Received error when we shouldn't have: %v\n", err) + } + if err := c.MarkZshCompPositionalArgumentWords(1, "hello"); err == nil { + t.Error("Didn't receive an error when trying to overwrite argument position") + } + }) + + t.Run("Doesn't allow calling without words", func(t *testing.T) { + c := &Command{} + if err := c.MarkZshCompPositionalArgumentWords(0); err == nil { + t.Error("Should not allow saving empty word list for annotation") + } + }) + + t.Run("Refuses to accept argPosition less then 1", func(t *testing.T) { + c := &Command{} + err := c.MarkZshCompPositionalArgumentWords(0, "word") + if err == nil { + t.Fatal("Should not allow setting argument position less then 1") + } + if !strings.Contains(err.Error(), "position") { + t.Errorf("Expected error '%s' to contain 'position' but didn't", err.Error()) + } + }) +} + +func BenchmarkMediumSizeConstruct(b *testing.B) { + root := constructLargeCommandHierarchy() + // if err := root.GenZshCompletionFile("_mycmd"); err != nil { + // b.Error(err) + // } + + for i := 0; i < b.N; i++ { + buf := new(bytes.Buffer) + err := root.GenZshCompletion(buf) + if err != nil { + b.Error(err) + } + } +} + +func TestExtractFlags(t *testing.T) { + var debug, cmdc, cmdd bool + c := &Command{ + Use: "cmdC", + Long: "Command C", + } + c.PersistentFlags().BoolVarP(&debug, "debug", "d", debug, "debug mode") + c.Flags().BoolVar(&cmdc, "cmd-c", cmdc, "Command C") + d := &Command{ + Use: "CmdD", + Long: "Command D", + } + d.Flags().BoolVar(&cmdd, "cmd-d", cmdd, "Command D") + c.AddCommand(d) + + resC := zshCompExtractFlag(c) + resD := zshCompExtractFlag(d) + + if len(resC) != 2 { + t.Errorf("expected Command C to return 2 flags, got %d", len(resC)) + } + if len(resD) != 2 { + t.Errorf("expected Command D to return 2 flags, got %d", len(resD)) + } +} + +func constructLargeCommandHierarchy() *Command { + var config, st1, st2 string + var long, debug bool + var in1, in2 int + var verbose []bool + + r := genTestCommand("mycmd", false) + r.PersistentFlags().StringVarP(&config, "config", "c", config, "config usage") + if err := r.MarkPersistentFlagFilename("config", "*"); err != nil { + panic(err) + } + s1 := genTestCommand("sub1", true) + s1.Flags().BoolVar(&long, "long", long, "long description") + s1.Flags().BoolSliceVar(&verbose, "verbose", verbose, "verbose description") + s1.Flags().StringArray("option", []string{}, "various options") + s2 := genTestCommand("sub2", true) + s2.PersistentFlags().BoolVar(&debug, "debug", debug, "debug description") + s3 := genTestCommand("sub3", true) + s3.Hidden = true + s1_1 := genTestCommand("sub1sub1", true) + s1_1.Flags().StringVar(&st1, "st1", st1, "st1 description") + s1_1.Flags().StringVar(&st2, "st2", st2, "st2 description") + s1_2 := genTestCommand("sub1sub2", true) + s1_3 := genTestCommand("sub1sub3", true) + s1_3.Flags().IntVar(&in1, "int1", in1, "int1 description") + s1_3.Flags().IntVar(&in2, "int2", in2, "int2 description") + s1_3.Flags().StringArrayP("option", "O", []string{}, "more options") + s2_1 := genTestCommand("sub2sub1", true) + s2_2 := genTestCommand("sub2sub2", true) + s2_3 := genTestCommand("sub2sub3", true) + s2_4 := genTestCommand("sub2sub4", true) + s2_5 := genTestCommand("sub2sub5", true) + + s1.AddCommand(s1_1, s1_2, s1_3) + s2.AddCommand(s2_1, s2_2, s2_3, s2_4, s2_5) + r.AddCommand(s1, s2, s3) + r.Execute() + return r +} + +func genTestCommand(name string, withRun bool) *Command { + r := &Command{ + Use: name, + Short: name + " short description", + Long: "Long description for " + name, + } + if withRun { + r.Run = emptyRun + } + + return r +}