feat: added custom error handler function

This commit is contained in:
alexraileanu 2022-10-04 11:58:41 +02:00
parent d4040ad8db
commit 5feb403e5e
2 changed files with 40 additions and 2 deletions

View file

@ -161,6 +161,9 @@ type Command struct {
// versionTemplate is the version template defined by user. // versionTemplate is the version template defined by user.
versionTemplate string versionTemplate string
// errorHandlerFunc allows setting a custom error handler by the user.
errorHandlerFunc func(error)
// inReader is a reader defined by the user that replaces stdin // inReader is a reader defined by the user that replaces stdin
inReader io.Reader inReader io.Reader
// outWriter is a writer defined by the user that replaces stdout // outWriter is a writer defined by the user that replaces stdout
@ -323,6 +326,12 @@ func (c *Command) SetGlobalNormalizationFunc(n func(f *flag.FlagSet, name string
} }
} }
// SetErrorHandlerFunc is the function that will be called, if set, when there is any kind of error in the
// execution of the command.
func (c *Command) SetErrorHandlerFunc(f func(error)) {
c.errorHandlerFunc = f
}
// OutOrStdout returns output to stdout. // OutOrStdout returns output to stdout.
func (c *Command) OutOrStdout() io.Writer { func (c *Command) OutOrStdout() io.Writer {
return c.getOut(os.Stdout) return c.getOut(os.Stdout)
@ -979,7 +988,11 @@ func (c *Command) ExecuteC() (cmd *Command, err error) {
c = cmd c = cmd
} }
if !c.SilenceErrors { if !c.SilenceErrors {
c.PrintErrln("Error:", err.Error()) if c.errorHandlerFunc != nil {
c.errorHandlerFunc(err)
} else {
c.PrintErrln("Error:", err.Error())
}
c.PrintErrf("Run '%v --help' for usage.\n", c.CommandPath()) c.PrintErrf("Run '%v --help' for usage.\n", c.CommandPath())
} }
return c, err return c, err
@ -1008,7 +1021,11 @@ func (c *Command) ExecuteC() (cmd *Command, err error) {
// If root command has SilenceErrors flagged, // If root command has SilenceErrors flagged,
// all subcommands should respect it // all subcommands should respect it
if !cmd.SilenceErrors && !c.SilenceErrors { if !cmd.SilenceErrors && !c.SilenceErrors {
c.PrintErrln("Error:", err.Error()) if c.errorHandlerFunc != nil {
c.errorHandlerFunc(err)
} else {
c.PrintErrln("Error:", err.Error())
}
} }
// If root command has SilenceUsage flagged, // If root command has SilenceUsage flagged,

View file

@ -17,6 +17,7 @@ package cobra
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@ -2430,3 +2431,23 @@ func TestHelpflagCommandExecutedWithoutVersionSet(t *testing.T) {
checkStringContains(t, output, HelpFlag) checkStringContains(t, output, HelpFlag)
checkStringOmits(t, output, VersionFlag) checkStringOmits(t, output, VersionFlag)
} }
func TestSetCustomErrorHandler(t *testing.T) {
var writer bytes.Buffer
handler := func(err error) {
writer.Write([]byte(err.Error()))
}
root := &Command{
Use: "root",
RunE: func(cmd *Command, args []string) error {
return errors.New("test error handler function")
},
SilenceUsage: true,
}
root.SetErrorHandlerFunc(handler)
_ = root.Execute()
if writer.String() != "test error handler function" {
t.Errorf("Expected error handler to contain [%s] instead it contains [%s]", "test error handler function", writer.String())
}
}