From bf1cf5d900d4d0e35562ffb8877e4c34f75fbb3c Mon Sep 17 00:00:00 2001 From: John Stevens <98547427+jstevens-vorto@users.noreply.github.com> Date: Mon, 13 Mar 2023 09:23:46 -0600 Subject: [PATCH] Viper map lock (#8) * Add RWMutex protecting viper maps * Update README --- README.md | 6 +++++- viper.go | 28 +++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index cd39290..093fdc9 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,8 @@ -> ## Viper v2 feedback +# Fork of spf13/viper that makes it safe for concurrent access + +Forked from version 1.9.0 of `spf13/viper`. Adds an `RWMutex` to the `viper` struct that is used to protect reads and writes to the struct's internal maps. This should stop some our servers' intermittent crashing. + +> > ## Viper v2 feedback > Viper is heading towards v2 and we would love to hear what _**you**_ would like to see in it. Share your thoughts here: https://forms.gle/R6faU74qPRPAzchZ9 > > **Thank you!** diff --git a/viper.go b/viper.go index fa6f3e3..199b6b8 100644 --- a/viper.go +++ b/viper.go @@ -223,6 +223,8 @@ type Viper struct { // TODO: should probably be protected with a mutex encoderRegistry *encoding.EncoderRegistry decoderRegistry *encoding.DecoderRegistry + + mapLock sync.RWMutex } // New returns an initialized Viper instance. @@ -243,6 +245,8 @@ func New() *Viper { v.typeByDefValue = false v.logger = jwwLogger{} + v.mapLock = sync.RWMutex{} + v.resetEncoding() return v @@ -658,7 +662,9 @@ func (v *Viper) searchMap(source map[string]interface{}, path []string) interfac return source } + v.mapLock.RLock() next, ok := source[path[0]] + v.mapLock.RUnlock() if ok { // Fast path if len(path) == 1 { @@ -763,7 +769,9 @@ func (v *Viper) searchMapWithPathPrefixes( pathIndex int, path []string, ) interface{} { + v.mapLock.RLock() next, ok := sourceMap[prefixKey] + v.mapLock.RUnlock() if !ok { return nil } @@ -832,7 +840,10 @@ func (v *Viper) isPathShadowedInFlatMap(path []string, mi interface{}) string { var parentKey string for i := 1; i < len(path); i++ { parentKey = strings.Join(path[0:i], v.keyDelim) - if _, ok := m[parentKey]; ok { + v.mapLock.RLock() + _, ok := m[parentKey] + v.mapLock.RUnlock() + if ok { return parentKey } } @@ -1456,9 +1467,13 @@ func (v *Viper) RegisterAlias(alias string, key string) { func (v *Viper) registerAlias(alias string, key string) { alias = strings.ToLower(alias) if alias != key && alias != v.realKey(key) { + v.mapLock.RLock() _, exists := v.aliases[alias] + v.mapLock.RUnlock() if !exists { + v.mapLock.Lock() + // if we alias something that exists in one of the maps to another // name, we'll never be able to get that value using the original // name, so move the config value to the new realkey. @@ -1479,6 +1494,8 @@ func (v *Viper) registerAlias(alias string, key string) { v.override[key] = val } v.aliases[alias] = key + + v.mapLock.Unlock() } } else { v.logger.Warn("creating circular reference alias", "alias", alias, "key", key, "real_key", v.realKey(key)) @@ -1523,7 +1540,9 @@ func (v *Viper) SetDefault(key string, value interface{}) { deepestMap := deepSearch(v.defaults, path[0:len(path)-1]) // set innermost value + v.mapLock.Lock() deepestMap[lastKey] = value + v.mapLock.Unlock() } // Set sets the value for the key in the override register. @@ -1542,7 +1561,9 @@ func (v *Viper) Set(key string, value interface{}) { deepestMap := deepSearch(v.override, path[0:len(path)-1]) // set innermost value + v.mapLock.Lock() deepestMap[lastKey] = value + v.mapLock.Unlock() } // ReadInConfig will discover and load the configuration file from disk @@ -1627,8 +1648,10 @@ func (v *Viper) MergeConfigMap(cfg map[string]interface{}) error { if v.config == nil { v.config = make(map[string]interface{}) } + v.mapLock.Lock() insensitiviseMap(cfg) mergeMaps(cfg, v.config, nil) + v.mapLock.Unlock() return nil } @@ -1716,6 +1739,9 @@ func unmarshalReader(in io.Reader, c map[string]interface{}) error { } func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error { + v.mapLock.Lock() + defer v.mapLock.Unlock() + buf := new(bytes.Buffer) buf.ReadFrom(in)