From b52b215be2463dcc6e4823ae55395b7b99b2f60c Mon Sep 17 00:00:00 2001 From: Mark Sagi-Kazar Date: Sun, 27 Jan 2019 03:23:33 +0100 Subject: [PATCH] Add support for int slice flags --- viper.go | 12 ++++++++++++ viper_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/viper.go b/viper.go index cee37b2..74c6ed4 100644 --- a/viper.go +++ b/viper.go @@ -697,6 +697,8 @@ func (v *Viper) Get(key string) interface{} { return cast.ToDuration(val) case []string: return cast.ToStringSlice(val) + case []int: + return cast.ToIntSlice(val) } } @@ -992,6 +994,11 @@ func (v *Viper) find(lcaseKey string) interface{} { s = strings.TrimSuffix(s, "]") res, _ := readAsCSV(s) return res + case "intSlice": + s := strings.TrimPrefix(flag.ValueString(), "[") + s = strings.TrimSuffix(s, "]") + res, _ := readAsCSV(s) + return cast.ToIntSlice(res) default: return flag.ValueString() } @@ -1061,6 +1068,11 @@ func (v *Viper) find(lcaseKey string) interface{} { s = strings.TrimSuffix(s, "]") res, _ := readAsCSV(s) return res + case "intSlice": + s := strings.TrimPrefix(flag.ValueString(), "[") + s = strings.TrimSuffix(s, "]") + res, _ := readAsCSV(s) + return cast.ToIntSlice(res) default: return flag.ValueString() } diff --git a/viper_test.go b/viper_test.go index f4263d3..f325b82 100644 --- a/viper_test.go +++ b/viper_test.go @@ -658,6 +658,51 @@ func TestBindPFlagsStringSlice(t *testing.T) { } } +func TestBindPFlagsIntSlice(t *testing.T) { + tests := []struct { + Expected []int + Value string + }{ + {nil, ""}, + {[]int{1}, "1"}, + {[]int{2, 3}, "2,3"}, + } + + v := New() // create independent Viper object + defaultVal := []int{0} + v.SetDefault("intslice", defaultVal) + + for _, testValue := range tests { + flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) + flagSet.IntSlice("intslice", testValue.Expected, "test") + + for _, changed := range []bool{true, false} { + flagSet.VisitAll(func(f *pflag.Flag) { + f.Value.Set(testValue.Value) + f.Changed = changed + }) + + err := v.BindPFlags(flagSet) + if err != nil { + t.Fatalf("error binding flag set, %v", err) + } + + type TestInt struct { + IntSlice []int + } + val := &TestInt{} + if err := v.Unmarshal(val); err != nil { + t.Fatalf("%+#v cannot unmarshal: %s", testValue.Value, err) + } + if changed { + assert.Equal(t, testValue.Expected, val.IntSlice) + } else { + assert.Equal(t, defaultVal, val.IntSlice) + } + } + } +} + func TestBindPFlag(t *testing.T) { var testString = "testing" var testValue = newStringValue(testString, &testString)