This commit is contained in:
Vlad Didenko 2015-06-22 02:50:30 +00:00
commit 489cd6f1c8
3 changed files with 204 additions and 23 deletions

21
util.go
View file

@ -28,6 +28,16 @@ import (
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
// Denotes failing to parse configuration file.
type ConfigFileParseError struct {
err error
}
// Returns the formatted configuration error.
func (pe ConfigFileParseError) Error() string {
return fmt.Sprintf("While parsing config: %s", pe.err.Error())
}
func insensitiviseMap(m map[string]interface{}) { func insensitiviseMap(m map[string]interface{}) {
for key, val := range m { for key, val := range m {
lower := strings.ToLower(key) lower := strings.ToLower(key)
@ -119,31 +129,31 @@ func findCWD() (string, error) {
return path, nil return path, nil
} }
func marshallConfigReader(in io.Reader, c map[string]interface{}, configType string) { func marshallConfigReader(in io.Reader, c map[string]interface{}, configType string) error {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
buf.ReadFrom(in) buf.ReadFrom(in)
switch strings.ToLower(configType) { switch strings.ToLower(configType) {
case "yaml", "yml": case "yaml", "yml":
if err := yaml.Unmarshal(buf.Bytes(), &c); err != nil { if err := yaml.Unmarshal(buf.Bytes(), &c); err != nil {
jww.ERROR.Fatalf("Error parsing config: %s", err) return ConfigFileParseError{err}
} }
case "json": case "json":
if err := json.Unmarshal(buf.Bytes(), &c); err != nil { if err := json.Unmarshal(buf.Bytes(), &c); err != nil {
jww.ERROR.Fatalf("Error parsing config: %s", err) return ConfigFileParseError{err}
} }
case "toml": case "toml":
if _, err := toml.Decode(buf.String(), &c); err != nil { if _, err := toml.Decode(buf.String(), &c); err != nil {
jww.ERROR.Fatalf("Error parsing config: %s", err) return ConfigFileParseError{err}
} }
case "properties", "props", "prop": case "properties", "props", "prop":
var p *properties.Properties var p *properties.Properties
var err error var err error
if p, err = properties.Load(buf.Bytes(), properties.UTF8); err != nil { if p, err = properties.Load(buf.Bytes(), properties.UTF8); err != nil {
jww.ERROR.Fatalf("Error parsing config: %s", err) return ConfigFileParseError{err}
} }
for _, key := range p.Keys() { for _, key := range p.Keys() {
value, _ := p.Get(key) value, _ := p.Get(key)
@ -152,6 +162,7 @@ func marshallConfigReader(in io.Reader, c map[string]interface{}, configType str
} }
insensitiviseMap(c) insensitiviseMap(c)
return nil
} }
func safeMul(a, b uint) uint { func safeMul(a, b uint) uint {

View file

@ -79,6 +79,16 @@ func (rce RemoteConfigError) Error() string {
return fmt.Sprintf("Remote Configurations Error: %s", string(rce)) return fmt.Sprintf("Remote Configurations Error: %s", string(rce))
} }
// Denotes failing to find configuration file.
type ConfigFileNotFoundError struct {
name, locations string
}
// Returns the formatted configuration error.
func (fnfe ConfigFileNotFoundError) Error() string {
return fmt.Sprintf("Config File %q Not Found in %q", fnfe.name, fnfe.locations)
}
// Viper is a prioritized configuration registry. It // Viper is a prioritized configuration registry. It
// maintains a set of configuration sources, fetches // maintains a set of configuration sources, fetches
// values to populate those, and provides them according // values to populate those, and provides them according
@ -522,11 +532,11 @@ func (v *Viper) BindPFlag(key string, flag *pflag.Flag) (err error) {
switch flag.Value.Type() { switch flag.Value.Type() {
case "int", "int8", "int16", "int32", "int64": case "int", "int8", "int16", "int32", "int64":
SetDefault(key, cast.ToInt(flag.Value.String())) v.SetDefault(key, cast.ToInt(flag.Value.String()))
case "bool": case "bool":
SetDefault(key, cast.ToBool(flag.Value.String())) v.SetDefault(key, cast.ToBool(flag.Value.String()))
default: default:
SetDefault(key, flag.Value.String()) v.SetDefault(key, flag.Value.String())
} }
return nil return nil
} }
@ -738,22 +748,19 @@ func (v *Viper) ReadInConfig() error {
v.config = make(map[string]interface{}) v.config = make(map[string]interface{})
v.marshalReader(bytes.NewReader(file), v.config) return v.marshalReader(bytes.NewReader(file), v.config)
return nil
} }
func ReadConfig(in io.Reader) error { return v.ReadConfig(in) } func ReadConfig(in io.Reader) error { return v.ReadConfig(in) }
func (v *Viper) ReadConfig(in io.Reader) error { func (v *Viper) ReadConfig(in io.Reader) error {
v.config = make(map[string]interface{}) v.config = make(map[string]interface{})
v.marshalReader(in, v.config) return v.marshalReader(in, v.config)
return nil
} }
// func ReadBufConfig(buf *bytes.Buffer) error { return v.ReadBufConfig(buf) } // func ReadBufConfig(buf *bytes.Buffer) error { return v.ReadBufConfig(buf) }
// func (v *Viper) ReadBufConfig(buf *bytes.Buffer) error { // func (v *Viper) ReadBufConfig(buf *bytes.Buffer) error {
// v.config = make(map[string]interface{}) // v.config = make(map[string]interface{})
// v.marshalReader(buf, v.config) // return v.marshalReader(buf, v.config)
// return nil
// } // }
// Attempts to get configuration from a remote source // Attempts to get configuration from a remote source
@ -778,9 +785,12 @@ func (v *Viper) WatchRemoteConfig() error {
// Marshall a Reader into a map // Marshall a Reader into a map
// Should probably be an unexported function // Should probably be an unexported function
func marshalReader(in io.Reader, c map[string]interface{}) { v.marshalReader(in, c) } func marshalReader(in io.Reader, c map[string]interface{}) error {
func (v *Viper) marshalReader(in io.Reader, c map[string]interface{}) { return v.marshalReader(in, c)
marshallConfigReader(in, c, v.getConfigType()) }
func (v *Viper) marshalReader(in io.Reader, c map[string]interface{}) error {
return marshallConfigReader(in, c, v.getConfigType())
} }
func (v *Viper) insensitiviseMaps() { func (v *Viper) insensitiviseMaps() {
@ -813,7 +823,7 @@ func (v *Viper) getRemoteConfig(provider *defaultRemoteProvider) (map[string]int
if err != nil { if err != nil {
return nil, err return nil, err
} }
v.marshalReader(reader, v.kvstore) err = v.marshalReader(reader, v.kvstore)
return v.kvstore, err return v.kvstore, err
} }
@ -835,7 +845,7 @@ func (v *Viper) watchRemoteConfig(provider *defaultRemoteProvider) (map[string]i
if err != nil { if err != nil {
return nil, err return nil, err
} }
v.marshalReader(reader, v.kvstore) err = v.marshalReader(reader, v.kvstore)
return v.kvstore, err return v.kvstore, err
} }
@ -940,9 +950,22 @@ func (v *Viper) searchInPath(in string) (filename string) {
return "" return ""
} }
// search all configPaths for any config file. // Choose where to look for a config file: either
// Returns the first path that exists (and is a config file) // in provided directories or in the working directory
func (v *Viper) findConfigFile() (string, error) { func (v *Viper) findConfigFile() (string, error) {
if len(v.configPaths) > 0 {
return v.findConfigInPaths()
} else {
return v.findConfigInCWD()
}
}
// Search all configPaths for any config file.
// Returns the first path that exists (and has a config file)
func (v *Viper) findConfigInPaths() (string, error) {
jww.INFO.Println("Searching for config in ", v.configPaths) jww.INFO.Println("Searching for config in ", v.configPaths)
for _, cp := range v.configPaths { for _, cp := range v.configPaths {
@ -951,14 +974,20 @@ func (v *Viper) findConfigFile() (string, error) {
return file, nil return file, nil
} }
} }
return "", ConfigFileNotFoundError{v.configName, fmt.Sprintf("%s", v.configPaths)}
}
// Search the current working directory for any config file.
func (v *Viper) findConfigInCWD() (string, error) {
// try the current working directory
wd, _ := os.Getwd() wd, _ := os.Getwd()
jww.INFO.Println("Searching for config in ", wd)
file := v.searchInPath(wd) file := v.searchInPath(wd)
if file != "" { if file != "" {
return file, nil return file, nil
} }
return "", fmt.Errorf("config file not found in: %s", v.configPaths) return "", ConfigFileNotFoundError{v.configName, wd}
} }
// Prints all configuration registries for debugging // Prints all configuration registries for debugging

View file

@ -8,7 +8,10 @@ package viper
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"path"
"reflect"
"sort" "sort"
"strings" "strings"
"testing" "testing"
@ -126,6 +129,47 @@ func initTOML() {
marshalReader(r, v.config) marshalReader(r, v.config)
} }
// make directories for testing
func initDirs(t *testing.T) (string, string, func()) {
var (
testDirs = []string{`a a`, `b`, `c\c`, `D:`}
config = `improbable`
)
root, err := ioutil.TempDir("", "")
cleanup := true
defer func() {
if cleanup {
os.Chdir("..")
os.RemoveAll(root)
}
}()
assert.Nil(t, err)
err = os.Chdir(root)
assert.Nil(t, err)
err = ioutil.WriteFile(path.Join(root, config+".toml"), []byte("key = \"root\"\n"), 0640)
assert.Nil(t, err)
for _, dir := range testDirs {
err = os.Mkdir(dir, 0750)
assert.Nil(t, err)
err = ioutil.WriteFile(path.Join(dir, config+".toml"), []byte("key = \"value is "+dir+"\"\n"), 0640)
assert.Nil(t, err)
}
cleanup = false
return root, config, func() {
os.Chdir("..")
os.RemoveAll(root)
}
}
//stubs for PFlag Values //stubs for PFlag Values
type stringValue string type stringValue string
@ -557,3 +601,100 @@ func TestReadBufConfig(t *testing.T) {
assert.Equal(t, map[interface{}]interface{}{"jacket": "leather", "trousers": "denim", "pants": map[interface{}]interface{}{"size": "large"}}, v.Get("clothing")) assert.Equal(t, map[interface{}]interface{}{"jacket": "leather", "trousers": "denim", "pants": map[interface{}]interface{}{"size": "large"}}, v.Get("clothing"))
assert.Equal(t, 35, v.Get("age")) assert.Equal(t, 35, v.Get("age"))
} }
func TestCWDSearch(t *testing.T) {
_, config, cleanup := initDirs(t)
defer cleanup()
v := New()
v.SetConfigName(config)
v.SetDefault(`key`, `default`)
err := v.ReadInConfig()
assert.Nil(t, err)
assert.Equal(t, `root`, v.GetString(`key`))
}
func TestCWDSearchNoConfig(t *testing.T) {
_, config, cleanup := initDirs(t)
defer cleanup()
// Remove the config file in CWD
os.Remove(config + ".toml")
v := New()
v.SetConfigName(config)
v.SetDefault(`key`, `default`)
err := v.ReadInConfig()
assert.Equal(t, reflect.TypeOf(UnsupportedConfigError("")), reflect.TypeOf(err))
assert.Equal(t, `default`, v.GetString(`key`))
}
func TestDirsSearch(t *testing.T) {
root, config, cleanup := initDirs(t)
defer cleanup()
v := New()
v.SetConfigName(config)
v.SetDefault(`key`, `default`)
entries, err := ioutil.ReadDir(root)
for _, e := range entries {
if e.IsDir() {
v.AddConfigPath(e.Name())
}
}
err = v.ReadInConfig()
assert.Nil(t, err)
assert.Equal(t, `value is `+path.Base(v.configPaths[0]), v.GetString(`key`))
}
func TestWrongDirsSearchNotFoundHasCWDConfig(t *testing.T) {
_, config, cleanup := initDirs(t)
defer cleanup()
v := New()
v.SetConfigName(config)
v.SetDefault(`key`, `default`)
v.AddConfigPath(`whattayoutalkingbout`)
v.AddConfigPath(`thispathaintthere`)
err := v.ReadInConfig()
assert.Equal(t, reflect.TypeOf(UnsupportedConfigError("")), reflect.TypeOf(err))
// Should not see the value "root" which comes from config in CWD
assert.Equal(t, `default`, v.GetString(`key`))
}
func TestWrongDirsSearchNotFoundNoCWDConfig(t *testing.T) {
_, config, cleanup := initDirs(t)
defer cleanup()
// Remove the config file in CWD
os.Remove(config + ".toml")
v := New()
v.SetConfigName(config)
v.SetDefault(`key`, `default`)
v.AddConfigPath(`whattayoutalkingbout`)
v.AddConfigPath(`thispathaintthere`)
err := v.ReadInConfig()
assert.Equal(t, reflect.TypeOf(UnsupportedConfigError("")), reflect.TypeOf(err))
// Even though config did not load and the error might have
// been ignored by the client, the default still loads
assert.Equal(t, `default`, v.GetString(`key`))
}