mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-04-24 04:58:31 -05:00
use plain pkg module
Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>
This commit is contained in:
committed by
Florian Schade
parent
259cbc2e56
commit
b07b5a1149
@@ -0,0 +1 @@
|
||||
!config
|
||||
@@ -0,0 +1,9 @@
|
||||
with-expecter: true
|
||||
filename: "{{.InterfaceName | snakecase }}.go"
|
||||
dir: "{{.PackageName}}/mocks"
|
||||
mockname: "{{.InterfaceName}}"
|
||||
outpkg: "mocks"
|
||||
packages:
|
||||
github.com/opencloud-eu/opencloud/ocis-pkg/oidc:
|
||||
interfaces:
|
||||
OIDCClient:
|
||||
@@ -0,0 +1,40 @@
|
||||
SHELL := bash
|
||||
NAME := ocis-pkg
|
||||
|
||||
include ../.make/recursion.mk
|
||||
|
||||
############ tooling ############
|
||||
ifneq (, $(shell command -v go 2> /dev/null)) # suppress `command not found warnings` for non go targets in CI
|
||||
include ../.bingo/Variables.mk
|
||||
endif
|
||||
|
||||
############ go tooling ############
|
||||
include ../.make/go.mk
|
||||
|
||||
############ release ############
|
||||
include ../.make/release.mk
|
||||
|
||||
############ docs generate ############
|
||||
SKIP_CONFIG_DOCS_GENERATE = 1
|
||||
|
||||
include ../.make/docs.mk
|
||||
|
||||
.PHONY: docs-generate
|
||||
docs-generate:
|
||||
|
||||
############ generate ############
|
||||
include ../.make/generate.mk
|
||||
|
||||
.PHONY: ci-go-generate
|
||||
ci-go-generate: $(MOCKERY) # CI runs ci-node-generate automatically before this target
|
||||
$(MOCKERY)
|
||||
|
||||
.PHONY: ci-node-generate
|
||||
ci-node-generate:
|
||||
|
||||
############ licenses ############
|
||||
.PHONY: ci-node-check-licenses
|
||||
ci-node-check-licenses:
|
||||
|
||||
.PHONY: ci-node-save-licenses
|
||||
ci-node-save-licenses:
|
||||
@@ -0,0 +1,30 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
)
|
||||
|
||||
// Option defines a single option function.
|
||||
type Option func(o *Options)
|
||||
|
||||
// Options defines the available options for this package.
|
||||
type Options struct {
|
||||
// Logger to use for logging, must be set
|
||||
Logger log.Logger
|
||||
// JWTSecret is the jwt secret for the reva token manager
|
||||
JWTSecret string
|
||||
}
|
||||
|
||||
// Logger provides a function to set the logger option.
|
||||
func Logger(l log.Logger) Option {
|
||||
return func(o *Options) {
|
||||
o.Logger = l
|
||||
}
|
||||
}
|
||||
|
||||
// JWTSecret provides a function to set the jwt secret option.
|
||||
func JWTSecret(s string) Option {
|
||||
return func(o *Options) {
|
||||
o.JWTSecret = s
|
||||
}
|
||||
}
|
||||
+107
@@ -0,0 +1,107 @@
|
||||
// Package ast provides available ast nodes.
|
||||
package ast
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Node represents abstract syntax tree node
|
||||
type Node interface {
|
||||
Location() *Location
|
||||
}
|
||||
|
||||
// Position represents a specific location in the source
|
||||
type Position struct {
|
||||
Line int
|
||||
Column int
|
||||
}
|
||||
|
||||
// Location represents the location of a node in the AST
|
||||
type Location struct {
|
||||
Start Position `json:"start"`
|
||||
End Position `json:"end"`
|
||||
Source *string `json:"source,omitempty"`
|
||||
}
|
||||
|
||||
// Base contains shared node attributes
|
||||
// each node should inherit from this
|
||||
type Base struct {
|
||||
Loc *Location
|
||||
}
|
||||
|
||||
// Location is the source location of the Node
|
||||
func (b *Base) Location() *Location { return b.Loc }
|
||||
|
||||
// Ast represents the query - node structure as abstract syntax tree
|
||||
type Ast struct {
|
||||
*Base
|
||||
Nodes []Node `json:"body"`
|
||||
}
|
||||
|
||||
// StringNode represents a string value
|
||||
type StringNode struct {
|
||||
*Base
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
// BooleanNode represents a bool value
|
||||
type BooleanNode struct {
|
||||
*Base
|
||||
Key string
|
||||
Value bool
|
||||
}
|
||||
|
||||
// DateTimeNode represents a time.Time value
|
||||
type DateTimeNode struct {
|
||||
*Base
|
||||
Key string
|
||||
Operator *OperatorNode
|
||||
Value time.Time
|
||||
}
|
||||
|
||||
// OperatorNode represents an operator value like
|
||||
// AND, OR, NOT, =, <= ... and so on
|
||||
type OperatorNode struct {
|
||||
*Base
|
||||
Value string
|
||||
}
|
||||
|
||||
// GroupNode represents a collection of many grouped nodes
|
||||
type GroupNode struct {
|
||||
*Base
|
||||
Key string
|
||||
Nodes []Node
|
||||
}
|
||||
|
||||
// NodeKey tries to return the node key
|
||||
func NodeKey(n Node) string {
|
||||
switch node := n.(type) {
|
||||
case *StringNode:
|
||||
return node.Key
|
||||
case *DateTimeNode:
|
||||
return node.Key
|
||||
case *BooleanNode:
|
||||
return node.Key
|
||||
case *GroupNode:
|
||||
return node.Key
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// NodeValue tries to return the node key
|
||||
func NodeValue(n Node) interface{} {
|
||||
switch node := n.(type) {
|
||||
case *StringNode:
|
||||
return node.Value
|
||||
case *DateTimeNode:
|
||||
return node.Value
|
||||
case *BooleanNode:
|
||||
return node.Value
|
||||
case *GroupNode:
|
||||
return node.Nodes
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
// Package test provides shared test primitives for ast testing.
|
||||
package test
|
||||
|
||||
import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/opencloud-eu/opencloud/pkg/ast"
|
||||
)
|
||||
|
||||
// DiffAst returns a human-readable report of the differences between two values
|
||||
// by default it ignores every ast node Base field.
|
||||
func DiffAst(x, y interface{}, opts ...cmp.Option) string {
|
||||
return cmp.Diff(
|
||||
x,
|
||||
y,
|
||||
append(
|
||||
opts,
|
||||
cmpopts.IgnoreFields(ast.Ast{}, "Base"),
|
||||
cmpopts.IgnoreFields(ast.StringNode{}, "Base"),
|
||||
cmpopts.IgnoreFields(ast.OperatorNode{}, "Base"),
|
||||
cmpopts.IgnoreFields(ast.GroupNode{}, "Base"),
|
||||
cmpopts.IgnoreFields(ast.BooleanNode{}, "Base"),
|
||||
cmpopts.IgnoreFields(ast.DateTimeNode{}, "Base"),
|
||||
)...,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package broker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"go-micro.dev/v4/broker"
|
||||
)
|
||||
|
||||
type NoOp struct{}
|
||||
|
||||
func (n NoOp) Init(_ ...broker.Option) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n NoOp) Options() broker.Options {
|
||||
return broker.Options{}
|
||||
}
|
||||
|
||||
func (n NoOp) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (n NoOp) Connect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n NoOp) Disconnect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n NoOp) Publish(topic string, m *broker.Message, opts ...broker.PublishOption) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n NoOp) Subscribe(topic string, h broker.Handler, opts ...broker.SubscribeOption) (broker.Subscriber, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (n NoOp) String() string {
|
||||
return "NoOp"
|
||||
}
|
||||
|
||||
func NewBroker(_ ...broker.Option) broker.Broker {
|
||||
return &NoOp{}
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package capabilities
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/cs3org/reva/v2/pkg/owncloud/ocs"
|
||||
)
|
||||
|
||||
// allow the consuming part to change defaults, e.g., tests
|
||||
var defaultCapabilities atomic.Pointer[ocs.Capabilities]
|
||||
|
||||
func init() { //nolint:gochecknoinits
|
||||
ResetDefault()
|
||||
}
|
||||
|
||||
// ResetDefault resets the default [Capabilities] to the default values.
|
||||
func ResetDefault() {
|
||||
defaultCapabilities.Store(
|
||||
&ocs.Capabilities{
|
||||
Theme: &ocs.CapabilitiesTheme{
|
||||
Logo: &ocs.CapabilitiesThemeLogo{
|
||||
PermittedFileTypes: map[string]string{
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".gif": "image/gif",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Default returns the default [Capabilities].
|
||||
func Default() *ocs.Capabilities { return defaultCapabilities.Load() }
|
||||
@@ -0,0 +1,27 @@
|
||||
package checks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/handlers"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
// NewGRPCCheck checks the reachability of a grpc server.
|
||||
func NewGRPCCheck(address string) func(context.Context) error {
|
||||
return func(_ context.Context) error {
|
||||
address, err := handlers.FailSaveAddress(address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := grpc.NewClient(address, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not connect to grpc server: %v", err)
|
||||
}
|
||||
_ = conn.Close()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package checks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/handlers"
|
||||
)
|
||||
|
||||
// NewHTTPCheck checks the reachability of a http server.
|
||||
func NewHTTPCheck(url string) func(context.Context) error {
|
||||
return func(_ context.Context) error {
|
||||
url, err := handlers.FailSaveAddress(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
|
||||
url = "http://" + url
|
||||
}
|
||||
|
||||
c := http.Client{
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
resp, err := c.Get(url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not connect to http server: %v", err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package checks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
// NewNatsCheck checks the reachability of a nats server.
|
||||
func NewNatsCheck(natsCluster string, options ...nats.Option) func(context.Context) error {
|
||||
return func(_ context.Context) error {
|
||||
n, err := nats.Connect(natsCluster, options...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not connect to nats server: %v", err)
|
||||
}
|
||||
defer n.Close()
|
||||
if n.Status() != nats.CONNECTED {
|
||||
return fmt.Errorf("nats server not connected")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package checks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/handlers"
|
||||
)
|
||||
|
||||
// NewTCPCheck returns a check that connects to a given tcp endpoint.
|
||||
func NewTCPCheck(address string) func(context.Context) error {
|
||||
return func(_ context.Context) error {
|
||||
address, err := handlers.FailSaveAddress(address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("tcp", address, 3*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package clihelper
|
||||
|
||||
import (
|
||||
"github.com/opencloud-eu/opencloud/pkg/version"
|
||||
"github.com/urfave/cli/v2"
|
||||
)
|
||||
|
||||
func DefaultApp(app *cli.App) *cli.App {
|
||||
// version info
|
||||
app.Version = version.String
|
||||
app.Compiled = version.Compiled()
|
||||
|
||||
// author info
|
||||
app.Authors = []*cli.Author{
|
||||
{
|
||||
Name: "ownCloud GmbH",
|
||||
Email: "support@owncloud.com",
|
||||
},
|
||||
}
|
||||
|
||||
// disable global version flag
|
||||
// instead we provide the version command
|
||||
app.HideVersion = true
|
||||
|
||||
return app
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/opencloud-eu/opencloud/pkg/shared"
|
||||
activitylog "github.com/opencloud-eu/opencloud/services/activitylog/pkg/config"
|
||||
antivirus "github.com/opencloud-eu/opencloud/services/antivirus/pkg/config"
|
||||
appProvider "github.com/opencloud-eu/opencloud/services/app-provider/pkg/config"
|
||||
appRegistry "github.com/opencloud-eu/opencloud/services/app-registry/pkg/config"
|
||||
audit "github.com/opencloud-eu/opencloud/services/audit/pkg/config"
|
||||
authapp "github.com/opencloud-eu/opencloud/services/auth-app/pkg/config"
|
||||
authbasic "github.com/opencloud-eu/opencloud/services/auth-basic/pkg/config"
|
||||
authbearer "github.com/opencloud-eu/opencloud/services/auth-bearer/pkg/config"
|
||||
authmachine "github.com/opencloud-eu/opencloud/services/auth-machine/pkg/config"
|
||||
authservice "github.com/opencloud-eu/opencloud/services/auth-service/pkg/config"
|
||||
clientlog "github.com/opencloud-eu/opencloud/services/clientlog/pkg/config"
|
||||
collaboration "github.com/opencloud-eu/opencloud/services/collaboration/pkg/config"
|
||||
eventhistory "github.com/opencloud-eu/opencloud/services/eventhistory/pkg/config"
|
||||
frontend "github.com/opencloud-eu/opencloud/services/frontend/pkg/config"
|
||||
gateway "github.com/opencloud-eu/opencloud/services/gateway/pkg/config"
|
||||
graph "github.com/opencloud-eu/opencloud/services/graph/pkg/config"
|
||||
groups "github.com/opencloud-eu/opencloud/services/groups/pkg/config"
|
||||
idm "github.com/opencloud-eu/opencloud/services/idm/pkg/config"
|
||||
idp "github.com/opencloud-eu/opencloud/services/idp/pkg/config"
|
||||
invitations "github.com/opencloud-eu/opencloud/services/invitations/pkg/config"
|
||||
nats "github.com/opencloud-eu/opencloud/services/nats/pkg/config"
|
||||
notifications "github.com/opencloud-eu/opencloud/services/notifications/pkg/config"
|
||||
ocdav "github.com/opencloud-eu/opencloud/services/ocdav/pkg/config"
|
||||
ocm "github.com/opencloud-eu/opencloud/services/ocm/pkg/config"
|
||||
ocs "github.com/opencloud-eu/opencloud/services/ocs/pkg/config"
|
||||
policies "github.com/opencloud-eu/opencloud/services/policies/pkg/config"
|
||||
postprocessing "github.com/opencloud-eu/opencloud/services/postprocessing/pkg/config"
|
||||
proxy "github.com/opencloud-eu/opencloud/services/proxy/pkg/config"
|
||||
search "github.com/opencloud-eu/opencloud/services/search/pkg/config"
|
||||
settings "github.com/opencloud-eu/opencloud/services/settings/pkg/config"
|
||||
sharing "github.com/opencloud-eu/opencloud/services/sharing/pkg/config"
|
||||
sse "github.com/opencloud-eu/opencloud/services/sse/pkg/config"
|
||||
storagepublic "github.com/opencloud-eu/opencloud/services/storage-publiclink/pkg/config"
|
||||
storageshares "github.com/opencloud-eu/opencloud/services/storage-shares/pkg/config"
|
||||
storagesystem "github.com/opencloud-eu/opencloud/services/storage-system/pkg/config"
|
||||
storageusers "github.com/opencloud-eu/opencloud/services/storage-users/pkg/config"
|
||||
thumbnails "github.com/opencloud-eu/opencloud/services/thumbnails/pkg/config"
|
||||
userlog "github.com/opencloud-eu/opencloud/services/userlog/pkg/config"
|
||||
users "github.com/opencloud-eu/opencloud/services/users/pkg/config"
|
||||
web "github.com/opencloud-eu/opencloud/services/web/pkg/config"
|
||||
webdav "github.com/opencloud-eu/opencloud/services/webdav/pkg/config"
|
||||
webfinger "github.com/opencloud-eu/opencloud/services/webfinger/pkg/config"
|
||||
)
|
||||
|
||||
type Mode int
|
||||
|
||||
// Runtime configures the oCIS runtime when running in supervised mode.
|
||||
type Runtime struct {
|
||||
Port string `yaml:"port" env:"OC_RUNTIME_PORT" desc:"The TCP port at which oCIS will be available" introductionVersion:"pre5.0"`
|
||||
Host string `yaml:"host" env:"OC_RUNTIME_HOST" desc:"The host at which oCIS will be available" introductionVersion:"pre5.0"`
|
||||
Services []string `yaml:"services" env:"OC_RUN_EXTENSIONS;OC_RUN_SERVICES" desc:"A comma-separated list of service names. Will start only the listed services." introductionVersion:"pre5.0"`
|
||||
Disabled []string `yaml:"disabled_services" env:"OC_EXCLUDE_RUN_SERVICES" desc:"A comma-separated list of service names. Will start all default services except of the ones listed. Has no effect when OC_RUN_SERVICES is set." introductionVersion:"pre5.0"`
|
||||
Additional []string `yaml:"add_services" env:"OC_ADD_RUN_SERVICES" desc:"A comma-separated list of service names. Will add the listed services to the default configuration. Has no effect when OC_RUN_SERVICES is set. Note that one can add services not started by the default list and exclude services from the default list by using both envvars at the same time." introductionVersion:"pre5.0"`
|
||||
}
|
||||
|
||||
// Config combines all available configuration parts.
|
||||
type Config struct {
|
||||
*shared.Commons `yaml:"shared"`
|
||||
|
||||
Tracing *shared.Tracing `yaml:"tracing"`
|
||||
Log *shared.Log `yaml:"log"`
|
||||
Cache *shared.Cache `yaml:"cache"`
|
||||
GRPCClientTLS *shared.GRPCClientTLS `yaml:"grpc_client_tls"`
|
||||
GRPCServiceTLS *shared.GRPCServiceTLS `yaml:"grpc_service_tls"`
|
||||
HTTPServiceTLS shared.HTTPServiceTLS `yaml:"http_service_tls"`
|
||||
Reva *shared.Reva `yaml:"reva"`
|
||||
|
||||
Mode Mode // DEPRECATED
|
||||
File string
|
||||
OcisURL string `yaml:"ocis_url" env:"OC_URL" desc:"URL, where oCIS is reachable for users." introductionVersion:"pre5.0"`
|
||||
|
||||
Registry string `yaml:"registry"`
|
||||
TokenManager *shared.TokenManager `yaml:"token_manager"`
|
||||
MachineAuthAPIKey string `yaml:"machine_auth_api_key" env:"OC_MACHINE_AUTH_API_KEY" desc:"Machine auth API key used to validate internal requests necessary for the access to resources from other services." introductionVersion:"pre5.0"`
|
||||
TransferSecret string `yaml:"transfer_secret" env:"OC_TRANSFER_SECRET" desc:"Transfer secret for signing file up- and download requests." introductionVersion:"pre5.0"`
|
||||
SystemUserID string `yaml:"system_user_id" env:"OC_SYSTEM_USER_ID" desc:"ID of the oCIS storage-system system user. Admins need to set the ID for the storage-system system user in this config option which is then used to reference the user. Any reasonable long string is possible, preferably this would be an UUIDv4 format." introductionVersion:"pre5.0"`
|
||||
SystemUserAPIKey string `yaml:"system_user_api_key" env:"OC_SYSTEM_USER_API_KEY" desc:"API key for the storage-system system user." introductionVersion:"pre5.0"`
|
||||
AdminUserID string `yaml:"admin_user_id" env:"OC_ADMIN_USER_ID" desc:"ID of a user, that should receive admin privileges. Consider that the UUID can be encoded in some LDAP deployment configurations like in .ldif files. These need to be decoded beforehand." introductionVersion:"pre5.0"`
|
||||
Runtime Runtime `yaml:"runtime"`
|
||||
|
||||
Activitylog *activitylog.Config `yaml:"activitylog"`
|
||||
Antivirus *antivirus.Config `yaml:"antivirus"`
|
||||
AppProvider *appProvider.Config `yaml:"app_provider"`
|
||||
AppRegistry *appRegistry.Config `yaml:"app_registry"`
|
||||
Audit *audit.Config `yaml:"audit"`
|
||||
AuthApp *authapp.Config `yaml:"auth_app"`
|
||||
AuthBasic *authbasic.Config `yaml:"auth_basic"`
|
||||
AuthBearer *authbearer.Config `yaml:"auth_bearer"`
|
||||
AuthMachine *authmachine.Config `yaml:"auth_machine"`
|
||||
AuthService *authservice.Config `yaml:"auth_service"`
|
||||
Clientlog *clientlog.Config `yaml:"clientlog"`
|
||||
Collaboration *collaboration.Config `yaml:"collaboration"`
|
||||
EventHistory *eventhistory.Config `yaml:"eventhistory"`
|
||||
Frontend *frontend.Config `yaml:"frontend"`
|
||||
Gateway *gateway.Config `yaml:"gateway"`
|
||||
Graph *graph.Config `yaml:"graph"`
|
||||
Groups *groups.Config `yaml:"groups"`
|
||||
IDM *idm.Config `yaml:"idm"`
|
||||
IDP *idp.Config `yaml:"idp"`
|
||||
Invitations *invitations.Config `yaml:"invitations"`
|
||||
Nats *nats.Config `yaml:"nats"`
|
||||
Notifications *notifications.Config `yaml:"notifications"`
|
||||
OCDav *ocdav.Config `yaml:"ocdav"`
|
||||
OCM *ocm.Config `yaml:"ocm"`
|
||||
OCS *ocs.Config `yaml:"ocs"`
|
||||
Postprocessing *postprocessing.Config `yaml:"postprocessing"`
|
||||
Policies *policies.Config `yaml:"policies"`
|
||||
Proxy *proxy.Config `yaml:"proxy"`
|
||||
Settings *settings.Config `yaml:"settings"`
|
||||
Sharing *sharing.Config `yaml:"sharing"`
|
||||
SSE *sse.Config `yaml:"sse"`
|
||||
StorageSystem *storagesystem.Config `yaml:"storage_system"`
|
||||
StoragePublicLink *storagepublic.Config `yaml:"storage_public"`
|
||||
StorageShares *storageshares.Config `yaml:"storage_shares"`
|
||||
StorageUsers *storageusers.Config `yaml:"storage_users"`
|
||||
Thumbnails *thumbnails.Config `yaml:"thumbnails"`
|
||||
Userlog *userlog.Config `yaml:"userlog"`
|
||||
Users *users.Config `yaml:"users"`
|
||||
Web *web.Config `yaml:"web"`
|
||||
WebDAV *webdav.Config `yaml:"webdav"`
|
||||
Webfinger *webfinger.Config `yaml:"webfinger"`
|
||||
Search *search.Config `yaml:"search"`
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Config Suite")
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/opencloud-eu/opencloud/pkg/config"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
var _ = Describe("Config", func() {
|
||||
It("Success generating the default config", func() {
|
||||
cfg := config.DefaultConfig()
|
||||
_, err := yaml.Marshal(cfg)
|
||||
Expect(err).To(BeNil())
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,30 @@
|
||||
package configlog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Error logs the error
|
||||
func Error(err error) {
|
||||
if err != nil {
|
||||
fmt.Printf("%v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ReturnError logs the error and returns it unchanged
|
||||
func ReturnError(err error) error {
|
||||
if err != nil {
|
||||
fmt.Printf("%v\n", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ReturnFatal logs the error and calls os.Exit(1) and returns nil if no error is passed
|
||||
func ReturnFatal(err error) error {
|
||||
if err != nil {
|
||||
fmt.Printf("%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/opencloud-eu/opencloud/pkg/shared"
|
||||
activitylog "github.com/opencloud-eu/opencloud/services/activitylog/pkg/config/defaults"
|
||||
antivirus "github.com/opencloud-eu/opencloud/services/antivirus/pkg/config/defaults"
|
||||
appProvider "github.com/opencloud-eu/opencloud/services/app-provider/pkg/config/defaults"
|
||||
appRegistry "github.com/opencloud-eu/opencloud/services/app-registry/pkg/config/defaults"
|
||||
audit "github.com/opencloud-eu/opencloud/services/audit/pkg/config/defaults"
|
||||
authapp "github.com/opencloud-eu/opencloud/services/auth-app/pkg/config/defaults"
|
||||
authbasic "github.com/opencloud-eu/opencloud/services/auth-basic/pkg/config/defaults"
|
||||
authbearer "github.com/opencloud-eu/opencloud/services/auth-bearer/pkg/config/defaults"
|
||||
authmachine "github.com/opencloud-eu/opencloud/services/auth-machine/pkg/config/defaults"
|
||||
authservice "github.com/opencloud-eu/opencloud/services/auth-service/pkg/config/defaults"
|
||||
clientlog "github.com/opencloud-eu/opencloud/services/clientlog/pkg/config/defaults"
|
||||
collaboration "github.com/opencloud-eu/opencloud/services/collaboration/pkg/config/defaults"
|
||||
eventhistory "github.com/opencloud-eu/opencloud/services/eventhistory/pkg/config/defaults"
|
||||
frontend "github.com/opencloud-eu/opencloud/services/frontend/pkg/config/defaults"
|
||||
gateway "github.com/opencloud-eu/opencloud/services/gateway/pkg/config/defaults"
|
||||
graph "github.com/opencloud-eu/opencloud/services/graph/pkg/config/defaults"
|
||||
groups "github.com/opencloud-eu/opencloud/services/groups/pkg/config/defaults"
|
||||
idm "github.com/opencloud-eu/opencloud/services/idm/pkg/config/defaults"
|
||||
idp "github.com/opencloud-eu/opencloud/services/idp/pkg/config/defaults"
|
||||
invitations "github.com/opencloud-eu/opencloud/services/invitations/pkg/config/defaults"
|
||||
nats "github.com/opencloud-eu/opencloud/services/nats/pkg/config/defaults"
|
||||
notifications "github.com/opencloud-eu/opencloud/services/notifications/pkg/config/defaults"
|
||||
ocdav "github.com/opencloud-eu/opencloud/services/ocdav/pkg/config/defaults"
|
||||
ocm "github.com/opencloud-eu/opencloud/services/ocm/pkg/config/defaults"
|
||||
ocs "github.com/opencloud-eu/opencloud/services/ocs/pkg/config/defaults"
|
||||
policies "github.com/opencloud-eu/opencloud/services/policies/pkg/config/defaults"
|
||||
postprocessing "github.com/opencloud-eu/opencloud/services/postprocessing/pkg/config/defaults"
|
||||
proxy "github.com/opencloud-eu/opencloud/services/proxy/pkg/config/defaults"
|
||||
search "github.com/opencloud-eu/opencloud/services/search/pkg/config/defaults"
|
||||
settings "github.com/opencloud-eu/opencloud/services/settings/pkg/config/defaults"
|
||||
sharing "github.com/opencloud-eu/opencloud/services/sharing/pkg/config/defaults"
|
||||
sse "github.com/opencloud-eu/opencloud/services/sse/pkg/config/defaults"
|
||||
storagepublic "github.com/opencloud-eu/opencloud/services/storage-publiclink/pkg/config/defaults"
|
||||
storageshares "github.com/opencloud-eu/opencloud/services/storage-shares/pkg/config/defaults"
|
||||
storageSystem "github.com/opencloud-eu/opencloud/services/storage-system/pkg/config/defaults"
|
||||
storageusers "github.com/opencloud-eu/opencloud/services/storage-users/pkg/config/defaults"
|
||||
thumbnails "github.com/opencloud-eu/opencloud/services/thumbnails/pkg/config/defaults"
|
||||
userlog "github.com/opencloud-eu/opencloud/services/userlog/pkg/config/defaults"
|
||||
users "github.com/opencloud-eu/opencloud/services/users/pkg/config/defaults"
|
||||
web "github.com/opencloud-eu/opencloud/services/web/pkg/config/defaults"
|
||||
webdav "github.com/opencloud-eu/opencloud/services/webdav/pkg/config/defaults"
|
||||
webfinger "github.com/opencloud-eu/opencloud/services/webfinger/pkg/config/defaults"
|
||||
)
|
||||
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
OcisURL: "https://localhost:9200",
|
||||
Runtime: Runtime{
|
||||
Port: "9250",
|
||||
Host: "localhost",
|
||||
},
|
||||
Reva: &shared.Reva{
|
||||
Address: "com.owncloud.api.gateway",
|
||||
},
|
||||
|
||||
Activitylog: activitylog.DefaultConfig(),
|
||||
Antivirus: antivirus.DefaultConfig(),
|
||||
AppProvider: appProvider.DefaultConfig(),
|
||||
AppRegistry: appRegistry.DefaultConfig(),
|
||||
Audit: audit.DefaultConfig(),
|
||||
AuthApp: authapp.DefaultConfig(),
|
||||
AuthBasic: authbasic.DefaultConfig(),
|
||||
AuthBearer: authbearer.DefaultConfig(),
|
||||
AuthMachine: authmachine.DefaultConfig(),
|
||||
AuthService: authservice.DefaultConfig(),
|
||||
Clientlog: clientlog.DefaultConfig(),
|
||||
Collaboration: collaboration.DefaultConfig(),
|
||||
EventHistory: eventhistory.DefaultConfig(),
|
||||
Frontend: frontend.DefaultConfig(),
|
||||
Gateway: gateway.DefaultConfig(),
|
||||
Graph: graph.DefaultConfig(),
|
||||
Groups: groups.DefaultConfig(),
|
||||
IDM: idm.DefaultConfig(),
|
||||
IDP: idp.DefaultConfig(),
|
||||
Invitations: invitations.DefaultConfig(),
|
||||
Nats: nats.DefaultConfig(),
|
||||
Notifications: notifications.DefaultConfig(),
|
||||
OCDav: ocdav.DefaultConfig(),
|
||||
OCM: ocm.DefaultConfig(),
|
||||
OCS: ocs.DefaultConfig(),
|
||||
Postprocessing: postprocessing.DefaultConfig(),
|
||||
Policies: policies.DefaultConfig(),
|
||||
Proxy: proxy.DefaultConfig(),
|
||||
Search: search.DefaultConfig(),
|
||||
Settings: settings.DefaultConfig(),
|
||||
Sharing: sharing.DefaultConfig(),
|
||||
SSE: sse.DefaultConfig(),
|
||||
StoragePublicLink: storagepublic.DefaultConfig(),
|
||||
StorageShares: storageshares.DefaultConfig(),
|
||||
StorageSystem: storageSystem.DefaultConfig(),
|
||||
StorageUsers: storageusers.DefaultConfig(),
|
||||
Thumbnails: thumbnails.DefaultConfig(),
|
||||
Userlog: userlog.DefaultConfig(),
|
||||
Users: users.DefaultConfig(),
|
||||
Web: web.DefaultConfig(),
|
||||
WebDAV: webdav.DefaultConfig(),
|
||||
Webfinger: webfinger.DefaultConfig(),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package defaults
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
)
|
||||
|
||||
const ()
|
||||
|
||||
var (
|
||||
// switch between modes
|
||||
BaseDataPathType = "homedir" // or "path"
|
||||
// default data path
|
||||
BaseDataPathValue = "/var/lib/ocis"
|
||||
)
|
||||
|
||||
func BaseDataPath() string {
|
||||
|
||||
// It is not nice to have hidden / secrete configuration options
|
||||
// But how can we update the base path for every occurrence with a flagset option?
|
||||
// This is currently not possible and needs a new configuration concept
|
||||
p := os.Getenv("OC_BASE_DATA_PATH")
|
||||
if p != "" {
|
||||
return p
|
||||
}
|
||||
|
||||
switch BaseDataPathType {
|
||||
case "homedir":
|
||||
dir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
// fallback to BaseDatapathValue for users without home
|
||||
return BaseDataPathValue
|
||||
}
|
||||
return path.Join(dir, ".ocis")
|
||||
case "path":
|
||||
return BaseDataPathValue
|
||||
default:
|
||||
log.Fatalf("BaseDataPathType %s not found", BaseDataPathType)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// switch between modes
|
||||
BaseConfigPathType = "homedir" // or "path"
|
||||
// default config path
|
||||
BaseConfigPathValue = "/etc/ocis"
|
||||
)
|
||||
|
||||
func BaseConfigPath() string {
|
||||
|
||||
// It is not nice to have hidden / secrete configuration options
|
||||
// But how can we update the base path for every occurrence with a flagset option?
|
||||
// This is currently not possible and needs a new configuration concept
|
||||
p := os.Getenv("OC_CONFIG_DIR")
|
||||
if p != "" {
|
||||
return p
|
||||
}
|
||||
|
||||
switch BaseConfigPathType {
|
||||
case "homedir":
|
||||
dir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
// fallback to BaseConfigPathValue for users without home
|
||||
return BaseConfigPathValue
|
||||
}
|
||||
return path.Join(dir, ".ocis", "config")
|
||||
case "path":
|
||||
return BaseConfigPathValue
|
||||
default:
|
||||
log.Fatalf("BaseConfigPathType %s not found", BaseConfigPathType)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Joe Shaw
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
@@ -0,0 +1,88 @@
|
||||
`envdecode` is a Go package for populating structs from environment
|
||||
variables. It's basically a fork of https://github.com/joeshaw/envdecode,
|
||||
but changed to support multiple environment variables (precedence).
|
||||
|
||||
`envdecode` uses struct tags to map environment variables to fields,
|
||||
allowing you to use any names you want for environment variables.
|
||||
`envdecode` will recurse into nested structs, including pointers to
|
||||
nested structs, but it will not allocate new pointers to structs.
|
||||
|
||||
## API
|
||||
|
||||
Full API docs are available on
|
||||
[godoc.org](https://godoc.org/github.com/owncloud/ocis/ocis-pkg/config/envdecode).
|
||||
|
||||
Define a struct with `env` struct tags:
|
||||
|
||||
```go
|
||||
type Config struct {
|
||||
Hostname string `env:"SERVER_HOSTNAME,default=localhost"`
|
||||
Port uint16 `env:"HTTP_PORT;SERVER_PORT,default=8080"`
|
||||
|
||||
AWS struct {
|
||||
ID string `env:"AWS_ACCESS_KEY_ID"`
|
||||
Secret string `env:"AWS_SECRET_ACCESS_KEY,required"`
|
||||
SnsTopics []string `env:"AWS_SNS_TOPICS"`
|
||||
}
|
||||
|
||||
Timeout time.Duration `env:"TIMEOUT,default=1m,strict"`
|
||||
}
|
||||
```
|
||||
|
||||
Fields _must be exported_ (i.e. begin with a capital letter) in order
|
||||
for `envdecode` to work with them. An error will be returned if a
|
||||
struct with no exported fields is decoded (including one that contains
|
||||
no `env` tags at all).
|
||||
Default values may be provided by appending ",default=value" to the
|
||||
struct tag. Required values may be marked by appending ",required" to the
|
||||
struct tag. Strict values may be marked by appending ",strict" which will
|
||||
return an error on Decode if there is an error while parsing.
|
||||
|
||||
Then call `envdecode.Decode`:
|
||||
|
||||
```go
|
||||
var cfg Config
|
||||
err := envdecode.Decode(&cfg)
|
||||
```
|
||||
|
||||
If you want all fields to act `strict`, you may use `envdecode.StrictDecode`:
|
||||
|
||||
```go
|
||||
var cfg Config
|
||||
err := envdecode.StrictDecode(&cfg)
|
||||
```
|
||||
|
||||
All parse errors will fail fast and return an error in this mode.
|
||||
|
||||
## Supported types
|
||||
|
||||
- Structs (and pointer to structs)
|
||||
- Slices of defined types below, separated by semicolon
|
||||
- `bool`
|
||||
- `float32`, `float64`
|
||||
- `int`, `int8`, `int16`, `int32`, `int64`
|
||||
- `uint`, `uint8`, `uint16`, `uint32`, `uint64`
|
||||
- `string`
|
||||
- `time.Duration`, using the [`time.ParseDuration()` format](http://golang.org/pkg/time/#ParseDuration)
|
||||
- `*url.URL`, using [`url.Parse()`](https://godoc.org/net/url#Parse)
|
||||
- Types those implement a `Decoder` interface
|
||||
|
||||
## Custom `Decoder`
|
||||
|
||||
If you want a field to be decoded with custom behavior, you may implement the interface `Decoder` for the filed type.
|
||||
|
||||
```go
|
||||
type Config struct {
|
||||
IPAddr IP `env:"IP_ADDR"`
|
||||
}
|
||||
|
||||
type IP net.IP
|
||||
|
||||
// Decode implements the interface `envdecode.Decoder`
|
||||
func (i *IP) Decode(repl string) error {
|
||||
*i = net.ParseIP(repl)
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
`Decoder` is the interface implemented by an object that can decode an environment variable string representation of itself.
|
||||
@@ -0,0 +1,436 @@
|
||||
// Package envdecode is a package for populating structs from environment
|
||||
// variables, using struct tags.
|
||||
package envdecode
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrInvalidTarget indicates that the target value passed to
|
||||
// Decode is invalid. Target must be a non-nil pointer to a struct.
|
||||
var ErrInvalidTarget = errors.New("target must be non-nil pointer to struct that has at least one exported field with a valid env tag")
|
||||
var ErrNoTargetFieldsAreSet = errors.New("none of the target fields were set from environment variables")
|
||||
|
||||
// FailureFunc is called when an error is encountered during a MustDecode
|
||||
// operation. It prints the error and terminates the process.
|
||||
//
|
||||
// This variable can be assigned to another function of the user-programmer's
|
||||
// design, allowing for graceful recovery of the problem, such as loading
|
||||
// from a backup configuration file.
|
||||
var FailureFunc = func(err error) {
|
||||
log.Fatalf("envdecode: an error was encountered while decoding: %v\n", err)
|
||||
}
|
||||
|
||||
// Decoder is the interface implemented by an object that can decode an
|
||||
// environment variable string representation of itself.
|
||||
type Decoder interface {
|
||||
Decode(string) error
|
||||
}
|
||||
|
||||
// Decode environment variables into the provided target. The target
|
||||
// must be a non-nil pointer to a struct. Fields in the struct must
|
||||
// be exported, and tagged with an "env" struct tag with a value
|
||||
// containing the name of the environment variable. An error is
|
||||
// returned if there are no exported members tagged.
|
||||
//
|
||||
// Default values may be provided by appending ",default=value" to the
|
||||
// struct tag. Required values may be marked by appending ",required"
|
||||
// to the struct tag. It is an error to provide both "default" and
|
||||
// "required". Strict values may be marked by appending ",strict" which
|
||||
// will return an error on Decode if there is an error while parsing.
|
||||
// If everything must be strict, consider using StrictDecode instead.
|
||||
//
|
||||
// All primitive types are supported, including bool, floating point,
|
||||
// signed and unsigned integers, and string. Boolean and numeric
|
||||
// types are decoded using the standard strconv Parse functions for
|
||||
// those types. Structs and pointers to structs are decoded
|
||||
// recursively. time.Duration is supported via the
|
||||
// time.ParseDuration() function and *url.URL is supported via the
|
||||
// url.Parse() function. Slices are supported for all above mentioned
|
||||
// primitive types. Semicolon is used as delimiter in environment variables.
|
||||
func Decode(target interface{}) error {
|
||||
nFields, err := decode(target, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if we didn't do anything - the user probably did something
|
||||
// wrong like leave all fields unexported.
|
||||
if nFields == 0 {
|
||||
return ErrNoTargetFieldsAreSet
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StrictDecode is similar to Decode except all fields will have an implicit
|
||||
// ",strict" on all fields.
|
||||
func StrictDecode(target interface{}) error {
|
||||
nFields, err := decode(target, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if we didn't do anything - the user probably did something
|
||||
// wrong like leave all fields unexported.
|
||||
if nFields == 0 {
|
||||
return ErrInvalidTarget
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decode(target interface{}, strict bool) (int, error) {
|
||||
s := reflect.ValueOf(target)
|
||||
if s.Kind() != reflect.Ptr || s.IsNil() {
|
||||
return 0, ErrInvalidTarget
|
||||
}
|
||||
|
||||
s = s.Elem()
|
||||
if s.Kind() != reflect.Struct {
|
||||
return 0, ErrInvalidTarget
|
||||
}
|
||||
|
||||
t := s.Type()
|
||||
setFieldCount := 0
|
||||
for i := 0; i < s.NumField(); i++ {
|
||||
// Localize the umbrella `strict` value to the specific field.
|
||||
strict := strict
|
||||
|
||||
f := s.Field(i)
|
||||
|
||||
switch f.Kind() {
|
||||
case reflect.Ptr:
|
||||
if f.Elem().Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
f = f.Elem()
|
||||
fallthrough
|
||||
|
||||
case reflect.Struct:
|
||||
if !f.Addr().CanInterface() {
|
||||
continue
|
||||
}
|
||||
|
||||
ss := f.Addr().Interface()
|
||||
_, custom := ss.(Decoder)
|
||||
if custom {
|
||||
break
|
||||
}
|
||||
|
||||
n, err := decode(ss, strict)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
setFieldCount += n
|
||||
}
|
||||
|
||||
if !f.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
tag := t.Field(i).Tag.Get("env")
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.Split(tag, ",")
|
||||
overrides := strings.Split(parts[0], `;`)
|
||||
|
||||
var env string
|
||||
var envSet bool
|
||||
for _, override := range overrides {
|
||||
if v, set := os.LookupEnv(override); set {
|
||||
env = v
|
||||
envSet = true
|
||||
}
|
||||
}
|
||||
|
||||
required := false
|
||||
hasDefault := false
|
||||
defaultValue := ""
|
||||
|
||||
for _, o := range parts[1:] {
|
||||
if !required {
|
||||
required = strings.HasPrefix(o, "required")
|
||||
}
|
||||
if strings.HasPrefix(o, "default=") {
|
||||
hasDefault = true
|
||||
defaultValue = o[8:]
|
||||
}
|
||||
if !strict {
|
||||
strict = strings.HasPrefix(o, "strict")
|
||||
}
|
||||
}
|
||||
|
||||
if required && hasDefault {
|
||||
panic(`envdecode: "default" and "required" may not be specified in the same annotation`)
|
||||
}
|
||||
if !envSet && required {
|
||||
return 0, fmt.Errorf("the environment variable \"%s\" is missing", parts[0])
|
||||
}
|
||||
if !envSet {
|
||||
env = defaultValue
|
||||
}
|
||||
if !envSet && env == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
setFieldCount++
|
||||
|
||||
unmarshaler, implementsUnmarshaler := f.Addr().Interface().(encoding.TextUnmarshaler)
|
||||
decoder, implmentsDecoder := f.Addr().Interface().(Decoder)
|
||||
if implmentsDecoder {
|
||||
if err := decoder.Decode(env); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else if implementsUnmarshaler {
|
||||
if err := unmarshaler.UnmarshalText([]byte(env)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else if f.Kind() == reflect.Slice {
|
||||
if err := decodeSlice(&f, env); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
if err := decodePrimitiveType(&f, env); err != nil && strict {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return setFieldCount, nil
|
||||
}
|
||||
|
||||
func decodeSlice(f *reflect.Value, env string) error {
|
||||
parts := strings.Split(env, ",")
|
||||
|
||||
values := parts[:0]
|
||||
for _, x := range parts {
|
||||
if x != "" {
|
||||
values = append(values, strings.TrimSpace(x))
|
||||
}
|
||||
}
|
||||
|
||||
valuesCount := len(values)
|
||||
slice := reflect.MakeSlice(f.Type(), valuesCount, valuesCount)
|
||||
if valuesCount > 0 {
|
||||
for i := 0; i < valuesCount; i++ {
|
||||
e := slice.Index(i)
|
||||
err := decodePrimitiveType(&e, values[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
f.Set(slice)
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodePrimitiveType(f *reflect.Value, env string) error {
|
||||
switch f.Kind() {
|
||||
case reflect.Bool:
|
||||
v, err := strconv.ParseBool(env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.SetBool(v)
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
bits := f.Type().Bits()
|
||||
v, err := strconv.ParseFloat(env, bits)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.SetFloat(v)
|
||||
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
if t := f.Type(); t.PkgPath() == "time" && t.Name() == "Duration" {
|
||||
v, err := time.ParseDuration(env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.SetInt(int64(v))
|
||||
} else {
|
||||
bits := f.Type().Bits()
|
||||
v, err := strconv.ParseInt(env, 0, bits)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.SetInt(v)
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
bits := f.Type().Bits()
|
||||
v, err := strconv.ParseUint(env, 0, bits)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.SetUint(v)
|
||||
|
||||
case reflect.String:
|
||||
f.SetString(env)
|
||||
|
||||
case reflect.Ptr:
|
||||
if t := f.Type().Elem(); t.Kind() == reflect.Struct && t.PkgPath() == "net/url" && t.Name() == "URL" {
|
||||
v, err := url.Parse(env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.Set(reflect.ValueOf(v))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MustDecode calls Decode and terminates the process if any errors
|
||||
// are encountered.
|
||||
func MustDecode(target interface{}) {
|
||||
if err := Decode(target); err != nil {
|
||||
FailureFunc(err)
|
||||
}
|
||||
}
|
||||
|
||||
// MustStrictDecode calls StrictDecode and terminates the process if any errors
|
||||
// are encountered.
|
||||
func MustStrictDecode(target interface{}) {
|
||||
if err := StrictDecode(target); err != nil {
|
||||
FailureFunc(err)
|
||||
}
|
||||
}
|
||||
|
||||
//// Configuration info for Export
|
||||
|
||||
type ConfigInfo struct {
|
||||
Field string
|
||||
EnvVar string
|
||||
Value string
|
||||
DefaultValue string
|
||||
HasDefault bool
|
||||
Required bool
|
||||
UsesEnv bool
|
||||
}
|
||||
|
||||
type ConfigInfoSlice []*ConfigInfo
|
||||
|
||||
func (c ConfigInfoSlice) Less(i, j int) bool {
|
||||
return c[i].EnvVar < c[j].EnvVar
|
||||
}
|
||||
func (c ConfigInfoSlice) Len() int {
|
||||
return len(c)
|
||||
}
|
||||
func (c ConfigInfoSlice) Swap(i, j int) {
|
||||
c[i], c[j] = c[j], c[i]
|
||||
}
|
||||
|
||||
// Returns a list of final configuration metadata sorted by envvar name
|
||||
func Export(target interface{}) ([]*ConfigInfo, error) {
|
||||
s := reflect.ValueOf(target)
|
||||
if s.Kind() != reflect.Ptr || s.IsNil() {
|
||||
return nil, ErrInvalidTarget
|
||||
}
|
||||
|
||||
cfg := []*ConfigInfo{}
|
||||
|
||||
s = s.Elem()
|
||||
if s.Kind() != reflect.Struct {
|
||||
return nil, ErrInvalidTarget
|
||||
}
|
||||
|
||||
t := s.Type()
|
||||
for i := 0; i < s.NumField(); i++ {
|
||||
f := s.Field(i)
|
||||
fName := t.Field(i).Name
|
||||
|
||||
fElem := f
|
||||
if f.Kind() == reflect.Ptr {
|
||||
fElem = f.Elem()
|
||||
}
|
||||
if fElem.Kind() == reflect.Struct {
|
||||
ss := fElem.Addr().Interface()
|
||||
subCfg, err := Export(ss)
|
||||
if err != ErrInvalidTarget {
|
||||
f = fElem
|
||||
for _, v := range subCfg {
|
||||
v.Field = fmt.Sprintf("%s.%s", fName, v.Field)
|
||||
cfg = append(cfg, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tag := t.Field(i).Tag.Get("env")
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.Split(tag, ",")
|
||||
|
||||
ci := &ConfigInfo{
|
||||
Field: fName,
|
||||
EnvVar: parts[0],
|
||||
UsesEnv: os.Getenv(parts[0]) != "",
|
||||
}
|
||||
|
||||
for _, o := range parts[1:] {
|
||||
if strings.HasPrefix(o, "default=") {
|
||||
ci.HasDefault = true
|
||||
ci.DefaultValue = o[8:]
|
||||
} else if strings.HasPrefix(o, "required") {
|
||||
ci.Required = true
|
||||
}
|
||||
}
|
||||
|
||||
if f.Kind() == reflect.Ptr && f.IsNil() {
|
||||
ci.Value = ""
|
||||
} else if stringer, ok := f.Interface().(fmt.Stringer); ok {
|
||||
ci.Value = stringer.String()
|
||||
} else {
|
||||
switch f.Kind() {
|
||||
case reflect.Bool:
|
||||
ci.Value = strconv.FormatBool(f.Bool())
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
bits := f.Type().Bits()
|
||||
ci.Value = strconv.FormatFloat(f.Float(), 'f', -1, bits)
|
||||
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
ci.Value = strconv.FormatInt(f.Int(), 10)
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
ci.Value = strconv.FormatUint(f.Uint(), 10)
|
||||
|
||||
case reflect.String:
|
||||
ci.Value = f.String()
|
||||
|
||||
case reflect.Slice:
|
||||
ci.Value = fmt.Sprintf("%v", f.Interface())
|
||||
|
||||
default:
|
||||
// Unable to determine string format for value
|
||||
return nil, ErrInvalidTarget
|
||||
}
|
||||
}
|
||||
|
||||
cfg = append(cfg, ci)
|
||||
}
|
||||
|
||||
// No configuration tags found, assume invalid input
|
||||
if len(cfg) == 0 {
|
||||
return nil, ErrInvalidTarget
|
||||
}
|
||||
|
||||
sort.Sort(ConfigInfoSlice(cfg))
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
@@ -0,0 +1,790 @@
|
||||
package envdecode
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type nested struct {
|
||||
String string `env:"TEST_STRING"`
|
||||
}
|
||||
|
||||
type testConfig struct {
|
||||
String string `env:"TEST_STRING"`
|
||||
Int64 int64 `env:"TEST_INT64"`
|
||||
Uint16 uint16 `env:"TEST_UINT16"`
|
||||
Float64 float64 `env:"TEST_FLOAT64"`
|
||||
Bool bool `env:"TEST_BOOL"`
|
||||
Duration time.Duration `env:"TEST_DURATION"`
|
||||
URL *url.URL `env:"TEST_URL"`
|
||||
|
||||
StringSlice []string `env:"TEST_STRING_SLICE"`
|
||||
Int64Slice []int64 `env:"TEST_INT64_SLICE"`
|
||||
Uint16Slice []uint16 `env:"TEST_UINT16_SLICE"`
|
||||
Float64Slice []float64 `env:"TEST_FLOAT64_SLICE"`
|
||||
BoolSlice []bool `env:"TEST_BOOL_SLICE"`
|
||||
DurationSlice []time.Duration `env:"TEST_DURATION_SLICE"`
|
||||
URLSlice []*url.URL `env:"TEST_URL_SLICE"`
|
||||
|
||||
UnsetString string `env:"TEST_UNSET_STRING"`
|
||||
UnsetInt64 int64 `env:"TEST_UNSET_INT64"`
|
||||
UnsetDuration time.Duration `env:"TEST_UNSET_DURATION"`
|
||||
UnsetURL *url.URL `env:"TEST_UNSET_URL"`
|
||||
UnsetSlice []string `env:"TEST_UNSET_SLICE"`
|
||||
|
||||
InvalidInt64 int64 `env:"TEST_INVALID_INT64"`
|
||||
|
||||
UnusedField string
|
||||
unexportedField string
|
||||
|
||||
IgnoredPtr *bool `env:"TEST_BOOL"`
|
||||
|
||||
Nested nested
|
||||
NestedPtr *nested
|
||||
|
||||
DecoderStruct decoderStruct `env:"TEST_DECODER_STRUCT"`
|
||||
DecoderStructPtr *decoderStruct `env:"TEST_DECODER_STRUCT_PTR"`
|
||||
|
||||
DecoderString decoderString `env:"TEST_DECODER_STRING"`
|
||||
|
||||
UnmarshalerNumber unmarshalerNumber `env:"TEST_UNMARSHALER_NUMBER"`
|
||||
|
||||
DefaultInt int `env:"TEST_UNSET,asdf=asdf,default=1234"`
|
||||
DefaultSliceInt []int `env:"TEST_UNSET,asdf=asdf,default=1"`
|
||||
DefaultDuration time.Duration `env:"TEST_UNSET,asdf=asdf,default=24h"`
|
||||
DefaultURL *url.URL `env:"TEST_UNSET,default=http://example.com"`
|
||||
}
|
||||
|
||||
type testConfigNoSet struct {
|
||||
Some string `env:"TEST_THIS_ENV_WILL_NOT_BE_SET"`
|
||||
}
|
||||
|
||||
type testConfigRequired struct {
|
||||
Required string `env:"TEST_REQUIRED,required"`
|
||||
}
|
||||
|
||||
type testConfigRequiredDefault struct {
|
||||
RequiredDefault string `env:"TEST_REQUIRED_DEFAULT,required,default=test"`
|
||||
}
|
||||
|
||||
type testConfigOverride struct {
|
||||
OverrideString string `env:"TEST_OVERRIDE_A;TEST_OVERRIDE_B,default=override_default"`
|
||||
}
|
||||
|
||||
type testNoExportedFields struct {
|
||||
// following unexported fields are used for tests
|
||||
aString string `env:"TEST_STRING"` //nolint:structcheck,unused
|
||||
anInt64 int64 `env:"TEST_INT64"` //nolint:structcheck,unused
|
||||
aUint16 uint16 `env:"TEST_UINT16"` //nolint:structcheck,unused
|
||||
aFloat64 float64 `env:"TEST_FLOAT64"` //nolint:structcheck,unused
|
||||
aBool bool `env:"TEST_BOOL"` //nolint:structcheck,unused
|
||||
}
|
||||
|
||||
type testNoTags struct {
|
||||
String string
|
||||
}
|
||||
|
||||
type decoderStruct struct {
|
||||
String string
|
||||
}
|
||||
|
||||
func (d *decoderStruct) Decode(env string) error {
|
||||
return json.Unmarshal([]byte(env), &d)
|
||||
}
|
||||
|
||||
type decoderString string
|
||||
|
||||
func (d *decoderString) Decode(env string) error {
|
||||
r, l := []rune(env), len(env)
|
||||
|
||||
for i := 0; i < l/2; i++ {
|
||||
r[i], r[l-1-i] = r[l-1-i], r[i]
|
||||
}
|
||||
|
||||
*d = decoderString(r)
|
||||
return nil
|
||||
}
|
||||
|
||||
type unmarshalerNumber uint8
|
||||
|
||||
func (o *unmarshalerNumber) UnmarshalText(raw []byte) error {
|
||||
n, err := strconv.ParseUint(string(raw), 8, 8) // parse text as octal number
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*o = unmarshalerNumber(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDecode(t *testing.T) {
|
||||
int64Val := int64(-(1 << 50))
|
||||
int64AsString := fmt.Sprintf("%d", int64Val)
|
||||
piAsString := fmt.Sprintf("%.48f", math.Pi)
|
||||
|
||||
os.Setenv("TEST_STRING", "foo")
|
||||
os.Setenv("TEST_INT64", int64AsString)
|
||||
os.Setenv("TEST_UINT16", "60000")
|
||||
os.Setenv("TEST_FLOAT64", piAsString)
|
||||
os.Setenv("TEST_BOOL", "true")
|
||||
os.Setenv("TEST_DURATION", "10m")
|
||||
os.Setenv("TEST_URL", "https://example.com")
|
||||
os.Setenv("TEST_INVALID_INT64", "asdf")
|
||||
os.Setenv("TEST_STRING_SLICE", "foo,bar")
|
||||
os.Setenv("TEST_INT64_SLICE", int64AsString+","+int64AsString)
|
||||
os.Setenv("TEST_UINT16_SLICE", "60000,50000")
|
||||
os.Setenv("TEST_FLOAT64_SLICE", piAsString+","+piAsString)
|
||||
os.Setenv("TEST_BOOL_SLICE", "true, false, true")
|
||||
os.Setenv("TEST_DURATION_SLICE", "10m, 20m")
|
||||
os.Setenv("TEST_URL_SLICE", "https://example.com")
|
||||
os.Setenv("TEST_DECODER_STRUCT", "{\"string\":\"foo\"}")
|
||||
os.Setenv("TEST_DECODER_STRUCT_PTR", "{\"string\":\"foo\"}")
|
||||
os.Setenv("TEST_DECODER_STRING", "oof")
|
||||
os.Setenv("TEST_UNMARSHALER_NUMBER", "07")
|
||||
|
||||
var tc testConfig
|
||||
tc.NestedPtr = &nested{}
|
||||
tc.DecoderStructPtr = &decoderStruct{}
|
||||
|
||||
err := Decode(&tc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tc.String != "foo" {
|
||||
t.Fatalf(`Expected "foo", got "%s"`, tc.String)
|
||||
}
|
||||
|
||||
if tc.Int64 != -(1 << 50) {
|
||||
t.Fatalf("Expected %d, got %d", -(1 << 50), tc.Int64)
|
||||
}
|
||||
|
||||
if tc.Uint16 != 60000 {
|
||||
t.Fatalf("Expected 60000, got %d", tc.Uint16)
|
||||
}
|
||||
|
||||
if tc.Float64 != math.Pi {
|
||||
t.Fatalf("Expected %.48f, got %.48f", math.Pi, tc.Float64)
|
||||
}
|
||||
|
||||
if !tc.Bool {
|
||||
t.Fatal("Expected true, got false")
|
||||
}
|
||||
|
||||
duration, _ := time.ParseDuration("10m")
|
||||
if tc.Duration != duration {
|
||||
t.Fatalf("Expected %d, got %d", duration, tc.Duration)
|
||||
}
|
||||
|
||||
if tc.URL == nil {
|
||||
t.Fatalf("Expected https://example.com, got nil")
|
||||
} else if tc.URL.String() != "https://example.com" {
|
||||
t.Fatalf("Expected https://example.com, got %s", tc.URL.String())
|
||||
}
|
||||
|
||||
expectedStringSlice := []string{"foo", "bar"}
|
||||
if !reflect.DeepEqual(tc.StringSlice, expectedStringSlice) {
|
||||
t.Fatalf("Expected %s, got %s", expectedStringSlice, tc.StringSlice)
|
||||
}
|
||||
|
||||
expectedInt64Slice := []int64{int64Val, int64Val}
|
||||
if !reflect.DeepEqual(tc.Int64Slice, expectedInt64Slice) {
|
||||
t.Fatalf("Expected %#v, got %#v", expectedInt64Slice, tc.Int64Slice)
|
||||
}
|
||||
|
||||
expectedUint16Slice := []uint16{60000, 50000}
|
||||
if !reflect.DeepEqual(tc.Uint16Slice, expectedUint16Slice) {
|
||||
t.Fatalf("Expected %#v, got %#v", expectedUint16Slice, tc.Uint16Slice)
|
||||
}
|
||||
|
||||
expectedFloat64Slice := []float64{math.Pi, math.Pi}
|
||||
if !reflect.DeepEqual(tc.Float64Slice, expectedFloat64Slice) {
|
||||
t.Fatalf("Expected %#v, got %#v", expectedFloat64Slice, tc.Float64Slice)
|
||||
}
|
||||
|
||||
expectedBoolSlice := []bool{true, false, true}
|
||||
if !reflect.DeepEqual(tc.BoolSlice, expectedBoolSlice) {
|
||||
t.Fatalf("Expected %#v, got %#v", expectedBoolSlice, tc.BoolSlice)
|
||||
}
|
||||
|
||||
duration2, _ := time.ParseDuration("20m")
|
||||
expectedDurationSlice := []time.Duration{duration, duration2}
|
||||
if !reflect.DeepEqual(tc.DurationSlice, expectedDurationSlice) {
|
||||
t.Fatalf("Expected %s, got %s", expectedDurationSlice, tc.DurationSlice)
|
||||
}
|
||||
|
||||
urlVal, _ := url.Parse("https://example.com")
|
||||
expectedURLSlice := []*url.URL{urlVal}
|
||||
if !reflect.DeepEqual(tc.URLSlice, expectedURLSlice) {
|
||||
t.Fatalf("Expected %s, got %s", expectedURLSlice, tc.URLSlice)
|
||||
}
|
||||
|
||||
if tc.UnsetString != "" {
|
||||
t.Fatal("Got non-empty string unexpectedly")
|
||||
}
|
||||
|
||||
if tc.UnsetInt64 != 0 {
|
||||
t.Fatal("Got non-zero int unexpectedly")
|
||||
}
|
||||
|
||||
if tc.UnsetDuration != time.Duration(0) {
|
||||
t.Fatal("Got non-zero time.Duration unexpectedly")
|
||||
}
|
||||
|
||||
if tc.UnsetURL != nil {
|
||||
t.Fatal("Got non-zero *url.URL unexpectedly")
|
||||
}
|
||||
|
||||
if len(tc.UnsetSlice) > 0 {
|
||||
t.Fatal("Got not-empty string slice unexpectedly")
|
||||
}
|
||||
|
||||
if tc.InvalidInt64 != 0 {
|
||||
t.Fatal("Got non-zero int unexpectedly")
|
||||
}
|
||||
|
||||
if tc.UnusedField != "" {
|
||||
t.Fatal("Expected empty field")
|
||||
}
|
||||
|
||||
if tc.unexportedField != "" {
|
||||
t.Fatal("Expected empty field")
|
||||
}
|
||||
|
||||
if tc.IgnoredPtr != nil {
|
||||
t.Fatal("Expected nil pointer")
|
||||
}
|
||||
|
||||
if tc.Nested.String != "foo" {
|
||||
t.Fatalf(`Expected "foo", got "%s"`, tc.Nested.String)
|
||||
}
|
||||
|
||||
if tc.NestedPtr.String != "foo" {
|
||||
t.Fatalf(`Expected "foo", got "%s"`, tc.NestedPtr.String)
|
||||
}
|
||||
|
||||
if tc.DefaultInt != 1234 {
|
||||
t.Fatalf("Expected 1234, got %d", tc.DefaultInt)
|
||||
}
|
||||
|
||||
expectedDefaultSlice := []int{1}
|
||||
if !reflect.DeepEqual(tc.DefaultSliceInt, expectedDefaultSlice) {
|
||||
t.Fatalf("Expected %d, got %d", expectedDefaultSlice, tc.DefaultSliceInt)
|
||||
}
|
||||
|
||||
defaultDuration, _ := time.ParseDuration("24h")
|
||||
if tc.DefaultDuration != defaultDuration {
|
||||
t.Fatalf("Expected %d, got %d", defaultDuration, tc.DefaultInt)
|
||||
}
|
||||
|
||||
if tc.DefaultURL.String() != "http://example.com" {
|
||||
t.Fatalf("Expected http://example.com, got %s", tc.DefaultURL.String())
|
||||
}
|
||||
|
||||
if tc.DecoderStruct.String != "foo" {
|
||||
t.Fatalf("Expected foo, got %s", tc.DecoderStruct.String)
|
||||
}
|
||||
|
||||
if tc.DecoderStructPtr.String != "foo" {
|
||||
t.Fatalf("Expected foo, got %s", tc.DecoderStructPtr.String)
|
||||
}
|
||||
|
||||
if tc.DecoderString != "foo" {
|
||||
t.Fatalf("Expected foo, got %s", tc.DecoderString)
|
||||
}
|
||||
|
||||
if tc.UnmarshalerNumber != 07 {
|
||||
t.Fatalf("Expected 07, got %04o", tc.UnmarshalerNumber)
|
||||
}
|
||||
|
||||
os.Setenv("TEST_REQUIRED", "required")
|
||||
var tcr testConfigRequired
|
||||
|
||||
err = Decode(&tcr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tcr.Required != "required" {
|
||||
t.Fatalf("Expected \"required\", got %s", tcr.Required)
|
||||
}
|
||||
|
||||
_, err = Export(&tcr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var tco testConfigOverride
|
||||
err = Decode(&tco)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tco.OverrideString != "override_default" {
|
||||
t.Fatalf(`Expected "override_default" but got %s`, tco.OverrideString)
|
||||
}
|
||||
|
||||
os.Setenv("TEST_OVERRIDE_A", "override_a")
|
||||
|
||||
tco = testConfigOverride{}
|
||||
err = Decode(&tco)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tco.OverrideString != "override_a" {
|
||||
t.Fatalf(`Expected "override_a" but got %s`, tco.OverrideString)
|
||||
}
|
||||
|
||||
os.Setenv("TEST_OVERRIDE_B", "override_b")
|
||||
|
||||
tco = testConfigOverride{}
|
||||
err = Decode(&tco)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tco.OverrideString != "override_b" {
|
||||
t.Fatalf(`Expected "override_b" but got %s`, tco.OverrideString)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeErrors(t *testing.T) {
|
||||
var b bool
|
||||
err := Decode(&b)
|
||||
if err != ErrInvalidTarget {
|
||||
t.Fatal("Should have gotten an error decoding into a bool")
|
||||
}
|
||||
|
||||
var tc testConfig
|
||||
err = Decode(tc) //nolint:govet
|
||||
if err != ErrInvalidTarget {
|
||||
t.Fatal("Should have gotten an error decoding into a non-pointer")
|
||||
}
|
||||
|
||||
var tcp *testConfig
|
||||
err = Decode(tcp)
|
||||
if err != ErrInvalidTarget {
|
||||
t.Fatal("Should have gotten an error decoding to a nil pointer")
|
||||
}
|
||||
|
||||
var tnt testNoTags
|
||||
err = Decode(&tnt)
|
||||
if err != ErrNoTargetFieldsAreSet {
|
||||
t.Fatal("Should have gotten an error decoding a struct with no tags")
|
||||
}
|
||||
|
||||
var tcni testNoExportedFields
|
||||
err = Decode(&tcni)
|
||||
if err != ErrNoTargetFieldsAreSet {
|
||||
t.Fatal("Should have gotten an error decoding a struct with no unexported fields")
|
||||
}
|
||||
|
||||
var tcr testConfigRequired
|
||||
os.Clearenv()
|
||||
err = Decode(&tcr)
|
||||
if err == nil {
|
||||
t.Fatal("An error was expected but recieved:", err)
|
||||
}
|
||||
|
||||
var tcns testConfigNoSet
|
||||
err = Decode(&tcns)
|
||||
if err != ErrNoTargetFieldsAreSet {
|
||||
t.Fatal("Should have gotten an error decoding when no env variables are set")
|
||||
}
|
||||
|
||||
missing := false
|
||||
FailureFunc = func(err error) {
|
||||
missing = true
|
||||
}
|
||||
MustDecode(&tcr)
|
||||
if !missing {
|
||||
t.Fatal("The FailureFunc should have been called but it was not")
|
||||
}
|
||||
|
||||
var tcrd testConfigRequiredDefault
|
||||
defer func() {
|
||||
_ = recover()
|
||||
}()
|
||||
_ = Decode(&tcrd)
|
||||
t.Fatal("This should not have been reached. A panic should have occured.")
|
||||
}
|
||||
|
||||
func TestOnlyNested(t *testing.T) {
|
||||
os.Setenv("TEST_STRING", "foo")
|
||||
|
||||
// No env vars in the outer level are ok, as long as they're
|
||||
// in the inner struct.
|
||||
var o struct {
|
||||
Inner nested
|
||||
}
|
||||
if err := Decode(&o); err != nil {
|
||||
t.Fatalf("Expected no error, got %s", err)
|
||||
}
|
||||
|
||||
// No env vars in the inner levels are ok, as long as they're
|
||||
// in the outer struct.
|
||||
var o2 struct {
|
||||
Inner noConfig
|
||||
X string `env:"TEST_STRING"`
|
||||
}
|
||||
if err := Decode(&o2); err != nil {
|
||||
t.Fatalf("Expected no error, got %s", err)
|
||||
}
|
||||
|
||||
// No env vars in either outer or inner levels should result
|
||||
// in error
|
||||
var o3 struct {
|
||||
Inner noConfig
|
||||
}
|
||||
if err := Decode(&o3); err != ErrNoTargetFieldsAreSet {
|
||||
t.Fatalf("Expected ErrInvalidTarget, got %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleDecode() {
|
||||
type Example struct {
|
||||
// A string field, without any default
|
||||
String string `env:"EXAMPLE_STRING"`
|
||||
|
||||
// A uint16 field, with a default value of 100
|
||||
Uint16 uint16 `env:"EXAMPLE_UINT16,default=100"`
|
||||
}
|
||||
|
||||
os.Setenv("EXAMPLE_STRING", "an example!")
|
||||
|
||||
var e Example
|
||||
if err := Decode(&e); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// If TEST_STRING is set, e.String will contain its value
|
||||
fmt.Println(e.String)
|
||||
|
||||
// If TEST_UINT16 is set, e.Uint16 will contain its value.
|
||||
// Otherwise, it will contain the default value, 100.
|
||||
fmt.Println(e.Uint16)
|
||||
|
||||
// Output:
|
||||
// an example!
|
||||
// 100
|
||||
}
|
||||
|
||||
//// Export tests
|
||||
|
||||
type testConfigExport struct {
|
||||
String string `env:"TEST_STRING"`
|
||||
Int64 int64 `env:"TEST_INT64"`
|
||||
Uint16 uint16 `env:"TEST_UINT16"`
|
||||
Float64 float64 `env:"TEST_FLOAT64"`
|
||||
Bool bool `env:"TEST_BOOL"`
|
||||
Duration time.Duration `env:"TEST_DURATION"`
|
||||
URL *url.URL `env:"TEST_URL"`
|
||||
|
||||
StringSlice []string `env:"TEST_STRING_SLICE"`
|
||||
|
||||
UnsetString string `env:"TEST_UNSET_STRING"`
|
||||
UnsetInt64 int64 `env:"TEST_UNSET_INT64"`
|
||||
UnsetDuration time.Duration `env:"TEST_UNSET_DURATION"`
|
||||
UnsetURL *url.URL `env:"TEST_UNSET_URL"`
|
||||
|
||||
UnusedField string
|
||||
unexportedField string //nolint:structcheck,unused
|
||||
|
||||
IgnoredPtr *bool `env:"TEST_IGNORED_POINTER"`
|
||||
|
||||
Nested nestedConfigExport
|
||||
NestedPtr *nestedConfigExportPointer
|
||||
NestedPtrUnset *nestedConfigExportPointer
|
||||
|
||||
NestedTwice nestedTwiceConfig
|
||||
|
||||
NoConfig noConfig
|
||||
NoConfigPtr *noConfig
|
||||
NoConfigPtrSet *noConfig
|
||||
|
||||
RequiredInt int `env:"TEST_REQUIRED_INT,required"`
|
||||
|
||||
DefaultBool bool `env:"TEST_DEFAULT_BOOL,default=true"`
|
||||
DefaultInt int `env:"TEST_DEFAULT_INT,default=1234"`
|
||||
DefaultDuration time.Duration `env:"TEST_DEFAULT_DURATION,default=24h"`
|
||||
DefaultURL *url.URL `env:"TEST_DEFAULT_URL,default=http://example.com"`
|
||||
DefaultIntSet int `env:"TEST_DEFAULT_INT_SET,default=99"`
|
||||
DefaultIntSlice []int `env:"TEST_DEFAULT_INT_SLICE,default=99"`
|
||||
}
|
||||
|
||||
type nestedConfigExport struct {
|
||||
String string `env:"TEST_NESTED_STRING"`
|
||||
}
|
||||
|
||||
type nestedConfigExportPointer struct {
|
||||
String string `env:"TEST_NESTED_STRING_POINTER"`
|
||||
}
|
||||
|
||||
type noConfig struct {
|
||||
Int int
|
||||
}
|
||||
|
||||
type nestedTwiceConfig struct {
|
||||
Nested nestedConfigInner
|
||||
}
|
||||
|
||||
type nestedConfigInner struct {
|
||||
String string `env:"TEST_NESTED_TWICE_STRING"`
|
||||
}
|
||||
|
||||
type testConfigStrict struct {
|
||||
InvalidInt64Strict int64 `env:"TEST_INVALID_INT64,strict,default=1"`
|
||||
InvalidInt64Implicit int64 `env:"TEST_INVALID_INT64_IMPLICIT,default=1"`
|
||||
|
||||
Nested struct {
|
||||
InvalidInt64Strict int64 `env:"TEST_INVALID_INT64_NESTED,strict,required"`
|
||||
InvalidInt64Implicit int64 `env:"TEST_INVALID_INT64_NESTED_IMPLICIT,required"`
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidStrict(t *testing.T) {
|
||||
cases := []struct {
|
||||
decoder func(interface{}) error
|
||||
rootValue string
|
||||
nestedValue string
|
||||
rootValueImplicit string
|
||||
nestedValueImplicit string
|
||||
pass bool
|
||||
}{
|
||||
{Decode, "1", "1", "1", "1", true},
|
||||
{Decode, "1", "1", "1", "asdf", true},
|
||||
{Decode, "1", "1", "asdf", "1", true},
|
||||
{Decode, "1", "1", "asdf", "asdf", true},
|
||||
{Decode, "1", "asdf", "1", "1", false},
|
||||
{Decode, "asdf", "1", "1", "1", false},
|
||||
{Decode, "asdf", "asdf", "1", "1", false},
|
||||
{StrictDecode, "1", "1", "1", "1", true},
|
||||
{StrictDecode, "asdf", "1", "1", "1", false},
|
||||
{StrictDecode, "1", "asdf", "1", "1", false},
|
||||
{StrictDecode, "1", "1", "asdf", "1", false},
|
||||
{StrictDecode, "1", "1", "1", "asdf", false},
|
||||
{StrictDecode, "asdf", "asdf", "1", "1", false},
|
||||
{StrictDecode, "1", "asdf", "asdf", "1", false},
|
||||
{StrictDecode, "1", "1", "asdf", "asdf", false},
|
||||
{StrictDecode, "1", "asdf", "asdf", "asdf", false},
|
||||
{StrictDecode, "asdf", "asdf", "asdf", "asdf", false},
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
os.Setenv("TEST_INVALID_INT64", test.rootValue)
|
||||
os.Setenv("TEST_INVALID_INT64_NESTED", test.nestedValue)
|
||||
os.Setenv("TEST_INVALID_INT64_IMPLICIT", test.rootValueImplicit)
|
||||
os.Setenv("TEST_INVALID_INT64_NESTED_IMPLICIT", test.nestedValueImplicit)
|
||||
|
||||
var tc testConfigStrict
|
||||
if err := test.decoder(&tc); test.pass != (err == nil) {
|
||||
t.Fatalf("Have err=%s wanted pass=%v", err, test.pass)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExport(t *testing.T) {
|
||||
testFloat64 := fmt.Sprintf("%.48f", math.Pi)
|
||||
testFloat64Output := strconv.FormatFloat(math.Pi, 'f', -1, 64)
|
||||
testInt64 := fmt.Sprintf("%d", -(1 << 50))
|
||||
|
||||
os.Setenv("TEST_STRING", "foo")
|
||||
os.Setenv("TEST_INT64", testInt64)
|
||||
os.Setenv("TEST_UINT16", "60000")
|
||||
os.Setenv("TEST_FLOAT64", testFloat64)
|
||||
os.Setenv("TEST_BOOL", "true")
|
||||
os.Setenv("TEST_DURATION", "10m")
|
||||
os.Setenv("TEST_URL", "https://example.com")
|
||||
os.Setenv("TEST_STRING_SLICE", "foo,bar")
|
||||
os.Setenv("TEST_NESTED_STRING", "nest_foo")
|
||||
os.Setenv("TEST_NESTED_STRING_POINTER", "nest_foo_ptr")
|
||||
os.Setenv("TEST_NESTED_TWICE_STRING", "nest_twice_foo")
|
||||
os.Setenv("TEST_REQUIRED_INT", "101")
|
||||
os.Setenv("TEST_DEFAULT_INT_SET", "102")
|
||||
os.Setenv("TEST_DEFAULT_INT_SLICE", "1,2,3")
|
||||
|
||||
var tc testConfigExport
|
||||
tc.NestedPtr = &nestedConfigExportPointer{}
|
||||
tc.NoConfigPtrSet = &noConfig{}
|
||||
|
||||
if err := Decode(&tc); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rc, err := Export(&tc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := []*ConfigInfo{
|
||||
&ConfigInfo{
|
||||
Field: "String",
|
||||
EnvVar: "TEST_STRING",
|
||||
Value: "foo",
|
||||
UsesEnv: true,
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "Int64",
|
||||
EnvVar: "TEST_INT64",
|
||||
Value: testInt64,
|
||||
UsesEnv: true,
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "Uint16",
|
||||
EnvVar: "TEST_UINT16",
|
||||
Value: "60000",
|
||||
UsesEnv: true,
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "Float64",
|
||||
EnvVar: "TEST_FLOAT64",
|
||||
Value: testFloat64Output,
|
||||
UsesEnv: true,
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "Bool",
|
||||
EnvVar: "TEST_BOOL",
|
||||
Value: "true",
|
||||
UsesEnv: true,
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "Duration",
|
||||
EnvVar: "TEST_DURATION",
|
||||
Value: "10m0s",
|
||||
UsesEnv: true,
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "URL",
|
||||
EnvVar: "TEST_URL",
|
||||
Value: "https://example.com",
|
||||
UsesEnv: true,
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "StringSlice",
|
||||
EnvVar: "TEST_STRING_SLICE",
|
||||
Value: "[foo bar]",
|
||||
UsesEnv: true,
|
||||
},
|
||||
|
||||
&ConfigInfo{
|
||||
Field: "UnsetString",
|
||||
EnvVar: "TEST_UNSET_STRING",
|
||||
Value: "",
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "UnsetInt64",
|
||||
EnvVar: "TEST_UNSET_INT64",
|
||||
Value: "0",
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "UnsetDuration",
|
||||
EnvVar: "TEST_UNSET_DURATION",
|
||||
Value: "0s",
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "UnsetURL",
|
||||
EnvVar: "TEST_UNSET_URL",
|
||||
Value: "",
|
||||
},
|
||||
|
||||
&ConfigInfo{
|
||||
Field: "IgnoredPtr",
|
||||
EnvVar: "TEST_IGNORED_POINTER",
|
||||
Value: "",
|
||||
},
|
||||
|
||||
&ConfigInfo{
|
||||
Field: "Nested.String",
|
||||
EnvVar: "TEST_NESTED_STRING",
|
||||
Value: "nest_foo",
|
||||
UsesEnv: true,
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "NestedPtr.String",
|
||||
EnvVar: "TEST_NESTED_STRING_POINTER",
|
||||
Value: "nest_foo_ptr",
|
||||
UsesEnv: true,
|
||||
},
|
||||
|
||||
&ConfigInfo{
|
||||
Field: "NestedTwice.Nested.String",
|
||||
EnvVar: "TEST_NESTED_TWICE_STRING",
|
||||
Value: "nest_twice_foo",
|
||||
UsesEnv: true,
|
||||
},
|
||||
|
||||
&ConfigInfo{
|
||||
Field: "RequiredInt",
|
||||
EnvVar: "TEST_REQUIRED_INT",
|
||||
Value: "101",
|
||||
UsesEnv: true,
|
||||
Required: true,
|
||||
},
|
||||
|
||||
&ConfigInfo{
|
||||
Field: "DefaultBool",
|
||||
EnvVar: "TEST_DEFAULT_BOOL",
|
||||
Value: "true",
|
||||
DefaultValue: "true",
|
||||
HasDefault: true,
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "DefaultInt",
|
||||
EnvVar: "TEST_DEFAULT_INT",
|
||||
Value: "1234",
|
||||
DefaultValue: "1234",
|
||||
HasDefault: true,
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "DefaultDuration",
|
||||
EnvVar: "TEST_DEFAULT_DURATION",
|
||||
Value: "24h0m0s",
|
||||
DefaultValue: "24h",
|
||||
HasDefault: true,
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "DefaultURL",
|
||||
EnvVar: "TEST_DEFAULT_URL",
|
||||
Value: "http://example.com",
|
||||
DefaultValue: "http://example.com",
|
||||
HasDefault: true,
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "DefaultIntSet",
|
||||
EnvVar: "TEST_DEFAULT_INT_SET",
|
||||
Value: "102",
|
||||
DefaultValue: "99",
|
||||
HasDefault: true,
|
||||
UsesEnv: true,
|
||||
},
|
||||
&ConfigInfo{
|
||||
Field: "DefaultIntSlice",
|
||||
EnvVar: "TEST_DEFAULT_INT_SLICE",
|
||||
Value: "[1 2 3]",
|
||||
DefaultValue: "99",
|
||||
HasDefault: true,
|
||||
UsesEnv: true,
|
||||
},
|
||||
}
|
||||
|
||||
sort.Sort(ConfigInfoSlice(expected))
|
||||
|
||||
if len(rc) != len(expected) {
|
||||
t.Fatalf("Have %d results, expected %d", len(rc), len(expected))
|
||||
}
|
||||
|
||||
for n, v := range rc {
|
||||
ci := expected[n]
|
||||
if *ci != *v {
|
||||
t.Fatalf("have %+v, expected %+v", v, ci)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
gofig "github.com/gookit/config/v2"
|
||||
gooyaml "github.com/gookit/config/v2/yaml"
|
||||
"github.com/opencloud-eu/opencloud/pkg/config/defaults"
|
||||
)
|
||||
|
||||
var (
|
||||
// decoderConfigTagName sets the tag name to be used from the config structs
|
||||
// currently we only support "yaml" because we only support config loading
|
||||
// from yaml files and the yaml parser has no simple way to set a custom tag name to use
|
||||
decoderConfigTagName = "yaml"
|
||||
)
|
||||
|
||||
// BindSourcesToStructs assigns any config value from a config file / env variable to struct `dst`.
|
||||
func BindSourcesToStructs(service string, dst interface{}) error {
|
||||
fileSystem := os.DirFS("/")
|
||||
filePath := strings.TrimLeft(path.Join(defaults.BaseConfigPath(), service+".yaml"), "/")
|
||||
return bindSourcesToStructs(fileSystem, filePath, service, dst)
|
||||
}
|
||||
|
||||
func bindSourcesToStructs(fileSystem fs.FS, filePath, service string, dst interface{}) error {
|
||||
cnf := gofig.NewWithOptions(service)
|
||||
cnf.WithOptions(func(options *gofig.Options) {
|
||||
options.ParseEnv = true
|
||||
options.DecoderConfig.TagName = decoderConfigTagName
|
||||
})
|
||||
cnf.AddDriver(gooyaml.Driver)
|
||||
|
||||
yamlContent, err := fs.ReadFile(fileSystem, filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
_ = cnf.LoadSources("yaml", yamlContent)
|
||||
|
||||
err = cnf.BindStruct("", &dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalEndpoint returns the local endpoint for a given protocol and address.
|
||||
// Use it when configuring the reva runtime to get a service endpoint in the same
|
||||
// runtime, e.g. a gateway talking to an authregistry service.
|
||||
func LocalEndpoint(protocol, addr string) string {
|
||||
localEndpoint := addr
|
||||
switch protocol {
|
||||
case "tcp":
|
||||
parts := strings.SplitN(addr, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
localEndpoint = "dns:127.0.0.1:" + parts[1]
|
||||
}
|
||||
case "unix":
|
||||
localEndpoint = "unix:" + addr
|
||||
}
|
||||
return localEndpoint
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"gotest.tools/v3/assert"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
)
|
||||
|
||||
type TestConfig struct {
|
||||
A string `yaml:"a"`
|
||||
B string `yaml:"b"`
|
||||
C string `yaml:"c"`
|
||||
}
|
||||
|
||||
func TestBindSourcesToStructs(t *testing.T) {
|
||||
// setup test env
|
||||
yaml := `
|
||||
a: "${FOO_VAR|no-foo}"
|
||||
b: "${BAR_VAR|no-bar}"
|
||||
c: "${CODE_VAR|code}"
|
||||
`
|
||||
filePath := "etc/ocis/foo.yaml"
|
||||
fs := fstest.MapFS{
|
||||
filePath: {Data: []byte(yaml)},
|
||||
}
|
||||
// perform test
|
||||
c := TestConfig{}
|
||||
err := bindSourcesToStructs(fs, filePath, "foo", &c)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
assert.Equal(t, c.A, "no-foo")
|
||||
assert.Equal(t, c.B, "no-bar")
|
||||
assert.Equal(t, c.C, "code")
|
||||
}
|
||||
|
||||
func TestBindSourcesToStructs_UnknownFile(t *testing.T) {
|
||||
// setup test env
|
||||
filePath := "etc/ocis/foo.yaml"
|
||||
fs := fstest.MapFS{}
|
||||
// perform test
|
||||
c := TestConfig{}
|
||||
err := bindSourcesToStructs(fs, filePath, "foo", &c)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
assert.Equal(t, c.A, "")
|
||||
assert.Equal(t, c.B, "")
|
||||
assert.Equal(t, c.C, "")
|
||||
}
|
||||
|
||||
func TestBindSourcesToStructs_NoEnvVar(t *testing.T) {
|
||||
// setup test env
|
||||
yaml := `
|
||||
token_manager:
|
||||
jwt_secret: f%LovwC6xnKkHhc.!.Lp4ZYpQDIO7=d@
|
||||
machine_auth_api_key: jG&%ZCmCSYqT#Yi$9y28o5u84ZMo2UBf
|
||||
system_user_api_key: wqxH7FZHv5gifuLIzxqdyaZOCo2s^yl1
|
||||
transfer_secret: $1^2xspR1WHussV16knaJ$x@X*XLPL%y
|
||||
system_user_id: 4d0bf32c-83ee-4703-bd43-5e0d6b78215b
|
||||
admin_user_id: e2fca2b3-992b-47d5-8ecd-3312418ed3d7
|
||||
graph:
|
||||
application:
|
||||
id: 4fdff90c-d13c-47ab-8227-bbd3e6dbee3c
|
||||
events:
|
||||
tls_insecure: true
|
||||
spaces:
|
||||
insecure: true
|
||||
identity:
|
||||
ldap:
|
||||
bind_password: $ZZ8fSJR&YA02jBBPx6IRCzW0kVZ#cBO
|
||||
service_account:
|
||||
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
|
||||
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
|
||||
idp:
|
||||
ldap:
|
||||
bind_password: kWJGC6WRY1wQ+e8Bmt--=-3r6gp0CNVS
|
||||
idm:
|
||||
service_user_passwords:
|
||||
admin_password: admin
|
||||
idm_password: $ZZ8fSJR&YA02jBBPx6IRCzW0kVZ#cBO
|
||||
reva_password: c68JL=V$c@0GHs!%eSb8r&Ps3rgzKnXJ
|
||||
idp_password: kWJGC6WRY1wQ+e8Bmt--=-3r6gp0CNVS
|
||||
proxy:
|
||||
oidc:
|
||||
insecure: true
|
||||
insecure_backends: true
|
||||
service_account:
|
||||
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
|
||||
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
|
||||
frontend:
|
||||
app_handler:
|
||||
insecure: true
|
||||
archiver:
|
||||
insecure: true
|
||||
service_account:
|
||||
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
|
||||
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
|
||||
auth_basic:
|
||||
auth_providers:
|
||||
ldap:
|
||||
bind_password: c68JL=V$c@0GHs!%eSb8r&Ps3rgzKnXJ
|
||||
auth_bearer:
|
||||
auth_providers:
|
||||
oidc:
|
||||
insecure: true
|
||||
users:
|
||||
drivers:
|
||||
ldap:
|
||||
bind_password: c68JL=V$c@0GHs!%eSb8r&Ps3rgzKnXJ
|
||||
groups:
|
||||
drivers:
|
||||
ldap:
|
||||
bind_password: c68JL=V$c@0GHs!%eSb8r&Ps3rgzKnXJ
|
||||
ocdav:
|
||||
insecure: true
|
||||
ocm:
|
||||
service_account:
|
||||
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
|
||||
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
|
||||
thumbnails:
|
||||
thumbnail:
|
||||
transfer_secret: 0N05@YXB.h3e@lsVfksL4YxwQC9aE5A.
|
||||
webdav_allow_insecure: true
|
||||
cs3_allow_insecure: true
|
||||
search:
|
||||
events:
|
||||
tls_insecure: true
|
||||
service_account:
|
||||
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
|
||||
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
|
||||
audit:
|
||||
events:
|
||||
tls_insecure: true
|
||||
settings:
|
||||
service_account_ids:
|
||||
- c05389b2-d94c-4d01-a9b5-a2f97952cc14
|
||||
sharing:
|
||||
events:
|
||||
tls_insecure: true
|
||||
storage_users:
|
||||
events:
|
||||
tls_insecure: true
|
||||
mount_id: 64fdfb03-22ff-4788-be4d-d7731a475683
|
||||
service_account:
|
||||
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
|
||||
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
|
||||
notifications:
|
||||
notifications:
|
||||
events:
|
||||
tls_insecure: true
|
||||
service_account:
|
||||
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
|
||||
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
|
||||
nats:
|
||||
nats:
|
||||
tls_skip_verify_client_cert: true
|
||||
gateway:
|
||||
storage_registry:
|
||||
storage_users_mount_id: 64fdfb03-22ff-4788-be4d-d7731a475683
|
||||
userlog:
|
||||
service_account:
|
||||
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
|
||||
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
|
||||
auth_service:
|
||||
service_account:
|
||||
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
|
||||
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
|
||||
clientlog:
|
||||
service_account:
|
||||
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
|
||||
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
|
||||
`
|
||||
filePath := "etc/ocis/foo.yaml"
|
||||
fs := fstest.MapFS{
|
||||
filePath: {Data: []byte(yaml)},
|
||||
}
|
||||
// perform test
|
||||
c := Config{}
|
||||
err := bindSourcesToStructs(fs, filePath, "foo", &c)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
assert.Equal(t, c.Graph.Identity.LDAP.BindPassword, "$ZZ8fSJR&YA02jBBPx6IRCzW0kVZ#cBO")
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/config"
|
||||
"github.com/opencloud-eu/opencloud/pkg/config/envdecode"
|
||||
"github.com/opencloud-eu/opencloud/pkg/shared"
|
||||
"github.com/opencloud-eu/opencloud/pkg/structs"
|
||||
)
|
||||
|
||||
// ParseConfig loads the ocis configuration and
|
||||
// copies applicable parts into the commons part, from
|
||||
// where the services can copy it into their own config
|
||||
func ParseConfig(cfg *config.Config, skipValidate bool) error {
|
||||
err := config.BindSourcesToStructs("ocis", cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
EnsureDefaults(cfg)
|
||||
|
||||
// load all env variables relevant to the config in the current context.
|
||||
if err := envdecode.Decode(cfg); err != nil {
|
||||
// no environment variable set for this config is an expected "error"
|
||||
if !errors.Is(err, envdecode.ErrNoTargetFieldsAreSet) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
EnsureCommons(cfg)
|
||||
|
||||
if skipValidate {
|
||||
return nil
|
||||
}
|
||||
|
||||
return Validate(cfg)
|
||||
}
|
||||
|
||||
// EnsureDefaults ensures that all pointers in the
|
||||
// oCIS config (not the services configs) are initialized
|
||||
func EnsureDefaults(cfg *config.Config) {
|
||||
if cfg.Tracing == nil {
|
||||
cfg.Tracing = &shared.Tracing{}
|
||||
}
|
||||
if cfg.Log == nil {
|
||||
cfg.Log = &shared.Log{}
|
||||
}
|
||||
if cfg.TokenManager == nil {
|
||||
cfg.TokenManager = &shared.TokenManager{}
|
||||
}
|
||||
if cfg.Cache == nil {
|
||||
cfg.Cache = &shared.Cache{}
|
||||
}
|
||||
if cfg.GRPCClientTLS == nil {
|
||||
cfg.GRPCClientTLS = &shared.GRPCClientTLS{}
|
||||
}
|
||||
if cfg.GRPCServiceTLS == nil {
|
||||
cfg.GRPCServiceTLS = &shared.GRPCServiceTLS{}
|
||||
}
|
||||
if cfg.Reva == nil {
|
||||
cfg.Reva = &shared.Reva{}
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureCommons copies applicable parts of the oCIS config into the commons part
|
||||
func EnsureCommons(cfg *config.Config) {
|
||||
// ensure the commons part is initialized
|
||||
if cfg.Commons == nil {
|
||||
cfg.Commons = &shared.Commons{}
|
||||
}
|
||||
|
||||
cfg.Commons.Log = structs.CopyOrZeroValue(cfg.Log)
|
||||
cfg.Commons.Tracing = structs.CopyOrZeroValue(cfg.Tracing)
|
||||
cfg.Commons.Cache = structs.CopyOrZeroValue(cfg.Cache)
|
||||
|
||||
if cfg.GRPCClientTLS != nil {
|
||||
cfg.Commons.GRPCClientTLS = cfg.GRPCClientTLS
|
||||
}
|
||||
|
||||
if cfg.GRPCServiceTLS != nil {
|
||||
cfg.Commons.GRPCServiceTLS = cfg.GRPCServiceTLS
|
||||
}
|
||||
|
||||
cfg.Commons.HTTPServiceTLS = cfg.HTTPServiceTLS
|
||||
|
||||
cfg.Commons.TokenManager = structs.CopyOrZeroValue(cfg.TokenManager)
|
||||
|
||||
// copy machine auth api key to the commons part if set
|
||||
if cfg.MachineAuthAPIKey != "" {
|
||||
cfg.Commons.MachineAuthAPIKey = cfg.MachineAuthAPIKey
|
||||
}
|
||||
|
||||
if cfg.SystemUserAPIKey != "" {
|
||||
cfg.Commons.SystemUserAPIKey = cfg.SystemUserAPIKey
|
||||
}
|
||||
|
||||
// copy transfer secret to the commons part if set
|
||||
if cfg.TransferSecret != "" {
|
||||
cfg.Commons.TransferSecret = cfg.TransferSecret
|
||||
}
|
||||
|
||||
// copy metadata user id to the commons part if set
|
||||
if cfg.SystemUserID != "" {
|
||||
cfg.Commons.SystemUserID = cfg.SystemUserID
|
||||
}
|
||||
|
||||
// copy admin user id to the commons part if set
|
||||
if cfg.AdminUserID != "" {
|
||||
cfg.Commons.AdminUserID = cfg.AdminUserID
|
||||
}
|
||||
|
||||
if cfg.OcisURL != "" {
|
||||
cfg.Commons.OcisURL = cfg.OcisURL
|
||||
}
|
||||
|
||||
cfg.Commons.Reva = structs.CopyOrZeroValue(cfg.Reva)
|
||||
}
|
||||
|
||||
// Validate checks that all required configs are set. If a required config value
|
||||
// is missing an error will be returned.
|
||||
func Validate(cfg *config.Config) error {
|
||||
if cfg.TokenManager.JWTSecret == "" {
|
||||
return shared.MissingJWTTokenError("ocis")
|
||||
}
|
||||
|
||||
if cfg.TransferSecret == "" {
|
||||
return shared.MissingRevaTransferSecretError("ocis")
|
||||
}
|
||||
|
||||
if cfg.MachineAuthAPIKey == "" {
|
||||
return shared.MissingMachineAuthApiKeyError("ocis")
|
||||
}
|
||||
|
||||
if cfg.SystemUserID == "" {
|
||||
return shared.MissingSystemUserID("ocis")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package conversions
|
||||
|
||||
// ToPointer converts a value to a pointer
|
||||
func ToPointer[T any](val T) *T {
|
||||
return &val
|
||||
}
|
||||
|
||||
// ToValue converts a pointer to a value
|
||||
func ToValue[T any](ptr *T) T {
|
||||
if ptr == nil {
|
||||
var t T
|
||||
return t
|
||||
}
|
||||
|
||||
return *ptr
|
||||
}
|
||||
|
||||
// ToPointerSlice converts a slice of values to a slice of pointers
|
||||
func ToPointerSlice[E any](s []E) []*E {
|
||||
rs := make([]*E, len(s))
|
||||
|
||||
for i, v := range s {
|
||||
rs[i] = ToPointer(v)
|
||||
}
|
||||
|
||||
return rs
|
||||
}
|
||||
|
||||
// ToValueSlice converts a slice of pointers to a slice of values
|
||||
func ToValueSlice[E any](s []*E) []E {
|
||||
rs := make([]E, len(s))
|
||||
|
||||
for i, v := range s {
|
||||
rs[i] = ToValue(v)
|
||||
}
|
||||
|
||||
return rs
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package conversions_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
libregraph "github.com/owncloud/libre-graph-api-go"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/conversions"
|
||||
)
|
||||
|
||||
func checkIdentical[T any](t *testing.T, p T, want string) {
|
||||
t.Helper()
|
||||
got := fmt.Sprintf("%T", p)
|
||||
if got != want {
|
||||
t.Errorf("want:%q got:%q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToPointer2(t *testing.T) {
|
||||
checkIdentical(t, conversions.ToPointer("a"), "*string")
|
||||
checkIdentical(t, conversions.ToPointer(1), "*int")
|
||||
checkIdentical(t, conversions.ToPointer(-1), "*int")
|
||||
checkIdentical(t, conversions.ToPointer(float64(1)), "*float64")
|
||||
checkIdentical(t, conversions.ToPointer(float64(-1)), "*float64")
|
||||
checkIdentical(t, conversions.ToPointer(libregraph.UnifiedRoleDefinition{}), "*libregraph.UnifiedRoleDefinition")
|
||||
|
||||
checkIdentical(t, conversions.ToPointer([]string{"a"}), "*[]string")
|
||||
checkIdentical(t, conversions.ToPointer([]int{1}), "*[]int")
|
||||
checkIdentical(t, conversions.ToPointer([]float64{1}), "*[]float64")
|
||||
checkIdentical(t, conversions.ToPointer([]libregraph.UnifiedRoleDefinition{{}}), "*[]libregraph.UnifiedRoleDefinition")
|
||||
|
||||
checkIdentical(t, conversions.ToPointer(conversions.ToPointer("a")), "**string")
|
||||
checkIdentical(t, conversions.ToPointer(conversions.ToPointer(1)), "**int")
|
||||
checkIdentical(t, conversions.ToPointer(conversions.ToPointer(-1)), "**int")
|
||||
checkIdentical(t, conversions.ToPointer(conversions.ToPointer(float64(1))), "**float64")
|
||||
checkIdentical(t, conversions.ToPointer(conversions.ToPointer(float64(-1))), "**float64")
|
||||
checkIdentical(t, conversions.ToPointer(conversions.ToPointer(libregraph.UnifiedRoleDefinition{})), "**libregraph.UnifiedRoleDefinition")
|
||||
|
||||
checkIdentical(t, conversions.ToPointer(conversions.ToPointer([]string{"a"})), "**[]string")
|
||||
checkIdentical(t, conversions.ToPointer(conversions.ToPointer([]int{1})), "**[]int")
|
||||
checkIdentical(t, conversions.ToPointer(conversions.ToPointer([]float64{1})), "**[]float64")
|
||||
checkIdentical(t, conversions.ToPointer(conversions.ToPointer([]libregraph.UnifiedRoleDefinition{{}})), "**[]libregraph.UnifiedRoleDefinition")
|
||||
}
|
||||
|
||||
func TestToValue(t *testing.T) {
|
||||
checkIdentical(t, conversions.ToValue((*int)(nil)), "int")
|
||||
checkIdentical(t, conversions.ToValue((*string)(nil)), "string")
|
||||
checkIdentical(t, conversions.ToValue((*float64)(nil)), "float64")
|
||||
checkIdentical(t, conversions.ToValue((*libregraph.UnifiedRoleDefinition)(nil)), "libregraph.UnifiedRoleDefinition")
|
||||
checkIdentical(t, conversions.ToValue((*[]string)(nil)), "[]string")
|
||||
checkIdentical(t, conversions.ToValue((*[]libregraph.UnifiedRoleDefinition)(nil)), "[]libregraph.UnifiedRoleDefinition")
|
||||
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer("a")), "string")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer(1)), "int")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer(-1)), "int")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer(float64(1))), "float64")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer(float64(-1))), "float64")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer(libregraph.UnifiedRoleDefinition{})), "libregraph.UnifiedRoleDefinition")
|
||||
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer([]string{"a"})), "[]string")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer([]int{1})), "[]int")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer([]float64{1})), "[]float64")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer([]libregraph.UnifiedRoleDefinition{{}})), "[]libregraph.UnifiedRoleDefinition")
|
||||
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer("a"))), "*string")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer(1))), "*int")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer(-1))), "*int")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer(float64(1)))), "*float64")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer(float64(-1)))), "*float64")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer(libregraph.UnifiedRoleDefinition{}))), "*libregraph.UnifiedRoleDefinition")
|
||||
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer([]string{"a"}))), "*[]string")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer([]int{1}))), "*[]int")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer([]float64{1}))), "*[]float64")
|
||||
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer([]libregraph.UnifiedRoleDefinition{{}}))), "*[]libregraph.UnifiedRoleDefinition")
|
||||
}
|
||||
|
||||
func TestToPointerSlice(t *testing.T) {
|
||||
checkIdentical(t, conversions.ToPointerSlice([]string{"a"}), "[]*string")
|
||||
checkIdentical(t, conversions.ToPointerSlice([]int{1}), "[]*int")
|
||||
checkIdentical(t, conversions.ToPointerSlice([]libregraph.UnifiedRoleDefinition{{}}), "[]*libregraph.UnifiedRoleDefinition")
|
||||
|
||||
checkIdentical(t, conversions.ToPointerSlice(([]string)(nil)), "[]*string")
|
||||
checkIdentical(t, conversions.ToPointerSlice(([]int)(nil)), "[]*int")
|
||||
checkIdentical(t, conversions.ToPointerSlice(([]libregraph.UnifiedRoleDefinition)(nil)), "[]*libregraph.UnifiedRoleDefinition")
|
||||
|
||||
checkIdentical(t, conversions.ToPointerSlice([]*string{conversions.ToPointer("a")}), "[]**string")
|
||||
checkIdentical(t, conversions.ToPointerSlice([]*int{conversions.ToPointer(1)}), "[]**int")
|
||||
checkIdentical(t, conversions.ToPointerSlice(([]*libregraph.UnifiedRoleDefinition)(nil)), "[]**libregraph.UnifiedRoleDefinition")
|
||||
|
||||
checkIdentical(t, conversions.ToPointerSlice(([]*string)(nil)), "[]**string")
|
||||
checkIdentical(t, conversions.ToPointerSlice(([]*int)(nil)), "[]**int")
|
||||
checkIdentical(t, conversions.ToPointerSlice(([]*libregraph.UnifiedRoleDefinition)(nil)), "[]**libregraph.UnifiedRoleDefinition")
|
||||
}
|
||||
|
||||
func TestToValueSlice(t *testing.T) {
|
||||
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice([]string{"a"})), "[]string")
|
||||
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice([]int{1})), "[]int")
|
||||
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice([]libregraph.UnifiedRoleDefinition{{}})), "[]libregraph.UnifiedRoleDefinition")
|
||||
|
||||
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice(([]string)(nil))), "[]string")
|
||||
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice(([]int)(nil))), "[]int")
|
||||
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice(([]libregraph.UnifiedRoleDefinition)(nil))), "[]libregraph.UnifiedRoleDefinition")
|
||||
|
||||
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice([]*string{conversions.ToPointer("a")})), "[]*string")
|
||||
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice([]*int{conversions.ToPointer(1)})), "[]*int")
|
||||
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice([]*libregraph.UnifiedRoleDefinition{conversions.ToPointer(libregraph.UnifiedRoleDefinition{})})), "[]*libregraph.UnifiedRoleDefinition")
|
||||
|
||||
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice(([]*string)(nil))), "[]*string")
|
||||
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice(([]*int)(nil))), "[]*int")
|
||||
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice(([]*libregraph.UnifiedRoleDefinition)(nil))), "[]*libregraph.UnifiedRoleDefinition")
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package conversions
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// StringToSliceString splits a string into a slice string according to separator
|
||||
func StringToSliceString(src string, sep string) []string {
|
||||
parsed := strings.Split(src, sep)
|
||||
parts := make([]string, 0, len(parsed))
|
||||
for _, v := range parsed {
|
||||
parts = append(parts, strings.TrimSpace(v))
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package conversions
|
||||
|
||||
import "testing"
|
||||
|
||||
var scenarios = []struct {
|
||||
name string
|
||||
input string
|
||||
separator string
|
||||
out []string
|
||||
}{
|
||||
{
|
||||
"comma separated input",
|
||||
"a, b, c, d",
|
||||
",",
|
||||
[]string{"a", "b", "c", "d"},
|
||||
}, {
|
||||
"space separated input",
|
||||
"a b c d",
|
||||
" ",
|
||||
[]string{"a", "b", "c", "d"},
|
||||
},
|
||||
}
|
||||
|
||||
func TestStringToSliceString(t *testing.T) {
|
||||
for _, tt := range scenarios {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := StringToSliceString(tt.input, tt.separator)
|
||||
for i, v := range tt.out {
|
||||
if s[i] != v {
|
||||
t.Errorf("got %q, want %q", s, tt.out)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
)
|
||||
|
||||
// Option defines a single option function.
|
||||
type Option func(o *Options)
|
||||
|
||||
// Options defines the available options for this package.
|
||||
type Options struct {
|
||||
// Logger to use for logging, must be set
|
||||
Logger log.Logger
|
||||
// AllowedOrigins represents the allowed CORS origins
|
||||
AllowedOrigins []string
|
||||
// AllowedMethods represents the allowed CORS methods
|
||||
AllowedMethods []string
|
||||
// AllowedHeaders represents the allowed CORS headers
|
||||
AllowedHeaders []string
|
||||
// AllowCredentials represents the AllowCredentials CORS option
|
||||
AllowCredentials bool
|
||||
}
|
||||
|
||||
// newAccountOptions initializes the available default options.
|
||||
func NewOptions(opts ...Option) Options {
|
||||
opt := Options{}
|
||||
|
||||
for _, o := range opts {
|
||||
o(&opt)
|
||||
}
|
||||
|
||||
return opt
|
||||
}
|
||||
|
||||
// Logger provides a function to set the logger option.
|
||||
func Logger(l log.Logger) Option {
|
||||
return func(o *Options) {
|
||||
o.Logger = l
|
||||
}
|
||||
}
|
||||
|
||||
// AllowedOrigins provides a function to set the AllowedOrigins option.
|
||||
func AllowedOrigins(origins []string) Option {
|
||||
return func(o *Options) {
|
||||
o.AllowedOrigins = origins
|
||||
}
|
||||
}
|
||||
|
||||
// AllowedMethods provides a function to set the AllowedMethods option.
|
||||
func AllowedMethods(methods []string) Option {
|
||||
return func(o *Options) {
|
||||
o.AllowedMethods = methods
|
||||
}
|
||||
}
|
||||
|
||||
// AllowedHeaders provides a function to set the AllowedHeaders option.
|
||||
func AllowedHeaders(headers []string) Option {
|
||||
return func(o *Options) {
|
||||
o.AllowedHeaders = headers
|
||||
}
|
||||
}
|
||||
|
||||
// AlloweCredentials provides a function to set the AllowCredentials option.
|
||||
func AllowCredentials(allow bool) Option {
|
||||
return func(o *Options) {
|
||||
o.AllowCredentials = allow
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
// Package crypto implements utility functions for handling crypto related files.
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// NewCertPoolFromPEM reads certificates from io.Reader and returns a x509.CertPool
|
||||
// containing those certificates.
|
||||
func NewCertPoolFromPEM(crts ...io.Reader) (*x509.CertPool, error) {
|
||||
certPool := x509.NewCertPool()
|
||||
|
||||
var buf bytes.Buffer
|
||||
for _, c := range crts {
|
||||
if _, err := io.Copy(&buf, c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !certPool.AppendCertsFromPEM(buf.Bytes()) {
|
||||
return nil, errors.New("failed to append cert from PEM")
|
||||
}
|
||||
buf.Reset()
|
||||
}
|
||||
|
||||
return certPool, nil
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package crypto_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestCrypto(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Crypto Suite")
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package crypto_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/crypto"
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
cfg "github.com/opencloud-eu/opencloud/pkg/config"
|
||||
)
|
||||
|
||||
var _ = Describe("Crypto", func() {
|
||||
var (
|
||||
userConfigDir string
|
||||
err error
|
||||
config = cfg.DefaultConfig()
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
userConfigDir, err = os.UserConfigDir()
|
||||
if err != nil {
|
||||
Fail(err.Error())
|
||||
}
|
||||
config.Proxy.HTTP.TLSKey = filepath.Join(userConfigDir, "ocis", "server.key")
|
||||
config.Proxy.HTTP.TLSCert = filepath.Join(userConfigDir, "ocis", "server.cert")
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
if err := os.RemoveAll(filepath.Join(userConfigDir, "ocis")); err != nil {
|
||||
Fail(err.Error())
|
||||
}
|
||||
})
|
||||
|
||||
// This little test should nail down the main functionality of this package, which is providing with a default location
|
||||
// for the key / certificate pair in case none is configured. Regardless of how the values ended in the configuration,
|
||||
// the side effects of GenCert is what we want to test.
|
||||
Describe("Creating key / certificate pair", func() {
|
||||
Context("For ocis-proxy in the location of the user config directory", func() {
|
||||
It(fmt.Sprintf("Creates the cert / key tuple in: %s", filepath.Join(userConfigDir, "ocis")), func() {
|
||||
if err := crypto.GenCert(config.Proxy.HTTP.TLSCert, config.Proxy.HTTP.TLSKey, log.NopLogger()); err != nil {
|
||||
Fail(err.Error())
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(userConfigDir, "ocis", "server.key")); err != nil {
|
||||
Fail("key not found at the expected location")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(userConfigDir, "ocis", "server.cert")); err != nil {
|
||||
Fail("certificate not found at the expected location")
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
Describe("Creating a new cert pool", func() {
|
||||
var (
|
||||
crtOne string
|
||||
keyOne string
|
||||
crtTwo string
|
||||
keyTwo string
|
||||
)
|
||||
BeforeEach(func() {
|
||||
crtOne = filepath.Join(userConfigDir, "ocis/one.cert")
|
||||
keyOne = filepath.Join(userConfigDir, "ocis/one.key")
|
||||
crtTwo = filepath.Join(userConfigDir, "ocis/two.cert")
|
||||
keyTwo = filepath.Join(userConfigDir, "ocis/two.key")
|
||||
if err := crypto.GenCert(crtOne, keyOne, log.NopLogger()); err != nil {
|
||||
Fail(err.Error())
|
||||
}
|
||||
if err := crypto.GenCert(crtTwo, keyTwo, log.NopLogger()); err != nil {
|
||||
Fail(err.Error())
|
||||
}
|
||||
})
|
||||
It("handles one certificate", func() {
|
||||
f1, _ := os.Open(crtOne)
|
||||
defer f1.Close()
|
||||
|
||||
c, err := crypto.NewCertPoolFromPEM(f1)
|
||||
if err != nil {
|
||||
Fail(err.Error())
|
||||
}
|
||||
if len(c.Subjects()) != 1 {
|
||||
Fail("expected 1 certificate in the cert pool")
|
||||
}
|
||||
})
|
||||
It("handles multiple certificates", func() {
|
||||
f1, _ := os.Open(crtOne)
|
||||
f2, _ := os.Open(crtTwo)
|
||||
defer f1.Close()
|
||||
defer f2.Close()
|
||||
|
||||
c, err := crypto.NewCertPoolFromPEM(f1, f2)
|
||||
if err != nil {
|
||||
Fail(err.Error())
|
||||
}
|
||||
if len(c.Subjects()) != 2 {
|
||||
Fail("expected 2 certificates in the cert pool")
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,167 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
mtls "go-micro.dev/v4/util/tls"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultHosts = []string{"127.0.0.1", "localhost"}
|
||||
)
|
||||
|
||||
// GenCert generates TLS-Certificates. This function has side effects: it creates the respective certificate / key pair at
|
||||
// the destination locations unless the tuple already exists, if that is the case, this is a noop.
|
||||
func GenCert(certName string, keyName string, l log.Logger) error {
|
||||
var pk *rsa.PrivateKey
|
||||
var err error
|
||||
|
||||
pk, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, certErr := os.Stat(certName)
|
||||
_, keyErr := os.Stat(keyName)
|
||||
|
||||
if certErr == nil || keyErr == nil {
|
||||
l.Info().Msg(
|
||||
fmt.Sprintf("%v certificate / key pair already present. skipping acme certificate generation",
|
||||
filepath.Base(certName)))
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := persistCertificate(certName, l, pk); err != nil {
|
||||
l.Fatal().Err(err).Msg("failed to store certificate")
|
||||
}
|
||||
|
||||
if err := persistKey(keyName, l, pk); err != nil {
|
||||
l.Fatal().Err(err).Msg("failed to store key")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenTempCertForAddr generates temporary TLS-Certificates in memory.
|
||||
func GenTempCertForAddr(addr string) (tls.Certificate, error) {
|
||||
subjects := defaultHosts
|
||||
|
||||
if host, _, err := net.SplitHostPort(addr); err == nil && host != "" {
|
||||
subjects = []string{host}
|
||||
}
|
||||
return mtls.Certificate(subjects...)
|
||||
}
|
||||
|
||||
// persistCertificate generates a certificate using pk as private key and proceeds to store it into a file named certName.
|
||||
func persistCertificate(certName string, l log.Logger, pk interface{}) error {
|
||||
if err := ensureExistsDir(certName); err != nil {
|
||||
return fmt.Errorf("creating certificate destination: " + certName)
|
||||
}
|
||||
|
||||
certificate, err := generateCertificate(pk)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating certificate: " + filepath.Dir(certName))
|
||||
}
|
||||
|
||||
certOut, err := os.Create(certName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open `%v` for writing", certName)
|
||||
}
|
||||
|
||||
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certificate})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode certificate")
|
||||
}
|
||||
|
||||
err = certOut.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write cert")
|
||||
}
|
||||
l.Info().Msg(fmt.Sprintf("written certificate to %v", certName))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// genCert generates a self signed certificate using a random rsa key.
|
||||
func generateCertificate(pk interface{}) ([]byte, error) {
|
||||
for _, h := range defaultHosts {
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
acmeTemplate.IPAddresses = append(acmeTemplate.IPAddresses, ip)
|
||||
} else {
|
||||
acmeTemplate.DNSNames = append(acmeTemplate.DNSNames, h)
|
||||
}
|
||||
}
|
||||
|
||||
return x509.CreateCertificate(rand.Reader, &acmeTemplate, &acmeTemplate, publicKey(pk), pk)
|
||||
}
|
||||
|
||||
// persistKey persists the private key used to generate the certificate at the configured location.
|
||||
func persistKey(destination string, l log.Logger, pk interface{}) error {
|
||||
if err := ensureExistsDir(destination); err != nil {
|
||||
return fmt.Errorf("creating key destination: " + destination)
|
||||
}
|
||||
|
||||
keyOut, err := os.OpenFile(destination, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %v for writing", destination)
|
||||
}
|
||||
err = pem.Encode(keyOut, pemBlockForKey(pk, l))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode key")
|
||||
}
|
||||
|
||||
err = keyOut.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write key")
|
||||
}
|
||||
l.Info().Msg(fmt.Sprintf("written key to %v", destination))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func publicKey(pk interface{}) interface{} {
|
||||
switch k := pk.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
case *ecdsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func pemBlockForKey(pk interface{}, l log.Logger) *pem.Block {
|
||||
switch k := pk.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)}
|
||||
case *ecdsa.PrivateKey:
|
||||
b, err := x509.MarshalECPrivateKey(k)
|
||||
if err != nil {
|
||||
l.Fatal().Err(err).Msg("Unable to marshal ECDSA private key")
|
||||
}
|
||||
return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func ensureExistsDir(uri string) error {
|
||||
certPath := filepath.Dir(uri)
|
||||
if _, err := os.Stat(certPath); os.IsNotExist(err) {
|
||||
err = os.MkdirAll(certPath, 0700)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
)
|
||||
|
||||
func TestEnsureExistsDir(t *testing.T) {
|
||||
var tmpDir = t.TempDir()
|
||||
|
||||
type args struct {
|
||||
uri string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "creates a dir if it does not exist",
|
||||
args: args{
|
||||
uri: filepath.Join(tmpDir, "example"),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "noop if the target directory exists",
|
||||
args: args{
|
||||
uri: filepath.Join(tmpDir, "example"),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := ensureExistsDir(tt.args.uri); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ensureExistsDir() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistKey(t *testing.T) {
|
||||
p256 := elliptic.P256()
|
||||
var (
|
||||
tmpDir = t.TempDir()
|
||||
keyPath = filepath.Join(tmpDir, "ocis", "testKey")
|
||||
rsaPk, _ = rsa.GenerateKey(rand.Reader, 2048)
|
||||
ecdsaPk, _ = ecdsa.GenerateKey(p256, rand.Reader)
|
||||
)
|
||||
|
||||
type args struct {
|
||||
keyName string
|
||||
pk interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
}{
|
||||
{
|
||||
name: "writes a private key (rsa) to the specified location",
|
||||
args: args{
|
||||
keyName: keyPath,
|
||||
pk: rsaPk,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "writes a private key (ecdsa) to the specified location",
|
||||
args: args{
|
||||
keyName: keyPath,
|
||||
pk: ecdsaPk,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := persistKey(tt.args.keyName, log.NopLogger(), tt.args.pk); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
|
||||
// side effect: tt.args.keyName is created
|
||||
if _, err := os.Stat(tt.args.keyName); err != nil {
|
||||
t.Errorf("persistKey() error = %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistCertificate(t *testing.T) {
|
||||
p256 := elliptic.P256()
|
||||
var (
|
||||
tmpDir = t.TempDir()
|
||||
certPath = filepath.Join(tmpDir, "ocis", "testCert")
|
||||
rsaPk, _ = rsa.GenerateKey(rand.Reader, 2048)
|
||||
ecdsaPk, _ = ecdsa.GenerateKey(p256, rand.Reader)
|
||||
)
|
||||
|
||||
type args struct {
|
||||
certName string
|
||||
pk interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "store a certificate with an rsa private key",
|
||||
args: args{
|
||||
certName: certPath,
|
||||
pk: rsaPk,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "store a certificate with an ecdsa private key",
|
||||
args: args{
|
||||
certName: certPath,
|
||||
pk: ecdsaPk,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "should fail",
|
||||
args: args{
|
||||
certName: certPath,
|
||||
pk: 42,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := persistCertificate(tt.args.certName, log.NopLogger(), tt.args.pk); err != nil {
|
||||
if !tt.wantErr {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// side effect: tt.args.keyName is created
|
||||
if _, err := os.Stat(tt.args.certName); err != nil {
|
||||
t.Errorf("persistCertificate() error = %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"time"
|
||||
)
|
||||
|
||||
var serialNumber, _ = rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
|
||||
var acmeTemplate = x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Acme Corp"},
|
||||
CommonName: "OCIS",
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour * 365),
|
||||
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package flags
|
||||
|
||||
// OverrideDefaultString checks whether the default value of v is the zero value, if so, ensure the flag has a correct
|
||||
// value by providing one. A value different than zero would mean that it was read from a config file either from an
|
||||
// service or from a higher source (i.e: ocis command).
|
||||
func OverrideDefaultString(v, def string) string {
|
||||
if v != "" {
|
||||
return v
|
||||
}
|
||||
|
||||
return def
|
||||
}
|
||||
|
||||
// OverrideDefaultBool checks whether the default value of v is the zero value, if so, ensure the flag has a correct
|
||||
// value by providing one. A value different than zero would mean that it was read from a config file either from an
|
||||
// service or from a higher source (i.e: ocis command).
|
||||
func OverrideDefaultBool(v, def bool) bool {
|
||||
if v {
|
||||
return v
|
||||
}
|
||||
|
||||
return def
|
||||
}
|
||||
|
||||
// OverrideDefaultInt checks whether the default value of v is the zero value, if so, ensure the flag has a correct
|
||||
// value by providing one. A value different than zero would mean that it was read from a config file either from an
|
||||
// service or from a higher source (i.e: ocis command).
|
||||
func OverrideDefaultInt(v, def int) int {
|
||||
if v != 0 {
|
||||
return v
|
||||
}
|
||||
|
||||
return def
|
||||
}
|
||||
|
||||
// OverrideDefaultInt64 checks whether the default value of v is the zero value, if so, ensure the flag has a correct
|
||||
// value by providing one. A value different than zero would mean that it was read from a config file either from an
|
||||
// service or from a higher source (i.e: ocis command).
|
||||
func OverrideDefaultInt64(v, def int64) int64 {
|
||||
if v != 0 {
|
||||
return v
|
||||
}
|
||||
|
||||
return def
|
||||
}
|
||||
|
||||
// OverrideDefaultUint64 checks whether the default value of v is the zero value, if so, ensure the flag has a correct
|
||||
// value by providing one. A value different than zero would mean that it was read from a config file either from an
|
||||
// service or from a higher source (i.e: ocis command).
|
||||
func OverrideDefaultUint64(v, def uint64) uint64 {
|
||||
if v != 0 {
|
||||
return v
|
||||
}
|
||||
|
||||
return def
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package generators
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
const (
|
||||
// PasswordChars contains alphanumeric chars (0-9, A-Z, a-z), plus "-=+!@#$%^&*."
|
||||
PasswordChars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-=+!@#$%^&*."
|
||||
// AlphaNumChars contains alphanumeric chars (0-9, A-Z, a-z)
|
||||
AlphaNumChars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
)
|
||||
|
||||
// GenerateRandomPassword generates a random password with the given length.
|
||||
// The password will contain chars picked from the `PasswordChars` constant.
|
||||
// If an error happens, the string will be empty and the error will be non-nil.
|
||||
//
|
||||
// This is equivalent to `GenerateRandomString(PasswordChars, length)`
|
||||
func GenerateRandomPassword(length int) (string, error) {
|
||||
return generateString(PasswordChars, length)
|
||||
}
|
||||
|
||||
// GenerateRandomString generates a random string with the given length
|
||||
// based on the chars provided. You can use `PasswordChars` or `AlphaNumChars`
|
||||
// constants, or even any other string.
|
||||
//
|
||||
// Chars from the provided string will be picked uniformly. The provided
|
||||
// constants have unique chars, which means that all the chars will have the
|
||||
// same probability of being picked.
|
||||
// You can use your own strings to change that probability. For example, using
|
||||
// "AAAB" you'll have a 75% of probability of getting "A" and 25% of "B"
|
||||
func GenerateRandomString(chars string, length int) (string, error) {
|
||||
return generateString(chars, length)
|
||||
}
|
||||
|
||||
func generateString(chars string, length int) (string, error) {
|
||||
ret := make([]byte, length)
|
||||
for i := 0; i < length; i++ {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(chars))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ret[i] = chars[num.Int64()]
|
||||
}
|
||||
|
||||
return string(ret), nil
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// check is a function that performs a check.
|
||||
type checker func(ctx context.Context) error
|
||||
|
||||
// CheckHandlerConfiguration defines the configuration for the CheckHandler.
|
||||
type CheckHandlerConfiguration struct {
|
||||
checks map[string]checker
|
||||
logger log.Logger
|
||||
limit int
|
||||
statusFailed int
|
||||
statusSuccess int
|
||||
}
|
||||
|
||||
// NewCheckHandlerConfiguration initializes a new CheckHandlerConfiguration.
|
||||
func NewCheckHandlerConfiguration() CheckHandlerConfiguration {
|
||||
return CheckHandlerConfiguration{
|
||||
checks: make(map[string]checker),
|
||||
|
||||
limit: -1,
|
||||
statusFailed: http.StatusInternalServerError,
|
||||
statusSuccess: http.StatusOK,
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger sets the logger for the CheckHandlerConfiguration.
|
||||
func (c CheckHandlerConfiguration) WithLogger(l log.Logger) CheckHandlerConfiguration {
|
||||
c.logger = l
|
||||
return c
|
||||
}
|
||||
|
||||
// WithCheck sets a check for the CheckHandlerConfiguration.
|
||||
func (c CheckHandlerConfiguration) WithCheck(name string, check checker) CheckHandlerConfiguration {
|
||||
if _, ok := c.checks[name]; ok {
|
||||
c.logger.Panic().Str("check", name).Msg("check already exists")
|
||||
}
|
||||
|
||||
c.checks = maps.Clone(c.checks) // prevent propagated check duplication, maps are references;
|
||||
c.checks[name] = check
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// WithLimit limits the number of active goroutines for the checks to at most n
|
||||
func (c CheckHandlerConfiguration) WithLimit(n int) CheckHandlerConfiguration {
|
||||
c.limit = n
|
||||
return c
|
||||
}
|
||||
|
||||
// WithStatusFailed sets the status code for the failed checks.
|
||||
func (c CheckHandlerConfiguration) WithStatusFailed(status int) CheckHandlerConfiguration {
|
||||
c.statusFailed = status
|
||||
return c
|
||||
}
|
||||
|
||||
// WithStatusSuccess sets the status code for the successful checks.
|
||||
func (c CheckHandlerConfiguration) WithStatusSuccess(status int) CheckHandlerConfiguration {
|
||||
c.statusSuccess = status
|
||||
return c
|
||||
}
|
||||
|
||||
// CheckHandler is a http Handler that performs different checks.
|
||||
type CheckHandler struct {
|
||||
conf CheckHandlerConfiguration
|
||||
}
|
||||
|
||||
// NewCheckHandler initializes a new CheckHandler.
|
||||
func NewCheckHandler(c CheckHandlerConfiguration) *CheckHandler {
|
||||
c.checks = maps.Clone(c.checks) // prevent check duplication after initialization
|
||||
return &CheckHandler{
|
||||
conf: c,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *CheckHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
g, ctx := errgroup.WithContext(r.Context())
|
||||
g.SetLimit(h.conf.limit)
|
||||
|
||||
for name, check := range h.conf.checks {
|
||||
checker := check
|
||||
checkerName := name
|
||||
g.Go(func() error { // https://go.dev/blog/loopvar-preview per iteration scope since go 1.22
|
||||
if err := checker(ctx); err != nil { // since go 1.22 for loops have a per-iteration scope instead of per-loop scope, no need to pin the check...
|
||||
return fmt.Errorf("'%s': %w", checkerName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
status := h.conf.statusSuccess
|
||||
if err := g.Wait(); err != nil {
|
||||
status = h.conf.statusFailed
|
||||
h.conf.logger.Error().Err(err).Msg("check failed")
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(status)
|
||||
|
||||
if _, err := io.WriteString(w, http.StatusText(status)); err != nil { // io.WriteString should not fail, but if it does, we want to know.
|
||||
h.conf.logger.Panic().Err(err).Msg("failed to write response")
|
||||
}
|
||||
}
|
||||
|
||||
// FailSaveAddress replaces wildcard addresses with the outbound IP.
|
||||
func FailSaveAddress(address string) (string, error) {
|
||||
if strings.Contains(address, "0.0.0.0") || strings.Contains(address, "::") {
|
||||
outboundIp, err := getOutBoundIP()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
address = strings.Replace(address, "0.0.0.0", outboundIp, 1)
|
||||
address = strings.Replace(address, "::", "["+outboundIp+"]", 1)
|
||||
address = strings.Replace(address, "[::]", "["+outboundIp+"]", 1)
|
||||
}
|
||||
return address, nil
|
||||
}
|
||||
|
||||
// getOutBoundIP returns the outbound IP address.
|
||||
func getOutBoundIP() (string, error) {
|
||||
interfacesAddresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, address := range interfacesAddresses {
|
||||
if ipNet, ok := address.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
|
||||
if ipNet.IP.To4() != nil {
|
||||
return ipNet.IP.String(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no IP found")
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/test-go/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/handlers"
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
)
|
||||
|
||||
func TestCheckHandlerConfiguration(t *testing.T) {
|
||||
nopCheckCounter := 0
|
||||
nopCheck := func(_ context.Context) error { nopCheckCounter++; return nil }
|
||||
handlerConfiguration := handlers.NewCheckHandlerConfiguration().WithCheck("check-1", nopCheck)
|
||||
|
||||
t.Run("add check", func(t *testing.T) {
|
||||
localCounter := 0
|
||||
handlers.NewCheckHandler(handlerConfiguration).ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil))
|
||||
require.Equal(t, 1, nopCheckCounter)
|
||||
|
||||
handlers.NewCheckHandler(handlerConfiguration.WithCheck("check-2", func(_ context.Context) error { localCounter++; return nil })).ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil))
|
||||
require.Equal(t, 2, nopCheckCounter)
|
||||
require.Equal(t, 1, localCounter)
|
||||
})
|
||||
|
||||
t.Run("checks are unique", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("checks should be unique")
|
||||
}
|
||||
}()
|
||||
|
||||
handlerConfiguration.WithCheck("check-1", nopCheck)
|
||||
require.Equal(t, 3, nopCheckCounter)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheckHandler(t *testing.T) {
|
||||
checkFactory := func(err error) func(ctx context.Context) error {
|
||||
return func(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("passes with custom status", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
handler := handlers.NewCheckHandler(
|
||||
handlers.
|
||||
NewCheckHandlerConfiguration().
|
||||
WithStatusSuccess(http.StatusCreated),
|
||||
)
|
||||
|
||||
handler.ServeHTTP(rec, httptest.NewRequest("GET", "/", nil))
|
||||
require.Equal(t, http.StatusCreated, rec.Code)
|
||||
require.Equal(t, http.StatusText(http.StatusCreated), rec.Body.String())
|
||||
})
|
||||
|
||||
t.Run("is not ok if any check fails", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
handler := handlers.NewCheckHandler(
|
||||
handlers.
|
||||
NewCheckHandlerConfiguration().
|
||||
WithCheck("check-1", checkFactory(errors.New("failed"))),
|
||||
)
|
||||
handler.ServeHTTP(rec, httptest.NewRequest("GET", "/", nil))
|
||||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
require.Equal(t, http.StatusText(http.StatusInternalServerError), rec.Body.String())
|
||||
})
|
||||
|
||||
t.Run("fails with custom status", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
handler := handlers.NewCheckHandler(
|
||||
handlers.
|
||||
NewCheckHandlerConfiguration().
|
||||
WithCheck("check-1", checkFactory(errors.New("failed"))).
|
||||
WithStatusFailed(http.StatusTeapot),
|
||||
)
|
||||
handler.ServeHTTP(rec, httptest.NewRequest("GET", "/", nil))
|
||||
require.Equal(t, http.StatusTeapot, rec.Code)
|
||||
require.Equal(t, http.StatusText(http.StatusTeapot), rec.Body.String())
|
||||
})
|
||||
|
||||
t.Run("exits all other running tests on failure", func(t *testing.T) {
|
||||
var errs []error
|
||||
rec := httptest.NewRecorder()
|
||||
buffer := &bytes.Buffer{}
|
||||
logger := log.Logger{Logger: log.NewLogger().Output(buffer)}
|
||||
handler := handlers.NewCheckHandler(
|
||||
handlers.
|
||||
NewCheckHandlerConfiguration().
|
||||
WithLogger(logger).
|
||||
WithCheck("check-1", func(ctx context.Context) error {
|
||||
err := checkFactory(nil)(ctx)
|
||||
errs = append(errs, err)
|
||||
return err
|
||||
}).
|
||||
WithCheck("check-2", func(ctx context.Context) error {
|
||||
err := checkFactory(errors.New("failed"))(ctx)
|
||||
errs = append(errs, err)
|
||||
return err
|
||||
}).
|
||||
WithCheck("check-3", func(ctx context.Context) error {
|
||||
err := checkFactory(nil)(ctx)
|
||||
errs = append(errs, err)
|
||||
return err
|
||||
}),
|
||||
)
|
||||
handler.ServeHTTP(rec, httptest.NewRequest("GET", "/", nil))
|
||||
|
||||
require.Equal(t, "'check-2': failed", gjson.Get(buffer.String(), "error").String())
|
||||
require.Equal(t, 1, len(slices.DeleteFunc(errs, func(err error) bool { return err == nil })))
|
||||
require.Equal(t, 2, len(slices.DeleteFunc(errs, func(err error) bool { return err != nil })))
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
// Package keycloak is a package for keycloak utility functions.
|
||||
package keycloak
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"github.com/Nerzal/gocloak/v13"
|
||||
libregraph "github.com/owncloud/libre-graph-api-go"
|
||||
)
|
||||
|
||||
// Some attribute constants.
|
||||
// TODO: Make these configurable in the future.
|
||||
const (
|
||||
_idAttr = "OWNCLOUD_ID"
|
||||
_userTypeAttr = "OWNCLOUD_USER_TYPE"
|
||||
)
|
||||
|
||||
// ConcreteClient represents a concrete implementation of a keycloak client
|
||||
type ConcreteClient struct {
|
||||
keycloak GoCloak
|
||||
clientID string
|
||||
clientSecret string
|
||||
realm string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// New instantiates a new keycloak.Backend with a default gocloak client.
|
||||
func New(
|
||||
baseURL, clientID, clientSecret, realm string,
|
||||
insecureSkipVerify bool,
|
||||
) *ConcreteClient {
|
||||
gc := gocloak.NewClient(baseURL)
|
||||
restyClient := gc.RestyClient()
|
||||
restyClient.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: insecureSkipVerify}) //nolint:gosec
|
||||
return NewWithClient(gc, baseURL, clientID, clientSecret, realm)
|
||||
}
|
||||
|
||||
// NewWithClient instantiates a new keycloak.Backend with a custom
|
||||
func NewWithClient(
|
||||
gocloakClient GoCloak,
|
||||
baseURL, clientID, clientSecret, realm string,
|
||||
) *ConcreteClient {
|
||||
return &ConcreteClient{
|
||||
keycloak: gocloakClient,
|
||||
baseURL: baseURL,
|
||||
clientID: clientID,
|
||||
clientSecret: clientSecret,
|
||||
realm: realm,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateUser creates a user from a libregraph user and returns its *keycloak* ID.
|
||||
// TODO: For now we only call this from the invitation service where all the attributes are set correctly.
|
||||
//
|
||||
// For more wider use, do some sanity checking on the user instance.
|
||||
func (c *ConcreteClient) CreateUser(ctx context.Context, realm string, user *libregraph.User, userActions []UserAction) (string, error) {
|
||||
token, err := c.getToken(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
req := gocloak.User{
|
||||
Email: user.Mail,
|
||||
Enabled: user.AccountEnabled,
|
||||
Username: &user.OnPremisesSamAccountName,
|
||||
FirstName: user.GivenName,
|
||||
LastName: user.Surname,
|
||||
Attributes: &map[string][]string{
|
||||
_idAttr: {user.GetId()},
|
||||
_userTypeAttr: {user.GetUserType()},
|
||||
},
|
||||
RequiredActions: convertUserActions(userActions),
|
||||
}
|
||||
return c.keycloak.CreateUser(ctx, token.AccessToken, realm, req)
|
||||
}
|
||||
|
||||
// SendActionsMail sends a mail to the user with userID instructing them to do the actions defined in userActions.
|
||||
func (c *ConcreteClient) SendActionsMail(ctx context.Context, realm, userID string, userActions []UserAction) error {
|
||||
token, err := c.getToken(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params := gocloak.ExecuteActionsEmail{
|
||||
UserID: &userID,
|
||||
Actions: convertUserActions(userActions),
|
||||
}
|
||||
|
||||
return c.keycloak.ExecuteActionsEmail(ctx, token.AccessToken, realm, params)
|
||||
}
|
||||
|
||||
// getUserByParams looks up a user by the given parameters.
|
||||
func (c *ConcreteClient) getUserByParams(ctx context.Context, realm string, params gocloak.GetUsersParams) (*libregraph.User, error) {
|
||||
token, err := c.getToken(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users, err := c.keycloak.GetUsers(ctx, token.AccessToken, realm, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(users) == 0 {
|
||||
return nil, fmt.Errorf("no users found")
|
||||
}
|
||||
|
||||
if len(users) > 1 {
|
||||
return nil, fmt.Errorf("%d users found", len(users))
|
||||
}
|
||||
|
||||
return c.keycloakUserToLibregraph(users[0]), nil
|
||||
}
|
||||
|
||||
// GetUserByUsername looks up a user by username.
|
||||
func (c *ConcreteClient) GetUserByUsername(ctx context.Context, realm, username string) (*libregraph.User, error) {
|
||||
return c.getUserByParams(ctx, realm, gocloak.GetUsersParams{
|
||||
Username: &username,
|
||||
})
|
||||
}
|
||||
|
||||
// GetPIIReport returns a structure with all the PII for the user.
|
||||
func (c *ConcreteClient) GetPIIReport(ctx context.Context, realm, username string) (*PIIReport, error) {
|
||||
u, err := c.GetUserByUsername(ctx, realm, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token, err := c.getToken(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keycloakID, err := c.getKeyCloakID(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessions, err := c.keycloak.GetUserSessions(ctx, token.AccessToken, realm, keycloakID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PIIReport{
|
||||
UserData: u,
|
||||
Sessions: sessions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getToken gets a fresh token for the request.
|
||||
// TODO: set a token on the struct and check if it's still valid before requesting a new one.
|
||||
func (c *ConcreteClient) getToken(ctx context.Context) (*gocloak.JWT, error) {
|
||||
token, err := c.keycloak.LoginClient(ctx, c.clientID, c.clientSecret, c.realm)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
|
||||
rRes, err := c.keycloak.RetrospectToken(ctx, token.AccessToken, c.clientID, c.clientSecret, c.realm)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrospect token: %w", err)
|
||||
}
|
||||
|
||||
if !*rRes.Active {
|
||||
return nil, fmt.Errorf("token is not active")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (c *ConcreteClient) keycloakUserToLibregraph(u *gocloak.User) *libregraph.User {
|
||||
var ldapID string
|
||||
var userType *string
|
||||
|
||||
if u.Attributes != nil {
|
||||
attrs := *u.Attributes
|
||||
ldapIDs, ok := attrs[_idAttr]
|
||||
if ok {
|
||||
ldapID = ldapIDs[0]
|
||||
}
|
||||
|
||||
userTypes, ok := attrs[_userTypeAttr]
|
||||
if ok {
|
||||
userType = &userTypes[0]
|
||||
}
|
||||
}
|
||||
|
||||
return &libregraph.User{
|
||||
Id: &ldapID,
|
||||
Mail: u.Email,
|
||||
GivenName: u.FirstName,
|
||||
Surname: u.LastName,
|
||||
AccountEnabled: u.Enabled,
|
||||
UserType: userType,
|
||||
Identities: []libregraph.ObjectIdentity{
|
||||
{
|
||||
Issuer: &c.baseURL,
|
||||
IssuerAssignedId: u.ID,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConcreteClient) getKeyCloakID(u *libregraph.User) (string, error) {
|
||||
for _, i := range u.Identities {
|
||||
if *i.Issuer == c.baseURL {
|
||||
return *i.IssuerAssignedId, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("could not find identity for issuer: %s", c.baseURL)
|
||||
}
|
||||
|
||||
func convertUserActions(userActions []UserAction) *[]string {
|
||||
stringActions := make([]string, len(userActions))
|
||||
for i, a := range userActions {
|
||||
stringActions[i] = userActionsToString[a]
|
||||
}
|
||||
return &stringActions
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package keycloak
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Nerzal/gocloak/v13"
|
||||
)
|
||||
|
||||
// GoCloak represents the parts of gocloak.GoCloak that we use, mainly here for mockery.
|
||||
type GoCloak interface {
|
||||
CreateUser(ctx context.Context, token, realm string, user gocloak.User) (string, error)
|
||||
GetUsers(ctx context.Context, token, realm string, params gocloak.GetUsersParams) ([]*gocloak.User, error)
|
||||
ExecuteActionsEmail(ctx context.Context, token, realm string, params gocloak.ExecuteActionsEmail) error
|
||||
LoginClient(ctx context.Context, clientID, clientSecret, realm string, scopes ...string) (*gocloak.JWT, error)
|
||||
RetrospectToken(ctx context.Context, accessToken, clientID, clientSecret, realm string) (*gocloak.IntroSpectTokenResult, error)
|
||||
GetCredentials(ctx context.Context, accessToken, realm, userID string) ([]*gocloak.CredentialRepresentation, error)
|
||||
GetUserSessions(ctx context.Context, token, realm, userID string) ([]*gocloak.UserSessionRepresentation, error)
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package keycloak
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Nerzal/gocloak/v13"
|
||||
libregraph "github.com/owncloud/libre-graph-api-go"
|
||||
)
|
||||
|
||||
// UserAction defines a type for user actions
|
||||
type UserAction int8
|
||||
|
||||
// An incomplete list of UserActions
|
||||
const (
|
||||
// UserActionUpdatePassword sets it that the user needs to change their password.
|
||||
UserActionUpdatePassword UserAction = iota
|
||||
// UserActionVerifyEmail sets it that the user needs to verify their email address.
|
||||
UserActionVerifyEmail
|
||||
)
|
||||
|
||||
// A lookup table to translate user actions to their string equivalents
|
||||
var userActionsToString = map[UserAction]string{
|
||||
UserActionUpdatePassword: "UPDATE_PASSWORD",
|
||||
UserActionVerifyEmail: "VERIFY_EMAIL",
|
||||
}
|
||||
|
||||
// PIIReport is a structure of all the PersonalIdentifiableInformation contained in keycloak.
|
||||
type PIIReport struct {
|
||||
UserData *libregraph.User
|
||||
Sessions []*gocloak.UserSessionRepresentation
|
||||
}
|
||||
|
||||
// Client represents a keycloak client.
|
||||
type Client interface {
|
||||
CreateUser(ctx context.Context, realm string, user *libregraph.User, userActions []UserAction) (string, error)
|
||||
SendActionsMail(ctx context.Context, realm, userID string, userActions []UserAction) error
|
||||
GetUserByUsername(ctx context.Context, realm, username string) (*libregraph.User, error)
|
||||
GetPIIReport(ctx context.Context, realm, username string) (*PIIReport, error)
|
||||
}
|
||||
+138
@@ -0,0 +1,138 @@
|
||||
package kql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/now"
|
||||
"github.com/opencloud-eu/opencloud/pkg/ast"
|
||||
"github.com/opencloud-eu/opencloud/services/search/pkg/query"
|
||||
)
|
||||
|
||||
func toNode[T ast.Node](in interface{}) (T, error) {
|
||||
var t T
|
||||
out, ok := in.(T)
|
||||
if !ok {
|
||||
return t, fmt.Errorf("can't convert '%T' to '%T'", in, t)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func toNodes[T ast.Node](in interface{}) ([]T, error) {
|
||||
switch v := in.(type) {
|
||||
case T:
|
||||
return []T{v}, nil
|
||||
case []T:
|
||||
return v, nil
|
||||
case []*ast.OperatorNode, []*ast.DateTimeNode:
|
||||
return toNodes[T](v)
|
||||
case []interface{}:
|
||||
var nodes []T
|
||||
for _, el := range v {
|
||||
node, err := toNodes[T](el)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes = append(nodes, node...)
|
||||
}
|
||||
return nodes, nil
|
||||
case nil:
|
||||
return nil, nil
|
||||
default:
|
||||
var t T
|
||||
return nil, fmt.Errorf("can't convert '%T' to '%T'", in, t)
|
||||
}
|
||||
}
|
||||
|
||||
func toString(in interface{}) (string, error) {
|
||||
switch v := in.(type) {
|
||||
case []byte:
|
||||
return string(v), nil
|
||||
case []interface{}:
|
||||
var str string
|
||||
|
||||
for i := range v {
|
||||
sv, err := toString(v[i])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
str += sv
|
||||
}
|
||||
|
||||
return str, nil
|
||||
case string:
|
||||
return v, nil
|
||||
default:
|
||||
return "", fmt.Errorf("can't convert '%T' to string", v)
|
||||
}
|
||||
}
|
||||
|
||||
func toTime(in interface{}) (time.Time, error) {
|
||||
ts, err := toString(in)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
return now.Parse(ts)
|
||||
}
|
||||
|
||||
func toTimeRange(in interface{}) (*time.Time, *time.Time, error) {
|
||||
var from, to time.Time
|
||||
|
||||
value, err := toString(in)
|
||||
if err != nil {
|
||||
return &from, &to, &query.UnsupportedTimeRangeError{}
|
||||
}
|
||||
|
||||
c := &now.Config{
|
||||
WeekStartDay: time.Monday,
|
||||
}
|
||||
|
||||
n := c.With(timeNow())
|
||||
|
||||
switch value {
|
||||
case "today":
|
||||
from = n.BeginningOfDay()
|
||||
to = n.EndOfDay()
|
||||
case "yesterday":
|
||||
yesterday := n.With(n.AddDate(0, 0, -1))
|
||||
from = yesterday.BeginningOfDay()
|
||||
to = yesterday.EndOfDay()
|
||||
case "this week":
|
||||
from = n.BeginningOfWeek()
|
||||
to = n.EndOfWeek()
|
||||
case "last week":
|
||||
lastWeek := n.With(n.AddDate(0, 0, -7))
|
||||
from = lastWeek.BeginningOfWeek()
|
||||
to = lastWeek.EndOfWeek()
|
||||
case "last 7 days":
|
||||
from = n.With(n.AddDate(0, 0, -6)).BeginningOfDay()
|
||||
to = n.EndOfDay()
|
||||
case "this month":
|
||||
from = n.BeginningOfMonth()
|
||||
to = n.EndOfMonth()
|
||||
case "last month":
|
||||
lastMonth := n.With(n.BeginningOfMonth().AddDate(0, 0, -1))
|
||||
from = lastMonth.BeginningOfMonth()
|
||||
to = lastMonth.EndOfMonth()
|
||||
case "last 30 days":
|
||||
from = n.With(n.AddDate(0, 0, -29)).BeginningOfDay()
|
||||
to = n.EndOfDay()
|
||||
case "this year":
|
||||
from = n.BeginningOfYear()
|
||||
to = n.EndOfYear()
|
||||
case "last year":
|
||||
lastYear := n.With(n.AddDate(-1, 0, 0))
|
||||
from = lastYear.BeginningOfYear()
|
||||
to = lastYear.EndOfYear()
|
||||
}
|
||||
|
||||
if from.IsZero() || to.IsZero() {
|
||||
return nil, nil, &query.UnsupportedTimeRangeError{}
|
||||
}
|
||||
|
||||
return &from, &to, nil
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
package kql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/ast"
|
||||
)
|
||||
|
||||
// connectNodes connects given nodes
|
||||
func connectNodes(c Connector, nodes ...ast.Node) []ast.Node {
|
||||
var connectedNodes []ast.Node
|
||||
|
||||
for i := range nodes {
|
||||
ri := len(nodes) - 1 - i
|
||||
head := nodes[ri]
|
||||
|
||||
if connectionNodes := connectNode(c, head, connectedNodes...); len(connectionNodes) > 0 {
|
||||
connectedNodes = append(connectionNodes, connectedNodes...)
|
||||
}
|
||||
|
||||
connectedNodes = append([]ast.Node{head}, connectedNodes...)
|
||||
}
|
||||
|
||||
return connectedNodes
|
||||
}
|
||||
|
||||
// connectNode connects a tip node with the rest
|
||||
func connectNode(c Connector, headNode ast.Node, tailNodes ...ast.Node) []ast.Node {
|
||||
var nearestNeighborNode ast.Node
|
||||
var nearestNeighborOperators []*ast.OperatorNode
|
||||
|
||||
l:
|
||||
for _, tailNode := range tailNodes {
|
||||
switch node := tailNode.(type) {
|
||||
case *ast.OperatorNode:
|
||||
nearestNeighborOperators = append(nearestNeighborOperators, node)
|
||||
default:
|
||||
nearestNeighborNode = node
|
||||
break l
|
||||
}
|
||||
}
|
||||
|
||||
if nearestNeighborNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.Connect(headNode, nearestNeighborNode, nearestNeighborOperators)
|
||||
}
|
||||
|
||||
// Connector is responsible to decide what node connections are needed
|
||||
type Connector interface {
|
||||
Connect(head ast.Node, neighbor ast.Node, connections []*ast.OperatorNode) []ast.Node
|
||||
}
|
||||
|
||||
// DefaultConnector is the default node connector
|
||||
type DefaultConnector struct {
|
||||
sameKeyOPValue string
|
||||
}
|
||||
|
||||
// Connect implements the Connector interface and is used to connect the nodes using
|
||||
// the default logic defined by the kql spec.
|
||||
func (c DefaultConnector) Connect(head ast.Node, neighbor ast.Node, connections []*ast.OperatorNode) []ast.Node {
|
||||
switch head.(type) {
|
||||
case *ast.OperatorNode:
|
||||
return nil
|
||||
}
|
||||
|
||||
headKey := strings.ToLower(ast.NodeKey(head))
|
||||
neighborKey := strings.ToLower(ast.NodeKey(neighbor))
|
||||
|
||||
connection := &ast.OperatorNode{
|
||||
Base: &ast.Base{Loc: &ast.Location{Source: &[]string{"implicitly operator"}[0]}},
|
||||
Value: BoolAND,
|
||||
}
|
||||
|
||||
// if the current node and the neighbor node have the same key
|
||||
// the connection is of type OR
|
||||
//
|
||||
// spec: same
|
||||
// author:"John Smith" author:"Jane Smith"
|
||||
// author:"John Smith" OR author:"Jane Smith"
|
||||
//
|
||||
// if the nodes have NO key, the edge is a AND connection
|
||||
//
|
||||
// spec: same
|
||||
// cat dog
|
||||
// cat AND dog
|
||||
// from the spec:
|
||||
// To construct complex queries, you can combine multiple
|
||||
// free-text expressions with KQL query operators.
|
||||
// If there are multiple free-text expressions without any
|
||||
// operators in between them, the query behavior is the same
|
||||
// as using the AND operator.
|
||||
//
|
||||
// nodes inside of group node are handled differently,
|
||||
// if no explicit operator given, it uses AND
|
||||
//
|
||||
// spec: same
|
||||
// author:"John Smith" AND author:"Jane Smith"
|
||||
// author:("John Smith" "Jane Smith")
|
||||
if headKey == neighborKey && headKey != "" && neighborKey != "" {
|
||||
connection.Value = c.sameKeyOPValue
|
||||
}
|
||||
|
||||
// decisions based on nearest neighbor operators
|
||||
for i, node := range connections {
|
||||
// consider direct neighbor operator only
|
||||
if i == 0 {
|
||||
// no connection is necessary here because an `AND` or `OR` edge is already present
|
||||
// exit
|
||||
for _, skipValue := range []string{BoolOR, BoolAND} {
|
||||
if node.Value == skipValue {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// if neighbor node negotiates, an AND edge is needed
|
||||
//
|
||||
// spec: same
|
||||
// cat -dog
|
||||
// cat AND NOT dog
|
||||
if node.Value == BoolNOT {
|
||||
connection.Value = BoolAND
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return []ast.Node{connection}
|
||||
}
|
||||
@@ -0,0 +1,235 @@
|
||||
{
|
||||
package kql
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// ast
|
||||
////////////////////////////////////////////////////////
|
||||
|
||||
AST <-
|
||||
n:Nodes {
|
||||
return buildAST(n, c.text, c.pos)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// nodes
|
||||
////////////////////////////////////////////////////////
|
||||
|
||||
Nodes <-
|
||||
(_ Node)+
|
||||
|
||||
Node <-
|
||||
GroupNode /
|
||||
PropertyRestrictionNodes /
|
||||
OperatorBooleanNodes /
|
||||
FreeTextKeywordNodes
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// nesting
|
||||
////////////////////////////////////////////////////////
|
||||
|
||||
GroupNode <-
|
||||
k:(Char+)? (OperatorColonNode / OperatorEqualNode)? "(" v:Nodes ")" {
|
||||
return buildGroupNode(k, v, c.text, c.pos)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// property restrictions
|
||||
////////////////////////////////////////////////////////
|
||||
|
||||
PropertyRestrictionNodes <-
|
||||
YesNoPropertyRestrictionNode /
|
||||
DateTimeRestrictionNode /
|
||||
TextPropertyRestrictionNode
|
||||
|
||||
YesNoPropertyRestrictionNode <-
|
||||
k:Char+ (OperatorColonNode / OperatorEqualNode) v:("true" / "false"){
|
||||
return buildBooleanNode(k, v, c.text, c.pos)
|
||||
}
|
||||
|
||||
DateTimeRestrictionNode <-
|
||||
k:Char+ o:(
|
||||
OperatorGreaterOrEqualNode /
|
||||
OperatorLessOrEqualNode /
|
||||
OperatorGreaterNode /
|
||||
OperatorLessNode /
|
||||
OperatorEqualNode /
|
||||
OperatorColonNode
|
||||
) '"'? v:(
|
||||
DateTime /
|
||||
FullDate /
|
||||
FullTime
|
||||
) '"'? {
|
||||
return buildDateTimeNode(k, o, v, c.text, c.pos)
|
||||
} /
|
||||
k:Char+ (
|
||||
OperatorEqualNode /
|
||||
OperatorColonNode
|
||||
) '"'? v:NaturalLanguageDateTime '"'? {
|
||||
return buildNaturalLanguageDateTimeNodes(k, v, c.text, c.pos)
|
||||
}
|
||||
|
||||
TextPropertyRestrictionNode <-
|
||||
k:Char+ (OperatorColonNode / OperatorEqualNode) v:(String / [^ ()]+){
|
||||
return buildStringNode(k, v, c.text, c.pos)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// free text-keywords
|
||||
////////////////////////////////////////////////////////
|
||||
|
||||
FreeTextKeywordNodes <-
|
||||
PhraseNode /
|
||||
WordNode
|
||||
|
||||
PhraseNode <-
|
||||
OperatorColonNode? _ v:String _ OperatorColonNode? {
|
||||
return buildStringNode("", v, c.text, c.pos)
|
||||
}
|
||||
|
||||
WordNode <-
|
||||
OperatorColonNode? _ v:[^ :()]+ _ OperatorColonNode? {
|
||||
return buildStringNode("", v, c.text, c.pos)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// operators
|
||||
////////////////////////////////////////////////////////
|
||||
|
||||
OperatorBooleanNodes <-
|
||||
OperatorBooleanAndNode /
|
||||
OperatorBooleanNotNode /
|
||||
OperatorBooleanOrNode
|
||||
|
||||
OperatorBooleanAndNode <-
|
||||
("AND" / "+") {
|
||||
return buildOperatorNode(c.text, c.pos)
|
||||
}
|
||||
|
||||
OperatorBooleanNotNode <-
|
||||
("NOT" / "-") {
|
||||
return buildOperatorNode(c.text, c.pos)
|
||||
}
|
||||
|
||||
OperatorBooleanOrNode <-
|
||||
("OR") {
|
||||
return buildOperatorNode(c.text, c.pos)
|
||||
}
|
||||
|
||||
OperatorColonNode <-
|
||||
":" {
|
||||
return buildOperatorNode(c.text, c.pos)
|
||||
}
|
||||
|
||||
OperatorEqualNode <-
|
||||
"=" {
|
||||
return buildOperatorNode(c.text, c.pos)
|
||||
}
|
||||
|
||||
OperatorLessNode <-
|
||||
"<" {
|
||||
return buildOperatorNode(c.text, c.pos)
|
||||
}
|
||||
|
||||
OperatorLessOrEqualNode <-
|
||||
"<=" {
|
||||
return buildOperatorNode(c.text, c.pos)
|
||||
}
|
||||
|
||||
OperatorGreaterNode <-
|
||||
">" {
|
||||
return buildOperatorNode(c.text, c.pos)
|
||||
}
|
||||
|
||||
OperatorGreaterOrEqualNode <-
|
||||
">=" {
|
||||
return buildOperatorNode(c.text, c.pos)
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// time
|
||||
////////////////////////////////////////////////////////
|
||||
|
||||
TimeYear <-
|
||||
Digit Digit Digit Digit {
|
||||
return c.text, nil
|
||||
}
|
||||
|
||||
TimeMonth <-
|
||||
Digit Digit {
|
||||
return c.text, nil
|
||||
}
|
||||
|
||||
TimeDay <-
|
||||
Digit Digit {
|
||||
return c.text, nil
|
||||
}
|
||||
|
||||
TimeHour <-
|
||||
Digit Digit {
|
||||
return c.text, nil
|
||||
}
|
||||
|
||||
TimeMinute <-
|
||||
Digit Digit {
|
||||
return c.text, nil
|
||||
}
|
||||
|
||||
TimeSecond <-
|
||||
Digit Digit {
|
||||
return c.text, nil
|
||||
}
|
||||
|
||||
FullDate <-
|
||||
TimeYear "-" TimeMonth "-" TimeDay {
|
||||
return c.text, nil
|
||||
}
|
||||
|
||||
FullTime <-
|
||||
TimeHour ":" TimeMinute ":" TimeSecond ("." Digit+)? ("Z" / ("+" / "-") TimeHour ":" TimeMinute) {
|
||||
return c.text, nil
|
||||
}
|
||||
|
||||
DateTime <-
|
||||
FullDate "T" FullTime {
|
||||
return c.text, nil
|
||||
}
|
||||
|
||||
NaturalLanguageDateTime <-
|
||||
"today" /
|
||||
"yesterday" /
|
||||
"this week" /
|
||||
"last week" /
|
||||
"last 7 days" /
|
||||
"this month" /
|
||||
"last month" /
|
||||
"last 30 days" /
|
||||
"this year" /
|
||||
"last year" {
|
||||
return c.text, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// misc
|
||||
////////////////////////////////////////////////////////
|
||||
|
||||
Char <-
|
||||
[A-Za-z] {
|
||||
return c.text, nil
|
||||
}
|
||||
|
||||
String <-
|
||||
'"' v:[^"]* '"' {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
Digit <-
|
||||
[0-9] {
|
||||
return c.text, nil
|
||||
}
|
||||
|
||||
_ <-
|
||||
[ \t]* {
|
||||
return nil, nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,27 @@
|
||||
/*
|
||||
Package kql provides the ability to work with kql queries.
|
||||
|
||||
Not every aspect of the spec is implemented yet.
|
||||
The language support will grow over time if needed.
|
||||
|
||||
The following spec parts are supported and tested:
|
||||
- 2.1.2 AND Operator
|
||||
- 2.1.6 NOT Operator
|
||||
- 2.1.8 OR Operator
|
||||
- 2.1.12 Parentheses
|
||||
- 2.3.5 Date Tokens
|
||||
- 3.1.11 Implicit Operator
|
||||
- 3.1.12 Parentheses
|
||||
- 3.1.2 AND Operator
|
||||
- 3.1.6 NOT Operator
|
||||
- 3.1.8 OR Operator
|
||||
- 3.2.3 Implicit Operator for Property Restriction
|
||||
- 3.3.1.1.1 Implicit AND Operator
|
||||
- 3.3.5 Date Tokens
|
||||
|
||||
References:
|
||||
- https://learn.microsoft.com/en-us/sharepoint/dev/general-development/keyword-query-language-kql-syntax-reference
|
||||
- https://learn.microsoft.com/en-us/openspecs/sharepoint_protocols/ms-kql/3bbf06cd-8fc1-4277-bd92-8661ccd3c9b0
|
||||
- https://msopenspecs.azureedge.net/files/MS-KQL/%5bMS-KQL%5d.pdf
|
||||
*/
|
||||
package kql
|
||||
@@ -0,0 +1,11 @@
|
||||
package kql
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// PatchTimeNow is here to patch the package time now func,
|
||||
// which is used in the test suite
|
||||
func PatchTimeNow(t func() time.Time) {
|
||||
timeNow = t
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package kql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/ast"
|
||||
)
|
||||
|
||||
func base(text []byte, pos position) (*ast.Base, error) {
|
||||
source, err := toString(text)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ast.Base{
|
||||
Loc: &ast.Location{
|
||||
Start: ast.Position{
|
||||
Line: pos.line,
|
||||
Column: pos.col,
|
||||
},
|
||||
End: ast.Position{
|
||||
Line: pos.line,
|
||||
Column: pos.col + len(text),
|
||||
},
|
||||
Source: &source,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildAST(n interface{}, text []byte, pos position) (*ast.Ast, error) {
|
||||
b, err := base(text, pos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes, err := toNodes[ast.Node](n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a := &ast.Ast{
|
||||
Base: b,
|
||||
Nodes: connectNodes(DefaultConnector{sameKeyOPValue: BoolOR}, nodes...),
|
||||
}
|
||||
|
||||
if err := validateAst(a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func buildStringNode(k, v interface{}, text []byte, pos position) (*ast.StringNode, error) {
|
||||
b, err := base(text, pos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := toString(k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value, err := toString(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ast.StringNode{
|
||||
Base: b,
|
||||
Key: key,
|
||||
Value: value,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildDateTimeNode(k, o, v interface{}, text []byte, pos position) (*ast.DateTimeNode, error) {
|
||||
b, err := base(text, pos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
operator, err := toNode[*ast.OperatorNode](o)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := toString(k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value, err := toTime(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ast.DateTimeNode{
|
||||
Base: b,
|
||||
Key: key,
|
||||
Operator: operator,
|
||||
Value: value,
|
||||
}, nil
|
||||
}
|
||||
func buildNaturalLanguageDateTimeNodes(k, v interface{}, text []byte, pos position) ([]ast.Node, error) {
|
||||
b, err := base(text, pos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := toString(k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
from, to, err := toTimeRange(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []ast.Node{
|
||||
&ast.DateTimeNode{
|
||||
Base: b,
|
||||
Value: *from,
|
||||
Key: key,
|
||||
Operator: &ast.OperatorNode{Value: ">="},
|
||||
},
|
||||
&ast.OperatorNode{Value: BoolAND},
|
||||
&ast.DateTimeNode{
|
||||
Base: b,
|
||||
Value: *to,
|
||||
Key: key,
|
||||
Operator: &ast.OperatorNode{Value: "<="},
|
||||
},
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func buildBooleanNode(k, v interface{}, text []byte, pos position) (*ast.BooleanNode, error) {
|
||||
b, err := base(text, pos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := toString(k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value, err := toString(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ast.BooleanNode{
|
||||
Base: b,
|
||||
Key: key,
|
||||
Value: strings.ToLower(value) == "true",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildOperatorNode(text []byte, pos position) (*ast.OperatorNode, error) {
|
||||
b, err := base(text, pos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value, err := toString(text)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch value {
|
||||
case "+":
|
||||
value = BoolAND
|
||||
case "-":
|
||||
value = BoolNOT
|
||||
}
|
||||
|
||||
return &ast.OperatorNode{
|
||||
Base: b,
|
||||
Value: value,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildGroupNode(k, n interface{}, text []byte, pos position) (*ast.GroupNode, error) {
|
||||
b, err := base(text, pos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, _ := toString(k)
|
||||
|
||||
nodes, err := toNodes[ast.Node](n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gn := &ast.GroupNode{
|
||||
Base: b,
|
||||
Key: key,
|
||||
Nodes: connectNodes(DefaultConnector{sameKeyOPValue: BoolOR}, nodes...),
|
||||
}
|
||||
|
||||
if err := validateGroupNode(gn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gn, nil
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
package kql
|
||||
|
||||
//go:generate go run github.com/mna/pigeon -optimize-grammar -optimize-parser -o dictionary_gen.go dictionary.peg
|
||||
@@ -0,0 +1,49 @@
|
||||
// Package kql provides the ability to work with kql queries.
|
||||
package kql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/ast"
|
||||
)
|
||||
|
||||
// The operator node value definition
|
||||
const (
|
||||
// BoolAND connect two nodes with "AND"
|
||||
BoolAND = "AND"
|
||||
// BoolOR connect two nodes with "OR"
|
||||
BoolOR = "OR"
|
||||
// BoolNOT connect two nodes with "NOT"
|
||||
BoolNOT = "NOT"
|
||||
)
|
||||
|
||||
// Builder implements kql Builder interface
|
||||
type Builder struct{}
|
||||
|
||||
// Build creates an ast.Ast based on a kql query
|
||||
func (b Builder) Build(q string) (*ast.Ast, error) {
|
||||
f, err := Parse("", []byte(q))
|
||||
if err != nil {
|
||||
var list errList
|
||||
errors.As(err, &list)
|
||||
|
||||
for _, listError := range list {
|
||||
var parserError *parserError
|
||||
switch {
|
||||
case errors.As(listError, &parserError):
|
||||
if parserError.Inner != nil {
|
||||
return nil, parserError.Inner
|
||||
}
|
||||
|
||||
return nil, listError
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return f.(*ast.Ast), nil
|
||||
}
|
||||
|
||||
// timeNow mirrors time.Now by default, the only reason why this exists
|
||||
// is to monkey patch it from the tests. See PatchTimeNow
|
||||
var timeNow = time.Now
|
||||
@@ -0,0 +1,54 @@
|
||||
package kql_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/ast"
|
||||
"github.com/opencloud-eu/opencloud/pkg/kql"
|
||||
"github.com/opencloud-eu/opencloud/services/search/pkg/query"
|
||||
tAssert "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewAST(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
givenQuery string
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
givenQuery: "foo:bar",
|
||||
},
|
||||
{
|
||||
name: "error",
|
||||
givenQuery: kql.BoolAND,
|
||||
expectedError: query.StartsWithBinaryOperatorError{
|
||||
Node: &ast.OperatorNode{Value: kql.BoolAND},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert := tAssert.New(t)
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := kql.Builder{}.Build(tt.givenQuery)
|
||||
|
||||
if tt.expectedError != nil {
|
||||
if tt.expectedError.Error() != "" {
|
||||
assert.Equal(err.Error(), tt.expectedError.Error())
|
||||
} else {
|
||||
assert.NotNil(err)
|
||||
}
|
||||
|
||||
assert.Nil(got)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.Nil(err)
|
||||
assert.NotNil(got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package kql
|
||||
|
||||
import (
|
||||
"github.com/opencloud-eu/opencloud/pkg/ast"
|
||||
"github.com/opencloud-eu/opencloud/services/search/pkg/query"
|
||||
)
|
||||
|
||||
func validateAst(a *ast.Ast) error {
|
||||
switch node := a.Nodes[0].(type) {
|
||||
case *ast.OperatorNode:
|
||||
switch node.Value {
|
||||
case BoolAND, BoolOR:
|
||||
return &query.StartsWithBinaryOperatorError{Node: node}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateGroupNode(n *ast.GroupNode) error {
|
||||
switch node := n.Nodes[0].(type) {
|
||||
case *ast.OperatorNode:
|
||||
switch node.Value {
|
||||
case BoolAND, BoolOR:
|
||||
return &query.StartsWithBinaryOperatorError{Node: node}
|
||||
}
|
||||
}
|
||||
|
||||
if n.Key != "" {
|
||||
for _, node := range n.Nodes {
|
||||
if ast.NodeKey(node) != "" {
|
||||
return &query.NamedGroupInvalidNodesError{Node: node}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,345 @@
|
||||
// package l10n holds translation mechanics that are used by user facing services (notifications, userlog, graph)
|
||||
package l10n
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
"github.com/leonelquinteros/gotext"
|
||||
"github.com/opencloud-eu/opencloud/pkg/middleware"
|
||||
settingssvc "github.com/opencloud-eu/opencloud/protogen/gen/ocis/services/settings/v0"
|
||||
micrometadata "go-micro.dev/v4/metadata"
|
||||
)
|
||||
|
||||
var (
|
||||
// HeaderAcceptLanguage is the header key for the accept-language header
|
||||
HeaderAcceptLanguage = "Accept-Language"
|
||||
|
||||
// ErrUnsupportedType is returned when the type is not supported
|
||||
ErrUnsupportedType = errors.New("unsupported type")
|
||||
)
|
||||
|
||||
// Template marks a string as translatable
|
||||
func Template(s string) string { return s }
|
||||
|
||||
// Translator is able to translate strings
|
||||
type Translator struct {
|
||||
fs fs.FS
|
||||
defaultLocale string
|
||||
domain string
|
||||
}
|
||||
|
||||
// NewTranslator creates a Translator with library path and language code and load default domain
|
||||
func NewTranslator(defaultLocale string, domain string, fsys fs.FS) Translator {
|
||||
return Translator{
|
||||
fs: fsys,
|
||||
defaultLocale: defaultLocale,
|
||||
domain: domain,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTranslatorFromCommonConfig creates a new Translator from legacy config
|
||||
func NewTranslatorFromCommonConfig(defaultLocale string, domain string, path string, fsys fs.FS, fsSubPath string) Translator {
|
||||
var filesystem fs.FS
|
||||
if path == "" {
|
||||
filesystem, _ = fs.Sub(fsys, fsSubPath)
|
||||
} else { // use custom path instead
|
||||
filesystem = os.DirFS(path)
|
||||
}
|
||||
return NewTranslator(defaultLocale, domain, filesystem)
|
||||
}
|
||||
|
||||
// Translate translates a string to the locale
|
||||
func (t Translator) Translate(str, locale string) string {
|
||||
return t.Locale(locale).Get(str)
|
||||
}
|
||||
|
||||
// Locale returns the gotext.Locale, use `.Get` method to translate strings
|
||||
func (t Translator) Locale(locale string) *gotext.Locale {
|
||||
l := gotext.NewLocaleFS(locale, t.fs)
|
||||
l.AddDomain(t.domain) // make domain configurable only if needed
|
||||
if locale != "en" && len(l.GetTranslations()) == 0 {
|
||||
l = gotext.NewLocaleFS(t.defaultLocale, t.fs)
|
||||
l.AddDomain(t.domain) // make domain configurable only if needed
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// TranslateEntity function provides the generic way to translate a struct, array or slice.
|
||||
// Support for maps is also provided, but non-pointer values will not work.
|
||||
// The function also takes the entity with fields to translate.
|
||||
// The function supports nested structs and slices of structs.
|
||||
/*
|
||||
tr := NewTranslator("en", _domain, _fsys)
|
||||
|
||||
// a slice of translatables can be passed directly
|
||||
val := []string{"description", "display name"}
|
||||
err := tr.TranslateEntity(tr, s, val)
|
||||
|
||||
// string maps work the same way
|
||||
val := map[string]string{
|
||||
"entryOne": "description",
|
||||
"entryTwo": "display name",
|
||||
}
|
||||
err := TranslateEntity(tr, val)
|
||||
|
||||
// struct fields need to be specified
|
||||
type Struct struct {
|
||||
Description string
|
||||
DisplayName string
|
||||
MetaInformation string
|
||||
}
|
||||
val := Struct{}
|
||||
err := TranslateEntity(tr, val,
|
||||
l10n.TranslateField("Description"),
|
||||
l10n.TranslateField("DisplayName"),
|
||||
)
|
||||
|
||||
// nested structures are supported
|
||||
type InnerStruct struct {
|
||||
Description string
|
||||
Roles []string
|
||||
}
|
||||
type OuterStruct struct {
|
||||
DisplayName string
|
||||
First InnerStruct
|
||||
Others map[string]InnerStruct
|
||||
}
|
||||
val := OuterStruct{}
|
||||
err := TranslateEntity(tr, val,
|
||||
l10n.TranslateField("DisplayName"),
|
||||
l10n.TranslateStruct("First",
|
||||
l10n.TranslateField("Description"),
|
||||
l10n.TranslateEach("Roles"),
|
||||
),
|
||||
l10n.TranslateMap("Others",
|
||||
l10n.TranslateField("Description"),
|
||||
},
|
||||
*/
|
||||
func (t Translator) TranslateEntity(locale string, entity any, opts ...TranslateOption) error {
|
||||
return TranslateEntity(t.Locale(locale).Get, entity, opts...)
|
||||
}
|
||||
|
||||
// MustGetUserLocale returns the locale the user wants to use, omitting errors
|
||||
func MustGetUserLocale(ctx context.Context, userID string, preferedLang string, vc settingssvc.ValueService) string {
|
||||
if preferedLang != "" {
|
||||
return preferedLang
|
||||
}
|
||||
|
||||
locale, _ := GetUserLocale(ctx, userID, vc)
|
||||
return locale
|
||||
}
|
||||
|
||||
// GetUserLocale returns the locale of the user
|
||||
func GetUserLocale(ctx context.Context, userID string, vc settingssvc.ValueService) (string, error) {
|
||||
resp, err := vc.GetValueByUniqueIdentifiers(
|
||||
micrometadata.Set(ctx, middleware.AccountID, userID),
|
||||
&settingssvc.GetValueByUniqueIdentifiersRequest{
|
||||
AccountUuid: userID,
|
||||
// this defaults.SettingUUIDProfileLanguage. Copied here to avoid import cycles.
|
||||
SettingId: "aa8cfbe5-95d4-4f7e-a032-c3c01f5f062f",
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
val := resp.GetValue().GetValue().GetListValue().GetValues()
|
||||
if len(val) == 0 {
|
||||
return "", errors.New("no language setting found")
|
||||
}
|
||||
return val[0].GetStringValue(), nil
|
||||
}
|
||||
|
||||
// TranslateOption is used to specify fields in structs to translate
|
||||
type TranslateOption func() (string, FieldType, []TranslateOption)
|
||||
|
||||
// FieldType is used to specify the type of field to translate
|
||||
type FieldType int
|
||||
|
||||
const (
|
||||
// FieldTypeString is a string field
|
||||
FieldTypeString FieldType = iota
|
||||
// FieldTypeStruct is a struct field
|
||||
FieldTypeStruct
|
||||
// FieldTypeIterable is a slice or array field
|
||||
FieldTypeIterable
|
||||
// FieldTypeMap is a map field
|
||||
FieldTypeMap
|
||||
)
|
||||
|
||||
// TranslateField function provides the generic way to translate the necessary field in composite entities.
|
||||
func TranslateField(fieldName string) TranslateOption {
|
||||
return func() (string, FieldType, []TranslateOption) {
|
||||
return fieldName, FieldTypeString, nil
|
||||
}
|
||||
}
|
||||
|
||||
// TranslateStruct function provides the generic way to translate the nested fields in composite entities.
|
||||
func TranslateStruct(fieldName string, args ...TranslateOption) TranslateOption {
|
||||
return func() (string, FieldType, []TranslateOption) {
|
||||
return fieldName, FieldTypeStruct, args
|
||||
}
|
||||
}
|
||||
|
||||
// TranslateEach function provides the generic way to translate the necessary fields in slices or nested entities.
|
||||
func TranslateEach(fieldName string, args ...TranslateOption) TranslateOption {
|
||||
return func() (string, FieldType, []TranslateOption) {
|
||||
return fieldName, FieldTypeIterable, args
|
||||
}
|
||||
}
|
||||
|
||||
// TranslateMap function provides the generic way to translate the necessary fields in maps.
|
||||
func TranslateMap(fieldName string, args ...TranslateOption) TranslateOption {
|
||||
return func() (string, FieldType, []TranslateOption) {
|
||||
return fieldName, FieldTypeMap, args
|
||||
}
|
||||
}
|
||||
|
||||
// TranslateEntity translates a slice, array or struct
|
||||
// See Translator.TranslateEntity for more information
|
||||
func TranslateEntity(tr func(string, ...any) string, entity any, opts ...TranslateOption) error {
|
||||
value := reflect.ValueOf(entity)
|
||||
|
||||
value, ok := cleanValue(value)
|
||||
if !ok {
|
||||
return errors.New("entity is not valid")
|
||||
}
|
||||
|
||||
switch value.Kind() {
|
||||
case reflect.Struct:
|
||||
rangeOverArgs(tr, value, opts...)
|
||||
case reflect.Slice, reflect.Array, reflect.Map:
|
||||
translateEach(tr, value, opts...)
|
||||
case reflect.String:
|
||||
translateField(tr, value)
|
||||
default:
|
||||
return ErrUnsupportedType
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func translateEach(tr func(string, ...any) string, value reflect.Value, args ...TranslateOption) {
|
||||
value, ok := cleanValue(value)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
switch value.Kind() {
|
||||
case reflect.Array, reflect.Slice:
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
v := value.Index(i)
|
||||
switch v.Kind() {
|
||||
case reflect.Struct, reflect.Ptr:
|
||||
rangeOverArgs(tr, v, args...)
|
||||
case reflect.String:
|
||||
translateField(tr, v)
|
||||
case reflect.Slice, reflect.Array, reflect.Map:
|
||||
translateEach(tr, v, args...)
|
||||
}
|
||||
}
|
||||
case reflect.Map:
|
||||
for _, k := range value.MapKeys() {
|
||||
v := value.MapIndex(k)
|
||||
switch v.Kind() {
|
||||
case reflect.Struct:
|
||||
// FIXME: add support for non-pointer values
|
||||
case reflect.Pointer:
|
||||
rangeOverArgs(tr, v, args...)
|
||||
case reflect.String:
|
||||
if nv := tr(v.String()); nv != "" {
|
||||
value.SetMapIndex(k, reflect.ValueOf(nv))
|
||||
}
|
||||
case reflect.Slice, reflect.Array, reflect.Map:
|
||||
translateEach(tr, v, args...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func rangeOverArgs(tr func(string, ...any) string, value reflect.Value, args ...TranslateOption) {
|
||||
value, ok := cleanValue(value)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for _, arg := range args {
|
||||
fieldName, fieldType, opts := arg()
|
||||
|
||||
switch fieldType {
|
||||
case FieldTypeString:
|
||||
f := value.FieldByName(fieldName)
|
||||
translateField(tr, f)
|
||||
case FieldTypeStruct:
|
||||
innerValue := value.FieldByName(fieldName)
|
||||
if !innerValue.IsValid() || !isStruct(innerValue) {
|
||||
return
|
||||
}
|
||||
rangeOverArgs(tr, innerValue, opts...)
|
||||
case FieldTypeIterable:
|
||||
innerValue := value.FieldByName(fieldName)
|
||||
if !innerValue.IsValid() {
|
||||
return
|
||||
}
|
||||
if kind := innerValue.Kind(); kind != reflect.Array && kind != reflect.Slice {
|
||||
return
|
||||
}
|
||||
translateEach(tr, innerValue, opts...)
|
||||
case FieldTypeMap:
|
||||
innerValue := value.FieldByName(fieldName)
|
||||
if !innerValue.IsValid() {
|
||||
return
|
||||
}
|
||||
if kind := innerValue.Kind(); kind != reflect.Map {
|
||||
return
|
||||
}
|
||||
translateEach(tr, innerValue, opts...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func translateField(tr func(string, ...any) string, f reflect.Value) {
|
||||
if f.IsValid() {
|
||||
if f.Kind() == reflect.Ptr {
|
||||
if f.IsNil() {
|
||||
return
|
||||
}
|
||||
f = f.Elem()
|
||||
}
|
||||
// A Value can be changed only if it is
|
||||
// addressable and was not obtained by
|
||||
// the use of unexported struct fields.
|
||||
if f.CanSet() {
|
||||
// change value
|
||||
if f.Kind() == reflect.String {
|
||||
val := tr(f.String())
|
||||
if val == "" {
|
||||
return
|
||||
}
|
||||
f.SetString(val)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isStruct(r reflect.Value) bool {
|
||||
if r.Kind() == reflect.Ptr {
|
||||
r = r.Elem()
|
||||
}
|
||||
return r.Kind() == reflect.Struct
|
||||
}
|
||||
|
||||
func cleanValue(v reflect.Value) (reflect.Value, bool) {
|
||||
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
|
||||
if v.IsNil() {
|
||||
return v, false
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
if !v.IsValid() {
|
||||
return v, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
@@ -0,0 +1,414 @@
|
||||
package l10n
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTranslateStruct(t *testing.T) {
|
||||
|
||||
type InnerStruct struct {
|
||||
Description string
|
||||
DisplayName *string
|
||||
}
|
||||
|
||||
type TopLevelStruct struct {
|
||||
Description string
|
||||
DisplayName *string
|
||||
SubStruct *InnerStruct
|
||||
}
|
||||
|
||||
type WrapperStruct struct {
|
||||
Description string
|
||||
StructList []*InnerStruct
|
||||
}
|
||||
|
||||
toStrPointer := func(str string) *string {
|
||||
return &str
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
entity any
|
||||
args []TranslateOption
|
||||
expected any
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "top level slice of struct",
|
||||
entity: []*InnerStruct{
|
||||
{
|
||||
Description: "inner 1",
|
||||
DisplayName: toStrPointer("innerDisplayName 1"),
|
||||
},
|
||||
{
|
||||
Description: "inner 2",
|
||||
DisplayName: toStrPointer("innerDisplayName 2"),
|
||||
},
|
||||
},
|
||||
args: []TranslateOption{
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
},
|
||||
expected: []*InnerStruct{
|
||||
{
|
||||
Description: "new Inner 1",
|
||||
DisplayName: toStrPointer("new InnerDisplayName 1"),
|
||||
},
|
||||
{
|
||||
Description: "new Inner 2",
|
||||
DisplayName: toStrPointer("new InnerDisplayName 2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "top level slice of string",
|
||||
entity: []string{
|
||||
"inner 1",
|
||||
"inner 2",
|
||||
},
|
||||
expected: []string{
|
||||
"new Inner 1",
|
||||
"new Inner 2",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "top level slice of struct",
|
||||
entity: []*TopLevelStruct{
|
||||
{
|
||||
Description: "inner 1",
|
||||
DisplayName: toStrPointer("innerDisplayName 1"),
|
||||
SubStruct: &InnerStruct{
|
||||
Description: "inner",
|
||||
DisplayName: toStrPointer("innerDisplayName"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Description: "inner 2",
|
||||
DisplayName: toStrPointer("innerDisplayName 2"),
|
||||
},
|
||||
},
|
||||
args: []TranslateOption{
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
TranslateStruct("SubStruct",
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
),
|
||||
},
|
||||
expected: []*TopLevelStruct{
|
||||
{
|
||||
Description: "new Inner 1",
|
||||
DisplayName: toStrPointer("new InnerDisplayName 1"),
|
||||
SubStruct: &InnerStruct{
|
||||
Description: "new Inner",
|
||||
DisplayName: toStrPointer("new InnerDisplayName"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Description: "new Inner 2",
|
||||
DisplayName: toStrPointer("new InnerDisplayName 2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "wrapped struct full",
|
||||
entity: &WrapperStruct{
|
||||
StructList: []*InnerStruct{
|
||||
{
|
||||
Description: "inner 1",
|
||||
DisplayName: toStrPointer("innerDisplayName 1"),
|
||||
},
|
||||
{
|
||||
Description: "inner 2",
|
||||
DisplayName: toStrPointer("innerDisplayName 2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
args: []TranslateOption{
|
||||
TranslateEach("StructList",
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
),
|
||||
},
|
||||
expected: &WrapperStruct{
|
||||
StructList: []*InnerStruct{
|
||||
{
|
||||
Description: "new Inner 1",
|
||||
DisplayName: toStrPointer("new InnerDisplayName 1"),
|
||||
},
|
||||
{
|
||||
Description: "new Inner 2",
|
||||
DisplayName: toStrPointer("new InnerDisplayName 2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty struct, NotExistingSubStructName",
|
||||
entity: &TopLevelStruct{},
|
||||
args: []TranslateOption{
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
TranslateStruct("NotExistingSubStructName",
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
),
|
||||
},
|
||||
expected: &TopLevelStruct{},
|
||||
},
|
||||
{
|
||||
name: "empty struct",
|
||||
entity: &TopLevelStruct{},
|
||||
args: []TranslateOption{
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
TranslateStruct("SubStruct",
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
),
|
||||
},
|
||||
expected: &TopLevelStruct{},
|
||||
},
|
||||
{
|
||||
name: "empty struct, not existing field",
|
||||
entity: &TopLevelStruct{
|
||||
Description: "description",
|
||||
DisplayName: toStrPointer("displayName"),
|
||||
},
|
||||
args: []TranslateOption{
|
||||
TranslateField("NotExistingFieldName"),
|
||||
TranslateStruct("SubStruct",
|
||||
TranslateField("NotExistingFieldName"),
|
||||
),
|
||||
},
|
||||
expected: &TopLevelStruct{
|
||||
Description: "description",
|
||||
DisplayName: toStrPointer("displayName"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "inner struct DisplayName empy",
|
||||
entity: &TopLevelStruct{
|
||||
Description: "description",
|
||||
DisplayName: toStrPointer("displayName"),
|
||||
},
|
||||
args: []TranslateOption{
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
TranslateStruct("SubStruct",
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
),
|
||||
},
|
||||
expected: &TopLevelStruct{
|
||||
Description: "new Description",
|
||||
DisplayName: toStrPointer("new DisplayName"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "inner struct full",
|
||||
entity: &TopLevelStruct{
|
||||
Description: "description",
|
||||
DisplayName: toStrPointer("displayName"),
|
||||
},
|
||||
args: []TranslateOption{
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
TranslateStruct("SubStruct",
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
),
|
||||
},
|
||||
expected: &TopLevelStruct{
|
||||
Description: "new Description",
|
||||
DisplayName: toStrPointer("new DisplayName"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "full struct",
|
||||
entity: &TopLevelStruct{
|
||||
Description: "description",
|
||||
DisplayName: toStrPointer("displayName"),
|
||||
SubStruct: &InnerStruct{
|
||||
Description: "inner",
|
||||
DisplayName: toStrPointer("innerDisplayName"),
|
||||
},
|
||||
},
|
||||
args: []TranslateOption{
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
TranslateStruct("SubStruct",
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
),
|
||||
},
|
||||
expected: &TopLevelStruct{
|
||||
Description: "new Description",
|
||||
DisplayName: toStrPointer("new DisplayName"),
|
||||
SubStruct: &InnerStruct{
|
||||
Description: "new Inner",
|
||||
DisplayName: toStrPointer("new InnerDisplayName"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty slice",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "string slice",
|
||||
entity: []string{"description", "inner"},
|
||||
expected: []string{"new Description", "new Inner"},
|
||||
},
|
||||
{
|
||||
name: "string map",
|
||||
entity: map[string]string{
|
||||
"entryOne": "description",
|
||||
"entryTwo": "inner",
|
||||
},
|
||||
expected: map[string]string{
|
||||
"entryOne": "new Description",
|
||||
"entryTwo": "new Inner",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pointer struct map",
|
||||
entity: map[string]*InnerStruct{
|
||||
"entryOne": {Description: "description", DisplayName: toStrPointer("displayName")},
|
||||
"entryTwo": {Description: "inner", DisplayName: toStrPointer("innerDisplayName")},
|
||||
},
|
||||
args: []TranslateOption{
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
},
|
||||
expected: map[string]*InnerStruct{
|
||||
"entryOne": {Description: "new Description", DisplayName: toStrPointer("new DisplayName")},
|
||||
"entryTwo": {Description: "new Inner", DisplayName: toStrPointer("new InnerDisplayName")},
|
||||
},
|
||||
},
|
||||
/* FIXME: non pointer maps are currently not working
|
||||
{
|
||||
name: "struct map",
|
||||
entity: map[string]InnerStruct{
|
||||
"entryOne": {Description: "description", DisplayName: toStrPointer("displayName")},
|
||||
"entryTwo": {Description: "inner", DisplayName: toStrPointer("innerDisplayName")},
|
||||
},
|
||||
args: []TranslateOption{
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
},
|
||||
expected: map[string]InnerStruct{
|
||||
"entryOne": {Description: "new Description", DisplayName: toStrPointer("new DisplayName")},
|
||||
"entryTwo": {Description: "new Inner", DisplayName: toStrPointer("new InnerDisplayName")},
|
||||
},
|
||||
},
|
||||
*/
|
||||
{
|
||||
name: "slice map",
|
||||
entity: map[string][]string{
|
||||
"entryOne": {"description", "inner"},
|
||||
"entryTwo": {"inner 2", "innerDisplayName 2"},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"entryOne": {"new Description", "new Inner"},
|
||||
"entryTwo": {"new Inner 2", "new InnerDisplayName 2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "double slice",
|
||||
entity: [][]string{
|
||||
{"description", "inner"},
|
||||
{"inner 2", "innerDisplayName 2"},
|
||||
},
|
||||
expected: [][]string{
|
||||
{"new Description", "new Inner"},
|
||||
{"new Inner 2", "new InnerDisplayName 2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested structs",
|
||||
entity: [][]*InnerStruct{
|
||||
{
|
||||
&InnerStruct{Description: "description", DisplayName: toStrPointer("displayName")},
|
||||
&InnerStruct{Description: "inner", DisplayName: toStrPointer("innerDisplayName")},
|
||||
},
|
||||
{
|
||||
&InnerStruct{Description: "inner 2", DisplayName: toStrPointer("innerDisplayName 2")},
|
||||
},
|
||||
},
|
||||
args: []TranslateOption{
|
||||
TranslateField("Description"),
|
||||
TranslateField("DisplayName"),
|
||||
},
|
||||
expected: [][]*InnerStruct{
|
||||
{
|
||||
&InnerStruct{Description: "new Description", DisplayName: toStrPointer("new DisplayName")},
|
||||
&InnerStruct{Description: "new Inner", DisplayName: toStrPointer("new InnerDisplayName")},
|
||||
},
|
||||
{
|
||||
&InnerStruct{Description: "new Inner 2", DisplayName: toStrPointer("new InnerDisplayName 2")},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "double mapslices",
|
||||
entity: []map[string][]string{
|
||||
{
|
||||
"entryOne": {"inner 1", "innerDisplayName 1"},
|
||||
"entryTwo": {"inner 2", "innerDisplayName 2"},
|
||||
},
|
||||
{
|
||||
"entryOne": {"description", "displayName"},
|
||||
},
|
||||
},
|
||||
expected: []map[string][]string{
|
||||
{
|
||||
"entryOne": {"new Inner 1", "new InnerDisplayName 1"},
|
||||
"entryTwo": {"new Inner 2", "new InnerDisplayName 2"},
|
||||
},
|
||||
{
|
||||
"entryOne": {"new Description", "new DisplayName"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := TranslateEntity(mock(), tt.entity, tt.args...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("TranslateEntity() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
assert.Equal(t, tt.expected, tt.entity)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mock() func(string, ...interface{}) string {
|
||||
return func(s string, i ...interface{}) string {
|
||||
switch s {
|
||||
case "description":
|
||||
return "new Description"
|
||||
case "displayName":
|
||||
return "new DisplayName"
|
||||
case "inner":
|
||||
return "new Inner"
|
||||
case "innerDisplayName":
|
||||
return "new InnerDisplayName"
|
||||
case "inner 1":
|
||||
return "new Inner 1"
|
||||
case "innerDisplayName 1":
|
||||
return "new InnerDisplayName 1"
|
||||
case "inner 2":
|
||||
return "new Inner 2"
|
||||
case "innerDisplayName 2":
|
||||
return "new InnerDisplayName 2"
|
||||
}
|
||||
return s
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
)
|
||||
|
||||
const (
|
||||
caCheckRetries = 3
|
||||
caCheckSleep = 2
|
||||
)
|
||||
|
||||
func WaitForCA(log log.Logger, insecure bool, caCert string) error {
|
||||
if !insecure && caCert != "" {
|
||||
for i := 0; i < caCheckRetries; i++ {
|
||||
if _, err := os.Stat(caCert); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
// Check if this actually is a CA cert. We need to retry here as well
|
||||
// as the file might exist already, but have no contents yet.
|
||||
certs := x509.NewCertPool()
|
||||
pemData, err := os.ReadFile(caCert)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Str("LDAP CACert", caCert).Msg("Error reading CA")
|
||||
} else if !certs.AppendCertsFromPEM(pemData) {
|
||||
log.Debug().Str("LDAP CAcert", caCert).Msg("Failed to append CA to pool")
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(caCheckSleep * time.Second)
|
||||
log.Warn().Str("LDAP CACert", caCert).Msgf("CA cert file is not ready yet. Waiting %d seconds for it to appear.", caCheckSleep)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package ldap_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestLdap(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Ldap Suite")
|
||||
}
|
||||
+145
@@ -0,0 +1,145 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
chimiddleware "github.com/go-chi/chi/v5/middleware"
|
||||
mzlog "github.com/go-micro/plugins/v4/logger/zerolog"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go-micro.dev/v4/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
RequestIDString = "request-id"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// this is ugly, but "logger.DefaultLogger" is a global variable, and we need to set it _before_ anybody uses it
|
||||
setMicroLogger()
|
||||
}
|
||||
|
||||
// for logging reasons we don't want the same logging level on both oCIS and micro. As a framework builder we do not
|
||||
// want to expose to the end user the internal framework logs unless explicitly specified.
|
||||
func setMicroLogger() {
|
||||
if os.Getenv("MICRO_LOG_LEVEL") == "" {
|
||||
_ = os.Setenv("MICRO_LOG_LEVEL", "error")
|
||||
}
|
||||
|
||||
lev, err := zerolog.ParseLevel(os.Getenv("MICRO_LOG_LEVEL"))
|
||||
if err != nil {
|
||||
lev = zerolog.ErrorLevel
|
||||
}
|
||||
logger.DefaultLogger = mzlog.NewLogger(
|
||||
logger.WithLevel(logger.Level(lev)),
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"system": "go-micro",
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
// Logger simply wraps the zerolog logger.
|
||||
type Logger struct {
|
||||
zerolog.Logger
|
||||
}
|
||||
|
||||
// NopLogger initializes a no-operation logger.
|
||||
func NopLogger() Logger {
|
||||
return Logger{zerolog.Nop()}
|
||||
}
|
||||
|
||||
type LineInfoHook struct{}
|
||||
|
||||
// Run is a hook to add line info to log messages.
|
||||
// I found the zerolog example for this here:
|
||||
// https://github.com/rs/zerolog/issues/22#issuecomment-1127295489
|
||||
func (h LineInfoHook) Run(e *zerolog.Event, _ zerolog.Level, _ string) {
|
||||
_, file, line, ok := runtime.Caller(3)
|
||||
if ok {
|
||||
e.Str("line", fmt.Sprintf("%s:%d", file, line))
|
||||
}
|
||||
}
|
||||
|
||||
// NewLogger initializes a new logger instance.
|
||||
func NewLogger(opts ...Option) Logger {
|
||||
options := newOptions(opts...)
|
||||
|
||||
// set GlobalLevel() to the minimum value -1 = TraceLevel, so that only the services' log level matter
|
||||
zerolog.SetGlobalLevel(zerolog.TraceLevel)
|
||||
|
||||
var logLevel zerolog.Level
|
||||
switch strings.ToLower(options.Level) {
|
||||
case "panic":
|
||||
logLevel = zerolog.PanicLevel
|
||||
case "fatal":
|
||||
logLevel = zerolog.FatalLevel
|
||||
case "error":
|
||||
logLevel = zerolog.ErrorLevel
|
||||
case "warn":
|
||||
logLevel = zerolog.WarnLevel
|
||||
case "info":
|
||||
logLevel = zerolog.InfoLevel
|
||||
case "debug":
|
||||
logLevel = zerolog.DebugLevel
|
||||
case "trace":
|
||||
logLevel = zerolog.TraceLevel
|
||||
default:
|
||||
logLevel = zerolog.ErrorLevel
|
||||
}
|
||||
|
||||
var l zerolog.Logger
|
||||
|
||||
if options.Pretty {
|
||||
l = log.Output(
|
||||
zerolog.NewConsoleWriter(
|
||||
func(w *zerolog.ConsoleWriter) {
|
||||
w.TimeFormat = time.RFC3339
|
||||
w.Out = os.Stderr
|
||||
w.NoColor = !options.Color
|
||||
},
|
||||
),
|
||||
)
|
||||
} else if options.File != "" {
|
||||
f, err := os.OpenFile(options.File, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0644)
|
||||
if err != nil {
|
||||
print(fmt.Sprintf("file could not be opened for writing: %s. error: %v", options.File, err))
|
||||
os.Exit(1)
|
||||
}
|
||||
l = l.Output(f)
|
||||
} else {
|
||||
l = zerolog.New(os.Stderr)
|
||||
}
|
||||
|
||||
l = l.With().
|
||||
Str("service", options.Name).
|
||||
Timestamp().
|
||||
Logger().Level(logLevel)
|
||||
|
||||
if logLevel <= zerolog.InfoLevel {
|
||||
var lineInfoHook LineInfoHook
|
||||
l = l.Hook(lineInfoHook)
|
||||
}
|
||||
|
||||
return Logger{
|
||||
l,
|
||||
}
|
||||
}
|
||||
|
||||
// SubloggerWithRequestID returns a sub-logger with the x-request-id added to all events
|
||||
func (l Logger) SubloggerWithRequestID(c context.Context) Logger {
|
||||
return Logger{
|
||||
l.With().Str(RequestIDString, chimiddleware.GetReqID(c)).Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecation logs a deprecation message,
|
||||
// it is used to inform the user that a certain feature is deprecated and will be removed in the future.
|
||||
// Do not use a logger here because the message MUST be visible independent of the log level.
|
||||
func Deprecation(a ...any) {
|
||||
fmt.Printf("\033[1;31mDEPRECATION: %s\033[0m\n", a...)
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package log_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/onsi/gomega"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/internal/testenv"
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
)
|
||||
|
||||
func TestDeprecation(t *testing.T) {
|
||||
cmdTest := testenv.NewCMDTest(t.Name())
|
||||
if cmdTest.ShouldRun() {
|
||||
log.Deprecation("this is a deprecation")
|
||||
return
|
||||
}
|
||||
|
||||
out, err := cmdTest.Run()
|
||||
|
||||
g := gomega.NewWithT(t)
|
||||
g.Expect(err).ToNot(gomega.HaveOccurred())
|
||||
g.Expect(string(out)).To(gomega.HavePrefix("\033[1;31mDEPRECATION: this is a deprecation"))
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type levelMap map[logrus.Level]zerolog.Level
|
||||
|
||||
var levelMapping = levelMap{
|
||||
logrus.PanicLevel: zerolog.PanicLevel,
|
||||
logrus.ErrorLevel: zerolog.ErrorLevel,
|
||||
logrus.TraceLevel: zerolog.TraceLevel,
|
||||
logrus.DebugLevel: zerolog.DebugLevel,
|
||||
logrus.WarnLevel: zerolog.WarnLevel,
|
||||
logrus.InfoLevel: zerolog.InfoLevel,
|
||||
}
|
||||
|
||||
// LogrusWrapper around zerolog. Required because idp uses logrus internally.
|
||||
type LogrusWrapper struct {
|
||||
zeroLog *zerolog.Logger
|
||||
levelMap levelMap
|
||||
}
|
||||
|
||||
// LogrusWrap returns a logrus logger which internally logs to /dev/null. Messages are passed to the
|
||||
// underlying zerolog via hooks.
|
||||
func LogrusWrap(zr zerolog.Logger) *logrus.Logger {
|
||||
lr := logrus.New()
|
||||
lr.SetOutput(io.Discard)
|
||||
lr.SetLevel(logrusLevel(zr.GetLevel()))
|
||||
lr.AddHook(&LogrusWrapper{
|
||||
zeroLog: &zr,
|
||||
levelMap: levelMapping,
|
||||
})
|
||||
|
||||
return lr
|
||||
}
|
||||
|
||||
// Levels on which logrus hooks should be triggered
|
||||
func (h *LogrusWrapper) Levels() []logrus.Level {
|
||||
return logrus.AllLevels
|
||||
}
|
||||
|
||||
// Fire called by logrus on new message
|
||||
func (h *LogrusWrapper) Fire(entry *logrus.Entry) error {
|
||||
h.zeroLog.WithLevel(h.levelMap[entry.Level]).
|
||||
Fields(zeroLogFields(entry.Data)).
|
||||
Msg(entry.Message)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert logrus fields to zerolog
|
||||
func zeroLogFields(fields logrus.Fields) map[string]interface{} {
|
||||
fm := make(map[string]interface{})
|
||||
for k, v := range fields {
|
||||
fm[k] = v
|
||||
}
|
||||
|
||||
return fm
|
||||
}
|
||||
|
||||
// Convert logrus level to zerolog
|
||||
func logrusLevel(level zerolog.Level) logrus.Level {
|
||||
for lrLvl, zrLvl := range levelMapping {
|
||||
if zrLvl == level {
|
||||
return lrLvl
|
||||
}
|
||||
}
|
||||
|
||||
panic("Unexpected loglevel")
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package log
|
||||
|
||||
// Option defines a single option function.
|
||||
type Option func(o *Options)
|
||||
|
||||
// Options defines the available options for this package.
|
||||
type Options struct {
|
||||
Name string
|
||||
Level string
|
||||
Pretty bool
|
||||
Color bool
|
||||
File string
|
||||
}
|
||||
|
||||
// newOptions initializes the available default options.
|
||||
func newOptions(opts ...Option) Options {
|
||||
opt := Options{
|
||||
Name: "ocis",
|
||||
Level: "info",
|
||||
Pretty: true,
|
||||
Color: true,
|
||||
}
|
||||
|
||||
for _, o := range opts {
|
||||
o(&opt)
|
||||
}
|
||||
|
||||
return opt
|
||||
}
|
||||
|
||||
// Name provides a function to set the name option.
|
||||
func Name(val string) Option {
|
||||
return func(o *Options) {
|
||||
o.Name = val
|
||||
}
|
||||
}
|
||||
|
||||
// Level provides a function to set the level option.
|
||||
func Level(val string) Option {
|
||||
return func(o *Options) {
|
||||
o.Level = val
|
||||
}
|
||||
}
|
||||
|
||||
// Pretty provides a function to set the pretty option.
|
||||
func Pretty(val bool) Option {
|
||||
return func(o *Options) {
|
||||
o.Pretty = val
|
||||
}
|
||||
}
|
||||
|
||||
// Color provides a function to set the color option.
|
||||
func Color(val bool) Option {
|
||||
return func(o *Options) {
|
||||
o.Color = val
|
||||
}
|
||||
}
|
||||
|
||||
// File provides a function to set the file option.
|
||||
func File(val string) Option {
|
||||
return func(o *Options) {
|
||||
o.File = val
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
// Package markdown allows reading and editing Markdown files
|
||||
package markdown
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Heading represents a markdown Heading
|
||||
type Heading struct {
|
||||
Level int // the level of the heading. 1 means it's the H1
|
||||
Content string // the text of the heading
|
||||
Header string // the heading itself
|
||||
}
|
||||
|
||||
// MD represents a markdown file
|
||||
type MD struct {
|
||||
Headings []Heading
|
||||
}
|
||||
|
||||
// Bytes returns the markdown as []bytes, ignoring errors
|
||||
func (md MD) Bytes() []byte {
|
||||
var b bytes.Buffer
|
||||
_, _ = md.WriteContent(&b)
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
// String returns the markdown as string, ignoring errors
|
||||
func (md MD) String() string {
|
||||
var b strings.Builder
|
||||
_, _ = md.WriteContent(&b)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// TocBytes returns the table of contents as []byte, ignoring errors
|
||||
func (md MD) TocBytes() []byte {
|
||||
var b bytes.Buffer
|
||||
_, _ = md.WriteToc(&b)
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
// TocString returns the table of contents as string, ignoring errors
|
||||
func (md MD) TocString() string {
|
||||
var b strings.Builder
|
||||
_, _ = md.WriteToc(&b)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// WriteContent writes the MDs content to the given writer
|
||||
func (md MD) WriteContent(w io.Writer) (int64, error) {
|
||||
written := int64(0)
|
||||
write := func(s string) error {
|
||||
n, err := w.Write([]byte(s))
|
||||
written += int64(n)
|
||||
return err
|
||||
}
|
||||
for _, h := range md.Headings {
|
||||
if err := write(strings.Repeat("#", h.Level) + " " + h.Header + "\n"); err != nil {
|
||||
return written, err
|
||||
}
|
||||
if len(h.Content) > 0 {
|
||||
if err := write(h.Content); err != nil {
|
||||
return written, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// WriteToc writes the table of contents to the given writer
|
||||
func (md MD) WriteToc(w io.Writer) (int64, error) {
|
||||
var written int64
|
||||
for _, h := range md.Headings {
|
||||
if h.Level == 1 {
|
||||
// main title not in toc
|
||||
continue
|
||||
}
|
||||
link := fmt.Sprintf("#%s", strings.ToLower(strings.Replace(h.Header, " ", "-", -1)))
|
||||
s := fmt.Sprintf("%s* [%s](%s)\n", strings.Repeat(" ", h.Level-2), h.Header, link)
|
||||
n, err := w.Write([]byte(s))
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
written += int64(n)
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// NewMD parses a new Markdown
|
||||
func NewMD(b []byte) MD {
|
||||
var (
|
||||
md MD
|
||||
heading Heading
|
||||
content strings.Builder
|
||||
)
|
||||
sendHeading := func() {
|
||||
if heading.Header != "" {
|
||||
heading.Content = content.String()
|
||||
md.Headings = append(md.Headings, heading)
|
||||
content = strings.Builder{}
|
||||
}
|
||||
}
|
||||
parts := strings.Split("\n"+string(b), "\n#")
|
||||
numparts := len(parts) - 1
|
||||
for i, p := range parts {
|
||||
if i == 0 {
|
||||
// omit part before first heading
|
||||
continue
|
||||
}
|
||||
|
||||
all := strings.SplitN(p, "\n", 2)
|
||||
if len(all) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
head, con := all[0], all[1]
|
||||
// readd lost "#"
|
||||
heading = headingFromString("#" + head)
|
||||
_, _ = content.WriteString(con)
|
||||
// readd lost "\n" - omit for last part
|
||||
if i < numparts {
|
||||
_, _ = content.WriteString("\n")
|
||||
}
|
||||
// add heading
|
||||
sendHeading()
|
||||
}
|
||||
return md
|
||||
}
|
||||
|
||||
func headingFromString(s string) Heading {
|
||||
i := strings.LastIndex(s, "#")
|
||||
levs, con := s[:i+1], s[i+1:]
|
||||
return Heading{
|
||||
Level: len(levs),
|
||||
Header: strings.TrimPrefix(con, " "),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package markdown
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestSearch(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Markdown Suite")
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package markdown
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var (
|
||||
SmallMarkdown = `# Title
|
||||
|
||||
some abstract description
|
||||
|
||||
## SubTitle 1
|
||||
|
||||
subtitle one description
|
||||
|
||||
## SubTitle 2
|
||||
subtitle two description
|
||||
### Subpoint to SubTitle 2
|
||||
|
||||
description to subpoint
|
||||
|
||||
more text
|
||||
`
|
||||
SmallMD = MD{
|
||||
Headings: []Heading{
|
||||
{Level: 1, Header: "Title", Content: "\nsome abstract description\n\n"},
|
||||
{Level: 2, Header: "SubTitle 1", Content: "\nsubtitle one description\n\n"},
|
||||
{Level: 2, Header: "SubTitle 2", Content: "subtitle two description\n"},
|
||||
{Level: 3, Header: "Subpoint to SubTitle 2", Content: "\ndescription to subpoint\n\nmore text\n"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
var _ = Describe("TestMarkdown", func() {
|
||||
DescribeTable("Conversion works both ways",
|
||||
func(mdfile string, expectedMD MD) {
|
||||
md := NewMD([]byte(mdfile))
|
||||
|
||||
Expect(len(md.Headings)).To(Equal(len(expectedMD.Headings)))
|
||||
for i, h := range md.Headings {
|
||||
Expect(h).To(Equal(expectedMD.Headings[i]))
|
||||
}
|
||||
Expect(md.String()).To(Equal(mdfile))
|
||||
},
|
||||
Entry("converts a small markdown", SmallMarkdown, SmallMD),
|
||||
)
|
||||
})
|
||||
@@ -0,0 +1,78 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/cs3org/reva/v2/pkg/auth/scope"
|
||||
|
||||
revactx "github.com/cs3org/reva/v2/pkg/ctx"
|
||||
"github.com/cs3org/reva/v2/pkg/token/manager/jwt"
|
||||
"github.com/opencloud-eu/opencloud/pkg/account"
|
||||
"go-micro.dev/v4/metadata"
|
||||
)
|
||||
|
||||
// newAccountOptions initializes the available default options.
|
||||
func newAccountOptions(opts ...account.Option) account.Options {
|
||||
opt := account.Options{}
|
||||
|
||||
for _, o := range opts {
|
||||
o(&opt)
|
||||
}
|
||||
|
||||
return opt
|
||||
}
|
||||
|
||||
// AccountID serves as key for the account uuid in the context
|
||||
const AccountID string = "Account-Id"
|
||||
|
||||
// RoleIDs serves as key for the roles in the context
|
||||
const RoleIDs string = "Role-Ids"
|
||||
|
||||
// ExtractAccountUUID provides a middleware to extract the account uuid from the x-access-token header value
|
||||
// and write it to the context. If there is no x-access-token the middleware is omitted.
|
||||
func ExtractAccountUUID(opts ...account.Option) func(http.Handler) http.Handler {
|
||||
opt := newAccountOptions(opts...)
|
||||
tokenManager, err := jwt.New(map[string]interface{}{
|
||||
"secret": opt.JWTSecret,
|
||||
"expires": int64(24 * 60 * 60),
|
||||
})
|
||||
if err != nil {
|
||||
opt.Logger.Fatal().Err(err).Msgf("Could not initialize token-manager")
|
||||
}
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.Header.Get("x-access-token")
|
||||
if len(token) == 0 {
|
||||
roleIDsJSON, _ := json.Marshal([]string{})
|
||||
ctx := metadata.Set(r.Context(), RoleIDs, string(roleIDsJSON))
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
u, tokenScope, err := tokenManager.DismantleToken(r.Context(), token)
|
||||
if err != nil {
|
||||
opt.Logger.Error().Err(err)
|
||||
return
|
||||
}
|
||||
if ok, err := scope.VerifyScope(r.Context(), tokenScope, r); err != nil || !ok {
|
||||
opt.Logger.Error().Err(err).Msg("verifying scope failed")
|
||||
return
|
||||
}
|
||||
|
||||
// store user in context for request
|
||||
ctx := revactx.ContextSetUser(r.Context(), u)
|
||||
|
||||
// Important: user.Id.OpaqueId is the AccountUUID. Set this way in the account uuid middleware in ocis-proxy.
|
||||
// https://github.com/opencloud-eu/opencloud-proxy/blob/ea254d6036592cf9469d757d1295e0c4309d1e63/pkg/middleware/account_uuid.go#L109
|
||||
// TODO: implement token manager in cs3org/reva that uses generic metadata instead of access token from header.
|
||||
ctx = metadata.Set(ctx, AccountID, u.Id.OpaqueId)
|
||||
if u.Opaque != nil {
|
||||
if roles, ok := u.Opaque.Map["roles"]; ok {
|
||||
ctx = metadata.Set(ctx, RoleIDs, string(roles.Value))
|
||||
}
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/cors"
|
||||
|
||||
rscors "github.com/rs/cors"
|
||||
)
|
||||
|
||||
// NoCache writes required cache headers to all requests.
|
||||
func NoCache(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate, value")
|
||||
w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT")
|
||||
w.Header().Set("Last-Modified", time.Now().UTC().Format(http.TimeFormat))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Cors writes required cors headers to all requests.
|
||||
func Cors(opts ...cors.Option) func(http.Handler) http.Handler {
|
||||
options := cors.NewOptions(opts...)
|
||||
logger := options.Logger
|
||||
logger.Debug().
|
||||
Str("allowed_origins", strings.Join(options.AllowedOrigins, ", ")).
|
||||
Str("allowed_methods", strings.Join(options.AllowedMethods, ", ")).
|
||||
Str("allowed_headers", strings.Join(options.AllowedHeaders, ", ")).
|
||||
Bool("allow_credentials", options.AllowCredentials).
|
||||
Msg("setup cors middleware")
|
||||
c := rscors.New(rscors.Options{
|
||||
AllowedOrigins: options.AllowedOrigins,
|
||||
AllowedMethods: options.AllowedMethods,
|
||||
AllowedHeaders: options.AllowedHeaders,
|
||||
AllowCredentials: options.AllowCredentials,
|
||||
})
|
||||
return c.Handler
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
)
|
||||
|
||||
// Logger is a middleware to log http requests. It uses debug level logging and should be used by all services save the proxy (which uses info level logging).
|
||||
func Logger(logger log.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
wrap := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
|
||||
next.ServeHTTP(wrap, r)
|
||||
|
||||
logger.Debug().
|
||||
Str(log.RequestIDString, r.Header.Get("X-Request-ID")).
|
||||
Str("proto", r.Proto).
|
||||
Str("method", r.Method).
|
||||
Int("status", wrap.Status()).
|
||||
Str("path", r.URL.Path).
|
||||
Dur("duration", time.Since(start)).
|
||||
Int("bytes", wrap.BytesWritten()).
|
||||
Msg("")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
goidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/opencloud-eu/opencloud/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// newOidcOptions initializes the available default options.
|
||||
func newOidcOptions(opts ...Option) Options {
|
||||
opt := Options{}
|
||||
|
||||
for _, o := range opts {
|
||||
o(&opt)
|
||||
}
|
||||
|
||||
return opt
|
||||
}
|
||||
|
||||
// OIDCProvider used to mock the oidc provider during tests
|
||||
type OIDCProvider interface {
|
||||
UserInfo(ctx context.Context, ts oauth2.TokenSource) (*goidc.UserInfo, error)
|
||||
}
|
||||
|
||||
// OidcAuth provides a middleware to authenticate a bearer auth with an OpenID Connect identity provider
|
||||
// It will put all claims provided by the userinfo endpoint in the context
|
||||
func OidcAuth(opts ...Option) func(http.Handler) http.Handler {
|
||||
opt := newOidcOptions(opts...)
|
||||
|
||||
// TODO use a micro store cache option
|
||||
|
||||
providerFunc := func() (OIDCProvider, error) {
|
||||
// Initialize a provider by specifying the issuer URL.
|
||||
// it will fetch the keys from the issuer using the .well-known
|
||||
// endpoint
|
||||
return goidc.NewProvider(
|
||||
context.WithValue(context.Background(), oauth2.HTTPClient, &opt.HttpClient),
|
||||
opt.OidcIssuer,
|
||||
)
|
||||
}
|
||||
var provider OIDCProvider
|
||||
initializeProviderLock := sync.Mutex{}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
switch {
|
||||
case strings.HasPrefix(authHeader, "Bearer "):
|
||||
if provider == nil {
|
||||
// lazy initialize provider
|
||||
initializeProviderLock.Lock()
|
||||
var err error
|
||||
// ensure no other request initialized the provider
|
||||
if provider == nil {
|
||||
provider, err = providerFunc()
|
||||
}
|
||||
initializeProviderLock.Unlock()
|
||||
if err != nil {
|
||||
opt.Logger.Error().Err(err).Msg("could not initialize OIDC provider")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
opt.Logger.Debug().Msg("initialized OIDC provider")
|
||||
}
|
||||
|
||||
oauth2Token := &oauth2.Token{
|
||||
AccessToken: strings.TrimPrefix(authHeader, "Bearer "),
|
||||
}
|
||||
|
||||
userInfo, err := provider.UserInfo(
|
||||
context.WithValue(ctx, oauth2.HTTPClient, &opt.HttpClient),
|
||||
oauth2.StaticTokenSource(oauth2Token),
|
||||
)
|
||||
if err != nil {
|
||||
w.Header().Add("WWW-Authenticate", `Bearer`)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
claims := map[string]interface{}{}
|
||||
err = userInfo.Claims(&claims)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
ctx = oidc.NewContext(ctx, claims)
|
||||
|
||||
default:
|
||||
// do nothing
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
gatewayv1beta1 "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1"
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
)
|
||||
|
||||
// Option defines a single option function.
|
||||
type Option func(o *Options)
|
||||
|
||||
// Options defines the available options for this package.
|
||||
type Options struct {
|
||||
// Logger to use for logging, must be set
|
||||
Logger log.Logger
|
||||
// The OpenID Connect Issuer URL
|
||||
OidcIssuer string
|
||||
// GatewayAPIClient is a reva gateway client
|
||||
GatewayAPIClient gatewayv1beta1.GatewayAPIClient
|
||||
// HttpClient is a http client
|
||||
HttpClient http.Client
|
||||
}
|
||||
|
||||
// WithLogger provides a function to set the openid connect issuer option.
|
||||
func WithOidcIssuer(val string) Option {
|
||||
return func(o *Options) {
|
||||
o.OidcIssuer = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger provides a function to set the logger option.
|
||||
func WithLogger(val log.Logger) Option {
|
||||
return func(o *Options) {
|
||||
o.Logger = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithGatewayAPIClient provides a function to set the reva gateway client option.
|
||||
func WithGatewayAPIClient(val gatewayv1beta1.GatewayAPIClient) Option {
|
||||
return func(o *Options) {
|
||||
o.GatewayAPIClient = val
|
||||
}
|
||||
}
|
||||
|
||||
// HttpClient provides a function to set the http client option.
|
||||
func WithHttpClient(val http.Client) Option {
|
||||
return func(o *Options) {
|
||||
o.HttpClient = val
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Static is a middleware that serves static assets.
|
||||
func Static(root string, fs http.FileSystem, ttl int) func(http.Handler) http.Handler {
|
||||
if !strings.HasSuffix(root, "/") {
|
||||
root = root + "/"
|
||||
}
|
||||
|
||||
static := http.StripPrefix(
|
||||
root,
|
||||
http.FileServer(
|
||||
fs,
|
||||
),
|
||||
)
|
||||
|
||||
// TODO: investigate broken caching - https://github.com/opencloud-eu/opencloud/issues/1094
|
||||
// we don't have a last modification date of the static assets, so we use the service start date
|
||||
//lastModified := time.Now().UTC().Format(http.TimeFormat)
|
||||
//expires := time.Now().Add(time.Second * time.Duration(ttl)).UTC().Format(http.TimeFormat)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, path.Join(root, "api")) {
|
||||
next.ServeHTTP(w, r)
|
||||
} else {
|
||||
// TODO: investigate broken caching - https://github.com/opencloud-eu/opencloud/issues/1094
|
||||
//w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%s, must-revalidate", strconv.Itoa(ttl)))
|
||||
//w.Header().Set("Expires", expires)
|
||||
//w.Header().Set("Last-Modified", lastModified)
|
||||
w.Header().Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate, value")
|
||||
w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT")
|
||||
w.Header().Set("Last-Modified", time.Now().UTC().Format(http.TimeFormat))
|
||||
static.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
// Throttle limits the number of concurrent requests.
|
||||
func Throttle(limit int) func(http.Handler) http.Handler {
|
||||
if limit > 0 {
|
||||
return middleware.Throttle(limit)
|
||||
}
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"go.opentelemetry.io/otel/propagation"
|
||||
)
|
||||
|
||||
var propagator = propagation.NewCompositeTextMapPropagator(
|
||||
propagation.Baggage{},
|
||||
propagation.TraceContext{},
|
||||
)
|
||||
|
||||
// TraceContext unpacks the request context looking for an existing trace id.
|
||||
func TraceContext(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
|
||||
propagator.Inject(ctx, propagation.HeaderCarrier(r.Header))
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Version writes the current version to the headers.
|
||||
func Version(name, version string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(
|
||||
fmt.Sprintf("X-%s-VERSION", strings.ToUpper(name)),
|
||||
version,
|
||||
)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package natsjsregistry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go-micro.dev/v4/registry"
|
||||
"go-micro.dev/v4/store"
|
||||
)
|
||||
|
||||
type storeOptionsKey struct{}
|
||||
type defaultTTLKey struct{}
|
||||
|
||||
// StoreOptions sets the options for the underlying store
|
||||
func StoreOptions(opts []store.Option) registry.Option {
|
||||
return func(o *registry.Options) {
|
||||
if o.Context == nil {
|
||||
o.Context = context.Background()
|
||||
}
|
||||
o.Context = context.WithValue(o.Context, storeOptionsKey{}, opts)
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultTTL allows setting a default register TTL for services
|
||||
func DefaultTTL(t time.Duration) registry.Option {
|
||||
return func(o *registry.Options) {
|
||||
if o.Context == nil {
|
||||
o.Context = context.Background()
|
||||
}
|
||||
o.Context = context.WithValue(o.Context, defaultTTLKey{}, t)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
// Package natsjsregistry implements a registry using natsjs kv store
|
||||
package natsjsregistry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
natsjskv "github.com/go-micro/plugins/v4/store/nats-js-kv"
|
||||
"github.com/nats-io/nats.go"
|
||||
"go-micro.dev/v4/registry"
|
||||
"go-micro.dev/v4/server"
|
||||
"go-micro.dev/v4/store"
|
||||
"go-micro.dev/v4/util/cmd"
|
||||
)
|
||||
|
||||
var (
|
||||
_registryName = "nats-js-kv"
|
||||
_registryAddressEnv = "MICRO_REGISTRY_ADDRESS"
|
||||
_registryUsernameEnv = "MICRO_REGISTRY_AUTH_USERNAME"
|
||||
_registryPasswordEnv = "MICRO_REGISTRY_AUTH_PASSWORD"
|
||||
|
||||
_serviceDelimiter = "@"
|
||||
)
|
||||
|
||||
func init() {
|
||||
cmd.DefaultRegistries[_registryName] = NewRegistry
|
||||
}
|
||||
|
||||
// NewRegistry returns a new natsjs registry
|
||||
func NewRegistry(opts ...registry.Option) registry.Registry {
|
||||
options := registry.Options{
|
||||
Context: context.Background(),
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(&options)
|
||||
}
|
||||
defaultTTL, _ := options.Context.Value(defaultTTLKey{}).(time.Duration)
|
||||
n := &storeregistry{
|
||||
opts: options,
|
||||
typ: _registryName,
|
||||
defaultTTL: defaultTTL,
|
||||
}
|
||||
n.store = natsjskv.NewStore(n.storeOptions(options)...)
|
||||
return n
|
||||
}
|
||||
|
||||
type storeregistry struct {
|
||||
opts registry.Options
|
||||
store store.Store
|
||||
typ string
|
||||
defaultTTL time.Duration
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// Init inits the registry
|
||||
func (n *storeregistry) Init(opts ...registry.Option) error {
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
|
||||
for _, o := range opts {
|
||||
o(&n.opts)
|
||||
}
|
||||
n.store = natsjskv.NewStore(n.storeOptions(n.opts)...)
|
||||
return n.store.Init(n.storeOptions(n.opts)...)
|
||||
}
|
||||
|
||||
// Options returns the configured options
|
||||
func (n *storeregistry) Options() registry.Options {
|
||||
return n.opts
|
||||
}
|
||||
|
||||
// Register adds a service to the registry
|
||||
func (n *storeregistry) Register(s *registry.Service, opts ...registry.RegisterOption) error {
|
||||
n.lock.RLock()
|
||||
defer n.lock.RUnlock()
|
||||
|
||||
if s == nil {
|
||||
return errors.New("wont store nil service")
|
||||
}
|
||||
|
||||
var options registry.RegisterOptions
|
||||
options.TTL = n.defaultTTL
|
||||
for _, o := range opts {
|
||||
o(&options)
|
||||
}
|
||||
|
||||
b, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return n.store.Write(&store.Record{
|
||||
Key: s.Name + _serviceDelimiter + server.DefaultId + _serviceDelimiter + s.Version,
|
||||
Value: b,
|
||||
Expiry: options.TTL,
|
||||
})
|
||||
}
|
||||
|
||||
// Deregister removes a service from the registry.
|
||||
func (n *storeregistry) Deregister(s *registry.Service, _ ...registry.DeregisterOption) error {
|
||||
n.lock.RLock()
|
||||
defer n.lock.RUnlock()
|
||||
return n.store.Delete(s.Name + _serviceDelimiter + server.DefaultId + _serviceDelimiter + s.Version)
|
||||
}
|
||||
|
||||
// GetService gets a specific service from the registry
|
||||
func (n *storeregistry) GetService(s string, _ ...registry.GetOption) ([]*registry.Service, error) {
|
||||
// avoid listing e.g. `webfinger` when requesting `web` by adding the delimiter to the service name
|
||||
return n.listServices(store.ListPrefix(s + _serviceDelimiter))
|
||||
}
|
||||
|
||||
// ListServices lists all registered services
|
||||
func (n *storeregistry) ListServices(...registry.ListOption) ([]*registry.Service, error) {
|
||||
return n.listServices()
|
||||
}
|
||||
|
||||
// Watch allowes following the changes in the registry if it would be implemented
|
||||
func (n *storeregistry) Watch(...registry.WatchOption) (registry.Watcher, error) {
|
||||
return NewWatcher(n)
|
||||
}
|
||||
|
||||
// String returns the name of the registry
|
||||
func (n *storeregistry) String() string {
|
||||
return n.typ
|
||||
}
|
||||
|
||||
func (n *storeregistry) listServices(opts ...store.ListOption) ([]*registry.Service, error) {
|
||||
n.lock.RLock()
|
||||
defer n.lock.RUnlock()
|
||||
|
||||
keys, err := n.store.List(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
versions := map[string]*registry.Service{}
|
||||
for _, k := range keys {
|
||||
s, err := n.getNode(k)
|
||||
if err != nil {
|
||||
// TODO: continue ?
|
||||
return nil, err
|
||||
}
|
||||
if versions[s.Version] == nil {
|
||||
versions[s.Version] = s
|
||||
} else {
|
||||
versions[s.Version].Nodes = append(versions[s.Version].Nodes, s.Nodes...)
|
||||
}
|
||||
}
|
||||
svcs := make([]*registry.Service, 0, len(versions))
|
||||
for _, s := range versions {
|
||||
svcs = append(svcs, s)
|
||||
}
|
||||
return svcs, nil
|
||||
}
|
||||
|
||||
// getNode retrieves a node from the store. It returns a service to also keep track of the version.
|
||||
func (n *storeregistry) getNode(s string) (*registry.Service, error) {
|
||||
recs, err := n.store.Read(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(recs) == 0 {
|
||||
return nil, registry.ErrNotFound
|
||||
}
|
||||
var svc registry.Service
|
||||
if err := json.Unmarshal(recs[0].Value, &svc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &svc, nil
|
||||
}
|
||||
|
||||
func (n *storeregistry) storeOptions(opts registry.Options) []store.Option {
|
||||
storeoptions := []store.Option{
|
||||
store.Database("service-registry"),
|
||||
store.Table("service-registry"),
|
||||
natsjskv.DefaultMemory(),
|
||||
natsjskv.EncodeKeys(),
|
||||
}
|
||||
|
||||
if defaultTTL, ok := opts.Context.Value(defaultTTLKey{}).(time.Duration); ok {
|
||||
storeoptions = append(storeoptions, natsjskv.DefaultTTL(defaultTTL))
|
||||
}
|
||||
|
||||
addr := []string{"127.0.0.1:9233"}
|
||||
if len(opts.Addrs) > 0 {
|
||||
addr = opts.Addrs
|
||||
} else if a := strings.Split(os.Getenv(_registryAddressEnv), ","); len(a) > 0 && a[0] != "" {
|
||||
addr = a
|
||||
}
|
||||
storeoptions = append(storeoptions, store.Nodes(addr...))
|
||||
|
||||
natsOptions := nats.GetDefaultOptions()
|
||||
natsOptions.Name = "nats-js-kv-registry"
|
||||
natsOptions.User, natsOptions.Password = getAuth()
|
||||
natsOptions.ReconnectedCB = func(_ *nats.Conn) {
|
||||
if err := n.Init(); err != nil {
|
||||
fmt.Println("cannot reconnect to nats")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
natsOptions.ClosedCB = func(_ *nats.Conn) {
|
||||
fmt.Println("nats connection closed")
|
||||
os.Exit(1)
|
||||
}
|
||||
storeoptions = append(storeoptions, natsjskv.NatsOptions(natsOptions))
|
||||
|
||||
if so, ok := opts.Context.Value(storeOptionsKey{}).([]store.Option); ok {
|
||||
storeoptions = append(storeoptions, so...)
|
||||
}
|
||||
|
||||
return storeoptions
|
||||
}
|
||||
|
||||
func getAuth() (string, string) {
|
||||
return os.Getenv(_registryUsernameEnv), os.Getenv(_registryPasswordEnv)
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package natsjsregistry
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
natsjskv "github.com/go-micro/plugins/v4/store/nats-js-kv"
|
||||
"github.com/nats-io/nats.go"
|
||||
"go-micro.dev/v4/registry"
|
||||
)
|
||||
|
||||
// NatsWatcher is the watcher of the nats interface
|
||||
type NatsWatcher interface {
|
||||
WatchAll(bucket string, opts ...nats.WatchOpt) (<-chan *natsjskv.StoreUpdate, func() error, error)
|
||||
}
|
||||
|
||||
// Watcher is used to keep track of changes in the registry
|
||||
type Watcher struct {
|
||||
updates <-chan *natsjskv.StoreUpdate
|
||||
stop func() error
|
||||
reg *storeregistry
|
||||
}
|
||||
|
||||
// NewWatcher returns a new watcher
|
||||
func NewWatcher(s *storeregistry) (*Watcher, error) {
|
||||
w, ok := s.store.(NatsWatcher)
|
||||
if !ok {
|
||||
return nil, errors.New("store does not implement watcher interface")
|
||||
}
|
||||
|
||||
watcher, stop, err := w.WatchAll("service-registry")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Watcher{
|
||||
updates: watcher,
|
||||
stop: stop,
|
||||
reg: s,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Next returns the next result. It is a blocking call
|
||||
func (w *Watcher) Next() (*registry.Result, error) {
|
||||
kve := <-w.updates
|
||||
if kve == nil {
|
||||
return nil, errors.New("watcher stopped")
|
||||
}
|
||||
|
||||
var svc registry.Service
|
||||
if kve.Value.Data == nil {
|
||||
// fake a service
|
||||
parts := strings.SplitN(kve.Value.Key, _serviceDelimiter, 3)
|
||||
if len(parts) != 3 {
|
||||
return nil, errors.New("invalid service key")
|
||||
}
|
||||
svc.Name = parts[0]
|
||||
// ocis registers nodes with a - separator
|
||||
svc.Nodes = []*registry.Node{{Id: parts[0] + "-" + parts[1]}}
|
||||
svc.Version = parts[2]
|
||||
} else {
|
||||
if err := json.Unmarshal(kve.Value.Data, &svc); err != nil {
|
||||
_ = w.stop()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return ®istry.Result{
|
||||
Service: &svc,
|
||||
Action: kve.Action,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Stop stops the watcher
|
||||
func (w *Watcher) Stop() {
|
||||
_ = w.stop()
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
Iss = "iss"
|
||||
Sub = "sub"
|
||||
Email = "email"
|
||||
Name = "name"
|
||||
PreferredUsername = "preferred_username"
|
||||
UIDNumber = "uidnumber"
|
||||
GIDNumber = "gidnumber"
|
||||
Groups = "groups"
|
||||
OwncloudUUID = "ownclouduuid"
|
||||
OcisRoutingPolicy = "ocis.routing.policy"
|
||||
)
|
||||
|
||||
// SplitWithEscaping splits s into segments using separator which can be escaped using the escape string
|
||||
// See https://codereview.stackexchange.com/a/280193
|
||||
func SplitWithEscaping(s string, separator string, escapeString string) []string {
|
||||
a := strings.Split(s, separator)
|
||||
|
||||
for i := len(a) - 2; i >= 0; i-- {
|
||||
if strings.HasSuffix(a[i], escapeString) {
|
||||
a[i] = a[i][:len(a[i])-len(escapeString)] + separator + a[i+1]
|
||||
a = append(a[:i+1], a[i+2:]...)
|
||||
}
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
// WalkSegments uses the given array of segments to walk the claims and return whatever interface was found
|
||||
func WalkSegments(segments []string, claims map[string]interface{}) (interface{}, error) {
|
||||
i := 0
|
||||
for ; i < len(segments)-1; i++ {
|
||||
switch castedClaims := claims[segments[i]].(type) {
|
||||
case map[string]interface{}:
|
||||
claims = castedClaims
|
||||
case map[interface{}]interface{}:
|
||||
claims = make(map[string]interface{}, len(castedClaims))
|
||||
for k, v := range castedClaims {
|
||||
if s, ok := k.(string); ok {
|
||||
claims[s] = v
|
||||
} else {
|
||||
return nil, fmt.Errorf("could not walk claims path, key '%v' is not a string", k)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type '%v'", castedClaims)
|
||||
}
|
||||
}
|
||||
return claims[segments[i]], nil
|
||||
}
|
||||
|
||||
// ReadStringClaim returns the string obtained by following the . seperated path in the claims
|
||||
func ReadStringClaim(path string, claims map[string]interface{}) (string, error) {
|
||||
// check the simple case first
|
||||
value, _ := claims[path].(string)
|
||||
if value != "" {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
claim, err := WalkSegments(SplitWithEscaping(path, ".", "\\"), claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if value, _ = claim.(string); value != "" {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
return value, fmt.Errorf("claim path '%s' not set or empty", path)
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
package oidc_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/oidc"
|
||||
)
|
||||
|
||||
type splitWithEscapingTest struct {
|
||||
// Name of the subtest.
|
||||
name string
|
||||
|
||||
// string to split
|
||||
s string
|
||||
|
||||
// seperator to use
|
||||
seperator string
|
||||
|
||||
// escape character to use for escaping
|
||||
escape string
|
||||
|
||||
expectedParts []string
|
||||
}
|
||||
|
||||
func (swet splitWithEscapingTest) run(t *testing.T) {
|
||||
parts := oidc.SplitWithEscaping(swet.s, swet.seperator, swet.escape)
|
||||
if len(swet.expectedParts) != len(parts) {
|
||||
t.Errorf("mismatching length")
|
||||
}
|
||||
for i, v := range swet.expectedParts {
|
||||
if parts[i] != v {
|
||||
t.Errorf("expected part %d to be '%s', got '%s'", i, v, parts[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitWithEscaping(t *testing.T) {
|
||||
tests := []splitWithEscapingTest{
|
||||
{
|
||||
name: "plain claim name",
|
||||
s: "roles",
|
||||
seperator: ".",
|
||||
escape: "\\",
|
||||
expectedParts: []string{"roles"},
|
||||
},
|
||||
{
|
||||
name: "claim with .",
|
||||
s: "my.roles",
|
||||
seperator: ".",
|
||||
escape: "\\",
|
||||
expectedParts: []string{"my", "roles"},
|
||||
},
|
||||
{
|
||||
name: "claim with escaped .",
|
||||
s: "my\\.roles",
|
||||
seperator: ".",
|
||||
escape: "\\",
|
||||
expectedParts: []string{"my.roles"},
|
||||
},
|
||||
{
|
||||
name: "claim with escaped . left",
|
||||
s: "my\\.other.roles",
|
||||
seperator: ".",
|
||||
escape: "\\",
|
||||
expectedParts: []string{"my.other", "roles"},
|
||||
},
|
||||
{
|
||||
name: "claim with escaped . right",
|
||||
s: "my.other\\.roles",
|
||||
seperator: ".",
|
||||
escape: "\\",
|
||||
expectedParts: []string{"my", "other.roles"},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, test.run)
|
||||
}
|
||||
}
|
||||
|
||||
type walkSegmentsTest struct {
|
||||
// Name of the subtest.
|
||||
name string
|
||||
|
||||
// path segments to walk
|
||||
segments []string
|
||||
|
||||
// seperator to use
|
||||
claims map[string]interface{}
|
||||
|
||||
expected interface{}
|
||||
|
||||
wantErr bool
|
||||
}
|
||||
|
||||
func (wst walkSegmentsTest) run(t *testing.T) {
|
||||
v, err := oidc.WalkSegments(wst.segments, wst.claims)
|
||||
if err != nil && !wst.wantErr {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
if err == nil && wst.wantErr {
|
||||
t.Errorf("expected error")
|
||||
}
|
||||
if !reflect.DeepEqual(v, wst.expected) {
|
||||
t.Errorf("expected %v got %v", wst.expected, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWalkSegments(t *testing.T) {
|
||||
byt := []byte(`{"first":{"second":{"third":["value1","value2"]},"foo":"bar"},"fizz":"buzz"}`)
|
||||
var dat map[string]interface{}
|
||||
if err := json.Unmarshal(byt, &dat); err != nil {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
|
||||
tests := []walkSegmentsTest{
|
||||
{
|
||||
name: "one segment, single value",
|
||||
segments: []string{"first"},
|
||||
claims: map[string]interface{}{
|
||||
"first": "value",
|
||||
},
|
||||
expected: "value",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "one segment, array value",
|
||||
segments: []string{"first"},
|
||||
claims: map[string]interface{}{
|
||||
"first": []string{"value1", "value2"},
|
||||
},
|
||||
expected: []string{"value1", "value2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "two segments, single value",
|
||||
segments: []string{"first", "second"},
|
||||
claims: map[string]interface{}{
|
||||
"first": map[string]interface{}{
|
||||
"second": "value",
|
||||
},
|
||||
},
|
||||
expected: "value",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "two segments, array value",
|
||||
segments: []string{"first", "second"},
|
||||
claims: map[string]interface{}{
|
||||
"first": map[string]interface{}{
|
||||
"second": []string{"value1", "value2"},
|
||||
},
|
||||
},
|
||||
expected: []string{"value1", "value2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "three segments, array value from json",
|
||||
segments: []string{"first", "second", "third"},
|
||||
claims: dat,
|
||||
expected: []interface{}{"value1", "value2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "three segments, array value with interface key",
|
||||
segments: []string{"first", "second", "third"},
|
||||
claims: map[string]interface{}{
|
||||
"first": map[interface{}]interface{}{
|
||||
"second": map[interface{}]interface{}{
|
||||
"third": []string{"value1", "value2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []string{"value1", "value2"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, test.run)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,373 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/MicahParks/keyfunc/v2"
|
||||
goidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
"github.com/opencloud-eu/opencloud/services/proxy/pkg/config"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OIDCClient used to mock the oidc client during tests
|
||||
type OIDCClient interface {
|
||||
UserInfo(ctx context.Context, ts oauth2.TokenSource) (*UserInfo, error)
|
||||
VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, jwt.MapClaims, error)
|
||||
VerifyLogoutToken(ctx context.Context, token string) (*LogoutToken, error)
|
||||
}
|
||||
|
||||
// KeySet is a set of public JSON Web Keys that can be used to validate the signature
|
||||
// of JSON web tokens. This is expected to be backed by a remote key set through
|
||||
// provider metadata discovery or an in-memory set of keys delivered out-of-band.
|
||||
type KeySet interface {
|
||||
// VerifySignature parses the JSON web token, verifies the signature, and returns
|
||||
// the raw payload. Header and claim fields are validated by other parts of the
|
||||
// package. For example, the KeySet does not need to check values such as signature
|
||||
// algorithm, issuer, and audience since the IDTokenVerifier validates these values
|
||||
// independently.
|
||||
//
|
||||
// If VerifySignature makes HTTP requests to verify the token, it's expected to
|
||||
// use any HTTP client associated with the context through ClientContext.
|
||||
VerifySignature(ctx context.Context, jwt string) (payload []byte, err error)
|
||||
}
|
||||
|
||||
type RegClaimsWithSID struct {
|
||||
SessionID string `json:"sid"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type oidcClient struct {
|
||||
// Logger to use for logging, must be set
|
||||
Logger log.Logger
|
||||
|
||||
issuer string
|
||||
provider *ProviderMetadata
|
||||
providerLock *sync.Mutex
|
||||
skipIssuerValidation bool
|
||||
accessTokenVerifyMethod string
|
||||
remoteKeySet KeySet
|
||||
algorithms []string
|
||||
|
||||
JWKSOptions config.JWKS
|
||||
JWKS *keyfunc.JWKS
|
||||
jwksLock *sync.Mutex
|
||||
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// _supportedAlgorithms is a list of algorithms explicitly supported by this
|
||||
// package. If a provider supports other algorithms, such as HS256 or none,
|
||||
// those values won't be passed to the IDTokenVerifier.
|
||||
var _supportedAlgorithms = map[string]bool{
|
||||
RS256: true,
|
||||
RS384: true,
|
||||
RS512: true,
|
||||
ES256: true,
|
||||
ES384: true,
|
||||
ES512: true,
|
||||
PS256: true,
|
||||
PS384: true,
|
||||
PS512: true,
|
||||
}
|
||||
|
||||
// NewOIDCClient returns an OIDClient instance for the given issuer
|
||||
func NewOIDCClient(opts ...Option) OIDCClient {
|
||||
options := newOptions(opts...)
|
||||
|
||||
return &oidcClient{
|
||||
Logger: options.Logger,
|
||||
issuer: options.OIDCIssuer,
|
||||
httpClient: options.HTTPClient,
|
||||
accessTokenVerifyMethod: options.AccessTokenVerifyMethod,
|
||||
JWKSOptions: options.JWKSOptions, // TODO I don't like that we pass down config options ...
|
||||
JWKS: options.JWKS,
|
||||
providerLock: &sync.Mutex{},
|
||||
jwksLock: &sync.Mutex{},
|
||||
remoteKeySet: options.KeySet,
|
||||
provider: options.ProviderMetadata,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *oidcClient) lookupWellKnownOpenidConfiguration(ctx context.Context) error {
|
||||
c.providerLock.Lock()
|
||||
defer c.providerLock.Unlock()
|
||||
if c.provider == nil {
|
||||
wellKnown := strings.TrimSuffix(c.issuer, "/") + wellknownPath
|
||||
req, err := http.NewRequest("GET", wellKnown, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := c.httpClient.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read response body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("%s: %s", resp.Status, body)
|
||||
}
|
||||
|
||||
var p ProviderMetadata
|
||||
err = unmarshalResp(resp, body, &p)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oidc: failed to decode provider discovery object: %v", err)
|
||||
}
|
||||
|
||||
if !c.skipIssuerValidation && p.Issuer != c.issuer {
|
||||
return fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", c.issuer, p.Issuer)
|
||||
}
|
||||
var algs []string
|
||||
for _, a := range p.IDTokenSigningAlgValuesSupported {
|
||||
if _supportedAlgorithms[a] {
|
||||
algs = append(algs, a)
|
||||
}
|
||||
}
|
||||
c.provider = &p
|
||||
c.algorithms = algs
|
||||
c.remoteKeySet = goidc.NewRemoteKeySet(goidc.ClientContext(ctx, c.httpClient), p.JwksURI)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *oidcClient) getKeyfunc() *keyfunc.JWKS {
|
||||
c.jwksLock.Lock()
|
||||
defer c.jwksLock.Unlock()
|
||||
if c.JWKS == nil {
|
||||
var err error
|
||||
c.Logger.Debug().Str("jwks", c.provider.JwksURI).Msg("discovered jwks endpoint")
|
||||
options := keyfunc.Options{
|
||||
Client: c.httpClient,
|
||||
RefreshErrorHandler: func(err error) {
|
||||
c.Logger.Error().Err(err).Msg("There was an error with the jwt.Keyfunc")
|
||||
},
|
||||
RefreshInterval: time.Minute * time.Duration(c.JWKSOptions.RefreshInterval),
|
||||
RefreshRateLimit: time.Second * time.Duration(c.JWKSOptions.RefreshRateLimit),
|
||||
RefreshTimeout: time.Second * time.Duration(c.JWKSOptions.RefreshTimeout),
|
||||
RefreshUnknownKID: c.JWKSOptions.RefreshUnknownKID,
|
||||
}
|
||||
c.JWKS, err = keyfunc.Get(c.provider.JwksURI, options)
|
||||
if err != nil {
|
||||
c.JWKS = nil
|
||||
c.Logger.Error().Err(err).Msg("Failed to create JWKS from resource at the given URL.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return c.JWKS
|
||||
}
|
||||
|
||||
type stringAsBool bool
|
||||
|
||||
// Claims unmarshals the raw JSON string into a bool.
|
||||
func (sb *stringAsBool) UnmarshalJSON(b []byte) error {
|
||||
v, err := strconv.ParseBool(string(b))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*sb = stringAsBool(v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserInfo represents the OpenID Connect userinfo claims.
|
||||
type UserInfo struct {
|
||||
Subject string `json:"sub"`
|
||||
Profile string `json:"profile"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
|
||||
claims []byte
|
||||
}
|
||||
|
||||
type userInfoRaw struct {
|
||||
Subject string `json:"sub"`
|
||||
Profile string `json:"profile"`
|
||||
Email string `json:"email"`
|
||||
// Handle providers that return email_verified as a string
|
||||
// https://forums.aws.amazon.com/thread.jspa?messageID=949441󧳁 and
|
||||
// https://discuss.elastic.co/t/openid-error-after-authenticating-against-aws-cognito/206018/11
|
||||
EmailVerified stringAsBool `json:"email_verified"`
|
||||
}
|
||||
|
||||
// Claims unmarshals the raw JSON object claims into the provided object.
|
||||
func (u *UserInfo) Claims(v interface{}) error {
|
||||
if u.claims == nil {
|
||||
return errors.New("oidc: claims not set")
|
||||
}
|
||||
return json.Unmarshal(u.claims, v)
|
||||
}
|
||||
|
||||
// UserInfo retrieves the userinfo from a Token
|
||||
func (c *oidcClient) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource) (*UserInfo, error) {
|
||||
if err := c.lookupWellKnownOpenidConfiguration(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.provider.UserinfoEndpoint == "" {
|
||||
return nil, errors.New("oidc: user info endpoint is not supported by this provider")
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", c.provider.UserinfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: create GET request: %v", err)
|
||||
}
|
||||
|
||||
token, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: get access token: %v", err)
|
||||
}
|
||||
token.SetAuthHeader(req)
|
||||
|
||||
resp, err := c.httpClient.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("%s: %s", resp.Status, body)
|
||||
}
|
||||
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
mediaType, _, err := mime.ParseMediaType(ct)
|
||||
if err == nil && mediaType == "application/jwt" {
|
||||
payload, err := c.remoteKeySet.VerifySignature(goidc.ClientContext(ctx, c.httpClient), string(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: invalid userinfo jwt signature %v", err)
|
||||
}
|
||||
body = payload
|
||||
}
|
||||
|
||||
var userInfo userInfoRaw
|
||||
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||
return nil, fmt.Errorf("oidc: failed to decode userinfo: %v", err)
|
||||
}
|
||||
return &UserInfo{
|
||||
Subject: userInfo.Subject,
|
||||
Profile: userInfo.Profile,
|
||||
Email: userInfo.Email,
|
||||
EmailVerified: bool(userInfo.EmailVerified),
|
||||
claims: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *oidcClient) VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, jwt.MapClaims, error) {
|
||||
if err := c.lookupWellKnownOpenidConfiguration(ctx); err != nil {
|
||||
return RegClaimsWithSID{}, jwt.MapClaims{}, err
|
||||
}
|
||||
switch c.accessTokenVerifyMethod {
|
||||
case config.AccessTokenVerificationJWT:
|
||||
return c.verifyAccessTokenJWT(token)
|
||||
case config.AccessTokenVerificationNone:
|
||||
c.Logger.Debug().Msg("Access Token verification disabled")
|
||||
return RegClaimsWithSID{}, jwt.MapClaims{}, nil
|
||||
default:
|
||||
c.Logger.Error().Str("access_token_verify_method", c.accessTokenVerifyMethod).Msg("Unknown Access Token verification setting")
|
||||
return RegClaimsWithSID{}, jwt.MapClaims{}, errors.New("unknown Access Token Verification method")
|
||||
}
|
||||
}
|
||||
|
||||
// verifyAccessTokenJWT tries to parse and verify the access token as a JWT.
|
||||
func (c *oidcClient) verifyAccessTokenJWT(token string) (RegClaimsWithSID, jwt.MapClaims, error) {
|
||||
var claims RegClaimsWithSID
|
||||
mapClaims := jwt.MapClaims{}
|
||||
jwks := c.getKeyfunc()
|
||||
if jwks == nil {
|
||||
return claims, mapClaims, errors.New("error initializing jwks keyfunc")
|
||||
}
|
||||
|
||||
issuer := c.issuer
|
||||
if c.provider.AccessTokenIssuer != "" {
|
||||
// AD FS .well-known/openid-configuration has an optional `access_token_issuer` which takes precedence over `issuer`
|
||||
// See https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-oidce/586de7dd-3385-47c7-93a2-935d9e90441c
|
||||
issuer = c.provider.AccessTokenIssuer
|
||||
}
|
||||
|
||||
_, err := jwt.ParseWithClaims(token, &claims, jwks.Keyfunc, jwt.WithIssuer(issuer))
|
||||
if err != nil {
|
||||
return claims, mapClaims, err
|
||||
}
|
||||
_, _, err = new(jwt.Parser).ParseUnverified(token, mapClaims)
|
||||
// TODO: decode mapClaims to sth readable
|
||||
c.Logger.Debug().Interface("access token", &claims).Msg("parsed access token")
|
||||
if err != nil {
|
||||
c.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.")
|
||||
return claims, mapClaims, err
|
||||
}
|
||||
|
||||
return claims, mapClaims, nil
|
||||
}
|
||||
|
||||
func (c *oidcClient) VerifyLogoutToken(ctx context.Context, rawToken string) (*LogoutToken, error) {
|
||||
var claims LogoutToken
|
||||
if err := c.lookupWellKnownOpenidConfiguration(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwks := c.getKeyfunc()
|
||||
if jwks == nil {
|
||||
return nil, errors.New("error initializing jwks keyfunc")
|
||||
}
|
||||
|
||||
// From the backchannel-logout spec: Like ID Tokens, selection of the
|
||||
// algorithm used is governed by the id_token_signing_alg_values_supported
|
||||
// Discovery parameter and the id_token_signed_response_alg Registration
|
||||
// parameter when they are used; otherwise, the value SHOULD be the default
|
||||
// of RS256
|
||||
supportedSigAlgs := c.algorithms
|
||||
if len(supportedSigAlgs) == 0 {
|
||||
supportedSigAlgs = []string{RS256}
|
||||
}
|
||||
|
||||
_, err := jwt.ParseWithClaims(rawToken, &claims, jwks.Keyfunc, jwt.WithValidMethods(supportedSigAlgs), jwt.WithIssuer(c.issuer))
|
||||
if err != nil {
|
||||
c.Logger.Debug().Err(err).Msg("Failed to parse logout token")
|
||||
return nil, err
|
||||
}
|
||||
// Basic token validation has happened in ParseWithClaims (signature,
|
||||
// issuer, audience, ...). Now for some logout token specific checks.
|
||||
// 1. Verify that the Logout Token contains a sub Claim, a sid Claim, or both.
|
||||
if claims.Subject == "" && claims.SessionId == "" {
|
||||
return nil, fmt.Errorf("oidc: logout token must contain either sub or sid and MAY contain both")
|
||||
}
|
||||
// 2. Verify that the Logout Token contains an events Claim whose value is JSON object containing the member name http://schemas.openid.net/event/backchannel-logout.
|
||||
if claims.Events.Event == nil {
|
||||
return nil, fmt.Errorf("oidc: logout token must contain logout event")
|
||||
}
|
||||
// 3. Verify that the Logout Token does not contain a nonce Claim.
|
||||
if claims.Nonce != nil {
|
||||
return nil, fmt.Errorf("oidc: nonce on logout token MUST NOT be present")
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
func unmarshalResp(r *http.Response, body []byte, v interface{}) error {
|
||||
err := json.Unmarshal(body, &v)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
ct := r.Header.Get("Content-Type")
|
||||
mediaType, _, err := mime.ParseMediaType(ct)
|
||||
if err == nil && mediaType == "application/json" {
|
||||
return fmt.Errorf("got Content-Type = application/json, but could not unmarshal as JSON: %v", err)
|
||||
}
|
||||
return fmt.Errorf("expected Content-Type = application/json, got %q: %v", ct, err)
|
||||
}
|
||||
@@ -0,0 +1,205 @@
|
||||
package oidc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"testing"
|
||||
|
||||
"github.com/MicahParks/keyfunc/v2"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/opencloud-eu/opencloud/pkg/oidc"
|
||||
)
|
||||
|
||||
type signingKey struct {
|
||||
priv interface{}
|
||||
jwks *keyfunc.JWKS
|
||||
}
|
||||
|
||||
func TestLogoutVerify(t *testing.T) {
|
||||
tests := []logoutVerificationTest{
|
||||
{
|
||||
name: "good token",
|
||||
logoutToken: jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
||||
"iss": "https://foo",
|
||||
"sub": "248289761001",
|
||||
"aud": "s6BhdRkqt3",
|
||||
"iat": 1471566154,
|
||||
"jti": "bWJq",
|
||||
"sid": "08a5019c-17e1-4977-8f42-65a12843ea02",
|
||||
"events": map[string]interface{}{
|
||||
"http://schemas.openid.net/event/backchannel-logout": struct{}{},
|
||||
},
|
||||
}),
|
||||
signKey: newRSAKey(t),
|
||||
},
|
||||
{
|
||||
name: "invalid issuer",
|
||||
issuer: "https://bar",
|
||||
logoutToken: jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
||||
"iss": "https://foo1",
|
||||
"sub": "248289761001",
|
||||
"events": map[string]interface{}{
|
||||
"http://schemas.openid.net/event/backchannel-logout": struct{}{},
|
||||
},
|
||||
}),
|
||||
signKey: newRSAKey(t),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid sig",
|
||||
logoutToken: jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
||||
"iss": "https://foo",
|
||||
"sub": "248289761001",
|
||||
"aud": "s6BhdRkqt3",
|
||||
"iat": 1471566154,
|
||||
"jti": "bWJq",
|
||||
"sid": "08a5019c-17e1-4977-8f42-65a12843ea02",
|
||||
"events": map[string]interface{}{
|
||||
"http://schemas.openid.net/event/backchannel-logout": struct{}{},
|
||||
},
|
||||
}),
|
||||
signKey: newRSAKey(t),
|
||||
verificationKey: newRSAKey(t),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no sid and no sub",
|
||||
logoutToken: jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
||||
"iss": "https://foo",
|
||||
"aud": "s6BhdRkqt3",
|
||||
"iat": 1471566154,
|
||||
"jti": "bWJq",
|
||||
"events": map[string]interface{}{
|
||||
"http://schemas.openid.net/event/backchannel-logout": struct{}{},
|
||||
},
|
||||
}),
|
||||
signKey: newRSAKey(t),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Prohibited nonce present",
|
||||
logoutToken: jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
||||
"iss": "https://foo",
|
||||
"sub": "248289761001",
|
||||
"aud": "s6BhdRkqt3",
|
||||
"iat": 1471566154,
|
||||
"jti": "bWJq",
|
||||
"sid": "08a5019c-17e1-4977-8f42-65a12843ea02",
|
||||
"nonce": "123",
|
||||
"events": map[string]interface{}{
|
||||
"http://schemas.openid.net/event/backchannel-logout": struct{}{},
|
||||
},
|
||||
}),
|
||||
signKey: newRSAKey(t),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Wrong Event string",
|
||||
logoutToken: jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
||||
"iss": "https://foo",
|
||||
"sub": "248289761001",
|
||||
"aud": "s6BhdRkqt3",
|
||||
"iat": 1471566154,
|
||||
"jti": "bWJq",
|
||||
"sid": "08a5019c-17e1-4977-8f42-65a12843ea02",
|
||||
"events": map[string]interface{}{
|
||||
"http://blah.blah.blash/event/backchannel-logout": struct{}{},
|
||||
},
|
||||
}),
|
||||
signKey: newRSAKey(t),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "No Event string",
|
||||
logoutToken: jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
||||
"iss": "https://foo",
|
||||
"sub": "248289761001",
|
||||
"aud": "s6BhdRkqt3",
|
||||
"iat": 1471566154,
|
||||
"jti": "bWJq",
|
||||
"sid": "08a5019c-17e1-4977-8f42-65a12843ea02",
|
||||
}),
|
||||
signKey: newRSAKey(t),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, test.run)
|
||||
}
|
||||
}
|
||||
|
||||
type logoutVerificationTest struct {
|
||||
// Name of the subtest.
|
||||
name string
|
||||
|
||||
// If not provided defaults to "https://foo"
|
||||
issuer string
|
||||
|
||||
// JWT payload (just the claims).
|
||||
logoutToken *jwt.Token
|
||||
|
||||
// Key to sign the ID Token with.
|
||||
signKey *signingKey
|
||||
// If not provided defaults to signKey. Only useful when
|
||||
// testing invalid signatures.
|
||||
verificationKey *signingKey
|
||||
|
||||
wantErr bool
|
||||
}
|
||||
|
||||
func (v logoutVerificationTest) runGetToken(t *testing.T) (*oidc.LogoutToken, error) {
|
||||
// token := v.signKey.sign(t, []byte(v.logoutToken))
|
||||
v.logoutToken.Header["kid"] = "1"
|
||||
token, err := v.logoutToken.SignedString(v.signKey.priv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
issuer := "https://foo"
|
||||
var jwks *keyfunc.JWKS
|
||||
if v.verificationKey == nil {
|
||||
jwks = v.signKey.jwks
|
||||
} else {
|
||||
jwks = v.verificationKey.jwks
|
||||
}
|
||||
|
||||
pm := oidc.ProviderMetadata{}
|
||||
verifier := oidc.NewOIDCClient(
|
||||
oidc.WithOidcIssuer(issuer),
|
||||
oidc.WithJWKS(jwks),
|
||||
oidc.WithProviderMetadata(&pm),
|
||||
)
|
||||
|
||||
return verifier.VerifyLogoutToken(ctx, token)
|
||||
}
|
||||
|
||||
func (l logoutVerificationTest) run(t *testing.T) {
|
||||
_, err := l.runGetToken(t)
|
||||
if err != nil && !l.wantErr {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
if err == nil && l.wantErr {
|
||||
t.Errorf("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func newRSAKey(t testing.TB) *signingKey {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 1028)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
givenKey := keyfunc.NewGivenRSA(
|
||||
&priv.PublicKey,
|
||||
keyfunc.GivenKeyOptions{Algorithm: jwt.SigningMethodRS256.Alg()},
|
||||
)
|
||||
jwks := keyfunc.NewGiven(
|
||||
map[string]keyfunc.GivenKey{
|
||||
"1": givenKey,
|
||||
},
|
||||
)
|
||||
|
||||
return &signingKey{priv, jwks}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package oidc
|
||||
|
||||
import "context"
|
||||
|
||||
// contextKey is the key for oidc claims in a context
|
||||
type contextKey struct{}
|
||||
|
||||
// newSessionFlagKey is the key for the new session flag in a context
|
||||
type newSessionFlagKey struct{}
|
||||
|
||||
// NewContext makes a new context that contains the OpenID connect claims in a map.
|
||||
func NewContext(parent context.Context, c map[string]interface{}) context.Context {
|
||||
return context.WithValue(parent, contextKey{}, c)
|
||||
}
|
||||
|
||||
// FromContext returns the claims map stored in a context, or nil if there isn't one.
|
||||
func FromContext(ctx context.Context) map[string]interface{} {
|
||||
s, _ := ctx.Value(contextKey{}).(map[string]interface{})
|
||||
return s
|
||||
}
|
||||
|
||||
// NewContextSessionFlag makes a new context that contains the new session flag.
|
||||
func NewContextSessionFlag(ctx context.Context, flag bool) context.Context {
|
||||
return context.WithValue(ctx, newSessionFlagKey{}, flag)
|
||||
}
|
||||
|
||||
// NewSessionFlagFromContext returns the new session flag stored in a context.
|
||||
func NewSessionFlagFromContext(ctx context.Context) bool {
|
||||
s, _ := ctx.Value(newSessionFlagKey{}).(bool)
|
||||
return s
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package oidc
|
||||
|
||||
// JOSE asymmetric signing algorithm values as defined by RFC 7518
|
||||
//
|
||||
// see: https://tools.ietf.org/html/rfc7518#section-3.1
|
||||
const (
|
||||
RS256 = "RS256" // RSASSA-PKCS-v1.5 using SHA-256
|
||||
RS384 = "RS384" // RSASSA-PKCS-v1.5 using SHA-384
|
||||
RS512 = "RS512" // RSASSA-PKCS-v1.5 using SHA-512
|
||||
ES256 = "ES256" // ECDSA using P-256 and SHA-256
|
||||
ES384 = "ES384" // ECDSA using P-384 and SHA-384
|
||||
ES512 = "ES512" // ECDSA using P-521 and SHA-512
|
||||
PS256 = "PS256" // RSASSA-PSS using SHA256 and MGF1-SHA256
|
||||
PS384 = "PS384" // RSASSA-PSS using SHA384 and MGF1-SHA384
|
||||
PS512 = "PS512" // RSASSA-PSS using SHA512 and MGF1-SHA512
|
||||
)
|
||||
@@ -0,0 +1,103 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
)
|
||||
|
||||
const wellknownPath = "/.well-known/openid-configuration"
|
||||
|
||||
// The ProviderMetadata describes an idp.
|
||||
// see https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
|
||||
type ProviderMetadata struct {
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"`
|
||||
//claims_parameter_supported
|
||||
ClaimsSupported []string `json:"claims_supported,omitempty"`
|
||||
//grant_types_supported
|
||||
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"`
|
||||
Issuer string `json:"issuer,omitempty"`
|
||||
// AccessTokenIssuer is only used by AD FS and needs to be used when validating the iss of its access tokens
|
||||
// See https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-oidce/586de7dd-3385-47c7-93a2-935d9e90441c
|
||||
AccessTokenIssuer string `json:"access_token_issuer,omitempty"`
|
||||
JwksURI string `json:"jwks_uri,omitempty"`
|
||||
//registration_endpoint
|
||||
//request_object_signing_alg_values_supported
|
||||
//request_parameter_supported
|
||||
//request_uri_parameter_supported
|
||||
//require_request_uri_registration
|
||||
//response_modes_supported
|
||||
ResponseTypesSupported []string `json:"response_types_supported,omitempty"`
|
||||
ScopesSupported []string `json:"scopes_supported,omitempty"`
|
||||
SubjectTypesSupported []string `json:"subject_types_supported,omitempty"`
|
||||
TokenEndpoint string `json:"token_endpoint,omitempty"`
|
||||
//token_endpoint_auth_methods_supported
|
||||
//token_endpoint_auth_signing_alg_values_supported
|
||||
UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
|
||||
//userinfo_signing_alg_values_supported
|
||||
//code_challenge_methods_supported
|
||||
IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"`
|
||||
//introspection_endpoint_auth_methods_supported
|
||||
//introspection_endpoint_auth_signing_alg_values_supported
|
||||
RevocationEndpoint string `json:"revocation_endpoint,omitempty"`
|
||||
//revocation_endpoint_auth_methods_supported
|
||||
//revocation_endpoint_auth_signing_alg_values_supported
|
||||
//id_token_encryption_alg_values_supported
|
||||
//id_token_encryption_enc_values_supported
|
||||
//userinfo_encryption_alg_values_supported
|
||||
//userinfo_encryption_enc_values_supported
|
||||
//request_object_encryption_alg_values_supported
|
||||
//request_object_encryption_enc_values_supported
|
||||
CheckSessionIframe string `json:"check_session_iframe,omitempty"`
|
||||
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
|
||||
//claim_types_supported
|
||||
}
|
||||
|
||||
// Logout Token defines an logout Token
|
||||
type LogoutToken struct {
|
||||
jwt.RegisteredClaims
|
||||
// The Session Id
|
||||
SessionId string `json:"sid"`
|
||||
Events LogoutEvent `json:"events"`
|
||||
// Note: This is just here to be able to check for nonce being absent
|
||||
Nonce *string `json:"nonce"`
|
||||
}
|
||||
|
||||
// LogoutEvent defines a logout Event
|
||||
type LogoutEvent struct {
|
||||
Event *struct{} `json:"http://schemas.openid.net/event/backchannel-logout"`
|
||||
}
|
||||
|
||||
func GetIDPMetadata(logger log.Logger, client *http.Client, idpURI string) (ProviderMetadata, error) {
|
||||
wellknownURI := strings.TrimSuffix(idpURI, "/") + wellknownPath
|
||||
|
||||
resp, err := client.Get(wellknownURI)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to set request for .well-known/openid-configuration")
|
||||
return ProviderMetadata{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("unable to read discovery response body")
|
||||
return ProviderMetadata{}, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.Error().Str("status", resp.Status).Str("body", string(body)).Msg("error requesting openid-configuration")
|
||||
return ProviderMetadata{}, err
|
||||
}
|
||||
|
||||
var oidcMetadata ProviderMetadata
|
||||
err = json.Unmarshal(body, &oidcMetadata)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("failed to decode provider openid-configuration")
|
||||
return ProviderMetadata{}, err
|
||||
}
|
||||
return oidcMetadata, nil
|
||||
}
|
||||
@@ -0,0 +1,225 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
jwt "github.com/golang-jwt/jwt/v5"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
oauth2 "golang.org/x/oauth2"
|
||||
|
||||
oidc "github.com/opencloud-eu/opencloud/pkg/oidc"
|
||||
)
|
||||
|
||||
// OIDCClient is an autogenerated mock type for the OIDCClient type
|
||||
type OIDCClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type OIDCClient_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *OIDCClient) EXPECT() *OIDCClient_Expecter {
|
||||
return &OIDCClient_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// UserInfo provides a mock function with given fields: ctx, ts
|
||||
func (_m *OIDCClient) UserInfo(ctx context.Context, ts oauth2.TokenSource) (*oidc.UserInfo, error) {
|
||||
ret := _m.Called(ctx, ts)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UserInfo")
|
||||
}
|
||||
|
||||
var r0 *oidc.UserInfo
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, oauth2.TokenSource) (*oidc.UserInfo, error)); ok {
|
||||
return rf(ctx, ts)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, oauth2.TokenSource) *oidc.UserInfo); ok {
|
||||
r0 = rf(ctx, ts)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*oidc.UserInfo)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, oauth2.TokenSource) error); ok {
|
||||
r1 = rf(ctx, ts)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// OIDCClient_UserInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UserInfo'
|
||||
type OIDCClient_UserInfo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UserInfo is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - ts oauth2.TokenSource
|
||||
func (_e *OIDCClient_Expecter) UserInfo(ctx interface{}, ts interface{}) *OIDCClient_UserInfo_Call {
|
||||
return &OIDCClient_UserInfo_Call{Call: _e.mock.On("UserInfo", ctx, ts)}
|
||||
}
|
||||
|
||||
func (_c *OIDCClient_UserInfo_Call) Run(run func(ctx context.Context, ts oauth2.TokenSource)) *OIDCClient_UserInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(oauth2.TokenSource))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *OIDCClient_UserInfo_Call) Return(_a0 *oidc.UserInfo, _a1 error) *OIDCClient_UserInfo_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *OIDCClient_UserInfo_Call) RunAndReturn(run func(context.Context, oauth2.TokenSource) (*oidc.UserInfo, error)) *OIDCClient_UserInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifyAccessToken provides a mock function with given fields: ctx, token
|
||||
func (_m *OIDCClient) VerifyAccessToken(ctx context.Context, token string) (oidc.RegClaimsWithSID, jwt.MapClaims, error) {
|
||||
ret := _m.Called(ctx, token)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifyAccessToken")
|
||||
}
|
||||
|
||||
var r0 oidc.RegClaimsWithSID
|
||||
var r1 jwt.MapClaims
|
||||
var r2 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) (oidc.RegClaimsWithSID, jwt.MapClaims, error)); ok {
|
||||
return rf(ctx, token)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) oidc.RegClaimsWithSID); ok {
|
||||
r0 = rf(ctx, token)
|
||||
} else {
|
||||
r0 = ret.Get(0).(oidc.RegClaimsWithSID)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string) jwt.MapClaims); ok {
|
||||
r1 = rf(ctx, token)
|
||||
} else {
|
||||
if ret.Get(1) != nil {
|
||||
r1 = ret.Get(1).(jwt.MapClaims)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(2).(func(context.Context, string) error); ok {
|
||||
r2 = rf(ctx, token)
|
||||
} else {
|
||||
r2 = ret.Error(2)
|
||||
}
|
||||
|
||||
return r0, r1, r2
|
||||
}
|
||||
|
||||
// OIDCClient_VerifyAccessToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyAccessToken'
|
||||
type OIDCClient_VerifyAccessToken_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifyAccessToken is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - token string
|
||||
func (_e *OIDCClient_Expecter) VerifyAccessToken(ctx interface{}, token interface{}) *OIDCClient_VerifyAccessToken_Call {
|
||||
return &OIDCClient_VerifyAccessToken_Call{Call: _e.mock.On("VerifyAccessToken", ctx, token)}
|
||||
}
|
||||
|
||||
func (_c *OIDCClient_VerifyAccessToken_Call) Run(run func(ctx context.Context, token string)) *OIDCClient_VerifyAccessToken_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *OIDCClient_VerifyAccessToken_Call) Return(_a0 oidc.RegClaimsWithSID, _a1 jwt.MapClaims, _a2 error) *OIDCClient_VerifyAccessToken_Call {
|
||||
_c.Call.Return(_a0, _a1, _a2)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *OIDCClient_VerifyAccessToken_Call) RunAndReturn(run func(context.Context, string) (oidc.RegClaimsWithSID, jwt.MapClaims, error)) *OIDCClient_VerifyAccessToken_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifyLogoutToken provides a mock function with given fields: ctx, token
|
||||
func (_m *OIDCClient) VerifyLogoutToken(ctx context.Context, token string) (*oidc.LogoutToken, error) {
|
||||
ret := _m.Called(ctx, token)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifyLogoutToken")
|
||||
}
|
||||
|
||||
var r0 *oidc.LogoutToken
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) (*oidc.LogoutToken, error)); ok {
|
||||
return rf(ctx, token)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) *oidc.LogoutToken); ok {
|
||||
r0 = rf(ctx, token)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*oidc.LogoutToken)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = rf(ctx, token)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// OIDCClient_VerifyLogoutToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyLogoutToken'
|
||||
type OIDCClient_VerifyLogoutToken_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifyLogoutToken is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - token string
|
||||
func (_e *OIDCClient_Expecter) VerifyLogoutToken(ctx interface{}, token interface{}) *OIDCClient_VerifyLogoutToken_Call {
|
||||
return &OIDCClient_VerifyLogoutToken_Call{Call: _e.mock.On("VerifyLogoutToken", ctx, token)}
|
||||
}
|
||||
|
||||
func (_c *OIDCClient_VerifyLogoutToken_Call) Run(run func(ctx context.Context, token string)) *OIDCClient_VerifyLogoutToken_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *OIDCClient_VerifyLogoutToken_Call) Return(_a0 *oidc.LogoutToken, _a1 error) *OIDCClient_VerifyLogoutToken_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *OIDCClient_VerifyLogoutToken_Call) RunAndReturn(run func(context.Context, string) (*oidc.LogoutToken, error)) *OIDCClient_VerifyLogoutToken_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewOIDCClient creates a new instance of OIDCClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewOIDCClient(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *OIDCClient {
|
||||
mock := &OIDCClient{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/MicahParks/keyfunc/v2"
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
"github.com/opencloud-eu/opencloud/services/proxy/pkg/config"
|
||||
|
||||
goidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
)
|
||||
|
||||
// Option defines a single option function.
|
||||
type Option func(o *Options)
|
||||
|
||||
// Options defines the available options for this package.
|
||||
type Options struct {
|
||||
// HTTPClient to use for requests
|
||||
HTTPClient *http.Client
|
||||
// Logger to use for logging, must be set
|
||||
Logger log.Logger
|
||||
// The OpenID Connect Issuer URL
|
||||
OIDCIssuer string
|
||||
// JWKSOptions to use when retrieving keys
|
||||
JWKSOptions config.JWKS
|
||||
// the JWKS keyset to use for verifying signatures of Access- and
|
||||
// Logout-Tokens
|
||||
// this option is mostly needed for unit test. To avoid fetching the keys
|
||||
// from the issuer
|
||||
JWKS *keyfunc.JWKS
|
||||
// KeySet to use when verifiing signatures of jwt encoded
|
||||
// user info responses
|
||||
// TODO move userinfo verification to use jwt/keyfunc as well
|
||||
KeySet KeySet
|
||||
// AccessTokenVerifyMethod to use when verifying access tokens
|
||||
// TODO pass a function or interface to verify? an AccessTokenVerifier?
|
||||
AccessTokenVerifyMethod string
|
||||
// Config to use
|
||||
Config *goidc.Config
|
||||
|
||||
// ProviderMetadata to use
|
||||
ProviderMetadata *ProviderMetadata
|
||||
}
|
||||
|
||||
// newOptions initializes the available default options.
|
||||
func newOptions(opts ...Option) Options {
|
||||
opt := Options{}
|
||||
|
||||
for _, o := range opts {
|
||||
o(&opt)
|
||||
}
|
||||
|
||||
return opt
|
||||
}
|
||||
|
||||
// WithOidcIssuer provides a function to set the openid connect issuer option.
|
||||
func WithOidcIssuer(val string) Option {
|
||||
return func(o *Options) {
|
||||
o.OIDCIssuer = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger provides a function to set the logger option.
|
||||
func WithLogger(val log.Logger) Option {
|
||||
return func(o *Options) {
|
||||
o.Logger = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithAccessTokenVerifyMethod provides a function to set the accessTokenVerifyMethod option.
|
||||
func WithAccessTokenVerifyMethod(val string) Option {
|
||||
return func(o *Options) {
|
||||
o.AccessTokenVerifyMethod = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient provides a function to set the httpClient option.
|
||||
func WithHTTPClient(val *http.Client) Option {
|
||||
return func(o *Options) {
|
||||
o.HTTPClient = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithJWKSOptions provides a function to set the jwksOptions option.
|
||||
func WithJWKSOptions(val config.JWKS) Option {
|
||||
return func(o *Options) {
|
||||
o.JWKSOptions = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithJWKS provides a function to set the JWKS option (mainly useful for testing).
|
||||
func WithJWKS(val *keyfunc.JWKS) Option {
|
||||
return func(o *Options) {
|
||||
o.JWKS = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithKeySet provides a function to set the KeySet option.
|
||||
func WithKeySet(val KeySet) Option {
|
||||
return func(o *Options) {
|
||||
o.KeySet = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithConfig provides a function to set the Config option.
|
||||
func WithConfig(val *goidc.Config) Option {
|
||||
return func(o *Options) {
|
||||
o.Config = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithProviderMetadata provides a function to set the provider option.
|
||||
func WithProviderMetadata(val *ProviderMetadata) Option {
|
||||
return func(o *Options) {
|
||||
o.ProviderMetadata = val
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
_registryRegisterIntervalEnv = "EXPERIMENTAL_REGISTER_INTERVAL"
|
||||
_registryRegisterTTLEnv = "EXPERIMENTAL_REGISTER_TTL"
|
||||
|
||||
// Note: _defaultRegisterInterval should always be lower than _defaultRegisterTTL
|
||||
_defaultRegisterInterval = time.Second * 25
|
||||
_defaultRegisterTTL = time.Second * 30
|
||||
)
|
||||
|
||||
// GetRegisterInterval returns the register interval from the environment.
|
||||
func GetRegisterInterval() time.Duration {
|
||||
d, err := time.ParseDuration(os.Getenv(_registryRegisterIntervalEnv))
|
||||
if err != nil {
|
||||
return _defaultRegisterInterval
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// GetRegisterTTL returns the register TTL from the environment.
|
||||
func GetRegisterTTL() time.Duration {
|
||||
d, err := time.ParseDuration(os.Getenv(_registryRegisterTTLEnv))
|
||||
if err != nil {
|
||||
return _defaultRegisterTTL
|
||||
}
|
||||
return d
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
mRegistry "go-micro.dev/v4/registry"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
)
|
||||
|
||||
// RegisterService publishes an arbitrary endpoint to the service-registry. This allows querying nodes of
|
||||
// non-micro services like reva. No health-checks are done, thus the caller is responsible for canceling.
|
||||
func RegisterService(ctx context.Context, logger log.Logger, service *mRegistry.Service, debugAddr string) error {
|
||||
registry := GetRegistry()
|
||||
node := service.Nodes[0]
|
||||
|
||||
logger.Info().Msgf("registering external service %v@%v", node.Id, node.Address)
|
||||
|
||||
rOpts := []mRegistry.RegisterOption{mRegistry.RegisterTTL(GetRegisterTTL())}
|
||||
if err := registry.Register(service, rOpts...); err != nil {
|
||||
logger.Fatal().Err(err).Msgf("Registration error for external service %v", service.Name)
|
||||
}
|
||||
|
||||
t := time.NewTicker(GetRegisterInterval())
|
||||
|
||||
go func() {
|
||||
// check if the service is ready
|
||||
delay := 500 * time.Millisecond
|
||||
for {
|
||||
resp, err := http.DefaultClient.Get("http://" + debugAddr + "/readyz")
|
||||
if err == nil && resp.StatusCode == http.StatusOK {
|
||||
resp.Body.Close()
|
||||
break
|
||||
}
|
||||
time.Sleep(delay)
|
||||
delay *= 2
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
logger.Debug().Interface("service", service).Msg("refreshing external service-registration")
|
||||
err := registry.Register(service, rOpts...)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msgf("registration error for external service %v", service.Name)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
logger.Debug().Interface("service", service).Msg("unregistering")
|
||||
t.Stop()
|
||||
err := registry.Deregister(service)
|
||||
if err != nil {
|
||||
logger.Err(err).Msgf("Error unregistering external service %v", service.Name)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package registry
|
||||
|
||||
//
|
||||
//import (
|
||||
// "context"
|
||||
// "testing"
|
||||
//
|
||||
// "github.com/micro/go-micro/v2/registry"
|
||||
// "github.com/opencloud-eu/opencloud/pkg/log"
|
||||
//)
|
||||
//
|
||||
//func TestRegisterGRPCEndpoint(t *testing.T) {
|
||||
// ctx, cancel := context.WithCancel(context.Background())
|
||||
// err := RegisterGRPCEndpoint(ctx, "test", "1234", "192.168.0.1:777", log.Logger{})
|
||||
// if err != nil {
|
||||
// t.Errorf("Unexpected error: %v", err)
|
||||
// }
|
||||
//
|
||||
// s, err := registry.GetService("test")
|
||||
// if err != nil {
|
||||
// t.Errorf("Unexpected error: %v", err)
|
||||
// }
|
||||
//
|
||||
// if len(s) != 1 {
|
||||
// t.Errorf("Expected exactly one service to be returned got %v", len(s))
|
||||
// }
|
||||
//
|
||||
// if len(s[0].Nodes) != 1 {
|
||||
// t.Errorf("Expected exactly one node to be returned got %v", len(s[0].Nodes))
|
||||
// }
|
||||
//
|
||||
// testSvc := s[0]
|
||||
// if testSvc.Name != "test" {
|
||||
// t.Errorf("Expected service name to be 'test' got %v", s[0].Name)
|
||||
// }
|
||||
//
|
||||
// testNode := testSvc.Nodes[0]
|
||||
//
|
||||
// if testNode.Address != "192.168.0.1:777" {
|
||||
// t.Errorf("Expected node address to be '192.168.0.1:777' got %v", testNode.Address)
|
||||
// }
|
||||
//
|
||||
// if testNode.Id != "test-1234" {
|
||||
// t.Errorf("Expected node id to be 'test-1234' got %v", testNode.Id)
|
||||
// }
|
||||
//
|
||||
// cancel()
|
||||
//
|
||||
// // When switching over to monorepo this little test fails. We're unsure of what the cause is, but since this test
|
||||
// // is testing a framework specific behavior, we're better off letting it commented out. There is also no use of
|
||||
// // com.owncloud.reva anywhere in the codebase, so we're effectively only registering reva as a go-micro service,
|
||||
// // but not sending any message.
|
||||
// s, err = registry.GetService("test")
|
||||
// if err != nil {
|
||||
// t.Errorf("Unexpected error: %v", err)
|
||||
// }
|
||||
//
|
||||
// if len(s) != 0 {
|
||||
// t.Errorf("Deregister on cancelation failed. Result-length should be zero, got %v", len(s))
|
||||
// }
|
||||
//}
|
||||
@@ -0,0 +1,102 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
rRegistry "github.com/cs3org/reva/v2/pkg/registry"
|
||||
memr "github.com/go-micro/plugins/v4/registry/memory"
|
||||
"github.com/opencloud-eu/opencloud/pkg/natsjsregistry"
|
||||
mRegistry "go-micro.dev/v4/registry"
|
||||
"go-micro.dev/v4/registry/cache"
|
||||
)
|
||||
|
||||
const (
|
||||
_registryEnv = "MICRO_REGISTRY"
|
||||
_registryAddressEnv = "MICRO_REGISTRY_ADDRESS"
|
||||
)
|
||||
|
||||
var (
|
||||
_once sync.Once
|
||||
_reg mRegistry.Registry
|
||||
)
|
||||
|
||||
// Config is the config for a registry
|
||||
type Config struct {
|
||||
Type string `mapstructure:"type"`
|
||||
Addresses []string `mapstructure:"addresses"`
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
DisableCache bool `mapstructure:"disable_cache"`
|
||||
RegisterTTL time.Duration `mapstructure:"register_ttl"`
|
||||
}
|
||||
|
||||
// Option allows configuring the registry
|
||||
type Option func(*Config)
|
||||
|
||||
// Inmemory overrides env values to use an in-memory registry
|
||||
func Inmemory() Option {
|
||||
return func(c *Config) {
|
||||
c.Type = "memory"
|
||||
}
|
||||
}
|
||||
|
||||
// GetRegistry returns a configured micro registry based on Micro env vars.
|
||||
// It defaults to mDNS, so mind that systems with mDNS disabled by default (i.e SUSE) will have a hard time
|
||||
// and it needs to explicitly use etcd. Os awareness for providing a working registry out of the box should be done.
|
||||
func GetRegistry(opts ...Option) mRegistry.Registry {
|
||||
_once.Do(func() {
|
||||
cfg := getEnvs(opts...)
|
||||
|
||||
switch cfg.Type {
|
||||
default:
|
||||
fmt.Println("Attention: unknown registry type, using default nats-js-kv")
|
||||
fallthrough
|
||||
case "natsjs", "nats-js", "nats-js-kv": // for backwards compatibility - we will stick with one of those
|
||||
_reg = natsjsregistry.NewRegistry(
|
||||
mRegistry.Addrs(cfg.Addresses...),
|
||||
natsjsregistry.DefaultTTL(cfg.RegisterTTL),
|
||||
)
|
||||
case "memory":
|
||||
_reg = memr.NewRegistry()
|
||||
cfg.DisableCache = true // no cache needed for in-memory registry
|
||||
}
|
||||
|
||||
// Disable cache if wanted
|
||||
if !cfg.DisableCache {
|
||||
_reg = cache.New(_reg, cache.WithTTL(30*time.Second))
|
||||
}
|
||||
|
||||
// fixme: lazy initialization of reva registry, needs refactor to a explicit call per service
|
||||
_ = rRegistry.Init(_reg)
|
||||
})
|
||||
// always use cached registry to prevent registry
|
||||
// lookup for every request
|
||||
return _reg
|
||||
}
|
||||
|
||||
func getEnvs(opts ...Option) *Config {
|
||||
cfg := &Config{
|
||||
Type: "nats-js-kv",
|
||||
Addresses: []string{"127.0.0.1:9233"},
|
||||
}
|
||||
|
||||
if s := os.Getenv(_registryEnv); s != "" {
|
||||
cfg.Type = s
|
||||
}
|
||||
|
||||
if s := strings.Split(os.Getenv(_registryAddressEnv), ","); len(s) > 0 && s[0] != "" {
|
||||
cfg.Addresses = s
|
||||
}
|
||||
|
||||
cfg.RegisterTTL = GetRegisterTTL()
|
||||
|
||||
for _, o := range opts {
|
||||
o(cfg)
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
mRegistry "go-micro.dev/v4/registry"
|
||||
"go-micro.dev/v4/server"
|
||||
mAddr "go-micro.dev/v4/util/addr"
|
||||
)
|
||||
|
||||
func BuildGRPCService(serviceID, transport, address, version string) *mRegistry.Service {
|
||||
var host string
|
||||
var port int
|
||||
|
||||
parts := strings.Split(address, ":")
|
||||
if len(parts) > 1 {
|
||||
host = strings.Join(parts[:len(parts)-1], ":")
|
||||
port, _ = strconv.Atoi(parts[len(parts)-1])
|
||||
} else {
|
||||
host = parts[0]
|
||||
}
|
||||
|
||||
addr := host
|
||||
if transport != "unix" {
|
||||
var err error
|
||||
addr, err = mAddr.Extract(host)
|
||||
if err != nil {
|
||||
addr = host
|
||||
}
|
||||
addr = net.JoinHostPort(addr, strconv.Itoa(port))
|
||||
}
|
||||
|
||||
node := &mRegistry.Node{
|
||||
Id: serviceID + "-" + server.DefaultId,
|
||||
Address: addr,
|
||||
Metadata: make(map[string]string),
|
||||
}
|
||||
|
||||
node.Metadata["registry"] = GetRegistry().String()
|
||||
node.Metadata["server"] = "grpc"
|
||||
node.Metadata["transport"] = transport
|
||||
node.Metadata["protocol"] = "grpc"
|
||||
|
||||
return &mRegistry.Service{
|
||||
Name: serviceID,
|
||||
Version: version,
|
||||
Nodes: []*mRegistry.Node{node},
|
||||
Endpoints: make([]*mRegistry.Endpoint, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func BuildHTTPService(serviceID, address string, version string) *mRegistry.Service {
|
||||
var host string
|
||||
var port int
|
||||
|
||||
parts := strings.Split(address, ":")
|
||||
if len(parts) > 1 {
|
||||
host = strings.Join(parts[:len(parts)-1], ":")
|
||||
port, _ = strconv.Atoi(parts[len(parts)-1])
|
||||
} else {
|
||||
host = parts[0]
|
||||
}
|
||||
|
||||
addr, err := mAddr.Extract(host)
|
||||
if err != nil {
|
||||
addr = host
|
||||
}
|
||||
|
||||
node := &mRegistry.Node{
|
||||
// This id is read by the registry watcher
|
||||
Id: serviceID + "-" + server.DefaultId,
|
||||
Address: net.JoinHostPort(addr, fmt.Sprint(port)),
|
||||
Metadata: make(map[string]string),
|
||||
}
|
||||
|
||||
node.Metadata["registry"] = GetRegistry().String()
|
||||
node.Metadata["server"] = "http"
|
||||
node.Metadata["transport"] = "http"
|
||||
node.Metadata["protocol"] = "http"
|
||||
|
||||
return &mRegistry.Service{
|
||||
Name: serviceID,
|
||||
Version: version,
|
||||
Nodes: []*mRegistry.Node{node},
|
||||
Endpoints: make([]*mRegistry.Endpoint, 0),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package roles
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/cs3org/reva/v2/pkg/store"
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
settingsmsg "github.com/opencloud-eu/opencloud/protogen/gen/ocis/messages/settings/v0"
|
||||
settingssvc "github.com/opencloud-eu/opencloud/protogen/gen/ocis/services/settings/v0"
|
||||
microstore "go-micro.dev/v4/store"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const (
|
||||
cacheDatabase = "ocis-pkg"
|
||||
cacheTableName = "roles"
|
||||
cacheTTL = time.Hour
|
||||
)
|
||||
|
||||
// Manager manages a cache of roles by fetching unknown roles from the settings.RoleService.
|
||||
type Manager struct {
|
||||
logger log.Logger
|
||||
roleCache microstore.Store
|
||||
roleService settingssvc.RoleService
|
||||
}
|
||||
|
||||
// NewManager returns a new instance of Manager.
|
||||
func NewManager(o ...Option) Manager {
|
||||
opts := newOptions(o...)
|
||||
|
||||
nStore := store.Create(opts.storeOptions...)
|
||||
return Manager{
|
||||
roleCache: nStore,
|
||||
roleService: opts.roleService,
|
||||
}
|
||||
}
|
||||
|
||||
// List returns all roles that match the given roleIDs.
|
||||
func (m *Manager) List(ctx context.Context, roleIDs []string) []*settingsmsg.Bundle {
|
||||
// get from cache
|
||||
result := make([]*settingsmsg.Bundle, 0)
|
||||
lookup := make([]string, 0)
|
||||
for _, roleID := range roleIDs {
|
||||
if records, err := m.roleCache.Read(roleID, microstore.ReadFrom(cacheDatabase, cacheTableName)); err != nil {
|
||||
lookup = append(lookup, roleID)
|
||||
} else {
|
||||
role := &settingsmsg.Bundle{}
|
||||
found := false
|
||||
for _, record := range records {
|
||||
if record.Key == roleID {
|
||||
if err := protojson.Unmarshal(record.Value, role); err == nil {
|
||||
// if we can unmarshal the role, append it to the result
|
||||
// otherwise assume the role wasn't found (data was damaged and
|
||||
// we need to get the role again)
|
||||
result = append(result, role)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
lookup = append(lookup, roleID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if there are roles missing, fetch them from the RoleService
|
||||
if len(lookup) > 0 {
|
||||
request := &settingssvc.ListBundlesRequest{
|
||||
BundleIds: lookup,
|
||||
}
|
||||
res, err := m.roleService.ListRoles(ctx, request)
|
||||
if err != nil {
|
||||
m.logger.Debug().Err(err).Msg("failed to fetch roles by roleIDs")
|
||||
return nil
|
||||
}
|
||||
for _, role := range res.Bundles {
|
||||
jsonbytes, _ := protojson.Marshal(role)
|
||||
record := µstore.Record{
|
||||
Key: role.Id,
|
||||
Value: jsonbytes,
|
||||
Expiry: cacheTTL,
|
||||
}
|
||||
err := m.roleCache.Write(
|
||||
record,
|
||||
microstore.WriteTo(cacheDatabase, cacheTableName),
|
||||
microstore.WriteTTL(cacheTTL),
|
||||
)
|
||||
if err != nil {
|
||||
m.logger.Debug().Err(err).Msg("failed to cache roles")
|
||||
}
|
||||
result = append(result, role)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// FindPermissionByID searches for a permission-setting by the permissionID, but limited to the given roleIDs
|
||||
func (m *Manager) FindPermissionByID(ctx context.Context, roleIDs []string, permissionID string) *settingsmsg.Setting {
|
||||
for _, role := range m.List(ctx, roleIDs) {
|
||||
for _, setting := range role.Settings {
|
||||
if setting.Id == permissionID {
|
||||
return setting
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindRoleIdsForUser returns all roles that are assigned to the supplied userid
|
||||
func (m *Manager) FindRoleIDsForUser(ctx context.Context, userID string) ([]string, error) {
|
||||
req := &settingssvc.ListRoleAssignmentsRequest{AccountUuid: userID}
|
||||
assignmentResponse, err := m.roleService.ListRoleAssignments(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roleIDs := make([]string, 0, len(assignmentResponse.Assignments))
|
||||
|
||||
for _, assignment := range assignmentResponse.Assignments {
|
||||
roleIDs = append(roleIDs, assignment.RoleId)
|
||||
}
|
||||
|
||||
return roleIDs, nil
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package roles
|
||||
|
||||
import (
|
||||
"github.com/opencloud-eu/opencloud/pkg/log"
|
||||
settingssvc "github.com/opencloud-eu/opencloud/protogen/gen/ocis/services/settings/v0"
|
||||
"go-micro.dev/v4/store"
|
||||
)
|
||||
|
||||
// Options are all the possible options.
|
||||
type Options struct {
|
||||
storeOptions []store.Option
|
||||
logger log.Logger
|
||||
roleService settingssvc.RoleService
|
||||
}
|
||||
|
||||
// Option mutates option
|
||||
type Option func(*Options)
|
||||
|
||||
// Logger sets a preconfigured logger
|
||||
func Logger(logger log.Logger) Option {
|
||||
return func(o *Options) {
|
||||
o.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// RoleService provides endpoints for fetching roles.
|
||||
func RoleService(rs settingssvc.RoleService) Option {
|
||||
return func(o *Options) {
|
||||
o.roleService = rs
|
||||
}
|
||||
}
|
||||
|
||||
// StoreOptions are the options for the store
|
||||
func StoreOptions(storeOpts []store.Option) Option {
|
||||
return func(o *Options) {
|
||||
o.storeOptions = storeOpts
|
||||
}
|
||||
}
|
||||
|
||||
func newOptions(opts ...Option) Options {
|
||||
o := Options{}
|
||||
|
||||
for _, v := range opts {
|
||||
v(&o)
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package roles
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/opencloud-eu/opencloud/pkg/middleware"
|
||||
"go-micro.dev/v4/metadata"
|
||||
)
|
||||
|
||||
// ReadRoleIDsFromContext extracts roleIDs from the metadata context and returns them as []string
|
||||
func ReadRoleIDsFromContext(ctx context.Context) (roleIDs []string, ok bool) {
|
||||
roleIDsJSON, ok := metadata.Get(ctx, middleware.RoleIDs)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
err := json.Unmarshal([]byte(roleIDsJSON), &roleIDs)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return roleIDs, true
|
||||
}
|
||||
@@ -0,0 +1,262 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GroupRunner represent a group of tasks that need to run together.
|
||||
// The expectation is that all the tasks will run at the same time, and when
|
||||
// one of them stops, the rest will also stop.
|
||||
//
|
||||
// The GroupRunner is intended to be used to run multiple services, which are
|
||||
// more or less independent from eachother, but at the same time it doesn't
|
||||
// make sense to have any of them stopped while the rest are running.
|
||||
// Basically, either all of them run, or none of them.
|
||||
// For example, you can have a GRPC and HTTP servers running, each of them
|
||||
// providing a piece of functionality, however, if any of them fails, the
|
||||
// feature provided by them would be incomplete or broken.
|
||||
//
|
||||
// The interrupt duration for the group can be set through the
|
||||
// `WithInterruptDuration` option. If the option isn't supplied, the default
|
||||
// value (15 secs) will be used.
|
||||
//
|
||||
// It's recommended that the timeouts are handled by each runner individually,
|
||||
// meaning that each runner's timeout should be less than the group runner's
|
||||
// timeout. This way, we can know which runner timed out.
|
||||
// If the group timeout is reached, the remaining results will have the
|
||||
// runner's id as "_unknown_".
|
||||
//
|
||||
// Note that, as services, the task aren't expected to stop by default.
|
||||
// This means that, if a task finishes naturally, the rest of the task will
|
||||
// asked to stop as well.
|
||||
type GroupRunner struct {
|
||||
runners sync.Map
|
||||
runnersCount int
|
||||
isRunning bool
|
||||
interruptDur time.Duration
|
||||
interrupted atomic.Bool
|
||||
interruptedCh chan time.Duration
|
||||
runningMutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewGroup will create a GroupRunner
|
||||
func NewGroup(opts ...Option) *GroupRunner {
|
||||
options := Options{
|
||||
InterruptDuration: DefaultGroupInterruptDuration,
|
||||
}
|
||||
|
||||
for _, o := range opts {
|
||||
o(&options)
|
||||
}
|
||||
|
||||
return &GroupRunner{
|
||||
runners: sync.Map{},
|
||||
runningMutex: sync.Mutex{},
|
||||
interruptDur: options.InterruptDuration,
|
||||
interruptedCh: make(chan time.Duration, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Add will add a runner to the group.
|
||||
//
|
||||
// It's mandatory that each runner in the group has an unique id, otherwise
|
||||
// there will be issues
|
||||
// Adding new runners once the group starts will cause a panic
|
||||
func (gr *GroupRunner) Add(r *Runner) {
|
||||
gr.runningMutex.Lock()
|
||||
defer gr.runningMutex.Unlock()
|
||||
|
||||
if gr.isRunning {
|
||||
panic("Adding a new runner after the group starts is forbidden")
|
||||
}
|
||||
|
||||
// LoadOrStore will try to store the runner
|
||||
if _, loaded := gr.runners.LoadOrStore(r.ID, r); loaded {
|
||||
// there is already a runner with the same id, which is forbidden
|
||||
panic("Trying to add a runner with an existing Id in the group")
|
||||
}
|
||||
// Only increase the count if a runner is stored.
|
||||
// Currently panicking if the runner exists and is loaded
|
||||
gr.runnersCount++
|
||||
}
|
||||
|
||||
// Run will execute all the tasks in the group at the same time.
|
||||
//
|
||||
// Similarly to the "regular" runner's `Run` method, the execution thread
|
||||
// will be blocked here until all tasks are completed, and their results
|
||||
// will be available (each result will have the runner's id so it's easy to
|
||||
// find which one failed). Note that there is no guarantee about the result's
|
||||
// order, so the first result in the slice might or might not be the first
|
||||
// result to be obtained.
|
||||
//
|
||||
// When the context is marked as done, the groupRunner will call all the
|
||||
// stoppers for each runner to notify each task to stop. Note that the tasks
|
||||
// might still take a while to complete.
|
||||
//
|
||||
// If a task finishes naturally (with the context still "alive"), it will also
|
||||
// cause the groupRunner to call the stoppers of the rest of the tasks. So if
|
||||
// a task finishes, the rest will also finish.
|
||||
// Note that it is NOT expected for the finished task's stopper to be called
|
||||
// in this case.
|
||||
func (gr *GroupRunner) Run(ctx context.Context) []*Result {
|
||||
// Set the flag inside the runningMutex to ensure we don't read the old value
|
||||
// in the `Add` method and add a new runner when this method is being executed
|
||||
// Note that if multiple `Run` or `RunAsync` happens, the underlying runners
|
||||
// will panic
|
||||
gr.runningMutex.Lock()
|
||||
gr.isRunning = true
|
||||
gr.runningMutex.Unlock()
|
||||
|
||||
results := make([]*Result, 0, gr.runnersCount)
|
||||
|
||||
ch := make(chan *Result, gr.runnersCount) // no need to block writing results
|
||||
gr.runners.Range(func(_, value any) bool {
|
||||
r := value.(*Runner)
|
||||
r.RunAsync(ch)
|
||||
return true
|
||||
})
|
||||
|
||||
var d time.Duration
|
||||
// wait for a result or for the context to be done
|
||||
select {
|
||||
case result := <-ch:
|
||||
results = append(results, result)
|
||||
case d = <-gr.interruptedCh:
|
||||
results = append(results, &Result{
|
||||
RunnerID: "_unknown_",
|
||||
RunnerError: NewGroupTimeoutError(d),
|
||||
})
|
||||
case <-ctx.Done():
|
||||
// Do nothing
|
||||
}
|
||||
|
||||
// interrupt the rest of the runners
|
||||
gr.Interrupt()
|
||||
|
||||
// Having notified that the context has been finished, we still need to
|
||||
// wait for the rest of the results
|
||||
for i := len(results); i < gr.runnersCount; i++ {
|
||||
select {
|
||||
case result := <-ch:
|
||||
results = append(results, result)
|
||||
case d2, ok := <-gr.interruptedCh:
|
||||
if ok {
|
||||
d = d2
|
||||
}
|
||||
results = append(results, &Result{
|
||||
RunnerID: "_unknown_",
|
||||
RunnerError: NewGroupTimeoutError(d),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Even if we reach the group time out and bail out early, tasks might
|
||||
// be running and eventually deliver the result through the channel.
|
||||
// We'll rely on the buffered channel so the tasks won't block and the
|
||||
// data can be eventually garbage-collected along with the unused
|
||||
// channel, so we won't close the channel here.
|
||||
return results
|
||||
}
|
||||
|
||||
// RunAsync will execute the tasks in the group asynchronously.
|
||||
// The result of each task will be placed in the provided channel as soon
|
||||
// as it's available.
|
||||
// Note that this method will finish as soon as all the tasks are running.
|
||||
func (gr *GroupRunner) RunAsync(ch chan<- *Result) {
|
||||
// Set the flag inside the runningMutex to ensure we don't read the old value
|
||||
// in the `Add` method and add a new runner when this method is being executed
|
||||
// Note that if multiple `Run` or `RunAsync` happens, the underlying runners
|
||||
// will panic
|
||||
gr.runningMutex.Lock()
|
||||
gr.isRunning = true
|
||||
gr.runningMutex.Unlock()
|
||||
|
||||
// we need a secondary channel to receive the first result so we can
|
||||
// interrupt the rest of the tasks
|
||||
interCh := make(chan *Result, gr.runnersCount)
|
||||
gr.runners.Range(func(_, value any) bool {
|
||||
r := value.(*Runner)
|
||||
r.RunAsync(interCh)
|
||||
return true
|
||||
})
|
||||
|
||||
go func() {
|
||||
var result *Result
|
||||
var d time.Duration
|
||||
|
||||
select {
|
||||
case result = <-interCh:
|
||||
// result already assigned, so do nothing
|
||||
case d = <-gr.interruptedCh:
|
||||
// we aren't tracking which runners have finished and which are still
|
||||
// running, so we'll use "_unknown_" as runner id
|
||||
result = &Result{
|
||||
RunnerID: "_unknown_",
|
||||
RunnerError: NewGroupTimeoutError(d),
|
||||
}
|
||||
}
|
||||
gr.Interrupt()
|
||||
|
||||
ch <- result
|
||||
for i := 1; i < gr.runnersCount; i++ {
|
||||
select {
|
||||
case result = <-interCh:
|
||||
// result already assigned, so do nothing
|
||||
case d2, ok := <-gr.interruptedCh:
|
||||
// if ok is true, d2 will have a good value; if false, the channel
|
||||
// is closed and we get a default value
|
||||
if ok {
|
||||
d = d2
|
||||
}
|
||||
result = &Result{
|
||||
RunnerID: "_unknown_",
|
||||
RunnerError: NewGroupTimeoutError(d),
|
||||
}
|
||||
}
|
||||
ch <- result
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Interrupt will execute the stopper function of ALL the tasks, which should
|
||||
// notify the tasks in order for them to finish.
|
||||
// The stoppers will be called immediately but sequentially. This means that
|
||||
// the second stopper won't be called until the first one has returned. This
|
||||
// usually isn't a problem because the service `Stop`'s methods either don't
|
||||
// take a long time to return, or they run asynchronously in another goroutine.
|
||||
//
|
||||
// As said, this will affect ALL the tasks in the group. It isn't possible to
|
||||
// try to stop just one task.
|
||||
// If a task has finished, the corresponding stopper won't be called
|
||||
//
|
||||
// The interrupt timeout for the group will start after all the runners in the
|
||||
// group have been notified. Note that, if the task's stopper for a runner
|
||||
// takes a lot of time to return, it will delay the timeout's start, so it's
|
||||
// advised that the stopper either returns fast or is run asynchronously.
|
||||
func (gr *GroupRunner) Interrupt() {
|
||||
if gr.interrupted.CompareAndSwap(false, true) {
|
||||
gr.runners.Range(func(_, value any) bool {
|
||||
r := value.(*Runner)
|
||||
select {
|
||||
case <-r.Finished():
|
||||
// No data should be sent through the channel, so we'd be
|
||||
// here only if the channel is closed. This means the task
|
||||
// has finished and we don't need to interrupt. We do
|
||||
// nothing in this case
|
||||
default:
|
||||
r.Interrupt()
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
_ = time.AfterFunc(gr.interruptDur, func() {
|
||||
// timeout reached -> send it through the channel so our runner
|
||||
// can abort
|
||||
gr.interruptedCh <- gr.interruptDur
|
||||
close(gr.interruptedCh)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,293 @@
|
||||
package runner_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/opencloud-eu/opencloud/pkg/runner"
|
||||
)
|
||||
|
||||
var _ = Describe("GroupRunner", func() {
|
||||
var (
|
||||
gr *runner.GroupRunner
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
gr = runner.NewGroup()
|
||||
|
||||
task1Ch := make(chan error)
|
||||
task1 := TimedTask(task1Ch, 30*time.Second)
|
||||
gr.Add(runner.New("task1", task1, func() {
|
||||
task1Ch <- nil
|
||||
close(task1Ch)
|
||||
}))
|
||||
|
||||
task2Ch := make(chan error)
|
||||
task2 := TimedTask(task2Ch, 20*time.Second)
|
||||
gr.Add(runner.New("task2", task2, func() {
|
||||
task2Ch <- nil
|
||||
close(task2Ch)
|
||||
}))
|
||||
})
|
||||
|
||||
Describe("Add", func() {
|
||||
It("Duplicated runner id panics", func() {
|
||||
Expect(func() {
|
||||
gr.Add(runner.New("task1", func() error {
|
||||
time.Sleep(6 * time.Second)
|
||||
return nil
|
||||
}, func() {
|
||||
}))
|
||||
}).To(Panic())
|
||||
})
|
||||
|
||||
It("Add after run panics", func(ctx SpecContext) {
|
||||
// context will be done in 1 second
|
||||
myCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// spawn a new goroutine and return the result in the channel
|
||||
ch2 := make(chan []*runner.Result)
|
||||
go func(ch2 chan []*runner.Result) {
|
||||
ch2 <- gr.Run(myCtx)
|
||||
close(ch2)
|
||||
}(ch2)
|
||||
|
||||
// context is done in 1 sec, so all task should be interrupted and finish
|
||||
Eventually(ctx, ch2).Should(Receive(ContainElements(
|
||||
&runner.Result{RunnerID: "task1", RunnerError: nil},
|
||||
&runner.Result{RunnerID: "task2", RunnerError: nil},
|
||||
)))
|
||||
|
||||
task3Ch := make(chan error)
|
||||
task3 := TimedTask(task3Ch, 6*time.Second)
|
||||
Expect(func() {
|
||||
gr.Add(runner.New("task3", task3, func() {
|
||||
task3Ch <- nil
|
||||
close(task3Ch)
|
||||
}))
|
||||
}).To(Panic())
|
||||
}, SpecTimeout(5*time.Second))
|
||||
|
||||
It("Add after runAsync panics", func(ctx SpecContext) {
|
||||
ch2 := make(chan *runner.Result)
|
||||
gr.RunAsync(ch2)
|
||||
|
||||
Expect(func() {
|
||||
task3Ch := make(chan error)
|
||||
task3 := TimedTask(task3Ch, 6*time.Second)
|
||||
gr.Add(runner.New("task3", task3, func() {
|
||||
task3Ch <- nil
|
||||
close(task3Ch)
|
||||
}))
|
||||
}).To(Panic())
|
||||
}, SpecTimeout(5*time.Second))
|
||||
})
|
||||
|
||||
Describe("Run", func() {
|
||||
It("Context is done", func(ctx SpecContext) {
|
||||
task3Ch := make(chan error)
|
||||
task3 := TimedTask(task3Ch, 6*time.Second)
|
||||
gr.Add(runner.New("task3", task3, func() {
|
||||
task3Ch <- nil
|
||||
close(task3Ch)
|
||||
}))
|
||||
|
||||
// context will be done in 1 second
|
||||
myCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// spawn a new goroutine and return the result in the channel
|
||||
ch2 := make(chan []*runner.Result)
|
||||
go func(ch2 chan []*runner.Result) {
|
||||
ch2 <- gr.Run(myCtx)
|
||||
close(ch2)
|
||||
}(ch2)
|
||||
|
||||
// context is done in 1 sec, so all task should be interrupted and finish
|
||||
Eventually(ctx, ch2).Should(Receive(ContainElements(
|
||||
&runner.Result{RunnerID: "task1", RunnerError: nil},
|
||||
&runner.Result{RunnerID: "task2", RunnerError: nil},
|
||||
&runner.Result{RunnerID: "task3", RunnerError: nil},
|
||||
)))
|
||||
}, SpecTimeout(5*time.Second))
|
||||
|
||||
It("One task finishes early", func(ctx SpecContext) {
|
||||
task3Ch := make(chan error)
|
||||
task3 := TimedTask(task3Ch, 1*time.Second)
|
||||
gr.Add(runner.New("task3", task3, func() {
|
||||
task3Ch <- nil
|
||||
close(task3Ch)
|
||||
}))
|
||||
|
||||
// context will be done in 10 second
|
||||
myCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// spawn a new goroutine and return the result in the channel
|
||||
ch2 := make(chan []*runner.Result)
|
||||
go func(ch2 chan []*runner.Result) {
|
||||
ch2 <- gr.Run(myCtx)
|
||||
close(ch2)
|
||||
}(ch2)
|
||||
|
||||
// task3 finishes in 1 sec, so the rest should also be interrupted
|
||||
Eventually(ctx, ch2).Should(Receive(ContainElements(
|
||||
&runner.Result{RunnerID: "task1", RunnerError: nil},
|
||||
&runner.Result{RunnerID: "task2", RunnerError: nil},
|
||||
&runner.Result{RunnerID: "task3", RunnerError: nil},
|
||||
)))
|
||||
}, SpecTimeout(5*time.Second))
|
||||
|
||||
It("Context done and group timeout reached", func(ctx SpecContext) {
|
||||
gr := runner.NewGroup(runner.WithInterruptDuration(2 * time.Second))
|
||||
|
||||
gr.Add(runner.New("task1", func() error {
|
||||
time.Sleep(6 * time.Second)
|
||||
return nil
|
||||
}, func() {
|
||||
}))
|
||||
|
||||
gr.Add(runner.New("task2", func() error {
|
||||
time.Sleep(6 * time.Second)
|
||||
return nil
|
||||
}, func() {
|
||||
}))
|
||||
|
||||
// context will be done in 1 second
|
||||
myCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// spawn a new goroutine and return the result in the channel
|
||||
ch2 := make(chan []*runner.Result)
|
||||
go func(ch2 chan []*runner.Result) {
|
||||
ch2 <- gr.Run(myCtx)
|
||||
close(ch2)
|
||||
}(ch2)
|
||||
|
||||
// context finishes in 1 sec, tasks will be interrupted
|
||||
// group timeout will be reached after 2 extra seconds
|
||||
Eventually(ctx, ch2).Should(Receive(ContainElements(
|
||||
&runner.Result{RunnerID: "_unknown_", RunnerError: runner.NewGroupTimeoutError(2 * time.Second)},
|
||||
&runner.Result{RunnerID: "_unknown_", RunnerError: runner.NewGroupTimeoutError(2 * time.Second)},
|
||||
)))
|
||||
}, SpecTimeout(5*time.Second))
|
||||
|
||||
It("Interrupted and group timeout reached", func(ctx SpecContext) {
|
||||
gr := runner.NewGroup(runner.WithInterruptDuration(2 * time.Second))
|
||||
|
||||
gr.Add(runner.New("task1", func() error {
|
||||
time.Sleep(6 * time.Second)
|
||||
return nil
|
||||
}, func() {
|
||||
}))
|
||||
|
||||
gr.Add(runner.New("task2", func() error {
|
||||
time.Sleep(6 * time.Second)
|
||||
return nil
|
||||
}, func() {
|
||||
}))
|
||||
|
||||
// context will be done in 10 second
|
||||
myCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// spawn a new goroutine and return the result in the channel
|
||||
ch2 := make(chan []*runner.Result)
|
||||
go func(ch2 chan []*runner.Result) {
|
||||
ch2 <- gr.Run(myCtx)
|
||||
close(ch2)
|
||||
}(ch2)
|
||||
gr.Interrupt()
|
||||
|
||||
// tasks will be interrupted
|
||||
// group timeout will be reached after 2 extra seconds
|
||||
Eventually(ctx, ch2).Should(Receive(ContainElements(
|
||||
&runner.Result{RunnerID: "_unknown_", RunnerError: runner.NewGroupTimeoutError(2 * time.Second)},
|
||||
&runner.Result{RunnerID: "_unknown_", RunnerError: runner.NewGroupTimeoutError(2 * time.Second)},
|
||||
)))
|
||||
}, SpecTimeout(5*time.Second))
|
||||
|
||||
It("Doble run panics", func(ctx SpecContext) {
|
||||
// context will be done in 1 second
|
||||
myCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
Expect(func() {
|
||||
gr.Run(myCtx)
|
||||
gr.Run(myCtx)
|
||||
}).To(Panic())
|
||||
}, SpecTimeout(5*time.Second))
|
||||
})
|
||||
|
||||
Describe("RunAsync", func() {
|
||||
It("Wait in channel", func(ctx SpecContext) {
|
||||
task3Ch := make(chan error)
|
||||
task3 := TimedTask(task3Ch, 1*time.Second)
|
||||
gr.Add(runner.New("task3", task3, func() {
|
||||
task3Ch <- nil
|
||||
close(task3Ch)
|
||||
}))
|
||||
|
||||
ch2 := make(chan *runner.Result)
|
||||
gr.RunAsync(ch2)
|
||||
|
||||
// task3 finishes in 1 sec, so the rest should also be interrupted
|
||||
Eventually(ctx, ch2).Should(Receive())
|
||||
Eventually(ctx, ch2).Should(Receive())
|
||||
Eventually(ctx, ch2).Should(Receive())
|
||||
}, SpecTimeout(5*time.Second))
|
||||
|
||||
It("Double runAsync panics", func(ctx SpecContext) {
|
||||
ch2 := make(chan *runner.Result)
|
||||
Expect(func() {
|
||||
gr.RunAsync(ch2)
|
||||
gr.RunAsync(ch2)
|
||||
}).To(Panic())
|
||||
}, SpecTimeout(5*time.Second))
|
||||
|
||||
It("Interrupt async", func(ctx SpecContext) {
|
||||
task3Ch := make(chan error)
|
||||
task3 := TimedTask(task3Ch, 6*time.Second)
|
||||
gr.Add(runner.New("task3", task3, func() {
|
||||
task3Ch <- nil
|
||||
close(task3Ch)
|
||||
}))
|
||||
|
||||
ch2 := make(chan *runner.Result)
|
||||
gr.RunAsync(ch2)
|
||||
gr.Interrupt()
|
||||
|
||||
// tasks will be interrupted
|
||||
Eventually(ctx, ch2).Should(Receive())
|
||||
Eventually(ctx, ch2).Should(Receive())
|
||||
Eventually(ctx, ch2).Should(Receive())
|
||||
}, SpecTimeout(5*time.Second))
|
||||
|
||||
It("Interrupt async group timeout reached", func(ctx SpecContext) {
|
||||
gr := runner.NewGroup(runner.WithInterruptDuration(2 * time.Second))
|
||||
|
||||
gr.Add(runner.New("task1", func() error {
|
||||
time.Sleep(6 * time.Second)
|
||||
return nil
|
||||
}, func() {
|
||||
}))
|
||||
|
||||
gr.Add(runner.New("task2", func() error {
|
||||
time.Sleep(6 * time.Second)
|
||||
return nil
|
||||
}, func() {
|
||||
}))
|
||||
|
||||
ch2 := make(chan *runner.Result)
|
||||
gr.RunAsync(ch2)
|
||||
gr.Interrupt()
|
||||
|
||||
// group timeout will be reached after 2 extra seconds
|
||||
Eventually(ctx, ch2).Should(Receive(Equal(&runner.Result{RunnerID: "_unknown_", RunnerError: runner.NewGroupTimeoutError(2 * time.Second)})))
|
||||
Eventually(ctx, ch2).Should(Receive(Equal(&runner.Result{RunnerID: "_unknown_", RunnerError: runner.NewGroupTimeoutError(2 * time.Second)})))
|
||||
}, SpecTimeout(5*time.Second))
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,30 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultInterruptDuration is the default value for the `WithInterruptDuration`
|
||||
// for the "regular" runners. This global value can be adjusted if needed.
|
||||
DefaultInterruptDuration = 10 * time.Second
|
||||
// DefaultGroupInterruptDuration is the default value for the `WithInterruptDuration`
|
||||
// for the group runners. This global value can be adjusted if needed.
|
||||
DefaultGroupInterruptDuration = 15 * time.Second
|
||||
)
|
||||
|
||||
// Option defines a single option function.
|
||||
type Option func(o *Options)
|
||||
|
||||
// Options defines the available options for this package.
|
||||
type Options struct {
|
||||
InterruptDuration time.Duration
|
||||
}
|
||||
|
||||
// WithInterruptDuration provides a function to set the interrupt
|
||||
// duration option.
|
||||
func WithInterruptDuration(val time.Duration) Option {
|
||||
return func(o *Options) {
|
||||
o.InterruptDuration = val
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,201 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Runner represents the one executing a long running task, such as a server
|
||||
// or a service.
|
||||
// The ID of the runner is public to make identification easier, and the
|
||||
// Result that it will generated will contain the same ID, so we can
|
||||
// know which runner provided which result.
|
||||
//
|
||||
// Runners are intended to be used only once. Reusing them isn't possible.
|
||||
// You'd need to create a new runner if you want to rerun the same task.
|
||||
type Runner struct {
|
||||
ID string
|
||||
interruptDur time.Duration
|
||||
fn Runable
|
||||
interrupt Stopper
|
||||
running atomic.Bool
|
||||
interrupted atomic.Bool
|
||||
interruptedCh chan time.Duration
|
||||
finished chan struct{}
|
||||
}
|
||||
|
||||
// New will create a new runner.
|
||||
// The runner will be created with the provided id (the id must be unique,
|
||||
// otherwise undefined behavior might occur), and will run the provided
|
||||
// runable task, using the "interrupt" function to stop that task if needed.
|
||||
//
|
||||
// The interrupt duration, which can be set through the `WithInterruptDuration`
|
||||
// option, will be used to ensure the runner doesn't block forever. If the
|
||||
// option isn't supplied, the default value (10 secs) will be used.
|
||||
// The interrupt duration will be used to start a timeout when the
|
||||
// runner gets interrupted (either the context of the `Run` method is done
|
||||
// or this runner's `Interrupt` method is called). If the timeout is reached,
|
||||
// a timeout result will be returned instead of whatever result the task should
|
||||
// be returning.
|
||||
//
|
||||
// Note that it's your responsibility to provide a proper stopper for the task.
|
||||
// The runner will just call that method assuming it will be enough to
|
||||
// eventually stop the task at some point.
|
||||
func New(id string, fn Runable, interrupt Stopper, opts ...Option) *Runner {
|
||||
options := Options{
|
||||
InterruptDuration: DefaultInterruptDuration,
|
||||
}
|
||||
|
||||
for _, o := range opts {
|
||||
o(&options)
|
||||
}
|
||||
|
||||
return &Runner{
|
||||
ID: id,
|
||||
interruptDur: options.InterruptDuration,
|
||||
fn: fn,
|
||||
interrupt: interrupt,
|
||||
interruptedCh: make(chan time.Duration, 1),
|
||||
finished: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Run will execute the task associated to this runner in a synchronous way.
|
||||
// The task will be spawned in a new goroutine, and the current thread will
|
||||
// wait until the task finishes.
|
||||
//
|
||||
// The task will finish "naturally". The stopper will be called in the
|
||||
// following ways:
|
||||
// - Manually calling this runner's `Interrupt` method
|
||||
// - When the provided context is done
|
||||
// As said, it's expected that calling the provided stopper will be enough to
|
||||
// make the task to eventually complete.
|
||||
//
|
||||
// Once the task finishes, the result will be returned.
|
||||
// When the context is done, or if the runner is interrupted, a timeout will
|
||||
// start using the provided "interrupt duration". If this timeout is reached,
|
||||
// a timeout result will be returned instead of the one from the task. This is
|
||||
// intended to prevent blocking the main thread indefinitely. A suitable
|
||||
// duration should be used depending on the task, usually 5, 10 or 30 secs
|
||||
//
|
||||
// Some nice things you can do:
|
||||
// - Use signal.NotifyContext(...) to call the stopper and provide a clean
|
||||
// shutdown procedure when an OS signal is received
|
||||
// - Use context.WithDeadline(...) or context.WithTimeout(...) to run the task
|
||||
// for a limited time
|
||||
func (r *Runner) Run(ctx context.Context) *Result {
|
||||
if !r.running.CompareAndSwap(false, true) {
|
||||
// If not swapped, the task is already running.
|
||||
// Running the same task multiple times is a bug, so we panic
|
||||
panic("Runner with id " + r.ID + " was running twice")
|
||||
}
|
||||
|
||||
ch := make(chan *Result)
|
||||
|
||||
go r.doTask(ch, true)
|
||||
|
||||
select {
|
||||
case result := <-ch:
|
||||
return result
|
||||
case <-ctx.Done():
|
||||
r.Interrupt()
|
||||
return <-ch
|
||||
}
|
||||
}
|
||||
|
||||
// RunAsync will execute the task associated to this runner asynchronously.
|
||||
// The task will be spawned in a new goroutine and this method will finish.
|
||||
// The task's result will be written in the provided channel when it's
|
||||
// available, so you can wait for it if needed. It's up to you to decide
|
||||
// to use a blocking or non-blocking channel, but the task will always finish
|
||||
// before writing in the channel.
|
||||
//
|
||||
// To interrupt the running task, the only option is to call the `Interrupt`
|
||||
// method at some point.
|
||||
func (r *Runner) RunAsync(ch chan<- *Result) {
|
||||
if !r.running.CompareAndSwap(false, true) {
|
||||
// If not swapped, the task is already running.
|
||||
// Running the same task multiple times is a bug, so we panic
|
||||
panic("Runner with id " + r.ID + " was running twice")
|
||||
}
|
||||
|
||||
go r.doTask(ch, false)
|
||||
}
|
||||
|
||||
// Interrupt will execute the stopper function, which should notify the task
|
||||
// in order for it to finish.
|
||||
// The stopper will be called immediately, although it's expected the
|
||||
// consequences to take a while (task might need a while to stop)
|
||||
// A timeout will start using the provided "interrupt duration". Once that
|
||||
// timeout is reached, the task must provide a result with a timeout error.
|
||||
// Note that, even after returning the timeout result, the task could still
|
||||
// be being executed and consuming resource.
|
||||
// This method will be called only once. Further calls won't do anything
|
||||
func (r *Runner) Interrupt() {
|
||||
if r.interrupted.CompareAndSwap(false, true) {
|
||||
go func() {
|
||||
select {
|
||||
case <-r.Finished():
|
||||
// Task finished -> runner should be delivering the result
|
||||
case <-time.After(r.interruptDur):
|
||||
// timeout reached -> send it through the channel so our runner
|
||||
// can abort
|
||||
r.interruptedCh <- r.interruptDur
|
||||
close(r.interruptedCh)
|
||||
}
|
||||
}()
|
||||
r.interrupt()
|
||||
}
|
||||
}
|
||||
|
||||
// Finished will return a receive-only channel that can be used to know when
|
||||
// the task has finished but the result hasn't been made available yet. The
|
||||
// channel will be closed (without sending any message) when the task has finished.
|
||||
// This can be used specially with the `RunAsync` method when multiple runners
|
||||
// use the same channel: results could be waiting on your side of the channel
|
||||
func (r *Runner) Finished() <-chan struct{} {
|
||||
return r.finished
|
||||
}
|
||||
|
||||
// doTask will perform this runner's task and write the result in the provided
|
||||
// channel. The channel will be closed if requested.
|
||||
// A result will be provided when either the task finishes naturally or we
|
||||
// reach the timeout after being interrupted
|
||||
func (r *Runner) doTask(ch chan<- *Result, closeChan bool) {
|
||||
tmpCh := make(chan *Result, 1)
|
||||
|
||||
// spawn the task and return the result in a temporary channel
|
||||
go func(tmpCh chan *Result) {
|
||||
err := r.fn()
|
||||
|
||||
close(r.finished)
|
||||
|
||||
result := &Result{
|
||||
RunnerID: r.ID,
|
||||
RunnerError: err,
|
||||
}
|
||||
tmpCh <- result
|
||||
|
||||
close(tmpCh)
|
||||
}(tmpCh)
|
||||
|
||||
// wait for the result in the temporary channel or until we get the
|
||||
// interrupted signal
|
||||
var result *Result
|
||||
select {
|
||||
case d := <-r.interruptedCh:
|
||||
result = &Result{
|
||||
RunnerID: r.ID,
|
||||
RunnerError: NewTimeoutError(r.ID, d),
|
||||
}
|
||||
case result = <-tmpCh:
|
||||
// Just assign the received value, nothing else to do
|
||||
}
|
||||
|
||||
// send the result
|
||||
ch <- result
|
||||
if closeChan {
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user