diff --git a/flag_groups.go b/flag_groups.go index dc784311..0e5e181c 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -24,6 +24,8 @@ import ( const ( requiredAsGroup = "cobra_annotation_required_if_others_set" mutuallyExclusive = "cobra_annotation_mutually_exclusive" + dependsOn = "cobra_annotation_depends_on" + dependsOnAny = "cobra_annotation_depends_on_any" ) // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors @@ -58,6 +60,76 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { } } +// MarkFlagsDependsOn marks the given flags with annotations so that Cobra errors +// if the command is invoked with 1 or more flags that are dependent on a specified +// other. +func (c *Command) MarkFlagsDependsOn(flagNames ...string) { + const format = "Failed to find flag %q and mark it as being part of depends on group" + c.markAnnotation(dependsOn, format, flagNames...) +} + +// MarkFlagDependsOnAny marks the given flags with annotations so that Cobra errors +// if the command is invoked with a flag that is dependent on any 1 of a group of others. +func (c *Command) MarkFlagDependsOnAny(flagNames ...string) { + const format = "Failed to find flag %q and mark it as being part of depends on any group" + c.markAnnotation(dependsOnAny, format, flagNames...) +} + +// markAnnotation currently only used by MarkFlagsDependsOn and MarkFlagDependsOnAny, +// but is generic enough and should be used by MarkFlagsRequiredTogether and +// MarkFlagsMutuallyExclusive. +// - format must contain a single place holder +func (c *Command) markAnnotation(annotation, format string, flagNames ...string) { + c.mergePersistentFlags() + for _, name := range flagNames { + c.setFlagAnnotation(name, annotation, + fmt.Sprintf(format, name), + flagNames..., + ) + } +} + +func (c *Command) setFlagAnnotation(flag string, annotation string, message string, flagNames ...string) { + f := c.Flags().Lookup(flag) + if f == nil { + panic(message) + } + ordered := strings.Join(flagNames, " ") + if err := c.Flags().SetAnnotation( + flag, annotation, + append(f.Annotations[annotation], ordered), + ); err != nil { + panic(err) + } +} + +// The 'special-ness' of a group means that the first member of the group carries +// special meaning. In contrast to the other group types, where all members are equal. +type specialStatusInfo struct { + isSet bool + isSpecial bool +} +type specialStatusInfoData map[string]*specialStatusInfo + +type specialGroupInfo struct { + special string + others []string + // maps the flag name to special status info + data specialStatusInfoData +} +type specialGroupInfoCollection map[string]*specialGroupInfo + +func newSpecialGroup(specialName string, others []string) *specialGroupInfo { + size := len(others) + 1 + result := specialGroupInfo{ + special: specialName, + others: others, + data: make(specialStatusInfoData, size), + } + + return &result +} + // validateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the // first error encountered. func (c *Command) validateFlagGroups() error { @@ -71,9 +143,13 @@ func (c *Command) validateFlagGroups() error { // then a map of each flag name and whether it is set or not. groupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} + dependsOnSpecialGroupStatus := specialGroupInfoCollection{} + dependsOnAnySpecialGroupStatus := specialGroupInfoCollection{} flags.VisitAll(func(pflag *flag.Flag) { processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) + processFlagForSpecialGroupAnnotation(flags, pflag, dependsOn, dependsOnSpecialGroupStatus) + processFlagForSpecialGroupAnnotation(flags, pflag, dependsOnAny, dependsOnAnySpecialGroupStatus) }) if err := validateRequiredFlagGroups(groupStatus); err != nil { @@ -82,6 +158,12 @@ func (c *Command) validateFlagGroups() error { if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { return err } + if err := validateDependsOnFlagGroups(dependsOnSpecialGroupStatus); err != nil { + return err + } + if err := validateDependsOnAnyFlagGroups(dependsOnAnySpecialGroupStatus); err != nil { + return err + } return nil } @@ -95,6 +177,16 @@ func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool { return true } +func hasAnyOfFlags(fs *flag.FlagSet, flagnames ...string) bool { + for _, fname := range flagnames { + f := fs.Lookup(fname) + if f != nil { + return true + } + } + return false +} + func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) { groupInfo, found := pflag.Annotations[annotation] if found { @@ -118,6 +210,52 @@ func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annota } } +func processFlagForSpecialGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, + annotation string, groupStatus specialGroupInfoCollection) { + + if groupInfo, found := pflag.Annotations[annotation]; found { + for _, group := range groupInfo { + if groupStatus[group] == nil { + + flagnames := strings.Split(group, " ") + // it's important to know that the order of the flags is established + // in setFlagAnnotation, which makes the assumption of the first + // item being sepcial, being valid + special := flagnames[0] + others := flagnames[1:] + isFlagSpecial := pflag.Name == special + + // Only consider this flag group at all if the first flag (Special) + // is set and at least 1 of the others is + if isFlagSpecial && flags.Lookup(special) == nil { + continue + } + + if !isFlagSpecial && !hasAnyOfFlags(flags, others...) { + continue + } + + groupStatus[group] = newSpecialGroup(special, others) + for _, name := range flagnames { + groupStatus[group].data[name] = &specialStatusInfo{} + + if name == special { + groupStatus[group].data[special].isSpecial = true + break // short circuit after finding special + } + } + } + + // group exists, but we still need to check if the flag exists in the group, + // because the previous loop is short circuited as soon as we find the special. + if _, found := groupStatus[group].data[pflag.Name]; !found { + groupStatus[group].data[pflag.Name] = &specialStatusInfo{} + } + groupStatus[group].data[pflag.Name].isSet = pflag.Changed + } + } +} + func validateRequiredFlagGroups(data map[string]map[string]bool) error { keys := sortedKeys(data) for _, flagList := range keys { @@ -162,6 +300,66 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error { return nil } +func validateDependsOnFlagGroups(data specialGroupInfoCollection) error { + keys := sortedKeysSpecial(data) + + for _, flagList := range keys { + flagnameAndStatus := data[flagList] + + if flagnameAndStatus.data[flagnameAndStatus.special].isSet { + // rule is satisfied, because the special flag is present, regardless of + // the presence of the other members in the group + return nil + } + + // we have a problem if at least one of present is set, because special + // is not set + present := []string{} + for _, o := range flagnameAndStatus.others { + if flagnameAndStatus.data[o].isSet { + present = append(present, o) + } + } + if len(present) == 0 { + continue + } + sort.Strings(present) + + return fmt.Errorf( + "if any flags in the group %v are set then [%v] must be present; only %v were set", + flagnameAndStatus.others, flagnameAndStatus.special, present, + ) + } + return nil +} + +func validateDependsOnAnyFlagGroups(data specialGroupInfoCollection) error { + keys := sortedKeysSpecial(data) + + for _, flagList := range keys { + flagnameAndStatus := data[flagList] + if !flagnameAndStatus.data[flagnameAndStatus.special].isSet { + return nil + } + + present := []string{} + for _, o := range flagnameAndStatus.others { + if flagnameAndStatus.data[o].isSet { + present = append(present, o) + } + } + if len(present) > 0 { + continue + } + + return fmt.Errorf( + "if [%v] is present, then at least one of the flags in %v must be; none were set", + flagnameAndStatus.special, flagnameAndStatus.others, + ) + } + return nil +} + func sortedKeys(m map[string]map[string]bool) []string { keys := make([]string, len(m)) i := 0 @@ -173,6 +371,18 @@ func sortedKeys(m map[string]map[string]bool) []string { return keys } +// implemented as a duplicate of sortedKeys as generics can't be used yet +func sortedKeysSpecial(m specialGroupInfoCollection) []string { + keys := make([]string, len(m)) + i := 0 + for k := range m { + keys[i] = k + i++ + } + sort.Strings(keys) + return keys +} + // enforceFlagGroupsForCompletion will do the following: // - when a flag in a group is present, other flags in the group will be marked required // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden @@ -185,9 +395,11 @@ func (c *Command) enforceFlagGroupsForCompletion() { flags := c.Flags() groupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} + dependsOnSpecialGroupStatus := specialGroupInfoCollection{} c.Flags().VisitAll(func(pflag *flag.Flag) { processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) + processFlagForSpecialGroupAnnotation(flags, pflag, dependsOn, dependsOnSpecialGroupStatus) }) // If a flag that is part of a group is present, we make all the other flags @@ -220,4 +432,15 @@ func (c *Command) enforceFlagGroupsForCompletion() { } } } + + // if any of others is set, then mark special as required + for _, flagnameAndStatus := range dependsOnSpecialGroupStatus { + for _, o := range flagnameAndStatus.others { + if flagnameAndStatus.data[o].isSet { + c.MarkFlagRequired(flagnameAndStatus.special) + break + } + } + } + // we can't aid the completion process for dependsOnAny } diff --git a/flag_groups_test.go b/flag_groups_test.go index 404ede56..f225c3ce 100644 --- a/flag_groups_test.go +++ b/flag_groups_test.go @@ -42,13 +42,17 @@ func TestValidateFlagGroups(t *testing.T) { // Each test case uses a unique command from the function above. testcases := []struct { - desc string - flagGroupsRequired []string - flagGroupsExclusive []string - subCmdFlagGroupsRequired []string - subCmdFlagGroupsExclusive []string - args []string - expectErr string + desc string + flagGroupsRequired []string + flagGroupsExclusive []string + flagGroupsDependsOn []string + flagGroupsDependsOnAny []string + subCmdFlagGroupsRequired []string + subCmdFlagGroupsExclusive []string + subCmdFlagGroupsDependsOn []string + subCmdFlagGroupsDependsOnAny []string + args []string + expectErr string }{ { desc: "No flags no problem", @@ -121,6 +125,71 @@ func TestValidateFlagGroups(t *testing.T) { subCmdFlagGroupsRequired: []string{"e subonly"}, args: []string{"--e=foo"}, }, + // DependsOn + { + desc: "The dependee 'a' is set so Depends On group is satisfied", + flagGroupsDependsOn: []string{"a b c d"}, + args: []string{"--a=foo", "--b=foo"}, + }, { + desc: "Depends On flag group not satisfied, a is missing, required by b", + flagGroupsDependsOn: []string{"a b c d"}, + args: []string{"--b=foo"}, + expectErr: "if any flags in the group [b c d] are set then [a] must be present; only [b] were set", + }, { + desc: "Depends On flag group not satisfied, a is missing, required by b and c", + flagGroupsDependsOn: []string{"a b c d"}, + args: []string{"--b=foo", "--c=foo"}, + expectErr: "if any flags in the group [b c d] are set then [a] must be present; only [b c] were set", + }, { + desc: "The inherited dependee 'e' is set so Depends On group is satisfied", + subCmdFlagGroupsDependsOn: []string{"e subonly"}, + args: []string{"subcmd", "--e=foo", "--subonly=foo"}, + }, { + desc: "The inherited dependee 'e' is not set so Depends On group not is satisfied", + subCmdFlagGroupsDependsOn: []string{"e subonly"}, + args: []string{"subcmd", "--subonly=foo"}, + expectErr: "if any flags in the group [subonly] are set then [e] must be present; only [subonly] were set", + }, { + desc: "Depends On Multiple exclusive flag group not satisfied returns still returns error", + flagGroupsDependsOn: []string{"a b c d"}, + flagGroupsExclusive: []string{"a b"}, + args: []string{"--a=foo", "--b=foo"}, + expectErr: "if any flags in the group [a b] are set none of the others can be; [a b] were all set", + }, + // DependsOnAny + { + desc: "At least 1 of the dependees are present so Depends On Any is satisfied", + flagGroupsDependsOnAny: []string{"a b c d"}, + args: []string{"--a=foo", "--b=foo"}, + }, { + desc: "All of the dependees are present so Depends On Any is satisfied", + flagGroupsDependsOnAny: []string{"a b c d"}, + args: []string{"--a=foo", "--b=foo", "--c=foo", "--d=foo"}, + }, { + desc: "None of the dependees are present so Depends On Any is not satisfied", + flagGroupsDependsOnAny: []string{"a b c d"}, + args: []string{"--a=foo"}, + expectErr: "if [a] is present, then at least one of the flags in [b c d] must be; none were set", + }, { + desc: "At least 1 of the inherited dependees are present so Depends On Any is satisfied", + subCmdFlagGroupsDependsOnAny: []string{"subonly e f g"}, + args: []string{"subcmd", "--subonly=foo", "--e=foo"}, + }, { + desc: "All of the inherited dependees are present so Depends On Any is satisfied", + subCmdFlagGroupsDependsOnAny: []string{"subonly e f g"}, + args: []string{"subcmd", "--subonly=foo", "--e=foo", "--f=foo", "--g=foo"}, + }, { + desc: "None of the inherited dependees are present so Depends On Any is not satisfied", + subCmdFlagGroupsDependsOnAny: []string{"subonly e f g"}, + args: []string{"subcmd", "--subonly=foo"}, + expectErr: "if [subonly] is present, then at least one of the flags in [e f g] must be; none were set", + }, { + desc: "Depends On Any Multiple exclusive flag group not satisfied returns still returns error", + flagGroupsDependsOnAny: []string{"a b c d"}, + flagGroupsExclusive: []string{"a b"}, + args: []string{"--a=foo", "--b=foo"}, + expectErr: "if any flags in the group [a b] are set none of the others can be; [a b] were all set", + }, } for _, tc := range testcases { t.Run(tc.desc, func(t *testing.T) { @@ -132,12 +201,25 @@ func TestValidateFlagGroups(t *testing.T) { for _, flagGroup := range tc.flagGroupsExclusive { c.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) } + for _, flagGroup := range tc.flagGroupsDependsOn { + c.MarkFlagsDependsOn(strings.Split(flagGroup, " ")...) + } + for _, flagGroup := range tc.flagGroupsDependsOnAny { + c.MarkFlagDependsOnAny(strings.Split(flagGroup, " ")...) + } for _, flagGroup := range tc.subCmdFlagGroupsRequired { sub.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...) } for _, flagGroup := range tc.subCmdFlagGroupsExclusive { sub.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) } + for _, flagGroup := range tc.subCmdFlagGroupsDependsOn { + sub.MarkFlagsDependsOn(strings.Split(flagGroup, " ")...) + } + for _, flagGroup := range tc.subCmdFlagGroupsDependsOnAny { + sub.MarkFlagDependsOnAny(strings.Split(flagGroup, " ")...) + } + c.SetArgs(tc.args) err := c.Execute() switch { diff --git a/user_guide.md b/user_guide.md index 5a7acf88..1f4585e4 100644 --- a/user_guide.md +++ b/user_guide.md @@ -318,7 +318,37 @@ rootCmd.Flags().BoolVar(&pw, "yaml", false, "Output in YAML") rootCmd.MarkFlagsMutuallyExclusive("json", "yaml") ``` -In both of these cases: +If you need 1 way dependency groups, as opposed to all flags in a group being required together (like `MarkFlagsRequiredTogether`), then you have a further 2 options. + +You can specify that 1 or more flags be dependent upon another using `MarkFlagsDependsOn` eg, let's say you have an app that performs filtering and you want to support regex and glob filter types, but by default the _filter_ is regex. The user should be able to define the _filter_ and omit the boolean _glob_ flag. But _glob_ can't be specified without the _filter_ also being present. So _glob_ is ___dependent___ on _filter_ (the ___dependee___), but not vice-versa. In general terms this equates to + +> cmd.MarkFlagsDependsOn(dependee, dependent-1, dependent-2 ...) + +eg: + +```go +rootCmd.Flags().StringVarP(&f, "filter", "f", "", "Filter") +rootCmd.Flags().BoolVarP(&t, "glob", "t", false, "Glob") +rootCmd.MarkFlagsDependsOn("filter", "glob") +``` + +A variation on this theme would be to specify that a particular flag is dependent upon on any 1 of another set of flags using `MarkFlagDependsOnAny`. Continuing our filtering theme; if we have a _filter_ flag which may be applied to a set of other entities, let's say _genres_, _albums_ and _artists_, you can do this by marking _filter_ to be dependent on any of these. At least one of the these dependees must be present. The general form would be + +> cmd.MarkFlagDependsOnAny(dependent, dependee-1, dependee-2 ...) + +eg: + +```go +rootCmd.Flags().StringVarP(&filter, "filter", "f", "", "Filter") +rootCmd.Flags().StringVarP(&genre, "genre", "g", "", "Genre") +rootCmd.Flags().StringVarP(&album, "album", "a", "", "Album") +rootCmd.Flags().StringVarP(&artist, "artist", "r", "", "Artist") +rootCmd.MarkFlagDependsOnAny("filter", "genre", "album", "artist") +``` + +So for `MarkFlagsDependsOn` and `MarkFlagDependsOnAny` the order of flag specification is significant. The first flag is denoted as being special as it carries different semantics than the remaining flags as opposed to `MarkFlagsRequiredTogether` and `MarkFlagsMutuallyExclusive` where order is insignificent as all flags in those groups are equal. + +In all of these cases: - both local and persistent flags can be used - **NOTE:** the group is only enforced on commands where every flag is defined - a flag may appear in multiple groups