Allow BindEnv to register multiple environment variables.

This change modifies BindEnv to permit a list of environment variable
names in order to support multiple env. vars. for the same config key.
When this form is used, env. keys take precedence in the written order.

Closes #971
This commit is contained in:
Gabriel Aszalos 2020-09-10 13:08:26 +03:00
parent 387404d518
commit e5d7915cac
No known key found for this signature in database
GPG key ID: C03768B4604E14A5
2 changed files with 32 additions and 13 deletions

View file

@ -205,7 +205,7 @@ type Viper struct {
defaults map[string]interface{} defaults map[string]interface{}
kvstore map[string]interface{} kvstore map[string]interface{}
pflags map[string]FlagValue pflags map[string]FlagValue
env map[string]string env map[string][]string
aliases map[string]string aliases map[string]string
typeByDefValue bool typeByDefValue bool
@ -228,7 +228,7 @@ func New() *Viper {
v.defaults = make(map[string]interface{}) v.defaults = make(map[string]interface{})
v.kvstore = make(map[string]interface{}) v.kvstore = make(map[string]interface{})
v.pflags = make(map[string]FlagValue) v.pflags = make(map[string]FlagValue)
v.env = make(map[string]string) v.env = make(map[string][]string)
v.aliases = make(map[string]string) v.aliases = make(map[string]string)
v.typeByDefValue = false v.typeByDefValue = false
@ -993,21 +993,18 @@ func (v *Viper) BindFlagValue(key string, flag FlagValue) error {
// EnvPrefix will be used when set when env name is not provided. // EnvPrefix will be used when set when env name is not provided.
func BindEnv(input ...string) error { return v.BindEnv(input...) } func BindEnv(input ...string) error { return v.BindEnv(input...) }
func (v *Viper) BindEnv(input ...string) error { func (v *Viper) BindEnv(input ...string) error {
var key, envkey string
if len(input) == 0 { if len(input) == 0 {
return fmt.Errorf("missing key to bind to") return fmt.Errorf("missing key to bind to")
} }
key = strings.ToLower(input[0]) key := strings.ToLower(input[0])
if len(input) == 1 { if len(input) == 1 {
envkey = v.mergeWithEnvPrefix(key) v.env[key] = append(v.env[key], v.mergeWithEnvPrefix(key))
} else { } else {
envkey = input[1] v.env[key] = append(v.env[key], input[1:]...)
} }
v.env[key] = envkey
return nil return nil
} }
@ -1086,10 +1083,12 @@ func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
return nil return nil
} }
} }
envkey, exists := v.env[lcaseKey] envkeys, exists := v.env[lcaseKey]
if exists { if exists {
if val, ok := v.getEnv(envkey); ok { for _, envkey := range envkeys {
return val if val, ok := v.getEnv(envkey); ok {
return val
}
} }
} }
if nested && v.isPathShadowedInFlatMap(path, v.env) != "" { if nested && v.isPathShadowedInFlatMap(path, v.env) != "" {
@ -1658,6 +1657,14 @@ func castToMapStringInterface(
return tgt return tgt
} }
func castMapStringSliceToMapInterface(src map[string][]string) map[string]interface{} {
tgt := map[string]interface{}{}
for k, v := range src {
tgt[k] = v
}
return tgt
}
func castMapStringToMapInterface(src map[string]string) map[string]interface{} { func castMapStringToMapInterface(src map[string]string) map[string]interface{} {
tgt := map[string]interface{}{} tgt := map[string]interface{}{}
for k, v := range src { for k, v := range src {
@ -1828,7 +1835,7 @@ func (v *Viper) AllKeys() []string {
m = v.flattenAndMergeMap(m, castMapStringToMapInterface(v.aliases), "") m = v.flattenAndMergeMap(m, castMapStringToMapInterface(v.aliases), "")
m = v.flattenAndMergeMap(m, v.override, "") m = v.flattenAndMergeMap(m, v.override, "")
m = v.mergeFlatMap(m, castMapFlagToMapInterface(v.pflags)) m = v.mergeFlatMap(m, castMapFlagToMapInterface(v.pflags))
m = v.mergeFlatMap(m, castMapStringToMapInterface(v.env)) m = v.mergeFlatMap(m, castMapStringSliceToMapInterface(v.env))
m = v.flattenAndMergeMap(m, v.config, "") m = v.flattenAndMergeMap(m, v.config, "")
m = v.flattenAndMergeMap(m, v.kvstore, "") m = v.flattenAndMergeMap(m, v.kvstore, "")
m = v.flattenAndMergeMap(m, v.defaults, "") m = v.flattenAndMergeMap(m, v.defaults, "")

View file

@ -487,10 +487,11 @@ func TestEnv(t *testing.T) {
initJSON() initJSON()
BindEnv("id") BindEnv("id")
BindEnv("f", "FOOD") BindEnv("f", "FOOD", "OLD_FOOD")
os.Setenv("ID", "13") os.Setenv("ID", "13")
os.Setenv("FOOD", "apple") os.Setenv("FOOD", "apple")
os.Setenv("OLD_FOOD", "banana")
os.Setenv("NAME", "crunk") os.Setenv("NAME", "crunk")
assert.Equal(t, "13", Get("id")) assert.Equal(t, "13", Get("id"))
@ -502,6 +503,17 @@ func TestEnv(t *testing.T) {
assert.Equal(t, "crunk", Get("name")) assert.Equal(t, "crunk", Get("name"))
} }
func TestMultipleEnv(t *testing.T) {
initJSON()
BindEnv("f", "FOOD", "OLD_FOOD")
os.Unsetenv("FOOD")
os.Setenv("OLD_FOOD", "banana")
assert.Equal(t, "banana", Get("f"))
}
func TestEmptyEnv(t *testing.T) { func TestEmptyEnv(t *testing.T) {
initJSON() initJSON()