Properly handle StringSlice flag escaped values.

This commit is contained in:
Paweł Szczur 2017-01-04 11:18:13 +01:00
parent 5d9cb36f40
commit 61c45d73da
2 changed files with 47 additions and 23 deletions

View file

@ -21,6 +21,7 @@ package viper
import ( import (
"bytes" "bytes"
"encoding/csv"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -880,7 +881,9 @@ func (v *Viper) find(lcaseKey string) interface{} {
return cast.ToBool(flag.ValueString()) return cast.ToBool(flag.ValueString())
case "stringSlice": case "stringSlice":
s := strings.TrimPrefix(flag.ValueString(), "[") s := strings.TrimPrefix(flag.ValueString(), "[")
return strings.Split(strings.TrimSuffix(s, "]"), ",") s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return res
default: default:
return flag.ValueString() return flag.ValueString()
} }
@ -947,7 +950,9 @@ func (v *Viper) find(lcaseKey string) interface{} {
return cast.ToBool(flag.ValueString()) return cast.ToBool(flag.ValueString())
case "stringSlice": case "stringSlice":
s := strings.TrimPrefix(flag.ValueString(), "[") s := strings.TrimPrefix(flag.ValueString(), "[")
return strings.Split(strings.TrimSuffix(s, "]"), ",") s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return res
default: default:
return flag.ValueString() return flag.ValueString()
} }
@ -957,6 +962,15 @@ func (v *Viper) find(lcaseKey string) interface{} {
return nil return nil
} }
func readAsCSV(val string) ([]string, error) {
if val == "" {
return []string{}, nil
}
stringReader := strings.NewReader(val)
csvReader := csv.NewReader(stringReader)
return csvReader.Read()
}
// IsSet checks to see if the key has been set in any of the data locations. // IsSet checks to see if the key has been set in any of the data locations.
// IsSet is case-insensitive for a key. // IsSet is case-insensitive for a key.
func IsSet(key string) bool { return v.IsSet(key) } func IsSet(key string) bool { return v.IsSet(key) }

View file

@ -539,29 +539,39 @@ func TestBindPFlags(t *testing.T) {
} }
func TestBindPFlagsStringSlice(t *testing.T) { func TestBindPFlagsStringSlice(t *testing.T) {
for _, testValue := range [][]string{nil, []string{}, []string{"jeden"}, []string{"dwa", "trzy"}} { for _, testValue := range []struct {
v := New() // create independent Viper object Expected []string
flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) Value string
flagSet.StringSlice("stringslice", testValue, "test") }{
flagSet.Visit(func(f *pflag.Flag) { {[]string{}, ""},
if len(testValue) > 0 { {[]string{"jeden"}, "jeden"},
f.Value.Set(strings.Join(testValue, ",")) {[]string{"dwa", "trzy"}, "dwa,trzy"},
f.Changed = true {[]string{"cztery", "piec , szesc"}, "cztery,\"piec , szesc\""}} {
for _, changed := range []bool{true, false} {
v := New() // create independent Viper object
flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError)
flagSet.StringSlice("stringslice", testValue.Expected, "test")
flagSet.Visit(func(f *pflag.Flag) {
if len(testValue.Value) > 0 {
f.Value.Set(testValue.Value)
f.Changed = changed
}
})
err := v.BindPFlags(flagSet)
if err != nil {
t.Fatalf("error binding flag set, %v", err)
} }
})
err := v.BindPFlags(flagSet) type TestStr struct {
if err != nil { StringSlice []string
t.Fatalf("error binding flag set, %v", err) }
} val := &TestStr{}
if err := v.Unmarshal(val); err != nil {
type TestStr struct { t.Fatalf("%+#v cannot unmarshal: %s", testValue.Value, err)
StringSlice []string }
} assert.Equal(t, testValue.Expected, val.StringSlice)
val := &TestStr{}
if err := v.Unmarshal(val); err != nil {
t.Fatalf("%+#v cannot unmarshal: %s", testValue, err)
assert.Equal(t, val.StringSlice, testValue)
} }
} }
} }