diff --git a/README.md b/README.md index f451881..a42c6f1 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,42 @@ if err := viper.ReadInConfig(); err != nil { *NOTE [since 1.6]:* You can also have a file without an extension and specify the format programmatically. For those configuration files that lie in the home of the user without any extension like `.bashrc` +### MapTo +- source file +```yaml +service: + port: 1234 + ip: "127.0.0.1" +version: 1.0.01 +``` +- use MapTo + +```go +type Service struct { + Port int `viper:"port"` + IP string `viper:"ip"` +} +//your prepare code ... + +var service Service +var version string + +if err := viper.MapTo("service",&service); err != nil { + //error handler... +} +if err := viper.MapTo("version",&version); err != nil { + //error handler... +} + + + +log.Println(service,version) + +//.... + +``` + + ### Writing Config Files Reading from config files is useful, but at times you want to store all modifications made at run time. @@ -875,7 +911,16 @@ application foundation needs. Is there a better name for a [commander](http://en.wikipedia.org/wiki/Cobra_Commander)? ### Does Viper support case sensitive keys? +#### [FEATURE] surport case sensitive +```go +// if you want to keep case insensitive, you can do nothing +// but if you want to make it case sensitive, please do the following step +func main(){ + viper.SetCaseSensitive() + // your code next... +} +``` **tl;dr:** No. Viper merges configuration from various sources, many of which are either case insensitive or uses different casing than the rest of the sources (eg. env vars). diff --git a/internal/convert/convert.go b/internal/convert/convert.go new file mode 100644 index 0000000..53d5314 --- /dev/null +++ b/internal/convert/convert.go @@ -0,0 +1,247 @@ +package convert + +import ( + "fmt" + "reflect" + "strings" +) + +var convertUtils = map[reflect.Kind]func(reflect.Value, reflect.Value) error{ + reflect.String: converNormal, + reflect.Int: converNormal, + reflect.Int16: converNormal, + reflect.Int32: converNormal, + reflect.Int64: converNormal, + reflect.Uint: converNormal, + reflect.Uint16: converNormal, + reflect.Uint32: converNormal, + reflect.Uint64: converNormal, + reflect.Float32: converNormal, + reflect.Float64: converNormal, + reflect.Uint8: converNormal, + reflect.Int8: converNormal, + reflect.Bool: converNormal, +} + +//Convert 类型强制转换 +//示例 +/* + type Target struct { + A int `json:"aint"` + B string `json:"bstr"` + } + src :=map[string]interface{}{ + "aint":1224, + "bstr":"124132" + } + + var t Target + Convert(src,&t) + +*/ +//fix循环引用的问题 +var _ = func() struct{} { + convertUtils[reflect.Map] = convertMap + convertUtils[reflect.Array] = convertSlice + convertUtils[reflect.Slice] = convertSlice + return struct{}{} +}() + +func Convert(src interface{}, dst interface{}) (err error) { + defer func() { + if v := recover(); v != nil { + err = fmt.Errorf("panic recover:%v", v) + } + }() + + dstRef := reflect.ValueOf(dst) + if dstRef.Kind() != reflect.Ptr { + return fmt.Errorf("dst is not ptr") + } + + dstRef = reflect.Indirect(dstRef) + + srcRef := reflect.ValueOf(src) + if srcRef.Kind() == reflect.Ptr || srcRef.Kind() == reflect.Interface { + srcRef = srcRef.Elem() + } + if f, ok := convertUtils[srcRef.Kind()]; ok { + return f(srcRef, dstRef) + } + + return fmt.Errorf("no implemented:%s", srcRef.Type()) +} + +func converNormal(src reflect.Value, dst reflect.Value) error { + if dst.CanSet() { + if src.Type() == dst.Type() { + dst.Set(src) + } else if src.CanConvert(dst.Type()) { + dst.Set(src.Convert(dst.Type())) + } else { + return fmt.Errorf("can not convert:%s:%s", src.Type().String(), dst.Type().String()) + } + } + return nil +} + +func convertSlice(src reflect.Value, dst reflect.Value) error { + if dst.Kind() != reflect.Array && dst.Kind() != reflect.Slice { + return fmt.Errorf("error type:%s", dst.Type().String()) + } else if !src.IsValid() { + return nil + } + + l := src.Len() + target := reflect.MakeSlice(dst.Type(), l, l) + if dst.CanSet() { + dst.Set(target) + } + for i := 0; i < l; i++ { + srcValue := src.Index(i) + if srcValue.Kind() == reflect.Ptr || srcValue.Kind() == reflect.Interface { + srcValue = srcValue.Elem() + } + if f, ok := convertUtils[srcValue.Kind()]; ok { + err := f(srcValue, dst.Index(i)) + if err != nil { + return err + } + } + } + + return nil +} + +func convertMap(src reflect.Value, dst reflect.Value) error { + // + if src.Kind() == reflect.Ptr || src.Kind() == reflect.Interface { + src = src.Elem() + } + if src.Kind() != reflect.Map || dst.Kind() != reflect.Struct { + if dst.Kind() == reflect.Map { + return converMapToMap(src, dst) + } + if !(dst.Kind() == reflect.Ptr && dst.Type().Elem().Kind() == reflect.Struct) { + if dst.Kind() == reflect.Interface && dst.CanSet() { + dst.Set(src) + return nil + } + return fmt.Errorf("src or dst type error:%s,%s", src.Kind().String(), dst.Type().String()) + } + if !reflect.Indirect(dst).IsValid() { + v := reflect.New(dst.Type().Elem()) + dst.Set(v) + } + dst = reflect.Indirect(dst) + } + dstType := dst.Type() + num := dstType.NumField() + exist := map[string]int{} + for i := 0; i < num; i++ { + k := dstType.Field(i).Tag.Get("viper") + if k == "" { + k = dstType.Field(i).Name + } + if strings.Contains(k, ",") { + taglist := strings.Split(k, ",") + if taglist[0] == "" { + if len(taglist) == 2 && + taglist[1] == "inline" { + v := dst.Field(i) + + err := convertMap(src, v) + if err != nil { + return err + } + dst.Field(i).Set(v) + continue + } else { + k = dstType.Field(i).Name + } + } else { + k = taglist[0] + + } + + } + exist[k] = i + } + + keys := src.MapKeys() + for _, key := range keys { + if index, ok := exist[key.String()]; ok { + v := dst.Field(index) + if v.Kind() == reflect.Struct { + err := convertMap(src.MapIndex(key), v) + if err != nil { + return err + } + } else if v.Kind() == reflect.Slice { + err := convertSlice(src.MapIndex(key).Elem(), v) + if err != nil { + return err + } + + } else { + if v.CanSet() && src.MapIndex(key).IsValid() && !src.MapIndex(key).IsZero() { + if v.Type() == src.MapIndex(key).Elem().Type() { + v.Set(src.MapIndex(key).Elem()) + } else if src.MapIndex(key).Elem().CanConvert(v.Type()) { + v.Set(src.MapIndex(key).Elem().Convert(v.Type())) + } else if f, ok := convertUtils[src.MapIndex(key).Elem().Kind()]; ok && f != nil { + err := f(src.MapIndex(key).Elem(), v) + if err != nil { + return err + } + } else { + return fmt.Errorf("error type:d(%s)s(%s)", v.Type(), src.MapIndex(key).Elem().Type()) + } + } + } + } + } + + return nil +} + +func converMapToMap(src reflect.Value, dst reflect.Value) error { + if src.Kind() != reflect.Map || dst.Kind() != reflect.Map { + return fmt.Errorf("type error: src(%v),dst(%v)", src.Kind(), src.Kind()) + } + mv := reflect.MakeMap(dst.Type()) + keys := src.MapKeys() + dt := dst.Type().Elem().Kind() + for _, key := range keys { + if dt == reflect.Struct { + me := reflect.New(dst.Type().Elem()) + me = reflect.Indirect(me) + convertMap(src.MapIndex(key).Elem(), me) + mv.SetMapIndex(key, me) + } else if dt == reflect.Ptr { + me := reflect.New(dst.Type().Elem().Elem()) + me = reflect.Indirect(me) + convertMap(src.MapIndex(key).Elem(), me) + mv.SetMapIndex(key, me.Addr()) + } else if dt == reflect.Slice { + l := src.MapIndex(key).Elem().Len() + v := reflect.MakeSlice(dst.Type().Elem(), l, l) + err := convertSlice(src.MapIndex(key).Elem(), v) + if err != nil { + return err + } + mv.SetMapIndex(key, v) + } else { + if src.MapIndex(key).Elem().Kind() != dst.Type().Elem().Kind() && + src.MapIndex(key).Elem().CanConvert(dst.Type().Elem()) { + v := src.MapIndex(key).Elem().Convert(dst.Type().Elem()) + mv.SetMapIndex(key, v) + continue + } + + mv.SetMapIndex(key, src.MapIndex(key).Elem()) + } + } + dst.Set(mv) + return nil +} diff --git a/internal/convert/convert_test.go b/internal/convert/convert_test.go new file mode 100644 index 0000000..ae98129 --- /dev/null +++ b/internal/convert/convert_test.go @@ -0,0 +1,108 @@ +package convert + +import ( + "testing" +) + +func TestConvert(t *testing.T) { + type Tmp1 struct { + Str string `viper:"str"` + I8 int8 `viper:"i8"` + Int16 int16 `viper:"i16"` + Int32 int32 `viper:"i32"` + Int64 int64 `viper:"i64"` + I int `viper:"i"` + U8 int8 `viper:"u8"` + Uint16 int16 `viper:"u16"` + Uint32 int32 `viper:"u32"` + Uint64 int64 `viper:"u64"` + U int `viper:"u"` + F32 float32 `viper:"f32"` + F64 float64 `viper:"f64"` + TF bool `viper:"tf"` + M map[string]interface{} `viper:"m"` + S []interface{} `viper:"s"` + } + tc := map[string]interface{}{ + "str": "Hello world", + "i8": -8, + "i16": -16, + "i32": -32, + "i64": -64, + "i": -1, + "u8": 8, + "u16": 16, + "u32": 32, + "u64": 64, + "u": 1, + "f32": 3.32, + "f64": 3.64, + "tf": true, + "m": map[string]interface{}{ + "im": 123, + }, + "s": []interface{}{ + "1234", + 1.23, + }, + } + + var tmp Tmp1 + err := Convert(tc, &tmp) + if err != nil { + t.Error(err) + } + t.Error(tmp) + +} + +func BenchmarkConvert(b *testing.B) { + type Tmp1 struct { + Str string `viper:"str"` + I8 int8 `viper:"i8"` + Int16 int16 `viper:"i16"` + Int32 int32 `viper:"i32"` + Int64 int64 `viper:"i64"` + I int `viper:"i"` + U8 int8 `viper:"u8"` + Uint16 int16 `viper:"u16"` + Uint32 int32 `viper:"u32"` + Uint64 int64 `viper:"u64"` + U int `viper:"u"` + F32 float32 `viper:"f32"` + F64 float64 `viper:"f64"` + TF bool `viper:"tf"` + M map[string]interface{} `viper:"m"` + S []interface{} `viper:"s"` + } + tc := map[string]interface{}{ + "str": "Hello world", + "i8": -8, + "i16": -16, + "i32": -32, + "i64": -64, + "i": -1, + "u8": 8, + "u16": 16, + "u32": 32, + "u64": 64, + "u": 1, + "f32": 3.32, + "f64": 3.64, + "tf": true, + "m": map[string]interface{}{ + "im": 123, + }, + "s": []interface{}{ + "1234", + 1.23, + }, + } + for i := 0; i < b.N; i++ { + var tmp Tmp1 + err := Convert(tc, &tmp) + if err != nil { + b.Error(err) + } + } +} diff --git a/internal/encoding/dotenv/map_utils.go b/internal/encoding/dotenv/map_utils.go index aeb6b87..e153652 100644 --- a/internal/encoding/dotenv/map_utils.go +++ b/internal/encoding/dotenv/map_utils.go @@ -1,9 +1,8 @@ package dotenv import ( - "strings" - "github.com/spf13/cast" + insensitiveopt "github.com/spf13/viper/internal/insensitiveOpt" ) // flattenAndMergeMap recursively flattens the given map into a new map @@ -31,7 +30,7 @@ func flattenAndMergeMap(shadow map[string]interface{}, m map[string]interface{}, m2 = cast.ToStringMap(val) default: // immediate value - shadow[strings.ToLower(fullKey)] = val + shadow[insensitiveopt.ToLower(fullKey)] = val continue } // recursively merge to shadow map diff --git a/internal/encoding/ini/map_utils.go b/internal/encoding/ini/map_utils.go index 4fb9eb1..6e97eb0 100644 --- a/internal/encoding/ini/map_utils.go +++ b/internal/encoding/ini/map_utils.go @@ -1,9 +1,8 @@ package ini import ( - "strings" - "github.com/spf13/cast" + insensitiveopt "github.com/spf13/viper/internal/insensitiveOpt" ) // THIS CODE IS COPIED HERE: IT SHOULD NOT BE MODIFIED @@ -64,7 +63,7 @@ func flattenAndMergeMap(shadow map[string]interface{}, m map[string]interface{}, m2 = cast.ToStringMap(val) default: // immediate value - shadow[strings.ToLower(fullKey)] = val + shadow[insensitiveopt.ToLower(fullKey)] = val continue } // recursively merge to shadow map diff --git a/internal/encoding/javaproperties/codec.go b/internal/encoding/javaproperties/codec.go index b8a2251..a2723d4 100644 --- a/internal/encoding/javaproperties/codec.go +++ b/internal/encoding/javaproperties/codec.go @@ -7,6 +7,7 @@ import ( "github.com/magiconair/properties" "github.com/spf13/cast" + insensitiveopt "github.com/spf13/viper/internal/insensitiveOpt" ) // Codec implements the encoding.Encoder and encoding.Decoder interfaces for Java properties encoding. @@ -67,7 +68,7 @@ func (c *Codec) Decode(b []byte, v map[string]interface{}) error { // recursively build nested maps path := strings.Split(key, c.keyDelimiter()) - lastKey := strings.ToLower(path[len(path)-1]) + lastKey := insensitiveopt.ToLower(path[len(path)-1]) deepestMap := deepSearch(v, path[0:len(path)-1]) // set innermost value diff --git a/internal/encoding/javaproperties/map_utils.go b/internal/encoding/javaproperties/map_utils.go index eb53790..c736fa5 100644 --- a/internal/encoding/javaproperties/map_utils.go +++ b/internal/encoding/javaproperties/map_utils.go @@ -1,9 +1,8 @@ package javaproperties import ( - "strings" - "github.com/spf13/cast" + insensitiveopt "github.com/spf13/viper/internal/insensitiveOpt" ) // THIS CODE IS COPIED HERE: IT SHOULD NOT BE MODIFIED @@ -64,7 +63,7 @@ func flattenAndMergeMap(shadow map[string]interface{}, m map[string]interface{}, m2 = cast.ToStringMap(val) default: // immediate value - shadow[strings.ToLower(fullKey)] = val + shadow[insensitiveopt.ToLower(fullKey)] = val continue } // recursively merge to shadow map diff --git a/internal/insensitiveOpt/fix.go b/internal/insensitiveOpt/fix.go new file mode 100644 index 0000000..3648ca2 --- /dev/null +++ b/internal/insensitiveOpt/fix.go @@ -0,0 +1,28 @@ +package insensitiveopt + +import ( + "strings" + "unicode" +) + +var insensitive = true + +func Insensitive(f bool) { + insensitive = f +} + +func ToLower(s string) string { + if insensitive { + return strings.ToLower(s) + } + + return s +} + +func ToLowerRune(s rune) rune { + if insensitive { + return unicode.ToLower(s) + } + + return s +} diff --git a/util.go b/util.go index 25c832c..051e9dd 100644 --- a/util.go +++ b/util.go @@ -16,10 +16,10 @@ import ( "path/filepath" "runtime" "strings" - "unicode" slog "github.com/sagikazarmark/slog-shim" "github.com/spf13/cast" + insensitiveopt "github.com/spf13/viper/internal/insensitiveOpt" ) // ConfigParseError denotes failing to parse configuration file. @@ -56,7 +56,7 @@ func copyAndInsensitiviseMap(m map[string]interface{}) map[string]interface{} { nm := make(map[string]interface{}) for key, val := range m { - lkey := strings.ToLower(key) + lkey := insensitiveopt.ToLower(key) switch v := val.(type) { case map[interface{}]interface{}: nm[lkey] = copyAndInsensitiviseMap(cast.ToStringMap(v)) @@ -89,7 +89,7 @@ func insensitiviseVal(val interface{}) interface{} { func insensitiviseMap(m map[string]interface{}) { for key, val := range m { val = insensitiviseVal(val) - lower := strings.ToLower(key) + lower := insensitiveopt.ToLower(key) if key != lower { // remove old key (not lower-cased) delete(m, key) @@ -165,7 +165,7 @@ func parseSizeInBytes(sizeStr string) uint { if lastChar > 0 { if sizeStr[lastChar] == 'b' || sizeStr[lastChar] == 'B' { if lastChar > 1 { - switch unicode.ToLower(rune(sizeStr[lastChar-1])) { + switch insensitiveopt.ToLowerRune(rune(sizeStr[lastChar-1])) { case 'k': multiplier = 1 << 10 sizeStr = strings.TrimSpace(sizeStr[:lastChar-1]) diff --git a/viper.go b/viper.go index 097483b..bc69f55 100644 --- a/viper.go +++ b/viper.go @@ -48,6 +48,7 @@ import ( "github.com/spf13/viper/internal/encoding/json" "github.com/spf13/viper/internal/encoding/toml" "github.com/spf13/viper/internal/encoding/yaml" + insensitiveopt "github.com/spf13/viper/internal/insensitiveOpt" ) // ConfigMarshalError happens when failing to marshal the configuration. @@ -305,6 +306,14 @@ func Reset() { SupportedRemoteProviders = []string{"etcd", "etcd3", "consul", "firestore", "nats"} } +func SetCaseSensitive() { + insensitiveopt.Insensitive(false) +} + +func SetCaseInsensitive() { + insensitiveopt.Insensitive(true) +} + // TODO: make this lazy initialization instead func (v *Viper) resetEncoding() { encoderRegistry := encoding.NewEncoderRegistry() @@ -706,7 +715,7 @@ func (v *Viper) searchIndexableWithPathPrefixes(source interface{}, path []strin // search for path prefixes, starting from the longest one for i := len(path); i > 0; i-- { - prefixKey := strings.ToLower(strings.Join(path[0:i], v.keyDelim)) + prefixKey := insensitiveopt.ToLower(strings.Join(path[0:i], v.keyDelim)) var val interface{} switch sourceIndexable := source.(type) { @@ -897,7 +906,7 @@ func GetViper() *Viper { func Get(key string) interface{} { return v.Get(key) } func (v *Viper) Get(key string) interface{} { - lcaseKey := strings.ToLower(key) + lcaseKey := insensitiveopt.ToLower(key) val := v.find(lcaseKey, true) if val == nil { return nil @@ -957,7 +966,7 @@ func (v *Viper) Sub(key string) *Viper { } if reflect.TypeOf(data).Kind() == reflect.Map { - subv.parents = append(v.parents, strings.ToLower(key)) + subv.parents = append(v.parents, insensitiveopt.ToLower(key)) subv.automaticEnvApplied = v.automaticEnvApplied subv.envPrefix = v.envPrefix subv.envKeyReplacer = v.envKeyReplacer @@ -1196,7 +1205,7 @@ func (v *Viper) BindFlagValue(key string, flag FlagValue) error { if flag == nil { return fmt.Errorf("flag for %q is nil", key) } - v.pflags[strings.ToLower(key)] = flag + v.pflags[insensitiveopt.ToLower(key)] = flag return nil } @@ -1213,7 +1222,7 @@ func (v *Viper) BindEnv(input ...string) error { return fmt.Errorf("missing key to bind to") } - key := strings.ToLower(input[0]) + key := insensitiveopt.ToLower(input[0]) if len(input) == 1 { v.env[key] = append(v.env[key], v.mergeWithEnvPrefix(key)) @@ -1457,7 +1466,7 @@ func stringToIntConv(val string) interface{} { func IsSet(key string) bool { return v.IsSet(key) } func (v *Viper) IsSet(key string) bool { - lcaseKey := strings.ToLower(key) + lcaseKey := insensitiveopt.ToLower(key) val := v.find(lcaseKey, false) return val != nil } @@ -1484,11 +1493,11 @@ func (v *Viper) SetEnvKeyReplacer(r *strings.Replacer) { func RegisterAlias(alias string, key string) { v.RegisterAlias(alias, key) } func (v *Viper) RegisterAlias(alias string, key string) { - v.registerAlias(alias, strings.ToLower(key)) + v.registerAlias(alias, insensitiveopt.ToLower(key)) } func (v *Viper) registerAlias(alias string, key string) { - alias = strings.ToLower(alias) + alias = insensitiveopt.ToLower(alias) if alias != key && alias != v.realKey(key) { _, exists := v.aliases[alias] @@ -1533,7 +1542,7 @@ func (v *Viper) realKey(key string) string { func InConfig(key string) bool { return v.InConfig(key) } func (v *Viper) InConfig(key string) bool { - lcaseKey := strings.ToLower(key) + lcaseKey := insensitiveopt.ToLower(key) // if the requested key is an alias, then return the proper key lcaseKey = v.realKey(lcaseKey) @@ -1549,11 +1558,11 @@ func SetDefault(key string, value interface{}) { v.SetDefault(key, value) } func (v *Viper) SetDefault(key string, value interface{}) { // If alias passed in, then set the proper default - key = v.realKey(strings.ToLower(key)) + key = v.realKey(insensitiveopt.ToLower(key)) value = toCaseInsensitiveValue(value) path := strings.Split(key, v.keyDelim) - lastKey := strings.ToLower(path[len(path)-1]) + lastKey := insensitiveopt.ToLower(path[len(path)-1]) deepestMap := deepSearch(v.defaults, path[0:len(path)-1]) // set innermost value @@ -1568,11 +1577,11 @@ func Set(key string, value interface{}) { v.Set(key, value) } func (v *Viper) Set(key string, value interface{}) { // If alias passed in, then set the proper override - key = v.realKey(strings.ToLower(key)) + key = v.realKey(insensitiveopt.ToLower(key)) value = toCaseInsensitiveValue(value) path := strings.Split(key, v.keyDelim) - lastKey := strings.ToLower(path[len(path)-1]) + lastKey := insensitiveopt.ToLower(path[len(path)-1]) deepestMap := deepSearch(v.override, path[0:len(path)-1]) // set innermost value @@ -1753,7 +1762,7 @@ func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error { buf := new(bytes.Buffer) buf.ReadFrom(in) - switch format := strings.ToLower(v.getConfigType()); format { + switch format := insensitiveopt.ToLower(v.getConfigType()); format { case "yaml", "yml", "json", "toml", "hcl", "tfvars", "ini", "properties", "props", "prop", "dotenv", "env": err := v.decoderRegistry.Decode(format, buf.Bytes(), c) if err != nil { @@ -1784,9 +1793,9 @@ func (v *Viper) marshalWriter(f afero.File, configType string) error { } func keyExists(k string, m map[string]interface{}) string { - lk := strings.ToLower(k) + lk := insensitiveopt.ToLower(k) for mk := range m { - lmk := strings.ToLower(mk) + lmk := insensitiveopt.ToLower(mk) if lmk == lk { return mk } @@ -2065,7 +2074,7 @@ func (v *Viper) flattenAndMergeMap(shadow map[string]bool, m map[string]interfac m2 = cast.ToStringMap(val) default: // immediate value - shadow[strings.ToLower(fullKey)] = true + shadow[insensitiveopt.ToLower(fullKey)] = true continue } // recursively merge to shadow map @@ -2091,7 +2100,7 @@ outer: } } // add key - shadow[strings.ToLower(k)] = true + shadow[insensitiveopt.ToLower(k)] = true } return shadow } @@ -2110,7 +2119,7 @@ func (v *Viper) AllSettings() map[string]interface{} { continue } path := strings.Split(k, v.keyDelim) - lastKey := strings.ToLower(path[len(path)-1]) + lastKey := insensitiveopt.ToLower(path[len(path)-1]) deepestMap := deepSearch(m, path[0:len(path)-1]) // set innermost value deepestMap[lastKey] = value diff --git a/viper_convert.go b/viper_convert.go new file mode 100644 index 0000000..71b2197 --- /dev/null +++ b/viper_convert.go @@ -0,0 +1,31 @@ +package viper + +import "github.com/spf13/viper/internal/convert" + +//MapTo quick map to struct if know what the value carries +//using `viper:"key"`` tag to specify keys +/* + EG: + type Service struct { + Port int `viper:"port"` + IP string `viper:"ip"` + } + + SetDefault("service", map[string]interface{}{ + "ip": "127.0.0.1", + "port": 1234, + }) + + var service Service + err := MapTo("service", &service) + assert.NoError(t, err) + assert.Equal(t, Get("service.port"), service.Port) + assert.Equal(t, Get("service.ip"), service.IP) +*/ +func MapTo(key string, target interface{}) error { + return v.MapTo(key, target) +} + +func (v *Viper) MapTo(key string, target interface{}) error { + return convert.Convert(v.Get(key), target) +} diff --git a/viper_test.go b/viper_test.go index 23f8b14..0f42e53 100644 --- a/viper_test.go +++ b/viper_test.go @@ -486,6 +486,30 @@ func TestDefault(t *testing.T) { assert.Equal(t, "leather", Get("clothing.jacket")) } +func TestMapTo(t *testing.T) { + type Service struct { + Port int `viper:"port"` + IP string `viper:"ip"` + } + + SetDefault("service", map[string]interface{}{ + "ip": "127.0.0.1", + "port": 1234, + }) + SetDefault("version", "1.0.01") + + var service Service + var version string + err := MapTo("service", &service) + assert.NoError(t, err) + assert.Equal(t, Get("service.port"), service.Port) + assert.Equal(t, Get("service.ip"), service.IP) + err = MapTo("version", &version) + assert.NoError(t, err) + assert.Equal(t, Get("version"), version) + +} + func TestUnmarshaling(t *testing.T) { SetConfigType("yaml") r := bytes.NewReader(yamlExample)