From 61c45d73da9eb53ed6be3b3eacf00d1bdf8d8f76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Szczur?= Date: Wed, 4 Jan 2017 11:18:13 +0100 Subject: [PATCH] Properly handle StringSlice flag escaped values. --- viper.go | 18 ++++++++++++++++-- viper_test.go | 52 ++++++++++++++++++++++++++++++--------------------- 2 files changed, 47 insertions(+), 23 deletions(-) diff --git a/viper.go b/viper.go index 137d015..63a66dd 100644 --- a/viper.go +++ b/viper.go @@ -21,6 +21,7 @@ package viper import ( "bytes" + "encoding/csv" "fmt" "io" "log" @@ -880,7 +881,9 @@ func (v *Viper) find(lcaseKey string) interface{} { return cast.ToBool(flag.ValueString()) case "stringSlice": s := strings.TrimPrefix(flag.ValueString(), "[") - return strings.Split(strings.TrimSuffix(s, "]"), ",") + s = strings.TrimSuffix(s, "]") + res, _ := readAsCSV(s) + return res default: return flag.ValueString() } @@ -947,7 +950,9 @@ func (v *Viper) find(lcaseKey string) interface{} { return cast.ToBool(flag.ValueString()) case "stringSlice": s := strings.TrimPrefix(flag.ValueString(), "[") - return strings.Split(strings.TrimSuffix(s, "]"), ",") + s = strings.TrimSuffix(s, "]") + res, _ := readAsCSV(s) + return res default: return flag.ValueString() } @@ -957,6 +962,15 @@ func (v *Viper) find(lcaseKey string) interface{} { return nil } +func readAsCSV(val string) ([]string, error) { + if val == "" { + return []string{}, nil + } + stringReader := strings.NewReader(val) + csvReader := csv.NewReader(stringReader) + return csvReader.Read() +} + // IsSet checks to see if the key has been set in any of the data locations. // IsSet is case-insensitive for a key. func IsSet(key string) bool { return v.IsSet(key) } diff --git a/viper_test.go b/viper_test.go index e928af5..8311f35 100644 --- a/viper_test.go +++ b/viper_test.go @@ -539,29 +539,39 @@ func TestBindPFlags(t *testing.T) { } func TestBindPFlagsStringSlice(t *testing.T) { - for _, testValue := range [][]string{nil, []string{}, []string{"jeden"}, []string{"dwa", "trzy"}} { - v := New() // create independent Viper object - flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) - flagSet.StringSlice("stringslice", testValue, "test") - flagSet.Visit(func(f *pflag.Flag) { - if len(testValue) > 0 { - f.Value.Set(strings.Join(testValue, ",")) - f.Changed = true + for _, testValue := range []struct { + Expected []string + Value string + }{ + {[]string{}, ""}, + {[]string{"jeden"}, "jeden"}, + {[]string{"dwa", "trzy"}, "dwa,trzy"}, + {[]string{"cztery", "piec , szesc"}, "cztery,\"piec , szesc\""}} { + + for _, changed := range []bool{true, false} { + v := New() // create independent Viper object + flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) + flagSet.StringSlice("stringslice", testValue.Expected, "test") + flagSet.Visit(func(f *pflag.Flag) { + if len(testValue.Value) > 0 { + f.Value.Set(testValue.Value) + f.Changed = changed + } + }) + + err := v.BindPFlags(flagSet) + if err != nil { + t.Fatalf("error binding flag set, %v", err) } - }) - err := v.BindPFlags(flagSet) - if err != nil { - t.Fatalf("error binding flag set, %v", err) - } - - type TestStr struct { - StringSlice []string - } - val := &TestStr{} - if err := v.Unmarshal(val); err != nil { - t.Fatalf("%+#v cannot unmarshal: %s", testValue, err) - assert.Equal(t, val.StringSlice, testValue) + type TestStr struct { + StringSlice []string + } + val := &TestStr{} + if err := v.Unmarshal(val); err != nil { + t.Fatalf("%+#v cannot unmarshal: %s", testValue.Value, err) + } + assert.Equal(t, testValue.Expected, val.StringSlice) } } }