diff --git a/viper.go b/viper.go index cee37b2..f38be4c 100644 --- a/viper.go +++ b/viper.go @@ -131,6 +131,45 @@ func DecodeHook(hook mapstructure.DecodeHookFunc) DecoderConfigOption { } } +type MergeStrategy interface { + Can(interface{}) bool + Merge(src, tgt interface{}) interface{} +} + +type mergeStrategy struct { + CanFn func(interface{}) bool + MergeFn func(src, tgt interface{}) interface{} +} + +func (m *mergeStrategy) Can(val interface{}) bool { + return m.CanFn(val) +} + +func (m *mergeStrategy) Merge(src, tgt interface{}) interface{} { + return m.MergeFn(src, tgt) +} + +func newMergeStrategy(canFn func(interface{}) bool, + mergeFn func(src, tgt interface{}) interface{}) *mergeStrategy { + return &mergeStrategy{ + CanFn: canFn, + MergeFn: mergeFn, + } +} + +func SliceAppendStrategy() MergeStrategy { + return newMergeStrategy(func(i interface{}) bool { + val := reflect.ValueOf(i) + if val.Kind() != reflect.Slice { + return false + } + return true + }, func(src, tgt interface{}) interface{} { + return reflect.AppendSlice(reflect.ValueOf(tgt), reflect.ValueOf(src)). + Interface() + }) +} + // Viper is a prioritized configuration registry. It // maintains a set of configuration sources, fetches // values to populate those, and provides them according @@ -189,14 +228,15 @@ type Viper struct { envKeyReplacer *strings.Replacer allowEmptyEnv bool - config map[string]interface{} - override map[string]interface{} - defaults map[string]interface{} - kvstore map[string]interface{} - pflags map[string]FlagValue - env map[string]string - aliases map[string]string - typeByDefValue bool + config map[string]interface{} + override map[string]interface{} + defaults map[string]interface{} + kvstore map[string]interface{} + pflags map[string]FlagValue + env map[string]string + aliases map[string]string + typeByDefValue bool + mergeStrategies []MergeStrategy // Store read properties on the object so that we can write back in order with comments. // This will only be used if the configuration read is a properties file. @@ -1274,7 +1314,7 @@ func (v *Viper) MergeConfigMap(cfg map[string]interface{}) error { v.config = make(map[string]interface{}) } insensitiviseMap(cfg) - mergeMaps(cfg, v.config, nil) + mergeMaps(cfg, v.config, nil, v.mergeStrategies) return nil } @@ -1509,7 +1549,8 @@ func castMapFlagToMapInterface(src map[string]FlagValue) map[string]interface{} // deep. Both map types are supported as there is a go-yaml fork that uses // `map[string]interface{}` instead. func mergeMaps( - src, tgt map[string]interface{}, itgt map[interface{}]interface{}) { + src, tgt map[string]interface{}, itgt map[interface{}]interface{}, + strategies []MergeStrategy) { for sk, sv := range src { tk := keyExists(sk, tgt) if tk == "" { @@ -1549,15 +1590,25 @@ func mergeMaps( tsv := sv.(map[interface{}]interface{}) ssv := castToMapStringInterface(tsv) stv := castToMapStringInterface(ttv) - mergeMaps(ssv, stv, ttv) + mergeMaps(ssv, stv, ttv, strategies) + continue case map[string]interface{}: jww.TRACE.Printf("merging maps") - mergeMaps(sv.(map[string]interface{}), ttv, nil) + mergeMaps(sv.(map[string]interface{}), ttv, nil, strategies) + continue default: + var val interface{} = sv + for _, strat := range strategies { + if !strat.Can(tv) { + continue + } + val = strat.Merge(sv, tv) + break + } jww.TRACE.Printf("setting value") - tgt[tk] = sv + tgt[tk] = val if itgt != nil { - itgt[tk] = sv + itgt[tk] = val } } } diff --git a/viper_test.go b/viper_test.go index f4263d3..39e114c 100644 --- a/viper_test.go +++ b/viper_test.go @@ -1272,6 +1272,66 @@ func TestMergeConfigMap(t *testing.T) { } +func TestMergeMapsSliceWithoutStategy(t *testing.T) { + srcSlc := []string{"one", "two", "three"} + tgtSlc := []string{"four", "five", "six"} + src := map[string]interface{}{ + "testKey": srcSlc, + } + tgt := map[string]interface{}{ + "testKey": tgtSlc, + } + + mergeMaps(src, tgt, nil, []MergeStrategy{}) + val, ok := tgt["testKey"] + if !ok { + t.Fatal("unable to get key in merged target for map") + } + if val, ok := val.([]string); !ok { + t.Fatal("unexpected type in merged target for test key") + } else if !reflect.DeepEqual(val, srcSlc) { + t.Fatalf("unexpected key value, wanted %s got %s", srcSlc, val) + } +} + +func TestMergeMapsSliceWithStategy(t *testing.T) { + tests := []struct { + srcSlc []interface{} + tgtSlc []interface{} + result []interface{} + }{ + { + srcSlc: []interface{}{"one", "two", "three"}, + tgtSlc: []interface{}{"four", "five", "six"}, + result: []interface{}{"four", "five", "six", "one", "two", "three"}, + }, + { + srcSlc: []interface{}{1, 2, 3}, + tgtSlc: []interface{}{4, 5, 6}, + result: []interface{}{4, 5, 6, 1, 2, 3}, + }, + } + + for _, test := range tests { + src := map[string]interface{}{ + "testKey": test.srcSlc, + } + tgt := map[string]interface{}{ + "testKey": test.tgtSlc, + } + + mergeMaps(src, tgt, nil, []MergeStrategy{SliceAppendStrategy()}) + val, ok := tgt["testKey"] + if !ok { + t.Fatal("unable to get key in merged target for map") + } + if !reflect.DeepEqual(val, test.result) { + t.Fatalf("unexpected key value, wanted %s got %s", test.result, + val) + } + } +} + func TestUnmarshalingWithAliases(t *testing.T) { v := New() v.SetDefault("ID", 1)