mirror of
https://github.com/spf13/viper
synced 2025-05-11 22:57:21 +00:00
Implement support for config includes/imports
This adds support for "importing" or "including" configuration files by specifying a config key which, when encountered, is to be treated as a list of config files to include and be merged into the source config file, creating a config hierarchy. This allows developers to have e.g., a "base" configuration which is then imported by their "beta" config or "production" config to be overridden or extended. One can call `SetConfigIncludeKey("includes")` and then include config files in their e.g., YAML config as follows: ```yaml includes: - "base" ```
This commit is contained in:
parent
a7cfd8b8e0
commit
81b4c39489
2 changed files with 173 additions and 23 deletions
141
viper.go
141
viper.go
|
@ -51,6 +51,10 @@ import (
|
|||
"github.com/spf13/viper/internal/encoding/yaml"
|
||||
)
|
||||
|
||||
// includeMaxDepth is the maximum depth of recursion we permit for
|
||||
// traversing included filepaths.
|
||||
const includeMaxDepth = 20
|
||||
|
||||
// ConfigMarshalError happens when failing to marshal the configuration.
|
||||
type ConfigMarshalError struct {
|
||||
err error
|
||||
|
@ -234,6 +238,7 @@ type Viper struct {
|
|||
configFile string
|
||||
configType string
|
||||
configPermissions os.FileMode
|
||||
configIncludeKey string
|
||||
envPrefix string
|
||||
|
||||
// Specific commands for ini parsing
|
||||
|
@ -264,6 +269,7 @@ func New() *Viper {
|
|||
v := new(Viper)
|
||||
v.keyDelim = "."
|
||||
v.configName = "config"
|
||||
v.configIncludeKey = ""
|
||||
v.configPermissions = os.FileMode(0644)
|
||||
v.fs = afero.NewOsFs()
|
||||
v.config = make(map[string]interface{})
|
||||
|
@ -449,6 +455,16 @@ func (v *Viper) WatchConfig() {
|
|||
initWG.Wait() // make sure that the go routine above fully ended before returning
|
||||
}
|
||||
|
||||
// SetConfigIncludeKey defines the config key value used for including parent config
|
||||
// files. If the value is not an absolute path, Viper will scan all of the config
|
||||
// paths trying to locate the imported file, and then begin resolving that config
|
||||
// file recursively.
|
||||
func SetConfigIncludeKey(in string) { v.SetConfigIncludeKey(in) }
|
||||
|
||||
func (v *Viper) SetConfigIncludeKey(in string) {
|
||||
v.configIncludeKey = in
|
||||
}
|
||||
|
||||
// SetConfigFile explicitly defines the path, name and extension of the config file.
|
||||
// Viper will use this and not check any of the config paths.
|
||||
func SetConfigFile(in string) { v.SetConfigFile(in) }
|
||||
|
@ -1463,20 +1479,98 @@ func (v *Viper) ReadInConfig() error {
|
|||
return UnsupportedConfigError(v.getConfigType())
|
||||
}
|
||||
|
||||
jww.DEBUG.Println("Reading file: ", filename)
|
||||
file, err := afero.ReadFile(v.fs, filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]interface{})
|
||||
|
||||
err = v.unmarshalReader(bytes.NewReader(file), config)
|
||||
config, err := v.resolveIncludes(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v.config = config
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveIncludes will read the config file at the provided path, and if the
|
||||
// configIncludeKey is set, will attempt to resolve all included configuration
|
||||
// files.
|
||||
func (v *Viper) resolveIncludes(in string) (map[string]interface{}, error) {
|
||||
includes := []string{}
|
||||
visited := make(map[string]map[string]interface{})
|
||||
if err := v.resolveIncludesHelper(in, &includes, visited, 0); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resolved := New()
|
||||
for _, include := range includes {
|
||||
if err := resolved.MergeConfigMap(visited[include]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// we remove the include key, to ensure that writing the config works
|
||||
// as desired
|
||||
if v.configIncludeKey != "" {
|
||||
delete(resolved.config, v.configIncludeKey)
|
||||
}
|
||||
|
||||
return resolved.config, nil
|
||||
}
|
||||
|
||||
// resolveIncludesHelper recursively attempts to resolve the included config
|
||||
// files. The includes slice keeps track of the included config files in post
|
||||
// order traversal, to preserve the import order. The visited map retains the
|
||||
// visited paths and the included config data (to avoid unnecessary filesystem
|
||||
// reads).
|
||||
func (v *Viper) resolveIncludesHelper(in string, includes *[]string, visited map[string]map[string]interface{}, depth int) error {
|
||||
if depth > includeMaxDepth {
|
||||
return errors.New("maximum recursion depth exceeded")
|
||||
}
|
||||
|
||||
// if this isn't an absolute path
|
||||
if !filepath.IsAbs(in) {
|
||||
// and doesn't match the config file specified, try to find the
|
||||
// referenced file
|
||||
if v.configFile != "" && in != v.configFile {
|
||||
var err error
|
||||
in, err = v.findConfigFile(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := visited[in]; ok {
|
||||
jww.DEBUG.Println("Already visited: ", in)
|
||||
return nil
|
||||
}
|
||||
|
||||
jww.DEBUG.Println("Reading file: ", in)
|
||||
file, err := afero.ReadFile(v.fs, in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]interface{})
|
||||
err = v.unmarshalReader(bytes.NewReader(file), config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
visited[in] = config
|
||||
|
||||
if v.configIncludeKey != "" {
|
||||
parent := New()
|
||||
parent.MergeConfigMap(config)
|
||||
children := parent.GetStringSlice(v.configIncludeKey)
|
||||
for _, child := range children {
|
||||
if err = v.resolveIncludesHelper(child, includes, visited, depth+1); err != nil {
|
||||
return err
|
||||
}
|
||||
// Post-order traversal in order to preserve the import order.
|
||||
*includes = append(*includes, child)
|
||||
}
|
||||
}
|
||||
|
||||
*includes = append(*includes, in)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1494,12 +1588,12 @@ func (v *Viper) MergeInConfig() error {
|
|||
return UnsupportedConfigError(v.getConfigType())
|
||||
}
|
||||
|
||||
file, err := afero.ReadFile(v.fs, filename)
|
||||
config, err := v.resolveIncludes(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return v.MergeConfig(bytes.NewReader(file))
|
||||
return v.MergeConfigMap(config)
|
||||
}
|
||||
|
||||
// ReadConfig will read a configuration file, setting existing keys to nil if the
|
||||
|
@ -2094,7 +2188,7 @@ func (v *Viper) getConfigType() string {
|
|||
|
||||
func (v *Viper) getConfigFile() (string, error) {
|
||||
if v.configFile == "" {
|
||||
cf, err := v.findConfigFile()
|
||||
cf, err := v.findConfigFile(v.configName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -2103,32 +2197,33 @@ func (v *Viper) getConfigFile() (string, error) {
|
|||
return v.configFile, nil
|
||||
}
|
||||
|
||||
func (v *Viper) searchInPath(in string) (filename string) {
|
||||
jww.DEBUG.Println("Searching for config in ", in)
|
||||
// Searches for configName in path.
|
||||
func (v *Viper) searchInPath(path string, configName string) (filename string) {
|
||||
jww.DEBUG.Println("Searching for config in ", path)
|
||||
for _, ext := range SupportedExts {
|
||||
jww.DEBUG.Println("Checking for", filepath.Join(in, v.configName+"."+ext))
|
||||
if b, _ := exists(v.fs, filepath.Join(in, v.configName+"."+ext)); b {
|
||||
jww.DEBUG.Println("Found: ", filepath.Join(in, v.configName+"."+ext))
|
||||
return filepath.Join(in, v.configName+"."+ext)
|
||||
jww.DEBUG.Println("Checking for", filepath.Join(path, configName+"."+ext))
|
||||
if b, _ := exists(v.fs, filepath.Join(path, configName+"."+ext)); b {
|
||||
jww.DEBUG.Println("Found: ", filepath.Join(path, configName+"."+ext))
|
||||
return filepath.Join(path, configName+"."+ext)
|
||||
}
|
||||
}
|
||||
|
||||
if v.configType != "" {
|
||||
if b, _ := exists(v.fs, filepath.Join(in, v.configName)); b {
|
||||
return filepath.Join(in, v.configName)
|
||||
if b, _ := exists(v.fs, filepath.Join(path, configName)); b {
|
||||
return filepath.Join(path, configName)
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Search all configPaths for any config file.
|
||||
// Search all configPaths for any config file with the provided config name.
|
||||
// Returns the first path that exists (and is a config file).
|
||||
func (v *Viper) findConfigFile() (string, error) {
|
||||
func (v *Viper) findConfigFile(configName string) (string, error) {
|
||||
jww.INFO.Println("Searching for config in ", v.configPaths)
|
||||
|
||||
for _, cp := range v.configPaths {
|
||||
file := v.searchInPath(cp)
|
||||
file := v.searchInPath(cp, configName)
|
||||
if file != "" {
|
||||
return file, nil
|
||||
}
|
||||
|
|
|
@ -2369,6 +2369,61 @@ func TestSliceIndexAccess(t *testing.T) {
|
|||
assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2"))
|
||||
}
|
||||
|
||||
var yamlIncludeRoot = []byte(`
|
||||
hello:
|
||||
pop: 37890
|
||||
largenum: 765432101234567
|
||||
`)
|
||||
|
||||
var yamlIncludeIntermediate = []byte(`
|
||||
includes:
|
||||
- root
|
||||
|
||||
hello:
|
||||
pop: 45000
|
||||
ints:
|
||||
- 1
|
||||
- 2
|
||||
fu: bar
|
||||
`)
|
||||
|
||||
var yamlIncludeOther = []byte(`
|
||||
foo: bar
|
||||
`)
|
||||
|
||||
var yamlIncludeLeaf = []byte(`
|
||||
includes:
|
||||
- intermediate
|
||||
- other
|
||||
|
||||
fu: baz
|
||||
`)
|
||||
|
||||
func TestSetConfigIncludeKey(t *testing.T) {
|
||||
root, _, cleanup := initDirs(t)
|
||||
defer cleanup()
|
||||
|
||||
configFile := filepath.Join(root, "leaf.yaml")
|
||||
|
||||
ioutil.WriteFile(filepath.Join(root, "root.yaml"), yamlIncludeRoot, 0640)
|
||||
ioutil.WriteFile(filepath.Join(root, "intermediate.yaml"), yamlIncludeIntermediate, 0640)
|
||||
ioutil.WriteFile(filepath.Join(root, "other.yaml"), yamlIncludeOther, 0640)
|
||||
ioutil.WriteFile(configFile, yamlIncludeLeaf, 0640)
|
||||
|
||||
v := New()
|
||||
v.SetConfigIncludeKey("includes")
|
||||
v.SetConfigFile(configFile)
|
||||
v.AddConfigPath(root)
|
||||
|
||||
assert.NoError(t, v.ReadInConfig())
|
||||
assert.Equal(t, "baz", v.GetString("fu"))
|
||||
assert.Equal(t, "bar", v.GetString("foo"))
|
||||
assert.Equal(t, 45000, v.GetInt("hello.pop"))
|
||||
assert.Equal(t, 765432101234567, v.GetInt("hello.largenum"))
|
||||
assert.Equal(t, []int{1, 2}, v.GetIntSlice("hello.ints"))
|
||||
assert.False(t, v.InConfig("includes"))
|
||||
}
|
||||
|
||||
func BenchmarkGetBool(b *testing.B) {
|
||||
key := "BenchmarkGetBool"
|
||||
v = New()
|
||||
|
|
Loading…
Add table
Reference in a new issue