diff --git a/command_test.go b/command_test.go index 16cc41b4..b2ec5dcd 100644 --- a/command_test.go +++ b/command_test.go @@ -148,6 +148,38 @@ func TestSubcommandExecuteC(t *testing.T) { } } +func TestSetContext(t *testing.T) { + ctx := context.TODO() + + aKey := "akey" + aVal := "aval" + + parentRun := func(cmd *Command, args []string) { + ctx := context.WithValue(cmd.Context(), aKey, aVal) + cmd.SetContext(ctx) + } + + childRun := func(cmd *Command, args []string) { + if val := cmd.Context().Value(aKey); val != aVal { + t.Errorf(`Context attribute not found in child command. Expected: "%+v". Have: "%+v"`, aVal, val) + } + } + + rootCmd := &Command{Use: "root", Run: parentRun, PersistentPreRun: parentRun} + childCmd := &Command{Use: "child", Run: childRun} + + rootCmd.AddCommand(childCmd) + + if _, err := executeCommandWithContext(ctx, rootCmd, ""); err != nil { + t.Errorf("Root command must not fail: %+v", err) + } + + if _, err := executeCommandWithContext(ctx, rootCmd, "child"); err != nil { + t.Errorf("Subcommand must not fail: %+v", err) + } + +} + func TestExecuteContext(t *testing.T) { ctx := context.TODO()