From 1750e0a9d30f43164393d9699de5302e6393f445 Mon Sep 17 00:00:00 2001 From: Harley Laue Date: Mon, 23 Apr 2018 10:24:03 -0700 Subject: [PATCH] Check for nil before binding pflag(s) * When passing nil to BindPFlag or BindPFlags, the value is set to a struct and passed as an interface. That struct never checks for the flag(set) being nil. Thus, it makes sense to check before it's set to the struct. * fixes #422 --- viper.go | 6 ++++++ viper_test.go | 16 ++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/viper.go b/viper.go index 405dc20..61d0b01 100644 --- a/viper.go +++ b/viper.go @@ -951,6 +951,9 @@ func (v *Viper) UnmarshalExact(rawVal interface{}, opts ...DecoderConfigOption) // name as the config key. func BindPFlags(flags *pflag.FlagSet) error { return v.BindPFlags(flags) } func (v *Viper) BindPFlags(flags *pflag.FlagSet) error { + if flags == nil { + return fmt.Errorf("FlagSet cannot be nil") + } return v.BindFlagValues(pflagValueSet{flags}) } @@ -962,6 +965,9 @@ func (v *Viper) BindPFlags(flags *pflag.FlagSet) error { // func BindPFlag(key string, flag *pflag.Flag) error { return v.BindPFlag(key, flag) } func (v *Viper) BindPFlag(key string, flag *pflag.Flag) error { + if flag == nil { + return fmt.Errorf("flag for %q is nil", key) + } return v.BindFlagValue(key, pflagValue{flag}) } diff --git a/viper_test.go b/viper_test.go index fe942de..8427fc4 100644 --- a/viper_test.go +++ b/viper_test.go @@ -950,6 +950,14 @@ func TestBindPFlagsIntSlice(t *testing.T) { } } +func TestBindPFlagsNil(t *testing.T) { + v := New() + err := v.BindPFlags(nil) + if err == nil { + t.Fatalf("expected error when passing nil to BindPFlags") + } +} + func TestBindPFlag(t *testing.T) { var testString = "testing" var testValue = newStringValue(testString, &testString) @@ -1017,6 +1025,14 @@ func TestBindPFlagStringToString(t *testing.T) { } } +func TestBindPFlagNil(t *testing.T) { + v := New() + err := v.BindPFlag("any", nil) + if err == nil { + t.Fatalf("expected error when passing nil to BindPFlag") + } +} + func TestBoundCaseSensitivity(t *testing.T) { assert.Equal(t, "brown", Get("eyes"))