2023-03-06 02:28:31 +00:00
// Copyright 2013-2023 The Cobra Authors
2022-04-17 16:04:57 -05:00
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
2022-09-16 13:55:56 +02:00
//
// http://www.apache.org/licenses/LICENSE-2.0
2022-04-17 16:04:57 -05:00
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cobra
import (
"fmt"
"sort"
"strings"
flag "github.com/spf13/pflag"
)
const (
Consistent annotation names (#2140)
Add `Annotation` suffix to the private annotations to allow nicer code
using the constants.
For example one can use the current annotation names as a temporary
variable instead of unclear shortcut. Instead of this:
rag := flagsFromAnnotation(c, f, requiredAsGroup)
me := flagsFromAnnotation(c, f, mutuallyExclusive)
or := flagsFromAnnotation(c, f, oneRequired)
We can use now:
requiredAsGrop := flagsFromAnnotation(c, f, requiredAsGroupAnnotation)
mutuallyExclusive := flagsFromAnnotation(c, f, mutuallyExclusiveAnnotation)
oneRequired := flagsFromAnnotation(c, f, oneRequiredAnnotation)
Example taken from #2105.
2024-05-18 16:41:31 +03:00
requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
oneRequiredAnnotation = "cobra_annotation_one_required"
mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
2022-04-17 16:04:57 -05:00
)
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
2025-04-26 20:38:36 +00:00
// if the command is invoked with a subset (but not all) of the given flags. It ensures
// that all flags provided in flagNames must be used together or none at all.
// Parameters:
// - flagNames: A slice of strings representing the names of the flags to mark.
2022-04-17 16:04:57 -05:00
func ( c * Command ) MarkFlagsRequiredTogether ( flagNames ... string ) {
c . mergePersistentFlags ( )
for _ , v := range flagNames {
f := c . Flags ( ) . Lookup ( v )
if f == nil {
panic ( fmt . Sprintf ( "Failed to find flag %q and mark it as being required in a flag group" , v ) )
}
Consistent annotation names (#2140)
Add `Annotation` suffix to the private annotations to allow nicer code
using the constants.
For example one can use the current annotation names as a temporary
variable instead of unclear shortcut. Instead of this:
rag := flagsFromAnnotation(c, f, requiredAsGroup)
me := flagsFromAnnotation(c, f, mutuallyExclusive)
or := flagsFromAnnotation(c, f, oneRequired)
We can use now:
requiredAsGrop := flagsFromAnnotation(c, f, requiredAsGroupAnnotation)
mutuallyExclusive := flagsFromAnnotation(c, f, mutuallyExclusiveAnnotation)
oneRequired := flagsFromAnnotation(c, f, oneRequiredAnnotation)
Example taken from #2105.
2024-05-18 16:41:31 +03:00
if err := c . Flags ( ) . SetAnnotation ( v , requiredAsGroupAnnotation , append ( f . Annotations [ requiredAsGroupAnnotation ] , strings . Join ( flagNames , " " ) ) ) ; err != nil {
2022-04-17 16:04:57 -05:00
// Only errs if the flag isn't found.
panic ( err )
}
}
}
2023-07-16 18:38:22 +02:00
// MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors
2025-04-26 20:38:36 +00:00
// if the command is invoked without at least one flag from the given set of flags. The
// `flagNames` parameter is a slice of strings containing the names of the flags to be marked.
2023-07-16 18:38:22 +02:00
func ( c * Command ) MarkFlagsOneRequired ( flagNames ... string ) {
c . mergePersistentFlags ( )
for _ , v := range flagNames {
f := c . Flags ( ) . Lookup ( v )
if f == nil {
panic ( fmt . Sprintf ( "Failed to find flag %q and mark it as being in a one-required flag group" , v ) )
}
Consistent annotation names (#2140)
Add `Annotation` suffix to the private annotations to allow nicer code
using the constants.
For example one can use the current annotation names as a temporary
variable instead of unclear shortcut. Instead of this:
rag := flagsFromAnnotation(c, f, requiredAsGroup)
me := flagsFromAnnotation(c, f, mutuallyExclusive)
or := flagsFromAnnotation(c, f, oneRequired)
We can use now:
requiredAsGrop := flagsFromAnnotation(c, f, requiredAsGroupAnnotation)
mutuallyExclusive := flagsFromAnnotation(c, f, mutuallyExclusiveAnnotation)
oneRequired := flagsFromAnnotation(c, f, oneRequiredAnnotation)
Example taken from #2105.
2024-05-18 16:41:31 +03:00
if err := c . Flags ( ) . SetAnnotation ( v , oneRequiredAnnotation , append ( f . Annotations [ oneRequiredAnnotation ] , strings . Join ( flagNames , " " ) ) ) ; err != nil {
2023-07-16 18:38:22 +02:00
// Only errs if the flag isn't found.
panic ( err )
}
}
}
2025-04-26 20:38:36 +00:00
// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors if the command is invoked with more than one flag from the given set of flags.
// It takes a variable number of strings representing the names of the flags to be marked as mutually exclusive. Each time this method is called, it adds a new entry to the annotation.
// If any of the specified flags are not found, it panics with an error message.
// The flagNames parameter contains one or more string values that represent the names of the flags to be marked as mutually exclusive.
// There is no return value.
2022-04-17 16:04:57 -05:00
func ( c * Command ) MarkFlagsMutuallyExclusive ( flagNames ... string ) {
c . mergePersistentFlags ( )
for _ , v := range flagNames {
f := c . Flags ( ) . Lookup ( v )
if f == nil {
panic ( fmt . Sprintf ( "Failed to find flag %q and mark it as being in a mutually exclusive flag group" , v ) )
}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
Consistent annotation names (#2140)
Add `Annotation` suffix to the private annotations to allow nicer code
using the constants.
For example one can use the current annotation names as a temporary
variable instead of unclear shortcut. Instead of this:
rag := flagsFromAnnotation(c, f, requiredAsGroup)
me := flagsFromAnnotation(c, f, mutuallyExclusive)
or := flagsFromAnnotation(c, f, oneRequired)
We can use now:
requiredAsGrop := flagsFromAnnotation(c, f, requiredAsGroupAnnotation)
mutuallyExclusive := flagsFromAnnotation(c, f, mutuallyExclusiveAnnotation)
oneRequired := flagsFromAnnotation(c, f, oneRequiredAnnotation)
Example taken from #2105.
2024-05-18 16:41:31 +03:00
if err := c . Flags ( ) . SetAnnotation ( v , mutuallyExclusiveAnnotation , append ( f . Annotations [ mutuallyExclusiveAnnotation ] , strings . Join ( flagNames , " " ) ) ) ; err != nil {
2022-04-17 16:04:57 -05:00
panic ( err )
}
}
}
2025-04-26 20:38:36 +00:00
// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the first error encountered.
2022-09-27 18:27:48 +08:00
func ( c * Command ) ValidateFlagGroups ( ) error {
2022-04-17 16:04:57 -05:00
if c . DisableFlagParsing {
return nil
}
flags := c . Flags ( )
// groupStatus format is the list of flags as a unique ID,
// then a map of each flag name and whether it is set or not.
groupStatus := map [ string ] map [ string ] bool { }
2023-07-16 18:38:22 +02:00
oneRequiredGroupStatus := map [ string ] map [ string ] bool { }
2022-04-17 16:04:57 -05:00
mutuallyExclusiveGroupStatus := map [ string ] map [ string ] bool { }
flags . VisitAll ( func ( pflag * flag . Flag ) {
Consistent annotation names (#2140)
Add `Annotation` suffix to the private annotations to allow nicer code
using the constants.
For example one can use the current annotation names as a temporary
variable instead of unclear shortcut. Instead of this:
rag := flagsFromAnnotation(c, f, requiredAsGroup)
me := flagsFromAnnotation(c, f, mutuallyExclusive)
or := flagsFromAnnotation(c, f, oneRequired)
We can use now:
requiredAsGrop := flagsFromAnnotation(c, f, requiredAsGroupAnnotation)
mutuallyExclusive := flagsFromAnnotation(c, f, mutuallyExclusiveAnnotation)
oneRequired := flagsFromAnnotation(c, f, oneRequiredAnnotation)
Example taken from #2105.
2024-05-18 16:41:31 +03:00
processFlagForGroupAnnotation ( flags , pflag , requiredAsGroupAnnotation , groupStatus )
processFlagForGroupAnnotation ( flags , pflag , oneRequiredAnnotation , oneRequiredGroupStatus )
processFlagForGroupAnnotation ( flags , pflag , mutuallyExclusiveAnnotation , mutuallyExclusiveGroupStatus )
2022-04-17 16:04:57 -05:00
} )
if err := validateRequiredFlagGroups ( groupStatus ) ; err != nil {
return err
}
2023-07-16 18:38:22 +02:00
if err := validateOneRequiredFlagGroups ( oneRequiredGroupStatus ) ; err != nil {
return err
}
2022-04-17 16:04:57 -05:00
if err := validateExclusiveFlagGroups ( mutuallyExclusiveGroupStatus ) ; err != nil {
return err
}
return nil
}
2025-04-26 20:38:36 +00:00
// hasAllFlags checks if all flags in the provided flag names exist within the given FlagSet.
// It returns true if all flags are found, otherwise it returns false.
2022-04-17 16:04:57 -05:00
func hasAllFlags ( fs * flag . FlagSet , flagnames ... string ) bool {
for _ , fname := range flagnames {
f := fs . Lookup ( fname )
if f == nil {
return false
}
}
return true
}
2025-04-26 20:38:36 +00:00
// processFlagForGroupAnnotation processes flags for group annotations, updating the group status based on the provided flag set and flag.
//
// Parameters:
// - flags: A pointer to the flag.FlagSet containing all flags.
// - pflag: A pointer to the flag.Flag being processed.
// - annotation: The name of the annotation to process.
// - groupStatus: A map tracking the status of groups, where the outer key is the group name and the inner map tracks individual flags within that group.
//
// This function checks if the provided flag has an annotation matching the specified annotation. If it does, it processes each group associated with the annotation. For each group, it ensures all flags in the group are defined in the provided flag set. If they are, it initializes the group status for any new groups and updates the status of the current flag within its group.
//
2022-04-17 16:04:57 -05:00
func processFlagForGroupAnnotation ( flags * flag . FlagSet , pflag * flag . Flag , annotation string , groupStatus map [ string ] map [ string ] bool ) {
groupInfo , found := pflag . Annotations [ annotation ]
if found {
for _ , group := range groupInfo {
if groupStatus [ group ] == nil {
flagnames := strings . Split ( group , " " )
// Only consider this flag group at all if all the flags are defined.
if ! hasAllFlags ( flags , flagnames ... ) {
continue
}
2023-11-23 19:24:33 +02:00
groupStatus [ group ] = make ( map [ string ] bool , len ( flagnames ) )
2022-04-17 16:04:57 -05:00
for _ , name := range flagnames {
groupStatus [ group ] [ name ] = false
}
}
groupStatus [ group ] [ pflag . Name ] = pflag . Changed
}
}
}
2025-04-26 20:38:36 +00:00
// validateRequiredFlagGroups checks if any flag groups have required flags that are not all set.
// It takes a map where keys represent flag groups and values are maps of flag names to their set status.
// Returns an error if any group contains flags that are either all unset or none unset, which is invalid.
2022-04-17 16:04:57 -05:00
func validateRequiredFlagGroups ( data map [ string ] map [ string ] bool ) error {
keys := sortedKeys ( data )
for _ , flagList := range keys {
flagnameAndStatus := data [ flagList ]
unset := [ ] string { }
for flagname , isSet := range flagnameAndStatus {
if ! isSet {
unset = append ( unset , flagname )
}
}
if len ( unset ) == len ( flagnameAndStatus ) || len ( unset ) == 0 {
continue
}
// Sort values, so they can be tested/scripted against consistently.
sort . Strings ( unset )
return fmt . Errorf ( "if any flags in the group [%v] are set they must all be set; missing %v" , flagList , unset )
}
return nil
}
2025-04-26 20:38:36 +00:00
// validateOneRequiredFlagGroups checks that at least one flag from each group is set. It takes a map where keys are groups and values are maps indicating whether each flag within a group is set.
//
// Parameters:
// - data: A map of string to map of string to bool, representing the flags and their statuses in different groups.
//
// Returns:
// - error: An error if any group does not have at least one required flag set. If all required flags are set, it returns nil.
2023-07-16 18:38:22 +02:00
func validateOneRequiredFlagGroups ( data map [ string ] map [ string ] bool ) error {
keys := sortedKeys ( data )
for _ , flagList := range keys {
flagnameAndStatus := data [ flagList ]
var set [ ] string
for flagname , isSet := range flagnameAndStatus {
if isSet {
set = append ( set , flagname )
}
}
if len ( set ) >= 1 {
continue
}
// Sort values, so they can be tested/scripted against consistently.
sort . Strings ( set )
return fmt . Errorf ( "at least one of the flags in the group [%v] is required" , flagList )
}
return nil
}
2025-04-26 20:38:36 +00:00
// validateExclusiveFlagGroups checks that within a map of flag groups, no more than one flag from each group is set.
//
// Parameters:
// - data: A map where keys are group names and values are maps of flag names to their status (true if set).
//
// Returns:
// - error: If any group contains more than one set flag, an error is returned with details. Otherwise, returns nil.
//
// Example:
// err := validateExclusiveFlagGroups(map[string]map[string]bool{
// "group1": {"flagA": true, "flagB": false},
// "group2": {"flagC": false, "flagD": true},
// })
// if err != nil {
// fmt.Println(err)
// }
2022-04-17 16:04:57 -05:00
func validateExclusiveFlagGroups ( data map [ string ] map [ string ] bool ) error {
keys := sortedKeys ( data )
for _ , flagList := range keys {
flagnameAndStatus := data [ flagList ]
var set [ ] string
for flagname , isSet := range flagnameAndStatus {
if isSet {
set = append ( set , flagname )
}
}
if len ( set ) == 0 || len ( set ) == 1 {
continue
}
// Sort values, so they can be tested/scripted against consistently.
sort . Strings ( set )
return fmt . Errorf ( "if any flags in the group [%v] are set none of the others can be; %v were all set" , flagList , set )
}
return nil
}
2025-04-26 20:38:36 +00:00
// sortedKeys returns a slice of strings containing the keys of the input map `m`, sorted in ascending order.
// The function does not modify the original map and ensures that the keys are returned as sorted lexicographically.
2022-04-17 16:04:57 -05:00
func sortedKeys ( m map [ string ] map [ string ] bool ) [ ] string {
keys := make ( [ ] string , len ( m ) )
i := 0
for k := range m {
keys [ i ] = k
i ++
}
sort . Strings ( keys )
return keys
}
2022-06-20 22:04:28 -04:00
2025-04-26 20:38:36 +00:00
// enforceFlagGroupsForCompletion enforces the completion logic for flag groups in a command.
// It ensures that when a flag in a group is present, other flags in the group are marked required,
// and when none of the flags in a one-required group are present, all flags in the group are marked required.
// Additionally, it hides flags that are mutually exclusive to others. This allows the standard completion logic
// to behave appropriately for flag groups.
2022-06-20 22:04:28 -04:00
func ( c * Command ) enforceFlagGroupsForCompletion ( ) {
if c . DisableFlagParsing {
return
}
flags := c . Flags ( )
groupStatus := map [ string ] map [ string ] bool { }
2023-07-16 18:38:22 +02:00
oneRequiredGroupStatus := map [ string ] map [ string ] bool { }
2022-06-20 22:04:28 -04:00
mutuallyExclusiveGroupStatus := map [ string ] map [ string ] bool { }
c . Flags ( ) . VisitAll ( func ( pflag * flag . Flag ) {
Consistent annotation names (#2140)
Add `Annotation` suffix to the private annotations to allow nicer code
using the constants.
For example one can use the current annotation names as a temporary
variable instead of unclear shortcut. Instead of this:
rag := flagsFromAnnotation(c, f, requiredAsGroup)
me := flagsFromAnnotation(c, f, mutuallyExclusive)
or := flagsFromAnnotation(c, f, oneRequired)
We can use now:
requiredAsGrop := flagsFromAnnotation(c, f, requiredAsGroupAnnotation)
mutuallyExclusive := flagsFromAnnotation(c, f, mutuallyExclusiveAnnotation)
oneRequired := flagsFromAnnotation(c, f, oneRequiredAnnotation)
Example taken from #2105.
2024-05-18 16:41:31 +03:00
processFlagForGroupAnnotation ( flags , pflag , requiredAsGroupAnnotation , groupStatus )
processFlagForGroupAnnotation ( flags , pflag , oneRequiredAnnotation , oneRequiredGroupStatus )
processFlagForGroupAnnotation ( flags , pflag , mutuallyExclusiveAnnotation , mutuallyExclusiveGroupStatus )
2022-06-20 22:04:28 -04:00
} )
// If a flag that is part of a group is present, we make all the other flags
// of that group required so that the shell completion suggests them automatically
for flagList , flagnameAndStatus := range groupStatus {
for _ , isSet := range flagnameAndStatus {
if isSet {
// One of the flags of the group is set, mark the other ones as required
for _ , fName := range strings . Split ( flagList , " " ) {
_ = c . MarkFlagRequired ( fName )
}
}
}
}
2023-07-16 18:38:22 +02:00
// If none of the flags of a one-required group are present, we make all the flags
// of that group required so that the shell completion suggests them automatically
for flagList , flagnameAndStatus := range oneRequiredGroupStatus {
2023-11-23 19:24:33 +02:00
isSet := false
2023-07-16 18:38:22 +02:00
2023-11-23 19:24:33 +02:00
for _ , isSet = range flagnameAndStatus {
2023-07-16 18:38:22 +02:00
if isSet {
2023-11-23 19:24:33 +02:00
break
2023-07-16 18:38:22 +02:00
}
}
// None of the flags of the group are set, mark all flags in the group
// as required
2023-11-23 19:24:33 +02:00
if ! isSet {
2023-07-16 18:38:22 +02:00
for _ , fName := range strings . Split ( flagList , " " ) {
_ = c . MarkFlagRequired ( fName )
}
}
}
2022-06-20 22:04:28 -04:00
// If a flag that is mutually exclusive to others is present, we hide the other
// flags of that group so the shell completion does not suggest them
for flagList , flagnameAndStatus := range mutuallyExclusiveGroupStatus {
for flagName , isSet := range flagnameAndStatus {
if isSet {
// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
// Don't mark the flag that is already set as hidden because it may be an
// array or slice flag and therefore must continue being suggested
for _ , fName := range strings . Split ( flagList , " " ) {
if fName != flagName {
flag := c . Flags ( ) . Lookup ( fName )
flag . Hidden = true
}
}
}
}
}
}