From 07a0f11324923d6fe3a38e44090a3f6528aef5b1 Mon Sep 17 00:00:00 2001 From: Xavier Lucas Date: Mon, 25 Feb 2019 17:38:41 +0100 Subject: [PATCH] Expose type casting errors --- util.go | 22 ++++++++----- viper.go | 85 +++++++++++++++++++++++++++++++++++++++++++++++++++ viper_test.go | 25 +++++++++------ 3 files changed, 114 insertions(+), 18 deletions(-) diff --git a/util.go b/util.go index 952cad4..1da1827 100644 --- a/util.go +++ b/util.go @@ -146,16 +146,17 @@ func userHomeDir() string { return os.Getenv("HOME") } -func safeMul(a, b uint) uint { +func safeMul(a, b uint) (uint, error) { c := a * b if a > 1 && b > 1 && c/b != a { - return 0 + return 0, fmt.Errorf("multiplication overflows uint: %d*%d", a, b) } - return c + return c, nil } // parseSizeInBytes converts strings like 1GB or 12 mb into an unsigned integer number of bytes -func parseSizeInBytes(sizeStr string) uint { +func parseSizeInBytes(sizeStr string) (uint, error) { + rawStr := sizeStr sizeStr = strings.TrimSpace(sizeStr) lastChar := len(sizeStr) - 1 multiplier := uint(1) @@ -181,12 +182,17 @@ func parseSizeInBytes(sizeStr string) uint { } } - size := cast.ToInt(sizeStr) - if size < 0 { - size = 0 + num, err := cast.ToUintE(sizeStr) + if err != nil { + return 0, err } - return safeMul(uint(size), multiplier) + size, err := safeMul(num, multiplier) + if err != nil { + return 0, fmt.Errorf("unable to cast %q to uint: %s", rawStr, err) + } + + return size, nil } // deepSearch scans deep maps, following the key indexes listed in the diff --git a/viper.go b/viper.go index 247eaa1..69cbfb6 100644 --- a/viper.go +++ b/viper.go @@ -726,77 +726,162 @@ func (v *Viper) GetString(key string) string { return cast.ToString(v.Get(key)) } +// GetStringE is the same than GetString but also returns parsing errors. +func GetStringE(key string) (string, error) { return v.GetStringE(key) } +func (v *Viper) GetStringE(key string) (string, error) { + return cast.ToStringE(v.Get(key)) +} + // GetBool returns the value associated with the key as a boolean. func GetBool(key string) bool { return v.GetBool(key) } func (v *Viper) GetBool(key string) bool { return cast.ToBool(v.Get(key)) } +// GetBoolE is the same than GetBool but also returns parsing errors. +func GetBoolE(key string) (bool, error) { return v.GetBoolE(key) } +func (v *Viper) GetBoolE(key string) (bool, error) { + return cast.ToBoolE(v.Get(key)) +} + // GetInt returns the value associated with the key as an integer. func GetInt(key string) int { return v.GetInt(key) } func (v *Viper) GetInt(key string) int { return cast.ToInt(v.Get(key)) } +// GetIntE is the same than GetInt but also returns parsing errors. +func GetIntE(key string) (int, error) { return v.GetIntE(key) } +func (v *Viper) GetIntE(key string) (int, error) { + return cast.ToIntE(v.Get(key)) +} + // GetInt32 returns the value associated with the key as an integer. func GetInt32(key string) int32 { return v.GetInt32(key) } func (v *Viper) GetInt32(key string) int32 { return cast.ToInt32(v.Get(key)) } +// GetInt32E is the same than GetInt32 but also returns parsing errors. +func GetInt32E(key string) (int32, error) { return v.GetInt32E(key) } +func (v *Viper) GetInt32E(key string) (int32, error) { + return cast.ToInt32E(v.Get(key)) +} + // GetInt64 returns the value associated with the key as an integer. func GetInt64(key string) int64 { return v.GetInt64(key) } func (v *Viper) GetInt64(key string) int64 { return cast.ToInt64(v.Get(key)) } +// GetInt64E is the same than GetInt64 but also returns parsing errors. +func GetInt64E(key string) (int64, error) { return v.GetInt64E(key) } +func (v *Viper) GetInt64E(key string) (int64, error) { + return cast.ToInt64E(v.Get(key)) +} + // GetFloat64 returns the value associated with the key as a float64. func GetFloat64(key string) float64 { return v.GetFloat64(key) } func (v *Viper) GetFloat64(key string) float64 { return cast.ToFloat64(v.Get(key)) } +// GetFloat64E is the same than GetFloat64 but also returns parsing errors. +func GetFloat64E(key string) (float64, error) { return v.GetFloat64E(key) } +func (v *Viper) GetFloat64E(key string) (float64, error) { + return cast.ToFloat64E(v.Get(key)) +} + // GetTime returns the value associated with the key as time. func GetTime(key string) time.Time { return v.GetTime(key) } func (v *Viper) GetTime(key string) time.Time { return cast.ToTime(v.Get(key)) } +// GetTimeE is the same than GetTime but also returns parsing errors. +func GetTimeE(key string) (time.Time, error) { return v.GetTimeE(key) } +func (v *Viper) GetTimeE(key string) (time.Time, error) { + return cast.ToTimeE(v.Get(key)) +} + // GetDuration returns the value associated with the key as a duration. func GetDuration(key string) time.Duration { return v.GetDuration(key) } func (v *Viper) GetDuration(key string) time.Duration { return cast.ToDuration(v.Get(key)) } +// GetDurationE is the same than GetDuration but also returns parsing errors. +func GetDurationE(key string) (time.Duration, error) { return v.GetDurationE(key) } +func (v *Viper) GetDurationE(key string) (time.Duration, error) { + return cast.ToDurationE(v.Get(key)) +} + // GetStringSlice returns the value associated with the key as a slice of strings. func GetStringSlice(key string) []string { return v.GetStringSlice(key) } func (v *Viper) GetStringSlice(key string) []string { return cast.ToStringSlice(v.Get(key)) } +// GetStringSliceE is the same than GetStringSlice but also returns parsing errors. +func GetStringSliceE(key string) ([]string, error) { return v.GetStringSliceE(key) } +func (v *Viper) GetStringSliceE(key string) ([]string, error) { + return cast.ToStringSliceE(v.Get(key)) +} + // GetStringMap returns the value associated with the key as a map of interfaces. func GetStringMap(key string) map[string]interface{} { return v.GetStringMap(key) } func (v *Viper) GetStringMap(key string) map[string]interface{} { return cast.ToStringMap(v.Get(key)) } +// GetStringMapE is the same than GetStringMap but also returns parsing errors. +func GetStringMapE(key string) (map[string]interface{}, error) { return v.GetStringMapE(key) } +func (v *Viper) GetStringMapE(key string) (map[string]interface{}, error) { + return cast.ToStringMapE(v.Get(key)) +} + // GetStringMapString returns the value associated with the key as a map of strings. func GetStringMapString(key string) map[string]string { return v.GetStringMapString(key) } func (v *Viper) GetStringMapString(key string) map[string]string { return cast.ToStringMapString(v.Get(key)) } +// GetStringMapStringE is the same than GetStringMapString but also returns parsing errors. +func GetStringMapStringE(key string) (map[string]string, error) { return v.GetStringMapStringE(key) } +func (v *Viper) GetStringMapStringE(key string) (map[string]string, error) { + return cast.ToStringMapStringE(v.Get(key)) +} + // GetStringMapStringSlice returns the value associated with the key as a map to a slice of strings. func GetStringMapStringSlice(key string) map[string][]string { return v.GetStringMapStringSlice(key) } func (v *Viper) GetStringMapStringSlice(key string) map[string][]string { return cast.ToStringMapStringSlice(v.Get(key)) } +// GetStringMapStringSliceE is the same than GetStringMapStringSlice but also returns parsing errors. +func GetStringMapStringSliceE(key string) (map[string][]string, error) { + return v.GetStringMapStringSliceE(key) +} +func (v *Viper) GetStringMapStringSliceE(key string) (map[string][]string, error) { + return cast.ToStringMapStringSliceE(v.Get(key)) +} + // GetSizeInBytes returns the size of the value associated with the given key // in bytes. func GetSizeInBytes(key string) uint { return v.GetSizeInBytes(key) } func (v *Viper) GetSizeInBytes(key string) uint { sizeStr := cast.ToString(v.Get(key)) + size, _ := parseSizeInBytes(sizeStr) + return size +} + +// GetSizeInBytesE is the same than GetSizeInBytes but also returns parsing errors. +func GetSizeInBytesE(key string) (uint, error) { return v.GetSizeInBytesE(key) } +func (v *Viper) GetSizeInBytesE(key string) (uint, error) { + sizeStr, err := cast.ToStringE(v.Get(key)) + if err != nil { + return 0, err + } return parseSizeInBytes(sizeStr) } diff --git a/viper_test.go b/viper_test.go index 42691be..d5296df 100644 --- a/viper_test.go +++ b/viper_test.go @@ -735,19 +735,24 @@ func TestBoundCaseSensitivity(t *testing.T) { } func TestSizeInBytes(t *testing.T) { - input := map[string]uint{ - "": 0, - "b": 0, - "12 bytes": 0, - "200000000000gb": 0, - "12 b": 12, - "43 MB": 43 * (1 << 20), - "10mb": 10 * (1 << 20), - "1gb": 1 << 30, + input := map[string]struct { + Size uint + Error bool + }{ + "": {0, true}, + "b": {0, true}, + "12 bytes": {0, true}, + "200000000000gb": {0, true}, + "12 b": {12, false}, + "43 MB": {43 * (1 << 20), false}, + "10mb": {10 * (1 << 20), false}, + "1gb": {1 << 30, false}, } for str, expected := range input { - assert.Equal(t, expected, parseSizeInBytes(str), str) + size, err := parseSizeInBytes(str) + assert.Equal(t, expected.Size, size, str) + assert.Equal(t, expected.Error, err != nil, str) } }