mirror of
https://github.com/spf13/cobra
synced 2025-05-03 11:57:21 +00:00
336 lines
13 KiB
Go
336 lines
13 KiB
Go
// Copyright 2013-2023 The Cobra Authors
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package cobra
|
|
|
|
import (
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
|
|
flag "github.com/spf13/pflag"
|
|
)
|
|
|
|
const (
|
|
requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
|
|
oneRequiredAnnotation = "cobra_annotation_one_required"
|
|
mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
|
|
)
|
|
|
|
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
|
|
// if the command is invoked with a subset (but not all) of the given flags. It ensures
|
|
// that all flags provided in flagNames must be used together or none at all.
|
|
// Parameters:
|
|
// - flagNames: A slice of strings representing the names of the flags to mark.
|
|
func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
|
|
c.mergePersistentFlags()
|
|
for _, v := range flagNames {
|
|
f := c.Flags().Lookup(v)
|
|
if f == nil {
|
|
panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
|
|
}
|
|
if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil {
|
|
// Only errs if the flag isn't found.
|
|
panic(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors
|
|
// if the command is invoked without at least one flag from the given set of flags. The
|
|
// `flagNames` parameter is a slice of strings containing the names of the flags to be marked.
|
|
func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
|
|
c.mergePersistentFlags()
|
|
for _, v := range flagNames {
|
|
f := c.Flags().Lookup(v)
|
|
if f == nil {
|
|
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
|
|
}
|
|
if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil {
|
|
// Only errs if the flag isn't found.
|
|
panic(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors if the command is invoked with more than one flag from the given set of flags.
|
|
// It takes a variable number of strings representing the names of the flags to be marked as mutually exclusive. Each time this method is called, it adds a new entry to the annotation.
|
|
// If any of the specified flags are not found, it panics with an error message.
|
|
// The flagNames parameter contains one or more string values that represent the names of the flags to be marked as mutually exclusive.
|
|
// There is no return value.
|
|
func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
|
|
c.mergePersistentFlags()
|
|
for _, v := range flagNames {
|
|
f := c.Flags().Lookup(v)
|
|
if f == nil {
|
|
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
|
|
}
|
|
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
|
|
if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the first error encountered.
|
|
func (c *Command) ValidateFlagGroups() error {
|
|
if c.DisableFlagParsing {
|
|
return nil
|
|
}
|
|
|
|
flags := c.Flags()
|
|
|
|
// groupStatus format is the list of flags as a unique ID,
|
|
// then a map of each flag name and whether it is set or not.
|
|
groupStatus := map[string]map[string]bool{}
|
|
oneRequiredGroupStatus := map[string]map[string]bool{}
|
|
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
|
|
flags.VisitAll(func(pflag *flag.Flag) {
|
|
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
|
|
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
|
|
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
|
|
})
|
|
|
|
if err := validateRequiredFlagGroups(groupStatus); err != nil {
|
|
return err
|
|
}
|
|
if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {
|
|
return err
|
|
}
|
|
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// hasAllFlags checks if all flags in the provided flag names exist within the given FlagSet.
|
|
// It returns true if all flags are found, otherwise it returns false.
|
|
func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
|
|
for _, fname := range flagnames {
|
|
f := fs.Lookup(fname)
|
|
if f == nil {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// processFlagForGroupAnnotation processes flags for group annotations, updating the group status based on the provided flag set and flag.
|
|
//
|
|
// Parameters:
|
|
// - flags: A pointer to the flag.FlagSet containing all flags.
|
|
// - pflag: A pointer to the flag.Flag being processed.
|
|
// - annotation: The name of the annotation to process.
|
|
// - groupStatus: A map tracking the status of groups, where the outer key is the group name and the inner map tracks individual flags within that group.
|
|
//
|
|
// This function checks if the provided flag has an annotation matching the specified annotation. If it does, it processes each group associated with the annotation. For each group, it ensures all flags in the group are defined in the provided flag set. If they are, it initializes the group status for any new groups and updates the status of the current flag within its group.
|
|
//
|
|
func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
|
|
groupInfo, found := pflag.Annotations[annotation]
|
|
if found {
|
|
for _, group := range groupInfo {
|
|
if groupStatus[group] == nil {
|
|
flagnames := strings.Split(group, " ")
|
|
|
|
// Only consider this flag group at all if all the flags are defined.
|
|
if !hasAllFlags(flags, flagnames...) {
|
|
continue
|
|
}
|
|
|
|
groupStatus[group] = make(map[string]bool, len(flagnames))
|
|
for _, name := range flagnames {
|
|
groupStatus[group][name] = false
|
|
}
|
|
}
|
|
|
|
groupStatus[group][pflag.Name] = pflag.Changed
|
|
}
|
|
}
|
|
}
|
|
|
|
// validateRequiredFlagGroups checks if any flag groups have required flags that are not all set.
|
|
// It takes a map where keys represent flag groups and values are maps of flag names to their set status.
|
|
// Returns an error if any group contains flags that are either all unset or none unset, which is invalid.
|
|
func validateRequiredFlagGroups(data map[string]map[string]bool) error {
|
|
keys := sortedKeys(data)
|
|
for _, flagList := range keys {
|
|
flagnameAndStatus := data[flagList]
|
|
|
|
unset := []string{}
|
|
for flagname, isSet := range flagnameAndStatus {
|
|
if !isSet {
|
|
unset = append(unset, flagname)
|
|
}
|
|
}
|
|
if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
|
|
continue
|
|
}
|
|
|
|
// Sort values, so they can be tested/scripted against consistently.
|
|
sort.Strings(unset)
|
|
return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateOneRequiredFlagGroups checks that at least one flag from each group is set. It takes a map where keys are groups and values are maps indicating whether each flag within a group is set.
|
|
//
|
|
// Parameters:
|
|
// - data: A map of string to map of string to bool, representing the flags and their statuses in different groups.
|
|
//
|
|
// Returns:
|
|
// - error: An error if any group does not have at least one required flag set. If all required flags are set, it returns nil.
|
|
func validateOneRequiredFlagGroups(data map[string]map[string]bool) error {
|
|
keys := sortedKeys(data)
|
|
for _, flagList := range keys {
|
|
flagnameAndStatus := data[flagList]
|
|
var set []string
|
|
for flagname, isSet := range flagnameAndStatus {
|
|
if isSet {
|
|
set = append(set, flagname)
|
|
}
|
|
}
|
|
if len(set) >= 1 {
|
|
continue
|
|
}
|
|
|
|
// Sort values, so they can be tested/scripted against consistently.
|
|
sort.Strings(set)
|
|
return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// validateExclusiveFlagGroups checks that within a map of flag groups, no more than one flag from each group is set.
|
|
//
|
|
// Parameters:
|
|
// - data: A map where keys are group names and values are maps of flag names to their status (true if set).
|
|
//
|
|
// Returns:
|
|
// - error: If any group contains more than one set flag, an error is returned with details. Otherwise, returns nil.
|
|
//
|
|
// Example:
|
|
// err := validateExclusiveFlagGroups(map[string]map[string]bool{
|
|
// "group1": {"flagA": true, "flagB": false},
|
|
// "group2": {"flagC": false, "flagD": true},
|
|
// })
|
|
// if err != nil {
|
|
// fmt.Println(err)
|
|
// }
|
|
func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
|
|
keys := sortedKeys(data)
|
|
for _, flagList := range keys {
|
|
flagnameAndStatus := data[flagList]
|
|
var set []string
|
|
for flagname, isSet := range flagnameAndStatus {
|
|
if isSet {
|
|
set = append(set, flagname)
|
|
}
|
|
}
|
|
if len(set) == 0 || len(set) == 1 {
|
|
continue
|
|
}
|
|
|
|
// Sort values, so they can be tested/scripted against consistently.
|
|
sort.Strings(set)
|
|
return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// sortedKeys returns a slice of strings containing the keys of the input map `m`, sorted in ascending order.
|
|
// The function does not modify the original map and ensures that the keys are returned as sorted lexicographically.
|
|
func sortedKeys(m map[string]map[string]bool) []string {
|
|
keys := make([]string, len(m))
|
|
i := 0
|
|
for k := range m {
|
|
keys[i] = k
|
|
i++
|
|
}
|
|
sort.Strings(keys)
|
|
return keys
|
|
}
|
|
|
|
// enforceFlagGroupsForCompletion enforces the completion logic for flag groups in a command.
|
|
// It ensures that when a flag in a group is present, other flags in the group are marked required,
|
|
// and when none of the flags in a one-required group are present, all flags in the group are marked required.
|
|
// Additionally, it hides flags that are mutually exclusive to others. This allows the standard completion logic
|
|
// to behave appropriately for flag groups.
|
|
func (c *Command) enforceFlagGroupsForCompletion() {
|
|
if c.DisableFlagParsing {
|
|
return
|
|
}
|
|
|
|
flags := c.Flags()
|
|
groupStatus := map[string]map[string]bool{}
|
|
oneRequiredGroupStatus := map[string]map[string]bool{}
|
|
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
|
|
c.Flags().VisitAll(func(pflag *flag.Flag) {
|
|
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
|
|
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
|
|
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
|
|
})
|
|
|
|
// If a flag that is part of a group is present, we make all the other flags
|
|
// of that group required so that the shell completion suggests them automatically
|
|
for flagList, flagnameAndStatus := range groupStatus {
|
|
for _, isSet := range flagnameAndStatus {
|
|
if isSet {
|
|
// One of the flags of the group is set, mark the other ones as required
|
|
for _, fName := range strings.Split(flagList, " ") {
|
|
_ = c.MarkFlagRequired(fName)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// If none of the flags of a one-required group are present, we make all the flags
|
|
// of that group required so that the shell completion suggests them automatically
|
|
for flagList, flagnameAndStatus := range oneRequiredGroupStatus {
|
|
isSet := false
|
|
|
|
for _, isSet = range flagnameAndStatus {
|
|
if isSet {
|
|
break
|
|
}
|
|
}
|
|
|
|
// None of the flags of the group are set, mark all flags in the group
|
|
// as required
|
|
if !isSet {
|
|
for _, fName := range strings.Split(flagList, " ") {
|
|
_ = c.MarkFlagRequired(fName)
|
|
}
|
|
}
|
|
}
|
|
|
|
// If a flag that is mutually exclusive to others is present, we hide the other
|
|
// flags of that group so the shell completion does not suggest them
|
|
for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
|
|
for flagName, isSet := range flagnameAndStatus {
|
|
if isSet {
|
|
// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
|
|
// Don't mark the flag that is already set as hidden because it may be an
|
|
// array or slice flag and therefore must continue being suggested
|
|
for _, fName := range strings.Split(flagList, " ") {
|
|
if fName != flagName {
|
|
flag := c.Flags().Lookup(fName)
|
|
flag.Hidden = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|