mirror of
https://github.com/spf13/viper
synced 2025-05-05 19:57:18 +00:00
Fix Viper 'UnmarshalKey' behavior to include all sources (#32)
Fix Viper 'UnmarshalKey' behavior to include all sources Viper 'UnmarshalKey' seems to have been designed to work on leaf settings only. When being called on a group of settings, it would not merge all sources (default, env vars, config, override, ...). This means that using the `Set()` method could change the returned value from 'UnmarshalKey' in unexpected ways (i.e. would no longer return the value from the configuration).
This commit is contained in:
parent
3e7837fd38
commit
c3db1c77f1
2 changed files with 188 additions and 88 deletions
241
viper.go
241
viper.go
|
@ -864,11 +864,19 @@ func (v *Viper) GetSizeInBytesE(key string) (uint, error) {
|
|||
func UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error {
|
||||
return v.UnmarshalKey(key, rawVal, opts...)
|
||||
}
|
||||
func (v *Viper) UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error {
|
||||
err := decode(v.Get(key), defaultDecoderConfig(rawVal, opts...))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
func (v *Viper) UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error {
|
||||
// We first get all value we have for a key (from default to override)
|
||||
lcaseKey := strings.ToLower(key)
|
||||
values := v.findAll(lcaseKey)
|
||||
|
||||
// findAll returns all the value for a settings from highest priority to lowest. So when aggregating them we want
|
||||
// to start at the lowest priority and override with higher ones.
|
||||
for idx := len(values) - 1; idx >= 0; idx-- {
|
||||
err := decode(values[idx], defaultDecoderConfig(rawVal, opts...))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -1001,20 +1009,106 @@ func (v *Viper) BindEnv(input ...string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Given a key, find the value.
|
||||
// Viper will check in the following order:
|
||||
// flag, env, config file, key/value store, default.
|
||||
// If skipDefault is set to true, find will ignore default values.
|
||||
// Viper will check to see if an alias exists first.
|
||||
// Note: this assumes a lower-cased key given.
|
||||
func (v *Viper) find(lcaseKey string, skipDefault bool) interface{} {
|
||||
// findFromOverride finds a key from override (ie: from 'Set()')
|
||||
func (v *Viper) findFromOverride(_ string, path []string, nested bool) (interface{}, bool) {
|
||||
if val := v.searchMap(v.override, path); val != nil {
|
||||
return val, true
|
||||
}
|
||||
if nested && v.isPathShadowedInDeepMap(path, v.override) != "" {
|
||||
return nil, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var (
|
||||
val interface{}
|
||||
exists bool
|
||||
path = strings.Split(lcaseKey, v.keyDelim)
|
||||
nested = len(path) > 1
|
||||
)
|
||||
func (v *Viper) findFromPFlag(lcaseKey string, path []string, nested bool) (interface{}, bool) {
|
||||
flag, exists := v.pflags[lcaseKey]
|
||||
if exists && flag.HasChanged() {
|
||||
switch flag.ValueType() {
|
||||
case "int", "int8", "int16", "int32", "int64":
|
||||
return cast.ToInt(flag.ValueString()), true
|
||||
case "bool":
|
||||
return cast.ToBool(flag.ValueString()), true
|
||||
case "stringSlice":
|
||||
s := strings.TrimPrefix(flag.ValueString(), "[")
|
||||
s = strings.TrimSuffix(s, "]")
|
||||
res, _ := readAsCSV(s)
|
||||
return res, true
|
||||
default:
|
||||
return flag.ValueString(), true
|
||||
}
|
||||
}
|
||||
if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" {
|
||||
return nil, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (v *Viper) findFromEnv(lcaseKey string, path []string, nested bool) (interface{}, bool) {
|
||||
if v.automaticEnvApplied {
|
||||
// even if it hasn't been registered, if automaticEnv is used,
|
||||
// check any Get request
|
||||
if val, ok := v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); ok {
|
||||
return val, true
|
||||
}
|
||||
if nested && v.isPathShadowedInAutoEnv(path) != "" {
|
||||
return nil, true
|
||||
}
|
||||
}
|
||||
envkeys, exists := v.env[lcaseKey]
|
||||
if exists {
|
||||
for _, key := range envkeys {
|
||||
if val, ok := v.getEnv(key); ok {
|
||||
if fn, ok := v.envTransform[lcaseKey]; ok {
|
||||
return fn(val), true
|
||||
}
|
||||
return val, true
|
||||
}
|
||||
}
|
||||
}
|
||||
if nested && v.isPathShadowedInFlatMap(path, v.env) != "" {
|
||||
return nil, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (v *Viper) findFromConfig(_ string, path []string, nested bool) (interface{}, bool) {
|
||||
val := v.searchMapWithPathPrefixes(v.config, path)
|
||||
if val != nil {
|
||||
return val, true
|
||||
}
|
||||
if nested && v.isPathShadowedInDeepMap(path, v.config) != "" {
|
||||
return nil, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (v *Viper) findFromKVStore(_ string, path []string, nested bool) (interface{}, bool) {
|
||||
val := v.searchMap(v.kvstore, path)
|
||||
if val != nil {
|
||||
return val, true
|
||||
}
|
||||
if nested && v.isPathShadowedInDeepMap(path, v.kvstore) != "" {
|
||||
return nil, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (v *Viper) findFromDefault(_ string, path []string, nested bool) (interface{}, bool) {
|
||||
val := v.searchMap(v.defaults, path)
|
||||
if val != nil {
|
||||
return val, true
|
||||
}
|
||||
if nested && v.isPathShadowedInDeepMap(path, v.defaults) != "" {
|
||||
return nil, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findAll returns all the value for a settings from highest priority to lowest
|
||||
func (v *Viper) findAll(lcaseKey string) []interface{} {
|
||||
var values []interface{}
|
||||
path := strings.Split(lcaseKey, v.keyDelim)
|
||||
nested := len(path) > 1
|
||||
|
||||
// compute the path through the nested maps to the nested value
|
||||
if nested && v.isPathShadowedInDeepMap(path, castMapStringToMapInterface(v.aliases)) != "" {
|
||||
|
@ -1026,89 +1120,60 @@ func (v *Viper) find(lcaseKey string, skipDefault bool) interface{} {
|
|||
path = strings.Split(lcaseKey, v.keyDelim)
|
||||
nested = len(path) > 1
|
||||
|
||||
// Set() override first
|
||||
val = v.searchMap(v.override, path)
|
||||
if val != nil {
|
||||
return val
|
||||
}
|
||||
if nested && v.isPathShadowedInDeepMap(path, v.override) != "" {
|
||||
return nil
|
||||
getters := []func(string, []string, bool) (interface{}, bool){
|
||||
v.findFromOverride, // override == value from Set()
|
||||
v.findFromPFlag,
|
||||
v.findFromEnv,
|
||||
v.findFromConfig,
|
||||
v.findFromKVStore,
|
||||
v.findFromDefault,
|
||||
}
|
||||
|
||||
// PFlag override next
|
||||
flag, exists := v.pflags[lcaseKey]
|
||||
if exists && flag.HasChanged() {
|
||||
switch flag.ValueType() {
|
||||
case "int", "int8", "int16", "int32", "int64":
|
||||
return cast.ToInt(flag.ValueString())
|
||||
case "bool":
|
||||
return cast.ToBool(flag.ValueString())
|
||||
case "stringSlice":
|
||||
s := strings.TrimPrefix(flag.ValueString(), "[")
|
||||
s = strings.TrimSuffix(s, "]")
|
||||
res, _ := readAsCSV(s)
|
||||
return res
|
||||
default:
|
||||
return flag.ValueString()
|
||||
for _, getter := range getters {
|
||||
if val, found := getter(lcaseKey, path, nested); found {
|
||||
values = append(values, val)
|
||||
}
|
||||
}
|
||||
if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" {
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
// Given a key, find the value.
|
||||
// Viper will check in the following order:
|
||||
// flag, env, config file, key/value store, default.
|
||||
// If skipDefault is set to true, find will ignore default values.
|
||||
// Viper will check to see if an alias exists first.
|
||||
// Note: this assumes a lower-cased key given.
|
||||
func (v *Viper) find(lcaseKey string, skipDefault bool) interface{} {
|
||||
path := strings.Split(lcaseKey, v.keyDelim)
|
||||
nested := len(path) > 1
|
||||
|
||||
// compute the path through the nested maps to the nested value
|
||||
if nested && v.isPathShadowedInDeepMap(path, castMapStringToMapInterface(v.aliases)) != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Env override next
|
||||
if v.automaticEnvApplied {
|
||||
// even if it hasn't been registered, if automaticEnv is used,
|
||||
// check any Get request
|
||||
if val, ok := v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); ok {
|
||||
return val
|
||||
}
|
||||
if nested && v.isPathShadowedInAutoEnv(path) != "" {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
envkeys, exists := v.env[lcaseKey]
|
||||
if exists {
|
||||
for _, key := range envkeys {
|
||||
if val, ok := v.getEnv(key); ok {
|
||||
if fn, ok := v.envTransform[lcaseKey]; ok {
|
||||
return fn(val)
|
||||
}
|
||||
return val
|
||||
}
|
||||
}
|
||||
}
|
||||
if nested && v.isPathShadowedInFlatMap(path, v.env) != "" {
|
||||
return nil
|
||||
// if the requested key is an alias, then return the proper key
|
||||
lcaseKey = v.realKey(lcaseKey)
|
||||
path = strings.Split(lcaseKey, v.keyDelim)
|
||||
nested = len(path) > 1
|
||||
|
||||
getters := []func(string, []string, bool) (interface{}, bool){
|
||||
v.findFromOverride, // override == value from Set()
|
||||
v.findFromPFlag,
|
||||
v.findFromEnv,
|
||||
v.findFromConfig,
|
||||
v.findFromKVStore,
|
||||
}
|
||||
|
||||
// Config file next
|
||||
val = v.searchMapWithPathPrefixes(v.config, path)
|
||||
if val != nil {
|
||||
return val
|
||||
}
|
||||
if nested && v.isPathShadowedInDeepMap(path, v.config) != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// K/V store next
|
||||
val = v.searchMap(v.kvstore, path)
|
||||
if val != nil {
|
||||
return val
|
||||
}
|
||||
if nested && v.isPathShadowedInDeepMap(path, v.kvstore) != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Default next
|
||||
if !skipDefault {
|
||||
val = v.searchMap(v.defaults, path)
|
||||
if val != nil {
|
||||
getters = append(getters, v.findFromDefault)
|
||||
}
|
||||
|
||||
for _, getter := range getters {
|
||||
if val, found := getter(lcaseKey, path, nested); found {
|
||||
return val
|
||||
}
|
||||
if nested && v.isPathShadowedInDeepMap(path, v.defaults) != "" {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// last chance: if no other value is returned and a flag does exist for the value,
|
||||
|
|
|
@ -1703,6 +1703,41 @@ func TestParseNested(t *testing.T) {
|
|||
assert.Equal(t, 200*time.Millisecond, items[0].Nested.Delay)
|
||||
}
|
||||
|
||||
func TestUnmarshalKey(t *testing.T) {
|
||||
type testStruct struct {
|
||||
Delay int `mapstructure:"delay" yaml:"delay"`
|
||||
Port int `mapstructure:"port" yaml:"port"`
|
||||
Items []string `mapstructure:"items" yaml:"items"`
|
||||
}
|
||||
|
||||
config := `
|
||||
level_1:
|
||||
level_2:
|
||||
port: 1234
|
||||
items:
|
||||
- "test 1"
|
||||
`
|
||||
v := New()
|
||||
v.SetDefault("level_1.level_2.delay", 50)
|
||||
v.SetDefault("level_1.level_2.port", 9999)
|
||||
v.SetDefault("level_1.level_2.items", []string{})
|
||||
|
||||
initConfig(v, "yaml", config)
|
||||
|
||||
// manually overwrite some settings
|
||||
v.Set("level_1.level_2.items", []string{"test_2", "test_3"})
|
||||
|
||||
data := testStruct{}
|
||||
err := v.UnmarshalKey("level_1.level_2", &data)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to decode into struct, %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 50, data.Delay) // value from defaults
|
||||
assert.Equal(t, 1234, data.Port) // value from config
|
||||
assert.Equal(t, []string{"test_2", "test_3"}, data.Items) // value from Set()
|
||||
}
|
||||
|
||||
func doTestCaseInsensitive(t *testing.T, typ, config string) {
|
||||
v := New()
|
||||
initConfig(v, typ, config)
|
||||
|
|
Loading…
Add table
Reference in a new issue