diff --git a/viper.go b/viper.go index 11f3a77..ea7d084 100644 --- a/viper.go +++ b/viper.go @@ -50,7 +50,7 @@ type UnsupportedConfigError string // Returns the formatted configuration error. func (str UnsupportedConfigError) Error() string { - return fmt.Sprintf("Unsupported Config Type %q", string(str)) + return fmt.Sprintf("Unsupported Config Type %q or Config File Not Found", string(str)) } // Denotes encountering an unsupported remote @@ -694,33 +694,42 @@ func (v *Viper) Set(key string, value interface{}) { func ReadInConfig() error { return v.ReadInConfig() } func (v *Viper) ReadInConfig() error { jww.INFO.Println("Attempting to read in config file") - if !stringInSlice(v.getConfigType(), SupportedExts) { - return UnsupportedConfigError(v.getConfigType()) + + cfgType, err := v.getConfigType() + if err != nil { + return err } - file, err := ioutil.ReadFile(v.getConfigFile()) + if !stringInSlice(cfgType, SupportedExts) { + return UnsupportedConfigError(cfgType) + } + + fname, err := v.getConfigFile() + if err != nil { + return err + } + + file, err := ioutil.ReadFile(fname) if err != nil { return err } v.config = make(map[string]interface{}) - v.marshalReader(bytes.NewReader(file), v.config) - return nil + return v.marshalReader(bytes.NewReader(file), v.config) } func ReadConfig(in io.Reader) error { return v.ReadConfig(in) } + func (v *Viper) ReadConfig(in io.Reader) error { v.config = make(map[string]interface{}) - v.marshalReader(in, v.config) - return nil + return v.marshalReader(in, v.config) } // func ReadBufConfig(buf *bytes.Buffer) error { return v.ReadBufConfig(buf) } // func (v *Viper) ReadBufConfig(buf *bytes.Buffer) error { // v.config = make(map[string]interface{}) -// v.marshalReader(buf, v.config) -// return nil +// return v.marshalReader(buf, v.config) // } // Attempts to get configuration from a remote source @@ -745,9 +754,17 @@ func (v *Viper) WatchRemoteConfig() error { // Marshall a Reader into a map // Should probably be an unexported function -func marshalReader(in io.Reader, c map[string]interface{}) { v.marshalReader(in, c) } -func (v *Viper) marshalReader(in io.Reader, c map[string]interface{}) { - marshallConfigReader(in, c, v.getConfigType()) +func marshalReader(in io.Reader, c map[string]interface{}) error { + return v.marshalReader(in, c) +} + +func (v *Viper) marshalReader(in io.Reader, c map[string]interface{}) error { + cfgType, err := v.getConfigType() + if err != nil { + return err + } + marshallConfigReader(in, c, cfgType) + return nil } func (v *Viper) insensitiviseMaps() { @@ -800,7 +817,7 @@ func (v *Viper) getRemoteConfig(provider *remoteProvider) (map[string]interface{ return nil, err } reader := bytes.NewReader(b) - v.marshalReader(reader, v.kvstore) + err = v.marshalReader(reader, v.kvstore) return v.kvstore, err } @@ -850,7 +867,7 @@ func (v *Viper) watchRemoteConfig(provider *remoteProvider) (map[string]interfac } reader := bytes.NewReader(resp.Value) - v.marshalReader(reader, v.kvstore) + err = v.marshalReader(reader, v.kvstore) return v.kvstore, err } @@ -912,34 +929,35 @@ func (v *Viper) SetConfigType(in string) { } } -func (v *Viper) getConfigType() string { +func (v *Viper) getConfigType() (string, error) { if v.configType != "" { - return v.configType + return v.configType, nil + } + + cf, err := v.getConfigFile() + if err != nil { + return "", err } - cf := v.getConfigFile() ext := filepath.Ext(cf) if len(ext) > 1 { - return ext[1:] + return ext[1:], nil } else { - return "" + return "", fmt.Errorf("Missing config file extension in %q", cf) } } -func (v *Viper) getConfigFile() string { +func (v *Viper) getConfigFile() (string, error) { + var err error + // if explicitly set, then use it if v.configFile != "" { - return v.configFile + return v.configFile, nil } - cf, err := v.findConfigFile() - if err != nil { - return "" - } - - v.configFile = cf - return v.getConfigFile() + v.configFile, err = v.findConfigFile() + return v.configFile, err } func (v *Viper) searchInPath(in string) (filename string) { @@ -955,9 +973,22 @@ func (v *Viper) searchInPath(in string) (filename string) { return "" } -// search all configPaths for any config file. -// Returns the first path that exists (and is a config file) +// Choose where to look for a config file: either +// in provided directories or in the working directory func (v *Viper) findConfigFile() (string, error) { + + if len(v.configPaths) > 0 { + return v.findConfigInPaths() + } else { + return v.findConfigInWD() + } + +} + +// Search all configPaths for any config file. +// Returns the first path that exists (and has a config file) +func (v *Viper) findConfigInPaths() (string, error) { + jww.INFO.Println("Searching for config in ", v.configPaths) for _, cp := range v.configPaths { @@ -966,14 +997,20 @@ func (v *Viper) findConfigFile() (string, error) { return file, nil } } + return "", fmt.Errorf("config file not found in: %s", v.configPaths) +} + +// Search the current working directory for any config file. +func (v *Viper) findConfigInWD() (string, error) { - // try the current working directory wd, _ := os.Getwd() + jww.INFO.Println("Searching for config in ", wd) + file := v.searchInPath(wd) if file != "" { return file, nil } - return "", fmt.Errorf("config file not found in: %s", v.configPaths) + return "", fmt.Errorf("config file not found in current working directory", v.configPaths) } // Prints all configuration registries for debugging diff --git a/viper_test.go b/viper_test.go index 7ad0245..a03b14e 100644 --- a/viper_test.go +++ b/viper_test.go @@ -147,7 +147,9 @@ func (s *stringValue) String() string { func TestBasics(t *testing.T) { SetConfigFile("/tmp/config.yaml") - assert.Equal(t, "/tmp/config.yaml", v.getConfigFile()) + cf, err := v.getConfigFile() + assert.Nil(t, err) + assert.Equal(t, "/tmp/config.yaml", cf) } func TestDefault(t *testing.T) {