From 856e87ed6d923ba341cd34f608703d5ac66708fe Mon Sep 17 00:00:00 2001 From: "zhijian.chen" Date: Sat, 15 Apr 2023 13:25:37 +0800 Subject: [PATCH 1/4] [FEATURE] it's option to set case sensitive --- README.md | 9 ++++ internal/encoding/dotenv/map_utils.go | 5 +- internal/encoding/ini/map_utils.go | 5 +- internal/encoding/javaproperties/codec.go | 3 +- internal/encoding/javaproperties/map_utils.go | 5 +- internal/insensitiveOpt/fix.go | 28 +++++++++++ util.go | 8 ++-- viper.go | 47 +++++++++++-------- 8 files changed, 77 insertions(+), 33 deletions(-) create mode 100644 internal/insensitiveOpt/fix.go diff --git a/README.md b/README.md index cd39290..c86b9b7 100644 --- a/README.md +++ b/README.md @@ -862,7 +862,16 @@ application foundation needs. Is there a better name for a [commander](http://en.wikipedia.org/wiki/Cobra_Commander)? ### Does Viper support case sensitive keys? +#### [FEATURE] surport case sensitive +```go +// if you want to keep case insensitive, you can do nothing +// but if you want to make it case sensitive, please do the following step +func main(){ + viper.SetCaseSensitive() + // your code next... +} +``` **tl;dr:** No. Viper merges configuration from various sources, many of which are either case insensitive or uses different casing than the rest of the sources (eg. env vars). diff --git a/internal/encoding/dotenv/map_utils.go b/internal/encoding/dotenv/map_utils.go index ce6e6ef..7f98abf 100644 --- a/internal/encoding/dotenv/map_utils.go +++ b/internal/encoding/dotenv/map_utils.go @@ -1,9 +1,8 @@ package dotenv import ( - "strings" - "github.com/spf13/cast" + insensitiveopt "github.com/spf13/viper/internal/insensitiveOpt" ) // flattenAndMergeMap recursively flattens the given map into a new map @@ -31,7 +30,7 @@ func flattenAndMergeMap(shadow map[string]interface{}, m map[string]interface{}, m2 = cast.ToStringMap(val) default: // immediate value - shadow[strings.ToLower(fullKey)] = val + shadow[insensitiveopt.ToLower(fullKey)] = val continue } // recursively merge to shadow map diff --git a/internal/encoding/ini/map_utils.go b/internal/encoding/ini/map_utils.go index 8329856..aa02ca6 100644 --- a/internal/encoding/ini/map_utils.go +++ b/internal/encoding/ini/map_utils.go @@ -1,9 +1,8 @@ package ini import ( - "strings" - "github.com/spf13/cast" + insensitiveopt "github.com/spf13/viper/internal/insensitiveOpt" ) // THIS CODE IS COPIED HERE: IT SHOULD NOT BE MODIFIED @@ -64,7 +63,7 @@ func flattenAndMergeMap(shadow map[string]interface{}, m map[string]interface{}, m2 = cast.ToStringMap(val) default: // immediate value - shadow[strings.ToLower(fullKey)] = val + shadow[insensitiveopt.ToLower(fullKey)] = val continue } // recursively merge to shadow map diff --git a/internal/encoding/javaproperties/codec.go b/internal/encoding/javaproperties/codec.go index b8a2251..a2723d4 100644 --- a/internal/encoding/javaproperties/codec.go +++ b/internal/encoding/javaproperties/codec.go @@ -7,6 +7,7 @@ import ( "github.com/magiconair/properties" "github.com/spf13/cast" + insensitiveopt "github.com/spf13/viper/internal/insensitiveOpt" ) // Codec implements the encoding.Encoder and encoding.Decoder interfaces for Java properties encoding. @@ -67,7 +68,7 @@ func (c *Codec) Decode(b []byte, v map[string]interface{}) error { // recursively build nested maps path := strings.Split(key, c.keyDelimiter()) - lastKey := strings.ToLower(path[len(path)-1]) + lastKey := insensitiveopt.ToLower(path[len(path)-1]) deepestMap := deepSearch(v, path[0:len(path)-1]) // set innermost value diff --git a/internal/encoding/javaproperties/map_utils.go b/internal/encoding/javaproperties/map_utils.go index 93755ca..6801be4 100644 --- a/internal/encoding/javaproperties/map_utils.go +++ b/internal/encoding/javaproperties/map_utils.go @@ -1,9 +1,8 @@ package javaproperties import ( - "strings" - "github.com/spf13/cast" + insensitiveopt "github.com/spf13/viper/internal/insensitiveOpt" ) // THIS CODE IS COPIED HERE: IT SHOULD NOT BE MODIFIED @@ -64,7 +63,7 @@ func flattenAndMergeMap(shadow map[string]interface{}, m map[string]interface{}, m2 = cast.ToStringMap(val) default: // immediate value - shadow[strings.ToLower(fullKey)] = val + shadow[insensitiveopt.ToLower(fullKey)] = val continue } // recursively merge to shadow map diff --git a/internal/insensitiveOpt/fix.go b/internal/insensitiveOpt/fix.go new file mode 100644 index 0000000..3648ca2 --- /dev/null +++ b/internal/insensitiveOpt/fix.go @@ -0,0 +1,28 @@ +package insensitiveopt + +import ( + "strings" + "unicode" +) + +var insensitive = true + +func Insensitive(f bool) { + insensitive = f +} + +func ToLower(s string) string { + if insensitive { + return strings.ToLower(s) + } + + return s +} + +func ToLowerRune(s rune) rune { + if insensitive { + return unicode.ToLower(s) + } + + return s +} diff --git a/util.go b/util.go index 64e6575..8cb2e17 100644 --- a/util.go +++ b/util.go @@ -16,9 +16,9 @@ import ( "path/filepath" "runtime" "strings" - "unicode" "github.com/spf13/cast" + insensitiveopt "github.com/spf13/viper/internal/insensitiveOpt" ) // ConfigParseError denotes failing to parse configuration file. @@ -50,7 +50,7 @@ func copyAndInsensitiviseMap(m map[string]interface{}) map[string]interface{} { nm := make(map[string]interface{}) for key, val := range m { - lkey := strings.ToLower(key) + lkey := insensitiveopt.ToLower(key) switch v := val.(type) { case map[interface{}]interface{}: nm[lkey] = copyAndInsensitiviseMap(cast.ToStringMap(v)) @@ -83,7 +83,7 @@ func insensitiviseVal(val interface{}) interface{} { func insensitiviseMap(m map[string]interface{}) { for key, val := range m { val = insensitiviseVal(val) - lower := strings.ToLower(key) + lower := insensitiveopt.ToLower(key) if key != lower { // remove old key (not lower-cased) delete(m, key) @@ -159,7 +159,7 @@ func parseSizeInBytes(sizeStr string) uint { if lastChar > 0 { if sizeStr[lastChar] == 'b' || sizeStr[lastChar] == 'B' { if lastChar > 1 { - switch unicode.ToLower(rune(sizeStr[lastChar-1])) { + switch insensitiveopt.ToLowerRune(rune(sizeStr[lastChar-1])) { case 'k': multiplier = 1 << 10 sizeStr = strings.TrimSpace(sizeStr[:lastChar-1]) diff --git a/viper.go b/viper.go index 7eac4b7..7c28f21 100644 --- a/viper.go +++ b/viper.go @@ -48,6 +48,7 @@ import ( "github.com/spf13/viper/internal/encoding/json" "github.com/spf13/viper/internal/encoding/toml" "github.com/spf13/viper/internal/encoding/yaml" + insensitiveopt "github.com/spf13/viper/internal/insensitiveOpt" ) // ConfigMarshalError happens when failing to marshal the configuration. @@ -305,6 +306,14 @@ func Reset() { SupportedRemoteProviders = []string{"etcd", "etcd3", "consul", "firestore"} } +func SetCaseSensitive() { + insensitiveopt.Insensitive(false) +} + +func SetCaseInsensitive() { + insensitiveopt.Insensitive(true) +} + // TODO: make this lazy initialization instead func (v *Viper) resetEncoding() { encoderRegistry := encoding.NewEncoderRegistry() @@ -699,7 +708,7 @@ func (v *Viper) searchIndexableWithPathPrefixes(source interface{}, path []strin // search for path prefixes, starting from the longest one for i := len(path); i > 0; i-- { - prefixKey := strings.ToLower(strings.Join(path[0:i], v.keyDelim)) + prefixKey := insensitiveopt.ToLower(strings.Join(path[0:i], v.keyDelim)) var val interface{} switch sourceIndexable := source.(type) { @@ -890,7 +899,7 @@ func GetViper() *Viper { func Get(key string) interface{} { return v.Get(key) } func (v *Viper) Get(key string) interface{} { - lcaseKey := strings.ToLower(key) + lcaseKey := insensitiveopt.ToLower(key) val := v.find(lcaseKey, true) if val == nil { return nil @@ -950,7 +959,7 @@ func (v *Viper) Sub(key string) *Viper { } if reflect.TypeOf(data).Kind() == reflect.Map { - subv.parents = append(v.parents, strings.ToLower(key)) + subv.parents = append(v.parents, insensitiveopt.ToLower(key)) subv.automaticEnvApplied = v.automaticEnvApplied subv.envPrefix = v.envPrefix subv.envKeyReplacer = v.envKeyReplacer @@ -1189,7 +1198,7 @@ func (v *Viper) BindFlagValue(key string, flag FlagValue) error { if flag == nil { return fmt.Errorf("flag for %q is nil", key) } - v.pflags[strings.ToLower(key)] = flag + v.pflags[insensitiveopt.ToLower(key)] = flag return nil } @@ -1206,7 +1215,7 @@ func (v *Viper) BindEnv(input ...string) error { return fmt.Errorf("missing key to bind to") } - key := strings.ToLower(input[0]) + key := insensitiveopt.ToLower(input[0]) if len(input) == 1 { v.env[key] = append(v.env[key], v.mergeWithEnvPrefix(key)) @@ -1423,7 +1432,7 @@ func stringToStringConv(val string) interface{} { func IsSet(key string) bool { return v.IsSet(key) } func (v *Viper) IsSet(key string) bool { - lcaseKey := strings.ToLower(key) + lcaseKey := insensitiveopt.ToLower(key) val := v.find(lcaseKey, false) return val != nil } @@ -1450,11 +1459,11 @@ func (v *Viper) SetEnvKeyReplacer(r *strings.Replacer) { func RegisterAlias(alias string, key string) { v.RegisterAlias(alias, key) } func (v *Viper) RegisterAlias(alias string, key string) { - v.registerAlias(alias, strings.ToLower(key)) + v.registerAlias(alias, insensitiveopt.ToLower(key)) } func (v *Viper) registerAlias(alias string, key string) { - alias = strings.ToLower(alias) + alias = insensitiveopt.ToLower(alias) if alias != key && alias != v.realKey(key) { _, exists := v.aliases[alias] @@ -1499,7 +1508,7 @@ func (v *Viper) realKey(key string) string { func InConfig(key string) bool { return v.InConfig(key) } func (v *Viper) InConfig(key string) bool { - lcaseKey := strings.ToLower(key) + lcaseKey := insensitiveopt.ToLower(key) // if the requested key is an alias, then return the proper key lcaseKey = v.realKey(lcaseKey) @@ -1515,11 +1524,11 @@ func SetDefault(key string, value interface{}) { v.SetDefault(key, value) } func (v *Viper) SetDefault(key string, value interface{}) { // If alias passed in, then set the proper default - key = v.realKey(strings.ToLower(key)) + key = v.realKey(insensitiveopt.ToLower(key)) value = toCaseInsensitiveValue(value) path := strings.Split(key, v.keyDelim) - lastKey := strings.ToLower(path[len(path)-1]) + lastKey := insensitiveopt.ToLower(path[len(path)-1]) deepestMap := deepSearch(v.defaults, path[0:len(path)-1]) // set innermost value @@ -1534,11 +1543,11 @@ func Set(key string, value interface{}) { v.Set(key, value) } func (v *Viper) Set(key string, value interface{}) { // If alias passed in, then set the proper override - key = v.realKey(strings.ToLower(key)) + key = v.realKey(insensitiveopt.ToLower(key)) value = toCaseInsensitiveValue(value) path := strings.Split(key, v.keyDelim) - lastKey := strings.ToLower(path[len(path)-1]) + lastKey := insensitiveopt.ToLower(path[len(path)-1]) deepestMap := deepSearch(v.override, path[0:len(path)-1]) // set innermost value @@ -1719,7 +1728,7 @@ func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error { buf := new(bytes.Buffer) buf.ReadFrom(in) - switch format := strings.ToLower(v.getConfigType()); format { + switch format := insensitiveopt.ToLower(v.getConfigType()); format { case "yaml", "yml", "json", "toml", "hcl", "tfvars", "ini", "properties", "props", "prop", "dotenv", "env": err := v.decoderRegistry.Decode(format, buf.Bytes(), c) if err != nil { @@ -1750,9 +1759,9 @@ func (v *Viper) marshalWriter(f afero.File, configType string) error { } func keyExists(k string, m map[string]interface{}) string { - lk := strings.ToLower(k) + lk := insensitiveopt.ToLower(k) for mk := range m { - lmk := strings.ToLower(mk) + lmk := insensitiveopt.ToLower(mk) if lmk == lk { return mk } @@ -2031,7 +2040,7 @@ func (v *Viper) flattenAndMergeMap(shadow map[string]bool, m map[string]interfac m2 = cast.ToStringMap(val) default: // immediate value - shadow[strings.ToLower(fullKey)] = true + shadow[insensitiveopt.ToLower(fullKey)] = true continue } // recursively merge to shadow map @@ -2057,7 +2066,7 @@ outer: } } // add key - shadow[strings.ToLower(k)] = true + shadow[insensitiveopt.ToLower(k)] = true } return shadow } @@ -2076,7 +2085,7 @@ func (v *Viper) AllSettings() map[string]interface{} { continue } path := strings.Split(k, v.keyDelim) - lastKey := strings.ToLower(path[len(path)-1]) + lastKey := insensitiveopt.ToLower(path[len(path)-1]) deepestMap := deepSearch(m, path[0:len(path)-1]) // set innermost value deepestMap[lastKey] = value From a62e20b9ad407a2330e74404770eeb81fce8080b Mon Sep 17 00:00:00 2001 From: Nothin Date: Sat, 15 Apr 2023 14:58:25 +0800 Subject: [PATCH 2/4] [FEATURE] viper add MapTo to quick map to struct or base type --- README.md | 36 +++++++ internal/convert/convert.go | 166 +++++++++++++++++++++++++++++++ internal/convert/convert_test.go | 108 ++++++++++++++++++++ viper_convert.go | 31 ++++++ viper_test.go | 19 ++++ 5 files changed, 360 insertions(+) create mode 100644 internal/convert/convert.go create mode 100644 internal/convert/convert_test.go create mode 100644 viper_convert.go diff --git a/README.md b/README.md index c86b9b7..e398dab 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,42 @@ if err := viper.ReadInConfig(); err != nil { *NOTE [since 1.6]:* You can also have a file without an extension and specify the format programmaticaly. For those configuration files that lie in the home of the user without any extension like `.bashrc` +### MapTo +- source file +```yaml +service: + port: 1234 + ip: "127.0.0.1" +version: 1.0.01 +``` +- use MapTo + +```go +type Service struct { + Port int `viper:"port"` + IP string `viper:"ip"` +} +//your prepare code ... + +var service Service +var version string + +if err := viper.MapTo("service",&service); err != nil { + //error handler... +} +if err := viper.MapTo("version",&version); err != nil { + //error handler... +} + + + +log.Println(service,version) + +//.... + +``` + + ### Writing Config Files Reading from config files is useful, but at times you want to store all modifications made at run time. diff --git a/internal/convert/convert.go b/internal/convert/convert.go new file mode 100644 index 0000000..415c158 --- /dev/null +++ b/internal/convert/convert.go @@ -0,0 +1,166 @@ +package convert + +import ( + "fmt" + "reflect" + "strings" +) + +var convertUtils = map[reflect.Kind]func(reflect.Value, reflect.Value) error{ + reflect.String: converNormal, + reflect.Int: converNormal, + reflect.Int16: converNormal, + reflect.Int32: converNormal, + reflect.Int64: converNormal, + reflect.Uint: converNormal, + reflect.Uint16: converNormal, + reflect.Uint32: converNormal, + reflect.Uint64: converNormal, + reflect.Float32: converNormal, + reflect.Float64: converNormal, + reflect.Uint8: converNormal, + reflect.Int8: converNormal, +} + +//Convert +//示例 +/* + type Target struct { + A int `viper:"aint"` + B string `viper:"bstr"` + } + src :=map[string]interface{}{ + "aint":1224, + "bstr":"124132" + } + + var t Target + Convert(src,&t) + +*/ +//fix循环引用的问题 +var _ = func() struct{} { + convertUtils[reflect.Map] = convertMap + convertUtils[reflect.Array] = convertSlice + convertUtils[reflect.Slice] = convertSlice + return struct{}{} +}() + +func Convert(src interface{}, dst interface{}) (err error) { + + dstRef := reflect.ValueOf(dst) + if dstRef.Kind() != reflect.Ptr { + return fmt.Errorf("dst is not ptr") + } + + dstRef = reflect.Indirect(dstRef) + + srcRef := reflect.ValueOf(src) + if srcRef.Kind() == reflect.Ptr || srcRef.Kind() == reflect.Interface { + srcRef = srcRef.Elem() + } + if f, ok := convertUtils[srcRef.Kind()]; ok { + return f(srcRef, dstRef) + } + + return fmt.Errorf("no implemented:%s", srcRef.Type()) +} + +func converNormal(src reflect.Value, dst reflect.Value) error { + if dst.CanSet() { + if src.Type() == dst.Type() { + dst.Set(src) + } else if src.CanConvert(dst.Type()) { + dst.Set(src.Convert(dst.Type())) + } else { + return fmt.Errorf("can not convert:%s:%s", src.Type().String(), dst.Type().String()) + } + } + return nil +} + +func convertSlice(src reflect.Value, dst reflect.Value) error { + if dst.Kind() != reflect.Array && dst.Kind() != reflect.Slice { + return fmt.Errorf("error type:%s", dst.Type().String()) + } + l := src.Len() + target := reflect.MakeSlice(dst.Type(), l, l) + if dst.CanSet() { + dst.Set(target) + } + for i := 0; i < l; i++ { + srcValue := src.Index(i) + if srcValue.Kind() == reflect.Ptr || srcValue.Kind() == reflect.Interface { + srcValue = srcValue.Elem() + } + if f, ok := convertUtils[srcValue.Kind()]; ok { + err := f(srcValue, dst.Index(i)) + if err != nil { + return err + } + } + } + + return nil +} + +func convertMap(src reflect.Value, dst reflect.Value) error { + if src.Kind() != reflect.Map || dst.Kind() != reflect.Struct { + if src.Kind() == reflect.Interface { + return convertMap(src.Elem(), dst) + } else { + return fmt.Errorf("src or dst type error,%s,%s", src.Type().String(), dst.Type().String()) + } + } + dstType := dst.Type() + num := dstType.NumField() + exist := map[string]int{} + for i := 0; i < num; i++ { + k := dstType.Field(i).Tag.Get("viper") + if k == "" { + k = dstType.Field(i).Name + } + if strings.Contains(k, ",") { + taglist := strings.Split(k, ",") + if taglist[0] == "" { + + k = dstType.Field(i).Name + } else { + k = taglist[0] + + } + + } + exist[k] = i + } + + keys := src.MapKeys() + for _, key := range keys { + if index, ok := exist[key.String()]; ok { + v := dst.Field(index) + if v.Kind() == reflect.Struct { + err := convertMap(src.MapIndex(key), v) + if err != nil { + return err + } + } else { + if v.CanSet() { + if v.Type() == src.MapIndex(key).Elem().Type() { + v.Set(src.MapIndex(key).Elem()) + } else if src.MapIndex(key).Elem().CanConvert(v.Type()) { + v.Set(src.MapIndex(key).Elem().Convert(v.Type())) + } else if f, ok := convertUtils[src.MapIndex(key).Elem().Kind()]; ok && f != nil { + err := f(src.MapIndex(key).Elem(), v) + if err != nil { + return err + } + } else { + return fmt.Errorf("error type:d(%s)s(%s)", v.Type(), src.Type()) + } + } + } + } + } + + return nil +} diff --git a/internal/convert/convert_test.go b/internal/convert/convert_test.go new file mode 100644 index 0000000..05c1ea3 --- /dev/null +++ b/internal/convert/convert_test.go @@ -0,0 +1,108 @@ +package convert + +import ( + "testing" +) + +func TestConvert(t *testing.T) { + type Tmp1 struct { + Str string `viper:"str"` + I8 int8 `viper:"i8"` + Int16 int16 `viper:"i16"` + Int32 int32 `viper:"i32"` + Int64 int64 `viper:"i64"` + I int `viper:"i"` + U8 int8 `viper:"u8"` + Uint16 int16 `viper:"u16"` + Uint32 int32 `viper:"u32"` + Uint64 int64 `viper:"u64"` + U int `viper:"u"` + F32 float32 `viper:"f32"` + F64 float64 `viper:"f64"` + TF bool `viper:"tf"` + M map[string]interface{} `viper:"m"` + S []interface{} `viper:"s"` + } + tc := map[string]interface{}{ + "str": "Hello world", + "i8": -8, + "i16": -16, + "i32": -32, + "i64": -64, + "i": -1, + "u8": 8, + "u16": 16, + "u32": 32, + "u64": 64, + "u": 1, + "f32": 3.32, + "f64": 3.64, + "tf": true, + "m": map[string]interface{}{ + "im": 123, + }, + "s": []interface{}{ + "1234", + 1.23, + }, + } + + var tmp Tmp1 + err := Convert(tc, &tmp) + if err != nil { + t.Error(err) + } + // t.Error(tmp) + +} + +func BenchmarkConvert(b *testing.B) { + type Tmp1 struct { + Str string `viper:"str"` + I8 int8 `viper:"i8"` + Int16 int16 `viper:"i16"` + Int32 int32 `viper:"i32"` + Int64 int64 `viper:"i64"` + I int `viper:"i"` + U8 int8 `viper:"u8"` + Uint16 int16 `viper:"u16"` + Uint32 int32 `viper:"u32"` + Uint64 int64 `viper:"u64"` + U int `viper:"u"` + F32 float32 `viper:"f32"` + F64 float64 `viper:"f64"` + TF bool `viper:"tf"` + M map[string]interface{} `viper:"m"` + S []interface{} `viper:"s"` + } + tc := map[string]interface{}{ + "str": "Hello world", + "i8": -8, + "i16": -16, + "i32": -32, + "i64": -64, + "i": -1, + "u8": 8, + "u16": 16, + "u32": 32, + "u64": 64, + "u": 1, + "f32": 3.32, + "f64": 3.64, + "tf": true, + "m": map[string]interface{}{ + "im": 123, + }, + "s": []interface{}{ + "1234", + 1.23, + }, + } + for i := 0; i < b.N; i++ { + var tmp Tmp1 + err := Convert(tc, &tmp) + if err != nil { + b.Error(err) + } + } +} diff --git a/viper_convert.go b/viper_convert.go new file mode 100644 index 0000000..71b2197 --- /dev/null +++ b/viper_convert.go @@ -0,0 +1,31 @@ +package viper + +import "github.com/spf13/viper/internal/convert" + +//MapTo quick map to struct if know what the value carries +//using `viper:"key"`` tag to specify keys +/* + EG: + type Service struct { + Port int `viper:"port"` + IP string `viper:"ip"` + } + + SetDefault("service", map[string]interface{}{ + "ip": "127.0.0.1", + "port": 1234, + }) + + var service Service + err := MapTo("service", &service) + assert.NoError(t, err) + assert.Equal(t, Get("service.port"), service.Port) + assert.Equal(t, Get("service.ip"), service.IP) +*/ +func MapTo(key string, target interface{}) error { + return v.MapTo(key, target) +} + +func (v *Viper) MapTo(key string, target interface{}) error { + return convert.Convert(v.Get(key), target) +} diff --git a/viper_test.go b/viper_test.go index 8283b5c..5596724 100644 --- a/viper_test.go +++ b/viper_test.go @@ -503,6 +503,25 @@ func TestDefault(t *testing.T) { assert.Equal(t, "leather", Get("clothing.jacket")) } +func TestMapTo(t *testing.T) { + type Service struct { + Port int `viper:"port"` + IP string `viper:"ip"` + } + + SetDefault("service", map[string]interface{}{ + "ip": "127.0.0.1", + "port": 1234, + }) + + var service Service + err := MapTo("service", &service) + assert.NoError(t, err) + assert.Equal(t, Get("service.port"), service.Port) + assert.Equal(t, Get("service.ip"), service.IP) + +} + func TestUnmarshaling(t *testing.T) { SetConfigType("yaml") r := bytes.NewReader(yamlExample) From 3390ac2d915f95eb2c8b0f5a53f16f413068c5a8 Mon Sep 17 00:00:00 2001 From: Nothin Date: Sat, 15 Apr 2023 15:00:27 +0800 Subject: [PATCH 3/4] [DOC] add MapTo test --- viper_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/viper_test.go b/viper_test.go index 5596724..1e7e266 100644 --- a/viper_test.go +++ b/viper_test.go @@ -513,12 +513,17 @@ func TestMapTo(t *testing.T) { "ip": "127.0.0.1", "port": 1234, }) + SetDefault("version", "1.0.01") var service Service + var version string err := MapTo("service", &service) assert.NoError(t, err) assert.Equal(t, Get("service.port"), service.Port) assert.Equal(t, Get("service.ip"), service.IP) + err = MapTo("version", &version) + assert.NoError(t, err) + assert.Equal(t, Get("version"), version) } From b69795b0aa4f38d7b748b6f6efe28122e484d48b Mon Sep 17 00:00:00 2001 From: Nothin Date: Fri, 22 Sep 2023 14:36:34 +0800 Subject: [PATCH 4/4] [DOC] add comvert --- internal/convert/convert.go | 101 ++++++++++++++++++++++++++++--- internal/convert/convert_test.go | 2 +- 2 files changed, 92 insertions(+), 11 deletions(-) diff --git a/internal/convert/convert.go b/internal/convert/convert.go index 415c158..53d5314 100644 --- a/internal/convert/convert.go +++ b/internal/convert/convert.go @@ -20,14 +20,15 @@ var convertUtils = map[reflect.Kind]func(reflect.Value, reflect.Value) error{ reflect.Float64: converNormal, reflect.Uint8: converNormal, reflect.Int8: converNormal, + reflect.Bool: converNormal, } -//Convert +//Convert 类型强制转换 //示例 /* type Target struct { - A int `viper:"aint"` - B string `viper:"bstr"` + A int `json:"aint"` + B string `json:"bstr"` } src :=map[string]interface{}{ "aint":1224, @@ -47,6 +48,11 @@ var _ = func() struct{} { }() func Convert(src interface{}, dst interface{}) (err error) { + defer func() { + if v := recover(); v != nil { + err = fmt.Errorf("panic recover:%v", v) + } + }() dstRef := reflect.ValueOf(dst) if dstRef.Kind() != reflect.Ptr { @@ -82,7 +88,10 @@ func converNormal(src reflect.Value, dst reflect.Value) error { func convertSlice(src reflect.Value, dst reflect.Value) error { if dst.Kind() != reflect.Array && dst.Kind() != reflect.Slice { return fmt.Errorf("error type:%s", dst.Type().String()) + } else if !src.IsValid() { + return nil } + l := src.Len() target := reflect.MakeSlice(dst.Type(), l, l) if dst.CanSet() { @@ -105,12 +114,26 @@ func convertSlice(src reflect.Value, dst reflect.Value) error { } func convertMap(src reflect.Value, dst reflect.Value) error { + // + if src.Kind() == reflect.Ptr || src.Kind() == reflect.Interface { + src = src.Elem() + } if src.Kind() != reflect.Map || dst.Kind() != reflect.Struct { - if src.Kind() == reflect.Interface { - return convertMap(src.Elem(), dst) - } else { - return fmt.Errorf("src or dst type error,%s,%s", src.Type().String(), dst.Type().String()) + if dst.Kind() == reflect.Map { + return converMapToMap(src, dst) } + if !(dst.Kind() == reflect.Ptr && dst.Type().Elem().Kind() == reflect.Struct) { + if dst.Kind() == reflect.Interface && dst.CanSet() { + dst.Set(src) + return nil + } + return fmt.Errorf("src or dst type error:%s,%s", src.Kind().String(), dst.Type().String()) + } + if !reflect.Indirect(dst).IsValid() { + v := reflect.New(dst.Type().Elem()) + dst.Set(v) + } + dst = reflect.Indirect(dst) } dstType := dst.Type() num := dstType.NumField() @@ -123,8 +146,19 @@ func convertMap(src reflect.Value, dst reflect.Value) error { if strings.Contains(k, ",") { taglist := strings.Split(k, ",") if taglist[0] == "" { + if len(taglist) == 2 && + taglist[1] == "inline" { + v := dst.Field(i) - k = dstType.Field(i).Name + err := convertMap(src, v) + if err != nil { + return err + } + dst.Field(i).Set(v) + continue + } else { + k = dstType.Field(i).Name + } } else { k = taglist[0] @@ -143,8 +177,14 @@ func convertMap(src reflect.Value, dst reflect.Value) error { if err != nil { return err } + } else if v.Kind() == reflect.Slice { + err := convertSlice(src.MapIndex(key).Elem(), v) + if err != nil { + return err + } + } else { - if v.CanSet() { + if v.CanSet() && src.MapIndex(key).IsValid() && !src.MapIndex(key).IsZero() { if v.Type() == src.MapIndex(key).Elem().Type() { v.Set(src.MapIndex(key).Elem()) } else if src.MapIndex(key).Elem().CanConvert(v.Type()) { @@ -155,7 +195,7 @@ func convertMap(src reflect.Value, dst reflect.Value) error { return err } } else { - return fmt.Errorf("error type:d(%s)s(%s)", v.Type(), src.Type()) + return fmt.Errorf("error type:d(%s)s(%s)", v.Type(), src.MapIndex(key).Elem().Type()) } } } @@ -164,3 +204,44 @@ func convertMap(src reflect.Value, dst reflect.Value) error { return nil } + +func converMapToMap(src reflect.Value, dst reflect.Value) error { + if src.Kind() != reflect.Map || dst.Kind() != reflect.Map { + return fmt.Errorf("type error: src(%v),dst(%v)", src.Kind(), src.Kind()) + } + mv := reflect.MakeMap(dst.Type()) + keys := src.MapKeys() + dt := dst.Type().Elem().Kind() + for _, key := range keys { + if dt == reflect.Struct { + me := reflect.New(dst.Type().Elem()) + me = reflect.Indirect(me) + convertMap(src.MapIndex(key).Elem(), me) + mv.SetMapIndex(key, me) + } else if dt == reflect.Ptr { + me := reflect.New(dst.Type().Elem().Elem()) + me = reflect.Indirect(me) + convertMap(src.MapIndex(key).Elem(), me) + mv.SetMapIndex(key, me.Addr()) + } else if dt == reflect.Slice { + l := src.MapIndex(key).Elem().Len() + v := reflect.MakeSlice(dst.Type().Elem(), l, l) + err := convertSlice(src.MapIndex(key).Elem(), v) + if err != nil { + return err + } + mv.SetMapIndex(key, v) + } else { + if src.MapIndex(key).Elem().Kind() != dst.Type().Elem().Kind() && + src.MapIndex(key).Elem().CanConvert(dst.Type().Elem()) { + v := src.MapIndex(key).Elem().Convert(dst.Type().Elem()) + mv.SetMapIndex(key, v) + continue + } + + mv.SetMapIndex(key, src.MapIndex(key).Elem()) + } + } + dst.Set(mv) + return nil +} diff --git a/internal/convert/convert_test.go b/internal/convert/convert_test.go index 05c1ea3..ae98129 100644 --- a/internal/convert/convert_test.go +++ b/internal/convert/convert_test.go @@ -52,7 +52,7 @@ func TestConvert(t *testing.T) { if err != nil { t.Error(err) } - // t.Error(tmp) + t.Error(tmp) }