From c3db1c77f147e9f69929d7de23ee600d270d40f7 Mon Sep 17 00:00:00 2001 From: maxime mouial Date: Mon, 22 Apr 2024 14:54:57 +0200 Subject: [PATCH] Fix Viper 'UnmarshalKey' behavior to include all sources (#32) Fix Viper 'UnmarshalKey' behavior to include all sources Viper 'UnmarshalKey' seems to have been designed to work on leaf settings only. When being called on a group of settings, it would not merge all sources (default, env vars, config, override, ...). This means that using the `Set()` method could change the returned value from 'UnmarshalKey' in unexpected ways (i.e. would no longer return the value from the configuration). --- viper.go | 241 ++++++++++++++++++++++++++++++++------------------ viper_test.go | 35 ++++++++ 2 files changed, 188 insertions(+), 88 deletions(-) diff --git a/viper.go b/viper.go index b77e682..fb4be60 100644 --- a/viper.go +++ b/viper.go @@ -864,11 +864,19 @@ func (v *Viper) GetSizeInBytesE(key string) (uint, error) { func UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error { return v.UnmarshalKey(key, rawVal, opts...) } -func (v *Viper) UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error { - err := decode(v.Get(key), defaultDecoderConfig(rawVal, opts...)) - if err != nil { - return err +func (v *Viper) UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error { + // We first get all value we have for a key (from default to override) + lcaseKey := strings.ToLower(key) + values := v.findAll(lcaseKey) + + // findAll returns all the value for a settings from highest priority to lowest. So when aggregating them we want + // to start at the lowest priority and override with higher ones. + for idx := len(values) - 1; idx >= 0; idx-- { + err := decode(values[idx], defaultDecoderConfig(rawVal, opts...)) + if err != nil { + return err + } } return nil @@ -1001,20 +1009,106 @@ func (v *Viper) BindEnv(input ...string) error { return nil } -// Given a key, find the value. -// Viper will check in the following order: -// flag, env, config file, key/value store, default. -// If skipDefault is set to true, find will ignore default values. -// Viper will check to see if an alias exists first. -// Note: this assumes a lower-cased key given. -func (v *Viper) find(lcaseKey string, skipDefault bool) interface{} { +// findFromOverride finds a key from override (ie: from 'Set()') +func (v *Viper) findFromOverride(_ string, path []string, nested bool) (interface{}, bool) { + if val := v.searchMap(v.override, path); val != nil { + return val, true + } + if nested && v.isPathShadowedInDeepMap(path, v.override) != "" { + return nil, true + } + return nil, false +} - var ( - val interface{} - exists bool - path = strings.Split(lcaseKey, v.keyDelim) - nested = len(path) > 1 - ) +func (v *Viper) findFromPFlag(lcaseKey string, path []string, nested bool) (interface{}, bool) { + flag, exists := v.pflags[lcaseKey] + if exists && flag.HasChanged() { + switch flag.ValueType() { + case "int", "int8", "int16", "int32", "int64": + return cast.ToInt(flag.ValueString()), true + case "bool": + return cast.ToBool(flag.ValueString()), true + case "stringSlice": + s := strings.TrimPrefix(flag.ValueString(), "[") + s = strings.TrimSuffix(s, "]") + res, _ := readAsCSV(s) + return res, true + default: + return flag.ValueString(), true + } + } + if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" { + return nil, true + } + return nil, false +} + +func (v *Viper) findFromEnv(lcaseKey string, path []string, nested bool) (interface{}, bool) { + if v.automaticEnvApplied { + // even if it hasn't been registered, if automaticEnv is used, + // check any Get request + if val, ok := v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); ok { + return val, true + } + if nested && v.isPathShadowedInAutoEnv(path) != "" { + return nil, true + } + } + envkeys, exists := v.env[lcaseKey] + if exists { + for _, key := range envkeys { + if val, ok := v.getEnv(key); ok { + if fn, ok := v.envTransform[lcaseKey]; ok { + return fn(val), true + } + return val, true + } + } + } + if nested && v.isPathShadowedInFlatMap(path, v.env) != "" { + return nil, true + } + return nil, false +} + +func (v *Viper) findFromConfig(_ string, path []string, nested bool) (interface{}, bool) { + val := v.searchMapWithPathPrefixes(v.config, path) + if val != nil { + return val, true + } + if nested && v.isPathShadowedInDeepMap(path, v.config) != "" { + return nil, true + } + return nil, false +} + +func (v *Viper) findFromKVStore(_ string, path []string, nested bool) (interface{}, bool) { + val := v.searchMap(v.kvstore, path) + if val != nil { + return val, true + } + if nested && v.isPathShadowedInDeepMap(path, v.kvstore) != "" { + return nil, true + } + return nil, false +} + +func (v *Viper) findFromDefault(_ string, path []string, nested bool) (interface{}, bool) { + val := v.searchMap(v.defaults, path) + if val != nil { + return val, true + } + if nested && v.isPathShadowedInDeepMap(path, v.defaults) != "" { + return nil, true + } + return nil, false +} + +// findAll returns all the value for a settings from highest priority to lowest +func (v *Viper) findAll(lcaseKey string) []interface{} { + var values []interface{} + path := strings.Split(lcaseKey, v.keyDelim) + nested := len(path) > 1 // compute the path through the nested maps to the nested value if nested && v.isPathShadowedInDeepMap(path, castMapStringToMapInterface(v.aliases)) != "" { @@ -1026,89 +1120,60 @@ func (v *Viper) find(lcaseKey string, skipDefault bool) interface{} { path = strings.Split(lcaseKey, v.keyDelim) nested = len(path) > 1 - // Set() override first - val = v.searchMap(v.override, path) - if val != nil { - return val - } - if nested && v.isPathShadowedInDeepMap(path, v.override) != "" { - return nil + getters := []func(string, []string, bool) (interface{}, bool){ + v.findFromOverride, // override == value from Set() + v.findFromPFlag, + v.findFromEnv, + v.findFromConfig, + v.findFromKVStore, + v.findFromDefault, } - // PFlag override next - flag, exists := v.pflags[lcaseKey] - if exists && flag.HasChanged() { - switch flag.ValueType() { - case "int", "int8", "int16", "int32", "int64": - return cast.ToInt(flag.ValueString()) - case "bool": - return cast.ToBool(flag.ValueString()) - case "stringSlice": - s := strings.TrimPrefix(flag.ValueString(), "[") - s = strings.TrimSuffix(s, "]") - res, _ := readAsCSV(s) - return res - default: - return flag.ValueString() + for _, getter := range getters { + if val, found := getter(lcaseKey, path, nested); found { + values = append(values, val) } } - if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" { + + return values +} + +// Given a key, find the value. +// Viper will check in the following order: +// flag, env, config file, key/value store, default. +// If skipDefault is set to true, find will ignore default values. +// Viper will check to see if an alias exists first. +// Note: this assumes a lower-cased key given. +func (v *Viper) find(lcaseKey string, skipDefault bool) interface{} { + path := strings.Split(lcaseKey, v.keyDelim) + nested := len(path) > 1 + + // compute the path through the nested maps to the nested value + if nested && v.isPathShadowedInDeepMap(path, castMapStringToMapInterface(v.aliases)) != "" { return nil } - // Env override next - if v.automaticEnvApplied { - // even if it hasn't been registered, if automaticEnv is used, - // check any Get request - if val, ok := v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); ok { - return val - } - if nested && v.isPathShadowedInAutoEnv(path) != "" { - return nil - } - } - envkeys, exists := v.env[lcaseKey] - if exists { - for _, key := range envkeys { - if val, ok := v.getEnv(key); ok { - if fn, ok := v.envTransform[lcaseKey]; ok { - return fn(val) - } - return val - } - } - } - if nested && v.isPathShadowedInFlatMap(path, v.env) != "" { - return nil + // if the requested key is an alias, then return the proper key + lcaseKey = v.realKey(lcaseKey) + path = strings.Split(lcaseKey, v.keyDelim) + nested = len(path) > 1 + + getters := []func(string, []string, bool) (interface{}, bool){ + v.findFromOverride, // override == value from Set() + v.findFromPFlag, + v.findFromEnv, + v.findFromConfig, + v.findFromKVStore, } - // Config file next - val = v.searchMapWithPathPrefixes(v.config, path) - if val != nil { - return val - } - if nested && v.isPathShadowedInDeepMap(path, v.config) != "" { - return nil - } - - // K/V store next - val = v.searchMap(v.kvstore, path) - if val != nil { - return val - } - if nested && v.isPathShadowedInDeepMap(path, v.kvstore) != "" { - return nil - } - - // Default next if !skipDefault { - val = v.searchMap(v.defaults, path) - if val != nil { + getters = append(getters, v.findFromDefault) + } + + for _, getter := range getters { + if val, found := getter(lcaseKey, path, nested); found { return val } - if nested && v.isPathShadowedInDeepMap(path, v.defaults) != "" { - return nil - } } // last chance: if no other value is returned and a flag does exist for the value, diff --git a/viper_test.go b/viper_test.go index 5507a6d..9d3195c 100644 --- a/viper_test.go +++ b/viper_test.go @@ -1703,6 +1703,41 @@ func TestParseNested(t *testing.T) { assert.Equal(t, 200*time.Millisecond, items[0].Nested.Delay) } +func TestUnmarshalKey(t *testing.T) { + type testStruct struct { + Delay int `mapstructure:"delay" yaml:"delay"` + Port int `mapstructure:"port" yaml:"port"` + Items []string `mapstructure:"items" yaml:"items"` + } + + config := ` +level_1: + level_2: + port: 1234 + items: + - "test 1" +` + v := New() + v.SetDefault("level_1.level_2.delay", 50) + v.SetDefault("level_1.level_2.port", 9999) + v.SetDefault("level_1.level_2.items", []string{}) + + initConfig(v, "yaml", config) + + // manually overwrite some settings + v.Set("level_1.level_2.items", []string{"test_2", "test_3"}) + + data := testStruct{} + err := v.UnmarshalKey("level_1.level_2", &data) + if err != nil { + t.Fatalf("unable to decode into struct, %v", err) + } + + assert.Equal(t, 50, data.Delay) // value from defaults + assert.Equal(t, 1234, data.Port) // value from config + assert.Equal(t, []string{"test_2", "test_3"}, data.Items) // value from Set() +} + func doTestCaseInsensitive(t *testing.T, typ, config string) { v := New() initConfig(v, typ, config)