diff --git a/.gitignore b/.gitignore index 8365624..9466b19 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +*.iml +.idea/ # Compiled Object files, Static and Dynamic libs (Shared Objects) *.o *.a diff --git a/README.md b/README.md index cf17560..6b48fdb 100644 --- a/README.md +++ b/README.md @@ -406,14 +406,23 @@ The following functions and methods exist: * `Get(key string) : interface{}` * `GetBool(key string) : bool` + * `GetDefaultBool(key string, value bool) : bool` * `GetFloat64(key string) : float64` + * `GetDefaultFloat64(key string, value float64) : float64` * `GetInt(key string) : int` + * `GetDefaultInt(key string, value int) : int` * `GetString(key string) : string` + * `GetDefaultString(key string, value string) : string` * `GetStringMap(key string) : map[string]interface{}` + * `GetDefaultStringMap(key string, value map[string]interface{}) : map[string]interface{}` * `GetStringMapString(key string) : map[string]string` + * `GetDefaultStringMapString(key string, value map[string]string) : map[string]string` * `GetStringSlice(key string) : []string` + * `GetDefaultStringSlice(key string, value []string) : []string` * `GetTime(key string) : time.Time` + * `GetDefaultTime(key string, value time.Time) : time.Time` * `GetDuration(key string) : time.Duration` + * `GetDefaultDuration(key string, value time.Duration) : time.Duration` * `IsSet(key string) : bool` One important thing to recognize is that each Get function will return a zero diff --git a/viper.go b/viper.go index f17790e..b42b8bb 100644 --- a/viper.go +++ b/viper.go @@ -249,7 +249,7 @@ func (v *Viper) WatchConfig() { for { select { case event := <-watcher.Events: - // we only care about the config file + // we only care about the config file if filepath.Clean(event.Name) == configFile { if event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create { err := v.ReadInConfig() @@ -548,66 +548,165 @@ func (v *Viper) GetString(key string) string { return cast.ToString(v.Get(key)) } +// Returns the value associated with the key as a string or the passed default value +func GetDefaultString(key string, defaultValue string) string { return v.GetDefaultString(key, defaultValue) } +func (v *Viper) GetDefaultString(key string, defaultValue string) string { + if value := v.Get(key); value != nil { + return cast.ToString(value) + } + return defaultValue +} + // Returns the value associated with the key as a boolean func GetBool(key string) bool { return v.GetBool(key) } func (v *Viper) GetBool(key string) bool { return cast.ToBool(v.Get(key)) } +// Returns the value associated with the key as a boolean or the passed default value +func GetDefaultBoolean(key string, defaultValue bool) bool { return v.GetDefaultBoolean(key, defaultValue) } +func (v *Viper) GetDefaultBoolean(key string, defaultValue bool) bool { + if value := v.Get(key); value != nil { + return cast.ToBool(value) + } + return defaultValue +} + // Returns the value associated with the key as an integer func GetInt(key string) int { return v.GetInt(key) } func (v *Viper) GetInt(key string) int { return cast.ToInt(v.Get(key)) } +// Returns the value associated with the key as an int or the passed default value +func GetDefaultInt(key string, defaultValue int) int { return v.GetDefaultInt(key, defaultValue) } +func (v *Viper) GetDefaultInt(key string, defaultValue int) int { + if value := v.Get(key); value != nil { + return cast.ToInt(value) + } + return defaultValue +} + // Returns the value associated with the key as an integer func GetInt64(key string) int64 { return v.GetInt64(key) } func (v *Viper) GetInt64(key string) int64 { return cast.ToInt64(v.Get(key)) } +// Returns the value associated with the key as an int64 or the passed default value +func GetDefaultInt64(key string, defaultValue int64) int64 { return v.GetDefaultInt64(key, defaultValue) } +func (v *Viper) GetDefaultInt64(key string, defaultValue int64) int64 { + if value := v.Get(key); value != nil { + return cast.ToInt64(value) + } + return defaultValue +} + // Returns the value associated with the key as a float64 func GetFloat64(key string) float64 { return v.GetFloat64(key) } func (v *Viper) GetFloat64(key string) float64 { return cast.ToFloat64(v.Get(key)) } +// Returns the value associated with the key as a float64 or the passed default value +func GetDefaultFloat64(key string, defaultValue float64) float64 { return v.GetDefaultFloat64(key, defaultValue) } +func (v *Viper) GetDefaultFloat64(key string, defaultValue float64) float64 { + if value := v.Get(key); value != nil { + return cast.ToFloat64(value) + } + return defaultValue +} + // Returns the value associated with the key as time func GetTime(key string) time.Time { return v.GetTime(key) } func (v *Viper) GetTime(key string) time.Time { return cast.ToTime(v.Get(key)) } +// Returns the value associated with the key as time or the passed default value +func GetDefaultTime(key string, defaultValue time.Time) time.Time { return v.GetDefaultTime(key, defaultValue) } +func (v *Viper) GetDefaultTime(key string, defaultValue time.Time) time.Time { + if value := v.Get(key); value != nil { + return cast.ToTime(value) + } + return defaultValue +} + // Returns the value associated with the key as a duration func GetDuration(key string) time.Duration { return v.GetDuration(key) } func (v *Viper) GetDuration(key string) time.Duration { return cast.ToDuration(v.Get(key)) } +// Returns the value associated with the key as a duration or the passed default value +func GetDefaultDuration(key string, defaultValue time.Duration) time.Duration { return v.GetDefaultDuration(key, defaultValue) } +func (v *Viper) GetDefaultDuration(key string, defaultValue time.Duration) time.Duration { + if value := v.Get(key); value != nil { + return cast.ToDuration(value) + } + return defaultValue +} + // Returns the value associated with the key as a slice of strings func GetStringSlice(key string) []string { return v.GetStringSlice(key) } func (v *Viper) GetStringSlice(key string) []string { return cast.ToStringSlice(v.Get(key)) } +// Returns the value associated with the key as a duration or the passed default value +func GetDefaultStringSlice(key string, defaultValue []string) []string { return v.GetDefaultStringSlice(key, defaultValue) } +func (v *Viper) GetDefaultStringSlice(key string, defaultValue []string) []string { + if value := v.Get(key); value != nil { + return cast.ToStringSlice(value) + } + return defaultValue +} + // Returns the value associated with the key as a map of interfaces func GetStringMap(key string) map[string]interface{} { return v.GetStringMap(key) } func (v *Viper) GetStringMap(key string) map[string]interface{} { return cast.ToStringMap(v.Get(key)) } +// Returns the value associated with the key as a map or the passed default value +func GetDefaultStringMap(key string, defaultValue map[string]interface{}) map[string]interface{} { return v.GetDefaultStringMap(key, defaultValue) } +func (v *Viper) GetDefaultStringMap(key string, defaultValue map[string]interface{}) map[string]interface{} { + if value := v.Get(key); value != nil { + return cast.ToStringMap(value) + } + return defaultValue +} + // Returns the value associated with the key as a map of strings func GetStringMapString(key string) map[string]string { return v.GetStringMapString(key) } func (v *Viper) GetStringMapString(key string) map[string]string { return cast.ToStringMapString(v.Get(key)) } +// Returns the value associated with the key as a map of strings or the passed default value +func GetDefaultStringMapString(key string, defaultValue map[string]string) map[string]string { return v.GetDefaultStringMapString(key, defaultValue) } +func (v *Viper) GetDefaultStringMapString(key string, defaultValue map[string]string) map[string]string { + if value := v.Get(key); value != nil { + return cast.ToStringMapString(value) + } + return defaultValue +} + // Returns the value associated with the key as a map to a slice of strings. func GetStringMapStringSlice(key string) map[string][]string { return v.GetStringMapStringSlice(key) } func (v *Viper) GetStringMapStringSlice(key string) map[string][]string { return cast.ToStringMapStringSlice(v.Get(key)) } +// Returns the value associated with the key as a map to a slice of string or the passed default value +func GetDefaultStringMapStringSlice(key string, defaultValue map[string][]string) map[string][]string { return v.GetDefaultStringMapStringSlice(key, defaultValue) } +func (v *Viper) GetDefaultStringMapStringSlice(key string, defaultValue map[string][]string) map[string][]string { + if value := v.Get(key); value != nil { + return cast.ToStringMapStringSlice(value) + } + return defaultValue +} + // Returns the size of the value associated with the given key // in bytes. func GetSizeInBytes(key string) uint { return v.GetSizeInBytes(key) } @@ -1005,7 +1104,7 @@ func keyExists(k string, m map[string]interface{}) string { } func castToMapStringInterface( - src map[interface{}]interface{}) map[string]interface{} { +src map[interface{}]interface{}) map[string]interface{} { tgt := map[string]interface{}{} for k, v := range src { tgt[fmt.Sprintf("%v", k)] = v @@ -1019,7 +1118,7 @@ func castToMapStringInterface( // deep. Both map types are supported as there is a go-yaml fork that uses // `map[string]interface{}` instead. func mergeMaps( - src, tgt map[string]interface{}, itgt map[interface{}]interface{}) { +src, tgt map[string]interface{}, itgt map[interface{}]interface{}) { for sk, sv := range src { tk := keyExists(sk, tgt) if tk == "" { diff --git a/viper_test.go b/viper_test.go index 0c0c7e5..c174246 100644 --- a/viper_test.go +++ b/viper_test.go @@ -250,6 +250,29 @@ func TestDefault(t *testing.T) { assert.Equal(t, "leather", Get("clothing.jacket")) } +func TestDefaultValueOnGetters(t *testing.T) { + SetConfigType("yaml") + r := bytes.NewReader(yamlExample) + + unmarshalReader(r, v.config) + assert.False(t, InConfig("surname")) + assert.Equal(t, "doe", GetDefaultString("surname", "doe")) + assert.False(t, InConfig("month")) + assert.Equal(t, 1, GetDefaultInt("month", 1)) + assert.False(t, InConfig("year")) + assert.Equal(t, int64(20), GetDefaultInt64("year", 20)) + assert.False(t, InConfig("false-bool")) + assert.Equal(t, true, GetDefaultBoolean("false-bool", true)) + assert.False(t, InConfig("float-64")) + assert.Equal(t, float64(2.4), GetDefaultFloat64("float-64", 2.4)) + localTime := time.Now() + assert.False(t, InConfig("no-time")) + assert.Equal(t, localTime, GetDefaultTime("no-time", localTime)) + secDuration := time.Second + assert.False(t, InConfig("no-duration")) + assert.Equal(t, secDuration, GetDefaultDuration("no-duration", secDuration)) +} + func TestUnmarshalling(t *testing.T) { SetConfigType("yaml") r := bytes.NewReader(yamlExample)