diff --git a/bash_completions.go b/bash_completions.go index 488c14c2..5ca20155 100644 --- a/bash_completions.go +++ b/bash_completions.go @@ -19,9 +19,9 @@ const ( BashCompSubdirsInDir = "cobra_annotation_bash_completion_subdirs_in_dir" ) -func writePreamble(buf *bytes.Buffer, name string) { - buf.WriteString(fmt.Sprintf("# bash completion for %-36s -*- shell-script -*-\n", name)) - buf.WriteString(fmt.Sprintf(` +func writePreamble(buf io.StringWriter, name string) { + WrStringAndCheck(buf, fmt.Sprintf("# bash completion for %-36s -*- shell-script -*-\n", name)) + WrStringAndCheck(buf, fmt.Sprintf(` __%[1]s_debug() { if [[ -n ${BASH_COMP_DEBUG_FILE} ]]; then @@ -282,10 +282,10 @@ __%[1]s_handle_word() `, name)) } -func writePostscript(buf *bytes.Buffer, name string) { +func writePostscript(buf io.StringWriter, name string) { name = strings.Replace(name, ":", "__", -1) - buf.WriteString(fmt.Sprintf("__start_%s()\n", name)) - buf.WriteString(fmt.Sprintf(`{ + WrStringAndCheck(buf, fmt.Sprintf("__start_%s()\n", name)) + WrStringAndCheck(buf, fmt.Sprintf(`{ local cur prev words cword declare -A flaghash 2>/dev/null || : declare -A aliashash 2>/dev/null || : @@ -311,33 +311,33 @@ func writePostscript(buf *bytes.Buffer, name string) { } `, name)) - buf.WriteString(fmt.Sprintf(`if [[ $(type -t compopt) = "builtin" ]]; then + WrStringAndCheck(buf, fmt.Sprintf(`if [[ $(type -t compopt) = "builtin" ]]; then complete -o default -F __start_%s %s else complete -o default -o nospace -F __start_%s %s fi `, name, name, name, name)) - buf.WriteString("# ex: ts=4 sw=4 et filetype=sh\n") + WrStringAndCheck(buf, "# ex: ts=4 sw=4 et filetype=sh\n") } -func writeCommands(buf *bytes.Buffer, cmd *Command) { - buf.WriteString(" commands=()\n") +func writeCommands(buf io.StringWriter, cmd *Command) { + WrStringAndCheck(buf, " commands=()\n") for _, c := range cmd.Commands() { if !c.IsAvailableCommand() || c == cmd.helpCommand { continue } - buf.WriteString(fmt.Sprintf(" commands+=(%q)\n", c.Name())) + WrStringAndCheck(buf, fmt.Sprintf(" commands+=(%q)\n", c.Name())) writeCmdAliases(buf, c) } - buf.WriteString("\n") + WrStringAndCheck(buf, "\n") } -func writeFlagHandler(buf *bytes.Buffer, name string, annotations map[string][]string, cmd *Command) { +func writeFlagHandler(buf io.StringWriter, name string, annotations map[string][]string, cmd *Command) { for key, value := range annotations { switch key { case BashCompFilenameExt: - buf.WriteString(fmt.Sprintf(" flags_with_completion+=(%q)\n", name)) + WrStringAndCheck(buf, fmt.Sprintf(" flags_with_completion+=(%q)\n", name)) var ext string if len(value) > 0 { @@ -345,17 +345,18 @@ func writeFlagHandler(buf *bytes.Buffer, name string, annotations map[string][]s } else { ext = "_filedir" } - buf.WriteString(fmt.Sprintf(" flags_completion+=(%q)\n", ext)) + WrStringAndCheck(buf, fmt.Sprintf(" flags_completion+=(%q)\n", ext)) case BashCompCustom: - buf.WriteString(fmt.Sprintf(" flags_with_completion+=(%q)\n", name)) + WrStringAndCheck(buf, fmt.Sprintf(" flags_with_completion+=(%q)\n", name)) + if len(value) > 0 { handlers := strings.Join(value, "; ") - buf.WriteString(fmt.Sprintf(" flags_completion+=(%q)\n", handlers)) + WrStringAndCheck(buf, fmt.Sprintf(" flags_completion+=(%q)\n", handlers)) } else { - buf.WriteString(" flags_completion+=(:)\n") + WrStringAndCheck(buf, " flags_completion+=(:)\n") } case BashCompSubdirsInDir: - buf.WriteString(fmt.Sprintf(" flags_with_completion+=(%q)\n", name)) + WrStringAndCheck(buf, fmt.Sprintf(" flags_with_completion+=(%q)\n", name)) var ext string if len(value) == 1 { @@ -363,49 +364,51 @@ func writeFlagHandler(buf *bytes.Buffer, name string, annotations map[string][]s } else { ext = "_filedir -d" } - buf.WriteString(fmt.Sprintf(" flags_completion+=(%q)\n", ext)) + WrStringAndCheck(buf, fmt.Sprintf(" flags_completion+=(%q)\n", ext)) } } } -func writeShortFlag(buf *bytes.Buffer, flag *pflag.Flag, cmd *Command) { +const cbn = "\")\n" + +func writeShortFlag(buf io.StringWriter, flag *pflag.Flag, cmd *Command) { name := flag.Shorthand format := " " if len(flag.NoOptDefVal) == 0 { format += "two_word_" } - format += "flags+=(\"-%s\")\n" - buf.WriteString(fmt.Sprintf(format, name)) + format += "flags+=(\"-%s" + cbn + WrStringAndCheck(buf, fmt.Sprintf(format, name)) writeFlagHandler(buf, "-"+name, flag.Annotations, cmd) } -func writeFlag(buf *bytes.Buffer, flag *pflag.Flag, cmd *Command) { +func writeFlag(buf io.StringWriter, flag *pflag.Flag, cmd *Command) { name := flag.Name format := " flags+=(\"--%s" if len(flag.NoOptDefVal) == 0 { format += "=" } - format += "\")\n" - buf.WriteString(fmt.Sprintf(format, name)) + format += cbn + WrStringAndCheck(buf, fmt.Sprintf(format, name)) if len(flag.NoOptDefVal) == 0 { - format = " two_word_flags+=(\"--%s\")\n" - buf.WriteString(fmt.Sprintf(format, name)) + format = " two_word_flags+=(\"--%s" + cbn + WrStringAndCheck(buf, fmt.Sprintf(format, name)) } writeFlagHandler(buf, "--"+name, flag.Annotations, cmd) } -func writeLocalNonPersistentFlag(buf *bytes.Buffer, flag *pflag.Flag) { +func writeLocalNonPersistentFlag(buf io.StringWriter, flag *pflag.Flag) { name := flag.Name format := " local_nonpersistent_flags+=(\"--%s" if len(flag.NoOptDefVal) == 0 { format += "=" } - format += "\")\n" - buf.WriteString(fmt.Sprintf(format, name)) + format += cbn + WrStringAndCheck(buf, fmt.Sprintf(format, name)) } -func writeFlags(buf *bytes.Buffer, cmd *Command) { - buf.WriteString(` flags=() +func writeFlags(buf io.StringWriter, cmd *Command) { + WrStringAndCheck(buf, ` flags=() two_word_flags=() local_nonpersistent_flags=() flags_with_completion=() @@ -435,11 +438,11 @@ func writeFlags(buf *bytes.Buffer, cmd *Command) { } }) - buf.WriteString("\n") + WrStringAndCheck(buf, "\n") } -func writeRequiredFlag(buf *bytes.Buffer, cmd *Command) { - buf.WriteString(" must_have_one_flag=()\n") +func writeRequiredFlag(buf io.StringWriter, cmd *Command) { + WrStringAndCheck(buf, " must_have_one_flag=()\n") flags := cmd.NonInheritedFlags() flags.VisitAll(func(flag *pflag.Flag) { if nonCompletableFlag(flag) { @@ -452,49 +455,49 @@ func writeRequiredFlag(buf *bytes.Buffer, cmd *Command) { if flag.Value.Type() != "bool" { format += "=" } - format += "\")\n" - buf.WriteString(fmt.Sprintf(format, flag.Name)) + format += cbn + WrStringAndCheck(buf, fmt.Sprintf(format, flag.Name)) if len(flag.Shorthand) > 0 { - buf.WriteString(fmt.Sprintf(" must_have_one_flag+=(\"-%s\")\n", flag.Shorthand)) + WrStringAndCheck(buf, fmt.Sprintf(" must_have_one_flag+=(\"-%s"+cbn, flag.Shorthand)) } } } }) } -func writeRequiredNouns(buf *bytes.Buffer, cmd *Command) { - buf.WriteString(" must_have_one_noun=()\n") +func writeRequiredNouns(buf io.StringWriter, cmd *Command) { + WrStringAndCheck(buf, " must_have_one_noun=()\n") sort.Strings(cmd.ValidArgs) for _, value := range cmd.ValidArgs { - buf.WriteString(fmt.Sprintf(" must_have_one_noun+=(%q)\n", value)) + WrStringAndCheck(buf, fmt.Sprintf(" must_have_one_noun+=(%q)\n", value)) } } -func writeCmdAliases(buf *bytes.Buffer, cmd *Command) { +func writeCmdAliases(buf io.StringWriter, cmd *Command) { if len(cmd.Aliases) == 0 { return } sort.Strings(cmd.Aliases) - buf.WriteString(fmt.Sprint(` if [[ -z "${BASH_VERSION}" || "${BASH_VERSINFO[0]}" -gt 3 ]]; then`, "\n")) + WrStringAndCheck(buf, fmt.Sprint(` if [[ -z "${BASH_VERSION}" || "${BASH_VERSINFO[0]}" -gt 3 ]]; then`, "\n")) for _, value := range cmd.Aliases { - buf.WriteString(fmt.Sprintf(" command_aliases+=(%q)\n", value)) - buf.WriteString(fmt.Sprintf(" aliashash[%q]=%q\n", value, cmd.Name())) + WrStringAndCheck(buf, fmt.Sprintf(" command_aliases+=(%q)\n", value)) + WrStringAndCheck(buf, fmt.Sprintf(" aliashash[%q]=%q\n", value, cmd.Name())) } - buf.WriteString(` fi`) - buf.WriteString("\n") + WrStringAndCheck(buf, ` fi`) + WrStringAndCheck(buf, "\n") } -func writeArgAliases(buf *bytes.Buffer, cmd *Command) { - buf.WriteString(" noun_aliases=()\n") +func writeArgAliases(buf io.StringWriter, cmd *Command) { + WrStringAndCheck(buf, " noun_aliases=()\n") sort.Strings(cmd.ArgAliases) for _, value := range cmd.ArgAliases { - buf.WriteString(fmt.Sprintf(" noun_aliases+=(%q)\n", value)) + WrStringAndCheck(buf, fmt.Sprintf(" noun_aliases+=(%q)\n", value)) } } -func gen(buf *bytes.Buffer, cmd *Command) { +func gen(buf io.StringWriter, cmd *Command) { for _, c := range cmd.Commands() { if !c.IsAvailableCommand() || c == cmd.helpCommand { continue @@ -506,22 +509,22 @@ func gen(buf *bytes.Buffer, cmd *Command) { commandName = strings.Replace(commandName, ":", "__", -1) if cmd.Root() == cmd { - buf.WriteString(fmt.Sprintf("_%s_root_command()\n{\n", commandName)) + WrStringAndCheck(buf, fmt.Sprintf("_%s_root_command()\n{\n", commandName)) } else { - buf.WriteString(fmt.Sprintf("_%s()\n{\n", commandName)) + WrStringAndCheck(buf, fmt.Sprintf("_%s()\n{\n", commandName)) } - buf.WriteString(fmt.Sprintf(" last_command=%q\n", commandName)) - buf.WriteString("\n") - buf.WriteString(" command_aliases=()\n") - buf.WriteString("\n") + WrStringAndCheck(buf, fmt.Sprintf(" last_command=%q\n", commandName)) + WrStringAndCheck(buf, "\n") + WrStringAndCheck(buf, " command_aliases=()\n") + WrStringAndCheck(buf, "\n") writeCommands(buf, cmd) writeFlags(buf, cmd) writeRequiredFlag(buf, cmd) writeRequiredNouns(buf, cmd) writeArgAliases(buf, cmd) - buf.WriteString("}\n\n") + WrStringAndCheck(buf, "}\n\n") } // GenBashCompletion generates bash completion file and writes to the passed writer. diff --git a/command_test.go b/command_test.go index 1f8c3c1c..378034b2 100644 --- a/command_test.go +++ b/command_test.go @@ -676,9 +676,9 @@ func TestPersistentFlagsOnChild(t *testing.T) { func TestRequiredFlags(t *testing.T) { c := &Command{Use: "c", Run: emptyRun} c.Flags().String("foo1", "", "") - c.MarkFlagRequired("foo1") + er(c.MarkFlagRequired("foo1")) c.Flags().String("foo2", "", "") - c.MarkFlagRequired("foo2") + er(c.MarkFlagRequired("foo2")) c.Flags().String("bar", "", "") expected := fmt.Sprintf("required flag(s) %q, %q not set", "foo1", "foo2") @@ -694,16 +694,16 @@ func TestRequiredFlags(t *testing.T) { func TestPersistentRequiredFlags(t *testing.T) { parent := &Command{Use: "parent", Run: emptyRun} parent.PersistentFlags().String("foo1", "", "") - parent.MarkPersistentFlagRequired("foo1") + er(parent.MarkPersistentFlagRequired("foo1")) parent.PersistentFlags().String("foo2", "", "") - parent.MarkPersistentFlagRequired("foo2") + er(parent.MarkPersistentFlagRequired("foo2")) parent.Flags().String("foo3", "", "") child := &Command{Use: "child", Run: emptyRun} child.Flags().String("bar1", "", "") - child.MarkFlagRequired("bar1") + er(child.MarkFlagRequired("bar1")) child.Flags().String("bar2", "", "") - child.MarkFlagRequired("bar2") + er(child.MarkFlagRequired("bar2")) child.Flags().String("bar3", "", "") parent.AddCommand(child) @@ -1484,7 +1484,7 @@ func TestMergeCommandLineToFlags(t *testing.T) { func TestUseDeprecatedFlags(t *testing.T) { c := &Command{Use: "c", Run: emptyRun} c.Flags().BoolP("deprecated", "d", false, "deprecated flag") - c.Flags().MarkDeprecated("deprecated", "This flag is deprecated") + er(c.Flags().MarkDeprecated("deprecated", "This flag is deprecated")) output, err := executeCommand(c, "c", "-d") if err != nil {