Fix Viper 'UnmarshalKey' behavior to include all sources (#32)

Fix Viper 'UnmarshalKey' behavior to include all sources

Viper 'UnmarshalKey' seems to have been designed to work on leaf
settings only. When being called on a group of settings, it would not
merge all sources (default, env vars, config, override, ...).

This means that using the `Set()` method could change the returned value
from 'UnmarshalKey' in unexpected ways (i.e. would no longer return the
value from the configuration).
This commit is contained in:
maxime mouial 2024-04-22 14:54:57 +02:00 committed by GitHub
parent 3e7837fd38
commit c3db1c77f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 188 additions and 88 deletions

241
viper.go
View file

@ -864,11 +864,19 @@ func (v *Viper) GetSizeInBytesE(key string) (uint, error) {
func UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error {
return v.UnmarshalKey(key, rawVal, opts...)
}
func (v *Viper) UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error {
err := decode(v.Get(key), defaultDecoderConfig(rawVal, opts...))
if err != nil {
return err
func (v *Viper) UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error {
// We first get all value we have for a key (from default to override)
lcaseKey := strings.ToLower(key)
values := v.findAll(lcaseKey)
// findAll returns all the value for a settings from highest priority to lowest. So when aggregating them we want
// to start at the lowest priority and override with higher ones.
for idx := len(values) - 1; idx >= 0; idx-- {
err := decode(values[idx], defaultDecoderConfig(rawVal, opts...))
if err != nil {
return err
}
}
return nil
@ -1001,20 +1009,106 @@ func (v *Viper) BindEnv(input ...string) error {
return nil
}
// Given a key, find the value.
// Viper will check in the following order:
// flag, env, config file, key/value store, default.
// If skipDefault is set to true, find will ignore default values.
// Viper will check to see if an alias exists first.
// Note: this assumes a lower-cased key given.
func (v *Viper) find(lcaseKey string, skipDefault bool) interface{} {
// findFromOverride finds a key from override (ie: from 'Set()')
func (v *Viper) findFromOverride(_ string, path []string, nested bool) (interface{}, bool) {
if val := v.searchMap(v.override, path); val != nil {
return val, true
}
if nested && v.isPathShadowedInDeepMap(path, v.override) != "" {
return nil, true
}
return nil, false
}
var (
val interface{}
exists bool
path = strings.Split(lcaseKey, v.keyDelim)
nested = len(path) > 1
)
func (v *Viper) findFromPFlag(lcaseKey string, path []string, nested bool) (interface{}, bool) {
flag, exists := v.pflags[lcaseKey]
if exists && flag.HasChanged() {
switch flag.ValueType() {
case "int", "int8", "int16", "int32", "int64":
return cast.ToInt(flag.ValueString()), true
case "bool":
return cast.ToBool(flag.ValueString()), true
case "stringSlice":
s := strings.TrimPrefix(flag.ValueString(), "[")
s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return res, true
default:
return flag.ValueString(), true
}
}
if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" {
return nil, true
}
return nil, false
}
func (v *Viper) findFromEnv(lcaseKey string, path []string, nested bool) (interface{}, bool) {
if v.automaticEnvApplied {
// even if it hasn't been registered, if automaticEnv is used,
// check any Get request
if val, ok := v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); ok {
return val, true
}
if nested && v.isPathShadowedInAutoEnv(path) != "" {
return nil, true
}
}
envkeys, exists := v.env[lcaseKey]
if exists {
for _, key := range envkeys {
if val, ok := v.getEnv(key); ok {
if fn, ok := v.envTransform[lcaseKey]; ok {
return fn(val), true
}
return val, true
}
}
}
if nested && v.isPathShadowedInFlatMap(path, v.env) != "" {
return nil, true
}
return nil, false
}
func (v *Viper) findFromConfig(_ string, path []string, nested bool) (interface{}, bool) {
val := v.searchMapWithPathPrefixes(v.config, path)
if val != nil {
return val, true
}
if nested && v.isPathShadowedInDeepMap(path, v.config) != "" {
return nil, true
}
return nil, false
}
func (v *Viper) findFromKVStore(_ string, path []string, nested bool) (interface{}, bool) {
val := v.searchMap(v.kvstore, path)
if val != nil {
return val, true
}
if nested && v.isPathShadowedInDeepMap(path, v.kvstore) != "" {
return nil, true
}
return nil, false
}
func (v *Viper) findFromDefault(_ string, path []string, nested bool) (interface{}, bool) {
val := v.searchMap(v.defaults, path)
if val != nil {
return val, true
}
if nested && v.isPathShadowedInDeepMap(path, v.defaults) != "" {
return nil, true
}
return nil, false
}
// findAll returns all the value for a settings from highest priority to lowest
func (v *Viper) findAll(lcaseKey string) []interface{} {
var values []interface{}
path := strings.Split(lcaseKey, v.keyDelim)
nested := len(path) > 1
// compute the path through the nested maps to the nested value
if nested && v.isPathShadowedInDeepMap(path, castMapStringToMapInterface(v.aliases)) != "" {
@ -1026,89 +1120,60 @@ func (v *Viper) find(lcaseKey string, skipDefault bool) interface{} {
path = strings.Split(lcaseKey, v.keyDelim)
nested = len(path) > 1
// Set() override first
val = v.searchMap(v.override, path)
if val != nil {
return val
}
if nested && v.isPathShadowedInDeepMap(path, v.override) != "" {
return nil
getters := []func(string, []string, bool) (interface{}, bool){
v.findFromOverride, // override == value from Set()
v.findFromPFlag,
v.findFromEnv,
v.findFromConfig,
v.findFromKVStore,
v.findFromDefault,
}
// PFlag override next
flag, exists := v.pflags[lcaseKey]
if exists && flag.HasChanged() {
switch flag.ValueType() {
case "int", "int8", "int16", "int32", "int64":
return cast.ToInt(flag.ValueString())
case "bool":
return cast.ToBool(flag.ValueString())
case "stringSlice":
s := strings.TrimPrefix(flag.ValueString(), "[")
s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return res
default:
return flag.ValueString()
for _, getter := range getters {
if val, found := getter(lcaseKey, path, nested); found {
values = append(values, val)
}
}
if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" {
return values
}
// Given a key, find the value.
// Viper will check in the following order:
// flag, env, config file, key/value store, default.
// If skipDefault is set to true, find will ignore default values.
// Viper will check to see if an alias exists first.
// Note: this assumes a lower-cased key given.
func (v *Viper) find(lcaseKey string, skipDefault bool) interface{} {
path := strings.Split(lcaseKey, v.keyDelim)
nested := len(path) > 1
// compute the path through the nested maps to the nested value
if nested && v.isPathShadowedInDeepMap(path, castMapStringToMapInterface(v.aliases)) != "" {
return nil
}
// Env override next
if v.automaticEnvApplied {
// even if it hasn't been registered, if automaticEnv is used,
// check any Get request
if val, ok := v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); ok {
return val
}
if nested && v.isPathShadowedInAutoEnv(path) != "" {
return nil
}
}
envkeys, exists := v.env[lcaseKey]
if exists {
for _, key := range envkeys {
if val, ok := v.getEnv(key); ok {
if fn, ok := v.envTransform[lcaseKey]; ok {
return fn(val)
}
return val
}
}
}
if nested && v.isPathShadowedInFlatMap(path, v.env) != "" {
return nil
// if the requested key is an alias, then return the proper key
lcaseKey = v.realKey(lcaseKey)
path = strings.Split(lcaseKey, v.keyDelim)
nested = len(path) > 1
getters := []func(string, []string, bool) (interface{}, bool){
v.findFromOverride, // override == value from Set()
v.findFromPFlag,
v.findFromEnv,
v.findFromConfig,
v.findFromKVStore,
}
// Config file next
val = v.searchMapWithPathPrefixes(v.config, path)
if val != nil {
return val
}
if nested && v.isPathShadowedInDeepMap(path, v.config) != "" {
return nil
}
// K/V store next
val = v.searchMap(v.kvstore, path)
if val != nil {
return val
}
if nested && v.isPathShadowedInDeepMap(path, v.kvstore) != "" {
return nil
}
// Default next
if !skipDefault {
val = v.searchMap(v.defaults, path)
if val != nil {
getters = append(getters, v.findFromDefault)
}
for _, getter := range getters {
if val, found := getter(lcaseKey, path, nested); found {
return val
}
if nested && v.isPathShadowedInDeepMap(path, v.defaults) != "" {
return nil
}
}
// last chance: if no other value is returned and a flag does exist for the value,

View file

@ -1703,6 +1703,41 @@ func TestParseNested(t *testing.T) {
assert.Equal(t, 200*time.Millisecond, items[0].Nested.Delay)
}
func TestUnmarshalKey(t *testing.T) {
type testStruct struct {
Delay int `mapstructure:"delay" yaml:"delay"`
Port int `mapstructure:"port" yaml:"port"`
Items []string `mapstructure:"items" yaml:"items"`
}
config := `
level_1:
level_2:
port: 1234
items:
- "test 1"
`
v := New()
v.SetDefault("level_1.level_2.delay", 50)
v.SetDefault("level_1.level_2.port", 9999)
v.SetDefault("level_1.level_2.items", []string{})
initConfig(v, "yaml", config)
// manually overwrite some settings
v.Set("level_1.level_2.items", []string{"test_2", "test_3"})
data := testStruct{}
err := v.UnmarshalKey("level_1.level_2", &data)
if err != nil {
t.Fatalf("unable to decode into struct, %v", err)
}
assert.Equal(t, 50, data.Delay) // value from defaults
assert.Equal(t, 1234, data.Port) // value from config
assert.Equal(t, []string{"test_2", "test_3"}, data.Items) // value from Set()
}
func doTestCaseInsensitive(t *testing.T, typ, config string) {
v := New()
initConfig(v, typ, config)