From da4885ce3d4c4ce2cdda088827c9d2eb4cb96de9 Mon Sep 17 00:00:00 2001 From: Roy Razon Date: Tue, 16 Jan 2018 12:59:06 +0200 Subject: [PATCH] Added functionality to mock the environment variable store, for tests --- util.go | 14 +++++++------- viper.go | 45 +++++++++++++++++++++++++++++++++++++++------ viper_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 13 deletions(-) diff --git a/util.go b/util.go index 3ebada9..e255fe3 100644 --- a/util.go +++ b/util.go @@ -94,16 +94,16 @@ func insensitiviseMap(m map[string]interface{}) { } } -func absPathify(inPath string) string { +func absPathify(inPath string, envStore EnvStore) string { jww.INFO.Println("Trying to resolve absolute path to", inPath) if strings.HasPrefix(inPath, "$HOME") { - inPath = userHomeDir() + inPath[5:] + inPath = userHomeDir(envStore) + inPath[5:] } if strings.HasPrefix(inPath, "$") { end := strings.Index(inPath, string(os.PathSeparator)) - inPath = os.Getenv(inPath[1:end]) + inPath[end:] + inPath = envStore.Get(inPath[1:end]) + inPath[end:] } if filepath.IsAbs(inPath) { @@ -141,15 +141,15 @@ func stringInSlice(a string, list []string) bool { return false } -func userHomeDir() string { +func userHomeDir(store EnvStore) string { if runtime.GOOS == "windows" { - home := os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH") + home := store.Get("HOMEDRIVE") + store.Get("HOMEPATH") if home == "" { - home = os.Getenv("USERPROFILE") + home = store.Get("USERPROFILE") } return home } - return os.Getenv("HOME") + return store.Get("HOME") } func unmarshallConfigReader(in io.Reader, c map[string]interface{}, configType string) error { diff --git a/viper.go b/viper.go index 963861a..7ddfd53 100644 --- a/viper.go +++ b/viper.go @@ -96,6 +96,16 @@ func (fnfe ConfigFileNotFoundError) Error() string { return fmt.Sprintf("Config File %q Not Found in %q", fnfe.name, fnfe.locations) } +type EnvStore interface { + Get(key string) string +} + +type RealEnvStore struct {} + +func (e *RealEnvStore) Get(key string) string { + return os.Getenv(key) +} + // Viper is a prioritized configuration registry. It // maintains a set of configuration sources, fetches // values to populate those, and provides them according @@ -163,10 +173,11 @@ type Viper struct { typeByDefValue bool onConfigChange func(fsnotify.Event) + + envStore EnvStore } -// New returns an initialized Viper instance. -func New() *Viper { +func baseNew() *Viper { v := new(Viper) v.keyDelim = "." v.configName = "config" @@ -179,10 +190,27 @@ func New() *Viper { v.env = make(map[string]string) v.aliases = make(map[string]string) v.typeByDefValue = false + v.envStore = new(RealEnvStore) return v } +// New returns an initialized Viper instance. +func New() *Viper { + v := baseNew() + v.fs = afero.NewOsFs() + v.envStore = new(RealEnvStore) + + return v +} + +func newFrom(from *Viper) *Viper { + v := baseNew() + v.fs = from.fs + v.envStore = from.envStore + return v +} + // Intended for testing, will reset all to default settings. // In the public interface for the viper package so applications // can use it in their testing as well. @@ -313,14 +341,14 @@ func (v *Viper) mergeWithEnvPrefix(in string) string { // rewriting keys many things, Ex: Get('someKey') -> some_key // (camel case to snake case for JSON keys perhaps) -// getEnv is a wrapper around os.Getenv which replaces characters in the original +// getEnv is a wrapper around EnvStore.Get which replaces characters in the original // key. This allows env vars which have different keys than the config object // keys. func (v *Viper) getEnv(key string) string { if v.envKeyReplacer != nil { key = v.envKeyReplacer.Replace(key) } - return os.Getenv(key) + return v.envStore.Get(key) } // ConfigFileUsed returns the file used to populate the config registry. @@ -332,7 +360,7 @@ func (v *Viper) ConfigFileUsed() string { return v.configFile } func AddConfigPath(in string) { v.AddConfigPath(in) } func (v *Viper) AddConfigPath(in string) { if in != "" { - absin := absPathify(in) + absin := absPathify(in, v.envStore) jww.INFO.Println("adding", absin, "to paths to search") if !stringInSlice(absin, v.configPaths) { v.configPaths = append(v.configPaths, absin) @@ -630,7 +658,7 @@ func (v *Viper) Get(key string) interface{} { // Sub is case-insensitive for a key. func Sub(key string) *Viper { return v.Sub(key) } func (v *Viper) Sub(key string) *Viper { - subv := New() + subv := newFrom(v) data := v.Get(key) if data == nil { return nil @@ -1480,6 +1508,11 @@ func (v *Viper) SetFs(fs afero.Fs) { v.fs = fs } +func SetEnvStore(store EnvStore) { v.SetEnvStore(store) } +func (v *Viper) SetEnvStore(store EnvStore) { + v.envStore = store +} + // SetConfigName sets name for the config file. // Does not include extension. func SetConfigName(in string) { v.SetConfigName(in) } diff --git a/viper_test.go b/viper_test.go index 774ca11..79a2e15 100644 --- a/viper_test.go +++ b/viper_test.go @@ -360,6 +360,50 @@ func TestRemotePrecedence(t *testing.T) { Set("newkey", "remote") } +type mapEnv struct { + m map[string]string +} + +func (e *mapEnv) Get(key string) string { + if val, ok := e.m[key]; ok { + return val + } else { + return "" + } +} + +func (e *mapEnv) Set(key, value string) { + e.m[key] = value +} + +func newMapEnv() *mapEnv { + mapEnv := new(mapEnv) + mapEnv.m = make(map[string]string) + return mapEnv +} + +func TestMockEnv(t *testing.T) { + e := newMapEnv() + + initJSON() + v.SetEnvStore(e) + + BindEnv("id") + BindEnv("f", "FOOD") + + e.Set("ID", "13") + e.Set("FOOD", "apple") + e.Set("NAME", "crunk") + + assert.Equal(t, "13", Get("id")) + assert.Equal(t, "apple", Get("f")) + assert.Equal(t, "Cake", Get("name")) + + AutomaticEnv() + + assert.Equal(t, "crunk", Get("name")) +} + func TestEnv(t *testing.T) { initJSON()