diff --git a/.travis.yml b/.travis.yml index cb2bf0d5..68efa136 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,8 +2,8 @@ language: go matrix: include: - - go: 1.7.5 - - go: 1.8.1 + - go: 1.7.6 + - go: 1.8.3 - go: tip allow_failures: - go: tip diff --git a/README.md b/README.md index e249c1bc..a38137b2 100644 --- a/README.md +++ b/README.md @@ -485,6 +485,36 @@ when the `--author` flag is not provided by user. More in [viper documentation](https://github.com/spf13/viper#working-with-flags). +## Positional and Custom Arguments + +Validation of positional arguments can be specified using the `Args` field. + +The follow validators are built in: + +- `NoArgs` - the command will report an error if there are any positional args. +- `ArbitraryArgs` - the command will accept any args. +- `OnlyValidArgs` - the command will report an error if there are any positional args that are not in the ValidArgs list. +- `MinimumNArgs(int)` - the command will report an error if there are not at least N positional args. +- `MaximumNArgs(int)` - the command will report an error if there are more than N positional args. +- `ExactArgs(int)` - the command will report an error if there are not exactly N positional args. +- `RangeArgs(min, max)` - the command will report an error if the number of args is not between the minimum and maximum number of expected args. + +A custom validator can be provided like this: + +```go + +Args: func validColorArgs(cmd *cobra.Command, args []string) error { + if err := cli.RequiresMinArgs(1)(cmd, args); err != nil { + return err + } + if myapp.IsValidColor(args[0]) { + return nil + } + return fmt.Errorf("Invalid color specified: %s", args[0]) +} + +``` + ## Example In the example below, we have defined three commands. Two are at the top level diff --git a/args.go b/args.go new file mode 100644 index 00000000..94a6ca27 --- /dev/null +++ b/args.go @@ -0,0 +1,98 @@ +package cobra + +import ( + "fmt" +) + +type PositionalArgs func(cmd *Command, args []string) error + +// Legacy arg validation has the following behaviour: +// - root commands with no subcommands can take arbitrary arguments +// - root commands with subcommands will do subcommand validity checking +// - subcommands will always accept arbitrary arguments +func legacyArgs(cmd *Command, args []string) error { + // no subcommand, always take args + if !cmd.HasSubCommands() { + return nil + } + + // root command with subcommands, do subcommand checking + if !cmd.HasParent() && len(args) > 0 { + return fmt.Errorf("unknown command %q for %q%s", args[0], cmd.CommandPath(), cmd.findSuggestions(args[0])) + } + return nil +} + +// NoArgs returns an error if any args are included +func NoArgs(cmd *Command, args []string) error { + if len(args) > 0 { + return fmt.Errorf("unknown command %q for %q", args[0], cmd.CommandPath()) + } + return nil +} + +// OnlyValidArgs returns an error if any args are not in the list of ValidArgs +func OnlyValidArgs(cmd *Command, args []string) error { + if len(cmd.ValidArgs) > 0 { + for _, v := range args { + if !stringInSlice(v, cmd.ValidArgs) { + return fmt.Errorf("invalid argument %q for %q%s", v, cmd.CommandPath(), cmd.findSuggestions(args[0])) + } + } + } + return nil +} + +func stringInSlice(a string, list []string) bool { + for _, b := range list { + if b == a { + return true + } + } + return false +} + +// ArbitraryArgs never returns an error +func ArbitraryArgs(cmd *Command, args []string) error { + return nil +} + +// MinimumNArgs returns an error if there is not at least N args +func MinimumNArgs(n int) PositionalArgs { + return func(cmd *Command, args []string) error { + if len(args) < n { + return fmt.Errorf("requires at least %d arg(s), only received %d", n, len(args)) + } + return nil + } +} + +// MaximumNArgs returns an error if there are more than N args +func MaximumNArgs(n int) PositionalArgs { + return func(cmd *Command, args []string) error { + if len(args) > n { + return fmt.Errorf("accepts at most %d arg(s), received %d", n, len(args)) + } + return nil + } +} + +// ExactArgs returns an error if there are not exactly n args +func ExactArgs(n int) PositionalArgs { + return func(cmd *Command, args []string) error { + if len(args) != n { + return fmt.Errorf("accepts %d arg(s), received %d", n, len(args)) + } + return nil + } +} + +// RangeArgs returns an error if the number of args is not within the expected range +func RangeArgs(min int, max int) PositionalArgs { + return func(cmd *Command, args []string) error { + if len(args) < min || len(args) > max { + return fmt.Errorf("accepts between %d and %d arg(s), received %d", min, max, len(args)) + } + return nil + } +} diff --git a/bash_completions_test.go b/bash_completions_test.go index 7511376a..071a6a2a 100644 --- a/bash_completions_test.go +++ b/bash_completions_test.go @@ -117,6 +117,8 @@ func TestBashCompletions(t *testing.T) { // check for filename extension flags check(t, str, `flags_completion+=("_filedir")`) // check for filename extension flags + check(t, str, `must_have_one_noun+=("three")`) + // check for filename extention flags check(t, str, `flags_completion+=("__handle_filename_extension_flag json|yaml|yml")`) // check for custom flags check(t, str, `flags_completion+=("__complete_custom")`) diff --git a/cobra.go b/cobra.go index 2726d19e..8928cefc 100644 --- a/cobra.go +++ b/cobra.go @@ -47,6 +47,15 @@ var EnablePrefixMatching = false // To disable sorting, set it to false. var EnableCommandSorting = true +// MousetrapHelpText enables an information splash screen on Windows +// if the CLI is started from explorer.exe. +// To disable the mousetrap, just set this variable to blank string (""). +// Works only on Microsoft Windows. +var MousetrapHelpText string = `This is a command line tool. + +You need to open cmd.exe and run it from there. +` + // AddTemplateFunc adds a template function that's available to Usage and Help // template generation. func AddTemplateFunc(name string, tmplFunc interface{}) { diff --git a/cobra/cmd/add.go b/cobra/cmd/add.go index 45f00bb5..30f83662 100644 --- a/cobra/cmd/add.go +++ b/cobra/cmd/add.go @@ -121,7 +121,7 @@ func validateCmdName(source string) string { func createCmdFile(license License, path, cmdName string) { template := `{{comment .copyright}} -{{comment .license}} +{{if .license}}{{comment .license}}{{end}} package {{.cmdPackage}} diff --git a/cobra/cmd/add_test.go b/cobra/cmd/add_test.go index dacbe838..1aa5f537 100644 --- a/cobra/cmd/add_test.go +++ b/cobra/cmd/add_test.go @@ -6,6 +6,8 @@ import ( "os" "path/filepath" "testing" + + "github.com/spf13/viper" ) // TestGoldenAddCmd initializes the project "github.com/spf13/testproject" @@ -20,6 +22,9 @@ func TestGoldenAddCmd(t *testing.T) { // Initialize the project at first. initializeProject(project) defer os.RemoveAll(project.AbsPath()) + defer viper.Set("year", nil) + + viper.Set("year", 2017) // For reproducible builds // Then add the "test" command. cmdName := "test" @@ -48,7 +53,7 @@ func TestGoldenAddCmd(t *testing.T) { goldenPath := filepath.Join("testdata", filepath.Base(path)+".golden") switch relPath { - // Know directories. + // Known directories. case ".": return nil // Known files. diff --git a/cobra/cmd/golden_test.go b/cobra/cmd/golden_test.go index 0ac7e893..59a5a1c9 100644 --- a/cobra/cmd/golden_test.go +++ b/cobra/cmd/golden_test.go @@ -39,11 +39,11 @@ func compareFiles(pathA, pathB string) error { // Don't execute diff if it can't be found. return nil } - diffCmd := exec.Command(diffPath, pathA, pathB) + diffCmd := exec.Command(diffPath, "-u", pathA, pathB) diffCmd.Stdout = output diffCmd.Stderr = output - output.WriteString("$ diff " + pathA + " " + pathB + "\n") + output.WriteString("$ diff -u " + pathA + " " + pathB + "\n") if err := diffCmd.Run(); err != nil { output.WriteString("\n" + err.Error()) } diff --git a/cobra/cmd/helpers.go b/cobra/cmd/helpers.go index 6114227d..c5e261ce 100644 --- a/cobra/cmd/helpers.go +++ b/cobra/cmd/helpers.go @@ -45,24 +45,34 @@ func er(msg interface{}) { } // isEmpty checks if a given path is empty. +// Hidden files in path are ignored. func isEmpty(path string) bool { fi, err := os.Stat(path) if err != nil { er(err) } - if fi.IsDir() { - f, err := os.Open(path) - if err != nil { - er(err) - } - defer f.Close() - dirs, err := f.Readdirnames(1) - if err != nil && err != io.EOF { - er(err) - } - return len(dirs) == 0 + + if !fi.IsDir() { + return fi.Size() == 0 } - return fi.Size() == 0 + + f, err := os.Open(path) + if err != nil { + er(err) + } + defer f.Close() + + names, err := f.Readdirnames(-1) + if err != nil && err != io.EOF { + er(err) + } + + for _, name := range names { + if len(name) > 0 && name[0] != '.' { + return false + } + } + return true } // exists checks if a file or directory exists. diff --git a/cobra/cmd/init_test.go b/cobra/cmd/init_test.go index 9a918b9b..1ba9a54b 100644 --- a/cobra/cmd/init_test.go +++ b/cobra/cmd/init_test.go @@ -6,6 +6,8 @@ import ( "os" "path/filepath" "testing" + + "github.com/spf13/viper" ) // TestGoldenInitCmd initializes the project "github.com/spf13/testproject" @@ -16,6 +18,9 @@ func TestGoldenInitCmd(t *testing.T) { projectName := "github.com/spf13/testproject" project := NewProject(projectName) defer os.RemoveAll(project.AbsPath()) + defer viper.Set("year", nil) + + viper.Set("year", 2017) // For reproducible builds os.Args = []string{"cobra", "init", projectName} if err := rootCmd.Execute(); err != nil { @@ -44,7 +49,7 @@ func TestGoldenInitCmd(t *testing.T) { goldenPath := filepath.Join("testdata", filepath.Base(path)+".golden") switch relPath { - // Know directories. + // Known directories. case ".", "cmd": return nil // Known files. diff --git a/cobra/cmd/license_agpl.go b/cobra/cmd/license_agpl.go index 4ea036ed..bc22e973 100644 --- a/cobra/cmd/license_agpl.go +++ b/cobra/cmd/license_agpl.go @@ -4,8 +4,7 @@ func initAgpl() { Licenses["agpl"] = License{ Name: "GNU Affero General Public License", PossibleMatches: []string{"agpl", "affero gpl", "gnu agpl"}, - Header: `{{.copyright}} - + Header: ` This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or diff --git a/cobra/cmd/license_apache_2.go b/cobra/cmd/license_apache_2.go index 3f330867..38393d54 100644 --- a/cobra/cmd/license_apache_2.go +++ b/cobra/cmd/license_apache_2.go @@ -19,7 +19,8 @@ func initApache2() { Licenses["apache"] = License{ Name: "Apache 2.0", PossibleMatches: []string{"apache", "apache20", "apache 2.0", "apache2.0", "apache-2.0"}, - Header: `Licensed under the Apache License, Version 2.0 (the "License"); + Header: ` +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/cobra/cmd/license_bsd_clause_2.go b/cobra/cmd/license_bsd_clause_2.go index f2982dab..4a847e04 100644 --- a/cobra/cmd/license_bsd_clause_2.go +++ b/cobra/cmd/license_bsd_clause_2.go @@ -20,8 +20,7 @@ func initBsdClause2() { Name: "Simplified BSD License", PossibleMatches: []string{"freebsd", "simpbsd", "simple bsd", "2-clause bsd", "2 clause bsd", "simplified bsd license"}, - Header: ` -All rights reserved. + Header: `All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/cobra/cmd/license_bsd_clause_3.go b/cobra/cmd/license_bsd_clause_3.go index 39c9571f..c7476b31 100644 --- a/cobra/cmd/license_bsd_clause_3.go +++ b/cobra/cmd/license_bsd_clause_3.go @@ -19,8 +19,7 @@ func initBsdClause3() { Licenses["bsd"] = License{ Name: "NewBSD", PossibleMatches: []string{"bsd", "newbsd", "3 clause bsd", "3-clause bsd"}, - Header: ` -All rights reserved. + Header: `All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/cobra/cmd/license_gpl_2.go b/cobra/cmd/license_gpl_2.go index 054b470f..03e05b3a 100644 --- a/cobra/cmd/license_gpl_2.go +++ b/cobra/cmd/license_gpl_2.go @@ -19,20 +19,19 @@ func initGpl2() { Licenses["gpl2"] = License{ Name: "GNU General Public License 2.0", PossibleMatches: []string{"gpl2", "gnu gpl2", "gplv2"}, - Header: `{{.copyright}} + Header: ` +This program is free software; you can redistribute it and/or +modify it under the terms of the GNU General Public License +as published by the Free Software Foundation; either version 2 +of the License, or (at your option) any later version. - This program is free software; you can redistribute it and/or - modify it under the terms of the GNU General Public License - as published by the Free Software Foundation; either version 2 - of the License, or (at your option) any later version. +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU Lesser General Public License - along with this program. If not, see .`, +You should have received a copy of the GNU Lesser General Public License +along with this program. If not, see .`, Text: ` GNU GENERAL PUBLIC LICENSE Version 2, June 1991 diff --git a/cobra/cmd/license_gpl_3.go b/cobra/cmd/license_gpl_3.go index d1ef656a..ce07679c 100644 --- a/cobra/cmd/license_gpl_3.go +++ b/cobra/cmd/license_gpl_3.go @@ -19,8 +19,7 @@ func initGpl3() { Licenses["gpl3"] = License{ Name: "GNU General Public License 3.0", PossibleMatches: []string{"gpl3", "gplv3", "gpl", "gnu gpl3", "gnu gpl"}, - Header: `{{.copyright}} - + Header: ` This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or diff --git a/cobra/cmd/license_lgpl.go b/cobra/cmd/license_lgpl.go index 75fd043b..0f8b96ca 100644 --- a/cobra/cmd/license_lgpl.go +++ b/cobra/cmd/license_lgpl.go @@ -4,8 +4,7 @@ func initLgpl() { Licenses["lgpl"] = License{ Name: "GNU Lesser General Public License", PossibleMatches: []string{"lgpl", "lesser gpl", "gnu lgpl"}, - Header: `{{.copyright}} - + Header: ` This program is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or diff --git a/cobra/cmd/license_mit.go b/cobra/cmd/license_mit.go index 4ff5a4cd..bd2d0c4f 100644 --- a/cobra/cmd/license_mit.go +++ b/cobra/cmd/license_mit.go @@ -17,7 +17,7 @@ package cmd func initMit() { Licenses["mit"] = License{ - Name: "Mit", + Name: "MIT License", PossibleMatches: []string{"mit"}, Header: ` Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/cobra/cmd/licenses.go b/cobra/cmd/licenses.go index d73e6fb3..61444f66 100644 --- a/cobra/cmd/licenses.go +++ b/cobra/cmd/licenses.go @@ -77,7 +77,11 @@ func getLicense() License { func copyrightLine() string { author := viper.GetString("author") - year := time.Now().Format("2006") + + year := viper.GetString("year") // For reproducible builds + if year == "" { + year = time.Now().Format("2006") + } return "Copyright © " + year + " " + author } diff --git a/cobra/cmd/root.go b/cobra/cmd/root.go index 1c5e6907..19568f98 100644 --- a/cobra/cmd/root.go +++ b/cobra/cmd/root.go @@ -40,7 +40,7 @@ func Execute() { } func init() { - initViper() + cobra.OnInitialize(initConfig) rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.cobra.yaml)") rootCmd.PersistentFlags().StringP("author", "a", "YOUR NAME", "author name for copyright attribution") @@ -55,7 +55,7 @@ func init() { rootCmd.AddCommand(initCmd) } -func initViper() { +func initConfig() { if cfgFile != "" { // Use config file from the flag. viper.SetConfigFile(cfgFile) diff --git a/cobra/cmd/testdata/main.go.golden b/cobra/cmd/testdata/main.go.golden index 69ecbd48..cdbe38d7 100644 --- a/cobra/cmd/testdata/main.go.golden +++ b/cobra/cmd/testdata/main.go.golden @@ -1,4 +1,5 @@ // Copyright © 2017 NAME HERE +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/cobra/cmd/testdata/root.go.golden b/cobra/cmd/testdata/root.go.golden index ecc87601..8eeeae89 100644 --- a/cobra/cmd/testdata/root.go.golden +++ b/cobra/cmd/testdata/root.go.golden @@ -1,4 +1,5 @@ // Copyright © 2017 NAME HERE +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/cobra/cmd/testdata/test.go.golden b/cobra/cmd/testdata/test.go.golden index c8319d1d..58405680 100644 --- a/cobra/cmd/testdata/test.go.golden +++ b/cobra/cmd/testdata/test.go.golden @@ -1,4 +1,5 @@ // Copyright © 2017 NAME HERE +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/cobra_test.go b/cobra_test.go index 576c97d3..d5df951e 100644 --- a/cobra_test.go +++ b/cobra_test.go @@ -36,6 +36,7 @@ var cmdHidden = &Command{ var cmdPrint = &Command{ Use: "print [string to print]", + Args: MinimumNArgs(1), Short: "Print anything to the screen", Long: `an absolutely utterly useless command for testing.`, Run: func(cmd *Command, args []string) { @@ -75,6 +76,7 @@ var cmdDeprecated = &Command{ Deprecated: "Please use echo instead", Run: func(cmd *Command, args []string) { }, + Args: NoArgs, } var cmdTimes = &Command{ @@ -88,6 +90,8 @@ var cmdTimes = &Command{ Run: func(cmd *Command, args []string) { tt = args }, + Args: OnlyValidArgs, + ValidArgs: []string{"one", "two", "three", "four"}, } var cmdRootNoRun = &Command{ @@ -105,6 +109,16 @@ var cmdRootSameName = &Command{ Long: "The root description for help", } +var cmdRootTakesArgs = &Command{ + Use: "root-with-args [random args]", + Short: "The root can run it's own function and takes args!", + Long: "The root description for help, and some args", + Run: func(cmd *Command, args []string) { + tr = args + }, + Args: ArbitraryArgs, +} + var cmdRootWithRun = &Command{ Use: "cobra-test", Short: "The root can run its own function", @@ -458,6 +472,63 @@ func TestUsage(t *testing.T) { checkResultOmits(t, x, cmdCustomFlags.Use+" [flags]") } +func TestRootTakesNoArgs(t *testing.T) { + c := initializeWithSameName() + c.AddCommand(cmdPrint, cmdEcho) + result := simpleTester(c, "illegal") + + if result.Error == nil { + t.Fatal("Expected an error") + } + + expectedError := `unknown command "illegal" for "print"` + if !strings.Contains(result.Error.Error(), expectedError) { + t.Errorf("exptected %v, got %v", expectedError, result.Error.Error()) + } +} + +func TestRootTakesArgs(t *testing.T) { + c := cmdRootTakesArgs + result := simpleTester(c, "legal") + + if result.Error != nil { + t.Errorf("expected no error, but got %v", result.Error) + } +} + +func TestSubCmdTakesNoArgs(t *testing.T) { + result := fullSetupTest("deprecated", "illegal") + + if result.Error == nil { + t.Fatal("Expected an error") + } + + expectedError := `unknown command "illegal" for "cobra-test deprecated"` + if !strings.Contains(result.Error.Error(), expectedError) { + t.Errorf("expected %v, got %v", expectedError, result.Error.Error()) + } +} + +func TestSubCmdTakesArgs(t *testing.T) { + noRRSetupTest("echo", "times", "one", "two") + if strings.Join(tt, " ") != "one two" { + t.Error("Command didn't parse correctly") + } +} + +func TestCmdOnlyValidArgs(t *testing.T) { + result := noRRSetupTest("echo", "times", "one", "two", "five") + + if result.Error == nil { + t.Fatal("Expected an error") + } + + expectedError := `invalid argument "five"` + if !strings.Contains(result.Error.Error(), expectedError) { + t.Errorf("expected %v, got %v", expectedError, result.Error.Error()) + } +} + func TestFlagLong(t *testing.T) { noRRSetupTest("echo", "--intone=13", "something", "--", "here") @@ -672,9 +743,9 @@ func TestPersistentFlags(t *testing.T) { } // persistentFlag should act like normal flag on its own command - fullSetupTest("echo", "times", "-s", "again", "-c", "-p", "test", "here") + fullSetupTest("echo", "times", "-s", "again", "-c", "-p", "one", "two") - if strings.Join(tt, " ") != "test here" { + if strings.Join(tt, " ") != "one two" { t.Errorf("flags didn't leave proper args remaining. %s given", tt) } @@ -1095,7 +1166,7 @@ func TestGlobalNormFuncPropagation(t *testing.T) { rootCmd := initialize() rootCmd.SetGlobalNormalizationFunc(normFunc) - if reflect.ValueOf(normFunc) != reflect.ValueOf(rootCmd.GlobalNormalizationFunc()) { + if reflect.ValueOf(normFunc).Pointer() != reflect.ValueOf(rootCmd.GlobalNormalizationFunc()).Pointer() { t.Error("rootCmd seems to have a wrong normalization function") } diff --git a/command.go b/command.go index 0ae8be21..4f65d770 100644 --- a/command.go +++ b/command.go @@ -59,6 +59,8 @@ type Command struct { // but accepted if entered manually. ArgAliases []string + // Expected arguments + Args PositionalArgs // BashCompletionFunction is custom functions used by the bash autocompletion generator. BashCompletionFunction string @@ -513,33 +515,29 @@ func (c *Command) Find(args []string) (*Command, []string, error) { } commandFound, a := innerfind(c, args) - argsWOflags := stripFlags(a, commandFound) - - // no subcommand, always take args - if !commandFound.HasSubCommands() { - return commandFound, a, nil + if commandFound.Args == nil { + return commandFound, a, legacyArgs(commandFound, stripFlags(a, commandFound)) } - - // root command with subcommands, do subcommand checking - if commandFound == c && len(argsWOflags) > 0 { - suggestionsString := "" - if !c.DisableSuggestions { - if c.SuggestionsMinimumDistance <= 0 { - c.SuggestionsMinimumDistance = 2 - } - if suggestions := c.SuggestionsFor(argsWOflags[0]); len(suggestions) > 0 { - suggestionsString += "\n\nDid you mean this?\n" - for _, s := range suggestions { - suggestionsString += fmt.Sprintf("\t%v\n", s) - } - } - } - return commandFound, a, fmt.Errorf("unknown command %q for %q%s", argsWOflags[0], commandFound.CommandPath(), suggestionsString) - } - return commandFound, a, nil } +func (c *Command) findSuggestions(arg string) string { + if c.DisableSuggestions { + return "" + } + if c.SuggestionsMinimumDistance <= 0 { + c.SuggestionsMinimumDistance = 2 + } + suggestionsString := "" + if suggestions := c.SuggestionsFor(arg); len(suggestions) > 0 { + suggestionsString += "\n\nDid you mean this?\n" + for _, s := range suggestions { + suggestionsString += fmt.Sprintf("\t%v\n", s) + } + } + return suggestionsString +} + // SuggestionsFor provides suggestions for the typedName. func (c *Command) SuggestionsFor(typedName string) []string { suggestions := []string{} @@ -624,6 +622,10 @@ func (c *Command) execute(a []string) (err error) { argWoFlags = a } + if err := c.ValidateArgs(argWoFlags); err != nil { + return err + } + for p := c; p != nil; p = p.Parent() { if p.PersistentPreRunE != nil { if err := p.PersistentPreRunE(c, argWoFlags); err != nil { @@ -747,6 +749,13 @@ func (c *Command) ExecuteC() (cmd *Command, err error) { return cmd, err } +func (c *Command) ValidateArgs(args []string) error { + if c.Args == nil { + return nil + } + return c.Args(c, args) +} + // InitDefaultHelpFlag adds default help flag to c. // It is called automatically by executing the c or by calling help and usage. // If c already has help flag, it will do nothing. @@ -776,7 +785,7 @@ func (c *Command) InitDefaultHelpCmd() { Use: "help [command]", Short: "Help about any command", Long: `Help provides help for any command in the application. - Simply type ` + c.Name() + ` help [path to command] for full details.`, +Simply type ` + c.Name() + ` help [path to command] for full details.`, Run: func(c *Command, args []string) { cmd, _, e := c.Root().Find(args) diff --git a/command_win.go b/command_win.go index 4b0eaa1b..edec728e 100644 --- a/command_win.go +++ b/command_win.go @@ -11,14 +11,8 @@ import ( var preExecHookFn = preExecHook -// enables an information splash screen on Windows if the CLI is started from explorer.exe. -var MousetrapHelpText string = `This is a command line tool - -You need to open cmd.exe and run it from there. -` - func preExecHook(c *Command) { - if mousetrap.StartedByExplorer() { + if MousetrapHelpText != "" && mousetrap.StartedByExplorer() { c.Print(MousetrapHelpText) time.Sleep(5 * time.Second) os.Exit(1) diff --git a/zsh_completions.go b/zsh_completions.go new file mode 100644 index 00000000..b350aeec --- /dev/null +++ b/zsh_completions.go @@ -0,0 +1,114 @@ +package cobra + +import ( + "bytes" + "fmt" + "io" + "strings" +) + +// GenZshCompletion generates a zsh completion file and writes to the passed writer. +func (cmd *Command) GenZshCompletion(w io.Writer) error { + buf := new(bytes.Buffer) + + writeHeader(buf, cmd) + maxDepth := maxDepth(cmd) + writeLevelMapping(buf, maxDepth) + writeLevelCases(buf, maxDepth, cmd) + + _, err := buf.WriteTo(w) + return err +} + +func writeHeader(w io.Writer, cmd *Command) { + fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name()) +} + +func maxDepth(c *Command) int { + if len(c.Commands()) == 0 { + return 0 + } + maxDepthSub := 0 + for _, s := range c.Commands() { + subDepth := maxDepth(s) + if subDepth > maxDepthSub { + maxDepthSub = subDepth + } + } + return 1 + maxDepthSub +} + +func writeLevelMapping(w io.Writer, numLevels int) { + fmt.Fprintln(w, `_arguments \`) + for i := 1; i <= numLevels; i++ { + fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i) + fmt.Fprintln(w) + } + fmt.Fprintf(w, ` '%d: :%s'`, numLevels+1, "_files") + fmt.Fprintln(w) +} + +func writeLevelCases(w io.Writer, maxDepth int, root *Command) { + fmt.Fprintln(w, "case $state in") + defer fmt.Fprintln(w, "esac") + + for i := 1; i <= maxDepth; i++ { + fmt.Fprintf(w, " level%d)\n", i) + writeLevel(w, root, i) + fmt.Fprintln(w, " ;;") + } + fmt.Fprintln(w, " *)") + fmt.Fprintln(w, " _arguments '*: :_files'") + fmt.Fprintln(w, " ;;") +} + +func writeLevel(w io.Writer, root *Command, i int) { + fmt.Fprintf(w, " case $words[%d] in\n", i) + defer fmt.Fprintln(w, " esac") + + commands := filterByLevel(root, i) + byParent := groupByParent(commands) + + for p, c := range byParent { + names := names(c) + fmt.Fprintf(w, " %s)\n", p) + fmt.Fprintf(w, " _arguments '%d: :(%s)'\n", i, strings.Join(names, " ")) + fmt.Fprintln(w, " ;;") + } + fmt.Fprintln(w, " *)") + fmt.Fprintln(w, " _arguments '*: :_files'") + fmt.Fprintln(w, " ;;") + +} + +func filterByLevel(c *Command, l int) []*Command { + cs := make([]*Command, 0) + if l == 0 { + cs = append(cs, c) + return cs + } + for _, s := range c.Commands() { + cs = append(cs, filterByLevel(s, l-1)...) + } + return cs +} + +func groupByParent(commands []*Command) map[string][]*Command { + m := make(map[string][]*Command) + for _, c := range commands { + parent := c.Parent() + if parent == nil { + continue + } + m[parent.Name()] = append(m[parent.Name()], c) + } + return m +} + +func names(commands []*Command) []string { + ns := make([]string, len(commands)) + for i, c := range commands { + ns[i] = c.Name() + } + return ns +} diff --git a/zsh_completions_test.go b/zsh_completions_test.go new file mode 100644 index 00000000..08b85159 --- /dev/null +++ b/zsh_completions_test.go @@ -0,0 +1,88 @@ +package cobra + +import ( + "bytes" + "strings" + "testing" +) + +func TestZshCompletion(t *testing.T) { + tcs := []struct { + name string + root *Command + expectedExpressions []string + }{ + { + name: "trivial", + root: &Command{Use: "trivialapp"}, + expectedExpressions: []string{"#compdef trivial"}, + }, + { + name: "linear", + root: func() *Command { + r := &Command{Use: "linear"} + + sub1 := &Command{Use: "sub1"} + r.AddCommand(sub1) + + sub2 := &Command{Use: "sub2"} + sub1.AddCommand(sub2) + + sub3 := &Command{Use: "sub3"} + sub2.AddCommand(sub3) + return r + }(), + expectedExpressions: []string{"sub1", "sub2", "sub3"}, + }, + { + name: "flat", + root: func() *Command { + r := &Command{Use: "flat"} + r.AddCommand(&Command{Use: "c1"}) + r.AddCommand(&Command{Use: "c2"}) + return r + }(), + expectedExpressions: []string{"(c1 c2)"}, + }, + { + name: "tree", + root: func() *Command { + r := &Command{Use: "tree"} + + sub1 := &Command{Use: "sub1"} + r.AddCommand(sub1) + + sub11 := &Command{Use: "sub11"} + sub12 := &Command{Use: "sub12"} + + sub1.AddCommand(sub11) + sub1.AddCommand(sub12) + + sub2 := &Command{Use: "sub2"} + r.AddCommand(sub2) + + sub21 := &Command{Use: "sub21"} + sub22 := &Command{Use: "sub22"} + + sub2.AddCommand(sub21) + sub2.AddCommand(sub22) + + return r + }(), + expectedExpressions: []string{"(sub11 sub12)", "(sub21 sub22)"}, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) + tc.root.GenZshCompletion(buf) + completion := buf.String() + for _, expectedExpression := range tc.expectedExpressions { + if !strings.Contains(completion, expectedExpression) { + t.Errorf("expected completion to contain '%v' somewhere; got '%v'", expectedExpression, completion) + } + } + }) + } +}