Properly handle StringSlice flag escaped values.

This commit is contained in:
Paweł Szczur 2017-01-04 11:18:13 +01:00
parent 5d9cb36f40
commit 61c45d73da
2 changed files with 47 additions and 23 deletions

View file

@ -21,6 +21,7 @@ package viper
import (
"bytes"
"encoding/csv"
"fmt"
"io"
"log"
@ -880,7 +881,9 @@ func (v *Viper) find(lcaseKey string) interface{} {
return cast.ToBool(flag.ValueString())
case "stringSlice":
s := strings.TrimPrefix(flag.ValueString(), "[")
return strings.Split(strings.TrimSuffix(s, "]"), ",")
s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return res
default:
return flag.ValueString()
}
@ -947,7 +950,9 @@ func (v *Viper) find(lcaseKey string) interface{} {
return cast.ToBool(flag.ValueString())
case "stringSlice":
s := strings.TrimPrefix(flag.ValueString(), "[")
return strings.Split(strings.TrimSuffix(s, "]"), ",")
s = strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return res
default:
return flag.ValueString()
}
@ -957,6 +962,15 @@ func (v *Viper) find(lcaseKey string) interface{} {
return nil
}
func readAsCSV(val string) ([]string, error) {
if val == "" {
return []string{}, nil
}
stringReader := strings.NewReader(val)
csvReader := csv.NewReader(stringReader)
return csvReader.Read()
}
// IsSet checks to see if the key has been set in any of the data locations.
// IsSet is case-insensitive for a key.
func IsSet(key string) bool { return v.IsSet(key) }

View file

@ -539,14 +539,23 @@ func TestBindPFlags(t *testing.T) {
}
func TestBindPFlagsStringSlice(t *testing.T) {
for _, testValue := range [][]string{nil, []string{}, []string{"jeden"}, []string{"dwa", "trzy"}} {
for _, testValue := range []struct {
Expected []string
Value string
}{
{[]string{}, ""},
{[]string{"jeden"}, "jeden"},
{[]string{"dwa", "trzy"}, "dwa,trzy"},
{[]string{"cztery", "piec , szesc"}, "cztery,\"piec , szesc\""}} {
for _, changed := range []bool{true, false} {
v := New() // create independent Viper object
flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError)
flagSet.StringSlice("stringslice", testValue, "test")
flagSet.StringSlice("stringslice", testValue.Expected, "test")
flagSet.Visit(func(f *pflag.Flag) {
if len(testValue) > 0 {
f.Value.Set(strings.Join(testValue, ","))
f.Changed = true
if len(testValue.Value) > 0 {
f.Value.Set(testValue.Value)
f.Changed = changed
}
})
@ -560,8 +569,9 @@ func TestBindPFlagsStringSlice(t *testing.T) {
}
val := &TestStr{}
if err := v.Unmarshal(val); err != nil {
t.Fatalf("%+#v cannot unmarshal: %s", testValue, err)
assert.Equal(t, val.StringSlice, testValue)
t.Fatalf("%+#v cannot unmarshal: %s", testValue.Value, err)
}
assert.Equal(t, testValue.Expected, val.StringSlice)
}
}
}