diff --git a/command.go b/command.go index ab3cf69a..7e02b2dd 100644 --- a/command.go +++ b/command.go @@ -16,6 +16,7 @@ package cobra import ( + "bufio" "bytes" "errors" "fmt" @@ -41,6 +42,9 @@ type Command struct { // Use is the one-line usage message. Use string + // Toggle ask prompt. + Ask bool + // Aliases is an array of aliases that can be used instead of the first word in Use. Aliases []string @@ -54,6 +58,9 @@ type Command struct { // Long is the long message shown in the 'help ' output. Long string + // Custom prompt. + Question string + // Example is examples of how to use the command. Example string @@ -794,6 +801,13 @@ func (c *Command) execute(a []string) (err error) { return ErrSubCommandRequired } + if c.Ask { + if !c.ask() { + c.Println("canceled...") + return nil + } + } + c.preRun() argWoFlags := c.Flags().Args() @@ -1605,3 +1619,29 @@ func (c *Command) updateParentsPflags() { c.parentsPflags.AddFlagSet(parent.PersistentFlags()) }) } + +func (c *Command) ask() bool { + if c.Question == "" { + c.Question = "continue? [Y/N]" + } else { + c.Question = strings.TrimRight(c.Question, "[Y/N]") + "[Y/N]" + } + c.Println(c.Question) + reader := bufio.NewReader(c.getIn(os.Stdin)) + for { + answer, err := reader.ReadString('\n') + if err != nil { + c.PrintErrf("Read string failed, err: %v\n", err) + break + } + answer = strings.ToLower(strings.TrimSpace(answer)) + if answer == "y" || answer == "yes" { + return true + } else if answer == "n" || answer == "no" { + break + } else { + continue + } + } + return false +} \ No newline at end of file diff --git a/command_test.go b/command_test.go index b26bd4ab..3a15f54c 100644 --- a/command_test.go +++ b/command_test.go @@ -3,12 +3,11 @@ package cobra import ( "bytes" "fmt" + "github.com/spf13/pflag" "os" "reflect" "strings" "testing" - - "github.com/spf13/pflag" ) func emptyRun(*Command, []string) {} @@ -20,7 +19,7 @@ func executeCommand(root *Command, args ...string) (output string, err error) { func executeCommandC(root *Command, args ...string) (c *Command, output string, err error) { buf := new(bytes.Buffer) - root.SetOutput(buf) + root.SetOut(buf) root.SetArgs(args) c, err = root.ExecuteC() @@ -96,6 +95,63 @@ func TestChildCommand(t *testing.T) { } } + +func TestCommandWithAskQuestionAndAnswerNo(t *testing.T) { + msg := "hello, world" + rootCmd := &Command{ + Use: "root", + Ask: true, + Run: func(c *Command, args []string) { + c.Print(msg) + }, + } + + outBuf,inBuf := new(bytes.Buffer),new(bytes.Buffer) + rootCmd.SetOut(outBuf) + rootCmd.SetIn(inBuf) + + go func() { + inBuf.WriteString("n\n") + }() + _, err := rootCmd.ExecuteC() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + got := outBuf.String() + expected := "continue? [Y/N]\ncanceled...\n" + if got != expected { + t.Errorf("expected: %q, got: %q", expected, got) + } +} + +func TestCommandWithAskQuestionAndAnswerYes(t *testing.T) { + msg := "hello,world" + rootCmd := &Command{ + Use: "root", + Ask: true, + Run: func(c *Command, args []string) { + c.Print(msg) + }, + } + + outBuf,inBuf := new(bytes.Buffer),new(bytes.Buffer) + rootCmd.SetOut(outBuf) + rootCmd.SetIn(inBuf) + + go func() { + inBuf.WriteString("y\n") + }() + _, err := rootCmd.ExecuteC() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + got := outBuf.String() + expected := "continue? [Y/N]\nhello,world" + if got != expected { + t.Errorf("expected: %q, got: %q", expected, got) + } +} + func TestCallCommandWithoutSubcommands(t *testing.T) { rootCmd := &Command{Use: "root", Args: NoArgs, Run: emptyRun} _, err := executeCommand(rootCmd)