From 361864a3a8841246ced928fe81df1c6436991ab6 Mon Sep 17 00:00:00 2001 From: Tyler Davis <676253+phinnaeus@users.noreply.github.com> Date: Fri, 27 Aug 2021 17:56:23 +1000 Subject: [PATCH] Add ConfigExtension This forces Viper to look for config files with a given extension which short circuits the supported extensions checks. See https://github.com/spf13/viper/issues/1163 --- viper.go | 44 +++++++++++++++++++++++++++++++++++--------- viper_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/viper.go b/viper.go index 46b1a85..8e93967 100644 --- a/viper.go +++ b/viper.go @@ -233,6 +233,7 @@ type Viper struct { configName string configFile string configType string + configExtension string configPermissions os.FileMode envPrefix string @@ -2059,6 +2060,17 @@ func (v *Viper) SetConfigType(in string) { } } +// SetConfigExtension forces viper to look for a specific file extension +// rather than checking all supported types +func SetConfigExtension(in string) { v.SetConfigExtension(in) } + +func (v *Viper) SetConfigExtension(in string) { + if in != "" { + // normalize extensions, for example ".yml" and "yml" + v.configExtension = strings.TrimPrefix(in, ".") + } +} + // SetConfigPermissions sets the permissions for the config file. func SetConfigPermissions(perm os.FileMode) { v.SetConfigPermissions(perm) } @@ -2104,18 +2116,32 @@ func (v *Viper) getConfigFile() (string, error) { } func (v *Viper) searchInPath(in string) (filename string) { - jww.DEBUG.Println("Searching for config in ", in) - for _, ext := range SupportedExts { - jww.DEBUG.Println("Checking for", filepath.Join(in, v.configName+"."+ext)) - if b, _ := exists(v.fs, filepath.Join(in, v.configName+"."+ext)); b { - jww.DEBUG.Println("Found: ", filepath.Join(in, v.configName+"."+ext)) - return filepath.Join(in, v.configName+"."+ext) + jww.DEBUG.Println("Searching for config in", in) + + basePath := filepath.Join(in, v.configName) + if v.configType != "" { + jww.DEBUG.Println("configType specified; checking for", basePath) + if b, _ := exists(v.fs, basePath); b { + jww.DEBUG.Println("Found: ", basePath) + return basePath } } - if v.configType != "" { - if b, _ := exists(v.fs, filepath.Join(in, v.configName)); b { - return filepath.Join(in, v.configName) + if v.configExtension != "" { + path := basePath + "." + v.configExtension + jww.DEBUG.Println("configExtension specified; checking for", path) + if b, _ := exists(v.fs, path); b { + jww.DEBUG.Println("Found: ", path) + return path + } + } + + for _, ext := range SupportedExts { + path := basePath + "." + ext + jww.DEBUG.Println("Checking for", path) + if b, _ := exists(v.fs, path); b { + jww.DEBUG.Println("Found: ", path) + return path } } diff --git a/viper_test.go b/viper_test.go index 4192748..c4f53ae 100644 --- a/viper_test.go +++ b/viper_test.go @@ -342,6 +342,30 @@ func TestSearchInPath(t *testing.T) { assert.NoError(t, err) } +func TestSearchInPath_extension(t *testing.T) { + filename := "config" + path := "/tmp" + file := filepath.Join(path, filename) + SetConfigName(filename) + SetConfigExtension("yaml") + AddConfigPath(path) + _, createErr := v.fs.Create(file+".yaml") + assert.NoError(t, createErr) + + // also create a second file which we want to ignore + _, createErr2 := v.fs.Create(file+".json") + assert.NoError(t, createErr2) + + defer func() { + _ = v.fs.Remove(file+".yaml") + _ = v.fs.Remove(file+".json") + }() + + filename, err := v.getConfigFile() + assert.Equal(t, file+".yaml", filename) + assert.NoError(t, err) +} + func TestSearchInPath_FilesOnly(t *testing.T) { fs := afero.NewMemMapFs()