added a setting to make viper case sensitive

This commit is contained in:
Emil Nikolov 2018-11-07 03:06:30 +01:00
parent b7a3b95476
commit 13b1becc90
No known key found for this signature in database
GPG key ID: 8066CB35FB614A33
2 changed files with 52 additions and 20 deletions

1
.gitignore vendored
View file

@ -27,3 +27,4 @@ _testmain.go
# exclude dependencies in the `/vendor` folder # exclude dependencies in the `/vendor` folder
vendor vendor
.idea/

View file

@ -203,6 +203,7 @@ type Viper struct {
properties *properties.Properties properties *properties.Properties
onConfigChange func(fsnotify.Event) onConfigChange func(fsnotify.Event)
caseSensitive bool
} }
// New returns an initialized Viper instance. // New returns an initialized Viper instance.
@ -219,6 +220,7 @@ func New() *Viper {
v.env = make(map[string]string) v.env = make(map[string]string)
v.aliases = make(map[string]string) v.aliases = make(map[string]string)
v.typeByDefValue = false v.typeByDefValue = false
v.caseSensitive = false
return v return v
} }
@ -277,6 +279,13 @@ func (v *Viper) OnConfigChange(run func(in fsnotify.Event)) {
v.onConfigChange = run v.onConfigChange = run
} }
func SetCaseSensitive(flag bool) {
v.SetCaseSensitive(flag)
}
func (v *Viper) SetCaseSensitive(flag bool) {
v.caseSensitive = flag
}
func WatchConfig() { v.WatchConfig() } func WatchConfig() { v.WatchConfig() }
func (v *Viper) WatchConfig() { func (v *Viper) WatchConfig() {
@ -536,7 +545,7 @@ func (v *Viper) searchMapWithPathPrefixes(source map[string]interface{}, path []
// search for path prefixes, starting from the longest one // search for path prefixes, starting from the longest one
for i := len(path); i > 0; i-- { for i := len(path); i > 0; i-- {
prefixKey := strings.ToLower(strings.Join(path[0:i], v.keyDelim)) prefixKey := v.properCase(strings.Join(path[0:i], v.keyDelim))
next, ok := source[prefixKey] next, ok := source[prefixKey]
if ok { if ok {
@ -665,7 +674,7 @@ func GetViper() *Viper {
// Get returns an interface. For a specific value use one of the Get____ methods. // Get returns an interface. For a specific value use one of the Get____ methods.
func Get(key string) interface{} { return v.Get(key) } func Get(key string) interface{} { return v.Get(key) }
func (v *Viper) Get(key string) interface{} { func (v *Viper) Get(key string) interface{} {
lcaseKey := strings.ToLower(key) lcaseKey := v.properCase(key)
val := v.find(lcaseKey) val := v.find(lcaseKey)
if val == nil { if val == nil {
return nil return nil
@ -918,7 +927,7 @@ func (v *Viper) BindFlagValue(key string, flag FlagValue) error {
if flag == nil { if flag == nil {
return fmt.Errorf("flag for %q is nil", key) return fmt.Errorf("flag for %q is nil", key)
} }
v.pflags[strings.ToLower(key)] = flag v.pflags[v.properCase(key)] = flag
return nil return nil
} }
@ -933,7 +942,7 @@ func (v *Viper) BindEnv(input ...string) error {
return fmt.Errorf("BindEnv missing key to bind to") return fmt.Errorf("BindEnv missing key to bind to")
} }
key = strings.ToLower(input[0]) key = v.properCase(input[0])
if len(input) == 1 { if len(input) == 1 {
envkey = v.mergeWithEnvPrefix(key) envkey = v.mergeWithEnvPrefix(key)
@ -1083,7 +1092,7 @@ func readAsCSV(val string) ([]string, error) {
// IsSet is case-insensitive for a key. // IsSet is case-insensitive for a key.
func IsSet(key string) bool { return v.IsSet(key) } func IsSet(key string) bool { return v.IsSet(key) }
func (v *Viper) IsSet(key string) bool { func (v *Viper) IsSet(key string) bool {
lcaseKey := strings.ToLower(key) lcaseKey := v.properCase(key)
val := v.find(lcaseKey) val := v.find(lcaseKey)
return val != nil return val != nil
} }
@ -1107,11 +1116,11 @@ func (v *Viper) SetEnvKeyReplacer(r *strings.Replacer) {
// This enables one to change a name without breaking the application // This enables one to change a name without breaking the application
func RegisterAlias(alias string, key string) { v.RegisterAlias(alias, key) } func RegisterAlias(alias string, key string) { v.RegisterAlias(alias, key) }
func (v *Viper) RegisterAlias(alias string, key string) { func (v *Viper) RegisterAlias(alias string, key string) {
v.registerAlias(alias, strings.ToLower(key)) v.registerAlias(alias, v.properCase(key))
} }
func (v *Viper) registerAlias(alias string, key string) { func (v *Viper) registerAlias(alias string, key string) {
alias = strings.ToLower(alias) alias = v.properCase(alias)
if alias != key && alias != v.realKey(key) { if alias != key && alias != v.realKey(key) {
_, exists := v.aliases[alias] _, exists := v.aliases[alias]
@ -1167,11 +1176,14 @@ func (v *Viper) InConfig(key string) bool {
func SetDefault(key string, value interface{}) { v.SetDefault(key, value) } func SetDefault(key string, value interface{}) { v.SetDefault(key, value) }
func (v *Viper) SetDefault(key string, value interface{}) { func (v *Viper) SetDefault(key string, value interface{}) {
// If alias passed in, then set the proper default // If alias passed in, then set the proper default
key = v.realKey(strings.ToLower(key)) key = v.realKey(v.properCase(key))
value = toCaseInsensitiveValue(value)
if !v.caseSensitive {
value = toCaseInsensitiveValue(value)
}
path := strings.Split(key, v.keyDelim) path := strings.Split(key, v.keyDelim)
lastKey := strings.ToLower(path[len(path)-1]) lastKey := v.properCase(path[len(path)-1])
deepestMap := deepSearch(v.defaults, path[0:len(path)-1]) deepestMap := deepSearch(v.defaults, path[0:len(path)-1])
// set innermost value // set innermost value
@ -1185,11 +1197,14 @@ func (v *Viper) SetDefault(key string, value interface{}) {
func Set(key string, value interface{}) { v.Set(key, value) } func Set(key string, value interface{}) { v.Set(key, value) }
func (v *Viper) Set(key string, value interface{}) { func (v *Viper) Set(key string, value interface{}) {
// If alias passed in, then set the proper override // If alias passed in, then set the proper override
key = v.realKey(strings.ToLower(key)) key = v.realKey(v.properCase(key))
value = toCaseInsensitiveValue(value)
if !v.caseSensitive {
value = toCaseInsensitiveValue(value)
}
path := strings.Split(key, v.keyDelim) path := strings.Split(key, v.keyDelim)
lastKey := strings.ToLower(path[len(path)-1]) lastKey := v.properCase(path[len(path)-1])
deepestMap := deepSearch(v.override, path[0:len(path)-1]) deepestMap := deepSearch(v.override, path[0:len(path)-1])
// set innermost value // set innermost value
@ -1342,7 +1357,7 @@ func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
buf.ReadFrom(in) buf.ReadFrom(in)
switch strings.ToLower(v.getConfigType()) { switch v.properCase(v.getConfigType()) {
case "yaml", "yml": case "yaml", "yml":
if err := yaml.Unmarshal(buf.Bytes(), &c); err != nil { if err := yaml.Unmarshal(buf.Bytes(), &c); err != nil {
return ConfigParseError{err} return ConfigParseError{err}
@ -1382,13 +1397,17 @@ func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error {
value, _ := v.properties.Get(key) value, _ := v.properties.Get(key)
// recursively build nested maps // recursively build nested maps
path := strings.Split(key, ".") path := strings.Split(key, ".")
lastKey := strings.ToLower(path[len(path)-1]) lastKey := v.properCase(path[len(path)-1])
deepestMap := deepSearch(c, path[0:len(path)-1]) deepestMap := deepSearch(c, path[0:len(path)-1])
// set innermost value // set innermost value
deepestMap[lastKey] = value deepestMap[lastKey] = value
} }
} }
if v.caseSensitive {
return nil
}
insensitiviseMap(c) insensitiviseMap(c)
return nil return nil
} }
@ -1460,9 +1479,9 @@ func (v *Viper) marshalWriter(f afero.File, configType string) error {
} }
func keyExists(k string, m map[string]interface{}) string { func keyExists(k string, m map[string]interface{}) string {
lk := strings.ToLower(k) lk := v.properCase(k)
for mk := range m { for mk := range m {
lmk := strings.ToLower(mk) lmk := v.properCase(mk)
if lmk == lk { if lmk == lk {
return mk return mk
} }
@ -1572,6 +1591,10 @@ func (v *Viper) WatchRemoteConfigOnChannel() error {
} }
func (v *Viper) insensitiviseMaps() { func (v *Viper) insensitiviseMaps() {
if v.caseSensitive {
return
}
insensitiviseMap(v.config) insensitiviseMap(v.config)
insensitiviseMap(v.defaults) insensitiviseMap(v.defaults)
insensitiviseMap(v.override) insensitiviseMap(v.override)
@ -1693,7 +1716,7 @@ func (v *Viper) flattenAndMergeMap(shadow map[string]bool, m map[string]interfac
m2 = cast.ToStringMap(val) m2 = cast.ToStringMap(val)
default: default:
// immediate value // immediate value
shadow[strings.ToLower(fullKey)] = true shadow[v.properCase(fullKey)] = true
continue continue
} }
// recursively merge to shadow map // recursively merge to shadow map
@ -1719,7 +1742,7 @@ outer:
} }
} }
// add key // add key
shadow[strings.ToLower(k)] = true shadow[v.properCase(k)] = true
} }
return shadow return shadow
} }
@ -1737,7 +1760,7 @@ func (v *Viper) AllSettings() map[string]interface{} {
continue continue
} }
path := strings.Split(k, v.keyDelim) path := strings.Split(k, v.keyDelim)
lastKey := strings.ToLower(path[len(path)-1]) lastKey := v.properCase(path[len(path)-1])
deepestMap := deepSearch(m, path[0:len(path)-1]) deepestMap := deepSearch(m, path[0:len(path)-1])
// set innermost value // set innermost value
deepestMap[lastKey] = value deepestMap[lastKey] = value
@ -1827,6 +1850,14 @@ func (v *Viper) findConfigFile() (string, error) {
return "", ConfigFileNotFoundError{v.configName, fmt.Sprintf("%s", v.configPaths)} return "", ConfigFileNotFoundError{v.configName, fmt.Sprintf("%s", v.configPaths)}
} }
func (v *Viper) properCase(s string) string {
if !v.caseSensitive {
s = strings.ToLower(s)
}
return s
}
// Debug prints all configuration registries for debugging // Debug prints all configuration registries for debugging
// purposes. // purposes.
func Debug() { v.Debug() } func Debug() { v.Debug() }