diff --git a/README.md b/README.md index c14e892..d87b3fc 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,9 @@ -> ## Viper v2 feedback -> Viper is heading towards v2 and we would love to hear what _**you**_ would like to see in it. Share your thoughts here: https://forms.gle/R6faU74qPRPAzchZ9 -> -> **Thank you!** +> ## Thread Safe Viper Fork +> This is a thread safe fork of viper. It uses a `sync.RWLock` to +> implement basic thread synchronization. Priority was given to basic +> `Get` and `Set` routines over features like configuration watching and remote +> config sources, so not *every* feature is covered -- though it should +> cover most simple use cases. ![Viper](.github/logo.png?raw=true) diff --git a/go.mod b/go.mod index 35ef234..745ad5d 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/spf13/viper +module github.com/everactive/viper go 1.17 diff --git a/remote/remote.go b/remote/remote.go index 0177288..0bceb94 100644 --- a/remote/remote.go +++ b/remote/remote.go @@ -11,9 +11,8 @@ import ( "io" "os" + "github.com/everactive/viper" crypt "github.com/sagikazarmark/crypt/config" - - "github.com/spf13/viper" ) type remoteConfigProvider struct{} diff --git a/util_test.go b/util_test.go index cb4e620..f85fc13 100644 --- a/util_test.go +++ b/util_test.go @@ -16,7 +16,7 @@ import ( "reflect" "testing" - "github.com/spf13/viper/internal/testutil" + "github.com/everactive/viper/internal/testutil" ) func TestCopyAndInsensitiviseMap(t *testing.T) { diff --git a/viper.go b/viper.go index a3812e9..655e257 100644 --- a/viper.go +++ b/viper.go @@ -34,20 +34,19 @@ import ( "sync" "time" + "github.com/everactive/viper/internal/encoding" + "github.com/everactive/viper/internal/encoding/dotenv" + "github.com/everactive/viper/internal/encoding/hcl" + "github.com/everactive/viper/internal/encoding/ini" + "github.com/everactive/viper/internal/encoding/javaproperties" + "github.com/everactive/viper/internal/encoding/json" + "github.com/everactive/viper/internal/encoding/toml" + "github.com/everactive/viper/internal/encoding/yaml" "github.com/fsnotify/fsnotify" "github.com/mitchellh/mapstructure" "github.com/spf13/afero" "github.com/spf13/cast" "github.com/spf13/pflag" - - "github.com/spf13/viper/internal/encoding" - "github.com/spf13/viper/internal/encoding/dotenv" - "github.com/spf13/viper/internal/encoding/hcl" - "github.com/spf13/viper/internal/encoding/ini" - "github.com/spf13/viper/internal/encoding/javaproperties" - "github.com/spf13/viper/internal/encoding/json" - "github.com/spf13/viper/internal/encoding/toml" - "github.com/spf13/viper/internal/encoding/yaml" ) // ConfigMarshalError happens when failing to marshal the configuration. @@ -179,6 +178,7 @@ func DecodeHook(hook mapstructure.DecodeHookFunc) DecoderConfigOption { // // Note: Vipers are not safe for concurrent Get() and Set() operations. type Viper struct { + mutex sync.RWMutex // Delimiter that separates a list of keys // used to access a nested value in one go keyDelim string @@ -864,6 +864,8 @@ func (v *Viper) isPathShadowedInAutoEnv(path []string) string { func SetTypeByDefaultValue(enable bool) { v.SetTypeByDefaultValue(enable) } func (v *Viper) SetTypeByDefaultValue(enable bool) { + v.mutex.Lock() + defer v.mutex.Unlock() v.typeByDefValue = enable } @@ -882,6 +884,8 @@ func GetViper() *Viper { func Get(key string) interface{} { return v.Get(key) } func (v *Viper) Get(key string) interface{} { + v.mutex.RLock() + defer v.mutex.RUnlock() lcaseKey := strings.ToLower(key) val := v.find(lcaseKey, true) if val == nil { @@ -1166,6 +1170,8 @@ func (v *Viper) BindFlagValues(flags FlagValueSet) (err error) { func BindFlagValue(key string, flag FlagValue) error { return v.BindFlagValue(key, flag) } func (v *Viper) BindFlagValue(key string, flag FlagValue) error { + v.mutex.Lock() + defer v.mutex.Unlock() if flag == nil { return fmt.Errorf("flag for %q is nil", key) } @@ -1182,6 +1188,8 @@ func (v *Viper) BindFlagValue(key string, flag FlagValue) error { func BindEnv(input ...string) error { return v.BindEnv(input...) } func (v *Viper) BindEnv(input ...string) error { + v.mutex.Lock() + defer v.mutex.Unlock() if len(input) == 0 { return fmt.Errorf("missing key to bind to") } @@ -1391,6 +1399,8 @@ func stringToStringConv(val string) interface{} { func IsSet(key string) bool { return v.IsSet(key) } func (v *Viper) IsSet(key string) bool { + v.mutex.RLock() + defer v.mutex.RUnlock() lcaseKey := strings.ToLower(key) val := v.find(lcaseKey, false) return val != nil @@ -1401,6 +1411,8 @@ func (v *Viper) IsSet(key string) bool { func AutomaticEnv() { v.AutomaticEnv() } func (v *Viper) AutomaticEnv() { + v.mutex.Lock() + defer v.mutex.Unlock() v.automaticEnvApplied = true } @@ -1410,6 +1422,8 @@ func (v *Viper) AutomaticEnv() { func SetEnvKeyReplacer(r *strings.Replacer) { v.SetEnvKeyReplacer(r) } func (v *Viper) SetEnvKeyReplacer(r *strings.Replacer) { + v.mutex.Lock() + defer v.mutex.Unlock() v.envKeyReplacer = r } @@ -1418,6 +1432,8 @@ func (v *Viper) SetEnvKeyReplacer(r *strings.Replacer) { func RegisterAlias(alias string, key string) { v.RegisterAlias(alias, key) } func (v *Viper) RegisterAlias(alias string, key string) { + v.mutex.Lock() + defer v.mutex.Unlock() v.registerAlias(alias, strings.ToLower(key)) } @@ -1467,8 +1483,10 @@ func (v *Viper) realKey(key string) string { func InConfig(key string) bool { return v.InConfig(key) } func (v *Viper) InConfig(key string) bool { - lcaseKey := strings.ToLower(key) + v.mutex.RLock() + defer v.mutex.RUnlock() + lcaseKey := strings.ToLower(key) // if the requested key is an alias, then return the proper key lcaseKey = v.realKey(lcaseKey) path := strings.Split(lcaseKey, v.keyDelim) @@ -1482,6 +1500,8 @@ func (v *Viper) InConfig(key string) bool { func SetDefault(key string, value interface{}) { v.SetDefault(key, value) } func (v *Viper) SetDefault(key string, value interface{}) { + v.mutex.Lock() + defer v.mutex.Unlock() // If alias passed in, then set the proper default key = v.realKey(strings.ToLower(key)) value = toCaseInsensitiveValue(value) @@ -1501,6 +1521,8 @@ func (v *Viper) SetDefault(key string, value interface{}) { func Set(key string, value interface{}) { v.Set(key, value) } func (v *Viper) Set(key string, value interface{}) { + v.mutex.Lock() + defer v.mutex.Unlock() // If alias passed in, then set the proper override key = v.realKey(strings.ToLower(key)) value = toCaseInsensitiveValue(value) @@ -1518,6 +1540,8 @@ func (v *Viper) Set(key string, value interface{}) { func ReadInConfig() error { return v.ReadInConfig() } func (v *Viper) ReadInConfig() error { + v.mutex.Lock() + defer v.mutex.Unlock() v.logger.Info("attempting to read in config file") filename, err := v.getConfigFile() if err != nil { @@ -1549,6 +1573,8 @@ func (v *Viper) ReadInConfig() error { func MergeInConfig() error { return v.MergeInConfig() } func (v *Viper) MergeInConfig() error { + v.mutex.Lock() + defer v.mutex.Unlock() v.logger.Info("attempting to merge in config file") filename, err := v.getConfigFile() if err != nil { @@ -1572,6 +1598,8 @@ func (v *Viper) MergeInConfig() error { func ReadConfig(in io.Reader) error { return v.ReadConfig(in) } func (v *Viper) ReadConfig(in io.Reader) error { + v.mutex.Lock() + defer v.mutex.Unlock() v.config = make(map[string]interface{}) return v.unmarshalReader(in, v.config) } @@ -1592,6 +1620,8 @@ func (v *Viper) MergeConfig(in io.Reader) error { func MergeConfigMap(cfg map[string]interface{}) error { return v.MergeConfigMap(cfg) } func (v *Viper) MergeConfigMap(cfg map[string]interface{}) error { + v.mutex.Lock() + defer v.mutex.Unlock() if v.config == nil { v.config = make(map[string]interface{}) } @@ -1604,6 +1634,8 @@ func (v *Viper) MergeConfigMap(cfg map[string]interface{}) error { func WriteConfig() error { return v.WriteConfig() } func (v *Viper) WriteConfig() error { + v.mutex.RLock() + defer v.mutex.RUnlock() filename, err := v.getConfigFile() if err != nil { return err @@ -1615,6 +1647,8 @@ func (v *Viper) WriteConfig() error { func SafeWriteConfig() error { return v.SafeWriteConfig() } func (v *Viper) SafeWriteConfig() error { + v.mutex.RLock() + defer v.mutex.RUnlock() if len(v.configPaths) < 1 { return errors.New("missing configuration for 'configPath'") } @@ -1625,6 +1659,8 @@ func (v *Viper) SafeWriteConfig() error { func WriteConfigAs(filename string) error { return v.WriteConfigAs(filename) } func (v *Viper) WriteConfigAs(filename string) error { + v.mutex.RLock() + defer v.mutex.RUnlock() return v.writeConfig(filename, true) } @@ -1632,6 +1668,8 @@ func (v *Viper) WriteConfigAs(filename string) error { func SafeWriteConfigAs(filename string) error { return v.SafeWriteConfigAs(filename) } func (v *Viper) SafeWriteConfigAs(filename string) error { + v.mutex.RLock() + defer v.mutex.RUnlock() alreadyExists, err := afero.Exists(v.fs, filename) if alreadyExists && err == nil { return ConfigFileAlreadyExistsError(filename) @@ -1938,6 +1976,8 @@ func (v *Viper) watchRemoteConfig(provider RemoteProvider) (map[string]interface func AllKeys() []string { return v.AllKeys() } func (v *Viper) AllKeys() []string { + v.mutex.RLock() + defer v.mutex.RUnlock() m := map[string]bool{} // add all paths, by order of descending priority to ensure correct shadowing m = v.flattenAndMergeMap(m, castMapStringToMapInterface(v.aliases), "") @@ -2019,6 +2059,8 @@ outer: func AllSettings() map[string]interface{} { return v.AllSettings() } func (v *Viper) AllSettings() map[string]interface{} { + v.mutex.RLock() + defer v.mutex.RUnlock() m := map[string]interface{}{} // start from the list of keys, and construct the map one value at a time for _, k := range v.AllKeys() { @@ -2114,6 +2156,8 @@ func (v *Viper) getConfigFile() (string, error) { func Debug() { v.Debug() } func (v *Viper) Debug() { + v.mutex.RLock() + defer v.mutex.RUnlock() fmt.Printf("Aliases:\n%#v\n", v.aliases) fmt.Printf("Override:\n%#v\n", v.override) fmt.Printf("PFlags:\n%#v\n", v.pflags) diff --git a/viper_test.go b/viper_test.go index c41a1e7..654b648 100644 --- a/viper_test.go +++ b/viper_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + "github.com/everactive/viper/internal/testutil" "github.com/fsnotify/fsnotify" "github.com/mitchellh/mapstructure" "github.com/spf13/afero" @@ -29,8 +30,6 @@ import ( "github.com/spf13/pflag" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/spf13/viper/internal/testutil" ) // var yamlExample = []byte(`Hacker: true