mirror of
https://github.com/spf13/viper
synced 2025-05-05 19:57:18 +00:00
Fix Viper 'UnmarshalKey' behavior to include all sources (#32)
Fix Viper 'UnmarshalKey' behavior to include all sources Viper 'UnmarshalKey' seems to have been designed to work on leaf settings only. When being called on a group of settings, it would not merge all sources (default, env vars, config, override, ...). This means that using the `Set()` method could change the returned value from 'UnmarshalKey' in unexpected ways (i.e. would no longer return the value from the configuration).
This commit is contained in:
parent
3e7837fd38
commit
c3db1c77f1
2 changed files with 188 additions and 88 deletions
243
viper.go
243
viper.go
|
@ -864,12 +864,20 @@ func (v *Viper) GetSizeInBytesE(key string) (uint, error) {
|
||||||
func UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error {
|
func UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error {
|
||||||
return v.UnmarshalKey(key, rawVal, opts...)
|
return v.UnmarshalKey(key, rawVal, opts...)
|
||||||
}
|
}
|
||||||
func (v *Viper) UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error {
|
|
||||||
err := decode(v.Get(key), defaultDecoderConfig(rawVal, opts...))
|
|
||||||
|
|
||||||
|
func (v *Viper) UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error {
|
||||||
|
// We first get all value we have for a key (from default to override)
|
||||||
|
lcaseKey := strings.ToLower(key)
|
||||||
|
values := v.findAll(lcaseKey)
|
||||||
|
|
||||||
|
// findAll returns all the value for a settings from highest priority to lowest. So when aggregating them we want
|
||||||
|
// to start at the lowest priority and override with higher ones.
|
||||||
|
for idx := len(values) - 1; idx >= 0; idx-- {
|
||||||
|
err := decode(values[idx], defaultDecoderConfig(rawVal, opts...))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1001,20 +1009,106 @@ func (v *Viper) BindEnv(input ...string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given a key, find the value.
|
// findFromOverride finds a key from override (ie: from 'Set()')
|
||||||
// Viper will check in the following order:
|
func (v *Viper) findFromOverride(_ string, path []string, nested bool) (interface{}, bool) {
|
||||||
// flag, env, config file, key/value store, default.
|
if val := v.searchMap(v.override, path); val != nil {
|
||||||
// If skipDefault is set to true, find will ignore default values.
|
return val, true
|
||||||
// Viper will check to see if an alias exists first.
|
}
|
||||||
// Note: this assumes a lower-cased key given.
|
if nested && v.isPathShadowedInDeepMap(path, v.override) != "" {
|
||||||
func (v *Viper) find(lcaseKey string, skipDefault bool) interface{} {
|
return nil, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
func (v *Viper) findFromPFlag(lcaseKey string, path []string, nested bool) (interface{}, bool) {
|
||||||
val interface{}
|
flag, exists := v.pflags[lcaseKey]
|
||||||
exists bool
|
if exists && flag.HasChanged() {
|
||||||
path = strings.Split(lcaseKey, v.keyDelim)
|
switch flag.ValueType() {
|
||||||
nested = len(path) > 1
|
case "int", "int8", "int16", "int32", "int64":
|
||||||
)
|
return cast.ToInt(flag.ValueString()), true
|
||||||
|
case "bool":
|
||||||
|
return cast.ToBool(flag.ValueString()), true
|
||||||
|
case "stringSlice":
|
||||||
|
s := strings.TrimPrefix(flag.ValueString(), "[")
|
||||||
|
s = strings.TrimSuffix(s, "]")
|
||||||
|
res, _ := readAsCSV(s)
|
||||||
|
return res, true
|
||||||
|
default:
|
||||||
|
return flag.ValueString(), true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Viper) findFromEnv(lcaseKey string, path []string, nested bool) (interface{}, bool) {
|
||||||
|
if v.automaticEnvApplied {
|
||||||
|
// even if it hasn't been registered, if automaticEnv is used,
|
||||||
|
// check any Get request
|
||||||
|
if val, ok := v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); ok {
|
||||||
|
return val, true
|
||||||
|
}
|
||||||
|
if nested && v.isPathShadowedInAutoEnv(path) != "" {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
envkeys, exists := v.env[lcaseKey]
|
||||||
|
if exists {
|
||||||
|
for _, key := range envkeys {
|
||||||
|
if val, ok := v.getEnv(key); ok {
|
||||||
|
if fn, ok := v.envTransform[lcaseKey]; ok {
|
||||||
|
return fn(val), true
|
||||||
|
}
|
||||||
|
return val, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if nested && v.isPathShadowedInFlatMap(path, v.env) != "" {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Viper) findFromConfig(_ string, path []string, nested bool) (interface{}, bool) {
|
||||||
|
val := v.searchMapWithPathPrefixes(v.config, path)
|
||||||
|
if val != nil {
|
||||||
|
return val, true
|
||||||
|
}
|
||||||
|
if nested && v.isPathShadowedInDeepMap(path, v.config) != "" {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Viper) findFromKVStore(_ string, path []string, nested bool) (interface{}, bool) {
|
||||||
|
val := v.searchMap(v.kvstore, path)
|
||||||
|
if val != nil {
|
||||||
|
return val, true
|
||||||
|
}
|
||||||
|
if nested && v.isPathShadowedInDeepMap(path, v.kvstore) != "" {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Viper) findFromDefault(_ string, path []string, nested bool) (interface{}, bool) {
|
||||||
|
val := v.searchMap(v.defaults, path)
|
||||||
|
if val != nil {
|
||||||
|
return val, true
|
||||||
|
}
|
||||||
|
if nested && v.isPathShadowedInDeepMap(path, v.defaults) != "" {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// findAll returns all the value for a settings from highest priority to lowest
|
||||||
|
func (v *Viper) findAll(lcaseKey string) []interface{} {
|
||||||
|
var values []interface{}
|
||||||
|
path := strings.Split(lcaseKey, v.keyDelim)
|
||||||
|
nested := len(path) > 1
|
||||||
|
|
||||||
// compute the path through the nested maps to the nested value
|
// compute the path through the nested maps to the nested value
|
||||||
if nested && v.isPathShadowedInDeepMap(path, castMapStringToMapInterface(v.aliases)) != "" {
|
if nested && v.isPathShadowedInDeepMap(path, castMapStringToMapInterface(v.aliases)) != "" {
|
||||||
|
@ -1026,88 +1120,59 @@ func (v *Viper) find(lcaseKey string, skipDefault bool) interface{} {
|
||||||
path = strings.Split(lcaseKey, v.keyDelim)
|
path = strings.Split(lcaseKey, v.keyDelim)
|
||||||
nested = len(path) > 1
|
nested = len(path) > 1
|
||||||
|
|
||||||
// Set() override first
|
getters := []func(string, []string, bool) (interface{}, bool){
|
||||||
val = v.searchMap(v.override, path)
|
v.findFromOverride, // override == value from Set()
|
||||||
if val != nil {
|
v.findFromPFlag,
|
||||||
return val
|
v.findFromEnv,
|
||||||
|
v.findFromConfig,
|
||||||
|
v.findFromKVStore,
|
||||||
|
v.findFromDefault,
|
||||||
}
|
}
|
||||||
if nested && v.isPathShadowedInDeepMap(path, v.override) != "" {
|
|
||||||
|
for _, getter := range getters {
|
||||||
|
if val, found := getter(lcaseKey, path, nested); found {
|
||||||
|
values = append(values, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given a key, find the value.
|
||||||
|
// Viper will check in the following order:
|
||||||
|
// flag, env, config file, key/value store, default.
|
||||||
|
// If skipDefault is set to true, find will ignore default values.
|
||||||
|
// Viper will check to see if an alias exists first.
|
||||||
|
// Note: this assumes a lower-cased key given.
|
||||||
|
func (v *Viper) find(lcaseKey string, skipDefault bool) interface{} {
|
||||||
|
path := strings.Split(lcaseKey, v.keyDelim)
|
||||||
|
nested := len(path) > 1
|
||||||
|
|
||||||
|
// compute the path through the nested maps to the nested value
|
||||||
|
if nested && v.isPathShadowedInDeepMap(path, castMapStringToMapInterface(v.aliases)) != "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PFlag override next
|
// if the requested key is an alias, then return the proper key
|
||||||
flag, exists := v.pflags[lcaseKey]
|
lcaseKey = v.realKey(lcaseKey)
|
||||||
if exists && flag.HasChanged() {
|
path = strings.Split(lcaseKey, v.keyDelim)
|
||||||
switch flag.ValueType() {
|
nested = len(path) > 1
|
||||||
case "int", "int8", "int16", "int32", "int64":
|
|
||||||
return cast.ToInt(flag.ValueString())
|
getters := []func(string, []string, bool) (interface{}, bool){
|
||||||
case "bool":
|
v.findFromOverride, // override == value from Set()
|
||||||
return cast.ToBool(flag.ValueString())
|
v.findFromPFlag,
|
||||||
case "stringSlice":
|
v.findFromEnv,
|
||||||
s := strings.TrimPrefix(flag.ValueString(), "[")
|
v.findFromConfig,
|
||||||
s = strings.TrimSuffix(s, "]")
|
v.findFromKVStore,
|
||||||
res, _ := readAsCSV(s)
|
|
||||||
return res
|
|
||||||
default:
|
|
||||||
return flag.ValueString()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Env override next
|
|
||||||
if v.automaticEnvApplied {
|
|
||||||
// even if it hasn't been registered, if automaticEnv is used,
|
|
||||||
// check any Get request
|
|
||||||
if val, ok := v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); ok {
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
if nested && v.isPathShadowedInAutoEnv(path) != "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
envkeys, exists := v.env[lcaseKey]
|
|
||||||
if exists {
|
|
||||||
for _, key := range envkeys {
|
|
||||||
if val, ok := v.getEnv(key); ok {
|
|
||||||
if fn, ok := v.envTransform[lcaseKey]; ok {
|
|
||||||
return fn(val)
|
|
||||||
}
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if nested && v.isPathShadowedInFlatMap(path, v.env) != "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Config file next
|
|
||||||
val = v.searchMapWithPathPrefixes(v.config, path)
|
|
||||||
if val != nil {
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
if nested && v.isPathShadowedInDeepMap(path, v.config) != "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// K/V store next
|
|
||||||
val = v.searchMap(v.kvstore, path)
|
|
||||||
if val != nil {
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
if nested && v.isPathShadowedInDeepMap(path, v.kvstore) != "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default next
|
|
||||||
if !skipDefault {
|
if !skipDefault {
|
||||||
val = v.searchMap(v.defaults, path)
|
getters = append(getters, v.findFromDefault)
|
||||||
if val != nil {
|
|
||||||
return val
|
|
||||||
}
|
}
|
||||||
if nested && v.isPathShadowedInDeepMap(path, v.defaults) != "" {
|
|
||||||
return nil
|
for _, getter := range getters {
|
||||||
|
if val, found := getter(lcaseKey, path, nested); found {
|
||||||
|
return val
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1703,6 +1703,41 @@ func TestParseNested(t *testing.T) {
|
||||||
assert.Equal(t, 200*time.Millisecond, items[0].Nested.Delay)
|
assert.Equal(t, 200*time.Millisecond, items[0].Nested.Delay)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalKey(t *testing.T) {
|
||||||
|
type testStruct struct {
|
||||||
|
Delay int `mapstructure:"delay" yaml:"delay"`
|
||||||
|
Port int `mapstructure:"port" yaml:"port"`
|
||||||
|
Items []string `mapstructure:"items" yaml:"items"`
|
||||||
|
}
|
||||||
|
|
||||||
|
config := `
|
||||||
|
level_1:
|
||||||
|
level_2:
|
||||||
|
port: 1234
|
||||||
|
items:
|
||||||
|
- "test 1"
|
||||||
|
`
|
||||||
|
v := New()
|
||||||
|
v.SetDefault("level_1.level_2.delay", 50)
|
||||||
|
v.SetDefault("level_1.level_2.port", 9999)
|
||||||
|
v.SetDefault("level_1.level_2.items", []string{})
|
||||||
|
|
||||||
|
initConfig(v, "yaml", config)
|
||||||
|
|
||||||
|
// manually overwrite some settings
|
||||||
|
v.Set("level_1.level_2.items", []string{"test_2", "test_3"})
|
||||||
|
|
||||||
|
data := testStruct{}
|
||||||
|
err := v.UnmarshalKey("level_1.level_2", &data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to decode into struct, %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 50, data.Delay) // value from defaults
|
||||||
|
assert.Equal(t, 1234, data.Port) // value from config
|
||||||
|
assert.Equal(t, []string{"test_2", "test_3"}, data.Items) // value from Set()
|
||||||
|
}
|
||||||
|
|
||||||
func doTestCaseInsensitive(t *testing.T, typ, config string) {
|
func doTestCaseInsensitive(t *testing.T, typ, config string) {
|
||||||
v := New()
|
v := New()
|
||||||
initConfig(v, typ, config)
|
initConfig(v, typ, config)
|
||||||
|
|
Loading…
Add table
Reference in a new issue