From bde4ac488aa8a0c3fd3468c3d2606ba1bb85b99f Mon Sep 17 00:00:00 2001
From: Dmitry Irtegov <fat@nsu.ru>
Date: Tue, 3 Mar 2020 00:14:43 +0700
Subject: [PATCH] Add options KeyPreserveCase() and KeyNormalizer(func (string)
 string)

---
 util.go       | 40 +++++++++++-----------
 util_test.go  |  3 +-
 viper.go      | 94 +++++++++++++++++++++++++++++++++++----------------
 viper_test.go | 70 ++++++++++++++++++++++++++++++++++++++
 4 files changed, 157 insertions(+), 50 deletions(-)

diff --git a/util.go b/util.go
index 64e6575..f5272b0 100644
--- a/util.go
+++ b/util.go
@@ -31,31 +31,31 @@ func (pe ConfigParseError) Error() string {
 	return fmt.Sprintf("While parsing config: %s", pe.err.Error())
 }
 
-// toCaseInsensitiveValue checks if the value is a  map;
-// if so, create a copy and lower-case the keys recursively.
-func toCaseInsensitiveValue(value interface{}) interface{} {
+// toCaseInsensitiveValue checks if the value is a map;
+// if so, create a copy and lower-case (normalize) the keys recursively.
+func toCaseInsensitiveValue(value interface{}, normalize keyNormalizeHookType) interface{} {
 	switch v := value.(type) {
 	case map[interface{}]interface{}:
-		value = copyAndInsensitiviseMap(cast.ToStringMap(v))
+		value = copyAndInsensitiviseMap(cast.ToStringMap(v), normalize)
 	case map[string]interface{}:
-		value = copyAndInsensitiviseMap(v)
+		value = copyAndInsensitiviseMap(v, normalize)
 	}
 
 	return value
 }
 
 // copyAndInsensitiviseMap behaves like insensitiviseMap, but creates a copy of
-// any map it makes case insensitive.
-func copyAndInsensitiviseMap(m map[string]interface{}) map[string]interface{} {
+// any map it makes case insensitive (normalized).
+func copyAndInsensitiviseMap(m map[string]interface{}, normalize keyNormalizeHookType) map[string]interface{} {
 	nm := make(map[string]interface{})
 
 	for key, val := range m {
-		lkey := strings.ToLower(key)
+		lkey := normalize(key)
 		switch v := val.(type) {
 		case map[interface{}]interface{}:
-			nm[lkey] = copyAndInsensitiviseMap(cast.ToStringMap(v))
+			nm[lkey] = copyAndInsensitiviseMap(cast.ToStringMap(v), normalize)
 		case map[string]interface{}:
-			nm[lkey] = copyAndInsensitiviseMap(v)
+			nm[lkey] = copyAndInsensitiviseMap(v, normalize)
 		default:
 			nm[lkey] = v
 		}
@@ -64,26 +64,26 @@ func copyAndInsensitiviseMap(m map[string]interface{}) map[string]interface{} {
 	return nm
 }
 
-func insensitiviseVal(val interface{}) interface{} {
-	switch val.(type) {
+func insensitiviseVal(val interface{}, normalize keyNormalizeHookType) interface{} {
+	switch valT := val.(type) {
 	case map[interface{}]interface{}:
 		// nested map: cast and recursively insensitivise
 		val = cast.ToStringMap(val)
-		insensitiviseMap(val.(map[string]interface{}))
+		insensitiviseMap(val.(map[string]interface{}), normalize)
 	case map[string]interface{}:
 		// nested map: recursively insensitivise
-		insensitiviseMap(val.(map[string]interface{}))
+		insensitiviseMap(valT, normalize)
 	case []interface{}:
 		// nested array: recursively insensitivise
-		insensitiveArray(val.([]interface{}))
+		insensitiveArray(valT, normalize)
 	}
 	return val
 }
 
-func insensitiviseMap(m map[string]interface{}) {
+func insensitiviseMap(m map[string]interface{}, normalize keyNormalizeHookType) {
 	for key, val := range m {
-		val = insensitiviseVal(val)
-		lower := strings.ToLower(key)
+		val = insensitiviseVal(val, normalize)
+		lower := normalize(key)
 		if key != lower {
 			// remove old key (not lower-cased)
 			delete(m, key)
@@ -93,9 +93,9 @@ func insensitiviseMap(m map[string]interface{}) {
 	}
 }
 
-func insensitiveArray(a []interface{}) {
+func insensitiveArray(a []interface{}, normalize keyNormalizeHookType) {
 	for i, val := range a {
-		a[i] = insensitiviseVal(val)
+		a[i] = insensitiviseVal(val, normalize)
 	}
 }
 
diff --git a/util_test.go b/util_test.go
index cb4e620..889bdf7 100644
--- a/util_test.go
+++ b/util_test.go
@@ -14,6 +14,7 @@ import (
 	"os"
 	"path/filepath"
 	"reflect"
+	"strings"
 	"testing"
 
 	"github.com/spf13/viper/internal/testutil"
@@ -37,7 +38,7 @@ func TestCopyAndInsensitiviseMap(t *testing.T) {
 		}
 	)
 
-	got := copyAndInsensitiviseMap(given)
+	got := copyAndInsensitiviseMap(given, strings.ToLower)
 
 	if !reflect.DeepEqual(got, expected) {
 		t.Fatalf("Got %q\nexpected\n%q", got, expected)
diff --git a/viper.go b/viper.go
index fa6f3e3..2f20ce1 100644
--- a/viper.go
+++ b/viper.go
@@ -142,6 +142,10 @@ func DecodeHook(hook mapstructure.DecodeHookFunc) DecoderConfigOption {
 	}
 }
 
+type keyNormalizeHookType func(string) string
+
+var defaultKeyNormalizer = strings.ToLower
+
 // Viper is a prioritized configuration registry. It
 // maintains a set of configuration sources, fetches
 // values to populate those, and provides them according
@@ -183,6 +187,10 @@ type Viper struct {
 	// used to access a nested value in one go
 	keyDelim string
 
+	// Function to normalize keys
+	// by default, strings.ToLower
+	keyNormalizeHook keyNormalizeHookType
+
 	// A set of paths to look for the config file in
 	configPaths []string
 
@@ -229,6 +237,7 @@ type Viper struct {
 func New() *Viper {
 	v := new(Viper)
 	v.keyDelim = "."
+	v.keyNormalizeHook = defaultKeyNormalizer
 	v.configName = "config"
 	v.configPermissions = os.FileMode(0o644)
 	v.fs = afero.NewOsFs()
@@ -270,6 +279,23 @@ func KeyDelimiter(d string) Option {
 	})
 }
 
+// KeyNormalizer is option to set arbitrary function for key normalization
+// This function will be applied to all keys after unmarshal, during merge, search for duplicates, etc
+// Default normalizer is strings.ToLower
+func KeyNormalizer(n keyNormalizeHookType) Option {
+	return optionFunc(func(v *Viper) {
+		v.keyNormalizeHook = n
+	})
+}
+
+// KeyPreserveCase is option to disable key lowercasing
+// By default, Viper converts all keys to lovercase
+func KeyPreserveCase() Option {
+	return optionFunc(func(v *Viper) {
+		v.keyNormalizeHook = func(key string) string { return key }
+	})
+}
+
 // StringReplacer applies a set of replacements to a string.
 type StringReplacer interface {
 	// Replace returns a copy of s with all replacements performed.
@@ -523,6 +549,13 @@ func (v *Viper) SetEnvPrefix(in string) {
 	}
 }
 
+func (v *Viper) keyNormalize(k string) string {
+	if v.keyNormalizeHook != nil {
+		return v.keyNormalizeHook(k)
+	}
+	return defaultKeyNormalizer(k)
+}
+
 func (v *Viper) mergeWithEnvPrefix(in string) string {
 	if v.envPrefix != "" {
 		return strings.ToUpper(v.envPrefix + "_" + in)
@@ -652,7 +685,7 @@ func (v *Viper) providerPathExists(p *defaultRemoteProvider) bool {
 
 // searchMap recursively searches for a value for path in source map.
 // Returns nil if not found.
-// Note: This assumes that the path entries and map keys are lower cased.
+// Note: This assumes that the path entries and map keys are normalized (by default, lowercased).
 func (v *Viper) searchMap(source map[string]interface{}, path []string) interface{} {
 	if len(path) == 0 {
 		return source
@@ -691,7 +724,7 @@ func (v *Viper) searchMap(source map[string]interface{}, path []string) interfac
 // This should be useful only at config level (other maps may not contain dots
 // in their keys).
 //
-// Note: This assumes that the path entries and map keys are lower cased.
+// Note: This assumes that the path entries and map keys are lower cased (normalized).
 func (v *Viper) searchIndexableWithPathPrefixes(source interface{}, path []string) interface{} {
 	if len(path) == 0 {
 		return source
@@ -699,7 +732,7 @@ func (v *Viper) searchIndexableWithPathPrefixes(source interface{}, path []strin
 
 	// search for path prefixes, starting from the longest one
 	for i := len(path); i > 0; i-- {
-		prefixKey := strings.ToLower(strings.Join(path[0:i], v.keyDelim))
+		prefixKey := v.keyNormalize(strings.Join(path[0:i], v.keyDelim))
 
 		var val interface{}
 		switch sourceIndexable := source.(type) {
@@ -890,7 +923,7 @@ func GetViper() *Viper {
 func Get(key string) interface{} { return v.Get(key) }
 
 func (v *Viper) Get(key string) interface{} {
-	lcaseKey := strings.ToLower(key)
+	lcaseKey := v.keyNormalize(key)
 	val := v.find(lcaseKey, true)
 	if val == nil {
 		return nil
@@ -950,7 +983,7 @@ func (v *Viper) Sub(key string) *Viper {
 	}
 
 	if reflect.TypeOf(data).Kind() == reflect.Map {
-		subv.parents = append(v.parents, strings.ToLower(key))
+		subv.parents = append(v.parents, v.keyNormalize(key))
 		subv.automaticEnvApplied = v.automaticEnvApplied
 		subv.envPrefix = v.envPrefix
 		subv.envKeyReplacer = v.envKeyReplacer
@@ -1189,7 +1222,7 @@ func (v *Viper) BindFlagValue(key string, flag FlagValue) error {
 	if flag == nil {
 		return fmt.Errorf("flag for %q is nil", key)
 	}
-	v.pflags[strings.ToLower(key)] = flag
+	v.pflags[v.keyNormalize(key)] = flag
 	return nil
 }
 
@@ -1206,7 +1239,7 @@ func (v *Viper) BindEnv(input ...string) error {
 		return fmt.Errorf("missing key to bind to")
 	}
 
-	key := strings.ToLower(input[0])
+	key := v.keyNormalize(input[0])
 
 	if len(input) == 1 {
 		v.env[key] = append(v.env[key], v.mergeWithEnvPrefix(key))
@@ -1236,7 +1269,7 @@ func (v *Viper) MustBindEnv(input ...string) {
 // Lastly, if no value was found and flagDefault is true, and if the key
 // corresponds to a flag, the flag's default value is returned.
 //
-// Note: this assumes a lower-cased key given.
+// Note: this assumes a normalized key given.
 func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
 	var (
 		val    interface{}
@@ -1419,12 +1452,12 @@ func stringToStringConv(val string) interface{} {
 }
 
 // IsSet checks to see if the key has been set in any of the data locations.
-// IsSet is case-insensitive for a key.
+// IsSet normalizes the key.
 func IsSet(key string) bool { return v.IsSet(key) }
 
 func (v *Viper) IsSet(key string) bool {
-	lcaseKey := strings.ToLower(key)
-	val := v.find(lcaseKey, false)
+	normKey := v.keyNormalize(key)
+	val := v.find(normKey, false)
 	return val != nil
 }
 
@@ -1450,11 +1483,11 @@ func (v *Viper) SetEnvKeyReplacer(r *strings.Replacer) {
 func RegisterAlias(alias string, key string) { v.RegisterAlias(alias, key) }
 
 func (v *Viper) RegisterAlias(alias string, key string) {
-	v.registerAlias(alias, strings.ToLower(key))
+	v.registerAlias(alias, v.keyNormalize(key))
 }
 
 func (v *Viper) registerAlias(alias string, key string) {
-	alias = strings.ToLower(alias)
+	alias = v.keyNormalize(alias)
 	if alias != key && alias != v.realKey(key) {
 		_, exists := v.aliases[alias]
 
@@ -1499,7 +1532,7 @@ func (v *Viper) realKey(key string) string {
 func InConfig(key string) bool { return v.InConfig(key) }
 
 func (v *Viper) InConfig(key string) bool {
-	lcaseKey := strings.ToLower(key)
+	lcaseKey := v.keyNormalize(key)
 
 	// if the requested key is an alias, then return the proper key
 	lcaseKey = v.realKey(lcaseKey)
@@ -1509,17 +1542,18 @@ func (v *Viper) InConfig(key string) bool {
 }
 
 // SetDefault sets the default value for this key.
-// SetDefault is case-insensitive for a key.
+// SetDefault applies normalization (by default, lowercases) a key.
 // Default only used when no value is provided by the user via flag, config or ENV.
 func SetDefault(key string, value interface{}) { v.SetDefault(key, value) }
 
 func (v *Viper) SetDefault(key string, value interface{}) {
 	// If alias passed in, then set the proper default
-	key = v.realKey(strings.ToLower(key))
-	value = toCaseInsensitiveValue(value)
+	key = v.keyNormalize(key)
+	value = toCaseInsensitiveValue(value, v.keyNormalize)
+	key = v.realKey(key)
 
 	path := strings.Split(key, v.keyDelim)
-	lastKey := strings.ToLower(path[len(path)-1])
+	lastKey := v.keyNormalize(path[len(path)-1])
 	deepestMap := deepSearch(v.defaults, path[0:len(path)-1])
 
 	// set innermost value
@@ -1527,18 +1561,19 @@ func (v *Viper) SetDefault(key string, value interface{}) {
 }
 
 // Set sets the value for the key in the override register.
-// Set is case-insensitive for a key.
+// Set normalizes a key.
 // Will be used instead of values obtained via
 // flags, config file, ENV, default, or key/value store.
 func Set(key string, value interface{}) { v.Set(key, value) }
 
 func (v *Viper) Set(key string, value interface{}) {
 	// If alias passed in, then set the proper override
-	key = v.realKey(strings.ToLower(key))
-	value = toCaseInsensitiveValue(value)
+	key = v.keyNormalize(key)
+	value = toCaseInsensitiveValue(value, v.keyNormalize)
+	key = v.realKey(key)
 
 	path := strings.Split(key, v.keyDelim)
-	lastKey := strings.ToLower(path[len(path)-1])
+	lastKey := v.keyNormalize(path[len(path)-1])
 	deepestMap := deepSearch(v.override, path[0:len(path)-1])
 
 	// set innermost value
@@ -1627,7 +1662,7 @@ func (v *Viper) MergeConfigMap(cfg map[string]interface{}) error {
 	if v.config == nil {
 		v.config = make(map[string]interface{})
 	}
-	insensitiviseMap(cfg)
+	insensitiviseMap(cfg, v.keyNormalize)
 	mergeMaps(cfg, v.config, nil)
 	return nil
 }
@@ -1727,7 +1762,8 @@ func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error {
 		}
 	}
 
-	insensitiviseMap(c)
+	insensitiviseMap(c, v.keyNormalize)
+
 	return nil
 }
 
@@ -1750,9 +1786,9 @@ func (v *Viper) marshalWriter(f afero.File, configType string) error {
 }
 
 func keyExists(k string, m map[string]interface{}) string {
-	lk := strings.ToLower(k)
+	lk := v.keyNormalize(k)
 	for mk := range m {
-		lmk := strings.ToLower(mk)
+		lmk := v.keyNormalize(mk)
 		if lmk == lk {
 			return mk
 		}
@@ -2031,7 +2067,7 @@ func (v *Viper) flattenAndMergeMap(shadow map[string]bool, m map[string]interfac
 			m2 = cast.ToStringMap(val)
 		default:
 			// immediate value
-			shadow[strings.ToLower(fullKey)] = true
+			shadow[v.keyNormalize(fullKey)] = true
 			continue
 		}
 		// recursively merge to shadow map
@@ -2057,7 +2093,7 @@ outer:
 			}
 		}
 		// add key
-		shadow[strings.ToLower(k)] = true
+		shadow[v.keyNormalize(k)] = true
 	}
 	return shadow
 }
@@ -2076,7 +2112,7 @@ func (v *Viper) AllSettings() map[string]interface{} {
 			continue
 		}
 		path := strings.Split(k, v.keyDelim)
-		lastKey := strings.ToLower(path[len(path)-1])
+		lastKey := v.keyNormalize(path[len(path)-1])
 		deepestMap := deepSearch(m, path[0:len(path)-1])
 		// set innermost value
 		deepestMap[lastKey] = value
diff --git a/viper_test.go b/viper_test.go
index b867337..871ac83 100644
--- a/viper_test.go
+++ b/viper_test.go
@@ -2334,6 +2334,76 @@ func TestCaseInsensitiveSet(t *testing.T) {
 	}
 }
 
+func TestCaseSensitive(t *testing.T) {
+	for _, config := range []struct {
+		typ     string
+		content string
+	}{
+		{"yaml", `
+aBcD: 1
+eF:
+  gH: 2
+  iJk: 3
+  Lm:
+    nO: 4
+    P:
+      Q: 5
+      R: 6
+`},
+		{"json", `{
+  "aBcD": 1,
+  "eF": {
+    "iJk": 3,
+    "Lm": {
+      "P": {
+        "Q": 5,
+        "R": 6
+      },
+      "nO": 4
+    },
+    "gH": 2
+  }
+}`},
+		{"toml", `aBcD = 1
+[eF]
+gH = 2
+iJk = 3
+[eF.Lm]
+nO = 4
+[eF.Lm.P]
+Q = 5
+R = 6
+`},
+	} {
+		doTestCaseSensitive(t, config.typ, config.content)
+	}
+}
+
+func doTestCaseSensitive(t *testing.T, typ, config string) {
+	// Create case-sensitive instance
+	v := NewWithOptions(KeyPreserveCase())
+	v.SetConfigType(typ)
+
+	r := strings.NewReader(config)
+	if err := v.unmarshalReader(r, v.config); err != nil {
+		panic(err)
+	}
+
+	v.Set("RfD", true)
+	assert.Equal(t, nil, v.Get("rfd"))
+	assert.Equal(t, true, v.Get("RfD"))
+	assert.Equal(t, 0, cast.ToInt(v.Get("abcd")))
+	assert.Equal(t, 1, cast.ToInt(v.Get("aBcD")))
+	assert.Equal(t, 0, cast.ToInt(v.Get("ef.gh")))
+	assert.Equal(t, 2, cast.ToInt(v.Get("eF.gH")))
+	assert.Equal(t, 0, cast.ToInt(v.Get("ef.ijk")))
+	assert.Equal(t, 3, cast.ToInt(v.Get("eF.iJk")))
+	assert.Equal(t, 0, cast.ToInt(v.Get("ef.lm.no")))
+	assert.Equal(t, 4, cast.ToInt(v.Get("eF.Lm.nO")))
+	assert.Equal(t, 0, cast.ToInt(v.Get("ef.lm.p.q")))
+	assert.Equal(t, 5, cast.ToInt(v.Get("eF.Lm.P.Q")))
+}
+
 func TestParseNested(t *testing.T) {
 	type duration struct {
 		Delay time.Duration