diff --git a/command_test.go b/command_test.go index 583cb023..2730f82e 100644 --- a/command_test.go +++ b/command_test.go @@ -2058,3 +2058,101 @@ func TestFParseErrWhitelistSiblingCommand(t *testing.T) { } checkStringContains(t, output, "unknown flag: --unknown") } + +func TestSetContext(t *testing.T) { + key, val := "foo", "bar" + root := &Command{ + Use: "root", + Run: func(cmd *Command, args []string) { + key := cmd.Context().Value(key) + got, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } + }, + } + ctx := context.WithValue(context.Background(), key, val) + root.SetContext(ctx) + err := root.Execute() + if err != nil { + t.Error(err) + } +} + +func TestSetContextPreRun(t *testing.T) { + key, val := "foo", "bar" + root := &Command{ + Use: "root", + PreRun: func(cmd *Command, args []string) { + ctx := context.WithValue(cmd.Context(), key, val) + cmd.SetContext(ctx) + }, + Run: func(cmd *Command, args []string) { + key := cmd.Context().Value(key) + got, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } + }, + } + err := root.Execute() + if err != nil { + t.Error(err) + } +} + +func TestSetContextPreRunOverwrite(t *testing.T) { + key, val := "foo", "bar" + root := &Command{ + Use: "root", + Run: func(cmd *Command, args []string) { + key := cmd.Context().Value(key) + _, ok := key.(string) + if ok { + t.Error("key found in context when not expected") + } + }, + } + ctx := context.WithValue(context.Background(), key, val) + root.SetContext(ctx) + err := root.ExecuteContext(context.Background()) + if err != nil { + t.Error(err) + } +} + +func TestSetContextPersistentPreRun(t *testing.T) { + key, val := "foo", "bar" + root := &Command{ + Use: "root", + PersistentPreRun: func(cmd *Command, args []string) { + ctx := context.WithValue(cmd.Context(), key, val) + cmd.SetContext(ctx) + }, + } + child := &Command{ + Use: "child", + Run: func(cmd *Command, args []string) { + key := cmd.Context().Value(key) + got, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } + }, + } + root.AddCommand(child) + root.SetArgs([]string{"child"}) + err := root.Execute() + if err != nil { + t.Error(err) + } +}