diff --git a/viper.go b/viper.go index a32ab73..4af8c2a 100644 --- a/viper.go +++ b/viper.go @@ -789,6 +789,165 @@ func (v *Viper) GetSizeInBytes(key string) uint { return parseSizeInBytes(sizeStr) } +// Require returns the value associated with the given key or +// an error if the key is not set. +func Require(key string) (interface{}, error) { return v.Require(key) } +func (v *Viper) Require(key string) (interface{}, error) { + value := v.Get(key) + if value == nil { + return nil, fmt.Errorf("key %s is not set", key) + } + return value, nil +} + +// RequireString returns the value associated with the key as a string or +// an error if the key is not set. +func RequireString(key string) (string, error) { return v.RequireString(key) } +func (v *Viper) RequireString(key string) (string, error) { + value, err := v.Require(key) + if err != nil { + return "", err + } + return cast.ToString(value), nil +} + +// RequireBool returns the value associated with the key as a boolean or +// an error if the key is not set. +func RequireBool(key string) (bool, error) { return v.RequireBool(key) } +func (v *Viper) RequireBool(key string) (bool, error) { + value, err := v.Require(key) + if err != nil { + return false, err + } + return cast.ToBool(value), nil +} + +// RequireInt returns the value associated with the key as an integer or +// an error if the key is not set. +func RequireInt(key string) (int, error) { return v.RequireInt(key) } +func (v *Viper) RequireInt(key string) (int, error) { + value, err := v.Require(key) + if err != nil { + return 0, err + } + return cast.ToInt(value), nil +} + +// RequireInt32 returns the value associated with the key as an integer or +// an error if the key is not set. +func RequireInt32(key string) (int32, error) { return v.RequireInt32(key) } +func (v *Viper) RequireInt32(key string) (int32, error) { + value, err := v.Require(key) + if err != nil { + return 0, err + } + return cast.ToInt32(value), nil +} + +// RequireInt64 returns the value associated with the key as an integer or +// an error if the key is not set. +func RequireInt64(key string) (int64, error) { return v.RequireInt64(key) } +func (v *Viper) RequireInt64(key string) (int64, error) { + value, err := v.Require(key) + if err != nil { + return 0, err + } + return cast.ToInt64(value), nil +} + +// RequireFloat64 returns the value associated with the key as a float64 or +// an error if the key is not set. +func RequireFloat64(key string) (float64, error) { return v.RequireFloat64(key) } +func (v *Viper) RequireFloat64(key string) (float64, error) { + value, err := v.Require(key) + if err != nil { + return 0, err + } + return cast.ToFloat64(value), nil +} + +// RequireTime returns the value associated with the key as time or +// an error if the key is not set. +func RequireTime(key string) (time.Time, error) { return v.RequireTime(key) } +func (v *Viper) RequireTime(key string) (time.Time, error) { + value, err := v.Require(key) + if err != nil { + return time.Time{}, err + } + return cast.ToTime(value), nil +} + +// RequireDuration returns the value associated with the key as a duration or +// an error if the key is not set. +func RequireDuration(key string) (time.Duration, error) { return v.RequireDuration(key) } +func (v *Viper) RequireDuration(key string) (time.Duration, error) { + value, err := v.Require(key) + if err != nil { + return time.Duration(0), err + } + return cast.ToDuration(value), nil +} + +// RequireStringSlice returns the value associated with the key as a slice of strings or +// an error if the key is not set. +func RequireStringSlice(key string) ([]string, error) { return v.RequireStringSlice(key) } +func (v *Viper) RequireStringSlice(key string) ([]string, error) { + value, err := v.Require(key) + if err != nil { + return nil, err + } + return cast.ToStringSlice(value), nil +} + +// RequireStringMap returns the value associated with the key as a map of interfaces or +// an error if the key is not set. +func RequireStringMap(key string) (map[string]interface{}, error) { return v.RequireStringMap(key) } +func (v *Viper) RequireStringMap(key string) (map[string]interface{}, error) { + value, err := v.Require(key) + if err != nil { + return nil, err + } + return cast.ToStringMap(value), nil +} + +// RequireStringMapString returns the value associated with the key as a map of strings or +// an error if the key is not set. +func RequireStringMapString(key string) (map[string]string, error) { + return v.RequireStringMapString(key) +} +func (v *Viper) RequireStringMapString(key string) (map[string]string, error) { + value, err := v.Require(key) + if err != nil { + return nil, err + } + return cast.ToStringMapString(value), nil +} + +// RequireStringMapStringSlice returns the value associated with the key as a map to a slice of strings or +// an error if the key is not set. +func RequireStringMapStringSlice(key string) (map[string][]string, error) { + return v.RequireStringMapStringSlice(key) +} +func (v *Viper) RequireStringMapStringSlice(key string) (map[string][]string, error) { + value, err := v.Require(key) + if err != nil { + return nil, err + } + return cast.ToStringMapStringSlice(value), nil +} + +// RequireSizeInBytes returns the size of the value associated with the given key in bytes or +// an error if the key is not set. +func RequireSizeInBytes(key string) (uint, error) { return v.RequireSizeInBytes(key) } +func (v *Viper) RequireSizeInBytes(key string) (uint, error) { + value, err := v.Require(key) + if err != nil { + return 0, err + } + sizeStr := cast.ToString(value) + return parseSizeInBytes(sizeStr), nil +} + // UnmarshalKey takes a single key and unmarshals it into a Struct. func UnmarshalKey(key string, rawVal interface{}, opts ...DecoderConfigOption) error { return v.UnmarshalKey(key, rawVal, opts...) diff --git a/viper_test.go b/viper_test.go index c8fa1f4..49cf4d4 100644 --- a/viper_test.go +++ b/viper_test.go @@ -453,6 +453,17 @@ func TestAllKeys(t *testing.T) { assert.Equal(t, all, AllSettings()) } +func TestRequireKeys(t *testing.T) { + initJSON() + presentKeys := []string{"id", "type", "name", "ppu", "batters"} + for _, key := range presentKeys { + _, err := Require(key) + assert.NoError(t, err) + } + _, err := Require("nope") + assert.EqualError(t, err, "key nope is not set") +} + func TestAllKeysWithEnv(t *testing.T) { v := New() @@ -642,6 +653,7 @@ func TestBindPFlag(t *testing.T) { } func TestBoundCaseSensitivity(t *testing.T) { + initConfigs() assert.Equal(t, "brown", Get("eyes")) BindEnv("eYEs", "TURTLE_EYES")