From d3bdf712421ca101a82b5a0e57d86fca32dfc1b5 Mon Sep 17 00:00:00 2001 From: Maxime mouial Date: Wed, 24 Apr 2024 15:44:06 +0200 Subject: [PATCH] Fix Viper 'UnmarshalKey' behavior to include all sources Viper 'UnmarshalKey' seems to have been designed to work on leaf settings only. When being called on a group of settings, it would not merge all sources (default, env vars, config, override, ...). This means that using the `Set()` method could change the returned value from 'UnmarshalKey' in unexpected ways (i.e. would no longer return the value from the configuration).% The second version of this fix correctly handles env vars override. --- viper.go | 24 +++++++++++++++++++----- viper_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/viper.go b/viper.go index b77e682..4f1580c 100644 --- a/viper.go +++ b/viper.go @@ -865,13 +865,27 @@ func UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) e return v.UnmarshalKey(key, rawVal, opts...) } func (v *Viper) UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error { - err := decode(v.Get(key), defaultDecoderConfig(rawVal, opts...)) + lcaseKey := strings.ToLower(key) - if err != nil { - return err + // AllSettings returns settings from every sources merged into one tree + settings := v.AllSettings() + + keyParts := strings.Split(lcaseKey, v.keyDelim) + for i := 0; i < len(keyParts)-1; i++ { + if value, found := settings[keyParts[i]]; found { + if valueMap, ok := value.(map[string]interface{}); ok { + settings = valueMap + continue + } + // if the current value is not a map[string]interface{} we most likely reach a + // leaf and the key/path is wrong + return fmt.Errorf("unknown key %s", lcaseKey) + } else { + return fmt.Errorf("unknown key %s", lcaseKey) + } } - - return nil + finalSetting := settings[keyParts[len(keyParts)-1]] + return decode(finalSetting, defaultDecoderConfig(rawVal, opts...)) } // Unmarshal unmarshals the config into a Struct. Make sure that the tags diff --git a/viper_test.go b/viper_test.go index 5507a6d..f41b413 100644 --- a/viper_test.go +++ b/viper_test.go @@ -1703,6 +1703,48 @@ func TestParseNested(t *testing.T) { assert.Equal(t, 200*time.Millisecond, items[0].Nested.Delay) } +func TestUnmarshalKey(t *testing.T) { + type testStruct struct { + Delay int `mapstructure:"delay" yaml:"delay"` + Port int `mapstructure:"port" yaml:"port"` + Host string `mapstructure:"host" yaml:"host"` + Items []string `mapstructure:"items" yaml:"items"` + } + + config := ` +level_1: + level_2: + port: 1234 + items: + - "test 1" +` + v := New() + v.SetDefault("level_1.level_2.delay", 50) + v.SetDefault("level_1.level_2.port", 9999) + v.SetDefault("level_1.level_2.host", "default") + v.SetDefault("level_1.level_2.items", []string{}) + + // Use env vars for some settings + v.BindEnv("level_1.level_2.host", "DD_TEST_HOST") + t.Setenv("DD_TEST_HOST", "dd.com") + + initConfig(v, "yaml", config) + + // manually overwrite some settings + v.Set("level_1.level_2.items", []string{"test_2", "test_3"}) + + data := testStruct{} + err := v.UnmarshalKey("level_1.level_2", &data) + if err != nil { + t.Fatalf("unable to decode into struct, %v", err) + } + + assert.Equal(t, 50, data.Delay) // value from defaults + assert.Equal(t, 1234, data.Port) // value from config + assert.Equal(t, "dd.com", data.Host) // value from env + assert.Equal(t, []string{"test_2", "test_3"}, data.Items) // value from Set() +} + func doTestCaseInsensitive(t *testing.T, typ, config string) { v := New() initConfig(v, typ, config)