add parser

This commit is contained in:
liuguoyang1 2020-08-19 17:47:40 +08:00
parent 3826be3135
commit d41059169e
2 changed files with 277 additions and 172 deletions

263
parsers.go Normal file
View file

@ -0,0 +1,263 @@
package viper
import (
"bytes"
"encoding/json"
"fmt"
"github.com/hashicorp/hcl"
"github.com/hashicorp/hcl/hcl/printer"
"github.com/magiconair/properties"
"github.com/pelletier/go-toml"
"github.com/spf13/afero"
"github.com/subosito/gotenv"
"gopkg.in/ini.v1"
"gopkg.in/yaml.v2"
"io"
"strings"
)
var SupportedParsers map[string]Parser
type Parser interface {
UnmarshalReader(v *Viper, in io.Reader, c map[string]interface{}) error // Unmarshal a Reader into a map.
MarshalWriter(v *Viper, f afero.File, c map[string]interface{}) error // Marshal a map into Writer.
}
// add support parser
func AddParser(parser Parser, names ...string) {
for _, n := range names {
SupportedParsers[n] = parser
}
}
func parserInit() {
SupportedParsers = make(map[string]Parser, 32)
AddParser(&JsonParser{}, "json")
AddParser(&TomlParser{}, "toml")
AddParser(&YamlParser{}, "yaml", "yml")
AddParser(&PropsParser{}, "properties", "props", "prop")
AddParser(&HclParser{}, "hcl")
AddParser(&DotenvParser{}, "dotenv", "env")
AddParser(&IniParser{}, "ini")
}
func parserReset() {
parserInit()
}
// json parser
type JsonParser struct {
}
func (pp *JsonParser) UnmarshalReader(v *Viper, in io.Reader, c map[string]interface{}) error {
buf := new(bytes.Buffer)
buf.ReadFrom(in)
return json.Unmarshal(buf.Bytes(), &c)
}
func (pp *JsonParser) MarshalWriter(v *Viper, f afero.File, c map[string]interface{}) error {
b, err := json.MarshalIndent(c, "", " ")
if err != nil {
return err
}
_, err = f.WriteString(string(b))
return err
}
// toml parser
type TomlParser struct {
}
func (pp *TomlParser) UnmarshalReader(v *Viper, in io.Reader, c map[string]interface{}) error {
buf := new(bytes.Buffer)
buf.ReadFrom(in)
tree, err := toml.LoadReader(buf)
if err != nil {
return err
}
tmap := tree.ToMap()
for k, v := range tmap {
c[k] = v
}
return nil
}
func (pp *TomlParser) MarshalWriter(v *Viper, f afero.File, c map[string]interface{}) error {
t, err := toml.TreeFromMap(c)
if err != nil {
return err
}
s := t.String()
if _, err := f.WriteString(s); err != nil {
return err
}
return nil
}
// yaml parser
type YamlParser struct {
}
func (pp *YamlParser) UnmarshalReader(v *Viper, in io.Reader, c map[string]interface{}) error {
buf := new(bytes.Buffer)
buf.ReadFrom(in)
return yaml.Unmarshal(buf.Bytes(), &c)
}
func (pp *YamlParser) MarshalWriter(v *Viper, f afero.File, c map[string]interface{}) error {
b, err := yaml.Marshal(c)
if err != nil {
return err
}
if _, err = f.WriteString(string(b)); err != nil {
return err
}
return nil
}
// ini parser
type IniParser struct {
}
func (pp *IniParser) UnmarshalReader(v *Viper, in io.Reader, c map[string]interface{}) error {
buf := new(bytes.Buffer)
buf.ReadFrom(in)
cfg := ini.Empty()
err := cfg.Append(buf.Bytes())
if err != nil {
return err
}
sections := cfg.Sections()
for i := 0; i < len(sections); i++ {
section := sections[i]
keys := section.Keys()
for j := 0; j < len(keys); j++ {
key := keys[j]
value := cfg.Section(section.Name()).Key(key.Name()).String()
c[section.Name()+"."+key.Name()] = value
}
}
return nil
}
func (pp *IniParser) MarshalWriter(v *Viper, f afero.File, c map[string]interface{}) error {
keys := v.AllKeys()
cfg := ini.Empty()
ini.PrettyFormat = false
for i := 0; i < len(keys); i++ {
key := keys[i]
lastSep := strings.LastIndex(key, ".")
sectionName := key[:(lastSep)]
keyName := key[(lastSep + 1):]
if sectionName == "default" {
sectionName = ""
}
cfg.Section(sectionName).Key(keyName).SetValue(v.Get(key).(string))
}
cfg.WriteTo(f)
return nil
}
// hcl parser
type HclParser struct {
}
func (pp *HclParser) UnmarshalReader(v *Viper, in io.Reader, c map[string]interface{}) error {
buf := new(bytes.Buffer)
buf.ReadFrom(in)
obj, err := hcl.Parse(buf.String())
if err != nil {
return err
}
if err = hcl.DecodeObject(&c, obj); err != nil {
return err
}
return nil
}
func (pp *HclParser) MarshalWriter(v *Viper, f afero.File, c map[string]interface{}) error {
b, err := json.Marshal(c)
if err != nil {
return err
}
ast, err := hcl.Parse(string(b))
if err != nil {
return err
}
err = printer.Fprint(f, ast.Node)
if err != nil {
return err
}
return nil
}
// dot env parser
type DotenvParser struct {
}
func (pp *DotenvParser) UnmarshalReader(v *Viper, in io.Reader, c map[string]interface{}) error {
buf := new(bytes.Buffer)
buf.ReadFrom(in)
env, err := gotenv.StrictParse(buf)
if err != nil {
return err
}
for k, v := range env {
c[k] = v
}
return nil
}
func (pp *DotenvParser) MarshalWriter(v *Viper, f afero.File, c map[string]interface{}) error {
lines := []string{}
for _, key := range v.AllKeys() {
envName := strings.ToUpper(strings.Replace(key, ".", "_", -1))
val := v.Get(key)
lines = append(lines, fmt.Sprintf("%v=%v", envName, val))
}
s := strings.Join(lines, "\n")
if _, err := f.WriteString(s); err != nil {
return err
}
return nil
}
// props parser
type PropsParser struct {
}
func (pp *PropsParser) UnmarshalReader(v *Viper, in io.Reader, c map[string]interface{}) error {
buf := new(bytes.Buffer)
buf.ReadFrom(in)
v.properties = properties.NewProperties()
var err error
if v.properties, err = properties.Load(buf.Bytes(), properties.UTF8); err != nil {
return ConfigParseError{err}
}
for _, key := range v.properties.Keys() {
value, _ := v.properties.Get(key)
// recursively build nested maps
path := strings.Split(key, ".")
lastKey := strings.ToLower(path[len(path)-1])
deepestMap := deepSearch(c, path[0:len(path)-1])
// set innermost value
deepestMap[lastKey] = value
}
return nil
}
func (pp *PropsParser) MarshalWriter(v *Viper, f afero.File, c map[string]interface{}) error {
if v.properties == nil {
v.properties = properties.NewProperties()
}
p := v.properties
for _, key := range v.AllKeys() {
_, _, err := p.Set(key, v.GetString(key))
if err != nil {
return err
}
}
_, err := p.WriteComment(f, "#", properties.UTF8)
if err != nil {
return err
}
return nil
}

186
viper.go
View file

@ -22,7 +22,6 @@ package viper
import (
"bytes"
"encoding/csv"
"encoding/json"
"errors"
"fmt"
"io"
@ -35,18 +34,12 @@ import (
"time"
"github.com/fsnotify/fsnotify"
"github.com/hashicorp/hcl"
"github.com/hashicorp/hcl/hcl/printer"
"github.com/magiconair/properties"
"github.com/mitchellh/mapstructure"
"github.com/pelletier/go-toml"
"github.com/spf13/afero"
"github.com/spf13/cast"
jww "github.com/spf13/jwalterweatherman"
"github.com/spf13/pflag"
"github.com/subosito/gotenv"
"gopkg.in/ini.v1"
"gopkg.in/yaml.v2"
)
// ConfigMarshalError happens when failing to marshal the configuration.
@ -68,6 +61,7 @@ type RemoteResponse struct {
func init() {
v = New()
parserInit()
}
type remoteConfigFactory interface {
@ -286,7 +280,7 @@ func NewWithOptions(opts ...Option) *Viper {
// can use it in their testing as well.
func Reset() {
v = New()
SupportedExts = []string{"json", "toml", "yaml", "yml", "properties", "props", "prop", "hcl", "dotenv", "env", "ini"}
parserReset()
SupportedRemoteProviders = []string{"etcd", "consul", "firestore"}
}
@ -324,9 +318,6 @@ type RemoteProvider interface {
SecretKeyring() string
}
// SupportedExts are universally supported extensions.
var SupportedExts = []string{"json", "toml", "yaml", "yml", "properties", "props", "prop", "hcl", "dotenv", "env", "ini"}
// SupportedRemoteProviders are universally supported remote providers.
var SupportedRemoteProviders = []string{"etcd", "consul", "firestore"}
@ -905,6 +896,7 @@ func Unmarshal(rawVal interface{}, opts ...DecoderConfigOption) error {
return v.Unmarshal(rawVal, opts...)
}
func (v *Viper) Unmarshal(rawVal interface{}, opts ...DecoderConfigOption) error {
fmt.Println(v.AllSettings())
return decode(v.AllSettings(), defaultDecoderConfig(rawVal, opts...))
}
@ -1314,7 +1306,7 @@ func (v *Viper) ReadInConfig() error {
return err
}
if !stringInSlice(v.getConfigType(), SupportedExts) {
if _, ok := SupportedParsers[strings.ToLower(v.getConfigType())]; !ok {
return UnsupportedConfigError(v.getConfigType())
}
@ -1344,7 +1336,7 @@ func (v *Viper) MergeInConfig() error {
return err
}
if !stringInSlice(v.getConfigType(), SupportedExts) {
if _, ok := SupportedParsers[strings.ToLower(v.getConfigType())]; !ok {
return UnsupportedConfigError(v.getConfigType())
}
@ -1435,7 +1427,7 @@ func (v *Viper) writeConfig(filename string, force bool) error {
return fmt.Errorf("config type could not be determined for %s", filename)
}
if !stringInSlice(configType, SupportedExts) {
if _, ok := SupportedParsers[strings.ToLower(configType)]; !ok {
return UnsupportedConfigError(configType)
}
if v.config == nil {
@ -1464,82 +1456,13 @@ func unmarshalReader(in io.Reader, c map[string]interface{}) error {
return v.unmarshalReader(in, c)
}
func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error {
buf := new(bytes.Buffer)
buf.ReadFrom(in)
switch strings.ToLower(v.getConfigType()) {
case "yaml", "yml":
if err := yaml.Unmarshal(buf.Bytes(), &c); err != nil {
return ConfigParseError{err}
}
case "json":
if err := json.Unmarshal(buf.Bytes(), &c); err != nil {
return ConfigParseError{err}
}
case "hcl":
obj, err := hcl.Parse(buf.String())
p, ok := SupportedParsers[strings.ToLower(v.getConfigType())]
if ok {
err := p.UnmarshalReader(v, in, c)
if err != nil {
return ConfigParseError{err}
}
if err = hcl.DecodeObject(&c, obj); err != nil {
return ConfigParseError{err}
}
case "toml":
tree, err := toml.LoadReader(buf)
if err != nil {
return ConfigParseError{err}
}
tmap := tree.ToMap()
for k, v := range tmap {
c[k] = v
}
case "dotenv", "env":
env, err := gotenv.StrictParse(buf)
if err != nil {
return ConfigParseError{err}
}
for k, v := range env {
c[k] = v
}
case "properties", "props", "prop":
v.properties = properties.NewProperties()
var err error
if v.properties, err = properties.Load(buf.Bytes(), properties.UTF8); err != nil {
return ConfigParseError{err}
}
for _, key := range v.properties.Keys() {
value, _ := v.properties.Get(key)
// recursively build nested maps
path := strings.Split(key, ".")
lastKey := strings.ToLower(path[len(path)-1])
deepestMap := deepSearch(c, path[0:len(path)-1])
// set innermost value
deepestMap[lastKey] = value
}
case "ini":
cfg := ini.Empty()
err := cfg.Append(buf.Bytes())
if err != nil {
return ConfigParseError{err}
}
sections := cfg.Sections()
for i := 0; i < len(sections); i++ {
section := sections[i]
keys := section.Keys()
for j := 0; j < len(keys); j++ {
key := keys[j]
value := cfg.Section(section.Name()).Key(key.Name()).String()
c[section.Name()+"."+key.Name()] = value
}
}
}
insensitiviseMap(c)
return nil
}
@ -1547,93 +1470,12 @@ func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error {
// Marshal a map into Writer.
func (v *Viper) marshalWriter(f afero.File, configType string) error {
c := v.AllSettings()
switch configType {
case "json":
b, err := json.MarshalIndent(c, "", " ")
p, ok := SupportedParsers[strings.ToLower(configType)]
if ok {
err := p.MarshalWriter(v, f, c)
if err != nil {
return ConfigMarshalError{err}
return ConfigParseError{err}
}
_, err = f.WriteString(string(b))
if err != nil {
return ConfigMarshalError{err}
}
case "hcl":
b, err := json.Marshal(c)
if err != nil {
return ConfigMarshalError{err}
}
ast, err := hcl.Parse(string(b))
if err != nil {
return ConfigMarshalError{err}
}
err = printer.Fprint(f, ast.Node)
if err != nil {
return ConfigMarshalError{err}
}
case "prop", "props", "properties":
if v.properties == nil {
v.properties = properties.NewProperties()
}
p := v.properties
for _, key := range v.AllKeys() {
_, _, err := p.Set(key, v.GetString(key))
if err != nil {
return ConfigMarshalError{err}
}
}
_, err := p.WriteComment(f, "#", properties.UTF8)
if err != nil {
return ConfigMarshalError{err}
}
case "dotenv", "env":
lines := []string{}
for _, key := range v.AllKeys() {
envName := strings.ToUpper(strings.Replace(key, ".", "_", -1))
val := v.Get(key)
lines = append(lines, fmt.Sprintf("%v=%v", envName, val))
}
s := strings.Join(lines, "\n")
if _, err := f.WriteString(s); err != nil {
return ConfigMarshalError{err}
}
case "toml":
t, err := toml.TreeFromMap(c)
if err != nil {
return ConfigMarshalError{err}
}
s := t.String()
if _, err := f.WriteString(s); err != nil {
return ConfigMarshalError{err}
}
case "yaml", "yml":
b, err := yaml.Marshal(c)
if err != nil {
return ConfigMarshalError{err}
}
if _, err = f.WriteString(string(b)); err != nil {
return ConfigMarshalError{err}
}
case "ini":
keys := v.AllKeys()
cfg := ini.Empty()
ini.PrettyFormat = false
for i := 0; i < len(keys); i++ {
key := keys[i]
lastSep := strings.LastIndex(key, ".")
sectionName := key[:(lastSep)]
keyName := key[(lastSep + 1):]
if sectionName == "default" {
sectionName = ""
}
cfg.Section(sectionName).Key(keyName).SetValue(v.Get(key).(string))
}
cfg.WriteTo(f)
}
return nil
}
@ -1980,7 +1822,7 @@ func (v *Viper) getConfigFile() (string, error) {
func (v *Viper) searchInPath(in string) (filename string) {
jww.DEBUG.Println("Searching for config in ", in)
for _, ext := range SupportedExts {
for ext := range SupportedParsers {
jww.DEBUG.Println("Checking for", filepath.Join(in, v.configName+"."+ext))
if b, _ := exists(v.fs, filepath.Join(in, v.configName+"."+ext)); b {
jww.DEBUG.Println("Found: ", filepath.Join(in, v.configName+"."+ext))