Files
opencloud/pkg/config/envdecode/envdecode.go
Jörn Friedrich Dreyer b07b5a1149 use plain pkg module
Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>
2025-01-13 16:42:19 +01:00

437 lines
10 KiB
Go

// Package envdecode is a package for populating structs from environment
// variables, using struct tags.
package envdecode
import (
"encoding"
"errors"
"fmt"
"log"
"net/url"
"os"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
// ErrInvalidTarget indicates that the target value passed to
// Decode is invalid. Target must be a non-nil pointer to a struct.
var ErrInvalidTarget = errors.New("target must be non-nil pointer to struct that has at least one exported field with a valid env tag")
var ErrNoTargetFieldsAreSet = errors.New("none of the target fields were set from environment variables")
// FailureFunc is called when an error is encountered during a MustDecode
// operation. It prints the error and terminates the process.
//
// This variable can be assigned to another function of the user-programmer's
// design, allowing for graceful recovery of the problem, such as loading
// from a backup configuration file.
var FailureFunc = func(err error) {
log.Fatalf("envdecode: an error was encountered while decoding: %v\n", err)
}
// Decoder is the interface implemented by an object that can decode an
// environment variable string representation of itself.
type Decoder interface {
Decode(string) error
}
// Decode environment variables into the provided target. The target
// must be a non-nil pointer to a struct. Fields in the struct must
// be exported, and tagged with an "env" struct tag with a value
// containing the name of the environment variable. An error is
// returned if there are no exported members tagged.
//
// Default values may be provided by appending ",default=value" to the
// struct tag. Required values may be marked by appending ",required"
// to the struct tag. It is an error to provide both "default" and
// "required". Strict values may be marked by appending ",strict" which
// will return an error on Decode if there is an error while parsing.
// If everything must be strict, consider using StrictDecode instead.
//
// All primitive types are supported, including bool, floating point,
// signed and unsigned integers, and string. Boolean and numeric
// types are decoded using the standard strconv Parse functions for
// those types. Structs and pointers to structs are decoded
// recursively. time.Duration is supported via the
// time.ParseDuration() function and *url.URL is supported via the
// url.Parse() function. Slices are supported for all above mentioned
// primitive types. Semicolon is used as delimiter in environment variables.
func Decode(target interface{}) error {
nFields, err := decode(target, false)
if err != nil {
return err
}
// if we didn't do anything - the user probably did something
// wrong like leave all fields unexported.
if nFields == 0 {
return ErrNoTargetFieldsAreSet
}
return nil
}
// StrictDecode is similar to Decode except all fields will have an implicit
// ",strict" on all fields.
func StrictDecode(target interface{}) error {
nFields, err := decode(target, true)
if err != nil {
return err
}
// if we didn't do anything - the user probably did something
// wrong like leave all fields unexported.
if nFields == 0 {
return ErrInvalidTarget
}
return nil
}
func decode(target interface{}, strict bool) (int, error) {
s := reflect.ValueOf(target)
if s.Kind() != reflect.Ptr || s.IsNil() {
return 0, ErrInvalidTarget
}
s = s.Elem()
if s.Kind() != reflect.Struct {
return 0, ErrInvalidTarget
}
t := s.Type()
setFieldCount := 0
for i := 0; i < s.NumField(); i++ {
// Localize the umbrella `strict` value to the specific field.
strict := strict
f := s.Field(i)
switch f.Kind() {
case reflect.Ptr:
if f.Elem().Kind() != reflect.Struct {
break
}
f = f.Elem()
fallthrough
case reflect.Struct:
if !f.Addr().CanInterface() {
continue
}
ss := f.Addr().Interface()
_, custom := ss.(Decoder)
if custom {
break
}
n, err := decode(ss, strict)
if err != nil {
return 0, err
}
setFieldCount += n
}
if !f.CanSet() {
continue
}
tag := t.Field(i).Tag.Get("env")
if tag == "" {
continue
}
parts := strings.Split(tag, ",")
overrides := strings.Split(parts[0], `;`)
var env string
var envSet bool
for _, override := range overrides {
if v, set := os.LookupEnv(override); set {
env = v
envSet = true
}
}
required := false
hasDefault := false
defaultValue := ""
for _, o := range parts[1:] {
if !required {
required = strings.HasPrefix(o, "required")
}
if strings.HasPrefix(o, "default=") {
hasDefault = true
defaultValue = o[8:]
}
if !strict {
strict = strings.HasPrefix(o, "strict")
}
}
if required && hasDefault {
panic(`envdecode: "default" and "required" may not be specified in the same annotation`)
}
if !envSet && required {
return 0, fmt.Errorf("the environment variable \"%s\" is missing", parts[0])
}
if !envSet {
env = defaultValue
}
if !envSet && env == "" {
continue
}
setFieldCount++
unmarshaler, implementsUnmarshaler := f.Addr().Interface().(encoding.TextUnmarshaler)
decoder, implmentsDecoder := f.Addr().Interface().(Decoder)
if implmentsDecoder {
if err := decoder.Decode(env); err != nil {
return 0, err
}
} else if implementsUnmarshaler {
if err := unmarshaler.UnmarshalText([]byte(env)); err != nil {
return 0, err
}
} else if f.Kind() == reflect.Slice {
if err := decodeSlice(&f, env); err != nil {
return 0, err
}
} else {
if err := decodePrimitiveType(&f, env); err != nil && strict {
return 0, err
}
}
}
return setFieldCount, nil
}
func decodeSlice(f *reflect.Value, env string) error {
parts := strings.Split(env, ",")
values := parts[:0]
for _, x := range parts {
if x != "" {
values = append(values, strings.TrimSpace(x))
}
}
valuesCount := len(values)
slice := reflect.MakeSlice(f.Type(), valuesCount, valuesCount)
if valuesCount > 0 {
for i := 0; i < valuesCount; i++ {
e := slice.Index(i)
err := decodePrimitiveType(&e, values[i])
if err != nil {
return err
}
}
}
f.Set(slice)
return nil
}
func decodePrimitiveType(f *reflect.Value, env string) error {
switch f.Kind() {
case reflect.Bool:
v, err := strconv.ParseBool(env)
if err != nil {
return err
}
f.SetBool(v)
case reflect.Float32, reflect.Float64:
bits := f.Type().Bits()
v, err := strconv.ParseFloat(env, bits)
if err != nil {
return err
}
f.SetFloat(v)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if t := f.Type(); t.PkgPath() == "time" && t.Name() == "Duration" {
v, err := time.ParseDuration(env)
if err != nil {
return err
}
f.SetInt(int64(v))
} else {
bits := f.Type().Bits()
v, err := strconv.ParseInt(env, 0, bits)
if err != nil {
return err
}
f.SetInt(v)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
bits := f.Type().Bits()
v, err := strconv.ParseUint(env, 0, bits)
if err != nil {
return err
}
f.SetUint(v)
case reflect.String:
f.SetString(env)
case reflect.Ptr:
if t := f.Type().Elem(); t.Kind() == reflect.Struct && t.PkgPath() == "net/url" && t.Name() == "URL" {
v, err := url.Parse(env)
if err != nil {
return err
}
f.Set(reflect.ValueOf(v))
}
}
return nil
}
// MustDecode calls Decode and terminates the process if any errors
// are encountered.
func MustDecode(target interface{}) {
if err := Decode(target); err != nil {
FailureFunc(err)
}
}
// MustStrictDecode calls StrictDecode and terminates the process if any errors
// are encountered.
func MustStrictDecode(target interface{}) {
if err := StrictDecode(target); err != nil {
FailureFunc(err)
}
}
//// Configuration info for Export
type ConfigInfo struct {
Field string
EnvVar string
Value string
DefaultValue string
HasDefault bool
Required bool
UsesEnv bool
}
type ConfigInfoSlice []*ConfigInfo
func (c ConfigInfoSlice) Less(i, j int) bool {
return c[i].EnvVar < c[j].EnvVar
}
func (c ConfigInfoSlice) Len() int {
return len(c)
}
func (c ConfigInfoSlice) Swap(i, j int) {
c[i], c[j] = c[j], c[i]
}
// Returns a list of final configuration metadata sorted by envvar name
func Export(target interface{}) ([]*ConfigInfo, error) {
s := reflect.ValueOf(target)
if s.Kind() != reflect.Ptr || s.IsNil() {
return nil, ErrInvalidTarget
}
cfg := []*ConfigInfo{}
s = s.Elem()
if s.Kind() != reflect.Struct {
return nil, ErrInvalidTarget
}
t := s.Type()
for i := 0; i < s.NumField(); i++ {
f := s.Field(i)
fName := t.Field(i).Name
fElem := f
if f.Kind() == reflect.Ptr {
fElem = f.Elem()
}
if fElem.Kind() == reflect.Struct {
ss := fElem.Addr().Interface()
subCfg, err := Export(ss)
if err != ErrInvalidTarget {
f = fElem
for _, v := range subCfg {
v.Field = fmt.Sprintf("%s.%s", fName, v.Field)
cfg = append(cfg, v)
}
}
}
tag := t.Field(i).Tag.Get("env")
if tag == "" {
continue
}
parts := strings.Split(tag, ",")
ci := &ConfigInfo{
Field: fName,
EnvVar: parts[0],
UsesEnv: os.Getenv(parts[0]) != "",
}
for _, o := range parts[1:] {
if strings.HasPrefix(o, "default=") {
ci.HasDefault = true
ci.DefaultValue = o[8:]
} else if strings.HasPrefix(o, "required") {
ci.Required = true
}
}
if f.Kind() == reflect.Ptr && f.IsNil() {
ci.Value = ""
} else if stringer, ok := f.Interface().(fmt.Stringer); ok {
ci.Value = stringer.String()
} else {
switch f.Kind() {
case reflect.Bool:
ci.Value = strconv.FormatBool(f.Bool())
case reflect.Float32, reflect.Float64:
bits := f.Type().Bits()
ci.Value = strconv.FormatFloat(f.Float(), 'f', -1, bits)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
ci.Value = strconv.FormatInt(f.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
ci.Value = strconv.FormatUint(f.Uint(), 10)
case reflect.String:
ci.Value = f.String()
case reflect.Slice:
ci.Value = fmt.Sprintf("%v", f.Interface())
default:
// Unable to determine string format for value
return nil, ErrInvalidTarget
}
}
cfg = append(cfg, ci)
}
// No configuration tags found, assume invalid input
if len(cfg) == 0 {
return nil, ErrInvalidTarget
}
sort.Sort(ConfigInfoSlice(cfg))
return cfg, nil
}