diff --git a/flags_test.go b/flags_test.go index 0b976b6..1dd7b0d 100644 --- a/flags_test.go +++ b/flags_test.go @@ -45,7 +45,7 @@ func TestBindFlagValueSet(t *testing.T) { func TestBindFlagValue(t *testing.T) { var testString = "testing" - var testValue = newStringValue(testString, &testString) + var testValue = newStringValue(testString) flag := &pflag.Flag{ Name: "testflag", diff --git a/viper.go b/viper.go index c64094a..6bc0ef1 100644 --- a/viper.go +++ b/viper.go @@ -670,7 +670,7 @@ func GetViper() *Viper { func Get(key string) interface{} { return v.Get(key) } func (v *Viper) Get(key string) interface{} { lcaseKey := strings.ToLower(key) - val := v.find(lcaseKey) + val := v.find(lcaseKey, true) if val == nil { return nil } @@ -911,6 +911,9 @@ func (v *Viper) UnmarshalExact(rawVal interface{}) error { // name as the config key. func BindPFlags(flags *pflag.FlagSet) error { return v.BindPFlags(flags) } func (v *Viper) BindPFlags(flags *pflag.FlagSet) error { + if flags == nil { + return fmt.Errorf("FlagSet cannot be nil") + } return v.BindFlagValues(pflagValueSet{flags}) } @@ -922,6 +925,9 @@ func (v *Viper) BindPFlags(flags *pflag.FlagSet) error { // func BindPFlag(key string, flag *pflag.Flag) error { return v.BindPFlag(key, flag) } func (v *Viper) BindPFlag(key string, flag *pflag.Flag) error { + if flag == nil { + return fmt.Errorf("flag for %q is nil", key) + } return v.BindFlagValue(key, pflagValue{flag}) } @@ -979,9 +985,12 @@ func (v *Viper) BindEnv(input ...string) error { // Given a key, find the value. // Viper will check in the following order: // flag, env, config file, key/value store, default. -// Viper will check to see if an alias exists first. -// Note: this assumes a lower-cased key given. -func (v *Viper) find(lcaseKey string) interface{} { +// Viper will then check in the following order: +// flag, env, config file, key/value store. +// Lastly, if no value was found and flagDefault is true, and if the key +// corresponds to a flag, the flag's default value is returned. +// +func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} { var ( val interface{} @@ -1107,6 +1116,24 @@ func (v *Viper) find(lcaseKey string) interface{} { } // last item, no need to check shadowing + // it could also be a key prefix, search for that prefix to get the values from + // pflags that match it + sub := make(map[string]interface{}) + for _, key := range v.AllKeys() { + if strings.HasPrefix(key, lcaseKey) { + value := v.Get(key) + keypath := strings.Split(lcaseKey, v.keyDelim) + path := strings.Split(key, v.keyDelim)[len(keypath)-1:] + lastKey := strings.ToLower(path[len(path)-1]) + deepestMap := deepSearch(sub, path[1:len(path)-1]) + // set innermost value + deepestMap[lastKey] = value + } + } + if len(sub) != 0 { + return sub + } + return nil } @@ -1124,7 +1151,7 @@ func readAsCSV(val string) ([]string, error) { func IsSet(key string) bool { return v.IsSet(key) } func (v *Viper) IsSet(key string) bool { lcaseKey := strings.ToLower(key) - val := v.find(lcaseKey) + val := v.find(lcaseKey, false) return val != nil } diff --git a/viper_test.go b/viper_test.go index 8dcfd1a..bdcd5a4 100644 --- a/viper_test.go +++ b/viper_test.go @@ -253,9 +253,8 @@ func initDirs(t *testing.T) (string, string, func()) { // stubs for PFlag Values type stringValue string -func newStringValue(val string, p *string) *stringValue { - *p = val - return (*stringValue)(p) +func newStringValue(val string) *stringValue { + return (*stringValue)(&val) } func (s *stringValue) Set(val string) error { @@ -853,9 +852,17 @@ func TestBindPFlagsIntSlice(t *testing.T) { } } +func TestBindPFlagsNil(t *testing.T) { + v := New() + err := v.BindPFlags(nil) + if err == nil { + t.Fatalf("expected error when passing nil to BindPFlags") + } +} + func TestBindPFlag(t *testing.T) { var testString = "testing" - var testValue = newStringValue(testString, &testString) + var testValue = newStringValue(testString) flag := &pflag.Flag{ Name: "testflag", @@ -874,6 +881,14 @@ func TestBindPFlag(t *testing.T) { } +func TestBindPFlagNil(t *testing.T) { + v := New() + err := v.BindPFlag("any", nil) + if err == nil { + t.Fatalf("expected error when passing nil to BindPFlag") + } +} + func TestBoundCaseSensitivity(t *testing.T) { assert.Equal(t, "brown", Get("eyes")) @@ -883,7 +898,7 @@ func TestBoundCaseSensitivity(t *testing.T) { assert.Equal(t, "blue", Get("eyes")) var testString = "green" - var testValue = newStringValue(testString, &testString) + var testValue = newStringValue(testString) flag := &pflag.Flag{ Name: "eyeballs", @@ -1128,6 +1143,73 @@ func TestSub(t *testing.T) { assert.Equal(t, (*Viper)(nil), subv) } +func TestSubPflags(t *testing.T) { + v := New() + + // same as yamlExample, without hobbies + v.BindPFlag("name", &pflag.Flag{Value: newStringValue("steve"), Changed: true}) + v.BindPFlag("clothing.jacket", &pflag.Flag{Value: newStringValue("leather"), Changed: true}) + v.BindPFlag("clothing.trousers", &pflag.Flag{Value: newStringValue("denim"), Changed: true}) + v.BindPFlag("clothing.pants.size", &pflag.Flag{Value: newStringValue("large"), Changed: true}) + v.BindPFlag("age", &pflag.Flag{Value: newStringValue("35"), Changed: true}) + v.BindPFlag("eyes", &pflag.Flag{Value: newStringValue("brown"), Changed: true}) + v.BindPFlag("beard", &pflag.Flag{Value: newStringValue("yes"), Changed: true}) + + type pants struct { + Size string + } + + type clothing struct { + Jacket string + Trousers string + Pants pants + } + + type cfg struct { + Name string + Clothing clothing + Age int + Eyes string + Beard bool + } + + var c cfg + v.Unmarshal(&c) + assert.Equal(t, v.Get("name"), c.Name) + assert.Equal(t, v.Get("clothing.jacket"), c.Clothing.Jacket) + assert.Equal(t, v.Get("clothing.trousers"), c.Clothing.Trousers) + assert.Equal(t, v.Get("clothing.pants.size"), c.Clothing.Pants.Size) + assert.Equal(t, v.GetInt("age"), c.Age) + assert.Equal(t, v.Get("eyes"), c.Eyes) + assert.Equal(t, v.GetBool("beard"), c.Beard) + + var cloth clothing + v.UnmarshalKey("clothing", &cloth) + assert.Equal(t, c.Clothing, cloth) + + var p pants + v.UnmarshalKey("clothing.pants", &p) + assert.Equal(t, c.Clothing.Pants, p) + + var size string + v.UnmarshalKey("clothing.pants.size", &size) + assert.Equal(t, c.Clothing.Pants.Size, size) + + subv := v.Sub("clothing") + assert.Equal(t, v.Get("clothing.jacket"), subv.Get("jacket")) + assert.Equal(t, v.Get("clothing.trousers"), subv.Get("trousers")) + assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("pants.size")) + + subv = v.Sub("clothing.pants") + assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("size")) + + subv = v.Sub("clothing.pants.size") + assert.Equal(t, (*Viper)(nil), subv) + + subv = v.Sub("missing.key") + assert.Equal(t, (*Viper)(nil), subv) +} + var hclWriteExpected = []byte(`"foos" = { "foo" = { "key" = 1