From 0967fc9aceab2ce9da34061253ac10fb99bba5b2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pawe=C5=82=20Szczur?= <orian@users.noreply.github.com>
Date: Mon, 17 Apr 2017 10:08:15 +0200
Subject: [PATCH] Properly handle string slice values

---
 viper.go      | 18 ++++++++++++++++--
 viper_test.go | 38 ++++++++++++++++++++++++++++++++++++++
 2 files changed, 54 insertions(+), 2 deletions(-)

diff --git a/viper.go b/viper.go
index 5ca66ae..31b41a6 100644
--- a/viper.go
+++ b/viper.go
@@ -21,6 +21,7 @@ package viper
 
 import (
 	"bytes"
+	"encoding/csv"
 	"fmt"
 	"io"
 	"log"
@@ -894,7 +895,9 @@ func (v *Viper) find(lcaseKey string) interface{} {
 			return cast.ToBool(flag.ValueString())
 		case "stringSlice":
 			s := strings.TrimPrefix(flag.ValueString(), "[")
-			return strings.TrimSuffix(s, "]")
+			s = strings.TrimSuffix(s, "]")
+			res, _ := readAsCSV(s)
+			return res
 		default:
 			return flag.ValueString()
 		}
@@ -961,7 +964,9 @@ func (v *Viper) find(lcaseKey string) interface{} {
 			return cast.ToBool(flag.ValueString())
 		case "stringSlice":
 			s := strings.TrimPrefix(flag.ValueString(), "[")
-			return strings.TrimSuffix(s, "]")
+			s = strings.TrimSuffix(s, "]")
+			res, _ := readAsCSV(s)
+			return res
 		default:
 			return flag.ValueString()
 		}
@@ -971,6 +976,15 @@ func (v *Viper) find(lcaseKey string) interface{} {
 	return nil
 }
 
+func readAsCSV(val string) ([]string, error) {
+	if val == "" {
+		return []string{}, nil
+	}
+	stringReader := strings.NewReader(val)
+	csvReader := csv.NewReader(stringReader)
+	return csvReader.Read()
+}
+
 // IsSet checks to see if the key has been set in any of the data locations.
 // IsSet is case-insensitive for a key.
 func IsSet(key string) bool { return v.IsSet(key) }
diff --git a/viper_test.go b/viper_test.go
index cd10603..774ca11 100644
--- a/viper_test.go
+++ b/viper_test.go
@@ -538,6 +538,44 @@ func TestBindPFlags(t *testing.T) {
 
 }
 
+func TestBindPFlagsStringSlice(t *testing.T) {
+	for _, testValue := range []struct {
+		Expected []string
+		Value    string
+	}{
+		{[]string{}, ""},
+		{[]string{"jeden"}, "jeden"},
+		{[]string{"dwa", "trzy"}, "dwa,trzy"},
+		{[]string{"cztery", "piec , szesc"}, "cztery,\"piec , szesc\""}} {
+
+		for _, changed := range []bool{true, false} {
+			v := New() // create independent Viper object
+			flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError)
+			flagSet.StringSlice("stringslice", testValue.Expected, "test")
+			flagSet.Visit(func(f *pflag.Flag) {
+				if len(testValue.Value) > 0 {
+					f.Value.Set(testValue.Value)
+					f.Changed = changed
+				}
+			})
+
+			err := v.BindPFlags(flagSet)
+			if err != nil {
+				t.Fatalf("error binding flag set, %v", err)
+			}
+
+			type TestStr struct {
+				StringSlice []string
+			}
+			val := &TestStr{}
+			if err := v.Unmarshal(val); err != nil {
+				t.Fatalf("%+#v cannot unmarshal: %s", testValue.Value, err)
+			}
+			assert.Equal(t, testValue.Expected, val.StringSlice)
+		}
+	}
+}
+
 func TestBindPFlag(t *testing.T) {
 	var testString = "testing"
 	var testValue = newStringValue(testString, &testString)