From 81b4c394890c683afeeb8fd02cf2e38c311ac2ee Mon Sep 17 00:00:00 2001 From: Thom Dixon Date: Sun, 12 Apr 2020 16:19:23 -0700 Subject: [PATCH] Implement support for config includes/imports This adds support for "importing" or "including" configuration files by specifying a config key which, when encountered, is to be treated as a list of config files to include and be merged into the source config file, creating a config hierarchy. This allows developers to have e.g., a "base" configuration which is then imported by their "beta" config or "production" config to be overridden or extended. One can call `SetConfigIncludeKey("includes")` and then include config files in their e.g., YAML config as follows: ```yaml includes: - "base" ``` --- viper.go | 141 ++++++++++++++++++++++++++++++++++++++++++-------- viper_test.go | 55 ++++++++++++++++++++ 2 files changed, 173 insertions(+), 23 deletions(-) diff --git a/viper.go b/viper.go index 46b1a85..8a8d7a3 100644 --- a/viper.go +++ b/viper.go @@ -51,6 +51,10 @@ import ( "github.com/spf13/viper/internal/encoding/yaml" ) +// includeMaxDepth is the maximum depth of recursion we permit for +// traversing included filepaths. +const includeMaxDepth = 20 + // ConfigMarshalError happens when failing to marshal the configuration. type ConfigMarshalError struct { err error @@ -234,6 +238,7 @@ type Viper struct { configFile string configType string configPermissions os.FileMode + configIncludeKey string envPrefix string // Specific commands for ini parsing @@ -264,6 +269,7 @@ func New() *Viper { v := new(Viper) v.keyDelim = "." v.configName = "config" + v.configIncludeKey = "" v.configPermissions = os.FileMode(0644) v.fs = afero.NewOsFs() v.config = make(map[string]interface{}) @@ -449,6 +455,16 @@ func (v *Viper) WatchConfig() { initWG.Wait() // make sure that the go routine above fully ended before returning } +// SetConfigIncludeKey defines the config key value used for including parent config +// files. If the value is not an absolute path, Viper will scan all of the config +// paths trying to locate the imported file, and then begin resolving that config +// file recursively. +func SetConfigIncludeKey(in string) { v.SetConfigIncludeKey(in) } + +func (v *Viper) SetConfigIncludeKey(in string) { + v.configIncludeKey = in +} + // SetConfigFile explicitly defines the path, name and extension of the config file. // Viper will use this and not check any of the config paths. func SetConfigFile(in string) { v.SetConfigFile(in) } @@ -1463,20 +1479,98 @@ func (v *Viper) ReadInConfig() error { return UnsupportedConfigError(v.getConfigType()) } - jww.DEBUG.Println("Reading file: ", filename) - file, err := afero.ReadFile(v.fs, filename) - if err != nil { - return err - } - - config := make(map[string]interface{}) - - err = v.unmarshalReader(bytes.NewReader(file), config) + config, err := v.resolveIncludes(filename) if err != nil { return err } v.config = config + + return nil +} + +// resolveIncludes will read the config file at the provided path, and if the +// configIncludeKey is set, will attempt to resolve all included configuration +// files. +func (v *Viper) resolveIncludes(in string) (map[string]interface{}, error) { + includes := []string{} + visited := make(map[string]map[string]interface{}) + if err := v.resolveIncludesHelper(in, &includes, visited, 0); err != nil { + return nil, err + } + + resolved := New() + for _, include := range includes { + if err := resolved.MergeConfigMap(visited[include]); err != nil { + return nil, err + } + } + + // we remove the include key, to ensure that writing the config works + // as desired + if v.configIncludeKey != "" { + delete(resolved.config, v.configIncludeKey) + } + + return resolved.config, nil +} + +// resolveIncludesHelper recursively attempts to resolve the included config +// files. The includes slice keeps track of the included config files in post +// order traversal, to preserve the import order. The visited map retains the +// visited paths and the included config data (to avoid unnecessary filesystem +// reads). +func (v *Viper) resolveIncludesHelper(in string, includes *[]string, visited map[string]map[string]interface{}, depth int) error { + if depth > includeMaxDepth { + return errors.New("maximum recursion depth exceeded") + } + + // if this isn't an absolute path + if !filepath.IsAbs(in) { + // and doesn't match the config file specified, try to find the + // referenced file + if v.configFile != "" && in != v.configFile { + var err error + in, err = v.findConfigFile(in) + if err != nil { + return err + } + } + } + + if _, ok := visited[in]; ok { + jww.DEBUG.Println("Already visited: ", in) + return nil + } + + jww.DEBUG.Println("Reading file: ", in) + file, err := afero.ReadFile(v.fs, in) + if err != nil { + return err + } + + config := make(map[string]interface{}) + err = v.unmarshalReader(bytes.NewReader(file), config) + if err != nil { + return err + } + + visited[in] = config + + if v.configIncludeKey != "" { + parent := New() + parent.MergeConfigMap(config) + children := parent.GetStringSlice(v.configIncludeKey) + for _, child := range children { + if err = v.resolveIncludesHelper(child, includes, visited, depth+1); err != nil { + return err + } + // Post-order traversal in order to preserve the import order. + *includes = append(*includes, child) + } + } + + *includes = append(*includes, in) return nil } @@ -1494,12 +1588,12 @@ func (v *Viper) MergeInConfig() error { return UnsupportedConfigError(v.getConfigType()) } - file, err := afero.ReadFile(v.fs, filename) + config, err := v.resolveIncludes(filename) if err != nil { return err } - return v.MergeConfig(bytes.NewReader(file)) + return v.MergeConfigMap(config) } // ReadConfig will read a configuration file, setting existing keys to nil if the @@ -2094,7 +2188,7 @@ func (v *Viper) getConfigType() string { func (v *Viper) getConfigFile() (string, error) { if v.configFile == "" { - cf, err := v.findConfigFile() + cf, err := v.findConfigFile(v.configName) if err != nil { return "", err } @@ -2103,32 +2197,33 @@ func (v *Viper) getConfigFile() (string, error) { return v.configFile, nil } -func (v *Viper) searchInPath(in string) (filename string) { - jww.DEBUG.Println("Searching for config in ", in) +// Searches for configName in path. +func (v *Viper) searchInPath(path string, configName string) (filename string) { + jww.DEBUG.Println("Searching for config in ", path) for _, ext := range SupportedExts { - jww.DEBUG.Println("Checking for", filepath.Join(in, v.configName+"."+ext)) - if b, _ := exists(v.fs, filepath.Join(in, v.configName+"."+ext)); b { - jww.DEBUG.Println("Found: ", filepath.Join(in, v.configName+"."+ext)) - return filepath.Join(in, v.configName+"."+ext) + jww.DEBUG.Println("Checking for", filepath.Join(path, configName+"."+ext)) + if b, _ := exists(v.fs, filepath.Join(path, configName+"."+ext)); b { + jww.DEBUG.Println("Found: ", filepath.Join(path, configName+"."+ext)) + return filepath.Join(path, configName+"."+ext) } } if v.configType != "" { - if b, _ := exists(v.fs, filepath.Join(in, v.configName)); b { - return filepath.Join(in, v.configName) + if b, _ := exists(v.fs, filepath.Join(path, configName)); b { + return filepath.Join(path, configName) } } return "" } -// Search all configPaths for any config file. +// Search all configPaths for any config file with the provided config name. // Returns the first path that exists (and is a config file). -func (v *Viper) findConfigFile() (string, error) { +func (v *Viper) findConfigFile(configName string) (string, error) { jww.INFO.Println("Searching for config in ", v.configPaths) for _, cp := range v.configPaths { - file := v.searchInPath(cp) + file := v.searchInPath(cp, configName) if file != "" { return file, nil } diff --git a/viper_test.go b/viper_test.go index 4192748..cbf99c1 100644 --- a/viper_test.go +++ b/viper_test.go @@ -2369,6 +2369,61 @@ func TestSliceIndexAccess(t *testing.T) { assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2")) } +var yamlIncludeRoot = []byte(` +hello: + pop: 37890 + largenum: 765432101234567 +`) + +var yamlIncludeIntermediate = []byte(` +includes: +- root + +hello: + pop: 45000 + ints: + - 1 + - 2 +fu: bar +`) + +var yamlIncludeOther = []byte(` +foo: bar +`) + +var yamlIncludeLeaf = []byte(` +includes: +- intermediate +- other + +fu: baz +`) + +func TestSetConfigIncludeKey(t *testing.T) { + root, _, cleanup := initDirs(t) + defer cleanup() + + configFile := filepath.Join(root, "leaf.yaml") + + ioutil.WriteFile(filepath.Join(root, "root.yaml"), yamlIncludeRoot, 0640) + ioutil.WriteFile(filepath.Join(root, "intermediate.yaml"), yamlIncludeIntermediate, 0640) + ioutil.WriteFile(filepath.Join(root, "other.yaml"), yamlIncludeOther, 0640) + ioutil.WriteFile(configFile, yamlIncludeLeaf, 0640) + + v := New() + v.SetConfigIncludeKey("includes") + v.SetConfigFile(configFile) + v.AddConfigPath(root) + + assert.NoError(t, v.ReadInConfig()) + assert.Equal(t, "baz", v.GetString("fu")) + assert.Equal(t, "bar", v.GetString("foo")) + assert.Equal(t, 45000, v.GetInt("hello.pop")) + assert.Equal(t, 765432101234567, v.GetInt("hello.largenum")) + assert.Equal(t, []int{1, 2}, v.GetIntSlice("hello.ints")) + assert.False(t, v.InConfig("includes")) +} + func BenchmarkGetBool(b *testing.B) { key := "BenchmarkGetBool" v = New()