Implement support for config includes/imports

This adds support for "importing" or "including" configuration
files by specifying a config key which, when encountered, is to be
treated as a list of config files to include and be merged into the
source config file, creating a config hierarchy.

This allows developers to have e.g., a "base" configuration which
is then imported by their "beta" config or "production" config to
be overridden or extended.

One can call `SetConfigIncludeKey("includes")` and then
include config files in their e.g., YAML config as follows:

```yaml
includes:
  - "base"
```
This commit is contained in:
Thom Dixon 2020-04-12 16:19:23 -07:00
parent a7cfd8b8e0
commit 81b4c39489
2 changed files with 173 additions and 23 deletions

141
viper.go
View file

@ -51,6 +51,10 @@ import (
"github.com/spf13/viper/internal/encoding/yaml"
)
// includeMaxDepth is the maximum depth of recursion we permit for
// traversing included filepaths.
const includeMaxDepth = 20
// ConfigMarshalError happens when failing to marshal the configuration.
type ConfigMarshalError struct {
err error
@ -234,6 +238,7 @@ type Viper struct {
configFile string
configType string
configPermissions os.FileMode
configIncludeKey string
envPrefix string
// Specific commands for ini parsing
@ -264,6 +269,7 @@ func New() *Viper {
v := new(Viper)
v.keyDelim = "."
v.configName = "config"
v.configIncludeKey = ""
v.configPermissions = os.FileMode(0644)
v.fs = afero.NewOsFs()
v.config = make(map[string]interface{})
@ -449,6 +455,16 @@ func (v *Viper) WatchConfig() {
initWG.Wait() // make sure that the go routine above fully ended before returning
}
// SetConfigIncludeKey defines the config key value used for including parent config
// files. If the value is not an absolute path, Viper will scan all of the config
// paths trying to locate the imported file, and then begin resolving that config
// file recursively.
func SetConfigIncludeKey(in string) { v.SetConfigIncludeKey(in) }
func (v *Viper) SetConfigIncludeKey(in string) {
v.configIncludeKey = in
}
// SetConfigFile explicitly defines the path, name and extension of the config file.
// Viper will use this and not check any of the config paths.
func SetConfigFile(in string) { v.SetConfigFile(in) }
@ -1463,20 +1479,98 @@ func (v *Viper) ReadInConfig() error {
return UnsupportedConfigError(v.getConfigType())
}
jww.DEBUG.Println("Reading file: ", filename)
file, err := afero.ReadFile(v.fs, filename)
if err != nil {
return err
}
config := make(map[string]interface{})
err = v.unmarshalReader(bytes.NewReader(file), config)
config, err := v.resolveIncludes(filename)
if err != nil {
return err
}
v.config = config
return nil
}
// resolveIncludes will read the config file at the provided path, and if the
// configIncludeKey is set, will attempt to resolve all included configuration
// files.
func (v *Viper) resolveIncludes(in string) (map[string]interface{}, error) {
includes := []string{}
visited := make(map[string]map[string]interface{})
if err := v.resolveIncludesHelper(in, &includes, visited, 0); err != nil {
return nil, err
}
resolved := New()
for _, include := range includes {
if err := resolved.MergeConfigMap(visited[include]); err != nil {
return nil, err
}
}
// we remove the include key, to ensure that writing the config works
// as desired
if v.configIncludeKey != "" {
delete(resolved.config, v.configIncludeKey)
}
return resolved.config, nil
}
// resolveIncludesHelper recursively attempts to resolve the included config
// files. The includes slice keeps track of the included config files in post
// order traversal, to preserve the import order. The visited map retains the
// visited paths and the included config data (to avoid unnecessary filesystem
// reads).
func (v *Viper) resolveIncludesHelper(in string, includes *[]string, visited map[string]map[string]interface{}, depth int) error {
if depth > includeMaxDepth {
return errors.New("maximum recursion depth exceeded")
}
// if this isn't an absolute path
if !filepath.IsAbs(in) {
// and doesn't match the config file specified, try to find the
// referenced file
if v.configFile != "" && in != v.configFile {
var err error
in, err = v.findConfigFile(in)
if err != nil {
return err
}
}
}
if _, ok := visited[in]; ok {
jww.DEBUG.Println("Already visited: ", in)
return nil
}
jww.DEBUG.Println("Reading file: ", in)
file, err := afero.ReadFile(v.fs, in)
if err != nil {
return err
}
config := make(map[string]interface{})
err = v.unmarshalReader(bytes.NewReader(file), config)
if err != nil {
return err
}
visited[in] = config
if v.configIncludeKey != "" {
parent := New()
parent.MergeConfigMap(config)
children := parent.GetStringSlice(v.configIncludeKey)
for _, child := range children {
if err = v.resolveIncludesHelper(child, includes, visited, depth+1); err != nil {
return err
}
// Post-order traversal in order to preserve the import order.
*includes = append(*includes, child)
}
}
*includes = append(*includes, in)
return nil
}
@ -1494,12 +1588,12 @@ func (v *Viper) MergeInConfig() error {
return UnsupportedConfigError(v.getConfigType())
}
file, err := afero.ReadFile(v.fs, filename)
config, err := v.resolveIncludes(filename)
if err != nil {
return err
}
return v.MergeConfig(bytes.NewReader(file))
return v.MergeConfigMap(config)
}
// ReadConfig will read a configuration file, setting existing keys to nil if the
@ -2094,7 +2188,7 @@ func (v *Viper) getConfigType() string {
func (v *Viper) getConfigFile() (string, error) {
if v.configFile == "" {
cf, err := v.findConfigFile()
cf, err := v.findConfigFile(v.configName)
if err != nil {
return "", err
}
@ -2103,32 +2197,33 @@ func (v *Viper) getConfigFile() (string, error) {
return v.configFile, nil
}
func (v *Viper) searchInPath(in string) (filename string) {
jww.DEBUG.Println("Searching for config in ", in)
// Searches for configName in path.
func (v *Viper) searchInPath(path string, configName string) (filename string) {
jww.DEBUG.Println("Searching for config in ", path)
for _, ext := range SupportedExts {
jww.DEBUG.Println("Checking for", filepath.Join(in, v.configName+"."+ext))
if b, _ := exists(v.fs, filepath.Join(in, v.configName+"."+ext)); b {
jww.DEBUG.Println("Found: ", filepath.Join(in, v.configName+"."+ext))
return filepath.Join(in, v.configName+"."+ext)
jww.DEBUG.Println("Checking for", filepath.Join(path, configName+"."+ext))
if b, _ := exists(v.fs, filepath.Join(path, configName+"."+ext)); b {
jww.DEBUG.Println("Found: ", filepath.Join(path, configName+"."+ext))
return filepath.Join(path, configName+"."+ext)
}
}
if v.configType != "" {
if b, _ := exists(v.fs, filepath.Join(in, v.configName)); b {
return filepath.Join(in, v.configName)
if b, _ := exists(v.fs, filepath.Join(path, configName)); b {
return filepath.Join(path, configName)
}
}
return ""
}
// Search all configPaths for any config file.
// Search all configPaths for any config file with the provided config name.
// Returns the first path that exists (and is a config file).
func (v *Viper) findConfigFile() (string, error) {
func (v *Viper) findConfigFile(configName string) (string, error) {
jww.INFO.Println("Searching for config in ", v.configPaths)
for _, cp := range v.configPaths {
file := v.searchInPath(cp)
file := v.searchInPath(cp, configName)
if file != "" {
return file, nil
}

View file

@ -2369,6 +2369,61 @@ func TestSliceIndexAccess(t *testing.T) {
assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2"))
}
var yamlIncludeRoot = []byte(`
hello:
pop: 37890
largenum: 765432101234567
`)
var yamlIncludeIntermediate = []byte(`
includes:
- root
hello:
pop: 45000
ints:
- 1
- 2
fu: bar
`)
var yamlIncludeOther = []byte(`
foo: bar
`)
var yamlIncludeLeaf = []byte(`
includes:
- intermediate
- other
fu: baz
`)
func TestSetConfigIncludeKey(t *testing.T) {
root, _, cleanup := initDirs(t)
defer cleanup()
configFile := filepath.Join(root, "leaf.yaml")
ioutil.WriteFile(filepath.Join(root, "root.yaml"), yamlIncludeRoot, 0640)
ioutil.WriteFile(filepath.Join(root, "intermediate.yaml"), yamlIncludeIntermediate, 0640)
ioutil.WriteFile(filepath.Join(root, "other.yaml"), yamlIncludeOther, 0640)
ioutil.WriteFile(configFile, yamlIncludeLeaf, 0640)
v := New()
v.SetConfigIncludeKey("includes")
v.SetConfigFile(configFile)
v.AddConfigPath(root)
assert.NoError(t, v.ReadInConfig())
assert.Equal(t, "baz", v.GetString("fu"))
assert.Equal(t, "bar", v.GetString("foo"))
assert.Equal(t, 45000, v.GetInt("hello.pop"))
assert.Equal(t, 765432101234567, v.GetInt("hello.largenum"))
assert.Equal(t, []int{1, 2}, v.GetIntSlice("hello.ints"))
assert.False(t, v.InConfig("includes"))
}
func BenchmarkGetBool(b *testing.B) {
key := "BenchmarkGetBool"
v = New()