From 84f29467123ad09f28696ed373d22e83f39e9a60 Mon Sep 17 00:00:00 2001 From: laxman vallandas Date: Tue, 24 Apr 2018 17:47:31 +0530 Subject: [PATCH] viper patch to support cloudconfig lib --- viper.go | 401 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 288 insertions(+), 113 deletions(-) diff --git a/viper.go b/viper.go index f80570b..c04f8bf 100644 --- a/viper.go +++ b/viper.go @@ -22,24 +22,41 @@ package viper import ( "bytes" "encoding/csv" + "encoding/json" "fmt" "io" "log" "os" "path/filepath" "reflect" - "strconv" "strings" "time" + "errors" + + yaml "gopkg.in/yaml.v2" "github.com/fsnotify/fsnotify" + "github.com/hashicorp/hcl" + "github.com/hashicorp/hcl/hcl/printer" + "github.com/magiconair/properties" "github.com/mitchellh/mapstructure" + toml "github.com/pelletier/go-toml" "github.com/spf13/afero" "github.com/spf13/cast" jww "github.com/spf13/jwalterweatherman" "github.com/spf13/pflag" ) +// ConfigMarshalError happens when failing to marshal the configuration. +type ConfigMarshalError struct { + err error +} + +// Error returns the formatted configuration error. +func (e ConfigMarshalError) Error() string { + return fmt.Sprintf("While marshaling config: %s", e.err.Error()) +} + var v *Viper type RemoteResponse struct { @@ -70,8 +87,7 @@ func (str UnsupportedConfigError) Error() string { } // UnsupportedRemoteProviderError denotes encountering an unsupported remote -// provider. Currently only etcd and Consul are -// supported. +// provider. Currently only etcd and Consul are supported. type UnsupportedRemoteProviderError string // Error returns the formatted remote provider error. @@ -164,6 +180,10 @@ type Viper struct { aliases map[string]string typeByDefValue bool + // Store read properties on the object so that we can write back in order with comments. + // This will only be used if the configuration read is a properties file. + properties *properties.Properties + onConfigChange func(fsnotify.Event) } @@ -190,7 +210,7 @@ func New() *Viper { // can use it in their testing as well. func Reset() { v = New() - SupportedExts = []string{"json", "toml", "yaml", "yml", "hcl"} + SupportedExts = []string{"json", "toml", "yaml", "yml", "properties", "props", "prop", "hcl"} SupportedRemoteProviders = []string{"etcd", "consul"} } @@ -284,8 +304,8 @@ func (v *Viper) WatchConfig() { }() } -// SetConfigFile explicitly defines the path, name and extension of the config file -// Viper will use this and not check any of the config paths +// SetConfigFile explicitly defines the path, name and extension of the config file. +// Viper will use this and not check any of the config paths. func SetConfigFile(in string) { v.SetConfigFile(in) } func (v *Viper) SetConfigFile(in string) { if in != "" { @@ -294,8 +314,8 @@ func (v *Viper) SetConfigFile(in string) { } // SetEnvPrefix defines a prefix that ENVIRONMENT variables will use. -// E.g. if your prefix is "spf", the env registry -// will look for env. variables that start with "SPF_" +// E.g. if your prefix is "spf", the env registry will look for env +// variables that start with "SPF_". func SetEnvPrefix(in string) { v.SetEnvPrefix(in) } func (v *Viper) SetEnvPrefix(in string) { if in != "" { @@ -313,11 +333,11 @@ func (v *Viper) mergeWithEnvPrefix(in string) string { // TODO: should getEnv logic be moved into find(). Can generalize the use of // rewriting keys many things, Ex: Get('someKey') -> some_key -// (cammel case to snake case for JSON keys perhaps) +// (camel case to snake case for JSON keys perhaps) // getEnv is a wrapper around os.Getenv which replaces characters in the original -// key. This allows env vars which have different keys then the config object -// keys +// key. This allows env vars which have different keys than the config object +// keys. func (v *Viper) getEnv(key string) string { if v.envKeyReplacer != nil { key = v.envKeyReplacer.Replace(key) @@ -325,7 +345,7 @@ func (v *Viper) getEnv(key string) string { return os.Getenv(key) } -// ConfigFileUsed returns the file used to populate the config registry +// ConfigFileUsed returns the file used to populate the config registry. func ConfigFileUsed() string { return v.ConfigFileUsed() } func (v *Viper) ConfigFileUsed() string { return v.configFile } @@ -493,56 +513,6 @@ func (v *Viper) searchMapWithPathPrefixes(source map[string]interface{}, path [] return nil } -func (v *Viper) searchMapWithArrayPrefix(source map[string]interface{}, path []string) interface{} { - if len(path) == 0 { - return source - } - - next, ok := source[path[0]] - if ok { - // Immediate Key Value - if len(path) == 1 { - return next - } - - // Value from Nested key - switch next.(type) { - case map[interface{}]interface{}: - return v.searchMapWithArrayPrefix(cast.ToStringMap(next), path[1:]) - case map[string]interface{}: - // Type assertion is safe here since it is only reached - // if the type of `next` is the same as the type being asserted - return v.searchMapWithArrayPrefix(next.(map[string]interface{}), path[1:]) - case []interface{}: - v1 := cast.ToSlice(next) - for k2, v2 := range v1 { - if reflect.TypeOf(v2).Kind() == reflect.Map { - if _, err := strconv.ParseInt(path[1], 10, 64); err == nil { - if strconv.Itoa(k2) == path[1] { - return v.searchMapWithArrayPrefix(cast.ToStringMap(v2), path[1:]) - } - } else { - return v.searchMapWithArrayPrefix(cast.ToStringMap(v2), path[1:]) - } - } else { - if _, err := strconv.ParseInt(path[1], 10, 64); err == nil { - if strconv.Itoa(k2) == path[1] { - return v2 - } - } - } - } - default: - return nil - } - } else { - if _, err := strconv.ParseInt(path[0], 10, 64); err == nil { - return v.searchMapWithArrayPrefix(source, path[1:]) - } - } - return nil -} - // isPathShadowedInDeepMap makes sure the given path is not shadowed somewhere // on its path in the map. // e.g., if "foo.bar" has a value in the given map, it “shadows” @@ -648,32 +618,33 @@ func (v *Viper) Get(key string) interface{} { return nil } - valType := val if v.typeByDefValue { // TODO(bep) this branch isn't covered by a single test. + valType := val path := strings.Split(lcaseKey, v.keyDelim) defVal := v.searchMap(v.defaults, path) if defVal != nil { valType = defVal } + + switch valType.(type) { + case bool: + return cast.ToBool(val) + case string: + return cast.ToString(val) + case int64, int32, int16, int8, int: + return cast.ToInt(val) + case float64, float32: + return cast.ToFloat64(val) + case time.Time: + return cast.ToTime(val) + case time.Duration: + return cast.ToDuration(val) + case []string: + return cast.ToStringSlice(val) + } } - switch valType.(type) { - case bool: - return cast.ToBool(val) - case string: - return cast.ToString(val) - case int64, int32, int16, int8, int: - return cast.ToInt(val) - case float64, float32: - return cast.ToFloat64(val) - case time.Time: - return cast.ToTime(val) - case time.Duration: - return cast.ToDuration(val) - case []string: - return cast.ToStringSlice(val) - } return val } @@ -798,13 +769,16 @@ func (v *Viper) Unmarshal(rawVal interface{}) error { } // defaultDecoderConfig returns default mapsstructure.DecoderConfig with suppot -// of time.Duration values +// of time.Duration values & string slices func defaultDecoderConfig(output interface{}) *mapstructure.DecoderConfig { return &mapstructure.DecoderConfig{ Metadata: nil, Result: output, WeaklyTypedInput: true, - DecodeHook: mapstructure.StringToTimeDurationHookFunc(), + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + ), } } @@ -865,7 +839,7 @@ func (v *Viper) BindFlagValues(flags FlagValueSet) (err error) { } // BindFlagValue binds a specific key to a FlagValue. -// Example(where serverCmd is a Cobra instance): +// Example (where serverCmd is a Cobra instance): // // serverCmd.Flags().Int("port", 1138, "Port to run Application server on") // Viper.BindFlagValue("port", serverCmd.Flags().Lookup("port")) @@ -909,6 +883,7 @@ func (v *Viper) BindEnv(input ...string) error { // Viper will check to see if an alias exists first. // Note: this assumes a lower-cased key given. func (v *Viper) find(lcaseKey string) interface{} { + var ( val interface{} exists bool @@ -982,11 +957,6 @@ func (v *Viper) find(lcaseKey string) interface{} { if val != nil { return val } - val = v.searchMapWithArrayPrefix(v.config, path) - if val != nil { - return val - } - if nested && v.isPathShadowedInDeepMap(path, v.config) != "" { return nil } @@ -1171,6 +1141,7 @@ func (v *Viper) ReadInConfig() error { return UnsupportedConfigError(v.getConfigType()) } + jww.DEBUG.Println("Reading file: ", filename) file, err := afero.ReadFile(v.fs, filename) if err != nil { return err @@ -1230,6 +1201,195 @@ func (v *Viper) MergeConfig(in io.Reader) error { return nil } +// WriteConfig writes the current configuration to a file. +func WriteConfig() error { return v.WriteConfig() } +func (v *Viper) WriteConfig() error { + filename, err := v.getConfigFile() + if err != nil { + return err + } + return v.writeConfig(filename, true) +} + +// SafeWriteConfig writes current configuration to file only if the file does not exist. +func SafeWriteConfig() error { return v.SafeWriteConfig() } +func (v *Viper) SafeWriteConfig() error { + filename, err := v.getConfigFile() + if err != nil { + return err + } + return v.writeConfig(filename, false) +} + +// WriteConfigAs writes current configuration to a given filename. +func WriteConfigAs(filename string) error { return v.WriteConfigAs(filename) } +func (v *Viper) WriteConfigAs(filename string) error { + return v.writeConfig(filename, true) +} + +// SafeWriteConfigAs writes current configuration to a given filename if it does not exist. +func SafeWriteConfigAs(filename string) error { return v.SafeWriteConfigAs(filename) } +func (v *Viper) SafeWriteConfigAs(filename string) error { + return v.writeConfig(filename, false) +} + +func writeConfig(filename string, force bool) error { return v.writeConfig(filename, force) } +func (v *Viper) writeConfig(filename string, force bool) error { + jww.INFO.Println("Attempting to write configuration to file.") + ext := filepath.Ext(filename) + if len(ext) <= 1 { + return fmt.Errorf("Filename: %s requires valid extension.", filename) + } + configType := ext[1:] + if !stringInSlice(configType, SupportedExts) { + return UnsupportedConfigError(configType) + } + if v.config == nil { + v.config = make(map[string]interface{}) + } + var flags int + if force == true { + flags = os.O_CREATE | os.O_TRUNC | os.O_WRONLY + } else { + if _, err := os.Stat(filename); os.IsNotExist(err) { + flags = os.O_WRONLY + } else { + return fmt.Errorf("File: %s exists. Use WriteConfig to overwrite.", filename) + } + } + f, err := v.fs.OpenFile(filename, flags, os.FileMode(0644)) + if err != nil { + return err + } + return v.marshalWriter(f, configType) +} + +// Unmarshal a Reader into a map. +// Should probably be an unexported function. +func unmarshalReader(in io.Reader, c map[string]interface{}) error { + return v.unmarshalReader(in, c) +} +func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error { + buf := new(bytes.Buffer) + buf.ReadFrom(in) + + switch strings.ToLower(v.getConfigType()) { + case "yaml", "yml": + if err := yaml.Unmarshal(buf.Bytes(), &c); err != nil { + return ConfigParseError{err} + } + + case "json": + if err := json.Unmarshal(buf.Bytes(), &c); err != nil { + return ConfigParseError{err} + } + + case "hcl": + obj, err := hcl.Parse(string(buf.Bytes())) + if err != nil { + return ConfigParseError{err} + } + if err = hcl.DecodeObject(&c, obj); err != nil { + return ConfigParseError{err} + } + + case "toml": + tree, err := toml.LoadReader(buf) + if err != nil { + return ConfigParseError{err} + } + tmap := tree.ToMap() + for k, v := range tmap { + c[k] = v + } + + case "properties", "props", "prop": + v.properties = properties.NewProperties() + var err error + if v.properties, err = properties.Load(buf.Bytes(), properties.UTF8); err != nil { + return ConfigParseError{err} + } + for _, key := range v.properties.Keys() { + value, _ := v.properties.Get(key) + // recursively build nested maps + path := strings.Split(key, ".") + lastKey := strings.ToLower(path[len(path)-1]) + deepestMap := deepSearch(c, path[0:len(path)-1]) + // set innermost value + deepestMap[lastKey] = value + } + } + + insensitiviseMap(c) + return nil +} + +// Marshal a map into Writer. +func marshalWriter(f afero.File, configType string) error { + return v.marshalWriter(f, configType) +} +func (v *Viper) marshalWriter(f afero.File, configType string) error { + c := v.AllSettings() + switch configType { + case "json": + b, err := json.MarshalIndent(c, "", " ") + if err != nil { + return ConfigMarshalError{err} + } + _, err = f.WriteString(string(b)) + if err != nil { + return ConfigMarshalError{err} + } + + case "hcl": + b, err := json.Marshal(c) + ast, err := hcl.Parse(string(b)) + if err != nil { + return ConfigMarshalError{err} + } + err = printer.Fprint(f, ast.Node) + if err != nil { + return ConfigMarshalError{err} + } + + case "prop", "props", "properties": + if v.properties == nil { + v.properties = properties.NewProperties() + } + p := v.properties + for _, key := range v.AllKeys() { + _, _, err := p.Set(key, v.GetString(key)) + if err != nil { + return ConfigMarshalError{err} + } + } + _, err := p.WriteComment(f, "#", properties.UTF8) + if err != nil { + return ConfigMarshalError{err} + } + + case "toml": + t, err := toml.TreeFromMap(c) + if err != nil { + return ConfigMarshalError{err} + } + s := t.String() + if _, err := f.WriteString(s); err != nil { + return ConfigMarshalError{err} + } + + case "yaml", "yml": + b, err := yaml.Marshal(c) + if err != nil { + return ConfigMarshalError{err} + } + if _, err = f.WriteString(string(b)); err != nil { + return ConfigMarshalError{err} + } + } + return nil +} + func keyExists(k string, m map[string]interface{}) string { lk := strings.ToLower(k) for mk := range m { @@ -1338,18 +1498,8 @@ func (v *Viper) WatchRemoteConfig() error { return v.watchKeyValueConfig() } -func (v *Viper) WatchRemoteConfigOnChannel() error { - return v.watchKeyValueConfigOnChannel() -} - -// Unmarshall a Reader into a map. -// Should probably be an unexported function. -func unmarshalReader(in io.Reader, c map[string]interface{}) error { - return v.unmarshalReader(in, c) -} - -func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error { - return unmarshallConfigReader(in, c, v.getConfigType()) +func (v *Viper) WatchRemoteConfigOnChannel(configChanged chan<- bool) error { + return v.watchKeyValueConfigOnChannel(configChanged) } func (v *Viper) insensitiviseMaps() { @@ -1385,8 +1535,20 @@ func (v *Viper) getRemoteConfig(provider RemoteProvider) (map[string]interface{} return v.kvstore, err } +// Retrieve the First Found remote configuration +func (v *Viper)GetRemoteConf() (io.Reader, error) { + for _, rp := range v.remoteProviders { + reader, err := RemoteConfig.Get(rp) + if err != nil { + continue + } + return reader, nil + } + return nil, errors.New("Did not find Remote config") +} + // Retrieve the first found remote configuration. -func (v *Viper) watchKeyValueConfigOnChannel() error { +func (v *Viper) watchKeyValueConfigOnChannel(configChanged chan<- bool) error { for _, rp := range v.remoteProviders { respc, _ := RemoteConfig.WatchChannel(rp) //Todo: Add quit channel @@ -1395,6 +1557,7 @@ func (v *Viper) watchKeyValueConfigOnChannel() error { b := <-rc reader := bytes.NewReader(b.Value) v.unmarshalReader(reader, v.kvstore) + configChanged <- true } }(respc) return nil @@ -1551,6 +1714,11 @@ func (v *Viper) SetConfigType(in string) { } } +// GetConfigType gets the type of configuration file used by current viper object +func (v *Viper) GetConfigType() string { + return v.getConfigType() +} + func (v *Viper) getConfigType() string { if v.configType != "" { return v.configType @@ -1570,26 +1738,34 @@ func (v *Viper) getConfigType() string { return "" } +// GetAppConfigFile gets the filename of local configuration file used +func (v *Viper) GetAppConfigFile() (string, error) { + if v.configFile == "" { + cf, err := v.findConfigFile() + if err != nil { + return "", err + } + v.configFile = cf + } + return v.configFile, nil +} + func (v *Viper) getConfigFile() (string, error) { - // if explicitly set, then use it - if v.configFile != "" { - return v.configFile, nil + if v.configFile == "" { + cf, err := v.findConfigFile() + if err != nil { + return "", err + } + v.configFile = cf } - - cf, err := v.findConfigFile() - if err != nil { - return "", err - } - - v.configFile = cf - return v.getConfigFile() + return v.configFile, nil } func (v *Viper) searchInPath(in string) (filename string) { jww.DEBUG.Println("Searching for config in ", in) for _, ext := range SupportedExts { jww.DEBUG.Println("Checking for", filepath.Join(in, v.configName+"."+ext)) - if b, _ := exists(filepath.Join(in, v.configName+"."+ext)); b { + if b, _ := exists(v.fs, filepath.Join(in, v.configName+"."+ext)); b { jww.DEBUG.Println("Found: ", filepath.Join(in, v.configName+"."+ext)) return filepath.Join(in, v.configName+"."+ext) } @@ -1601,7 +1777,6 @@ func (v *Viper) searchInPath(in string) (filename string) { // Search all configPaths for any config file. // Returns the first path that exists (and is a config file). func (v *Viper) findConfigFile() (string, error) { - jww.INFO.Println("Searching for config in ", v.configPaths) for _, cp := range v.configPaths {