use plain pkg module

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>
This commit is contained in:
Jörn Friedrich Dreyer
2025-01-13 15:54:00 +01:00
committed by Florian Schade
parent 259cbc2e56
commit b07b5a1149
841 changed files with 1383 additions and 1366 deletions
+1
View File
@@ -0,0 +1 @@
!config
+9
View File
@@ -0,0 +1,9 @@
with-expecter: true
filename: "{{.InterfaceName | snakecase }}.go"
dir: "{{.PackageName}}/mocks"
mockname: "{{.InterfaceName}}"
outpkg: "mocks"
packages:
github.com/opencloud-eu/opencloud/ocis-pkg/oidc:
interfaces:
OIDCClient:
+40
View File
@@ -0,0 +1,40 @@
SHELL := bash
NAME := ocis-pkg
include ../.make/recursion.mk
############ tooling ############
ifneq (, $(shell command -v go 2> /dev/null)) # suppress `command not found warnings` for non go targets in CI
include ../.bingo/Variables.mk
endif
############ go tooling ############
include ../.make/go.mk
############ release ############
include ../.make/release.mk
############ docs generate ############
SKIP_CONFIG_DOCS_GENERATE = 1
include ../.make/docs.mk
.PHONY: docs-generate
docs-generate:
############ generate ############
include ../.make/generate.mk
.PHONY: ci-go-generate
ci-go-generate: $(MOCKERY) # CI runs ci-node-generate automatically before this target
$(MOCKERY)
.PHONY: ci-node-generate
ci-node-generate:
############ licenses ############
.PHONY: ci-node-check-licenses
ci-node-check-licenses:
.PHONY: ci-node-save-licenses
ci-node-save-licenses:
+30
View File
@@ -0,0 +1,30 @@
package account
import (
"github.com/opencloud-eu/opencloud/pkg/log"
)
// Option defines a single option function.
type Option func(o *Options)
// Options defines the available options for this package.
type Options struct {
// Logger to use for logging, must be set
Logger log.Logger
// JWTSecret is the jwt secret for the reva token manager
JWTSecret string
}
// Logger provides a function to set the logger option.
func Logger(l log.Logger) Option {
return func(o *Options) {
o.Logger = l
}
}
// JWTSecret provides a function to set the jwt secret option.
func JWTSecret(s string) Option {
return func(o *Options) {
o.JWTSecret = s
}
}
+107
View File
@@ -0,0 +1,107 @@
// Package ast provides available ast nodes.
package ast
import (
"time"
)
// Node represents abstract syntax tree node
type Node interface {
Location() *Location
}
// Position represents a specific location in the source
type Position struct {
Line int
Column int
}
// Location represents the location of a node in the AST
type Location struct {
Start Position `json:"start"`
End Position `json:"end"`
Source *string `json:"source,omitempty"`
}
// Base contains shared node attributes
// each node should inherit from this
type Base struct {
Loc *Location
}
// Location is the source location of the Node
func (b *Base) Location() *Location { return b.Loc }
// Ast represents the query - node structure as abstract syntax tree
type Ast struct {
*Base
Nodes []Node `json:"body"`
}
// StringNode represents a string value
type StringNode struct {
*Base
Key string
Value string
}
// BooleanNode represents a bool value
type BooleanNode struct {
*Base
Key string
Value bool
}
// DateTimeNode represents a time.Time value
type DateTimeNode struct {
*Base
Key string
Operator *OperatorNode
Value time.Time
}
// OperatorNode represents an operator value like
// AND, OR, NOT, =, <= ... and so on
type OperatorNode struct {
*Base
Value string
}
// GroupNode represents a collection of many grouped nodes
type GroupNode struct {
*Base
Key string
Nodes []Node
}
// NodeKey tries to return the node key
func NodeKey(n Node) string {
switch node := n.(type) {
case *StringNode:
return node.Key
case *DateTimeNode:
return node.Key
case *BooleanNode:
return node.Key
case *GroupNode:
return node.Key
default:
return ""
}
}
// NodeValue tries to return the node key
func NodeValue(n Node) interface{} {
switch node := n.(type) {
case *StringNode:
return node.Value
case *DateTimeNode:
return node.Value
case *BooleanNode:
return node.Value
case *GroupNode:
return node.Nodes
default:
return ""
}
}
+26
View File
@@ -0,0 +1,26 @@
// Package test provides shared test primitives for ast testing.
package test
import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/opencloud-eu/opencloud/pkg/ast"
)
// DiffAst returns a human-readable report of the differences between two values
// by default it ignores every ast node Base field.
func DiffAst(x, y interface{}, opts ...cmp.Option) string {
return cmp.Diff(
x,
y,
append(
opts,
cmpopts.IgnoreFields(ast.Ast{}, "Base"),
cmpopts.IgnoreFields(ast.StringNode{}, "Base"),
cmpopts.IgnoreFields(ast.OperatorNode{}, "Base"),
cmpopts.IgnoreFields(ast.GroupNode{}, "Base"),
cmpopts.IgnoreFields(ast.BooleanNode{}, "Base"),
cmpopts.IgnoreFields(ast.DateTimeNode{}, "Base"),
)...,
)
}
+45
View File
@@ -0,0 +1,45 @@
package broker
import (
"errors"
"go-micro.dev/v4/broker"
)
type NoOp struct{}
func (n NoOp) Init(_ ...broker.Option) error {
return nil
}
func (n NoOp) Options() broker.Options {
return broker.Options{}
}
func (n NoOp) Address() string {
return ""
}
func (n NoOp) Connect() error {
return nil
}
func (n NoOp) Disconnect() error {
return nil
}
func (n NoOp) Publish(topic string, m *broker.Message, opts ...broker.PublishOption) error {
return nil
}
func (n NoOp) Subscribe(topic string, h broker.Handler, opts ...broker.SubscribeOption) (broker.Subscriber, error) {
return nil, errors.New("not implemented")
}
func (n NoOp) String() string {
return "NoOp"
}
func NewBroker(_ ...broker.Option) broker.Broker {
return &NoOp{}
}
+35
View File
@@ -0,0 +1,35 @@
package capabilities
import (
"sync/atomic"
"github.com/cs3org/reva/v2/pkg/owncloud/ocs"
)
// allow the consuming part to change defaults, e.g., tests
var defaultCapabilities atomic.Pointer[ocs.Capabilities]
func init() { //nolint:gochecknoinits
ResetDefault()
}
// ResetDefault resets the default [Capabilities] to the default values.
func ResetDefault() {
defaultCapabilities.Store(
&ocs.Capabilities{
Theme: &ocs.CapabilitiesTheme{
Logo: &ocs.CapabilitiesThemeLogo{
PermittedFileTypes: map[string]string{
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
},
},
},
},
)
}
// Default returns the default [Capabilities].
func Default() *ocs.Capabilities { return defaultCapabilities.Load() }
+27
View File
@@ -0,0 +1,27 @@
package checks
import (
"context"
"fmt"
"github.com/opencloud-eu/opencloud/pkg/handlers"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
// NewGRPCCheck checks the reachability of a grpc server.
func NewGRPCCheck(address string) func(context.Context) error {
return func(_ context.Context) error {
address, err := handlers.FailSaveAddress(address)
if err != nil {
return err
}
conn, err := grpc.NewClient(address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return fmt.Errorf("could not connect to grpc server: %v", err)
}
_ = conn.Close()
return nil
}
}
+35
View File
@@ -0,0 +1,35 @@
package checks
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/opencloud-eu/opencloud/pkg/handlers"
)
// NewHTTPCheck checks the reachability of a http server.
func NewHTTPCheck(url string) func(context.Context) error {
return func(_ context.Context) error {
url, err := handlers.FailSaveAddress(url)
if err != nil {
return err
}
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
url = "http://" + url
}
c := http.Client{
Timeout: 3 * time.Second,
}
resp, err := c.Get(url)
if err != nil {
return fmt.Errorf("could not connect to http server: %v", err)
}
_ = resp.Body.Close()
return nil
}
}
+23
View File
@@ -0,0 +1,23 @@
package checks
import (
"context"
"fmt"
"github.com/nats-io/nats.go"
)
// NewNatsCheck checks the reachability of a nats server.
func NewNatsCheck(natsCluster string, options ...nats.Option) func(context.Context) error {
return func(_ context.Context) error {
n, err := nats.Connect(natsCluster, options...)
if err != nil {
return fmt.Errorf("could not connect to nats server: %v", err)
}
defer n.Close()
if n.Status() != nats.CONNECTED {
return fmt.Errorf("nats server not connected")
}
return nil
}
}
+31
View File
@@ -0,0 +1,31 @@
package checks
import (
"context"
"net"
"time"
"github.com/opencloud-eu/opencloud/pkg/handlers"
)
// NewTCPCheck returns a check that connects to a given tcp endpoint.
func NewTCPCheck(address string) func(context.Context) error {
return func(_ context.Context) error {
address, err := handlers.FailSaveAddress(address)
if err != nil {
return err
}
conn, err := net.DialTimeout("tcp", address, 3*time.Second)
if err != nil {
return err
}
err = conn.Close()
if err != nil {
return err
}
return nil
}
}
+26
View File
@@ -0,0 +1,26 @@
package clihelper
import (
"github.com/opencloud-eu/opencloud/pkg/version"
"github.com/urfave/cli/v2"
)
func DefaultApp(app *cli.App) *cli.App {
// version info
app.Version = version.String
app.Compiled = version.Compiled()
// author info
app.Authors = []*cli.Author{
{
Name: "ownCloud GmbH",
Email: "support@owncloud.com",
},
}
// disable global version flag
// instead we provide the version command
app.HideVersion = true
return app
}
+127
View File
@@ -0,0 +1,127 @@
package config
import (
"github.com/opencloud-eu/opencloud/pkg/shared"
activitylog "github.com/opencloud-eu/opencloud/services/activitylog/pkg/config"
antivirus "github.com/opencloud-eu/opencloud/services/antivirus/pkg/config"
appProvider "github.com/opencloud-eu/opencloud/services/app-provider/pkg/config"
appRegistry "github.com/opencloud-eu/opencloud/services/app-registry/pkg/config"
audit "github.com/opencloud-eu/opencloud/services/audit/pkg/config"
authapp "github.com/opencloud-eu/opencloud/services/auth-app/pkg/config"
authbasic "github.com/opencloud-eu/opencloud/services/auth-basic/pkg/config"
authbearer "github.com/opencloud-eu/opencloud/services/auth-bearer/pkg/config"
authmachine "github.com/opencloud-eu/opencloud/services/auth-machine/pkg/config"
authservice "github.com/opencloud-eu/opencloud/services/auth-service/pkg/config"
clientlog "github.com/opencloud-eu/opencloud/services/clientlog/pkg/config"
collaboration "github.com/opencloud-eu/opencloud/services/collaboration/pkg/config"
eventhistory "github.com/opencloud-eu/opencloud/services/eventhistory/pkg/config"
frontend "github.com/opencloud-eu/opencloud/services/frontend/pkg/config"
gateway "github.com/opencloud-eu/opencloud/services/gateway/pkg/config"
graph "github.com/opencloud-eu/opencloud/services/graph/pkg/config"
groups "github.com/opencloud-eu/opencloud/services/groups/pkg/config"
idm "github.com/opencloud-eu/opencloud/services/idm/pkg/config"
idp "github.com/opencloud-eu/opencloud/services/idp/pkg/config"
invitations "github.com/opencloud-eu/opencloud/services/invitations/pkg/config"
nats "github.com/opencloud-eu/opencloud/services/nats/pkg/config"
notifications "github.com/opencloud-eu/opencloud/services/notifications/pkg/config"
ocdav "github.com/opencloud-eu/opencloud/services/ocdav/pkg/config"
ocm "github.com/opencloud-eu/opencloud/services/ocm/pkg/config"
ocs "github.com/opencloud-eu/opencloud/services/ocs/pkg/config"
policies "github.com/opencloud-eu/opencloud/services/policies/pkg/config"
postprocessing "github.com/opencloud-eu/opencloud/services/postprocessing/pkg/config"
proxy "github.com/opencloud-eu/opencloud/services/proxy/pkg/config"
search "github.com/opencloud-eu/opencloud/services/search/pkg/config"
settings "github.com/opencloud-eu/opencloud/services/settings/pkg/config"
sharing "github.com/opencloud-eu/opencloud/services/sharing/pkg/config"
sse "github.com/opencloud-eu/opencloud/services/sse/pkg/config"
storagepublic "github.com/opencloud-eu/opencloud/services/storage-publiclink/pkg/config"
storageshares "github.com/opencloud-eu/opencloud/services/storage-shares/pkg/config"
storagesystem "github.com/opencloud-eu/opencloud/services/storage-system/pkg/config"
storageusers "github.com/opencloud-eu/opencloud/services/storage-users/pkg/config"
thumbnails "github.com/opencloud-eu/opencloud/services/thumbnails/pkg/config"
userlog "github.com/opencloud-eu/opencloud/services/userlog/pkg/config"
users "github.com/opencloud-eu/opencloud/services/users/pkg/config"
web "github.com/opencloud-eu/opencloud/services/web/pkg/config"
webdav "github.com/opencloud-eu/opencloud/services/webdav/pkg/config"
webfinger "github.com/opencloud-eu/opencloud/services/webfinger/pkg/config"
)
type Mode int
// Runtime configures the oCIS runtime when running in supervised mode.
type Runtime struct {
Port string `yaml:"port" env:"OC_RUNTIME_PORT" desc:"The TCP port at which oCIS will be available" introductionVersion:"pre5.0"`
Host string `yaml:"host" env:"OC_RUNTIME_HOST" desc:"The host at which oCIS will be available" introductionVersion:"pre5.0"`
Services []string `yaml:"services" env:"OC_RUN_EXTENSIONS;OC_RUN_SERVICES" desc:"A comma-separated list of service names. Will start only the listed services." introductionVersion:"pre5.0"`
Disabled []string `yaml:"disabled_services" env:"OC_EXCLUDE_RUN_SERVICES" desc:"A comma-separated list of service names. Will start all default services except of the ones listed. Has no effect when OC_RUN_SERVICES is set." introductionVersion:"pre5.0"`
Additional []string `yaml:"add_services" env:"OC_ADD_RUN_SERVICES" desc:"A comma-separated list of service names. Will add the listed services to the default configuration. Has no effect when OC_RUN_SERVICES is set. Note that one can add services not started by the default list and exclude services from the default list by using both envvars at the same time." introductionVersion:"pre5.0"`
}
// Config combines all available configuration parts.
type Config struct {
*shared.Commons `yaml:"shared"`
Tracing *shared.Tracing `yaml:"tracing"`
Log *shared.Log `yaml:"log"`
Cache *shared.Cache `yaml:"cache"`
GRPCClientTLS *shared.GRPCClientTLS `yaml:"grpc_client_tls"`
GRPCServiceTLS *shared.GRPCServiceTLS `yaml:"grpc_service_tls"`
HTTPServiceTLS shared.HTTPServiceTLS `yaml:"http_service_tls"`
Reva *shared.Reva `yaml:"reva"`
Mode Mode // DEPRECATED
File string
OcisURL string `yaml:"ocis_url" env:"OC_URL" desc:"URL, where oCIS is reachable for users." introductionVersion:"pre5.0"`
Registry string `yaml:"registry"`
TokenManager *shared.TokenManager `yaml:"token_manager"`
MachineAuthAPIKey string `yaml:"machine_auth_api_key" env:"OC_MACHINE_AUTH_API_KEY" desc:"Machine auth API key used to validate internal requests necessary for the access to resources from other services." introductionVersion:"pre5.0"`
TransferSecret string `yaml:"transfer_secret" env:"OC_TRANSFER_SECRET" desc:"Transfer secret for signing file up- and download requests." introductionVersion:"pre5.0"`
SystemUserID string `yaml:"system_user_id" env:"OC_SYSTEM_USER_ID" desc:"ID of the oCIS storage-system system user. Admins need to set the ID for the storage-system system user in this config option which is then used to reference the user. Any reasonable long string is possible, preferably this would be an UUIDv4 format." introductionVersion:"pre5.0"`
SystemUserAPIKey string `yaml:"system_user_api_key" env:"OC_SYSTEM_USER_API_KEY" desc:"API key for the storage-system system user." introductionVersion:"pre5.0"`
AdminUserID string `yaml:"admin_user_id" env:"OC_ADMIN_USER_ID" desc:"ID of a user, that should receive admin privileges. Consider that the UUID can be encoded in some LDAP deployment configurations like in .ldif files. These need to be decoded beforehand." introductionVersion:"pre5.0"`
Runtime Runtime `yaml:"runtime"`
Activitylog *activitylog.Config `yaml:"activitylog"`
Antivirus *antivirus.Config `yaml:"antivirus"`
AppProvider *appProvider.Config `yaml:"app_provider"`
AppRegistry *appRegistry.Config `yaml:"app_registry"`
Audit *audit.Config `yaml:"audit"`
AuthApp *authapp.Config `yaml:"auth_app"`
AuthBasic *authbasic.Config `yaml:"auth_basic"`
AuthBearer *authbearer.Config `yaml:"auth_bearer"`
AuthMachine *authmachine.Config `yaml:"auth_machine"`
AuthService *authservice.Config `yaml:"auth_service"`
Clientlog *clientlog.Config `yaml:"clientlog"`
Collaboration *collaboration.Config `yaml:"collaboration"`
EventHistory *eventhistory.Config `yaml:"eventhistory"`
Frontend *frontend.Config `yaml:"frontend"`
Gateway *gateway.Config `yaml:"gateway"`
Graph *graph.Config `yaml:"graph"`
Groups *groups.Config `yaml:"groups"`
IDM *idm.Config `yaml:"idm"`
IDP *idp.Config `yaml:"idp"`
Invitations *invitations.Config `yaml:"invitations"`
Nats *nats.Config `yaml:"nats"`
Notifications *notifications.Config `yaml:"notifications"`
OCDav *ocdav.Config `yaml:"ocdav"`
OCM *ocm.Config `yaml:"ocm"`
OCS *ocs.Config `yaml:"ocs"`
Postprocessing *postprocessing.Config `yaml:"postprocessing"`
Policies *policies.Config `yaml:"policies"`
Proxy *proxy.Config `yaml:"proxy"`
Settings *settings.Config `yaml:"settings"`
Sharing *sharing.Config `yaml:"sharing"`
SSE *sse.Config `yaml:"sse"`
StorageSystem *storagesystem.Config `yaml:"storage_system"`
StoragePublicLink *storagepublic.Config `yaml:"storage_public"`
StorageShares *storageshares.Config `yaml:"storage_shares"`
StorageUsers *storageusers.Config `yaml:"storage_users"`
Thumbnails *thumbnails.Config `yaml:"thumbnails"`
Userlog *userlog.Config `yaml:"userlog"`
Users *users.Config `yaml:"users"`
Web *web.Config `yaml:"web"`
WebDAV *webdav.Config `yaml:"webdav"`
Webfinger *webfinger.Config `yaml:"webfinger"`
Search *search.Config `yaml:"search"`
}
+13
View File
@@ -0,0 +1,13 @@
package config_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestConfig(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Config Suite")
}
+16
View File
@@ -0,0 +1,16 @@
package config_test
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/opencloud-eu/opencloud/pkg/config"
"gopkg.in/yaml.v2"
)
var _ = Describe("Config", func() {
It("Success generating the default config", func() {
cfg := config.DefaultConfig()
_, err := yaml.Marshal(cfg)
Expect(err).To(BeNil())
})
})
+30
View File
@@ -0,0 +1,30 @@
package configlog
import (
"fmt"
"os"
)
// Error logs the error
func Error(err error) {
if err != nil {
fmt.Printf("%v\n", err)
}
}
// ReturnError logs the error and returns it unchanged
func ReturnError(err error) error {
if err != nil {
fmt.Printf("%v\n", err)
}
return err
}
// ReturnFatal logs the error and calls os.Exit(1) and returns nil if no error is passed
func ReturnFatal(err error) error {
if err != nil {
fmt.Printf("%v\n", err)
os.Exit(1)
}
return nil
}
+103
View File
@@ -0,0 +1,103 @@
package config
import (
"github.com/opencloud-eu/opencloud/pkg/shared"
activitylog "github.com/opencloud-eu/opencloud/services/activitylog/pkg/config/defaults"
antivirus "github.com/opencloud-eu/opencloud/services/antivirus/pkg/config/defaults"
appProvider "github.com/opencloud-eu/opencloud/services/app-provider/pkg/config/defaults"
appRegistry "github.com/opencloud-eu/opencloud/services/app-registry/pkg/config/defaults"
audit "github.com/opencloud-eu/opencloud/services/audit/pkg/config/defaults"
authapp "github.com/opencloud-eu/opencloud/services/auth-app/pkg/config/defaults"
authbasic "github.com/opencloud-eu/opencloud/services/auth-basic/pkg/config/defaults"
authbearer "github.com/opencloud-eu/opencloud/services/auth-bearer/pkg/config/defaults"
authmachine "github.com/opencloud-eu/opencloud/services/auth-machine/pkg/config/defaults"
authservice "github.com/opencloud-eu/opencloud/services/auth-service/pkg/config/defaults"
clientlog "github.com/opencloud-eu/opencloud/services/clientlog/pkg/config/defaults"
collaboration "github.com/opencloud-eu/opencloud/services/collaboration/pkg/config/defaults"
eventhistory "github.com/opencloud-eu/opencloud/services/eventhistory/pkg/config/defaults"
frontend "github.com/opencloud-eu/opencloud/services/frontend/pkg/config/defaults"
gateway "github.com/opencloud-eu/opencloud/services/gateway/pkg/config/defaults"
graph "github.com/opencloud-eu/opencloud/services/graph/pkg/config/defaults"
groups "github.com/opencloud-eu/opencloud/services/groups/pkg/config/defaults"
idm "github.com/opencloud-eu/opencloud/services/idm/pkg/config/defaults"
idp "github.com/opencloud-eu/opencloud/services/idp/pkg/config/defaults"
invitations "github.com/opencloud-eu/opencloud/services/invitations/pkg/config/defaults"
nats "github.com/opencloud-eu/opencloud/services/nats/pkg/config/defaults"
notifications "github.com/opencloud-eu/opencloud/services/notifications/pkg/config/defaults"
ocdav "github.com/opencloud-eu/opencloud/services/ocdav/pkg/config/defaults"
ocm "github.com/opencloud-eu/opencloud/services/ocm/pkg/config/defaults"
ocs "github.com/opencloud-eu/opencloud/services/ocs/pkg/config/defaults"
policies "github.com/opencloud-eu/opencloud/services/policies/pkg/config/defaults"
postprocessing "github.com/opencloud-eu/opencloud/services/postprocessing/pkg/config/defaults"
proxy "github.com/opencloud-eu/opencloud/services/proxy/pkg/config/defaults"
search "github.com/opencloud-eu/opencloud/services/search/pkg/config/defaults"
settings "github.com/opencloud-eu/opencloud/services/settings/pkg/config/defaults"
sharing "github.com/opencloud-eu/opencloud/services/sharing/pkg/config/defaults"
sse "github.com/opencloud-eu/opencloud/services/sse/pkg/config/defaults"
storagepublic "github.com/opencloud-eu/opencloud/services/storage-publiclink/pkg/config/defaults"
storageshares "github.com/opencloud-eu/opencloud/services/storage-shares/pkg/config/defaults"
storageSystem "github.com/opencloud-eu/opencloud/services/storage-system/pkg/config/defaults"
storageusers "github.com/opencloud-eu/opencloud/services/storage-users/pkg/config/defaults"
thumbnails "github.com/opencloud-eu/opencloud/services/thumbnails/pkg/config/defaults"
userlog "github.com/opencloud-eu/opencloud/services/userlog/pkg/config/defaults"
users "github.com/opencloud-eu/opencloud/services/users/pkg/config/defaults"
web "github.com/opencloud-eu/opencloud/services/web/pkg/config/defaults"
webdav "github.com/opencloud-eu/opencloud/services/webdav/pkg/config/defaults"
webfinger "github.com/opencloud-eu/opencloud/services/webfinger/pkg/config/defaults"
)
func DefaultConfig() *Config {
return &Config{
OcisURL: "https://localhost:9200",
Runtime: Runtime{
Port: "9250",
Host: "localhost",
},
Reva: &shared.Reva{
Address: "com.owncloud.api.gateway",
},
Activitylog: activitylog.DefaultConfig(),
Antivirus: antivirus.DefaultConfig(),
AppProvider: appProvider.DefaultConfig(),
AppRegistry: appRegistry.DefaultConfig(),
Audit: audit.DefaultConfig(),
AuthApp: authapp.DefaultConfig(),
AuthBasic: authbasic.DefaultConfig(),
AuthBearer: authbearer.DefaultConfig(),
AuthMachine: authmachine.DefaultConfig(),
AuthService: authservice.DefaultConfig(),
Clientlog: clientlog.DefaultConfig(),
Collaboration: collaboration.DefaultConfig(),
EventHistory: eventhistory.DefaultConfig(),
Frontend: frontend.DefaultConfig(),
Gateway: gateway.DefaultConfig(),
Graph: graph.DefaultConfig(),
Groups: groups.DefaultConfig(),
IDM: idm.DefaultConfig(),
IDP: idp.DefaultConfig(),
Invitations: invitations.DefaultConfig(),
Nats: nats.DefaultConfig(),
Notifications: notifications.DefaultConfig(),
OCDav: ocdav.DefaultConfig(),
OCM: ocm.DefaultConfig(),
OCS: ocs.DefaultConfig(),
Postprocessing: postprocessing.DefaultConfig(),
Policies: policies.DefaultConfig(),
Proxy: proxy.DefaultConfig(),
Search: search.DefaultConfig(),
Settings: settings.DefaultConfig(),
Sharing: sharing.DefaultConfig(),
SSE: sse.DefaultConfig(),
StoragePublicLink: storagepublic.DefaultConfig(),
StorageShares: storageshares.DefaultConfig(),
StorageSystem: storageSystem.DefaultConfig(),
StorageUsers: storageusers.DefaultConfig(),
Thumbnails: thumbnails.DefaultConfig(),
Userlog: userlog.DefaultConfig(),
Users: users.DefaultConfig(),
Web: web.DefaultConfig(),
WebDAV: webdav.DefaultConfig(),
Webfinger: webfinger.DefaultConfig(),
}
}
+75
View File
@@ -0,0 +1,75 @@
package defaults
import (
"log"
"os"
"path"
)
const ()
var (
// switch between modes
BaseDataPathType = "homedir" // or "path"
// default data path
BaseDataPathValue = "/var/lib/ocis"
)
func BaseDataPath() string {
// It is not nice to have hidden / secrete configuration options
// But how can we update the base path for every occurrence with a flagset option?
// This is currently not possible and needs a new configuration concept
p := os.Getenv("OC_BASE_DATA_PATH")
if p != "" {
return p
}
switch BaseDataPathType {
case "homedir":
dir, err := os.UserHomeDir()
if err != nil {
// fallback to BaseDatapathValue for users without home
return BaseDataPathValue
}
return path.Join(dir, ".ocis")
case "path":
return BaseDataPathValue
default:
log.Fatalf("BaseDataPathType %s not found", BaseDataPathType)
return ""
}
}
var (
// switch between modes
BaseConfigPathType = "homedir" // or "path"
// default config path
BaseConfigPathValue = "/etc/ocis"
)
func BaseConfigPath() string {
// It is not nice to have hidden / secrete configuration options
// But how can we update the base path for every occurrence with a flagset option?
// This is currently not possible and needs a new configuration concept
p := os.Getenv("OC_CONFIG_DIR")
if p != "" {
return p
}
switch BaseConfigPathType {
case "homedir":
dir, err := os.UserHomeDir()
if err != nil {
// fallback to BaseConfigPathValue for users without home
return BaseConfigPathValue
}
return path.Join(dir, ".ocis", "config")
case "path":
return BaseConfigPathValue
default:
log.Fatalf("BaseConfigPathType %s not found", BaseConfigPathType)
return ""
}
}
+21
View File
@@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2014 Joe Shaw
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
+88
View File
@@ -0,0 +1,88 @@
`envdecode` is a Go package for populating structs from environment
variables. It's basically a fork of https://github.com/joeshaw/envdecode,
but changed to support multiple environment variables (precedence).
`envdecode` uses struct tags to map environment variables to fields,
allowing you to use any names you want for environment variables.
`envdecode` will recurse into nested structs, including pointers to
nested structs, but it will not allocate new pointers to structs.
## API
Full API docs are available on
[godoc.org](https://godoc.org/github.com/owncloud/ocis/ocis-pkg/config/envdecode).
Define a struct with `env` struct tags:
```go
type Config struct {
Hostname string `env:"SERVER_HOSTNAME,default=localhost"`
Port uint16 `env:"HTTP_PORT;SERVER_PORT,default=8080"`
AWS struct {
ID string `env:"AWS_ACCESS_KEY_ID"`
Secret string `env:"AWS_SECRET_ACCESS_KEY,required"`
SnsTopics []string `env:"AWS_SNS_TOPICS"`
}
Timeout time.Duration `env:"TIMEOUT,default=1m,strict"`
}
```
Fields _must be exported_ (i.e. begin with a capital letter) in order
for `envdecode` to work with them. An error will be returned if a
struct with no exported fields is decoded (including one that contains
no `env` tags at all).
Default values may be provided by appending ",default=value" to the
struct tag. Required values may be marked by appending ",required" to the
struct tag. Strict values may be marked by appending ",strict" which will
return an error on Decode if there is an error while parsing.
Then call `envdecode.Decode`:
```go
var cfg Config
err := envdecode.Decode(&cfg)
```
If you want all fields to act `strict`, you may use `envdecode.StrictDecode`:
```go
var cfg Config
err := envdecode.StrictDecode(&cfg)
```
All parse errors will fail fast and return an error in this mode.
## Supported types
- Structs (and pointer to structs)
- Slices of defined types below, separated by semicolon
- `bool`
- `float32`, `float64`
- `int`, `int8`, `int16`, `int32`, `int64`
- `uint`, `uint8`, `uint16`, `uint32`, `uint64`
- `string`
- `time.Duration`, using the [`time.ParseDuration()` format](http://golang.org/pkg/time/#ParseDuration)
- `*url.URL`, using [`url.Parse()`](https://godoc.org/net/url#Parse)
- Types those implement a `Decoder` interface
## Custom `Decoder`
If you want a field to be decoded with custom behavior, you may implement the interface `Decoder` for the filed type.
```go
type Config struct {
IPAddr IP `env:"IP_ADDR"`
}
type IP net.IP
// Decode implements the interface `envdecode.Decoder`
func (i *IP) Decode(repl string) error {
*i = net.ParseIP(repl)
return nil
}
```
`Decoder` is the interface implemented by an object that can decode an environment variable string representation of itself.
+436
View File
@@ -0,0 +1,436 @@
// Package envdecode is a package for populating structs from environment
// variables, using struct tags.
package envdecode
import (
"encoding"
"errors"
"fmt"
"log"
"net/url"
"os"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
// ErrInvalidTarget indicates that the target value passed to
// Decode is invalid. Target must be a non-nil pointer to a struct.
var ErrInvalidTarget = errors.New("target must be non-nil pointer to struct that has at least one exported field with a valid env tag")
var ErrNoTargetFieldsAreSet = errors.New("none of the target fields were set from environment variables")
// FailureFunc is called when an error is encountered during a MustDecode
// operation. It prints the error and terminates the process.
//
// This variable can be assigned to another function of the user-programmer's
// design, allowing for graceful recovery of the problem, such as loading
// from a backup configuration file.
var FailureFunc = func(err error) {
log.Fatalf("envdecode: an error was encountered while decoding: %v\n", err)
}
// Decoder is the interface implemented by an object that can decode an
// environment variable string representation of itself.
type Decoder interface {
Decode(string) error
}
// Decode environment variables into the provided target. The target
// must be a non-nil pointer to a struct. Fields in the struct must
// be exported, and tagged with an "env" struct tag with a value
// containing the name of the environment variable. An error is
// returned if there are no exported members tagged.
//
// Default values may be provided by appending ",default=value" to the
// struct tag. Required values may be marked by appending ",required"
// to the struct tag. It is an error to provide both "default" and
// "required". Strict values may be marked by appending ",strict" which
// will return an error on Decode if there is an error while parsing.
// If everything must be strict, consider using StrictDecode instead.
//
// All primitive types are supported, including bool, floating point,
// signed and unsigned integers, and string. Boolean and numeric
// types are decoded using the standard strconv Parse functions for
// those types. Structs and pointers to structs are decoded
// recursively. time.Duration is supported via the
// time.ParseDuration() function and *url.URL is supported via the
// url.Parse() function. Slices are supported for all above mentioned
// primitive types. Semicolon is used as delimiter in environment variables.
func Decode(target interface{}) error {
nFields, err := decode(target, false)
if err != nil {
return err
}
// if we didn't do anything - the user probably did something
// wrong like leave all fields unexported.
if nFields == 0 {
return ErrNoTargetFieldsAreSet
}
return nil
}
// StrictDecode is similar to Decode except all fields will have an implicit
// ",strict" on all fields.
func StrictDecode(target interface{}) error {
nFields, err := decode(target, true)
if err != nil {
return err
}
// if we didn't do anything - the user probably did something
// wrong like leave all fields unexported.
if nFields == 0 {
return ErrInvalidTarget
}
return nil
}
func decode(target interface{}, strict bool) (int, error) {
s := reflect.ValueOf(target)
if s.Kind() != reflect.Ptr || s.IsNil() {
return 0, ErrInvalidTarget
}
s = s.Elem()
if s.Kind() != reflect.Struct {
return 0, ErrInvalidTarget
}
t := s.Type()
setFieldCount := 0
for i := 0; i < s.NumField(); i++ {
// Localize the umbrella `strict` value to the specific field.
strict := strict
f := s.Field(i)
switch f.Kind() {
case reflect.Ptr:
if f.Elem().Kind() != reflect.Struct {
break
}
f = f.Elem()
fallthrough
case reflect.Struct:
if !f.Addr().CanInterface() {
continue
}
ss := f.Addr().Interface()
_, custom := ss.(Decoder)
if custom {
break
}
n, err := decode(ss, strict)
if err != nil {
return 0, err
}
setFieldCount += n
}
if !f.CanSet() {
continue
}
tag := t.Field(i).Tag.Get("env")
if tag == "" {
continue
}
parts := strings.Split(tag, ",")
overrides := strings.Split(parts[0], `;`)
var env string
var envSet bool
for _, override := range overrides {
if v, set := os.LookupEnv(override); set {
env = v
envSet = true
}
}
required := false
hasDefault := false
defaultValue := ""
for _, o := range parts[1:] {
if !required {
required = strings.HasPrefix(o, "required")
}
if strings.HasPrefix(o, "default=") {
hasDefault = true
defaultValue = o[8:]
}
if !strict {
strict = strings.HasPrefix(o, "strict")
}
}
if required && hasDefault {
panic(`envdecode: "default" and "required" may not be specified in the same annotation`)
}
if !envSet && required {
return 0, fmt.Errorf("the environment variable \"%s\" is missing", parts[0])
}
if !envSet {
env = defaultValue
}
if !envSet && env == "" {
continue
}
setFieldCount++
unmarshaler, implementsUnmarshaler := f.Addr().Interface().(encoding.TextUnmarshaler)
decoder, implmentsDecoder := f.Addr().Interface().(Decoder)
if implmentsDecoder {
if err := decoder.Decode(env); err != nil {
return 0, err
}
} else if implementsUnmarshaler {
if err := unmarshaler.UnmarshalText([]byte(env)); err != nil {
return 0, err
}
} else if f.Kind() == reflect.Slice {
if err := decodeSlice(&f, env); err != nil {
return 0, err
}
} else {
if err := decodePrimitiveType(&f, env); err != nil && strict {
return 0, err
}
}
}
return setFieldCount, nil
}
func decodeSlice(f *reflect.Value, env string) error {
parts := strings.Split(env, ",")
values := parts[:0]
for _, x := range parts {
if x != "" {
values = append(values, strings.TrimSpace(x))
}
}
valuesCount := len(values)
slice := reflect.MakeSlice(f.Type(), valuesCount, valuesCount)
if valuesCount > 0 {
for i := 0; i < valuesCount; i++ {
e := slice.Index(i)
err := decodePrimitiveType(&e, values[i])
if err != nil {
return err
}
}
}
f.Set(slice)
return nil
}
func decodePrimitiveType(f *reflect.Value, env string) error {
switch f.Kind() {
case reflect.Bool:
v, err := strconv.ParseBool(env)
if err != nil {
return err
}
f.SetBool(v)
case reflect.Float32, reflect.Float64:
bits := f.Type().Bits()
v, err := strconv.ParseFloat(env, bits)
if err != nil {
return err
}
f.SetFloat(v)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if t := f.Type(); t.PkgPath() == "time" && t.Name() == "Duration" {
v, err := time.ParseDuration(env)
if err != nil {
return err
}
f.SetInt(int64(v))
} else {
bits := f.Type().Bits()
v, err := strconv.ParseInt(env, 0, bits)
if err != nil {
return err
}
f.SetInt(v)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
bits := f.Type().Bits()
v, err := strconv.ParseUint(env, 0, bits)
if err != nil {
return err
}
f.SetUint(v)
case reflect.String:
f.SetString(env)
case reflect.Ptr:
if t := f.Type().Elem(); t.Kind() == reflect.Struct && t.PkgPath() == "net/url" && t.Name() == "URL" {
v, err := url.Parse(env)
if err != nil {
return err
}
f.Set(reflect.ValueOf(v))
}
}
return nil
}
// MustDecode calls Decode and terminates the process if any errors
// are encountered.
func MustDecode(target interface{}) {
if err := Decode(target); err != nil {
FailureFunc(err)
}
}
// MustStrictDecode calls StrictDecode and terminates the process if any errors
// are encountered.
func MustStrictDecode(target interface{}) {
if err := StrictDecode(target); err != nil {
FailureFunc(err)
}
}
//// Configuration info for Export
type ConfigInfo struct {
Field string
EnvVar string
Value string
DefaultValue string
HasDefault bool
Required bool
UsesEnv bool
}
type ConfigInfoSlice []*ConfigInfo
func (c ConfigInfoSlice) Less(i, j int) bool {
return c[i].EnvVar < c[j].EnvVar
}
func (c ConfigInfoSlice) Len() int {
return len(c)
}
func (c ConfigInfoSlice) Swap(i, j int) {
c[i], c[j] = c[j], c[i]
}
// Returns a list of final configuration metadata sorted by envvar name
func Export(target interface{}) ([]*ConfigInfo, error) {
s := reflect.ValueOf(target)
if s.Kind() != reflect.Ptr || s.IsNil() {
return nil, ErrInvalidTarget
}
cfg := []*ConfigInfo{}
s = s.Elem()
if s.Kind() != reflect.Struct {
return nil, ErrInvalidTarget
}
t := s.Type()
for i := 0; i < s.NumField(); i++ {
f := s.Field(i)
fName := t.Field(i).Name
fElem := f
if f.Kind() == reflect.Ptr {
fElem = f.Elem()
}
if fElem.Kind() == reflect.Struct {
ss := fElem.Addr().Interface()
subCfg, err := Export(ss)
if err != ErrInvalidTarget {
f = fElem
for _, v := range subCfg {
v.Field = fmt.Sprintf("%s.%s", fName, v.Field)
cfg = append(cfg, v)
}
}
}
tag := t.Field(i).Tag.Get("env")
if tag == "" {
continue
}
parts := strings.Split(tag, ",")
ci := &ConfigInfo{
Field: fName,
EnvVar: parts[0],
UsesEnv: os.Getenv(parts[0]) != "",
}
for _, o := range parts[1:] {
if strings.HasPrefix(o, "default=") {
ci.HasDefault = true
ci.DefaultValue = o[8:]
} else if strings.HasPrefix(o, "required") {
ci.Required = true
}
}
if f.Kind() == reflect.Ptr && f.IsNil() {
ci.Value = ""
} else if stringer, ok := f.Interface().(fmt.Stringer); ok {
ci.Value = stringer.String()
} else {
switch f.Kind() {
case reflect.Bool:
ci.Value = strconv.FormatBool(f.Bool())
case reflect.Float32, reflect.Float64:
bits := f.Type().Bits()
ci.Value = strconv.FormatFloat(f.Float(), 'f', -1, bits)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
ci.Value = strconv.FormatInt(f.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
ci.Value = strconv.FormatUint(f.Uint(), 10)
case reflect.String:
ci.Value = f.String()
case reflect.Slice:
ci.Value = fmt.Sprintf("%v", f.Interface())
default:
// Unable to determine string format for value
return nil, ErrInvalidTarget
}
}
cfg = append(cfg, ci)
}
// No configuration tags found, assume invalid input
if len(cfg) == 0 {
return nil, ErrInvalidTarget
}
sort.Sort(ConfigInfoSlice(cfg))
return cfg, nil
}
+790
View File
@@ -0,0 +1,790 @@
package envdecode
import (
"encoding/json"
"fmt"
"math"
"net/url"
"os"
"reflect"
"sort"
"strconv"
"testing"
"time"
)
type nested struct {
String string `env:"TEST_STRING"`
}
type testConfig struct {
String string `env:"TEST_STRING"`
Int64 int64 `env:"TEST_INT64"`
Uint16 uint16 `env:"TEST_UINT16"`
Float64 float64 `env:"TEST_FLOAT64"`
Bool bool `env:"TEST_BOOL"`
Duration time.Duration `env:"TEST_DURATION"`
URL *url.URL `env:"TEST_URL"`
StringSlice []string `env:"TEST_STRING_SLICE"`
Int64Slice []int64 `env:"TEST_INT64_SLICE"`
Uint16Slice []uint16 `env:"TEST_UINT16_SLICE"`
Float64Slice []float64 `env:"TEST_FLOAT64_SLICE"`
BoolSlice []bool `env:"TEST_BOOL_SLICE"`
DurationSlice []time.Duration `env:"TEST_DURATION_SLICE"`
URLSlice []*url.URL `env:"TEST_URL_SLICE"`
UnsetString string `env:"TEST_UNSET_STRING"`
UnsetInt64 int64 `env:"TEST_UNSET_INT64"`
UnsetDuration time.Duration `env:"TEST_UNSET_DURATION"`
UnsetURL *url.URL `env:"TEST_UNSET_URL"`
UnsetSlice []string `env:"TEST_UNSET_SLICE"`
InvalidInt64 int64 `env:"TEST_INVALID_INT64"`
UnusedField string
unexportedField string
IgnoredPtr *bool `env:"TEST_BOOL"`
Nested nested
NestedPtr *nested
DecoderStruct decoderStruct `env:"TEST_DECODER_STRUCT"`
DecoderStructPtr *decoderStruct `env:"TEST_DECODER_STRUCT_PTR"`
DecoderString decoderString `env:"TEST_DECODER_STRING"`
UnmarshalerNumber unmarshalerNumber `env:"TEST_UNMARSHALER_NUMBER"`
DefaultInt int `env:"TEST_UNSET,asdf=asdf,default=1234"`
DefaultSliceInt []int `env:"TEST_UNSET,asdf=asdf,default=1"`
DefaultDuration time.Duration `env:"TEST_UNSET,asdf=asdf,default=24h"`
DefaultURL *url.URL `env:"TEST_UNSET,default=http://example.com"`
}
type testConfigNoSet struct {
Some string `env:"TEST_THIS_ENV_WILL_NOT_BE_SET"`
}
type testConfigRequired struct {
Required string `env:"TEST_REQUIRED,required"`
}
type testConfigRequiredDefault struct {
RequiredDefault string `env:"TEST_REQUIRED_DEFAULT,required,default=test"`
}
type testConfigOverride struct {
OverrideString string `env:"TEST_OVERRIDE_A;TEST_OVERRIDE_B,default=override_default"`
}
type testNoExportedFields struct {
// following unexported fields are used for tests
aString string `env:"TEST_STRING"` //nolint:structcheck,unused
anInt64 int64 `env:"TEST_INT64"` //nolint:structcheck,unused
aUint16 uint16 `env:"TEST_UINT16"` //nolint:structcheck,unused
aFloat64 float64 `env:"TEST_FLOAT64"` //nolint:structcheck,unused
aBool bool `env:"TEST_BOOL"` //nolint:structcheck,unused
}
type testNoTags struct {
String string
}
type decoderStruct struct {
String string
}
func (d *decoderStruct) Decode(env string) error {
return json.Unmarshal([]byte(env), &d)
}
type decoderString string
func (d *decoderString) Decode(env string) error {
r, l := []rune(env), len(env)
for i := 0; i < l/2; i++ {
r[i], r[l-1-i] = r[l-1-i], r[i]
}
*d = decoderString(r)
return nil
}
type unmarshalerNumber uint8
func (o *unmarshalerNumber) UnmarshalText(raw []byte) error {
n, err := strconv.ParseUint(string(raw), 8, 8) // parse text as octal number
if err != nil {
return err
}
*o = unmarshalerNumber(n)
return nil
}
func TestDecode(t *testing.T) {
int64Val := int64(-(1 << 50))
int64AsString := fmt.Sprintf("%d", int64Val)
piAsString := fmt.Sprintf("%.48f", math.Pi)
os.Setenv("TEST_STRING", "foo")
os.Setenv("TEST_INT64", int64AsString)
os.Setenv("TEST_UINT16", "60000")
os.Setenv("TEST_FLOAT64", piAsString)
os.Setenv("TEST_BOOL", "true")
os.Setenv("TEST_DURATION", "10m")
os.Setenv("TEST_URL", "https://example.com")
os.Setenv("TEST_INVALID_INT64", "asdf")
os.Setenv("TEST_STRING_SLICE", "foo,bar")
os.Setenv("TEST_INT64_SLICE", int64AsString+","+int64AsString)
os.Setenv("TEST_UINT16_SLICE", "60000,50000")
os.Setenv("TEST_FLOAT64_SLICE", piAsString+","+piAsString)
os.Setenv("TEST_BOOL_SLICE", "true, false, true")
os.Setenv("TEST_DURATION_SLICE", "10m, 20m")
os.Setenv("TEST_URL_SLICE", "https://example.com")
os.Setenv("TEST_DECODER_STRUCT", "{\"string\":\"foo\"}")
os.Setenv("TEST_DECODER_STRUCT_PTR", "{\"string\":\"foo\"}")
os.Setenv("TEST_DECODER_STRING", "oof")
os.Setenv("TEST_UNMARSHALER_NUMBER", "07")
var tc testConfig
tc.NestedPtr = &nested{}
tc.DecoderStructPtr = &decoderStruct{}
err := Decode(&tc)
if err != nil {
t.Fatal(err)
}
if tc.String != "foo" {
t.Fatalf(`Expected "foo", got "%s"`, tc.String)
}
if tc.Int64 != -(1 << 50) {
t.Fatalf("Expected %d, got %d", -(1 << 50), tc.Int64)
}
if tc.Uint16 != 60000 {
t.Fatalf("Expected 60000, got %d", tc.Uint16)
}
if tc.Float64 != math.Pi {
t.Fatalf("Expected %.48f, got %.48f", math.Pi, tc.Float64)
}
if !tc.Bool {
t.Fatal("Expected true, got false")
}
duration, _ := time.ParseDuration("10m")
if tc.Duration != duration {
t.Fatalf("Expected %d, got %d", duration, tc.Duration)
}
if tc.URL == nil {
t.Fatalf("Expected https://example.com, got nil")
} else if tc.URL.String() != "https://example.com" {
t.Fatalf("Expected https://example.com, got %s", tc.URL.String())
}
expectedStringSlice := []string{"foo", "bar"}
if !reflect.DeepEqual(tc.StringSlice, expectedStringSlice) {
t.Fatalf("Expected %s, got %s", expectedStringSlice, tc.StringSlice)
}
expectedInt64Slice := []int64{int64Val, int64Val}
if !reflect.DeepEqual(tc.Int64Slice, expectedInt64Slice) {
t.Fatalf("Expected %#v, got %#v", expectedInt64Slice, tc.Int64Slice)
}
expectedUint16Slice := []uint16{60000, 50000}
if !reflect.DeepEqual(tc.Uint16Slice, expectedUint16Slice) {
t.Fatalf("Expected %#v, got %#v", expectedUint16Slice, tc.Uint16Slice)
}
expectedFloat64Slice := []float64{math.Pi, math.Pi}
if !reflect.DeepEqual(tc.Float64Slice, expectedFloat64Slice) {
t.Fatalf("Expected %#v, got %#v", expectedFloat64Slice, tc.Float64Slice)
}
expectedBoolSlice := []bool{true, false, true}
if !reflect.DeepEqual(tc.BoolSlice, expectedBoolSlice) {
t.Fatalf("Expected %#v, got %#v", expectedBoolSlice, tc.BoolSlice)
}
duration2, _ := time.ParseDuration("20m")
expectedDurationSlice := []time.Duration{duration, duration2}
if !reflect.DeepEqual(tc.DurationSlice, expectedDurationSlice) {
t.Fatalf("Expected %s, got %s", expectedDurationSlice, tc.DurationSlice)
}
urlVal, _ := url.Parse("https://example.com")
expectedURLSlice := []*url.URL{urlVal}
if !reflect.DeepEqual(tc.URLSlice, expectedURLSlice) {
t.Fatalf("Expected %s, got %s", expectedURLSlice, tc.URLSlice)
}
if tc.UnsetString != "" {
t.Fatal("Got non-empty string unexpectedly")
}
if tc.UnsetInt64 != 0 {
t.Fatal("Got non-zero int unexpectedly")
}
if tc.UnsetDuration != time.Duration(0) {
t.Fatal("Got non-zero time.Duration unexpectedly")
}
if tc.UnsetURL != nil {
t.Fatal("Got non-zero *url.URL unexpectedly")
}
if len(tc.UnsetSlice) > 0 {
t.Fatal("Got not-empty string slice unexpectedly")
}
if tc.InvalidInt64 != 0 {
t.Fatal("Got non-zero int unexpectedly")
}
if tc.UnusedField != "" {
t.Fatal("Expected empty field")
}
if tc.unexportedField != "" {
t.Fatal("Expected empty field")
}
if tc.IgnoredPtr != nil {
t.Fatal("Expected nil pointer")
}
if tc.Nested.String != "foo" {
t.Fatalf(`Expected "foo", got "%s"`, tc.Nested.String)
}
if tc.NestedPtr.String != "foo" {
t.Fatalf(`Expected "foo", got "%s"`, tc.NestedPtr.String)
}
if tc.DefaultInt != 1234 {
t.Fatalf("Expected 1234, got %d", tc.DefaultInt)
}
expectedDefaultSlice := []int{1}
if !reflect.DeepEqual(tc.DefaultSliceInt, expectedDefaultSlice) {
t.Fatalf("Expected %d, got %d", expectedDefaultSlice, tc.DefaultSliceInt)
}
defaultDuration, _ := time.ParseDuration("24h")
if tc.DefaultDuration != defaultDuration {
t.Fatalf("Expected %d, got %d", defaultDuration, tc.DefaultInt)
}
if tc.DefaultURL.String() != "http://example.com" {
t.Fatalf("Expected http://example.com, got %s", tc.DefaultURL.String())
}
if tc.DecoderStruct.String != "foo" {
t.Fatalf("Expected foo, got %s", tc.DecoderStruct.String)
}
if tc.DecoderStructPtr.String != "foo" {
t.Fatalf("Expected foo, got %s", tc.DecoderStructPtr.String)
}
if tc.DecoderString != "foo" {
t.Fatalf("Expected foo, got %s", tc.DecoderString)
}
if tc.UnmarshalerNumber != 07 {
t.Fatalf("Expected 07, got %04o", tc.UnmarshalerNumber)
}
os.Setenv("TEST_REQUIRED", "required")
var tcr testConfigRequired
err = Decode(&tcr)
if err != nil {
t.Fatal(err)
}
if tcr.Required != "required" {
t.Fatalf("Expected \"required\", got %s", tcr.Required)
}
_, err = Export(&tcr)
if err != nil {
t.Fatal(err)
}
var tco testConfigOverride
err = Decode(&tco)
if err != nil {
t.Fatal(err)
}
if tco.OverrideString != "override_default" {
t.Fatalf(`Expected "override_default" but got %s`, tco.OverrideString)
}
os.Setenv("TEST_OVERRIDE_A", "override_a")
tco = testConfigOverride{}
err = Decode(&tco)
if err != nil {
t.Fatal(err)
}
if tco.OverrideString != "override_a" {
t.Fatalf(`Expected "override_a" but got %s`, tco.OverrideString)
}
os.Setenv("TEST_OVERRIDE_B", "override_b")
tco = testConfigOverride{}
err = Decode(&tco)
if err != nil {
t.Fatal(err)
}
if tco.OverrideString != "override_b" {
t.Fatalf(`Expected "override_b" but got %s`, tco.OverrideString)
}
}
func TestDecodeErrors(t *testing.T) {
var b bool
err := Decode(&b)
if err != ErrInvalidTarget {
t.Fatal("Should have gotten an error decoding into a bool")
}
var tc testConfig
err = Decode(tc) //nolint:govet
if err != ErrInvalidTarget {
t.Fatal("Should have gotten an error decoding into a non-pointer")
}
var tcp *testConfig
err = Decode(tcp)
if err != ErrInvalidTarget {
t.Fatal("Should have gotten an error decoding to a nil pointer")
}
var tnt testNoTags
err = Decode(&tnt)
if err != ErrNoTargetFieldsAreSet {
t.Fatal("Should have gotten an error decoding a struct with no tags")
}
var tcni testNoExportedFields
err = Decode(&tcni)
if err != ErrNoTargetFieldsAreSet {
t.Fatal("Should have gotten an error decoding a struct with no unexported fields")
}
var tcr testConfigRequired
os.Clearenv()
err = Decode(&tcr)
if err == nil {
t.Fatal("An error was expected but recieved:", err)
}
var tcns testConfigNoSet
err = Decode(&tcns)
if err != ErrNoTargetFieldsAreSet {
t.Fatal("Should have gotten an error decoding when no env variables are set")
}
missing := false
FailureFunc = func(err error) {
missing = true
}
MustDecode(&tcr)
if !missing {
t.Fatal("The FailureFunc should have been called but it was not")
}
var tcrd testConfigRequiredDefault
defer func() {
_ = recover()
}()
_ = Decode(&tcrd)
t.Fatal("This should not have been reached. A panic should have occured.")
}
func TestOnlyNested(t *testing.T) {
os.Setenv("TEST_STRING", "foo")
// No env vars in the outer level are ok, as long as they're
// in the inner struct.
var o struct {
Inner nested
}
if err := Decode(&o); err != nil {
t.Fatalf("Expected no error, got %s", err)
}
// No env vars in the inner levels are ok, as long as they're
// in the outer struct.
var o2 struct {
Inner noConfig
X string `env:"TEST_STRING"`
}
if err := Decode(&o2); err != nil {
t.Fatalf("Expected no error, got %s", err)
}
// No env vars in either outer or inner levels should result
// in error
var o3 struct {
Inner noConfig
}
if err := Decode(&o3); err != ErrNoTargetFieldsAreSet {
t.Fatalf("Expected ErrInvalidTarget, got %s", err)
}
}
func ExampleDecode() {
type Example struct {
// A string field, without any default
String string `env:"EXAMPLE_STRING"`
// A uint16 field, with a default value of 100
Uint16 uint16 `env:"EXAMPLE_UINT16,default=100"`
}
os.Setenv("EXAMPLE_STRING", "an example!")
var e Example
if err := Decode(&e); err != nil {
panic(err)
}
// If TEST_STRING is set, e.String will contain its value
fmt.Println(e.String)
// If TEST_UINT16 is set, e.Uint16 will contain its value.
// Otherwise, it will contain the default value, 100.
fmt.Println(e.Uint16)
// Output:
// an example!
// 100
}
//// Export tests
type testConfigExport struct {
String string `env:"TEST_STRING"`
Int64 int64 `env:"TEST_INT64"`
Uint16 uint16 `env:"TEST_UINT16"`
Float64 float64 `env:"TEST_FLOAT64"`
Bool bool `env:"TEST_BOOL"`
Duration time.Duration `env:"TEST_DURATION"`
URL *url.URL `env:"TEST_URL"`
StringSlice []string `env:"TEST_STRING_SLICE"`
UnsetString string `env:"TEST_UNSET_STRING"`
UnsetInt64 int64 `env:"TEST_UNSET_INT64"`
UnsetDuration time.Duration `env:"TEST_UNSET_DURATION"`
UnsetURL *url.URL `env:"TEST_UNSET_URL"`
UnusedField string
unexportedField string //nolint:structcheck,unused
IgnoredPtr *bool `env:"TEST_IGNORED_POINTER"`
Nested nestedConfigExport
NestedPtr *nestedConfigExportPointer
NestedPtrUnset *nestedConfigExportPointer
NestedTwice nestedTwiceConfig
NoConfig noConfig
NoConfigPtr *noConfig
NoConfigPtrSet *noConfig
RequiredInt int `env:"TEST_REQUIRED_INT,required"`
DefaultBool bool `env:"TEST_DEFAULT_BOOL,default=true"`
DefaultInt int `env:"TEST_DEFAULT_INT,default=1234"`
DefaultDuration time.Duration `env:"TEST_DEFAULT_DURATION,default=24h"`
DefaultURL *url.URL `env:"TEST_DEFAULT_URL,default=http://example.com"`
DefaultIntSet int `env:"TEST_DEFAULT_INT_SET,default=99"`
DefaultIntSlice []int `env:"TEST_DEFAULT_INT_SLICE,default=99"`
}
type nestedConfigExport struct {
String string `env:"TEST_NESTED_STRING"`
}
type nestedConfigExportPointer struct {
String string `env:"TEST_NESTED_STRING_POINTER"`
}
type noConfig struct {
Int int
}
type nestedTwiceConfig struct {
Nested nestedConfigInner
}
type nestedConfigInner struct {
String string `env:"TEST_NESTED_TWICE_STRING"`
}
type testConfigStrict struct {
InvalidInt64Strict int64 `env:"TEST_INVALID_INT64,strict,default=1"`
InvalidInt64Implicit int64 `env:"TEST_INVALID_INT64_IMPLICIT,default=1"`
Nested struct {
InvalidInt64Strict int64 `env:"TEST_INVALID_INT64_NESTED,strict,required"`
InvalidInt64Implicit int64 `env:"TEST_INVALID_INT64_NESTED_IMPLICIT,required"`
}
}
func TestInvalidStrict(t *testing.T) {
cases := []struct {
decoder func(interface{}) error
rootValue string
nestedValue string
rootValueImplicit string
nestedValueImplicit string
pass bool
}{
{Decode, "1", "1", "1", "1", true},
{Decode, "1", "1", "1", "asdf", true},
{Decode, "1", "1", "asdf", "1", true},
{Decode, "1", "1", "asdf", "asdf", true},
{Decode, "1", "asdf", "1", "1", false},
{Decode, "asdf", "1", "1", "1", false},
{Decode, "asdf", "asdf", "1", "1", false},
{StrictDecode, "1", "1", "1", "1", true},
{StrictDecode, "asdf", "1", "1", "1", false},
{StrictDecode, "1", "asdf", "1", "1", false},
{StrictDecode, "1", "1", "asdf", "1", false},
{StrictDecode, "1", "1", "1", "asdf", false},
{StrictDecode, "asdf", "asdf", "1", "1", false},
{StrictDecode, "1", "asdf", "asdf", "1", false},
{StrictDecode, "1", "1", "asdf", "asdf", false},
{StrictDecode, "1", "asdf", "asdf", "asdf", false},
{StrictDecode, "asdf", "asdf", "asdf", "asdf", false},
}
for _, test := range cases {
os.Setenv("TEST_INVALID_INT64", test.rootValue)
os.Setenv("TEST_INVALID_INT64_NESTED", test.nestedValue)
os.Setenv("TEST_INVALID_INT64_IMPLICIT", test.rootValueImplicit)
os.Setenv("TEST_INVALID_INT64_NESTED_IMPLICIT", test.nestedValueImplicit)
var tc testConfigStrict
if err := test.decoder(&tc); test.pass != (err == nil) {
t.Fatalf("Have err=%s wanted pass=%v", err, test.pass)
}
}
}
func TestExport(t *testing.T) {
testFloat64 := fmt.Sprintf("%.48f", math.Pi)
testFloat64Output := strconv.FormatFloat(math.Pi, 'f', -1, 64)
testInt64 := fmt.Sprintf("%d", -(1 << 50))
os.Setenv("TEST_STRING", "foo")
os.Setenv("TEST_INT64", testInt64)
os.Setenv("TEST_UINT16", "60000")
os.Setenv("TEST_FLOAT64", testFloat64)
os.Setenv("TEST_BOOL", "true")
os.Setenv("TEST_DURATION", "10m")
os.Setenv("TEST_URL", "https://example.com")
os.Setenv("TEST_STRING_SLICE", "foo,bar")
os.Setenv("TEST_NESTED_STRING", "nest_foo")
os.Setenv("TEST_NESTED_STRING_POINTER", "nest_foo_ptr")
os.Setenv("TEST_NESTED_TWICE_STRING", "nest_twice_foo")
os.Setenv("TEST_REQUIRED_INT", "101")
os.Setenv("TEST_DEFAULT_INT_SET", "102")
os.Setenv("TEST_DEFAULT_INT_SLICE", "1,2,3")
var tc testConfigExport
tc.NestedPtr = &nestedConfigExportPointer{}
tc.NoConfigPtrSet = &noConfig{}
if err := Decode(&tc); err != nil {
t.Fatal(err)
}
rc, err := Export(&tc)
if err != nil {
t.Fatal(err)
}
expected := []*ConfigInfo{
&ConfigInfo{
Field: "String",
EnvVar: "TEST_STRING",
Value: "foo",
UsesEnv: true,
},
&ConfigInfo{
Field: "Int64",
EnvVar: "TEST_INT64",
Value: testInt64,
UsesEnv: true,
},
&ConfigInfo{
Field: "Uint16",
EnvVar: "TEST_UINT16",
Value: "60000",
UsesEnv: true,
},
&ConfigInfo{
Field: "Float64",
EnvVar: "TEST_FLOAT64",
Value: testFloat64Output,
UsesEnv: true,
},
&ConfigInfo{
Field: "Bool",
EnvVar: "TEST_BOOL",
Value: "true",
UsesEnv: true,
},
&ConfigInfo{
Field: "Duration",
EnvVar: "TEST_DURATION",
Value: "10m0s",
UsesEnv: true,
},
&ConfigInfo{
Field: "URL",
EnvVar: "TEST_URL",
Value: "https://example.com",
UsesEnv: true,
},
&ConfigInfo{
Field: "StringSlice",
EnvVar: "TEST_STRING_SLICE",
Value: "[foo bar]",
UsesEnv: true,
},
&ConfigInfo{
Field: "UnsetString",
EnvVar: "TEST_UNSET_STRING",
Value: "",
},
&ConfigInfo{
Field: "UnsetInt64",
EnvVar: "TEST_UNSET_INT64",
Value: "0",
},
&ConfigInfo{
Field: "UnsetDuration",
EnvVar: "TEST_UNSET_DURATION",
Value: "0s",
},
&ConfigInfo{
Field: "UnsetURL",
EnvVar: "TEST_UNSET_URL",
Value: "",
},
&ConfigInfo{
Field: "IgnoredPtr",
EnvVar: "TEST_IGNORED_POINTER",
Value: "",
},
&ConfigInfo{
Field: "Nested.String",
EnvVar: "TEST_NESTED_STRING",
Value: "nest_foo",
UsesEnv: true,
},
&ConfigInfo{
Field: "NestedPtr.String",
EnvVar: "TEST_NESTED_STRING_POINTER",
Value: "nest_foo_ptr",
UsesEnv: true,
},
&ConfigInfo{
Field: "NestedTwice.Nested.String",
EnvVar: "TEST_NESTED_TWICE_STRING",
Value: "nest_twice_foo",
UsesEnv: true,
},
&ConfigInfo{
Field: "RequiredInt",
EnvVar: "TEST_REQUIRED_INT",
Value: "101",
UsesEnv: true,
Required: true,
},
&ConfigInfo{
Field: "DefaultBool",
EnvVar: "TEST_DEFAULT_BOOL",
Value: "true",
DefaultValue: "true",
HasDefault: true,
},
&ConfigInfo{
Field: "DefaultInt",
EnvVar: "TEST_DEFAULT_INT",
Value: "1234",
DefaultValue: "1234",
HasDefault: true,
},
&ConfigInfo{
Field: "DefaultDuration",
EnvVar: "TEST_DEFAULT_DURATION",
Value: "24h0m0s",
DefaultValue: "24h",
HasDefault: true,
},
&ConfigInfo{
Field: "DefaultURL",
EnvVar: "TEST_DEFAULT_URL",
Value: "http://example.com",
DefaultValue: "http://example.com",
HasDefault: true,
},
&ConfigInfo{
Field: "DefaultIntSet",
EnvVar: "TEST_DEFAULT_INT_SET",
Value: "102",
DefaultValue: "99",
HasDefault: true,
UsesEnv: true,
},
&ConfigInfo{
Field: "DefaultIntSlice",
EnvVar: "TEST_DEFAULT_INT_SLICE",
Value: "[1 2 3]",
DefaultValue: "99",
HasDefault: true,
UsesEnv: true,
},
}
sort.Sort(ConfigInfoSlice(expected))
if len(rc) != len(expected) {
t.Fatalf("Have %d results, expected %d", len(rc), len(expected))
}
for n, v := range rc {
ci := expected[n]
if *ci != *v {
t.Fatalf("have %+v, expected %+v", v, ci)
}
}
}
+69
View File
@@ -0,0 +1,69 @@
package config
import (
"io/fs"
"os"
"path"
"strings"
gofig "github.com/gookit/config/v2"
gooyaml "github.com/gookit/config/v2/yaml"
"github.com/opencloud-eu/opencloud/pkg/config/defaults"
)
var (
// decoderConfigTagName sets the tag name to be used from the config structs
// currently we only support "yaml" because we only support config loading
// from yaml files and the yaml parser has no simple way to set a custom tag name to use
decoderConfigTagName = "yaml"
)
// BindSourcesToStructs assigns any config value from a config file / env variable to struct `dst`.
func BindSourcesToStructs(service string, dst interface{}) error {
fileSystem := os.DirFS("/")
filePath := strings.TrimLeft(path.Join(defaults.BaseConfigPath(), service+".yaml"), "/")
return bindSourcesToStructs(fileSystem, filePath, service, dst)
}
func bindSourcesToStructs(fileSystem fs.FS, filePath, service string, dst interface{}) error {
cnf := gofig.NewWithOptions(service)
cnf.WithOptions(func(options *gofig.Options) {
options.ParseEnv = true
options.DecoderConfig.TagName = decoderConfigTagName
})
cnf.AddDriver(gooyaml.Driver)
yamlContent, err := fs.ReadFile(fileSystem, filePath)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
_ = cnf.LoadSources("yaml", yamlContent)
err = cnf.BindStruct("", &dst)
if err != nil {
return err
}
return nil
}
// LocalEndpoint returns the local endpoint for a given protocol and address.
// Use it when configuring the reva runtime to get a service endpoint in the same
// runtime, e.g. a gateway talking to an authregistry service.
func LocalEndpoint(protocol, addr string) string {
localEndpoint := addr
switch protocol {
case "tcp":
parts := strings.SplitN(addr, ":", 2)
if len(parts) == 2 {
localEndpoint = "dns:127.0.0.1:" + parts[1]
}
case "unix":
localEndpoint = "unix:" + addr
}
return localEndpoint
}
+188
View File
@@ -0,0 +1,188 @@
package config
import (
"gotest.tools/v3/assert"
"testing"
"testing/fstest"
)
type TestConfig struct {
A string `yaml:"a"`
B string `yaml:"b"`
C string `yaml:"c"`
}
func TestBindSourcesToStructs(t *testing.T) {
// setup test env
yaml := `
a: "${FOO_VAR|no-foo}"
b: "${BAR_VAR|no-bar}"
c: "${CODE_VAR|code}"
`
filePath := "etc/ocis/foo.yaml"
fs := fstest.MapFS{
filePath: {Data: []byte(yaml)},
}
// perform test
c := TestConfig{}
err := bindSourcesToStructs(fs, filePath, "foo", &c)
if err != nil {
t.Error(err)
}
assert.Equal(t, c.A, "no-foo")
assert.Equal(t, c.B, "no-bar")
assert.Equal(t, c.C, "code")
}
func TestBindSourcesToStructs_UnknownFile(t *testing.T) {
// setup test env
filePath := "etc/ocis/foo.yaml"
fs := fstest.MapFS{}
// perform test
c := TestConfig{}
err := bindSourcesToStructs(fs, filePath, "foo", &c)
if err != nil {
t.Error(err)
}
assert.Equal(t, c.A, "")
assert.Equal(t, c.B, "")
assert.Equal(t, c.C, "")
}
func TestBindSourcesToStructs_NoEnvVar(t *testing.T) {
// setup test env
yaml := `
token_manager:
jwt_secret: f%LovwC6xnKkHhc.!.Lp4ZYpQDIO7=d@
machine_auth_api_key: jG&%ZCmCSYqT#Yi$9y28o5u84ZMo2UBf
system_user_api_key: wqxH7FZHv5gifuLIzxqdyaZOCo2s^yl1
transfer_secret: $1^2xspR1WHussV16knaJ$x@X*XLPL%y
system_user_id: 4d0bf32c-83ee-4703-bd43-5e0d6b78215b
admin_user_id: e2fca2b3-992b-47d5-8ecd-3312418ed3d7
graph:
application:
id: 4fdff90c-d13c-47ab-8227-bbd3e6dbee3c
events:
tls_insecure: true
spaces:
insecure: true
identity:
ldap:
bind_password: $ZZ8fSJR&YA02jBBPx6IRCzW0kVZ#cBO
service_account:
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
idp:
ldap:
bind_password: kWJGC6WRY1wQ+e8Bmt--=-3r6gp0CNVS
idm:
service_user_passwords:
admin_password: admin
idm_password: $ZZ8fSJR&YA02jBBPx6IRCzW0kVZ#cBO
reva_password: c68JL=V$c@0GHs!%eSb8r&Ps3rgzKnXJ
idp_password: kWJGC6WRY1wQ+e8Bmt--=-3r6gp0CNVS
proxy:
oidc:
insecure: true
insecure_backends: true
service_account:
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
frontend:
app_handler:
insecure: true
archiver:
insecure: true
service_account:
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
auth_basic:
auth_providers:
ldap:
bind_password: c68JL=V$c@0GHs!%eSb8r&Ps3rgzKnXJ
auth_bearer:
auth_providers:
oidc:
insecure: true
users:
drivers:
ldap:
bind_password: c68JL=V$c@0GHs!%eSb8r&Ps3rgzKnXJ
groups:
drivers:
ldap:
bind_password: c68JL=V$c@0GHs!%eSb8r&Ps3rgzKnXJ
ocdav:
insecure: true
ocm:
service_account:
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
thumbnails:
thumbnail:
transfer_secret: 0N05@YXB.h3e@lsVfksL4YxwQC9aE5A.
webdav_allow_insecure: true
cs3_allow_insecure: true
search:
events:
tls_insecure: true
service_account:
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
audit:
events:
tls_insecure: true
settings:
service_account_ids:
- c05389b2-d94c-4d01-a9b5-a2f97952cc14
sharing:
events:
tls_insecure: true
storage_users:
events:
tls_insecure: true
mount_id: 64fdfb03-22ff-4788-be4d-d7731a475683
service_account:
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
notifications:
notifications:
events:
tls_insecure: true
service_account:
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
nats:
nats:
tls_skip_verify_client_cert: true
gateway:
storage_registry:
storage_users_mount_id: 64fdfb03-22ff-4788-be4d-d7731a475683
userlog:
service_account:
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
auth_service:
service_account:
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
clientlog:
service_account:
service_account_id: c05389b2-d94c-4d01-a9b5-a2f97952cc14
service_account_secret: GW5.x1vDM&+NPRi++eV@.P7Tms4vj!=s
`
filePath := "etc/ocis/foo.yaml"
fs := fstest.MapFS{
filePath: {Data: []byte(yaml)},
}
// perform test
c := Config{}
err := bindSourcesToStructs(fs, filePath, "foo", &c)
if err != nil {
t.Error(err)
}
assert.Equal(t, c.Graph.Identity.LDAP.BindPassword, "$ZZ8fSJR&YA02jBBPx6IRCzW0kVZ#cBO")
}
+140
View File
@@ -0,0 +1,140 @@
package parser
import (
"errors"
"github.com/opencloud-eu/opencloud/pkg/config"
"github.com/opencloud-eu/opencloud/pkg/config/envdecode"
"github.com/opencloud-eu/opencloud/pkg/shared"
"github.com/opencloud-eu/opencloud/pkg/structs"
)
// ParseConfig loads the ocis configuration and
// copies applicable parts into the commons part, from
// where the services can copy it into their own config
func ParseConfig(cfg *config.Config, skipValidate bool) error {
err := config.BindSourcesToStructs("ocis", cfg)
if err != nil {
return err
}
EnsureDefaults(cfg)
// load all env variables relevant to the config in the current context.
if err := envdecode.Decode(cfg); err != nil {
// no environment variable set for this config is an expected "error"
if !errors.Is(err, envdecode.ErrNoTargetFieldsAreSet) {
return err
}
}
EnsureCommons(cfg)
if skipValidate {
return nil
}
return Validate(cfg)
}
// EnsureDefaults ensures that all pointers in the
// oCIS config (not the services configs) are initialized
func EnsureDefaults(cfg *config.Config) {
if cfg.Tracing == nil {
cfg.Tracing = &shared.Tracing{}
}
if cfg.Log == nil {
cfg.Log = &shared.Log{}
}
if cfg.TokenManager == nil {
cfg.TokenManager = &shared.TokenManager{}
}
if cfg.Cache == nil {
cfg.Cache = &shared.Cache{}
}
if cfg.GRPCClientTLS == nil {
cfg.GRPCClientTLS = &shared.GRPCClientTLS{}
}
if cfg.GRPCServiceTLS == nil {
cfg.GRPCServiceTLS = &shared.GRPCServiceTLS{}
}
if cfg.Reva == nil {
cfg.Reva = &shared.Reva{}
}
}
// EnsureCommons copies applicable parts of the oCIS config into the commons part
func EnsureCommons(cfg *config.Config) {
// ensure the commons part is initialized
if cfg.Commons == nil {
cfg.Commons = &shared.Commons{}
}
cfg.Commons.Log = structs.CopyOrZeroValue(cfg.Log)
cfg.Commons.Tracing = structs.CopyOrZeroValue(cfg.Tracing)
cfg.Commons.Cache = structs.CopyOrZeroValue(cfg.Cache)
if cfg.GRPCClientTLS != nil {
cfg.Commons.GRPCClientTLS = cfg.GRPCClientTLS
}
if cfg.GRPCServiceTLS != nil {
cfg.Commons.GRPCServiceTLS = cfg.GRPCServiceTLS
}
cfg.Commons.HTTPServiceTLS = cfg.HTTPServiceTLS
cfg.Commons.TokenManager = structs.CopyOrZeroValue(cfg.TokenManager)
// copy machine auth api key to the commons part if set
if cfg.MachineAuthAPIKey != "" {
cfg.Commons.MachineAuthAPIKey = cfg.MachineAuthAPIKey
}
if cfg.SystemUserAPIKey != "" {
cfg.Commons.SystemUserAPIKey = cfg.SystemUserAPIKey
}
// copy transfer secret to the commons part if set
if cfg.TransferSecret != "" {
cfg.Commons.TransferSecret = cfg.TransferSecret
}
// copy metadata user id to the commons part if set
if cfg.SystemUserID != "" {
cfg.Commons.SystemUserID = cfg.SystemUserID
}
// copy admin user id to the commons part if set
if cfg.AdminUserID != "" {
cfg.Commons.AdminUserID = cfg.AdminUserID
}
if cfg.OcisURL != "" {
cfg.Commons.OcisURL = cfg.OcisURL
}
cfg.Commons.Reva = structs.CopyOrZeroValue(cfg.Reva)
}
// Validate checks that all required configs are set. If a required config value
// is missing an error will be returned.
func Validate(cfg *config.Config) error {
if cfg.TokenManager.JWTSecret == "" {
return shared.MissingJWTTokenError("ocis")
}
if cfg.TransferSecret == "" {
return shared.MissingRevaTransferSecretError("ocis")
}
if cfg.MachineAuthAPIKey == "" {
return shared.MissingMachineAuthApiKeyError("ocis")
}
if cfg.SystemUserID == "" {
return shared.MissingSystemUserID("ocis")
}
return nil
}
+38
View File
@@ -0,0 +1,38 @@
package conversions
// ToPointer converts a value to a pointer
func ToPointer[T any](val T) *T {
return &val
}
// ToValue converts a pointer to a value
func ToValue[T any](ptr *T) T {
if ptr == nil {
var t T
return t
}
return *ptr
}
// ToPointerSlice converts a slice of values to a slice of pointers
func ToPointerSlice[E any](s []E) []*E {
rs := make([]*E, len(s))
for i, v := range s {
rs[i] = ToPointer(v)
}
return rs
}
// ToValueSlice converts a slice of pointers to a slice of values
func ToValueSlice[E any](s []*E) []E {
rs := make([]E, len(s))
for i, v := range s {
rs[i] = ToValue(v)
}
return rs
}
+113
View File
@@ -0,0 +1,113 @@
package conversions_test
import (
"fmt"
"testing"
libregraph "github.com/owncloud/libre-graph-api-go"
"github.com/opencloud-eu/opencloud/pkg/conversions"
)
func checkIdentical[T any](t *testing.T, p T, want string) {
t.Helper()
got := fmt.Sprintf("%T", p)
if got != want {
t.Errorf("want:%q got:%q", want, got)
}
}
func TestToPointer2(t *testing.T) {
checkIdentical(t, conversions.ToPointer("a"), "*string")
checkIdentical(t, conversions.ToPointer(1), "*int")
checkIdentical(t, conversions.ToPointer(-1), "*int")
checkIdentical(t, conversions.ToPointer(float64(1)), "*float64")
checkIdentical(t, conversions.ToPointer(float64(-1)), "*float64")
checkIdentical(t, conversions.ToPointer(libregraph.UnifiedRoleDefinition{}), "*libregraph.UnifiedRoleDefinition")
checkIdentical(t, conversions.ToPointer([]string{"a"}), "*[]string")
checkIdentical(t, conversions.ToPointer([]int{1}), "*[]int")
checkIdentical(t, conversions.ToPointer([]float64{1}), "*[]float64")
checkIdentical(t, conversions.ToPointer([]libregraph.UnifiedRoleDefinition{{}}), "*[]libregraph.UnifiedRoleDefinition")
checkIdentical(t, conversions.ToPointer(conversions.ToPointer("a")), "**string")
checkIdentical(t, conversions.ToPointer(conversions.ToPointer(1)), "**int")
checkIdentical(t, conversions.ToPointer(conversions.ToPointer(-1)), "**int")
checkIdentical(t, conversions.ToPointer(conversions.ToPointer(float64(1))), "**float64")
checkIdentical(t, conversions.ToPointer(conversions.ToPointer(float64(-1))), "**float64")
checkIdentical(t, conversions.ToPointer(conversions.ToPointer(libregraph.UnifiedRoleDefinition{})), "**libregraph.UnifiedRoleDefinition")
checkIdentical(t, conversions.ToPointer(conversions.ToPointer([]string{"a"})), "**[]string")
checkIdentical(t, conversions.ToPointer(conversions.ToPointer([]int{1})), "**[]int")
checkIdentical(t, conversions.ToPointer(conversions.ToPointer([]float64{1})), "**[]float64")
checkIdentical(t, conversions.ToPointer(conversions.ToPointer([]libregraph.UnifiedRoleDefinition{{}})), "**[]libregraph.UnifiedRoleDefinition")
}
func TestToValue(t *testing.T) {
checkIdentical(t, conversions.ToValue((*int)(nil)), "int")
checkIdentical(t, conversions.ToValue((*string)(nil)), "string")
checkIdentical(t, conversions.ToValue((*float64)(nil)), "float64")
checkIdentical(t, conversions.ToValue((*libregraph.UnifiedRoleDefinition)(nil)), "libregraph.UnifiedRoleDefinition")
checkIdentical(t, conversions.ToValue((*[]string)(nil)), "[]string")
checkIdentical(t, conversions.ToValue((*[]libregraph.UnifiedRoleDefinition)(nil)), "[]libregraph.UnifiedRoleDefinition")
checkIdentical(t, conversions.ToValue(conversions.ToPointer("a")), "string")
checkIdentical(t, conversions.ToValue(conversions.ToPointer(1)), "int")
checkIdentical(t, conversions.ToValue(conversions.ToPointer(-1)), "int")
checkIdentical(t, conversions.ToValue(conversions.ToPointer(float64(1))), "float64")
checkIdentical(t, conversions.ToValue(conversions.ToPointer(float64(-1))), "float64")
checkIdentical(t, conversions.ToValue(conversions.ToPointer(libregraph.UnifiedRoleDefinition{})), "libregraph.UnifiedRoleDefinition")
checkIdentical(t, conversions.ToValue(conversions.ToPointer([]string{"a"})), "[]string")
checkIdentical(t, conversions.ToValue(conversions.ToPointer([]int{1})), "[]int")
checkIdentical(t, conversions.ToValue(conversions.ToPointer([]float64{1})), "[]float64")
checkIdentical(t, conversions.ToValue(conversions.ToPointer([]libregraph.UnifiedRoleDefinition{{}})), "[]libregraph.UnifiedRoleDefinition")
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer("a"))), "*string")
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer(1))), "*int")
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer(-1))), "*int")
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer(float64(1)))), "*float64")
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer(float64(-1)))), "*float64")
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer(libregraph.UnifiedRoleDefinition{}))), "*libregraph.UnifiedRoleDefinition")
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer([]string{"a"}))), "*[]string")
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer([]int{1}))), "*[]int")
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer([]float64{1}))), "*[]float64")
checkIdentical(t, conversions.ToValue(conversions.ToPointer(conversions.ToPointer([]libregraph.UnifiedRoleDefinition{{}}))), "*[]libregraph.UnifiedRoleDefinition")
}
func TestToPointerSlice(t *testing.T) {
checkIdentical(t, conversions.ToPointerSlice([]string{"a"}), "[]*string")
checkIdentical(t, conversions.ToPointerSlice([]int{1}), "[]*int")
checkIdentical(t, conversions.ToPointerSlice([]libregraph.UnifiedRoleDefinition{{}}), "[]*libregraph.UnifiedRoleDefinition")
checkIdentical(t, conversions.ToPointerSlice(([]string)(nil)), "[]*string")
checkIdentical(t, conversions.ToPointerSlice(([]int)(nil)), "[]*int")
checkIdentical(t, conversions.ToPointerSlice(([]libregraph.UnifiedRoleDefinition)(nil)), "[]*libregraph.UnifiedRoleDefinition")
checkIdentical(t, conversions.ToPointerSlice([]*string{conversions.ToPointer("a")}), "[]**string")
checkIdentical(t, conversions.ToPointerSlice([]*int{conversions.ToPointer(1)}), "[]**int")
checkIdentical(t, conversions.ToPointerSlice(([]*libregraph.UnifiedRoleDefinition)(nil)), "[]**libregraph.UnifiedRoleDefinition")
checkIdentical(t, conversions.ToPointerSlice(([]*string)(nil)), "[]**string")
checkIdentical(t, conversions.ToPointerSlice(([]*int)(nil)), "[]**int")
checkIdentical(t, conversions.ToPointerSlice(([]*libregraph.UnifiedRoleDefinition)(nil)), "[]**libregraph.UnifiedRoleDefinition")
}
func TestToValueSlice(t *testing.T) {
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice([]string{"a"})), "[]string")
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice([]int{1})), "[]int")
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice([]libregraph.UnifiedRoleDefinition{{}})), "[]libregraph.UnifiedRoleDefinition")
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice(([]string)(nil))), "[]string")
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice(([]int)(nil))), "[]int")
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice(([]libregraph.UnifiedRoleDefinition)(nil))), "[]libregraph.UnifiedRoleDefinition")
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice([]*string{conversions.ToPointer("a")})), "[]*string")
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice([]*int{conversions.ToPointer(1)})), "[]*int")
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice([]*libregraph.UnifiedRoleDefinition{conversions.ToPointer(libregraph.UnifiedRoleDefinition{})})), "[]*libregraph.UnifiedRoleDefinition")
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice(([]*string)(nil))), "[]*string")
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice(([]*int)(nil))), "[]*int")
checkIdentical(t, conversions.ToValueSlice(conversions.ToPointerSlice(([]*libregraph.UnifiedRoleDefinition)(nil))), "[]*libregraph.UnifiedRoleDefinition")
}
+16
View File
@@ -0,0 +1,16 @@
package conversions
import (
"strings"
)
// StringToSliceString splits a string into a slice string according to separator
func StringToSliceString(src string, sep string) []string {
parsed := strings.Split(src, sep)
parts := make([]string, 0, len(parsed))
for _, v := range parsed {
parts = append(parts, strings.TrimSpace(v))
}
return parts
}
+35
View File
@@ -0,0 +1,35 @@
package conversions
import "testing"
var scenarios = []struct {
name string
input string
separator string
out []string
}{
{
"comma separated input",
"a, b, c, d",
",",
[]string{"a", "b", "c", "d"},
}, {
"space separated input",
"a b c d",
" ",
[]string{"a", "b", "c", "d"},
},
}
func TestStringToSliceString(t *testing.T) {
for _, tt := range scenarios {
t.Run(tt.name, func(t *testing.T) {
s := StringToSliceString(tt.input, tt.separator)
for i, v := range tt.out {
if s[i] != v {
t.Errorf("got %q, want %q", s, tt.out)
}
}
})
}
}
+68
View File
@@ -0,0 +1,68 @@
package cors
import (
"github.com/opencloud-eu/opencloud/pkg/log"
)
// Option defines a single option function.
type Option func(o *Options)
// Options defines the available options for this package.
type Options struct {
// Logger to use for logging, must be set
Logger log.Logger
// AllowedOrigins represents the allowed CORS origins
AllowedOrigins []string
// AllowedMethods represents the allowed CORS methods
AllowedMethods []string
// AllowedHeaders represents the allowed CORS headers
AllowedHeaders []string
// AllowCredentials represents the AllowCredentials CORS option
AllowCredentials bool
}
// newAccountOptions initializes the available default options.
func NewOptions(opts ...Option) Options {
opt := Options{}
for _, o := range opts {
o(&opt)
}
return opt
}
// Logger provides a function to set the logger option.
func Logger(l log.Logger) Option {
return func(o *Options) {
o.Logger = l
}
}
// AllowedOrigins provides a function to set the AllowedOrigins option.
func AllowedOrigins(origins []string) Option {
return func(o *Options) {
o.AllowedOrigins = origins
}
}
// AllowedMethods provides a function to set the AllowedMethods option.
func AllowedMethods(methods []string) Option {
return func(o *Options) {
o.AllowedMethods = methods
}
}
// AllowedHeaders provides a function to set the AllowedHeaders option.
func AllowedHeaders(headers []string) Option {
return func(o *Options) {
o.AllowedHeaders = headers
}
}
// AlloweCredentials provides a function to set the AllowCredentials option.
func AllowCredentials(allow bool) Option {
return func(o *Options) {
o.AllowCredentials = allow
}
}
+28
View File
@@ -0,0 +1,28 @@
// Package crypto implements utility functions for handling crypto related files.
package crypto
import (
"bytes"
"crypto/x509"
"errors"
"io"
)
// NewCertPoolFromPEM reads certificates from io.Reader and returns a x509.CertPool
// containing those certificates.
func NewCertPoolFromPEM(crts ...io.Reader) (*x509.CertPool, error) {
certPool := x509.NewCertPool()
var buf bytes.Buffer
for _, c := range crts {
if _, err := io.Copy(&buf, c); err != nil {
return nil, err
}
if !certPool.AppendCertsFromPEM(buf.Bytes()) {
return nil, errors.New("failed to append cert from PEM")
}
buf.Reset()
}
return certPool, nil
}
+13
View File
@@ -0,0 +1,13 @@
package crypto_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestCrypto(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Crypto Suite")
}
+103
View File
@@ -0,0 +1,103 @@
package crypto_test
import (
"fmt"
"os"
"path/filepath"
"github.com/opencloud-eu/opencloud/pkg/crypto"
"github.com/opencloud-eu/opencloud/pkg/log"
. "github.com/onsi/ginkgo/v2"
cfg "github.com/opencloud-eu/opencloud/pkg/config"
)
var _ = Describe("Crypto", func() {
var (
userConfigDir string
err error
config = cfg.DefaultConfig()
)
BeforeEach(func() {
userConfigDir, err = os.UserConfigDir()
if err != nil {
Fail(err.Error())
}
config.Proxy.HTTP.TLSKey = filepath.Join(userConfigDir, "ocis", "server.key")
config.Proxy.HTTP.TLSCert = filepath.Join(userConfigDir, "ocis", "server.cert")
})
AfterEach(func() {
if err := os.RemoveAll(filepath.Join(userConfigDir, "ocis")); err != nil {
Fail(err.Error())
}
})
// This little test should nail down the main functionality of this package, which is providing with a default location
// for the key / certificate pair in case none is configured. Regardless of how the values ended in the configuration,
// the side effects of GenCert is what we want to test.
Describe("Creating key / certificate pair", func() {
Context("For ocis-proxy in the location of the user config directory", func() {
It(fmt.Sprintf("Creates the cert / key tuple in: %s", filepath.Join(userConfigDir, "ocis")), func() {
if err := crypto.GenCert(config.Proxy.HTTP.TLSCert, config.Proxy.HTTP.TLSKey, log.NopLogger()); err != nil {
Fail(err.Error())
}
if _, err := os.Stat(filepath.Join(userConfigDir, "ocis", "server.key")); err != nil {
Fail("key not found at the expected location")
}
if _, err := os.Stat(filepath.Join(userConfigDir, "ocis", "server.cert")); err != nil {
Fail("certificate not found at the expected location")
}
})
})
})
Describe("Creating a new cert pool", func() {
var (
crtOne string
keyOne string
crtTwo string
keyTwo string
)
BeforeEach(func() {
crtOne = filepath.Join(userConfigDir, "ocis/one.cert")
keyOne = filepath.Join(userConfigDir, "ocis/one.key")
crtTwo = filepath.Join(userConfigDir, "ocis/two.cert")
keyTwo = filepath.Join(userConfigDir, "ocis/two.key")
if err := crypto.GenCert(crtOne, keyOne, log.NopLogger()); err != nil {
Fail(err.Error())
}
if err := crypto.GenCert(crtTwo, keyTwo, log.NopLogger()); err != nil {
Fail(err.Error())
}
})
It("handles one certificate", func() {
f1, _ := os.Open(crtOne)
defer f1.Close()
c, err := crypto.NewCertPoolFromPEM(f1)
if err != nil {
Fail(err.Error())
}
if len(c.Subjects()) != 1 {
Fail("expected 1 certificate in the cert pool")
}
})
It("handles multiple certificates", func() {
f1, _ := os.Open(crtOne)
f2, _ := os.Open(crtTwo)
defer f1.Close()
defer f2.Close()
c, err := crypto.NewCertPoolFromPEM(f1, f2)
if err != nil {
Fail(err.Error())
}
if len(c.Subjects()) != 2 {
Fail("expected 2 certificates in the cert pool")
}
})
})
})
+167
View File
@@ -0,0 +1,167 @@
package crypto
import (
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"net"
"os"
"path/filepath"
"github.com/opencloud-eu/opencloud/pkg/log"
mtls "go-micro.dev/v4/util/tls"
)
var (
defaultHosts = []string{"127.0.0.1", "localhost"}
)
// GenCert generates TLS-Certificates. This function has side effects: it creates the respective certificate / key pair at
// the destination locations unless the tuple already exists, if that is the case, this is a noop.
func GenCert(certName string, keyName string, l log.Logger) error {
var pk *rsa.PrivateKey
var err error
pk, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
}
_, certErr := os.Stat(certName)
_, keyErr := os.Stat(keyName)
if certErr == nil || keyErr == nil {
l.Info().Msg(
fmt.Sprintf("%v certificate / key pair already present. skipping acme certificate generation",
filepath.Base(certName)))
return nil
}
if err := persistCertificate(certName, l, pk); err != nil {
l.Fatal().Err(err).Msg("failed to store certificate")
}
if err := persistKey(keyName, l, pk); err != nil {
l.Fatal().Err(err).Msg("failed to store key")
}
return nil
}
// GenTempCertForAddr generates temporary TLS-Certificates in memory.
func GenTempCertForAddr(addr string) (tls.Certificate, error) {
subjects := defaultHosts
if host, _, err := net.SplitHostPort(addr); err == nil && host != "" {
subjects = []string{host}
}
return mtls.Certificate(subjects...)
}
// persistCertificate generates a certificate using pk as private key and proceeds to store it into a file named certName.
func persistCertificate(certName string, l log.Logger, pk interface{}) error {
if err := ensureExistsDir(certName); err != nil {
return fmt.Errorf("creating certificate destination: " + certName)
}
certificate, err := generateCertificate(pk)
if err != nil {
return fmt.Errorf("creating certificate: " + filepath.Dir(certName))
}
certOut, err := os.Create(certName)
if err != nil {
return fmt.Errorf("failed to open `%v` for writing", certName)
}
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certificate})
if err != nil {
return fmt.Errorf("failed to encode certificate")
}
err = certOut.Close()
if err != nil {
return fmt.Errorf("failed to write cert")
}
l.Info().Msg(fmt.Sprintf("written certificate to %v", certName))
return nil
}
// genCert generates a self signed certificate using a random rsa key.
func generateCertificate(pk interface{}) ([]byte, error) {
for _, h := range defaultHosts {
if ip := net.ParseIP(h); ip != nil {
acmeTemplate.IPAddresses = append(acmeTemplate.IPAddresses, ip)
} else {
acmeTemplate.DNSNames = append(acmeTemplate.DNSNames, h)
}
}
return x509.CreateCertificate(rand.Reader, &acmeTemplate, &acmeTemplate, publicKey(pk), pk)
}
// persistKey persists the private key used to generate the certificate at the configured location.
func persistKey(destination string, l log.Logger, pk interface{}) error {
if err := ensureExistsDir(destination); err != nil {
return fmt.Errorf("creating key destination: " + destination)
}
keyOut, err := os.OpenFile(destination, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to open %v for writing", destination)
}
err = pem.Encode(keyOut, pemBlockForKey(pk, l))
if err != nil {
return fmt.Errorf("failed to encode key")
}
err = keyOut.Close()
if err != nil {
return fmt.Errorf("failed to write key")
}
l.Info().Msg(fmt.Sprintf("written key to %v", destination))
return nil
}
func publicKey(pk interface{}) interface{} {
switch k := pk.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
default:
return nil
}
}
func pemBlockForKey(pk interface{}, l log.Logger) *pem.Block {
switch k := pk.(type) {
case *rsa.PrivateKey:
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)}
case *ecdsa.PrivateKey:
b, err := x509.MarshalECPrivateKey(k)
if err != nil {
l.Fatal().Err(err).Msg("Unable to marshal ECDSA private key")
}
return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}
default:
return nil
}
}
func ensureExistsDir(uri string) error {
certPath := filepath.Dir(uri)
if _, err := os.Stat(certPath); os.IsNotExist(err) {
err = os.MkdirAll(certPath, 0700)
if err != nil {
return err
}
}
return nil
}
+155
View File
@@ -0,0 +1,155 @@
package crypto
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"os"
"path/filepath"
"testing"
"github.com/opencloud-eu/opencloud/pkg/log"
)
func TestEnsureExistsDir(t *testing.T) {
var tmpDir = t.TempDir()
type args struct {
uri string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "creates a dir if it does not exist",
args: args{
uri: filepath.Join(tmpDir, "example"),
},
wantErr: false,
},
{
name: "noop if the target directory exists",
args: args{
uri: filepath.Join(tmpDir, "example"),
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := ensureExistsDir(tt.args.uri); (err != nil) != tt.wantErr {
t.Errorf("ensureExistsDir() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestPersistKey(t *testing.T) {
p256 := elliptic.P256()
var (
tmpDir = t.TempDir()
keyPath = filepath.Join(tmpDir, "ocis", "testKey")
rsaPk, _ = rsa.GenerateKey(rand.Reader, 2048)
ecdsaPk, _ = ecdsa.GenerateKey(p256, rand.Reader)
)
type args struct {
keyName string
pk interface{}
}
tests := []struct {
name string
args args
}{
{
name: "writes a private key (rsa) to the specified location",
args: args{
keyName: keyPath,
pk: rsaPk,
},
},
{
name: "writes a private key (ecdsa) to the specified location",
args: args{
keyName: keyPath,
pk: ecdsaPk,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := persistKey(tt.args.keyName, log.NopLogger(), tt.args.pk); err != nil {
t.Error(err)
}
})
// side effect: tt.args.keyName is created
if _, err := os.Stat(tt.args.keyName); err != nil {
t.Errorf("persistKey() error = %v", err)
}
}
}
func TestPersistCertificate(t *testing.T) {
p256 := elliptic.P256()
var (
tmpDir = t.TempDir()
certPath = filepath.Join(tmpDir, "ocis", "testCert")
rsaPk, _ = rsa.GenerateKey(rand.Reader, 2048)
ecdsaPk, _ = ecdsa.GenerateKey(p256, rand.Reader)
)
type args struct {
certName string
pk interface{}
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "store a certificate with an rsa private key",
args: args{
certName: certPath,
pk: rsaPk,
},
wantErr: false,
},
{
name: "store a certificate with an ecdsa private key",
args: args{
certName: certPath,
pk: ecdsaPk,
},
wantErr: false,
},
{
name: "should fail",
args: args{
certName: certPath,
pk: 42,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if err := persistCertificate(tt.args.certName, log.NopLogger(), tt.args.pk); err != nil {
if !tt.wantErr {
t.Error(err)
}
}
})
// side effect: tt.args.keyName is created
if _, err := os.Stat(tt.args.certName); err != nil {
t.Errorf("persistCertificate() error = %v", err)
}
})
}
}
+25
View File
@@ -0,0 +1,25 @@
package crypto
import (
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"time"
)
var serialNumber, _ = rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
var acmeTemplate = x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Acme Corp"},
CommonName: "OCIS",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour * 365),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
+56
View File
@@ -0,0 +1,56 @@
package flags
// OverrideDefaultString checks whether the default value of v is the zero value, if so, ensure the flag has a correct
// value by providing one. A value different than zero would mean that it was read from a config file either from an
// service or from a higher source (i.e: ocis command).
func OverrideDefaultString(v, def string) string {
if v != "" {
return v
}
return def
}
// OverrideDefaultBool checks whether the default value of v is the zero value, if so, ensure the flag has a correct
// value by providing one. A value different than zero would mean that it was read from a config file either from an
// service or from a higher source (i.e: ocis command).
func OverrideDefaultBool(v, def bool) bool {
if v {
return v
}
return def
}
// OverrideDefaultInt checks whether the default value of v is the zero value, if so, ensure the flag has a correct
// value by providing one. A value different than zero would mean that it was read from a config file either from an
// service or from a higher source (i.e: ocis command).
func OverrideDefaultInt(v, def int) int {
if v != 0 {
return v
}
return def
}
// OverrideDefaultInt64 checks whether the default value of v is the zero value, if so, ensure the flag has a correct
// value by providing one. A value different than zero would mean that it was read from a config file either from an
// service or from a higher source (i.e: ocis command).
func OverrideDefaultInt64(v, def int64) int64 {
if v != 0 {
return v
}
return def
}
// OverrideDefaultUint64 checks whether the default value of v is the zero value, if so, ensure the flag has a correct
// value by providing one. A value different than zero would mean that it was read from a config file either from an
// service or from a higher source (i.e: ocis command).
func OverrideDefaultUint64(v, def uint64) uint64 {
if v != 0 {
return v
}
return def
}
+48
View File
@@ -0,0 +1,48 @@
package generators
import (
"crypto/rand"
"math/big"
)
const (
// PasswordChars contains alphanumeric chars (0-9, A-Z, a-z), plus "-=+!@#$%^&*."
PasswordChars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-=+!@#$%^&*."
// AlphaNumChars contains alphanumeric chars (0-9, A-Z, a-z)
AlphaNumChars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
)
// GenerateRandomPassword generates a random password with the given length.
// The password will contain chars picked from the `PasswordChars` constant.
// If an error happens, the string will be empty and the error will be non-nil.
//
// This is equivalent to `GenerateRandomString(PasswordChars, length)`
func GenerateRandomPassword(length int) (string, error) {
return generateString(PasswordChars, length)
}
// GenerateRandomString generates a random string with the given length
// based on the chars provided. You can use `PasswordChars` or `AlphaNumChars`
// constants, or even any other string.
//
// Chars from the provided string will be picked uniformly. The provided
// constants have unique chars, which means that all the chars will have the
// same probability of being picked.
// You can use your own strings to change that probability. For example, using
// "AAAB" you'll have a 75% of probability of getting "A" and 25% of "B"
func GenerateRandomString(chars string, length int) (string, error) {
return generateString(chars, length)
}
func generateString(chars string, length int) (string, error) {
ret := make([]byte, length)
for i := 0; i < length; i++ {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(chars))))
if err != nil {
return "", err
}
ret[i] = chars[num.Int64()]
}
return string(ret), nil
}
+148
View File
@@ -0,0 +1,148 @@
package handlers
import (
"context"
"fmt"
"io"
"maps"
"net"
"net/http"
"strings"
"github.com/opencloud-eu/opencloud/pkg/log"
"golang.org/x/sync/errgroup"
)
// check is a function that performs a check.
type checker func(ctx context.Context) error
// CheckHandlerConfiguration defines the configuration for the CheckHandler.
type CheckHandlerConfiguration struct {
checks map[string]checker
logger log.Logger
limit int
statusFailed int
statusSuccess int
}
// NewCheckHandlerConfiguration initializes a new CheckHandlerConfiguration.
func NewCheckHandlerConfiguration() CheckHandlerConfiguration {
return CheckHandlerConfiguration{
checks: make(map[string]checker),
limit: -1,
statusFailed: http.StatusInternalServerError,
statusSuccess: http.StatusOK,
}
}
// WithLogger sets the logger for the CheckHandlerConfiguration.
func (c CheckHandlerConfiguration) WithLogger(l log.Logger) CheckHandlerConfiguration {
c.logger = l
return c
}
// WithCheck sets a check for the CheckHandlerConfiguration.
func (c CheckHandlerConfiguration) WithCheck(name string, check checker) CheckHandlerConfiguration {
if _, ok := c.checks[name]; ok {
c.logger.Panic().Str("check", name).Msg("check already exists")
}
c.checks = maps.Clone(c.checks) // prevent propagated check duplication, maps are references;
c.checks[name] = check
return c
}
// WithLimit limits the number of active goroutines for the checks to at most n
func (c CheckHandlerConfiguration) WithLimit(n int) CheckHandlerConfiguration {
c.limit = n
return c
}
// WithStatusFailed sets the status code for the failed checks.
func (c CheckHandlerConfiguration) WithStatusFailed(status int) CheckHandlerConfiguration {
c.statusFailed = status
return c
}
// WithStatusSuccess sets the status code for the successful checks.
func (c CheckHandlerConfiguration) WithStatusSuccess(status int) CheckHandlerConfiguration {
c.statusSuccess = status
return c
}
// CheckHandler is a http Handler that performs different checks.
type CheckHandler struct {
conf CheckHandlerConfiguration
}
// NewCheckHandler initializes a new CheckHandler.
func NewCheckHandler(c CheckHandlerConfiguration) *CheckHandler {
c.checks = maps.Clone(c.checks) // prevent check duplication after initialization
return &CheckHandler{
conf: c,
}
}
func (h *CheckHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
g, ctx := errgroup.WithContext(r.Context())
g.SetLimit(h.conf.limit)
for name, check := range h.conf.checks {
checker := check
checkerName := name
g.Go(func() error { // https://go.dev/blog/loopvar-preview per iteration scope since go 1.22
if err := checker(ctx); err != nil { // since go 1.22 for loops have a per-iteration scope instead of per-loop scope, no need to pin the check...
return fmt.Errorf("'%s': %w", checkerName, err)
}
return nil
})
}
status := h.conf.statusSuccess
if err := g.Wait(); err != nil {
status = h.conf.statusFailed
h.conf.logger.Error().Err(err).Msg("check failed")
}
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(status)
if _, err := io.WriteString(w, http.StatusText(status)); err != nil { // io.WriteString should not fail, but if it does, we want to know.
h.conf.logger.Panic().Err(err).Msg("failed to write response")
}
}
// FailSaveAddress replaces wildcard addresses with the outbound IP.
func FailSaveAddress(address string) (string, error) {
if strings.Contains(address, "0.0.0.0") || strings.Contains(address, "::") {
outboundIp, err := getOutBoundIP()
if err != nil {
return "", err
}
address = strings.Replace(address, "0.0.0.0", outboundIp, 1)
address = strings.Replace(address, "::", "["+outboundIp+"]", 1)
address = strings.Replace(address, "[::]", "["+outboundIp+"]", 1)
}
return address, nil
}
// getOutBoundIP returns the outbound IP address.
func getOutBoundIP() (string, error) {
interfacesAddresses, err := net.InterfaceAddrs()
if err != nil {
return "", err
}
for _, address := range interfacesAddresses {
if ipNet, ok := address.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
return ipNet.IP.String(), nil
}
}
}
return "", fmt.Errorf("no IP found")
}
+127
View File
@@ -0,0 +1,127 @@
package handlers_test
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"slices"
"testing"
"github.com/test-go/testify/require"
"github.com/tidwall/gjson"
"github.com/opencloud-eu/opencloud/pkg/handlers"
"github.com/opencloud-eu/opencloud/pkg/log"
)
func TestCheckHandlerConfiguration(t *testing.T) {
nopCheckCounter := 0
nopCheck := func(_ context.Context) error { nopCheckCounter++; return nil }
handlerConfiguration := handlers.NewCheckHandlerConfiguration().WithCheck("check-1", nopCheck)
t.Run("add check", func(t *testing.T) {
localCounter := 0
handlers.NewCheckHandler(handlerConfiguration).ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil))
require.Equal(t, 1, nopCheckCounter)
handlers.NewCheckHandler(handlerConfiguration.WithCheck("check-2", func(_ context.Context) error { localCounter++; return nil })).ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil))
require.Equal(t, 2, nopCheckCounter)
require.Equal(t, 1, localCounter)
})
t.Run("checks are unique", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("checks should be unique")
}
}()
handlerConfiguration.WithCheck("check-1", nopCheck)
require.Equal(t, 3, nopCheckCounter)
})
}
func TestCheckHandler(t *testing.T) {
checkFactory := func(err error) func(ctx context.Context) error {
return func(ctx context.Context) error {
if err != nil {
return err
}
<-ctx.Done()
return nil
}
}
t.Run("passes with custom status", func(t *testing.T) {
rec := httptest.NewRecorder()
handler := handlers.NewCheckHandler(
handlers.
NewCheckHandlerConfiguration().
WithStatusSuccess(http.StatusCreated),
)
handler.ServeHTTP(rec, httptest.NewRequest("GET", "/", nil))
require.Equal(t, http.StatusCreated, rec.Code)
require.Equal(t, http.StatusText(http.StatusCreated), rec.Body.String())
})
t.Run("is not ok if any check fails", func(t *testing.T) {
rec := httptest.NewRecorder()
handler := handlers.NewCheckHandler(
handlers.
NewCheckHandlerConfiguration().
WithCheck("check-1", checkFactory(errors.New("failed"))),
)
handler.ServeHTTP(rec, httptest.NewRequest("GET", "/", nil))
require.Equal(t, http.StatusInternalServerError, rec.Code)
require.Equal(t, http.StatusText(http.StatusInternalServerError), rec.Body.String())
})
t.Run("fails with custom status", func(t *testing.T) {
rec := httptest.NewRecorder()
handler := handlers.NewCheckHandler(
handlers.
NewCheckHandlerConfiguration().
WithCheck("check-1", checkFactory(errors.New("failed"))).
WithStatusFailed(http.StatusTeapot),
)
handler.ServeHTTP(rec, httptest.NewRequest("GET", "/", nil))
require.Equal(t, http.StatusTeapot, rec.Code)
require.Equal(t, http.StatusText(http.StatusTeapot), rec.Body.String())
})
t.Run("exits all other running tests on failure", func(t *testing.T) {
var errs []error
rec := httptest.NewRecorder()
buffer := &bytes.Buffer{}
logger := log.Logger{Logger: log.NewLogger().Output(buffer)}
handler := handlers.NewCheckHandler(
handlers.
NewCheckHandlerConfiguration().
WithLogger(logger).
WithCheck("check-1", func(ctx context.Context) error {
err := checkFactory(nil)(ctx)
errs = append(errs, err)
return err
}).
WithCheck("check-2", func(ctx context.Context) error {
err := checkFactory(errors.New("failed"))(ctx)
errs = append(errs, err)
return err
}).
WithCheck("check-3", func(ctx context.Context) error {
err := checkFactory(nil)(ctx)
errs = append(errs, err)
return err
}),
)
handler.ServeHTTP(rec, httptest.NewRequest("GET", "/", nil))
require.Equal(t, "'check-2': failed", gjson.Get(buffer.String(), "error").String())
require.Equal(t, 1, len(slices.DeleteFunc(errs, func(err error) bool { return err == nil })))
require.Equal(t, 2, len(slices.DeleteFunc(errs, func(err error) bool { return err != nil })))
})
}
+219
View File
@@ -0,0 +1,219 @@
// Package keycloak is a package for keycloak utility functions.
package keycloak
import (
"context"
"crypto/tls"
"fmt"
"github.com/Nerzal/gocloak/v13"
libregraph "github.com/owncloud/libre-graph-api-go"
)
// Some attribute constants.
// TODO: Make these configurable in the future.
const (
_idAttr = "OWNCLOUD_ID"
_userTypeAttr = "OWNCLOUD_USER_TYPE"
)
// ConcreteClient represents a concrete implementation of a keycloak client
type ConcreteClient struct {
keycloak GoCloak
clientID string
clientSecret string
realm string
baseURL string
}
// New instantiates a new keycloak.Backend with a default gocloak client.
func New(
baseURL, clientID, clientSecret, realm string,
insecureSkipVerify bool,
) *ConcreteClient {
gc := gocloak.NewClient(baseURL)
restyClient := gc.RestyClient()
restyClient.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: insecureSkipVerify}) //nolint:gosec
return NewWithClient(gc, baseURL, clientID, clientSecret, realm)
}
// NewWithClient instantiates a new keycloak.Backend with a custom
func NewWithClient(
gocloakClient GoCloak,
baseURL, clientID, clientSecret, realm string,
) *ConcreteClient {
return &ConcreteClient{
keycloak: gocloakClient,
baseURL: baseURL,
clientID: clientID,
clientSecret: clientSecret,
realm: realm,
}
}
// CreateUser creates a user from a libregraph user and returns its *keycloak* ID.
// TODO: For now we only call this from the invitation service where all the attributes are set correctly.
//
// For more wider use, do some sanity checking on the user instance.
func (c *ConcreteClient) CreateUser(ctx context.Context, realm string, user *libregraph.User, userActions []UserAction) (string, error) {
token, err := c.getToken(ctx)
if err != nil {
return "", err
}
req := gocloak.User{
Email: user.Mail,
Enabled: user.AccountEnabled,
Username: &user.OnPremisesSamAccountName,
FirstName: user.GivenName,
LastName: user.Surname,
Attributes: &map[string][]string{
_idAttr: {user.GetId()},
_userTypeAttr: {user.GetUserType()},
},
RequiredActions: convertUserActions(userActions),
}
return c.keycloak.CreateUser(ctx, token.AccessToken, realm, req)
}
// SendActionsMail sends a mail to the user with userID instructing them to do the actions defined in userActions.
func (c *ConcreteClient) SendActionsMail(ctx context.Context, realm, userID string, userActions []UserAction) error {
token, err := c.getToken(ctx)
if err != nil {
return err
}
params := gocloak.ExecuteActionsEmail{
UserID: &userID,
Actions: convertUserActions(userActions),
}
return c.keycloak.ExecuteActionsEmail(ctx, token.AccessToken, realm, params)
}
// getUserByParams looks up a user by the given parameters.
func (c *ConcreteClient) getUserByParams(ctx context.Context, realm string, params gocloak.GetUsersParams) (*libregraph.User, error) {
token, err := c.getToken(ctx)
if err != nil {
return nil, err
}
users, err := c.keycloak.GetUsers(ctx, token.AccessToken, realm, params)
if err != nil {
return nil, err
}
if len(users) == 0 {
return nil, fmt.Errorf("no users found")
}
if len(users) > 1 {
return nil, fmt.Errorf("%d users found", len(users))
}
return c.keycloakUserToLibregraph(users[0]), nil
}
// GetUserByUsername looks up a user by username.
func (c *ConcreteClient) GetUserByUsername(ctx context.Context, realm, username string) (*libregraph.User, error) {
return c.getUserByParams(ctx, realm, gocloak.GetUsersParams{
Username: &username,
})
}
// GetPIIReport returns a structure with all the PII for the user.
func (c *ConcreteClient) GetPIIReport(ctx context.Context, realm, username string) (*PIIReport, error) {
u, err := c.GetUserByUsername(ctx, realm, username)
if err != nil {
return nil, err
}
token, err := c.getToken(ctx)
if err != nil {
return nil, err
}
keycloakID, err := c.getKeyCloakID(u)
if err != nil {
return nil, err
}
sessions, err := c.keycloak.GetUserSessions(ctx, token.AccessToken, realm, keycloakID)
if err != nil {
return nil, err
}
return &PIIReport{
UserData: u,
Sessions: sessions,
}, nil
}
// getToken gets a fresh token for the request.
// TODO: set a token on the struct and check if it's still valid before requesting a new one.
func (c *ConcreteClient) getToken(ctx context.Context) (*gocloak.JWT, error) {
token, err := c.keycloak.LoginClient(ctx, c.clientID, c.clientSecret, c.realm)
if err != nil {
return nil, fmt.Errorf("failed to get token: %w", err)
}
rRes, err := c.keycloak.RetrospectToken(ctx, token.AccessToken, c.clientID, c.clientSecret, c.realm)
if err != nil {
return nil, fmt.Errorf("failed to retrospect token: %w", err)
}
if !*rRes.Active {
return nil, fmt.Errorf("token is not active")
}
return token, nil
}
func (c *ConcreteClient) keycloakUserToLibregraph(u *gocloak.User) *libregraph.User {
var ldapID string
var userType *string
if u.Attributes != nil {
attrs := *u.Attributes
ldapIDs, ok := attrs[_idAttr]
if ok {
ldapID = ldapIDs[0]
}
userTypes, ok := attrs[_userTypeAttr]
if ok {
userType = &userTypes[0]
}
}
return &libregraph.User{
Id: &ldapID,
Mail: u.Email,
GivenName: u.FirstName,
Surname: u.LastName,
AccountEnabled: u.Enabled,
UserType: userType,
Identities: []libregraph.ObjectIdentity{
{
Issuer: &c.baseURL,
IssuerAssignedId: u.ID,
},
},
}
}
func (c *ConcreteClient) getKeyCloakID(u *libregraph.User) (string, error) {
for _, i := range u.Identities {
if *i.Issuer == c.baseURL {
return *i.IssuerAssignedId, nil
}
}
return "", fmt.Errorf("could not find identity for issuer: %s", c.baseURL)
}
func convertUserActions(userActions []UserAction) *[]string {
stringActions := make([]string, len(userActions))
for i, a := range userActions {
stringActions[i] = userActionsToString[a]
}
return &stringActions
}
+18
View File
@@ -0,0 +1,18 @@
package keycloak
import (
"context"
"github.com/Nerzal/gocloak/v13"
)
// GoCloak represents the parts of gocloak.GoCloak that we use, mainly here for mockery.
type GoCloak interface {
CreateUser(ctx context.Context, token, realm string, user gocloak.User) (string, error)
GetUsers(ctx context.Context, token, realm string, params gocloak.GetUsersParams) ([]*gocloak.User, error)
ExecuteActionsEmail(ctx context.Context, token, realm string, params gocloak.ExecuteActionsEmail) error
LoginClient(ctx context.Context, clientID, clientSecret, realm string, scopes ...string) (*gocloak.JWT, error)
RetrospectToken(ctx context.Context, accessToken, clientID, clientSecret, realm string) (*gocloak.IntroSpectTokenResult, error)
GetCredentials(ctx context.Context, accessToken, realm, userID string) ([]*gocloak.CredentialRepresentation, error)
GetUserSessions(ctx context.Context, token, realm, userID string) ([]*gocloak.UserSessionRepresentation, error)
}
+39
View File
@@ -0,0 +1,39 @@
package keycloak
import (
"context"
"github.com/Nerzal/gocloak/v13"
libregraph "github.com/owncloud/libre-graph-api-go"
)
// UserAction defines a type for user actions
type UserAction int8
// An incomplete list of UserActions
const (
// UserActionUpdatePassword sets it that the user needs to change their password.
UserActionUpdatePassword UserAction = iota
// UserActionVerifyEmail sets it that the user needs to verify their email address.
UserActionVerifyEmail
)
// A lookup table to translate user actions to their string equivalents
var userActionsToString = map[UserAction]string{
UserActionUpdatePassword: "UPDATE_PASSWORD",
UserActionVerifyEmail: "VERIFY_EMAIL",
}
// PIIReport is a structure of all the PersonalIdentifiableInformation contained in keycloak.
type PIIReport struct {
UserData *libregraph.User
Sessions []*gocloak.UserSessionRepresentation
}
// Client represents a keycloak client.
type Client interface {
CreateUser(ctx context.Context, realm string, user *libregraph.User, userActions []UserAction) (string, error)
SendActionsMail(ctx context.Context, realm, userID string, userActions []UserAction) error
GetUserByUsername(ctx context.Context, realm, username string) (*libregraph.User, error)
GetPIIReport(ctx context.Context, realm, username string) (*PIIReport, error)
}
+138
View File
@@ -0,0 +1,138 @@
package kql
import (
"fmt"
"time"
"github.com/jinzhu/now"
"github.com/opencloud-eu/opencloud/pkg/ast"
"github.com/opencloud-eu/opencloud/services/search/pkg/query"
)
func toNode[T ast.Node](in interface{}) (T, error) {
var t T
out, ok := in.(T)
if !ok {
return t, fmt.Errorf("can't convert '%T' to '%T'", in, t)
}
return out, nil
}
func toNodes[T ast.Node](in interface{}) ([]T, error) {
switch v := in.(type) {
case T:
return []T{v}, nil
case []T:
return v, nil
case []*ast.OperatorNode, []*ast.DateTimeNode:
return toNodes[T](v)
case []interface{}:
var nodes []T
for _, el := range v {
node, err := toNodes[T](el)
if err != nil {
return nil, err
}
nodes = append(nodes, node...)
}
return nodes, nil
case nil:
return nil, nil
default:
var t T
return nil, fmt.Errorf("can't convert '%T' to '%T'", in, t)
}
}
func toString(in interface{}) (string, error) {
switch v := in.(type) {
case []byte:
return string(v), nil
case []interface{}:
var str string
for i := range v {
sv, err := toString(v[i])
if err != nil {
return "", err
}
str += sv
}
return str, nil
case string:
return v, nil
default:
return "", fmt.Errorf("can't convert '%T' to string", v)
}
}
func toTime(in interface{}) (time.Time, error) {
ts, err := toString(in)
if err != nil {
return time.Time{}, err
}
return now.Parse(ts)
}
func toTimeRange(in interface{}) (*time.Time, *time.Time, error) {
var from, to time.Time
value, err := toString(in)
if err != nil {
return &from, &to, &query.UnsupportedTimeRangeError{}
}
c := &now.Config{
WeekStartDay: time.Monday,
}
n := c.With(timeNow())
switch value {
case "today":
from = n.BeginningOfDay()
to = n.EndOfDay()
case "yesterday":
yesterday := n.With(n.AddDate(0, 0, -1))
from = yesterday.BeginningOfDay()
to = yesterday.EndOfDay()
case "this week":
from = n.BeginningOfWeek()
to = n.EndOfWeek()
case "last week":
lastWeek := n.With(n.AddDate(0, 0, -7))
from = lastWeek.BeginningOfWeek()
to = lastWeek.EndOfWeek()
case "last 7 days":
from = n.With(n.AddDate(0, 0, -6)).BeginningOfDay()
to = n.EndOfDay()
case "this month":
from = n.BeginningOfMonth()
to = n.EndOfMonth()
case "last month":
lastMonth := n.With(n.BeginningOfMonth().AddDate(0, 0, -1))
from = lastMonth.BeginningOfMonth()
to = lastMonth.EndOfMonth()
case "last 30 days":
from = n.With(n.AddDate(0, 0, -29)).BeginningOfDay()
to = n.EndOfDay()
case "this year":
from = n.BeginningOfYear()
to = n.EndOfYear()
case "last year":
lastYear := n.With(n.AddDate(-1, 0, 0))
from = lastYear.BeginningOfYear()
to = lastYear.EndOfYear()
}
if from.IsZero() || to.IsZero() {
return nil, nil, &query.UnsupportedTimeRangeError{}
}
return &from, &to, nil
}
+129
View File
@@ -0,0 +1,129 @@
package kql
import (
"strings"
"github.com/opencloud-eu/opencloud/pkg/ast"
)
// connectNodes connects given nodes
func connectNodes(c Connector, nodes ...ast.Node) []ast.Node {
var connectedNodes []ast.Node
for i := range nodes {
ri := len(nodes) - 1 - i
head := nodes[ri]
if connectionNodes := connectNode(c, head, connectedNodes...); len(connectionNodes) > 0 {
connectedNodes = append(connectionNodes, connectedNodes...)
}
connectedNodes = append([]ast.Node{head}, connectedNodes...)
}
return connectedNodes
}
// connectNode connects a tip node with the rest
func connectNode(c Connector, headNode ast.Node, tailNodes ...ast.Node) []ast.Node {
var nearestNeighborNode ast.Node
var nearestNeighborOperators []*ast.OperatorNode
l:
for _, tailNode := range tailNodes {
switch node := tailNode.(type) {
case *ast.OperatorNode:
nearestNeighborOperators = append(nearestNeighborOperators, node)
default:
nearestNeighborNode = node
break l
}
}
if nearestNeighborNode == nil {
return nil
}
return c.Connect(headNode, nearestNeighborNode, nearestNeighborOperators)
}
// Connector is responsible to decide what node connections are needed
type Connector interface {
Connect(head ast.Node, neighbor ast.Node, connections []*ast.OperatorNode) []ast.Node
}
// DefaultConnector is the default node connector
type DefaultConnector struct {
sameKeyOPValue string
}
// Connect implements the Connector interface and is used to connect the nodes using
// the default logic defined by the kql spec.
func (c DefaultConnector) Connect(head ast.Node, neighbor ast.Node, connections []*ast.OperatorNode) []ast.Node {
switch head.(type) {
case *ast.OperatorNode:
return nil
}
headKey := strings.ToLower(ast.NodeKey(head))
neighborKey := strings.ToLower(ast.NodeKey(neighbor))
connection := &ast.OperatorNode{
Base: &ast.Base{Loc: &ast.Location{Source: &[]string{"implicitly operator"}[0]}},
Value: BoolAND,
}
// if the current node and the neighbor node have the same key
// the connection is of type OR
//
// spec: same
// author:"John Smith" author:"Jane Smith"
// author:"John Smith" OR author:"Jane Smith"
//
// if the nodes have NO key, the edge is a AND connection
//
// spec: same
// cat dog
// cat AND dog
// from the spec:
// To construct complex queries, you can combine multiple
// free-text expressions with KQL query operators.
// If there are multiple free-text expressions without any
// operators in between them, the query behavior is the same
// as using the AND operator.
//
// nodes inside of group node are handled differently,
// if no explicit operator given, it uses AND
//
// spec: same
// author:"John Smith" AND author:"Jane Smith"
// author:("John Smith" "Jane Smith")
if headKey == neighborKey && headKey != "" && neighborKey != "" {
connection.Value = c.sameKeyOPValue
}
// decisions based on nearest neighbor operators
for i, node := range connections {
// consider direct neighbor operator only
if i == 0 {
// no connection is necessary here because an `AND` or `OR` edge is already present
// exit
for _, skipValue := range []string{BoolOR, BoolAND} {
if node.Value == skipValue {
return nil
}
}
// if neighbor node negotiates, an AND edge is needed
//
// spec: same
// cat -dog
// cat AND NOT dog
if node.Value == BoolNOT {
connection.Value = BoolAND
}
}
}
return []ast.Node{connection}
}
+235
View File
@@ -0,0 +1,235 @@
{
package kql
}
////////////////////////////////////////////////////////
// ast
////////////////////////////////////////////////////////
AST <-
n:Nodes {
return buildAST(n, c.text, c.pos)
}
////////////////////////////////////////////////////////
// nodes
////////////////////////////////////////////////////////
Nodes <-
(_ Node)+
Node <-
GroupNode /
PropertyRestrictionNodes /
OperatorBooleanNodes /
FreeTextKeywordNodes
////////////////////////////////////////////////////////
// nesting
////////////////////////////////////////////////////////
GroupNode <-
k:(Char+)? (OperatorColonNode / OperatorEqualNode)? "(" v:Nodes ")" {
return buildGroupNode(k, v, c.text, c.pos)
}
////////////////////////////////////////////////////////
// property restrictions
////////////////////////////////////////////////////////
PropertyRestrictionNodes <-
YesNoPropertyRestrictionNode /
DateTimeRestrictionNode /
TextPropertyRestrictionNode
YesNoPropertyRestrictionNode <-
k:Char+ (OperatorColonNode / OperatorEqualNode) v:("true" / "false"){
return buildBooleanNode(k, v, c.text, c.pos)
}
DateTimeRestrictionNode <-
k:Char+ o:(
OperatorGreaterOrEqualNode /
OperatorLessOrEqualNode /
OperatorGreaterNode /
OperatorLessNode /
OperatorEqualNode /
OperatorColonNode
) '"'? v:(
DateTime /
FullDate /
FullTime
) '"'? {
return buildDateTimeNode(k, o, v, c.text, c.pos)
} /
k:Char+ (
OperatorEqualNode /
OperatorColonNode
) '"'? v:NaturalLanguageDateTime '"'? {
return buildNaturalLanguageDateTimeNodes(k, v, c.text, c.pos)
}
TextPropertyRestrictionNode <-
k:Char+ (OperatorColonNode / OperatorEqualNode) v:(String / [^ ()]+){
return buildStringNode(k, v, c.text, c.pos)
}
////////////////////////////////////////////////////////
// free text-keywords
////////////////////////////////////////////////////////
FreeTextKeywordNodes <-
PhraseNode /
WordNode
PhraseNode <-
OperatorColonNode? _ v:String _ OperatorColonNode? {
return buildStringNode("", v, c.text, c.pos)
}
WordNode <-
OperatorColonNode? _ v:[^ :()]+ _ OperatorColonNode? {
return buildStringNode("", v, c.text, c.pos)
}
////////////////////////////////////////////////////////
// operators
////////////////////////////////////////////////////////
OperatorBooleanNodes <-
OperatorBooleanAndNode /
OperatorBooleanNotNode /
OperatorBooleanOrNode
OperatorBooleanAndNode <-
("AND" / "+") {
return buildOperatorNode(c.text, c.pos)
}
OperatorBooleanNotNode <-
("NOT" / "-") {
return buildOperatorNode(c.text, c.pos)
}
OperatorBooleanOrNode <-
("OR") {
return buildOperatorNode(c.text, c.pos)
}
OperatorColonNode <-
":" {
return buildOperatorNode(c.text, c.pos)
}
OperatorEqualNode <-
"=" {
return buildOperatorNode(c.text, c.pos)
}
OperatorLessNode <-
"<" {
return buildOperatorNode(c.text, c.pos)
}
OperatorLessOrEqualNode <-
"<=" {
return buildOperatorNode(c.text, c.pos)
}
OperatorGreaterNode <-
">" {
return buildOperatorNode(c.text, c.pos)
}
OperatorGreaterOrEqualNode <-
">=" {
return buildOperatorNode(c.text, c.pos)
}
////////////////////////////////////////////////////////
// time
////////////////////////////////////////////////////////
TimeYear <-
Digit Digit Digit Digit {
return c.text, nil
}
TimeMonth <-
Digit Digit {
return c.text, nil
}
TimeDay <-
Digit Digit {
return c.text, nil
}
TimeHour <-
Digit Digit {
return c.text, nil
}
TimeMinute <-
Digit Digit {
return c.text, nil
}
TimeSecond <-
Digit Digit {
return c.text, nil
}
FullDate <-
TimeYear "-" TimeMonth "-" TimeDay {
return c.text, nil
}
FullTime <-
TimeHour ":" TimeMinute ":" TimeSecond ("." Digit+)? ("Z" / ("+" / "-") TimeHour ":" TimeMinute) {
return c.text, nil
}
DateTime <-
FullDate "T" FullTime {
return c.text, nil
}
NaturalLanguageDateTime <-
"today" /
"yesterday" /
"this week" /
"last week" /
"last 7 days" /
"this month" /
"last month" /
"last 30 days" /
"this year" /
"last year" {
return c.text, nil
}
////////////////////////////////////////////////////////
// misc
////////////////////////////////////////////////////////
Char <-
[A-Za-z] {
return c.text, nil
}
String <-
'"' v:[^"]* '"' {
return v, nil
}
Digit <-
[0-9] {
return c.text, nil
}
_ <-
[ \t]* {
return nil, nil
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+27
View File
@@ -0,0 +1,27 @@
/*
Package kql provides the ability to work with kql queries.
Not every aspect of the spec is implemented yet.
The language support will grow over time if needed.
The following spec parts are supported and tested:
- 2.1.2 AND Operator
- 2.1.6 NOT Operator
- 2.1.8 OR Operator
- 2.1.12 Parentheses
- 2.3.5 Date Tokens
- 3.1.11 Implicit Operator
- 3.1.12 Parentheses
- 3.1.2 AND Operator
- 3.1.6 NOT Operator
- 3.1.8 OR Operator
- 3.2.3 Implicit Operator for Property Restriction
- 3.3.1.1.1 Implicit AND Operator
- 3.3.5 Date Tokens
References:
- https://learn.microsoft.com/en-us/sharepoint/dev/general-development/keyword-query-language-kql-syntax-reference
- https://learn.microsoft.com/en-us/openspecs/sharepoint_protocols/ms-kql/3bbf06cd-8fc1-4277-bd92-8661ccd3c9b0
- https://msopenspecs.azureedge.net/files/MS-KQL/%5bMS-KQL%5d.pdf
*/
package kql
+11
View File
@@ -0,0 +1,11 @@
package kql
import (
"time"
)
// PatchTimeNow is here to patch the package time now func,
// which is used in the test suite
func PatchTimeNow(t func() time.Time) {
timeNow = t
}
+209
View File
@@ -0,0 +1,209 @@
package kql
import (
"strings"
"github.com/opencloud-eu/opencloud/pkg/ast"
)
func base(text []byte, pos position) (*ast.Base, error) {
source, err := toString(text)
if err != nil {
return nil, err
}
return &ast.Base{
Loc: &ast.Location{
Start: ast.Position{
Line: pos.line,
Column: pos.col,
},
End: ast.Position{
Line: pos.line,
Column: pos.col + len(text),
},
Source: &source,
},
}, nil
}
func buildAST(n interface{}, text []byte, pos position) (*ast.Ast, error) {
b, err := base(text, pos)
if err != nil {
return nil, err
}
nodes, err := toNodes[ast.Node](n)
if err != nil {
return nil, err
}
a := &ast.Ast{
Base: b,
Nodes: connectNodes(DefaultConnector{sameKeyOPValue: BoolOR}, nodes...),
}
if err := validateAst(a); err != nil {
return nil, err
}
return a, nil
}
func buildStringNode(k, v interface{}, text []byte, pos position) (*ast.StringNode, error) {
b, err := base(text, pos)
if err != nil {
return nil, err
}
key, err := toString(k)
if err != nil {
return nil, err
}
value, err := toString(v)
if err != nil {
return nil, err
}
return &ast.StringNode{
Base: b,
Key: key,
Value: value,
}, nil
}
func buildDateTimeNode(k, o, v interface{}, text []byte, pos position) (*ast.DateTimeNode, error) {
b, err := base(text, pos)
if err != nil {
return nil, err
}
operator, err := toNode[*ast.OperatorNode](o)
if err != nil {
return nil, err
}
key, err := toString(k)
if err != nil {
return nil, err
}
value, err := toTime(v)
if err != nil {
return nil, err
}
return &ast.DateTimeNode{
Base: b,
Key: key,
Operator: operator,
Value: value,
}, nil
}
func buildNaturalLanguageDateTimeNodes(k, v interface{}, text []byte, pos position) ([]ast.Node, error) {
b, err := base(text, pos)
if err != nil {
return nil, err
}
key, err := toString(k)
if err != nil {
return nil, err
}
from, to, err := toTimeRange(v)
if err != nil {
return nil, err
}
return []ast.Node{
&ast.DateTimeNode{
Base: b,
Value: *from,
Key: key,
Operator: &ast.OperatorNode{Value: ">="},
},
&ast.OperatorNode{Value: BoolAND},
&ast.DateTimeNode{
Base: b,
Value: *to,
Key: key,
Operator: &ast.OperatorNode{Value: "<="},
},
}, nil
}
func buildBooleanNode(k, v interface{}, text []byte, pos position) (*ast.BooleanNode, error) {
b, err := base(text, pos)
if err != nil {
return nil, err
}
key, err := toString(k)
if err != nil {
return nil, err
}
value, err := toString(v)
if err != nil {
return nil, err
}
return &ast.BooleanNode{
Base: b,
Key: key,
Value: strings.ToLower(value) == "true",
}, nil
}
func buildOperatorNode(text []byte, pos position) (*ast.OperatorNode, error) {
b, err := base(text, pos)
if err != nil {
return nil, err
}
value, err := toString(text)
if err != nil {
return nil, err
}
switch value {
case "+":
value = BoolAND
case "-":
value = BoolNOT
}
return &ast.OperatorNode{
Base: b,
Value: value,
}, nil
}
func buildGroupNode(k, n interface{}, text []byte, pos position) (*ast.GroupNode, error) {
b, err := base(text, pos)
if err != nil {
return nil, err
}
key, _ := toString(k)
nodes, err := toNodes[ast.Node](n)
if err != nil {
return nil, err
}
gn := &ast.GroupNode{
Base: b,
Key: key,
Nodes: connectNodes(DefaultConnector{sameKeyOPValue: BoolOR}, nodes...),
}
if err := validateGroupNode(gn); err != nil {
return nil, err
}
return gn, nil
}
+3
View File
@@ -0,0 +1,3 @@
package kql
//go:generate go run github.com/mna/pigeon -optimize-grammar -optimize-parser -o dictionary_gen.go dictionary.peg
+49
View File
@@ -0,0 +1,49 @@
// Package kql provides the ability to work with kql queries.
package kql
import (
"errors"
"time"
"github.com/opencloud-eu/opencloud/pkg/ast"
)
// The operator node value definition
const (
// BoolAND connect two nodes with "AND"
BoolAND = "AND"
// BoolOR connect two nodes with "OR"
BoolOR = "OR"
// BoolNOT connect two nodes with "NOT"
BoolNOT = "NOT"
)
// Builder implements kql Builder interface
type Builder struct{}
// Build creates an ast.Ast based on a kql query
func (b Builder) Build(q string) (*ast.Ast, error) {
f, err := Parse("", []byte(q))
if err != nil {
var list errList
errors.As(err, &list)
for _, listError := range list {
var parserError *parserError
switch {
case errors.As(listError, &parserError):
if parserError.Inner != nil {
return nil, parserError.Inner
}
return nil, listError
}
}
}
return f.(*ast.Ast), nil
}
// timeNow mirrors time.Now by default, the only reason why this exists
// is to monkey patch it from the tests. See PatchTimeNow
var timeNow = time.Now
+54
View File
@@ -0,0 +1,54 @@
package kql_test
import (
"testing"
"github.com/opencloud-eu/opencloud/pkg/ast"
"github.com/opencloud-eu/opencloud/pkg/kql"
"github.com/opencloud-eu/opencloud/services/search/pkg/query"
tAssert "github.com/stretchr/testify/assert"
)
func TestNewAST(t *testing.T) {
tests := []struct {
name string
givenQuery string
expectedError error
}{
{
name: "success",
givenQuery: "foo:bar",
},
{
name: "error",
givenQuery: kql.BoolAND,
expectedError: query.StartsWithBinaryOperatorError{
Node: &ast.OperatorNode{Value: kql.BoolAND},
},
},
}
assert := tAssert.New(t)
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
got, err := kql.Builder{}.Build(tt.givenQuery)
if tt.expectedError != nil {
if tt.expectedError.Error() != "" {
assert.Equal(err.Error(), tt.expectedError.Error())
} else {
assert.NotNil(err)
}
assert.Nil(got)
return
}
assert.Nil(err)
assert.NotNil(got)
})
}
}
+37
View File
@@ -0,0 +1,37 @@
package kql
import (
"github.com/opencloud-eu/opencloud/pkg/ast"
"github.com/opencloud-eu/opencloud/services/search/pkg/query"
)
func validateAst(a *ast.Ast) error {
switch node := a.Nodes[0].(type) {
case *ast.OperatorNode:
switch node.Value {
case BoolAND, BoolOR:
return &query.StartsWithBinaryOperatorError{Node: node}
}
}
return nil
}
func validateGroupNode(n *ast.GroupNode) error {
switch node := n.Nodes[0].(type) {
case *ast.OperatorNode:
switch node.Value {
case BoolAND, BoolOR:
return &query.StartsWithBinaryOperatorError{Node: node}
}
}
if n.Key != "" {
for _, node := range n.Nodes {
if ast.NodeKey(node) != "" {
return &query.NamedGroupInvalidNodesError{Node: node}
}
}
}
return nil
}
+345
View File
@@ -0,0 +1,345 @@
// package l10n holds translation mechanics that are used by user facing services (notifications, userlog, graph)
package l10n
import (
"context"
"errors"
"io/fs"
"os"
"reflect"
"github.com/leonelquinteros/gotext"
"github.com/opencloud-eu/opencloud/pkg/middleware"
settingssvc "github.com/opencloud-eu/opencloud/protogen/gen/ocis/services/settings/v0"
micrometadata "go-micro.dev/v4/metadata"
)
var (
// HeaderAcceptLanguage is the header key for the accept-language header
HeaderAcceptLanguage = "Accept-Language"
// ErrUnsupportedType is returned when the type is not supported
ErrUnsupportedType = errors.New("unsupported type")
)
// Template marks a string as translatable
func Template(s string) string { return s }
// Translator is able to translate strings
type Translator struct {
fs fs.FS
defaultLocale string
domain string
}
// NewTranslator creates a Translator with library path and language code and load default domain
func NewTranslator(defaultLocale string, domain string, fsys fs.FS) Translator {
return Translator{
fs: fsys,
defaultLocale: defaultLocale,
domain: domain,
}
}
// NewTranslatorFromCommonConfig creates a new Translator from legacy config
func NewTranslatorFromCommonConfig(defaultLocale string, domain string, path string, fsys fs.FS, fsSubPath string) Translator {
var filesystem fs.FS
if path == "" {
filesystem, _ = fs.Sub(fsys, fsSubPath)
} else { // use custom path instead
filesystem = os.DirFS(path)
}
return NewTranslator(defaultLocale, domain, filesystem)
}
// Translate translates a string to the locale
func (t Translator) Translate(str, locale string) string {
return t.Locale(locale).Get(str)
}
// Locale returns the gotext.Locale, use `.Get` method to translate strings
func (t Translator) Locale(locale string) *gotext.Locale {
l := gotext.NewLocaleFS(locale, t.fs)
l.AddDomain(t.domain) // make domain configurable only if needed
if locale != "en" && len(l.GetTranslations()) == 0 {
l = gotext.NewLocaleFS(t.defaultLocale, t.fs)
l.AddDomain(t.domain) // make domain configurable only if needed
}
return l
}
// TranslateEntity function provides the generic way to translate a struct, array or slice.
// Support for maps is also provided, but non-pointer values will not work.
// The function also takes the entity with fields to translate.
// The function supports nested structs and slices of structs.
/*
tr := NewTranslator("en", _domain, _fsys)
// a slice of translatables can be passed directly
val := []string{"description", "display name"}
err := tr.TranslateEntity(tr, s, val)
// string maps work the same way
val := map[string]string{
"entryOne": "description",
"entryTwo": "display name",
}
err := TranslateEntity(tr, val)
// struct fields need to be specified
type Struct struct {
Description string
DisplayName string
MetaInformation string
}
val := Struct{}
err := TranslateEntity(tr, val,
l10n.TranslateField("Description"),
l10n.TranslateField("DisplayName"),
)
// nested structures are supported
type InnerStruct struct {
Description string
Roles []string
}
type OuterStruct struct {
DisplayName string
First InnerStruct
Others map[string]InnerStruct
}
val := OuterStruct{}
err := TranslateEntity(tr, val,
l10n.TranslateField("DisplayName"),
l10n.TranslateStruct("First",
l10n.TranslateField("Description"),
l10n.TranslateEach("Roles"),
),
l10n.TranslateMap("Others",
l10n.TranslateField("Description"),
},
*/
func (t Translator) TranslateEntity(locale string, entity any, opts ...TranslateOption) error {
return TranslateEntity(t.Locale(locale).Get, entity, opts...)
}
// MustGetUserLocale returns the locale the user wants to use, omitting errors
func MustGetUserLocale(ctx context.Context, userID string, preferedLang string, vc settingssvc.ValueService) string {
if preferedLang != "" {
return preferedLang
}
locale, _ := GetUserLocale(ctx, userID, vc)
return locale
}
// GetUserLocale returns the locale of the user
func GetUserLocale(ctx context.Context, userID string, vc settingssvc.ValueService) (string, error) {
resp, err := vc.GetValueByUniqueIdentifiers(
micrometadata.Set(ctx, middleware.AccountID, userID),
&settingssvc.GetValueByUniqueIdentifiersRequest{
AccountUuid: userID,
// this defaults.SettingUUIDProfileLanguage. Copied here to avoid import cycles.
SettingId: "aa8cfbe5-95d4-4f7e-a032-c3c01f5f062f",
},
)
if err != nil {
return "", err
}
val := resp.GetValue().GetValue().GetListValue().GetValues()
if len(val) == 0 {
return "", errors.New("no language setting found")
}
return val[0].GetStringValue(), nil
}
// TranslateOption is used to specify fields in structs to translate
type TranslateOption func() (string, FieldType, []TranslateOption)
// FieldType is used to specify the type of field to translate
type FieldType int
const (
// FieldTypeString is a string field
FieldTypeString FieldType = iota
// FieldTypeStruct is a struct field
FieldTypeStruct
// FieldTypeIterable is a slice or array field
FieldTypeIterable
// FieldTypeMap is a map field
FieldTypeMap
)
// TranslateField function provides the generic way to translate the necessary field in composite entities.
func TranslateField(fieldName string) TranslateOption {
return func() (string, FieldType, []TranslateOption) {
return fieldName, FieldTypeString, nil
}
}
// TranslateStruct function provides the generic way to translate the nested fields in composite entities.
func TranslateStruct(fieldName string, args ...TranslateOption) TranslateOption {
return func() (string, FieldType, []TranslateOption) {
return fieldName, FieldTypeStruct, args
}
}
// TranslateEach function provides the generic way to translate the necessary fields in slices or nested entities.
func TranslateEach(fieldName string, args ...TranslateOption) TranslateOption {
return func() (string, FieldType, []TranslateOption) {
return fieldName, FieldTypeIterable, args
}
}
// TranslateMap function provides the generic way to translate the necessary fields in maps.
func TranslateMap(fieldName string, args ...TranslateOption) TranslateOption {
return func() (string, FieldType, []TranslateOption) {
return fieldName, FieldTypeMap, args
}
}
// TranslateEntity translates a slice, array or struct
// See Translator.TranslateEntity for more information
func TranslateEntity(tr func(string, ...any) string, entity any, opts ...TranslateOption) error {
value := reflect.ValueOf(entity)
value, ok := cleanValue(value)
if !ok {
return errors.New("entity is not valid")
}
switch value.Kind() {
case reflect.Struct:
rangeOverArgs(tr, value, opts...)
case reflect.Slice, reflect.Array, reflect.Map:
translateEach(tr, value, opts...)
case reflect.String:
translateField(tr, value)
default:
return ErrUnsupportedType
}
return nil
}
func translateEach(tr func(string, ...any) string, value reflect.Value, args ...TranslateOption) {
value, ok := cleanValue(value)
if !ok {
return
}
switch value.Kind() {
case reflect.Array, reflect.Slice:
for i := 0; i < value.Len(); i++ {
v := value.Index(i)
switch v.Kind() {
case reflect.Struct, reflect.Ptr:
rangeOverArgs(tr, v, args...)
case reflect.String:
translateField(tr, v)
case reflect.Slice, reflect.Array, reflect.Map:
translateEach(tr, v, args...)
}
}
case reflect.Map:
for _, k := range value.MapKeys() {
v := value.MapIndex(k)
switch v.Kind() {
case reflect.Struct:
// FIXME: add support for non-pointer values
case reflect.Pointer:
rangeOverArgs(tr, v, args...)
case reflect.String:
if nv := tr(v.String()); nv != "" {
value.SetMapIndex(k, reflect.ValueOf(nv))
}
case reflect.Slice, reflect.Array, reflect.Map:
translateEach(tr, v, args...)
}
}
}
}
func rangeOverArgs(tr func(string, ...any) string, value reflect.Value, args ...TranslateOption) {
value, ok := cleanValue(value)
if !ok {
return
}
for _, arg := range args {
fieldName, fieldType, opts := arg()
switch fieldType {
case FieldTypeString:
f := value.FieldByName(fieldName)
translateField(tr, f)
case FieldTypeStruct:
innerValue := value.FieldByName(fieldName)
if !innerValue.IsValid() || !isStruct(innerValue) {
return
}
rangeOverArgs(tr, innerValue, opts...)
case FieldTypeIterable:
innerValue := value.FieldByName(fieldName)
if !innerValue.IsValid() {
return
}
if kind := innerValue.Kind(); kind != reflect.Array && kind != reflect.Slice {
return
}
translateEach(tr, innerValue, opts...)
case FieldTypeMap:
innerValue := value.FieldByName(fieldName)
if !innerValue.IsValid() {
return
}
if kind := innerValue.Kind(); kind != reflect.Map {
return
}
translateEach(tr, innerValue, opts...)
}
}
}
func translateField(tr func(string, ...any) string, f reflect.Value) {
if f.IsValid() {
if f.Kind() == reflect.Ptr {
if f.IsNil() {
return
}
f = f.Elem()
}
// A Value can be changed only if it is
// addressable and was not obtained by
// the use of unexported struct fields.
if f.CanSet() {
// change value
if f.Kind() == reflect.String {
val := tr(f.String())
if val == "" {
return
}
f.SetString(val)
}
}
}
}
func isStruct(r reflect.Value) bool {
if r.Kind() == reflect.Ptr {
r = r.Elem()
}
return r.Kind() == reflect.Struct
}
func cleanValue(v reflect.Value) (reflect.Value, bool) {
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
if v.IsNil() {
return v, false
}
v = v.Elem()
}
if !v.IsValid() {
return v, false
}
return v, true
}
+414
View File
@@ -0,0 +1,414 @@
package l10n
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestTranslateStruct(t *testing.T) {
type InnerStruct struct {
Description string
DisplayName *string
}
type TopLevelStruct struct {
Description string
DisplayName *string
SubStruct *InnerStruct
}
type WrapperStruct struct {
Description string
StructList []*InnerStruct
}
toStrPointer := func(str string) *string {
return &str
}
tests := []struct {
name string
entity any
args []TranslateOption
expected any
wantErr bool
}{
{
name: "top level slice of struct",
entity: []*InnerStruct{
{
Description: "inner 1",
DisplayName: toStrPointer("innerDisplayName 1"),
},
{
Description: "inner 2",
DisplayName: toStrPointer("innerDisplayName 2"),
},
},
args: []TranslateOption{
TranslateField("Description"),
TranslateField("DisplayName"),
},
expected: []*InnerStruct{
{
Description: "new Inner 1",
DisplayName: toStrPointer("new InnerDisplayName 1"),
},
{
Description: "new Inner 2",
DisplayName: toStrPointer("new InnerDisplayName 2"),
},
},
},
{
name: "top level slice of string",
entity: []string{
"inner 1",
"inner 2",
},
expected: []string{
"new Inner 1",
"new Inner 2",
},
},
{
name: "top level slice of struct",
entity: []*TopLevelStruct{
{
Description: "inner 1",
DisplayName: toStrPointer("innerDisplayName 1"),
SubStruct: &InnerStruct{
Description: "inner",
DisplayName: toStrPointer("innerDisplayName"),
},
},
{
Description: "inner 2",
DisplayName: toStrPointer("innerDisplayName 2"),
},
},
args: []TranslateOption{
TranslateField("Description"),
TranslateField("DisplayName"),
TranslateStruct("SubStruct",
TranslateField("Description"),
TranslateField("DisplayName"),
),
},
expected: []*TopLevelStruct{
{
Description: "new Inner 1",
DisplayName: toStrPointer("new InnerDisplayName 1"),
SubStruct: &InnerStruct{
Description: "new Inner",
DisplayName: toStrPointer("new InnerDisplayName"),
},
},
{
Description: "new Inner 2",
DisplayName: toStrPointer("new InnerDisplayName 2"),
},
},
},
{
name: "wrapped struct full",
entity: &WrapperStruct{
StructList: []*InnerStruct{
{
Description: "inner 1",
DisplayName: toStrPointer("innerDisplayName 1"),
},
{
Description: "inner 2",
DisplayName: toStrPointer("innerDisplayName 2"),
},
},
},
args: []TranslateOption{
TranslateEach("StructList",
TranslateField("Description"),
TranslateField("DisplayName"),
),
},
expected: &WrapperStruct{
StructList: []*InnerStruct{
{
Description: "new Inner 1",
DisplayName: toStrPointer("new InnerDisplayName 1"),
},
{
Description: "new Inner 2",
DisplayName: toStrPointer("new InnerDisplayName 2"),
},
},
},
},
{
name: "empty struct, NotExistingSubStructName",
entity: &TopLevelStruct{},
args: []TranslateOption{
TranslateField("Description"),
TranslateField("DisplayName"),
TranslateStruct("NotExistingSubStructName",
TranslateField("Description"),
TranslateField("DisplayName"),
),
},
expected: &TopLevelStruct{},
},
{
name: "empty struct",
entity: &TopLevelStruct{},
args: []TranslateOption{
TranslateField("Description"),
TranslateField("DisplayName"),
TranslateStruct("SubStruct",
TranslateField("Description"),
TranslateField("DisplayName"),
),
},
expected: &TopLevelStruct{},
},
{
name: "empty struct, not existing field",
entity: &TopLevelStruct{
Description: "description",
DisplayName: toStrPointer("displayName"),
},
args: []TranslateOption{
TranslateField("NotExistingFieldName"),
TranslateStruct("SubStruct",
TranslateField("NotExistingFieldName"),
),
},
expected: &TopLevelStruct{
Description: "description",
DisplayName: toStrPointer("displayName"),
},
},
{
name: "inner struct DisplayName empy",
entity: &TopLevelStruct{
Description: "description",
DisplayName: toStrPointer("displayName"),
},
args: []TranslateOption{
TranslateField("Description"),
TranslateField("DisplayName"),
TranslateStruct("SubStruct",
TranslateField("Description"),
TranslateField("DisplayName"),
),
},
expected: &TopLevelStruct{
Description: "new Description",
DisplayName: toStrPointer("new DisplayName"),
},
},
{
name: "inner struct full",
entity: &TopLevelStruct{
Description: "description",
DisplayName: toStrPointer("displayName"),
},
args: []TranslateOption{
TranslateField("Description"),
TranslateField("DisplayName"),
TranslateStruct("SubStruct",
TranslateField("Description"),
TranslateField("DisplayName"),
),
},
expected: &TopLevelStruct{
Description: "new Description",
DisplayName: toStrPointer("new DisplayName"),
},
},
{
name: "full struct",
entity: &TopLevelStruct{
Description: "description",
DisplayName: toStrPointer("displayName"),
SubStruct: &InnerStruct{
Description: "inner",
DisplayName: toStrPointer("innerDisplayName"),
},
},
args: []TranslateOption{
TranslateField("Description"),
TranslateField("DisplayName"),
TranslateStruct("SubStruct",
TranslateField("Description"),
TranslateField("DisplayName"),
),
},
expected: &TopLevelStruct{
Description: "new Description",
DisplayName: toStrPointer("new DisplayName"),
SubStruct: &InnerStruct{
Description: "new Inner",
DisplayName: toStrPointer("new InnerDisplayName"),
},
},
},
{
name: "nil",
wantErr: true,
},
{
name: "empty slice",
wantErr: true,
},
{
name: "string slice",
entity: []string{"description", "inner"},
expected: []string{"new Description", "new Inner"},
},
{
name: "string map",
entity: map[string]string{
"entryOne": "description",
"entryTwo": "inner",
},
expected: map[string]string{
"entryOne": "new Description",
"entryTwo": "new Inner",
},
},
{
name: "pointer struct map",
entity: map[string]*InnerStruct{
"entryOne": {Description: "description", DisplayName: toStrPointer("displayName")},
"entryTwo": {Description: "inner", DisplayName: toStrPointer("innerDisplayName")},
},
args: []TranslateOption{
TranslateField("Description"),
TranslateField("DisplayName"),
},
expected: map[string]*InnerStruct{
"entryOne": {Description: "new Description", DisplayName: toStrPointer("new DisplayName")},
"entryTwo": {Description: "new Inner", DisplayName: toStrPointer("new InnerDisplayName")},
},
},
/* FIXME: non pointer maps are currently not working
{
name: "struct map",
entity: map[string]InnerStruct{
"entryOne": {Description: "description", DisplayName: toStrPointer("displayName")},
"entryTwo": {Description: "inner", DisplayName: toStrPointer("innerDisplayName")},
},
args: []TranslateOption{
TranslateField("Description"),
TranslateField("DisplayName"),
},
expected: map[string]InnerStruct{
"entryOne": {Description: "new Description", DisplayName: toStrPointer("new DisplayName")},
"entryTwo": {Description: "new Inner", DisplayName: toStrPointer("new InnerDisplayName")},
},
},
*/
{
name: "slice map",
entity: map[string][]string{
"entryOne": {"description", "inner"},
"entryTwo": {"inner 2", "innerDisplayName 2"},
},
expected: map[string][]string{
"entryOne": {"new Description", "new Inner"},
"entryTwo": {"new Inner 2", "new InnerDisplayName 2"},
},
},
{
name: "double slice",
entity: [][]string{
{"description", "inner"},
{"inner 2", "innerDisplayName 2"},
},
expected: [][]string{
{"new Description", "new Inner"},
{"new Inner 2", "new InnerDisplayName 2"},
},
},
{
name: "nested structs",
entity: [][]*InnerStruct{
{
&InnerStruct{Description: "description", DisplayName: toStrPointer("displayName")},
&InnerStruct{Description: "inner", DisplayName: toStrPointer("innerDisplayName")},
},
{
&InnerStruct{Description: "inner 2", DisplayName: toStrPointer("innerDisplayName 2")},
},
},
args: []TranslateOption{
TranslateField("Description"),
TranslateField("DisplayName"),
},
expected: [][]*InnerStruct{
{
&InnerStruct{Description: "new Description", DisplayName: toStrPointer("new DisplayName")},
&InnerStruct{Description: "new Inner", DisplayName: toStrPointer("new InnerDisplayName")},
},
{
&InnerStruct{Description: "new Inner 2", DisplayName: toStrPointer("new InnerDisplayName 2")},
},
},
},
{
name: "double mapslices",
entity: []map[string][]string{
{
"entryOne": {"inner 1", "innerDisplayName 1"},
"entryTwo": {"inner 2", "innerDisplayName 2"},
},
{
"entryOne": {"description", "displayName"},
},
},
expected: []map[string][]string{
{
"entryOne": {"new Inner 1", "new InnerDisplayName 1"},
"entryTwo": {"new Inner 2", "new InnerDisplayName 2"},
},
{
"entryOne": {"new Description", "new DisplayName"},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := TranslateEntity(mock(), tt.entity, tt.args...)
if (err != nil) != tt.wantErr {
t.Errorf("TranslateEntity() error = %v, wantErr %v", err, tt.wantErr)
}
assert.Equal(t, tt.expected, tt.entity)
})
}
}
func mock() func(string, ...interface{}) string {
return func(s string, i ...interface{}) string {
switch s {
case "description":
return "new Description"
case "displayName":
return "new DisplayName"
case "inner":
return "new Inner"
case "innerDisplayName":
return "new InnerDisplayName"
case "inner 1":
return "new Inner 1"
case "innerDisplayName 1":
return "new InnerDisplayName 1"
case "inner 2":
return "new Inner 2"
case "innerDisplayName 2":
return "new InnerDisplayName 2"
}
return s
}
}
+39
View File
@@ -0,0 +1,39 @@
package ldap
import (
"crypto/x509"
"errors"
"os"
"time"
"github.com/opencloud-eu/opencloud/pkg/log"
)
const (
caCheckRetries = 3
caCheckSleep = 2
)
func WaitForCA(log log.Logger, insecure bool, caCert string) error {
if !insecure && caCert != "" {
for i := 0; i < caCheckRetries; i++ {
if _, err := os.Stat(caCert); err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
// Check if this actually is a CA cert. We need to retry here as well
// as the file might exist already, but have no contents yet.
certs := x509.NewCertPool()
pemData, err := os.ReadFile(caCert)
if err != nil {
log.Debug().Err(err).Str("LDAP CACert", caCert).Msg("Error reading CA")
} else if !certs.AppendCertsFromPEM(pemData) {
log.Debug().Str("LDAP CAcert", caCert).Msg("Failed to append CA to pool")
} else {
return nil
}
time.Sleep(caCheckSleep * time.Second)
log.Warn().Str("LDAP CACert", caCert).Msgf("CA cert file is not ready yet. Waiting %d seconds for it to appear.", caCheckSleep)
}
}
return nil
}
+13
View File
@@ -0,0 +1,13 @@
package ldap_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestLdap(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Ldap Suite")
}
+145
View File
@@ -0,0 +1,145 @@
package log
import (
"context"
"fmt"
"os"
"runtime"
"strings"
"time"
chimiddleware "github.com/go-chi/chi/v5/middleware"
mzlog "github.com/go-micro/plugins/v4/logger/zerolog"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"go-micro.dev/v4/logger"
)
var (
RequestIDString = "request-id"
)
func init() {
// this is ugly, but "logger.DefaultLogger" is a global variable, and we need to set it _before_ anybody uses it
setMicroLogger()
}
// for logging reasons we don't want the same logging level on both oCIS and micro. As a framework builder we do not
// want to expose to the end user the internal framework logs unless explicitly specified.
func setMicroLogger() {
if os.Getenv("MICRO_LOG_LEVEL") == "" {
_ = os.Setenv("MICRO_LOG_LEVEL", "error")
}
lev, err := zerolog.ParseLevel(os.Getenv("MICRO_LOG_LEVEL"))
if err != nil {
lev = zerolog.ErrorLevel
}
logger.DefaultLogger = mzlog.NewLogger(
logger.WithLevel(logger.Level(lev)),
logger.WithFields(map[string]interface{}{
"system": "go-micro",
}),
)
}
// Logger simply wraps the zerolog logger.
type Logger struct {
zerolog.Logger
}
// NopLogger initializes a no-operation logger.
func NopLogger() Logger {
return Logger{zerolog.Nop()}
}
type LineInfoHook struct{}
// Run is a hook to add line info to log messages.
// I found the zerolog example for this here:
// https://github.com/rs/zerolog/issues/22#issuecomment-1127295489
func (h LineInfoHook) Run(e *zerolog.Event, _ zerolog.Level, _ string) {
_, file, line, ok := runtime.Caller(3)
if ok {
e.Str("line", fmt.Sprintf("%s:%d", file, line))
}
}
// NewLogger initializes a new logger instance.
func NewLogger(opts ...Option) Logger {
options := newOptions(opts...)
// set GlobalLevel() to the minimum value -1 = TraceLevel, so that only the services' log level matter
zerolog.SetGlobalLevel(zerolog.TraceLevel)
var logLevel zerolog.Level
switch strings.ToLower(options.Level) {
case "panic":
logLevel = zerolog.PanicLevel
case "fatal":
logLevel = zerolog.FatalLevel
case "error":
logLevel = zerolog.ErrorLevel
case "warn":
logLevel = zerolog.WarnLevel
case "info":
logLevel = zerolog.InfoLevel
case "debug":
logLevel = zerolog.DebugLevel
case "trace":
logLevel = zerolog.TraceLevel
default:
logLevel = zerolog.ErrorLevel
}
var l zerolog.Logger
if options.Pretty {
l = log.Output(
zerolog.NewConsoleWriter(
func(w *zerolog.ConsoleWriter) {
w.TimeFormat = time.RFC3339
w.Out = os.Stderr
w.NoColor = !options.Color
},
),
)
} else if options.File != "" {
f, err := os.OpenFile(options.File, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0644)
if err != nil {
print(fmt.Sprintf("file could not be opened for writing: %s. error: %v", options.File, err))
os.Exit(1)
}
l = l.Output(f)
} else {
l = zerolog.New(os.Stderr)
}
l = l.With().
Str("service", options.Name).
Timestamp().
Logger().Level(logLevel)
if logLevel <= zerolog.InfoLevel {
var lineInfoHook LineInfoHook
l = l.Hook(lineInfoHook)
}
return Logger{
l,
}
}
// SubloggerWithRequestID returns a sub-logger with the x-request-id added to all events
func (l Logger) SubloggerWithRequestID(c context.Context) Logger {
return Logger{
l.With().Str(RequestIDString, chimiddleware.GetReqID(c)).Logger(),
}
}
// Deprecation logs a deprecation message,
// it is used to inform the user that a certain feature is deprecated and will be removed in the future.
// Do not use a logger here because the message MUST be visible independent of the log level.
func Deprecation(a ...any) {
fmt.Printf("\033[1;31mDEPRECATION: %s\033[0m\n", a...)
}
+24
View File
@@ -0,0 +1,24 @@
package log_test
import (
"testing"
"github.com/onsi/gomega"
"github.com/opencloud-eu/opencloud/internal/testenv"
"github.com/opencloud-eu/opencloud/pkg/log"
)
func TestDeprecation(t *testing.T) {
cmdTest := testenv.NewCMDTest(t.Name())
if cmdTest.ShouldRun() {
log.Deprecation("this is a deprecation")
return
}
out, err := cmdTest.Run()
g := gomega.NewWithT(t)
g.Expect(err).ToNot(gomega.HaveOccurred())
g.Expect(string(out)).To(gomega.HavePrefix("\033[1;31mDEPRECATION: this is a deprecation"))
}
+74
View File
@@ -0,0 +1,74 @@
package log
import (
"io"
"github.com/rs/zerolog"
"github.com/sirupsen/logrus"
)
type levelMap map[logrus.Level]zerolog.Level
var levelMapping = levelMap{
logrus.PanicLevel: zerolog.PanicLevel,
logrus.ErrorLevel: zerolog.ErrorLevel,
logrus.TraceLevel: zerolog.TraceLevel,
logrus.DebugLevel: zerolog.DebugLevel,
logrus.WarnLevel: zerolog.WarnLevel,
logrus.InfoLevel: zerolog.InfoLevel,
}
// LogrusWrapper around zerolog. Required because idp uses logrus internally.
type LogrusWrapper struct {
zeroLog *zerolog.Logger
levelMap levelMap
}
// LogrusWrap returns a logrus logger which internally logs to /dev/null. Messages are passed to the
// underlying zerolog via hooks.
func LogrusWrap(zr zerolog.Logger) *logrus.Logger {
lr := logrus.New()
lr.SetOutput(io.Discard)
lr.SetLevel(logrusLevel(zr.GetLevel()))
lr.AddHook(&LogrusWrapper{
zeroLog: &zr,
levelMap: levelMapping,
})
return lr
}
// Levels on which logrus hooks should be triggered
func (h *LogrusWrapper) Levels() []logrus.Level {
return logrus.AllLevels
}
// Fire called by logrus on new message
func (h *LogrusWrapper) Fire(entry *logrus.Entry) error {
h.zeroLog.WithLevel(h.levelMap[entry.Level]).
Fields(zeroLogFields(entry.Data)).
Msg(entry.Message)
return nil
}
// Convert logrus fields to zerolog
func zeroLogFields(fields logrus.Fields) map[string]interface{} {
fm := make(map[string]interface{})
for k, v := range fields {
fm[k] = v
}
return fm
}
// Convert logrus level to zerolog
func logrusLevel(level zerolog.Level) logrus.Level {
for lrLvl, zrLvl := range levelMapping {
if zrLvl == level {
return lrLvl
}
}
panic("Unexpected loglevel")
}
+64
View File
@@ -0,0 +1,64 @@
package log
// Option defines a single option function.
type Option func(o *Options)
// Options defines the available options for this package.
type Options struct {
Name string
Level string
Pretty bool
Color bool
File string
}
// newOptions initializes the available default options.
func newOptions(opts ...Option) Options {
opt := Options{
Name: "ocis",
Level: "info",
Pretty: true,
Color: true,
}
for _, o := range opts {
o(&opt)
}
return opt
}
// Name provides a function to set the name option.
func Name(val string) Option {
return func(o *Options) {
o.Name = val
}
}
// Level provides a function to set the level option.
func Level(val string) Option {
return func(o *Options) {
o.Level = val
}
}
// Pretty provides a function to set the pretty option.
func Pretty(val bool) Option {
return func(o *Options) {
o.Pretty = val
}
}
// Color provides a function to set the color option.
func Color(val bool) Option {
return func(o *Options) {
o.Color = val
}
}
// File provides a function to set the file option.
func File(val string) Option {
return func(o *Options) {
o.File = val
}
}
+139
View File
@@ -0,0 +1,139 @@
// Package markdown allows reading and editing Markdown files
package markdown
import (
"bytes"
"fmt"
"io"
"strings"
)
// Heading represents a markdown Heading
type Heading struct {
Level int // the level of the heading. 1 means it's the H1
Content string // the text of the heading
Header string // the heading itself
}
// MD represents a markdown file
type MD struct {
Headings []Heading
}
// Bytes returns the markdown as []bytes, ignoring errors
func (md MD) Bytes() []byte {
var b bytes.Buffer
_, _ = md.WriteContent(&b)
return b.Bytes()
}
// String returns the markdown as string, ignoring errors
func (md MD) String() string {
var b strings.Builder
_, _ = md.WriteContent(&b)
return b.String()
}
// TocBytes returns the table of contents as []byte, ignoring errors
func (md MD) TocBytes() []byte {
var b bytes.Buffer
_, _ = md.WriteToc(&b)
return b.Bytes()
}
// TocString returns the table of contents as string, ignoring errors
func (md MD) TocString() string {
var b strings.Builder
_, _ = md.WriteToc(&b)
return b.String()
}
// WriteContent writes the MDs content to the given writer
func (md MD) WriteContent(w io.Writer) (int64, error) {
written := int64(0)
write := func(s string) error {
n, err := w.Write([]byte(s))
written += int64(n)
return err
}
for _, h := range md.Headings {
if err := write(strings.Repeat("#", h.Level) + " " + h.Header + "\n"); err != nil {
return written, err
}
if len(h.Content) > 0 {
if err := write(h.Content); err != nil {
return written, err
}
}
}
return written, nil
}
// WriteToc writes the table of contents to the given writer
func (md MD) WriteToc(w io.Writer) (int64, error) {
var written int64
for _, h := range md.Headings {
if h.Level == 1 {
// main title not in toc
continue
}
link := fmt.Sprintf("#%s", strings.ToLower(strings.Replace(h.Header, " ", "-", -1)))
s := fmt.Sprintf("%s* [%s](%s)\n", strings.Repeat(" ", h.Level-2), h.Header, link)
n, err := w.Write([]byte(s))
if err != nil {
return written, err
}
written += int64(n)
}
return written, nil
}
// NewMD parses a new Markdown
func NewMD(b []byte) MD {
var (
md MD
heading Heading
content strings.Builder
)
sendHeading := func() {
if heading.Header != "" {
heading.Content = content.String()
md.Headings = append(md.Headings, heading)
content = strings.Builder{}
}
}
parts := strings.Split("\n"+string(b), "\n#")
numparts := len(parts) - 1
for i, p := range parts {
if i == 0 {
// omit part before first heading
continue
}
all := strings.SplitN(p, "\n", 2)
if len(all) != 2 {
continue
}
head, con := all[0], all[1]
// readd lost "#"
heading = headingFromString("#" + head)
_, _ = content.WriteString(con)
// readd lost "\n" - omit for last part
if i < numparts {
_, _ = content.WriteString("\n")
}
// add heading
sendHeading()
}
return md
}
func headingFromString(s string) Heading {
i := strings.LastIndex(s, "#")
levs, con := s[:i+1], s[i+1:]
return Heading{
Level: len(levs),
Header: strings.TrimPrefix(con, " "),
}
}
+13
View File
@@ -0,0 +1,13 @@
package markdown
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestSearch(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Markdown Suite")
}
+48
View File
@@ -0,0 +1,48 @@
package markdown
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var (
SmallMarkdown = `# Title
some abstract description
## SubTitle 1
subtitle one description
## SubTitle 2
subtitle two description
### Subpoint to SubTitle 2
description to subpoint
more text
`
SmallMD = MD{
Headings: []Heading{
{Level: 1, Header: "Title", Content: "\nsome abstract description\n\n"},
{Level: 2, Header: "SubTitle 1", Content: "\nsubtitle one description\n\n"},
{Level: 2, Header: "SubTitle 2", Content: "subtitle two description\n"},
{Level: 3, Header: "Subpoint to SubTitle 2", Content: "\ndescription to subpoint\n\nmore text\n"},
},
}
)
var _ = Describe("TestMarkdown", func() {
DescribeTable("Conversion works both ways",
func(mdfile string, expectedMD MD) {
md := NewMD([]byte(mdfile))
Expect(len(md.Headings)).To(Equal(len(expectedMD.Headings)))
for i, h := range md.Headings {
Expect(h).To(Equal(expectedMD.Headings[i]))
}
Expect(md.String()).To(Equal(mdfile))
},
Entry("converts a small markdown", SmallMarkdown, SmallMD),
)
})
+78
View File
@@ -0,0 +1,78 @@
package middleware
import (
"encoding/json"
"net/http"
"github.com/cs3org/reva/v2/pkg/auth/scope"
revactx "github.com/cs3org/reva/v2/pkg/ctx"
"github.com/cs3org/reva/v2/pkg/token/manager/jwt"
"github.com/opencloud-eu/opencloud/pkg/account"
"go-micro.dev/v4/metadata"
)
// newAccountOptions initializes the available default options.
func newAccountOptions(opts ...account.Option) account.Options {
opt := account.Options{}
for _, o := range opts {
o(&opt)
}
return opt
}
// AccountID serves as key for the account uuid in the context
const AccountID string = "Account-Id"
// RoleIDs serves as key for the roles in the context
const RoleIDs string = "Role-Ids"
// ExtractAccountUUID provides a middleware to extract the account uuid from the x-access-token header value
// and write it to the context. If there is no x-access-token the middleware is omitted.
func ExtractAccountUUID(opts ...account.Option) func(http.Handler) http.Handler {
opt := newAccountOptions(opts...)
tokenManager, err := jwt.New(map[string]interface{}{
"secret": opt.JWTSecret,
"expires": int64(24 * 60 * 60),
})
if err != nil {
opt.Logger.Fatal().Err(err).Msgf("Could not initialize token-manager")
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("x-access-token")
if len(token) == 0 {
roleIDsJSON, _ := json.Marshal([]string{})
ctx := metadata.Set(r.Context(), RoleIDs, string(roleIDsJSON))
next.ServeHTTP(w, r.WithContext(ctx))
return
}
u, tokenScope, err := tokenManager.DismantleToken(r.Context(), token)
if err != nil {
opt.Logger.Error().Err(err)
return
}
if ok, err := scope.VerifyScope(r.Context(), tokenScope, r); err != nil || !ok {
opt.Logger.Error().Err(err).Msg("verifying scope failed")
return
}
// store user in context for request
ctx := revactx.ContextSetUser(r.Context(), u)
// Important: user.Id.OpaqueId is the AccountUUID. Set this way in the account uuid middleware in ocis-proxy.
// https://github.com/opencloud-eu/opencloud-proxy/blob/ea254d6036592cf9469d757d1295e0c4309d1e63/pkg/middleware/account_uuid.go#L109
// TODO: implement token manager in cs3org/reva that uses generic metadata instead of access token from header.
ctx = metadata.Set(ctx, AccountID, u.Id.OpaqueId)
if u.Opaque != nil {
if roles, ok := u.Opaque.Map["roles"]; ok {
ctx = metadata.Set(ctx, RoleIDs, string(roles.Value))
}
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
+41
View File
@@ -0,0 +1,41 @@
package middleware
import (
"net/http"
"strings"
"time"
"github.com/opencloud-eu/opencloud/pkg/cors"
rscors "github.com/rs/cors"
)
// NoCache writes required cache headers to all requests.
func NoCache(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate, value")
w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT")
w.Header().Set("Last-Modified", time.Now().UTC().Format(http.TimeFormat))
next.ServeHTTP(w, r)
})
}
// Cors writes required cors headers to all requests.
func Cors(opts ...cors.Option) func(http.Handler) http.Handler {
options := cors.NewOptions(opts...)
logger := options.Logger
logger.Debug().
Str("allowed_origins", strings.Join(options.AllowedOrigins, ", ")).
Str("allowed_methods", strings.Join(options.AllowedMethods, ", ")).
Str("allowed_headers", strings.Join(options.AllowedHeaders, ", ")).
Bool("allow_credentials", options.AllowCredentials).
Msg("setup cors middleware")
c := rscors.New(rscors.Options{
AllowedOrigins: options.AllowedOrigins,
AllowedMethods: options.AllowedMethods,
AllowedHeaders: options.AllowedHeaders,
AllowCredentials: options.AllowCredentials,
})
return c.Handler
}
+30
View File
@@ -0,0 +1,30 @@
package middleware
import (
"net/http"
"time"
"github.com/go-chi/chi/v5/middleware"
"github.com/opencloud-eu/opencloud/pkg/log"
)
// Logger is a middleware to log http requests. It uses debug level logging and should be used by all services save the proxy (which uses info level logging).
func Logger(logger log.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
wrap := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
next.ServeHTTP(wrap, r)
logger.Debug().
Str(log.RequestIDString, r.Header.Get("X-Request-ID")).
Str("proto", r.Proto).
Str("method", r.Method).
Int("status", wrap.Status()).
Str("path", r.URL.Path).
Dur("duration", time.Since(start)).
Int("bytes", wrap.BytesWritten()).
Msg("")
})
}
}
+102
View File
@@ -0,0 +1,102 @@
package middleware
import (
"context"
"net/http"
"strings"
"sync"
goidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/opencloud-eu/opencloud/pkg/oidc"
"golang.org/x/oauth2"
)
// newOidcOptions initializes the available default options.
func newOidcOptions(opts ...Option) Options {
opt := Options{}
for _, o := range opts {
o(&opt)
}
return opt
}
// OIDCProvider used to mock the oidc provider during tests
type OIDCProvider interface {
UserInfo(ctx context.Context, ts oauth2.TokenSource) (*goidc.UserInfo, error)
}
// OidcAuth provides a middleware to authenticate a bearer auth with an OpenID Connect identity provider
// It will put all claims provided by the userinfo endpoint in the context
func OidcAuth(opts ...Option) func(http.Handler) http.Handler {
opt := newOidcOptions(opts...)
// TODO use a micro store cache option
providerFunc := func() (OIDCProvider, error) {
// Initialize a provider by specifying the issuer URL.
// it will fetch the keys from the issuer using the .well-known
// endpoint
return goidc.NewProvider(
context.WithValue(context.Background(), oauth2.HTTPClient, &opt.HttpClient),
opt.OidcIssuer,
)
}
var provider OIDCProvider
initializeProviderLock := sync.Mutex{}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
authHeader := r.Header.Get("Authorization")
switch {
case strings.HasPrefix(authHeader, "Bearer "):
if provider == nil {
// lazy initialize provider
initializeProviderLock.Lock()
var err error
// ensure no other request initialized the provider
if provider == nil {
provider, err = providerFunc()
}
initializeProviderLock.Unlock()
if err != nil {
opt.Logger.Error().Err(err).Msg("could not initialize OIDC provider")
w.WriteHeader(http.StatusInternalServerError)
return
}
opt.Logger.Debug().Msg("initialized OIDC provider")
}
oauth2Token := &oauth2.Token{
AccessToken: strings.TrimPrefix(authHeader, "Bearer "),
}
userInfo, err := provider.UserInfo(
context.WithValue(ctx, oauth2.HTTPClient, &opt.HttpClient),
oauth2.StaticTokenSource(oauth2Token),
)
if err != nil {
w.Header().Add("WWW-Authenticate", `Bearer`)
w.WriteHeader(http.StatusUnauthorized)
return
}
claims := map[string]interface{}{}
err = userInfo.Claims(&claims)
if err != nil {
break
}
ctx = oidc.NewContext(ctx, claims)
default:
// do nothing
next.ServeHTTP(w, r.WithContext(ctx))
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
+51
View File
@@ -0,0 +1,51 @@
package middleware
import (
"net/http"
gatewayv1beta1 "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1"
"github.com/opencloud-eu/opencloud/pkg/log"
)
// Option defines a single option function.
type Option func(o *Options)
// Options defines the available options for this package.
type Options struct {
// Logger to use for logging, must be set
Logger log.Logger
// The OpenID Connect Issuer URL
OidcIssuer string
// GatewayAPIClient is a reva gateway client
GatewayAPIClient gatewayv1beta1.GatewayAPIClient
// HttpClient is a http client
HttpClient http.Client
}
// WithLogger provides a function to set the openid connect issuer option.
func WithOidcIssuer(val string) Option {
return func(o *Options) {
o.OidcIssuer = val
}
}
// WithLogger provides a function to set the logger option.
func WithLogger(val log.Logger) Option {
return func(o *Options) {
o.Logger = val
}
}
// WithGatewayAPIClient provides a function to set the reva gateway client option.
func WithGatewayAPIClient(val gatewayv1beta1.GatewayAPIClient) Option {
return func(o *Options) {
o.GatewayAPIClient = val
}
}
// HttpClient provides a function to set the http client option.
func WithHttpClient(val http.Client) Option {
return func(o *Options) {
o.HttpClient = val
}
}
+44
View File
@@ -0,0 +1,44 @@
package middleware
import (
"net/http"
"path"
"strings"
"time"
)
// Static is a middleware that serves static assets.
func Static(root string, fs http.FileSystem, ttl int) func(http.Handler) http.Handler {
if !strings.HasSuffix(root, "/") {
root = root + "/"
}
static := http.StripPrefix(
root,
http.FileServer(
fs,
),
)
// TODO: investigate broken caching - https://github.com/opencloud-eu/opencloud/issues/1094
// we don't have a last modification date of the static assets, so we use the service start date
//lastModified := time.Now().UTC().Format(http.TimeFormat)
//expires := time.Now().Add(time.Second * time.Duration(ttl)).UTC().Format(http.TimeFormat)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, path.Join(root, "api")) {
next.ServeHTTP(w, r)
} else {
// TODO: investigate broken caching - https://github.com/opencloud-eu/opencloud/issues/1094
//w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%s, must-revalidate", strconv.Itoa(ttl)))
//w.Header().Set("Expires", expires)
//w.Header().Set("Last-Modified", lastModified)
w.Header().Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate, value")
w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT")
w.Header().Set("Last-Modified", time.Now().UTC().Format(http.TimeFormat))
static.ServeHTTP(w, r)
}
})
}
}
+19
View File
@@ -0,0 +1,19 @@
package middleware
import (
"net/http"
"github.com/go-chi/chi/v5/middleware"
)
// Throttle limits the number of concurrent requests.
func Throttle(limit int) func(http.Handler) http.Handler {
if limit > 0 {
return middleware.Throttle(limit)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
})
}
}
+21
View File
@@ -0,0 +1,21 @@
package middleware
import (
"net/http"
"go.opentelemetry.io/otel/propagation"
)
var propagator = propagation.NewCompositeTextMapPropagator(
propagation.Baggage{},
propagation.TraceContext{},
)
// TraceContext unpacks the request context looking for an existing trace id.
func TraceContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
propagator.Inject(ctx, propagation.HeaderCarrier(r.Header))
next.ServeHTTP(w, r.WithContext(ctx))
})
}
+21
View File
@@ -0,0 +1,21 @@
package middleware
import (
"fmt"
"net/http"
"strings"
)
// Version writes the current version to the headers.
func Version(name, version string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(
fmt.Sprintf("X-%s-VERSION", strings.ToUpper(name)),
version,
)
next.ServeHTTP(w, r)
})
}
}
+32
View File
@@ -0,0 +1,32 @@
package natsjsregistry
import (
"context"
"time"
"go-micro.dev/v4/registry"
"go-micro.dev/v4/store"
)
type storeOptionsKey struct{}
type defaultTTLKey struct{}
// StoreOptions sets the options for the underlying store
func StoreOptions(opts []store.Option) registry.Option {
return func(o *registry.Options) {
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, storeOptionsKey{}, opts)
}
}
// DefaultTTL allows setting a default register TTL for services
func DefaultTTL(t time.Duration) registry.Option {
return func(o *registry.Options) {
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, defaultTTLKey{}, t)
}
}
+221
View File
@@ -0,0 +1,221 @@
// Package natsjsregistry implements a registry using natsjs kv store
package natsjsregistry
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"sync"
"time"
natsjskv "github.com/go-micro/plugins/v4/store/nats-js-kv"
"github.com/nats-io/nats.go"
"go-micro.dev/v4/registry"
"go-micro.dev/v4/server"
"go-micro.dev/v4/store"
"go-micro.dev/v4/util/cmd"
)
var (
_registryName = "nats-js-kv"
_registryAddressEnv = "MICRO_REGISTRY_ADDRESS"
_registryUsernameEnv = "MICRO_REGISTRY_AUTH_USERNAME"
_registryPasswordEnv = "MICRO_REGISTRY_AUTH_PASSWORD"
_serviceDelimiter = "@"
)
func init() {
cmd.DefaultRegistries[_registryName] = NewRegistry
}
// NewRegistry returns a new natsjs registry
func NewRegistry(opts ...registry.Option) registry.Registry {
options := registry.Options{
Context: context.Background(),
}
for _, o := range opts {
o(&options)
}
defaultTTL, _ := options.Context.Value(defaultTTLKey{}).(time.Duration)
n := &storeregistry{
opts: options,
typ: _registryName,
defaultTTL: defaultTTL,
}
n.store = natsjskv.NewStore(n.storeOptions(options)...)
return n
}
type storeregistry struct {
opts registry.Options
store store.Store
typ string
defaultTTL time.Duration
lock sync.RWMutex
}
// Init inits the registry
func (n *storeregistry) Init(opts ...registry.Option) error {
n.lock.Lock()
defer n.lock.Unlock()
for _, o := range opts {
o(&n.opts)
}
n.store = natsjskv.NewStore(n.storeOptions(n.opts)...)
return n.store.Init(n.storeOptions(n.opts)...)
}
// Options returns the configured options
func (n *storeregistry) Options() registry.Options {
return n.opts
}
// Register adds a service to the registry
func (n *storeregistry) Register(s *registry.Service, opts ...registry.RegisterOption) error {
n.lock.RLock()
defer n.lock.RUnlock()
if s == nil {
return errors.New("wont store nil service")
}
var options registry.RegisterOptions
options.TTL = n.defaultTTL
for _, o := range opts {
o(&options)
}
b, err := json.Marshal(s)
if err != nil {
return err
}
return n.store.Write(&store.Record{
Key: s.Name + _serviceDelimiter + server.DefaultId + _serviceDelimiter + s.Version,
Value: b,
Expiry: options.TTL,
})
}
// Deregister removes a service from the registry.
func (n *storeregistry) Deregister(s *registry.Service, _ ...registry.DeregisterOption) error {
n.lock.RLock()
defer n.lock.RUnlock()
return n.store.Delete(s.Name + _serviceDelimiter + server.DefaultId + _serviceDelimiter + s.Version)
}
// GetService gets a specific service from the registry
func (n *storeregistry) GetService(s string, _ ...registry.GetOption) ([]*registry.Service, error) {
// avoid listing e.g. `webfinger` when requesting `web` by adding the delimiter to the service name
return n.listServices(store.ListPrefix(s + _serviceDelimiter))
}
// ListServices lists all registered services
func (n *storeregistry) ListServices(...registry.ListOption) ([]*registry.Service, error) {
return n.listServices()
}
// Watch allowes following the changes in the registry if it would be implemented
func (n *storeregistry) Watch(...registry.WatchOption) (registry.Watcher, error) {
return NewWatcher(n)
}
// String returns the name of the registry
func (n *storeregistry) String() string {
return n.typ
}
func (n *storeregistry) listServices(opts ...store.ListOption) ([]*registry.Service, error) {
n.lock.RLock()
defer n.lock.RUnlock()
keys, err := n.store.List(opts...)
if err != nil {
return nil, err
}
versions := map[string]*registry.Service{}
for _, k := range keys {
s, err := n.getNode(k)
if err != nil {
// TODO: continue ?
return nil, err
}
if versions[s.Version] == nil {
versions[s.Version] = s
} else {
versions[s.Version].Nodes = append(versions[s.Version].Nodes, s.Nodes...)
}
}
svcs := make([]*registry.Service, 0, len(versions))
for _, s := range versions {
svcs = append(svcs, s)
}
return svcs, nil
}
// getNode retrieves a node from the store. It returns a service to also keep track of the version.
func (n *storeregistry) getNode(s string) (*registry.Service, error) {
recs, err := n.store.Read(s)
if err != nil {
return nil, err
}
if len(recs) == 0 {
return nil, registry.ErrNotFound
}
var svc registry.Service
if err := json.Unmarshal(recs[0].Value, &svc); err != nil {
return nil, err
}
return &svc, nil
}
func (n *storeregistry) storeOptions(opts registry.Options) []store.Option {
storeoptions := []store.Option{
store.Database("service-registry"),
store.Table("service-registry"),
natsjskv.DefaultMemory(),
natsjskv.EncodeKeys(),
}
if defaultTTL, ok := opts.Context.Value(defaultTTLKey{}).(time.Duration); ok {
storeoptions = append(storeoptions, natsjskv.DefaultTTL(defaultTTL))
}
addr := []string{"127.0.0.1:9233"}
if len(opts.Addrs) > 0 {
addr = opts.Addrs
} else if a := strings.Split(os.Getenv(_registryAddressEnv), ","); len(a) > 0 && a[0] != "" {
addr = a
}
storeoptions = append(storeoptions, store.Nodes(addr...))
natsOptions := nats.GetDefaultOptions()
natsOptions.Name = "nats-js-kv-registry"
natsOptions.User, natsOptions.Password = getAuth()
natsOptions.ReconnectedCB = func(_ *nats.Conn) {
if err := n.Init(); err != nil {
fmt.Println("cannot reconnect to nats")
os.Exit(1)
}
}
natsOptions.ClosedCB = func(_ *nats.Conn) {
fmt.Println("nats connection closed")
os.Exit(1)
}
storeoptions = append(storeoptions, natsjskv.NatsOptions(natsOptions))
if so, ok := opts.Context.Value(storeOptionsKey{}).([]store.Option); ok {
storeoptions = append(storeoptions, so...)
}
return storeoptions
}
func getAuth() (string, string) {
return os.Getenv(_registryUsernameEnv), os.Getenv(_registryPasswordEnv)
}
+78
View File
@@ -0,0 +1,78 @@
package natsjsregistry
import (
"encoding/json"
"errors"
"strings"
natsjskv "github.com/go-micro/plugins/v4/store/nats-js-kv"
"github.com/nats-io/nats.go"
"go-micro.dev/v4/registry"
)
// NatsWatcher is the watcher of the nats interface
type NatsWatcher interface {
WatchAll(bucket string, opts ...nats.WatchOpt) (<-chan *natsjskv.StoreUpdate, func() error, error)
}
// Watcher is used to keep track of changes in the registry
type Watcher struct {
updates <-chan *natsjskv.StoreUpdate
stop func() error
reg *storeregistry
}
// NewWatcher returns a new watcher
func NewWatcher(s *storeregistry) (*Watcher, error) {
w, ok := s.store.(NatsWatcher)
if !ok {
return nil, errors.New("store does not implement watcher interface")
}
watcher, stop, err := w.WatchAll("service-registry")
if err != nil {
return nil, err
}
return &Watcher{
updates: watcher,
stop: stop,
reg: s,
}, nil
}
// Next returns the next result. It is a blocking call
func (w *Watcher) Next() (*registry.Result, error) {
kve := <-w.updates
if kve == nil {
return nil, errors.New("watcher stopped")
}
var svc registry.Service
if kve.Value.Data == nil {
// fake a service
parts := strings.SplitN(kve.Value.Key, _serviceDelimiter, 3)
if len(parts) != 3 {
return nil, errors.New("invalid service key")
}
svc.Name = parts[0]
// ocis registers nodes with a - separator
svc.Nodes = []*registry.Node{{Id: parts[0] + "-" + parts[1]}}
svc.Version = parts[2]
} else {
if err := json.Unmarshal(kve.Value.Data, &svc); err != nil {
_ = w.stop()
return nil, err
}
}
return &registry.Result{
Service: &svc,
Action: kve.Action,
}, nil
}
// Stop stops the watcher
func (w *Watcher) Stop() {
_ = w.stop()
}
+76
View File
@@ -0,0 +1,76 @@
package oidc
import (
"fmt"
"strings"
)
const (
Iss = "iss"
Sub = "sub"
Email = "email"
Name = "name"
PreferredUsername = "preferred_username"
UIDNumber = "uidnumber"
GIDNumber = "gidnumber"
Groups = "groups"
OwncloudUUID = "ownclouduuid"
OcisRoutingPolicy = "ocis.routing.policy"
)
// SplitWithEscaping splits s into segments using separator which can be escaped using the escape string
// See https://codereview.stackexchange.com/a/280193
func SplitWithEscaping(s string, separator string, escapeString string) []string {
a := strings.Split(s, separator)
for i := len(a) - 2; i >= 0; i-- {
if strings.HasSuffix(a[i], escapeString) {
a[i] = a[i][:len(a[i])-len(escapeString)] + separator + a[i+1]
a = append(a[:i+1], a[i+2:]...)
}
}
return a
}
// WalkSegments uses the given array of segments to walk the claims and return whatever interface was found
func WalkSegments(segments []string, claims map[string]interface{}) (interface{}, error) {
i := 0
for ; i < len(segments)-1; i++ {
switch castedClaims := claims[segments[i]].(type) {
case map[string]interface{}:
claims = castedClaims
case map[interface{}]interface{}:
claims = make(map[string]interface{}, len(castedClaims))
for k, v := range castedClaims {
if s, ok := k.(string); ok {
claims[s] = v
} else {
return nil, fmt.Errorf("could not walk claims path, key '%v' is not a string", k)
}
}
default:
return nil, fmt.Errorf("unsupported type '%v'", castedClaims)
}
}
return claims[segments[i]], nil
}
// ReadStringClaim returns the string obtained by following the . seperated path in the claims
func ReadStringClaim(path string, claims map[string]interface{}) (string, error) {
// check the simple case first
value, _ := claims[path].(string)
if value != "" {
return value, nil
}
claim, err := WalkSegments(SplitWithEscaping(path, ".", "\\"), claims)
if err != nil {
return "", err
}
if value, _ = claim.(string); value != "" {
return value, nil
}
return value, fmt.Errorf("claim path '%s' not set or empty", path)
}
+182
View File
@@ -0,0 +1,182 @@
package oidc_test
import (
"encoding/json"
"reflect"
"testing"
"github.com/opencloud-eu/opencloud/pkg/oidc"
)
type splitWithEscapingTest struct {
// Name of the subtest.
name string
// string to split
s string
// seperator to use
seperator string
// escape character to use for escaping
escape string
expectedParts []string
}
func (swet splitWithEscapingTest) run(t *testing.T) {
parts := oidc.SplitWithEscaping(swet.s, swet.seperator, swet.escape)
if len(swet.expectedParts) != len(parts) {
t.Errorf("mismatching length")
}
for i, v := range swet.expectedParts {
if parts[i] != v {
t.Errorf("expected part %d to be '%s', got '%s'", i, v, parts[i])
}
}
}
func TestSplitWithEscaping(t *testing.T) {
tests := []splitWithEscapingTest{
{
name: "plain claim name",
s: "roles",
seperator: ".",
escape: "\\",
expectedParts: []string{"roles"},
},
{
name: "claim with .",
s: "my.roles",
seperator: ".",
escape: "\\",
expectedParts: []string{"my", "roles"},
},
{
name: "claim with escaped .",
s: "my\\.roles",
seperator: ".",
escape: "\\",
expectedParts: []string{"my.roles"},
},
{
name: "claim with escaped . left",
s: "my\\.other.roles",
seperator: ".",
escape: "\\",
expectedParts: []string{"my.other", "roles"},
},
{
name: "claim with escaped . right",
s: "my.other\\.roles",
seperator: ".",
escape: "\\",
expectedParts: []string{"my", "other.roles"},
},
}
for _, test := range tests {
t.Run(test.name, test.run)
}
}
type walkSegmentsTest struct {
// Name of the subtest.
name string
// path segments to walk
segments []string
// seperator to use
claims map[string]interface{}
expected interface{}
wantErr bool
}
func (wst walkSegmentsTest) run(t *testing.T) {
v, err := oidc.WalkSegments(wst.segments, wst.claims)
if err != nil && !wst.wantErr {
t.Errorf("%v", err)
}
if err == nil && wst.wantErr {
t.Errorf("expected error")
}
if !reflect.DeepEqual(v, wst.expected) {
t.Errorf("expected %v got %v", wst.expected, v)
}
}
func TestWalkSegments(t *testing.T) {
byt := []byte(`{"first":{"second":{"third":["value1","value2"]},"foo":"bar"},"fizz":"buzz"}`)
var dat map[string]interface{}
if err := json.Unmarshal(byt, &dat); err != nil {
t.Errorf("%v", err)
}
tests := []walkSegmentsTest{
{
name: "one segment, single value",
segments: []string{"first"},
claims: map[string]interface{}{
"first": "value",
},
expected: "value",
wantErr: false,
},
{
name: "one segment, array value",
segments: []string{"first"},
claims: map[string]interface{}{
"first": []string{"value1", "value2"},
},
expected: []string{"value1", "value2"},
wantErr: false,
},
{
name: "two segments, single value",
segments: []string{"first", "second"},
claims: map[string]interface{}{
"first": map[string]interface{}{
"second": "value",
},
},
expected: "value",
wantErr: false,
},
{
name: "two segments, array value",
segments: []string{"first", "second"},
claims: map[string]interface{}{
"first": map[string]interface{}{
"second": []string{"value1", "value2"},
},
},
expected: []string{"value1", "value2"},
wantErr: false,
},
{
name: "three segments, array value from json",
segments: []string{"first", "second", "third"},
claims: dat,
expected: []interface{}{"value1", "value2"},
wantErr: false,
},
{
name: "three segments, array value with interface key",
segments: []string{"first", "second", "third"},
claims: map[string]interface{}{
"first": map[interface{}]interface{}{
"second": map[interface{}]interface{}{
"third": []string{"value1", "value2"},
},
},
},
expected: []string{"value1", "value2"},
wantErr: false,
},
}
for _, test := range tests {
t.Run(test.name, test.run)
}
}
+373
View File
@@ -0,0 +1,373 @@
package oidc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/MicahParks/keyfunc/v2"
goidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt/v5"
"github.com/opencloud-eu/opencloud/pkg/log"
"github.com/opencloud-eu/opencloud/services/proxy/pkg/config"
"golang.org/x/oauth2"
)
// OIDCClient used to mock the oidc client during tests
type OIDCClient interface {
UserInfo(ctx context.Context, ts oauth2.TokenSource) (*UserInfo, error)
VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, jwt.MapClaims, error)
VerifyLogoutToken(ctx context.Context, token string) (*LogoutToken, error)
}
// KeySet is a set of public JSON Web Keys that can be used to validate the signature
// of JSON web tokens. This is expected to be backed by a remote key set through
// provider metadata discovery or an in-memory set of keys delivered out-of-band.
type KeySet interface {
// VerifySignature parses the JSON web token, verifies the signature, and returns
// the raw payload. Header and claim fields are validated by other parts of the
// package. For example, the KeySet does not need to check values such as signature
// algorithm, issuer, and audience since the IDTokenVerifier validates these values
// independently.
//
// If VerifySignature makes HTTP requests to verify the token, it's expected to
// use any HTTP client associated with the context through ClientContext.
VerifySignature(ctx context.Context, jwt string) (payload []byte, err error)
}
type RegClaimsWithSID struct {
SessionID string `json:"sid"`
jwt.RegisteredClaims
}
type oidcClient struct {
// Logger to use for logging, must be set
Logger log.Logger
issuer string
provider *ProviderMetadata
providerLock *sync.Mutex
skipIssuerValidation bool
accessTokenVerifyMethod string
remoteKeySet KeySet
algorithms []string
JWKSOptions config.JWKS
JWKS *keyfunc.JWKS
jwksLock *sync.Mutex
httpClient *http.Client
}
// _supportedAlgorithms is a list of algorithms explicitly supported by this
// package. If a provider supports other algorithms, such as HS256 or none,
// those values won't be passed to the IDTokenVerifier.
var _supportedAlgorithms = map[string]bool{
RS256: true,
RS384: true,
RS512: true,
ES256: true,
ES384: true,
ES512: true,
PS256: true,
PS384: true,
PS512: true,
}
// NewOIDCClient returns an OIDClient instance for the given issuer
func NewOIDCClient(opts ...Option) OIDCClient {
options := newOptions(opts...)
return &oidcClient{
Logger: options.Logger,
issuer: options.OIDCIssuer,
httpClient: options.HTTPClient,
accessTokenVerifyMethod: options.AccessTokenVerifyMethod,
JWKSOptions: options.JWKSOptions, // TODO I don't like that we pass down config options ...
JWKS: options.JWKS,
providerLock: &sync.Mutex{},
jwksLock: &sync.Mutex{},
remoteKeySet: options.KeySet,
provider: options.ProviderMetadata,
}
}
func (c *oidcClient) lookupWellKnownOpenidConfiguration(ctx context.Context) error {
c.providerLock.Lock()
defer c.providerLock.Unlock()
if c.provider == nil {
wellKnown := strings.TrimSuffix(c.issuer, "/") + wellknownPath
req, err := http.NewRequest("GET", wellKnown, nil)
if err != nil {
return err
}
resp, err := c.httpClient.Do(req.WithContext(ctx))
if err != nil {
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("unable to read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("%s: %s", resp.Status, body)
}
var p ProviderMetadata
err = unmarshalResp(resp, body, &p)
if err != nil {
return fmt.Errorf("oidc: failed to decode provider discovery object: %v", err)
}
if !c.skipIssuerValidation && p.Issuer != c.issuer {
return fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", c.issuer, p.Issuer)
}
var algs []string
for _, a := range p.IDTokenSigningAlgValuesSupported {
if _supportedAlgorithms[a] {
algs = append(algs, a)
}
}
c.provider = &p
c.algorithms = algs
c.remoteKeySet = goidc.NewRemoteKeySet(goidc.ClientContext(ctx, c.httpClient), p.JwksURI)
}
return nil
}
func (c *oidcClient) getKeyfunc() *keyfunc.JWKS {
c.jwksLock.Lock()
defer c.jwksLock.Unlock()
if c.JWKS == nil {
var err error
c.Logger.Debug().Str("jwks", c.provider.JwksURI).Msg("discovered jwks endpoint")
options := keyfunc.Options{
Client: c.httpClient,
RefreshErrorHandler: func(err error) {
c.Logger.Error().Err(err).Msg("There was an error with the jwt.Keyfunc")
},
RefreshInterval: time.Minute * time.Duration(c.JWKSOptions.RefreshInterval),
RefreshRateLimit: time.Second * time.Duration(c.JWKSOptions.RefreshRateLimit),
RefreshTimeout: time.Second * time.Duration(c.JWKSOptions.RefreshTimeout),
RefreshUnknownKID: c.JWKSOptions.RefreshUnknownKID,
}
c.JWKS, err = keyfunc.Get(c.provider.JwksURI, options)
if err != nil {
c.JWKS = nil
c.Logger.Error().Err(err).Msg("Failed to create JWKS from resource at the given URL.")
return nil
}
}
return c.JWKS
}
type stringAsBool bool
// Claims unmarshals the raw JSON string into a bool.
func (sb *stringAsBool) UnmarshalJSON(b []byte) error {
v, err := strconv.ParseBool(string(b))
if err != nil {
return err
}
*sb = stringAsBool(v)
return nil
}
// UserInfo represents the OpenID Connect userinfo claims.
type UserInfo struct {
Subject string `json:"sub"`
Profile string `json:"profile"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
claims []byte
}
type userInfoRaw struct {
Subject string `json:"sub"`
Profile string `json:"profile"`
Email string `json:"email"`
// Handle providers that return email_verified as a string
// https://forums.aws.amazon.com/thread.jspa?messageID=949441&#949441 and
// https://discuss.elastic.co/t/openid-error-after-authenticating-against-aws-cognito/206018/11
EmailVerified stringAsBool `json:"email_verified"`
}
// Claims unmarshals the raw JSON object claims into the provided object.
func (u *UserInfo) Claims(v interface{}) error {
if u.claims == nil {
return errors.New("oidc: claims not set")
}
return json.Unmarshal(u.claims, v)
}
// UserInfo retrieves the userinfo from a Token
func (c *oidcClient) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource) (*UserInfo, error) {
if err := c.lookupWellKnownOpenidConfiguration(ctx); err != nil {
return nil, err
}
if c.provider.UserinfoEndpoint == "" {
return nil, errors.New("oidc: user info endpoint is not supported by this provider")
}
req, err := http.NewRequest("GET", c.provider.UserinfoEndpoint, nil)
if err != nil {
return nil, fmt.Errorf("oidc: create GET request: %v", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("oidc: get access token: %v", err)
}
token.SetAuthHeader(req)
resp, err := c.httpClient.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s: %s", resp.Status, body)
}
ct := resp.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(ct)
if err == nil && mediaType == "application/jwt" {
payload, err := c.remoteKeySet.VerifySignature(goidc.ClientContext(ctx, c.httpClient), string(body))
if err != nil {
return nil, fmt.Errorf("oidc: invalid userinfo jwt signature %v", err)
}
body = payload
}
var userInfo userInfoRaw
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("oidc: failed to decode userinfo: %v", err)
}
return &UserInfo{
Subject: userInfo.Subject,
Profile: userInfo.Profile,
Email: userInfo.Email,
EmailVerified: bool(userInfo.EmailVerified),
claims: body,
}, nil
}
func (c *oidcClient) VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, jwt.MapClaims, error) {
if err := c.lookupWellKnownOpenidConfiguration(ctx); err != nil {
return RegClaimsWithSID{}, jwt.MapClaims{}, err
}
switch c.accessTokenVerifyMethod {
case config.AccessTokenVerificationJWT:
return c.verifyAccessTokenJWT(token)
case config.AccessTokenVerificationNone:
c.Logger.Debug().Msg("Access Token verification disabled")
return RegClaimsWithSID{}, jwt.MapClaims{}, nil
default:
c.Logger.Error().Str("access_token_verify_method", c.accessTokenVerifyMethod).Msg("Unknown Access Token verification setting")
return RegClaimsWithSID{}, jwt.MapClaims{}, errors.New("unknown Access Token Verification method")
}
}
// verifyAccessTokenJWT tries to parse and verify the access token as a JWT.
func (c *oidcClient) verifyAccessTokenJWT(token string) (RegClaimsWithSID, jwt.MapClaims, error) {
var claims RegClaimsWithSID
mapClaims := jwt.MapClaims{}
jwks := c.getKeyfunc()
if jwks == nil {
return claims, mapClaims, errors.New("error initializing jwks keyfunc")
}
issuer := c.issuer
if c.provider.AccessTokenIssuer != "" {
// AD FS .well-known/openid-configuration has an optional `access_token_issuer` which takes precedence over `issuer`
// See https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-oidce/586de7dd-3385-47c7-93a2-935d9e90441c
issuer = c.provider.AccessTokenIssuer
}
_, err := jwt.ParseWithClaims(token, &claims, jwks.Keyfunc, jwt.WithIssuer(issuer))
if err != nil {
return claims, mapClaims, err
}
_, _, err = new(jwt.Parser).ParseUnverified(token, mapClaims)
// TODO: decode mapClaims to sth readable
c.Logger.Debug().Interface("access token", &claims).Msg("parsed access token")
if err != nil {
c.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.")
return claims, mapClaims, err
}
return claims, mapClaims, nil
}
func (c *oidcClient) VerifyLogoutToken(ctx context.Context, rawToken string) (*LogoutToken, error) {
var claims LogoutToken
if err := c.lookupWellKnownOpenidConfiguration(ctx); err != nil {
return nil, err
}
jwks := c.getKeyfunc()
if jwks == nil {
return nil, errors.New("error initializing jwks keyfunc")
}
// From the backchannel-logout spec: Like ID Tokens, selection of the
// algorithm used is governed by the id_token_signing_alg_values_supported
// Discovery parameter and the id_token_signed_response_alg Registration
// parameter when they are used; otherwise, the value SHOULD be the default
// of RS256
supportedSigAlgs := c.algorithms
if len(supportedSigAlgs) == 0 {
supportedSigAlgs = []string{RS256}
}
_, err := jwt.ParseWithClaims(rawToken, &claims, jwks.Keyfunc, jwt.WithValidMethods(supportedSigAlgs), jwt.WithIssuer(c.issuer))
if err != nil {
c.Logger.Debug().Err(err).Msg("Failed to parse logout token")
return nil, err
}
// Basic token validation has happened in ParseWithClaims (signature,
// issuer, audience, ...). Now for some logout token specific checks.
// 1. Verify that the Logout Token contains a sub Claim, a sid Claim, or both.
if claims.Subject == "" && claims.SessionId == "" {
return nil, fmt.Errorf("oidc: logout token must contain either sub or sid and MAY contain both")
}
// 2. Verify that the Logout Token contains an events Claim whose value is JSON object containing the member name http://schemas.openid.net/event/backchannel-logout.
if claims.Events.Event == nil {
return nil, fmt.Errorf("oidc: logout token must contain logout event")
}
// 3. Verify that the Logout Token does not contain a nonce Claim.
if claims.Nonce != nil {
return nil, fmt.Errorf("oidc: nonce on logout token MUST NOT be present")
}
return &claims, nil
}
func unmarshalResp(r *http.Response, body []byte, v interface{}) error {
err := json.Unmarshal(body, &v)
if err == nil {
return nil
}
ct := r.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(ct)
if err == nil && mediaType == "application/json" {
return fmt.Errorf("got Content-Type = application/json, but could not unmarshal as JSON: %v", err)
}
return fmt.Errorf("expected Content-Type = application/json, got %q: %v", ct, err)
}
+205
View File
@@ -0,0 +1,205 @@
package oidc_test
import (
"context"
"crypto/rand"
"crypto/rsa"
"testing"
"github.com/MicahParks/keyfunc/v2"
"github.com/golang-jwt/jwt/v5"
"github.com/opencloud-eu/opencloud/pkg/oidc"
)
type signingKey struct {
priv interface{}
jwks *keyfunc.JWKS
}
func TestLogoutVerify(t *testing.T) {
tests := []logoutVerificationTest{
{
name: "good token",
logoutToken: jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": "https://foo",
"sub": "248289761001",
"aud": "s6BhdRkqt3",
"iat": 1471566154,
"jti": "bWJq",
"sid": "08a5019c-17e1-4977-8f42-65a12843ea02",
"events": map[string]interface{}{
"http://schemas.openid.net/event/backchannel-logout": struct{}{},
},
}),
signKey: newRSAKey(t),
},
{
name: "invalid issuer",
issuer: "https://bar",
logoutToken: jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": "https://foo1",
"sub": "248289761001",
"events": map[string]interface{}{
"http://schemas.openid.net/event/backchannel-logout": struct{}{},
},
}),
signKey: newRSAKey(t),
wantErr: true,
},
{
name: "invalid sig",
logoutToken: jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": "https://foo",
"sub": "248289761001",
"aud": "s6BhdRkqt3",
"iat": 1471566154,
"jti": "bWJq",
"sid": "08a5019c-17e1-4977-8f42-65a12843ea02",
"events": map[string]interface{}{
"http://schemas.openid.net/event/backchannel-logout": struct{}{},
},
}),
signKey: newRSAKey(t),
verificationKey: newRSAKey(t),
wantErr: true,
},
{
name: "no sid and no sub",
logoutToken: jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": "https://foo",
"aud": "s6BhdRkqt3",
"iat": 1471566154,
"jti": "bWJq",
"events": map[string]interface{}{
"http://schemas.openid.net/event/backchannel-logout": struct{}{},
},
}),
signKey: newRSAKey(t),
wantErr: true,
},
{
name: "Prohibited nonce present",
logoutToken: jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": "https://foo",
"sub": "248289761001",
"aud": "s6BhdRkqt3",
"iat": 1471566154,
"jti": "bWJq",
"sid": "08a5019c-17e1-4977-8f42-65a12843ea02",
"nonce": "123",
"events": map[string]interface{}{
"http://schemas.openid.net/event/backchannel-logout": struct{}{},
},
}),
signKey: newRSAKey(t),
wantErr: true,
},
{
name: "Wrong Event string",
logoutToken: jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": "https://foo",
"sub": "248289761001",
"aud": "s6BhdRkqt3",
"iat": 1471566154,
"jti": "bWJq",
"sid": "08a5019c-17e1-4977-8f42-65a12843ea02",
"events": map[string]interface{}{
"http://blah.blah.blash/event/backchannel-logout": struct{}{},
},
}),
signKey: newRSAKey(t),
wantErr: true,
},
{
name: "No Event string",
logoutToken: jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": "https://foo",
"sub": "248289761001",
"aud": "s6BhdRkqt3",
"iat": 1471566154,
"jti": "bWJq",
"sid": "08a5019c-17e1-4977-8f42-65a12843ea02",
}),
signKey: newRSAKey(t),
wantErr: true,
},
}
for _, test := range tests {
t.Run(test.name, test.run)
}
}
type logoutVerificationTest struct {
// Name of the subtest.
name string
// If not provided defaults to "https://foo"
issuer string
// JWT payload (just the claims).
logoutToken *jwt.Token
// Key to sign the ID Token with.
signKey *signingKey
// If not provided defaults to signKey. Only useful when
// testing invalid signatures.
verificationKey *signingKey
wantErr bool
}
func (v logoutVerificationTest) runGetToken(t *testing.T) (*oidc.LogoutToken, error) {
// token := v.signKey.sign(t, []byte(v.logoutToken))
v.logoutToken.Header["kid"] = "1"
token, err := v.logoutToken.SignedString(v.signKey.priv)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
issuer := "https://foo"
var jwks *keyfunc.JWKS
if v.verificationKey == nil {
jwks = v.signKey.jwks
} else {
jwks = v.verificationKey.jwks
}
pm := oidc.ProviderMetadata{}
verifier := oidc.NewOIDCClient(
oidc.WithOidcIssuer(issuer),
oidc.WithJWKS(jwks),
oidc.WithProviderMetadata(&pm),
)
return verifier.VerifyLogoutToken(ctx, token)
}
func (l logoutVerificationTest) run(t *testing.T) {
_, err := l.runGetToken(t)
if err != nil && !l.wantErr {
t.Errorf("%v", err)
}
if err == nil && l.wantErr {
t.Errorf("expected error")
}
}
func newRSAKey(t testing.TB) *signingKey {
priv, err := rsa.GenerateKey(rand.Reader, 1028)
if err != nil {
t.Fatal(err)
}
givenKey := keyfunc.NewGivenRSA(
&priv.PublicKey,
keyfunc.GivenKeyOptions{Algorithm: jwt.SigningMethodRS256.Alg()},
)
jwks := keyfunc.NewGiven(
map[string]keyfunc.GivenKey{
"1": givenKey,
},
)
return &signingKey{priv, jwks}
}
+31
View File
@@ -0,0 +1,31 @@
package oidc
import "context"
// contextKey is the key for oidc claims in a context
type contextKey struct{}
// newSessionFlagKey is the key for the new session flag in a context
type newSessionFlagKey struct{}
// NewContext makes a new context that contains the OpenID connect claims in a map.
func NewContext(parent context.Context, c map[string]interface{}) context.Context {
return context.WithValue(parent, contextKey{}, c)
}
// FromContext returns the claims map stored in a context, or nil if there isn't one.
func FromContext(ctx context.Context) map[string]interface{} {
s, _ := ctx.Value(contextKey{}).(map[string]interface{})
return s
}
// NewContextSessionFlag makes a new context that contains the new session flag.
func NewContextSessionFlag(ctx context.Context, flag bool) context.Context {
return context.WithValue(ctx, newSessionFlagKey{}, flag)
}
// NewSessionFlagFromContext returns the new session flag stored in a context.
func NewSessionFlagFromContext(ctx context.Context) bool {
s, _ := ctx.Value(newSessionFlagKey{}).(bool)
return s
}
+16
View File
@@ -0,0 +1,16 @@
package oidc
// JOSE asymmetric signing algorithm values as defined by RFC 7518
//
// see: https://tools.ietf.org/html/rfc7518#section-3.1
const (
RS256 = "RS256" // RSASSA-PKCS-v1.5 using SHA-256
RS384 = "RS384" // RSASSA-PKCS-v1.5 using SHA-384
RS512 = "RS512" // RSASSA-PKCS-v1.5 using SHA-512
ES256 = "ES256" // ECDSA using P-256 and SHA-256
ES384 = "ES384" // ECDSA using P-384 and SHA-384
ES512 = "ES512" // ECDSA using P-521 and SHA-512
PS256 = "PS256" // RSASSA-PSS using SHA256 and MGF1-SHA256
PS384 = "PS384" // RSASSA-PSS using SHA384 and MGF1-SHA384
PS512 = "PS512" // RSASSA-PSS using SHA512 and MGF1-SHA512
)
+103
View File
@@ -0,0 +1,103 @@
package oidc
import (
"encoding/json"
"io"
"net/http"
"strings"
"github.com/golang-jwt/jwt/v5"
"github.com/opencloud-eu/opencloud/pkg/log"
)
const wellknownPath = "/.well-known/openid-configuration"
// The ProviderMetadata describes an idp.
// see https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
type ProviderMetadata struct {
AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"`
//claims_parameter_supported
ClaimsSupported []string `json:"claims_supported,omitempty"`
//grant_types_supported
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"`
Issuer string `json:"issuer,omitempty"`
// AccessTokenIssuer is only used by AD FS and needs to be used when validating the iss of its access tokens
// See https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-oidce/586de7dd-3385-47c7-93a2-935d9e90441c
AccessTokenIssuer string `json:"access_token_issuer,omitempty"`
JwksURI string `json:"jwks_uri,omitempty"`
//registration_endpoint
//request_object_signing_alg_values_supported
//request_parameter_supported
//request_uri_parameter_supported
//require_request_uri_registration
//response_modes_supported
ResponseTypesSupported []string `json:"response_types_supported,omitempty"`
ScopesSupported []string `json:"scopes_supported,omitempty"`
SubjectTypesSupported []string `json:"subject_types_supported,omitempty"`
TokenEndpoint string `json:"token_endpoint,omitempty"`
//token_endpoint_auth_methods_supported
//token_endpoint_auth_signing_alg_values_supported
UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
//userinfo_signing_alg_values_supported
//code_challenge_methods_supported
IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"`
//introspection_endpoint_auth_methods_supported
//introspection_endpoint_auth_signing_alg_values_supported
RevocationEndpoint string `json:"revocation_endpoint,omitempty"`
//revocation_endpoint_auth_methods_supported
//revocation_endpoint_auth_signing_alg_values_supported
//id_token_encryption_alg_values_supported
//id_token_encryption_enc_values_supported
//userinfo_encryption_alg_values_supported
//userinfo_encryption_enc_values_supported
//request_object_encryption_alg_values_supported
//request_object_encryption_enc_values_supported
CheckSessionIframe string `json:"check_session_iframe,omitempty"`
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
//claim_types_supported
}
// Logout Token defines an logout Token
type LogoutToken struct {
jwt.RegisteredClaims
// The Session Id
SessionId string `json:"sid"`
Events LogoutEvent `json:"events"`
// Note: This is just here to be able to check for nonce being absent
Nonce *string `json:"nonce"`
}
// LogoutEvent defines a logout Event
type LogoutEvent struct {
Event *struct{} `json:"http://schemas.openid.net/event/backchannel-logout"`
}
func GetIDPMetadata(logger log.Logger, client *http.Client, idpURI string) (ProviderMetadata, error) {
wellknownURI := strings.TrimSuffix(idpURI, "/") + wellknownPath
resp, err := client.Get(wellknownURI)
if err != nil {
logger.Error().Err(err).Msg("Failed to set request for .well-known/openid-configuration")
return ProviderMetadata{}, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
logger.Error().Err(err).Msg("unable to read discovery response body")
return ProviderMetadata{}, err
}
if resp.StatusCode != http.StatusOK {
logger.Error().Str("status", resp.Status).Str("body", string(body)).Msg("error requesting openid-configuration")
return ProviderMetadata{}, err
}
var oidcMetadata ProviderMetadata
err = json.Unmarshal(body, &oidcMetadata)
if err != nil {
logger.Error().Err(err).Msg("failed to decode provider openid-configuration")
return ProviderMetadata{}, err
}
return oidcMetadata, nil
}
+225
View File
@@ -0,0 +1,225 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
package mocks
import (
context "context"
jwt "github.com/golang-jwt/jwt/v5"
mock "github.com/stretchr/testify/mock"
oauth2 "golang.org/x/oauth2"
oidc "github.com/opencloud-eu/opencloud/pkg/oidc"
)
// OIDCClient is an autogenerated mock type for the OIDCClient type
type OIDCClient struct {
mock.Mock
}
type OIDCClient_Expecter struct {
mock *mock.Mock
}
func (_m *OIDCClient) EXPECT() *OIDCClient_Expecter {
return &OIDCClient_Expecter{mock: &_m.Mock}
}
// UserInfo provides a mock function with given fields: ctx, ts
func (_m *OIDCClient) UserInfo(ctx context.Context, ts oauth2.TokenSource) (*oidc.UserInfo, error) {
ret := _m.Called(ctx, ts)
if len(ret) == 0 {
panic("no return value specified for UserInfo")
}
var r0 *oidc.UserInfo
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, oauth2.TokenSource) (*oidc.UserInfo, error)); ok {
return rf(ctx, ts)
}
if rf, ok := ret.Get(0).(func(context.Context, oauth2.TokenSource) *oidc.UserInfo); ok {
r0 = rf(ctx, ts)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*oidc.UserInfo)
}
}
if rf, ok := ret.Get(1).(func(context.Context, oauth2.TokenSource) error); ok {
r1 = rf(ctx, ts)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// OIDCClient_UserInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UserInfo'
type OIDCClient_UserInfo_Call struct {
*mock.Call
}
// UserInfo is a helper method to define mock.On call
// - ctx context.Context
// - ts oauth2.TokenSource
func (_e *OIDCClient_Expecter) UserInfo(ctx interface{}, ts interface{}) *OIDCClient_UserInfo_Call {
return &OIDCClient_UserInfo_Call{Call: _e.mock.On("UserInfo", ctx, ts)}
}
func (_c *OIDCClient_UserInfo_Call) Run(run func(ctx context.Context, ts oauth2.TokenSource)) *OIDCClient_UserInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(oauth2.TokenSource))
})
return _c
}
func (_c *OIDCClient_UserInfo_Call) Return(_a0 *oidc.UserInfo, _a1 error) *OIDCClient_UserInfo_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *OIDCClient_UserInfo_Call) RunAndReturn(run func(context.Context, oauth2.TokenSource) (*oidc.UserInfo, error)) *OIDCClient_UserInfo_Call {
_c.Call.Return(run)
return _c
}
// VerifyAccessToken provides a mock function with given fields: ctx, token
func (_m *OIDCClient) VerifyAccessToken(ctx context.Context, token string) (oidc.RegClaimsWithSID, jwt.MapClaims, error) {
ret := _m.Called(ctx, token)
if len(ret) == 0 {
panic("no return value specified for VerifyAccessToken")
}
var r0 oidc.RegClaimsWithSID
var r1 jwt.MapClaims
var r2 error
if rf, ok := ret.Get(0).(func(context.Context, string) (oidc.RegClaimsWithSID, jwt.MapClaims, error)); ok {
return rf(ctx, token)
}
if rf, ok := ret.Get(0).(func(context.Context, string) oidc.RegClaimsWithSID); ok {
r0 = rf(ctx, token)
} else {
r0 = ret.Get(0).(oidc.RegClaimsWithSID)
}
if rf, ok := ret.Get(1).(func(context.Context, string) jwt.MapClaims); ok {
r1 = rf(ctx, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(jwt.MapClaims)
}
}
if rf, ok := ret.Get(2).(func(context.Context, string) error); ok {
r2 = rf(ctx, token)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// OIDCClient_VerifyAccessToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyAccessToken'
type OIDCClient_VerifyAccessToken_Call struct {
*mock.Call
}
// VerifyAccessToken is a helper method to define mock.On call
// - ctx context.Context
// - token string
func (_e *OIDCClient_Expecter) VerifyAccessToken(ctx interface{}, token interface{}) *OIDCClient_VerifyAccessToken_Call {
return &OIDCClient_VerifyAccessToken_Call{Call: _e.mock.On("VerifyAccessToken", ctx, token)}
}
func (_c *OIDCClient_VerifyAccessToken_Call) Run(run func(ctx context.Context, token string)) *OIDCClient_VerifyAccessToken_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string))
})
return _c
}
func (_c *OIDCClient_VerifyAccessToken_Call) Return(_a0 oidc.RegClaimsWithSID, _a1 jwt.MapClaims, _a2 error) *OIDCClient_VerifyAccessToken_Call {
_c.Call.Return(_a0, _a1, _a2)
return _c
}
func (_c *OIDCClient_VerifyAccessToken_Call) RunAndReturn(run func(context.Context, string) (oidc.RegClaimsWithSID, jwt.MapClaims, error)) *OIDCClient_VerifyAccessToken_Call {
_c.Call.Return(run)
return _c
}
// VerifyLogoutToken provides a mock function with given fields: ctx, token
func (_m *OIDCClient) VerifyLogoutToken(ctx context.Context, token string) (*oidc.LogoutToken, error) {
ret := _m.Called(ctx, token)
if len(ret) == 0 {
panic("no return value specified for VerifyLogoutToken")
}
var r0 *oidc.LogoutToken
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (*oidc.LogoutToken, error)); ok {
return rf(ctx, token)
}
if rf, ok := ret.Get(0).(func(context.Context, string) *oidc.LogoutToken); ok {
r0 = rf(ctx, token)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*oidc.LogoutToken)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, token)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// OIDCClient_VerifyLogoutToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyLogoutToken'
type OIDCClient_VerifyLogoutToken_Call struct {
*mock.Call
}
// VerifyLogoutToken is a helper method to define mock.On call
// - ctx context.Context
// - token string
func (_e *OIDCClient_Expecter) VerifyLogoutToken(ctx interface{}, token interface{}) *OIDCClient_VerifyLogoutToken_Call {
return &OIDCClient_VerifyLogoutToken_Call{Call: _e.mock.On("VerifyLogoutToken", ctx, token)}
}
func (_c *OIDCClient_VerifyLogoutToken_Call) Run(run func(ctx context.Context, token string)) *OIDCClient_VerifyLogoutToken_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string))
})
return _c
}
func (_c *OIDCClient_VerifyLogoutToken_Call) Return(_a0 *oidc.LogoutToken, _a1 error) *OIDCClient_VerifyLogoutToken_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *OIDCClient_VerifyLogoutToken_Call) RunAndReturn(run func(context.Context, string) (*oidc.LogoutToken, error)) *OIDCClient_VerifyLogoutToken_Call {
_c.Call.Return(run)
return _c
}
// NewOIDCClient creates a new instance of OIDCClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewOIDCClient(t interface {
mock.TestingT
Cleanup(func())
}) *OIDCClient {
mock := &OIDCClient{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+117
View File
@@ -0,0 +1,117 @@
package oidc
import (
"net/http"
"github.com/MicahParks/keyfunc/v2"
"github.com/opencloud-eu/opencloud/pkg/log"
"github.com/opencloud-eu/opencloud/services/proxy/pkg/config"
goidc "github.com/coreos/go-oidc/v3/oidc"
)
// Option defines a single option function.
type Option func(o *Options)
// Options defines the available options for this package.
type Options struct {
// HTTPClient to use for requests
HTTPClient *http.Client
// Logger to use for logging, must be set
Logger log.Logger
// The OpenID Connect Issuer URL
OIDCIssuer string
// JWKSOptions to use when retrieving keys
JWKSOptions config.JWKS
// the JWKS keyset to use for verifying signatures of Access- and
// Logout-Tokens
// this option is mostly needed for unit test. To avoid fetching the keys
// from the issuer
JWKS *keyfunc.JWKS
// KeySet to use when verifiing signatures of jwt encoded
// user info responses
// TODO move userinfo verification to use jwt/keyfunc as well
KeySet KeySet
// AccessTokenVerifyMethod to use when verifying access tokens
// TODO pass a function or interface to verify? an AccessTokenVerifier?
AccessTokenVerifyMethod string
// Config to use
Config *goidc.Config
// ProviderMetadata to use
ProviderMetadata *ProviderMetadata
}
// newOptions initializes the available default options.
func newOptions(opts ...Option) Options {
opt := Options{}
for _, o := range opts {
o(&opt)
}
return opt
}
// WithOidcIssuer provides a function to set the openid connect issuer option.
func WithOidcIssuer(val string) Option {
return func(o *Options) {
o.OIDCIssuer = val
}
}
// WithLogger provides a function to set the logger option.
func WithLogger(val log.Logger) Option {
return func(o *Options) {
o.Logger = val
}
}
// WithAccessTokenVerifyMethod provides a function to set the accessTokenVerifyMethod option.
func WithAccessTokenVerifyMethod(val string) Option {
return func(o *Options) {
o.AccessTokenVerifyMethod = val
}
}
// WithHTTPClient provides a function to set the httpClient option.
func WithHTTPClient(val *http.Client) Option {
return func(o *Options) {
o.HTTPClient = val
}
}
// WithJWKSOptions provides a function to set the jwksOptions option.
func WithJWKSOptions(val config.JWKS) Option {
return func(o *Options) {
o.JWKSOptions = val
}
}
// WithJWKS provides a function to set the JWKS option (mainly useful for testing).
func WithJWKS(val *keyfunc.JWKS) Option {
return func(o *Options) {
o.JWKS = val
}
}
// WithKeySet provides a function to set the KeySet option.
func WithKeySet(val KeySet) Option {
return func(o *Options) {
o.KeySet = val
}
}
// WithConfig provides a function to set the Config option.
func WithConfig(val *goidc.Config) Option {
return func(o *Options) {
o.Config = val
}
}
// WithProviderMetadata provides a function to set the provider option.
func WithProviderMetadata(val *ProviderMetadata) Option {
return func(o *Options) {
o.ProviderMetadata = val
}
}
+33
View File
@@ -0,0 +1,33 @@
package registry
import (
"os"
"time"
)
const (
_registryRegisterIntervalEnv = "EXPERIMENTAL_REGISTER_INTERVAL"
_registryRegisterTTLEnv = "EXPERIMENTAL_REGISTER_TTL"
// Note: _defaultRegisterInterval should always be lower than _defaultRegisterTTL
_defaultRegisterInterval = time.Second * 25
_defaultRegisterTTL = time.Second * 30
)
// GetRegisterInterval returns the register interval from the environment.
func GetRegisterInterval() time.Duration {
d, err := time.ParseDuration(os.Getenv(_registryRegisterIntervalEnv))
if err != nil {
return _defaultRegisterInterval
}
return d
}
// GetRegisterTTL returns the register TTL from the environment.
func GetRegisterTTL() time.Duration {
d, err := time.ParseDuration(os.Getenv(_registryRegisterTTLEnv))
if err != nil {
return _defaultRegisterTTL
}
return d
}
+61
View File
@@ -0,0 +1,61 @@
package registry
import (
"context"
"net/http"
"time"
mRegistry "go-micro.dev/v4/registry"
"github.com/opencloud-eu/opencloud/pkg/log"
)
// RegisterService publishes an arbitrary endpoint to the service-registry. This allows querying nodes of
// non-micro services like reva. No health-checks are done, thus the caller is responsible for canceling.
func RegisterService(ctx context.Context, logger log.Logger, service *mRegistry.Service, debugAddr string) error {
registry := GetRegistry()
node := service.Nodes[0]
logger.Info().Msgf("registering external service %v@%v", node.Id, node.Address)
rOpts := []mRegistry.RegisterOption{mRegistry.RegisterTTL(GetRegisterTTL())}
if err := registry.Register(service, rOpts...); err != nil {
logger.Fatal().Err(err).Msgf("Registration error for external service %v", service.Name)
}
t := time.NewTicker(GetRegisterInterval())
go func() {
// check if the service is ready
delay := 500 * time.Millisecond
for {
resp, err := http.DefaultClient.Get("http://" + debugAddr + "/readyz")
if err == nil && resp.StatusCode == http.StatusOK {
resp.Body.Close()
break
}
time.Sleep(delay)
delay *= 2
}
for {
select {
case <-t.C:
logger.Debug().Interface("service", service).Msg("refreshing external service-registration")
err := registry.Register(service, rOpts...)
if err != nil {
logger.Error().Err(err).Msgf("registration error for external service %v", service.Name)
}
case <-ctx.Done():
logger.Debug().Interface("service", service).Msg("unregistering")
t.Stop()
err := registry.Deregister(service)
if err != nil {
logger.Err(err).Msgf("Error unregistering external service %v", service.Name)
}
return
}
}
}()
return nil
}
+61
View File
@@ -0,0 +1,61 @@
package registry
//
//import (
// "context"
// "testing"
//
// "github.com/micro/go-micro/v2/registry"
// "github.com/opencloud-eu/opencloud/pkg/log"
//)
//
//func TestRegisterGRPCEndpoint(t *testing.T) {
// ctx, cancel := context.WithCancel(context.Background())
// err := RegisterGRPCEndpoint(ctx, "test", "1234", "192.168.0.1:777", log.Logger{})
// if err != nil {
// t.Errorf("Unexpected error: %v", err)
// }
//
// s, err := registry.GetService("test")
// if err != nil {
// t.Errorf("Unexpected error: %v", err)
// }
//
// if len(s) != 1 {
// t.Errorf("Expected exactly one service to be returned got %v", len(s))
// }
//
// if len(s[0].Nodes) != 1 {
// t.Errorf("Expected exactly one node to be returned got %v", len(s[0].Nodes))
// }
//
// testSvc := s[0]
// if testSvc.Name != "test" {
// t.Errorf("Expected service name to be 'test' got %v", s[0].Name)
// }
//
// testNode := testSvc.Nodes[0]
//
// if testNode.Address != "192.168.0.1:777" {
// t.Errorf("Expected node address to be '192.168.0.1:777' got %v", testNode.Address)
// }
//
// if testNode.Id != "test-1234" {
// t.Errorf("Expected node id to be 'test-1234' got %v", testNode.Id)
// }
//
// cancel()
//
// // When switching over to monorepo this little test fails. We're unsure of what the cause is, but since this test
// // is testing a framework specific behavior, we're better off letting it commented out. There is also no use of
// // com.owncloud.reva anywhere in the codebase, so we're effectively only registering reva as a go-micro service,
// // but not sending any message.
// s, err = registry.GetService("test")
// if err != nil {
// t.Errorf("Unexpected error: %v", err)
// }
//
// if len(s) != 0 {
// t.Errorf("Deregister on cancelation failed. Result-length should be zero, got %v", len(s))
// }
//}
+102
View File
@@ -0,0 +1,102 @@
package registry
import (
"fmt"
"os"
"strings"
"sync"
"time"
rRegistry "github.com/cs3org/reva/v2/pkg/registry"
memr "github.com/go-micro/plugins/v4/registry/memory"
"github.com/opencloud-eu/opencloud/pkg/natsjsregistry"
mRegistry "go-micro.dev/v4/registry"
"go-micro.dev/v4/registry/cache"
)
const (
_registryEnv = "MICRO_REGISTRY"
_registryAddressEnv = "MICRO_REGISTRY_ADDRESS"
)
var (
_once sync.Once
_reg mRegistry.Registry
)
// Config is the config for a registry
type Config struct {
Type string `mapstructure:"type"`
Addresses []string `mapstructure:"addresses"`
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
DisableCache bool `mapstructure:"disable_cache"`
RegisterTTL time.Duration `mapstructure:"register_ttl"`
}
// Option allows configuring the registry
type Option func(*Config)
// Inmemory overrides env values to use an in-memory registry
func Inmemory() Option {
return func(c *Config) {
c.Type = "memory"
}
}
// GetRegistry returns a configured micro registry based on Micro env vars.
// It defaults to mDNS, so mind that systems with mDNS disabled by default (i.e SUSE) will have a hard time
// and it needs to explicitly use etcd. Os awareness for providing a working registry out of the box should be done.
func GetRegistry(opts ...Option) mRegistry.Registry {
_once.Do(func() {
cfg := getEnvs(opts...)
switch cfg.Type {
default:
fmt.Println("Attention: unknown registry type, using default nats-js-kv")
fallthrough
case "natsjs", "nats-js", "nats-js-kv": // for backwards compatibility - we will stick with one of those
_reg = natsjsregistry.NewRegistry(
mRegistry.Addrs(cfg.Addresses...),
natsjsregistry.DefaultTTL(cfg.RegisterTTL),
)
case "memory":
_reg = memr.NewRegistry()
cfg.DisableCache = true // no cache needed for in-memory registry
}
// Disable cache if wanted
if !cfg.DisableCache {
_reg = cache.New(_reg, cache.WithTTL(30*time.Second))
}
// fixme: lazy initialization of reva registry, needs refactor to a explicit call per service
_ = rRegistry.Init(_reg)
})
// always use cached registry to prevent registry
// lookup for every request
return _reg
}
func getEnvs(opts ...Option) *Config {
cfg := &Config{
Type: "nats-js-kv",
Addresses: []string{"127.0.0.1:9233"},
}
if s := os.Getenv(_registryEnv); s != "" {
cfg.Type = s
}
if s := strings.Split(os.Getenv(_registryAddressEnv), ","); len(s) > 0 && s[0] != "" {
cfg.Addresses = s
}
cfg.RegisterTTL = GetRegisterTTL()
for _, o := range opts {
o(cfg)
}
return cfg
}
+90
View File
@@ -0,0 +1,90 @@
package registry
import (
"fmt"
"net"
"strconv"
"strings"
mRegistry "go-micro.dev/v4/registry"
"go-micro.dev/v4/server"
mAddr "go-micro.dev/v4/util/addr"
)
func BuildGRPCService(serviceID, transport, address, version string) *mRegistry.Service {
var host string
var port int
parts := strings.Split(address, ":")
if len(parts) > 1 {
host = strings.Join(parts[:len(parts)-1], ":")
port, _ = strconv.Atoi(parts[len(parts)-1])
} else {
host = parts[0]
}
addr := host
if transport != "unix" {
var err error
addr, err = mAddr.Extract(host)
if err != nil {
addr = host
}
addr = net.JoinHostPort(addr, strconv.Itoa(port))
}
node := &mRegistry.Node{
Id: serviceID + "-" + server.DefaultId,
Address: addr,
Metadata: make(map[string]string),
}
node.Metadata["registry"] = GetRegistry().String()
node.Metadata["server"] = "grpc"
node.Metadata["transport"] = transport
node.Metadata["protocol"] = "grpc"
return &mRegistry.Service{
Name: serviceID,
Version: version,
Nodes: []*mRegistry.Node{node},
Endpoints: make([]*mRegistry.Endpoint, 0),
}
}
func BuildHTTPService(serviceID, address string, version string) *mRegistry.Service {
var host string
var port int
parts := strings.Split(address, ":")
if len(parts) > 1 {
host = strings.Join(parts[:len(parts)-1], ":")
port, _ = strconv.Atoi(parts[len(parts)-1])
} else {
host = parts[0]
}
addr, err := mAddr.Extract(host)
if err != nil {
addr = host
}
node := &mRegistry.Node{
// This id is read by the registry watcher
Id: serviceID + "-" + server.DefaultId,
Address: net.JoinHostPort(addr, fmt.Sprint(port)),
Metadata: make(map[string]string),
}
node.Metadata["registry"] = GetRegistry().String()
node.Metadata["server"] = "http"
node.Metadata["transport"] = "http"
node.Metadata["protocol"] = "http"
return &mRegistry.Service{
Name: serviceID,
Version: version,
Nodes: []*mRegistry.Node{node},
Endpoints: make([]*mRegistry.Endpoint, 0),
}
}
+128
View File
@@ -0,0 +1,128 @@
package roles
import (
"context"
"time"
"github.com/cs3org/reva/v2/pkg/store"
"github.com/opencloud-eu/opencloud/pkg/log"
settingsmsg "github.com/opencloud-eu/opencloud/protogen/gen/ocis/messages/settings/v0"
settingssvc "github.com/opencloud-eu/opencloud/protogen/gen/ocis/services/settings/v0"
microstore "go-micro.dev/v4/store"
"google.golang.org/protobuf/encoding/protojson"
)
const (
cacheDatabase = "ocis-pkg"
cacheTableName = "roles"
cacheTTL = time.Hour
)
// Manager manages a cache of roles by fetching unknown roles from the settings.RoleService.
type Manager struct {
logger log.Logger
roleCache microstore.Store
roleService settingssvc.RoleService
}
// NewManager returns a new instance of Manager.
func NewManager(o ...Option) Manager {
opts := newOptions(o...)
nStore := store.Create(opts.storeOptions...)
return Manager{
roleCache: nStore,
roleService: opts.roleService,
}
}
// List returns all roles that match the given roleIDs.
func (m *Manager) List(ctx context.Context, roleIDs []string) []*settingsmsg.Bundle {
// get from cache
result := make([]*settingsmsg.Bundle, 0)
lookup := make([]string, 0)
for _, roleID := range roleIDs {
if records, err := m.roleCache.Read(roleID, microstore.ReadFrom(cacheDatabase, cacheTableName)); err != nil {
lookup = append(lookup, roleID)
} else {
role := &settingsmsg.Bundle{}
found := false
for _, record := range records {
if record.Key == roleID {
if err := protojson.Unmarshal(record.Value, role); err == nil {
// if we can unmarshal the role, append it to the result
// otherwise assume the role wasn't found (data was damaged and
// we need to get the role again)
result = append(result, role)
found = true
break
}
}
}
if !found {
lookup = append(lookup, roleID)
}
}
}
// if there are roles missing, fetch them from the RoleService
if len(lookup) > 0 {
request := &settingssvc.ListBundlesRequest{
BundleIds: lookup,
}
res, err := m.roleService.ListRoles(ctx, request)
if err != nil {
m.logger.Debug().Err(err).Msg("failed to fetch roles by roleIDs")
return nil
}
for _, role := range res.Bundles {
jsonbytes, _ := protojson.Marshal(role)
record := &microstore.Record{
Key: role.Id,
Value: jsonbytes,
Expiry: cacheTTL,
}
err := m.roleCache.Write(
record,
microstore.WriteTo(cacheDatabase, cacheTableName),
microstore.WriteTTL(cacheTTL),
)
if err != nil {
m.logger.Debug().Err(err).Msg("failed to cache roles")
}
result = append(result, role)
}
}
return result
}
// FindPermissionByID searches for a permission-setting by the permissionID, but limited to the given roleIDs
func (m *Manager) FindPermissionByID(ctx context.Context, roleIDs []string, permissionID string) *settingsmsg.Setting {
for _, role := range m.List(ctx, roleIDs) {
for _, setting := range role.Settings {
if setting.Id == permissionID {
return setting
}
}
}
return nil
}
// FindRoleIdsForUser returns all roles that are assigned to the supplied userid
func (m *Manager) FindRoleIDsForUser(ctx context.Context, userID string) ([]string, error) {
req := &settingssvc.ListRoleAssignmentsRequest{AccountUuid: userID}
assignmentResponse, err := m.roleService.ListRoleAssignments(ctx, req)
if err != nil {
return nil, err
}
roleIDs := make([]string, 0, len(assignmentResponse.Assignments))
for _, assignment := range assignmentResponse.Assignments {
roleIDs = append(roleIDs, assignment.RoleId)
}
return roleIDs, nil
}
+48
View File
@@ -0,0 +1,48 @@
package roles
import (
"github.com/opencloud-eu/opencloud/pkg/log"
settingssvc "github.com/opencloud-eu/opencloud/protogen/gen/ocis/services/settings/v0"
"go-micro.dev/v4/store"
)
// Options are all the possible options.
type Options struct {
storeOptions []store.Option
logger log.Logger
roleService settingssvc.RoleService
}
// Option mutates option
type Option func(*Options)
// Logger sets a preconfigured logger
func Logger(logger log.Logger) Option {
return func(o *Options) {
o.logger = logger
}
}
// RoleService provides endpoints for fetching roles.
func RoleService(rs settingssvc.RoleService) Option {
return func(o *Options) {
o.roleService = rs
}
}
// StoreOptions are the options for the store
func StoreOptions(storeOpts []store.Option) Option {
return func(o *Options) {
o.storeOptions = storeOpts
}
}
func newOptions(opts ...Option) Options {
o := Options{}
for _, v := range opts {
v(&o)
}
return o
}
+22
View File
@@ -0,0 +1,22 @@
package roles
import (
"context"
"encoding/json"
"github.com/opencloud-eu/opencloud/pkg/middleware"
"go-micro.dev/v4/metadata"
)
// ReadRoleIDsFromContext extracts roleIDs from the metadata context and returns them as []string
func ReadRoleIDsFromContext(ctx context.Context) (roleIDs []string, ok bool) {
roleIDsJSON, ok := metadata.Get(ctx, middleware.RoleIDs)
if !ok {
return nil, false
}
err := json.Unmarshal([]byte(roleIDsJSON), &roleIDs)
if err != nil {
return nil, false
}
return roleIDs, true
}
+262
View File
@@ -0,0 +1,262 @@
package runner
import (
"context"
"sync"
"sync/atomic"
"time"
)
// GroupRunner represent a group of tasks that need to run together.
// The expectation is that all the tasks will run at the same time, and when
// one of them stops, the rest will also stop.
//
// The GroupRunner is intended to be used to run multiple services, which are
// more or less independent from eachother, but at the same time it doesn't
// make sense to have any of them stopped while the rest are running.
// Basically, either all of them run, or none of them.
// For example, you can have a GRPC and HTTP servers running, each of them
// providing a piece of functionality, however, if any of them fails, the
// feature provided by them would be incomplete or broken.
//
// The interrupt duration for the group can be set through the
// `WithInterruptDuration` option. If the option isn't supplied, the default
// value (15 secs) will be used.
//
// It's recommended that the timeouts are handled by each runner individually,
// meaning that each runner's timeout should be less than the group runner's
// timeout. This way, we can know which runner timed out.
// If the group timeout is reached, the remaining results will have the
// runner's id as "_unknown_".
//
// Note that, as services, the task aren't expected to stop by default.
// This means that, if a task finishes naturally, the rest of the task will
// asked to stop as well.
type GroupRunner struct {
runners sync.Map
runnersCount int
isRunning bool
interruptDur time.Duration
interrupted atomic.Bool
interruptedCh chan time.Duration
runningMutex sync.Mutex
}
// NewGroup will create a GroupRunner
func NewGroup(opts ...Option) *GroupRunner {
options := Options{
InterruptDuration: DefaultGroupInterruptDuration,
}
for _, o := range opts {
o(&options)
}
return &GroupRunner{
runners: sync.Map{},
runningMutex: sync.Mutex{},
interruptDur: options.InterruptDuration,
interruptedCh: make(chan time.Duration, 1),
}
}
// Add will add a runner to the group.
//
// It's mandatory that each runner in the group has an unique id, otherwise
// there will be issues
// Adding new runners once the group starts will cause a panic
func (gr *GroupRunner) Add(r *Runner) {
gr.runningMutex.Lock()
defer gr.runningMutex.Unlock()
if gr.isRunning {
panic("Adding a new runner after the group starts is forbidden")
}
// LoadOrStore will try to store the runner
if _, loaded := gr.runners.LoadOrStore(r.ID, r); loaded {
// there is already a runner with the same id, which is forbidden
panic("Trying to add a runner with an existing Id in the group")
}
// Only increase the count if a runner is stored.
// Currently panicking if the runner exists and is loaded
gr.runnersCount++
}
// Run will execute all the tasks in the group at the same time.
//
// Similarly to the "regular" runner's `Run` method, the execution thread
// will be blocked here until all tasks are completed, and their results
// will be available (each result will have the runner's id so it's easy to
// find which one failed). Note that there is no guarantee about the result's
// order, so the first result in the slice might or might not be the first
// result to be obtained.
//
// When the context is marked as done, the groupRunner will call all the
// stoppers for each runner to notify each task to stop. Note that the tasks
// might still take a while to complete.
//
// If a task finishes naturally (with the context still "alive"), it will also
// cause the groupRunner to call the stoppers of the rest of the tasks. So if
// a task finishes, the rest will also finish.
// Note that it is NOT expected for the finished task's stopper to be called
// in this case.
func (gr *GroupRunner) Run(ctx context.Context) []*Result {
// Set the flag inside the runningMutex to ensure we don't read the old value
// in the `Add` method and add a new runner when this method is being executed
// Note that if multiple `Run` or `RunAsync` happens, the underlying runners
// will panic
gr.runningMutex.Lock()
gr.isRunning = true
gr.runningMutex.Unlock()
results := make([]*Result, 0, gr.runnersCount)
ch := make(chan *Result, gr.runnersCount) // no need to block writing results
gr.runners.Range(func(_, value any) bool {
r := value.(*Runner)
r.RunAsync(ch)
return true
})
var d time.Duration
// wait for a result or for the context to be done
select {
case result := <-ch:
results = append(results, result)
case d = <-gr.interruptedCh:
results = append(results, &Result{
RunnerID: "_unknown_",
RunnerError: NewGroupTimeoutError(d),
})
case <-ctx.Done():
// Do nothing
}
// interrupt the rest of the runners
gr.Interrupt()
// Having notified that the context has been finished, we still need to
// wait for the rest of the results
for i := len(results); i < gr.runnersCount; i++ {
select {
case result := <-ch:
results = append(results, result)
case d2, ok := <-gr.interruptedCh:
if ok {
d = d2
}
results = append(results, &Result{
RunnerID: "_unknown_",
RunnerError: NewGroupTimeoutError(d),
})
}
}
// Even if we reach the group time out and bail out early, tasks might
// be running and eventually deliver the result through the channel.
// We'll rely on the buffered channel so the tasks won't block and the
// data can be eventually garbage-collected along with the unused
// channel, so we won't close the channel here.
return results
}
// RunAsync will execute the tasks in the group asynchronously.
// The result of each task will be placed in the provided channel as soon
// as it's available.
// Note that this method will finish as soon as all the tasks are running.
func (gr *GroupRunner) RunAsync(ch chan<- *Result) {
// Set the flag inside the runningMutex to ensure we don't read the old value
// in the `Add` method and add a new runner when this method is being executed
// Note that if multiple `Run` or `RunAsync` happens, the underlying runners
// will panic
gr.runningMutex.Lock()
gr.isRunning = true
gr.runningMutex.Unlock()
// we need a secondary channel to receive the first result so we can
// interrupt the rest of the tasks
interCh := make(chan *Result, gr.runnersCount)
gr.runners.Range(func(_, value any) bool {
r := value.(*Runner)
r.RunAsync(interCh)
return true
})
go func() {
var result *Result
var d time.Duration
select {
case result = <-interCh:
// result already assigned, so do nothing
case d = <-gr.interruptedCh:
// we aren't tracking which runners have finished and which are still
// running, so we'll use "_unknown_" as runner id
result = &Result{
RunnerID: "_unknown_",
RunnerError: NewGroupTimeoutError(d),
}
}
gr.Interrupt()
ch <- result
for i := 1; i < gr.runnersCount; i++ {
select {
case result = <-interCh:
// result already assigned, so do nothing
case d2, ok := <-gr.interruptedCh:
// if ok is true, d2 will have a good value; if false, the channel
// is closed and we get a default value
if ok {
d = d2
}
result = &Result{
RunnerID: "_unknown_",
RunnerError: NewGroupTimeoutError(d),
}
}
ch <- result
}
}()
}
// Interrupt will execute the stopper function of ALL the tasks, which should
// notify the tasks in order for them to finish.
// The stoppers will be called immediately but sequentially. This means that
// the second stopper won't be called until the first one has returned. This
// usually isn't a problem because the service `Stop`'s methods either don't
// take a long time to return, or they run asynchronously in another goroutine.
//
// As said, this will affect ALL the tasks in the group. It isn't possible to
// try to stop just one task.
// If a task has finished, the corresponding stopper won't be called
//
// The interrupt timeout for the group will start after all the runners in the
// group have been notified. Note that, if the task's stopper for a runner
// takes a lot of time to return, it will delay the timeout's start, so it's
// advised that the stopper either returns fast or is run asynchronously.
func (gr *GroupRunner) Interrupt() {
if gr.interrupted.CompareAndSwap(false, true) {
gr.runners.Range(func(_, value any) bool {
r := value.(*Runner)
select {
case <-r.Finished():
// No data should be sent through the channel, so we'd be
// here only if the channel is closed. This means the task
// has finished and we don't need to interrupt. We do
// nothing in this case
default:
r.Interrupt()
}
return true
})
_ = time.AfterFunc(gr.interruptDur, func() {
// timeout reached -> send it through the channel so our runner
// can abort
gr.interruptedCh <- gr.interruptDur
close(gr.interruptedCh)
})
}
}
+293
View File
@@ -0,0 +1,293 @@
package runner_test
import (
"context"
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/opencloud-eu/opencloud/pkg/runner"
)
var _ = Describe("GroupRunner", func() {
var (
gr *runner.GroupRunner
)
BeforeEach(func() {
gr = runner.NewGroup()
task1Ch := make(chan error)
task1 := TimedTask(task1Ch, 30*time.Second)
gr.Add(runner.New("task1", task1, func() {
task1Ch <- nil
close(task1Ch)
}))
task2Ch := make(chan error)
task2 := TimedTask(task2Ch, 20*time.Second)
gr.Add(runner.New("task2", task2, func() {
task2Ch <- nil
close(task2Ch)
}))
})
Describe("Add", func() {
It("Duplicated runner id panics", func() {
Expect(func() {
gr.Add(runner.New("task1", func() error {
time.Sleep(6 * time.Second)
return nil
}, func() {
}))
}).To(Panic())
})
It("Add after run panics", func(ctx SpecContext) {
// context will be done in 1 second
myCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
// spawn a new goroutine and return the result in the channel
ch2 := make(chan []*runner.Result)
go func(ch2 chan []*runner.Result) {
ch2 <- gr.Run(myCtx)
close(ch2)
}(ch2)
// context is done in 1 sec, so all task should be interrupted and finish
Eventually(ctx, ch2).Should(Receive(ContainElements(
&runner.Result{RunnerID: "task1", RunnerError: nil},
&runner.Result{RunnerID: "task2", RunnerError: nil},
)))
task3Ch := make(chan error)
task3 := TimedTask(task3Ch, 6*time.Second)
Expect(func() {
gr.Add(runner.New("task3", task3, func() {
task3Ch <- nil
close(task3Ch)
}))
}).To(Panic())
}, SpecTimeout(5*time.Second))
It("Add after runAsync panics", func(ctx SpecContext) {
ch2 := make(chan *runner.Result)
gr.RunAsync(ch2)
Expect(func() {
task3Ch := make(chan error)
task3 := TimedTask(task3Ch, 6*time.Second)
gr.Add(runner.New("task3", task3, func() {
task3Ch <- nil
close(task3Ch)
}))
}).To(Panic())
}, SpecTimeout(5*time.Second))
})
Describe("Run", func() {
It("Context is done", func(ctx SpecContext) {
task3Ch := make(chan error)
task3 := TimedTask(task3Ch, 6*time.Second)
gr.Add(runner.New("task3", task3, func() {
task3Ch <- nil
close(task3Ch)
}))
// context will be done in 1 second
myCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
// spawn a new goroutine and return the result in the channel
ch2 := make(chan []*runner.Result)
go func(ch2 chan []*runner.Result) {
ch2 <- gr.Run(myCtx)
close(ch2)
}(ch2)
// context is done in 1 sec, so all task should be interrupted and finish
Eventually(ctx, ch2).Should(Receive(ContainElements(
&runner.Result{RunnerID: "task1", RunnerError: nil},
&runner.Result{RunnerID: "task2", RunnerError: nil},
&runner.Result{RunnerID: "task3", RunnerError: nil},
)))
}, SpecTimeout(5*time.Second))
It("One task finishes early", func(ctx SpecContext) {
task3Ch := make(chan error)
task3 := TimedTask(task3Ch, 1*time.Second)
gr.Add(runner.New("task3", task3, func() {
task3Ch <- nil
close(task3Ch)
}))
// context will be done in 10 second
myCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
// spawn a new goroutine and return the result in the channel
ch2 := make(chan []*runner.Result)
go func(ch2 chan []*runner.Result) {
ch2 <- gr.Run(myCtx)
close(ch2)
}(ch2)
// task3 finishes in 1 sec, so the rest should also be interrupted
Eventually(ctx, ch2).Should(Receive(ContainElements(
&runner.Result{RunnerID: "task1", RunnerError: nil},
&runner.Result{RunnerID: "task2", RunnerError: nil},
&runner.Result{RunnerID: "task3", RunnerError: nil},
)))
}, SpecTimeout(5*time.Second))
It("Context done and group timeout reached", func(ctx SpecContext) {
gr := runner.NewGroup(runner.WithInterruptDuration(2 * time.Second))
gr.Add(runner.New("task1", func() error {
time.Sleep(6 * time.Second)
return nil
}, func() {
}))
gr.Add(runner.New("task2", func() error {
time.Sleep(6 * time.Second)
return nil
}, func() {
}))
// context will be done in 1 second
myCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
// spawn a new goroutine and return the result in the channel
ch2 := make(chan []*runner.Result)
go func(ch2 chan []*runner.Result) {
ch2 <- gr.Run(myCtx)
close(ch2)
}(ch2)
// context finishes in 1 sec, tasks will be interrupted
// group timeout will be reached after 2 extra seconds
Eventually(ctx, ch2).Should(Receive(ContainElements(
&runner.Result{RunnerID: "_unknown_", RunnerError: runner.NewGroupTimeoutError(2 * time.Second)},
&runner.Result{RunnerID: "_unknown_", RunnerError: runner.NewGroupTimeoutError(2 * time.Second)},
)))
}, SpecTimeout(5*time.Second))
It("Interrupted and group timeout reached", func(ctx SpecContext) {
gr := runner.NewGroup(runner.WithInterruptDuration(2 * time.Second))
gr.Add(runner.New("task1", func() error {
time.Sleep(6 * time.Second)
return nil
}, func() {
}))
gr.Add(runner.New("task2", func() error {
time.Sleep(6 * time.Second)
return nil
}, func() {
}))
// context will be done in 10 second
myCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
// spawn a new goroutine and return the result in the channel
ch2 := make(chan []*runner.Result)
go func(ch2 chan []*runner.Result) {
ch2 <- gr.Run(myCtx)
close(ch2)
}(ch2)
gr.Interrupt()
// tasks will be interrupted
// group timeout will be reached after 2 extra seconds
Eventually(ctx, ch2).Should(Receive(ContainElements(
&runner.Result{RunnerID: "_unknown_", RunnerError: runner.NewGroupTimeoutError(2 * time.Second)},
&runner.Result{RunnerID: "_unknown_", RunnerError: runner.NewGroupTimeoutError(2 * time.Second)},
)))
}, SpecTimeout(5*time.Second))
It("Doble run panics", func(ctx SpecContext) {
// context will be done in 1 second
myCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
Expect(func() {
gr.Run(myCtx)
gr.Run(myCtx)
}).To(Panic())
}, SpecTimeout(5*time.Second))
})
Describe("RunAsync", func() {
It("Wait in channel", func(ctx SpecContext) {
task3Ch := make(chan error)
task3 := TimedTask(task3Ch, 1*time.Second)
gr.Add(runner.New("task3", task3, func() {
task3Ch <- nil
close(task3Ch)
}))
ch2 := make(chan *runner.Result)
gr.RunAsync(ch2)
// task3 finishes in 1 sec, so the rest should also be interrupted
Eventually(ctx, ch2).Should(Receive())
Eventually(ctx, ch2).Should(Receive())
Eventually(ctx, ch2).Should(Receive())
}, SpecTimeout(5*time.Second))
It("Double runAsync panics", func(ctx SpecContext) {
ch2 := make(chan *runner.Result)
Expect(func() {
gr.RunAsync(ch2)
gr.RunAsync(ch2)
}).To(Panic())
}, SpecTimeout(5*time.Second))
It("Interrupt async", func(ctx SpecContext) {
task3Ch := make(chan error)
task3 := TimedTask(task3Ch, 6*time.Second)
gr.Add(runner.New("task3", task3, func() {
task3Ch <- nil
close(task3Ch)
}))
ch2 := make(chan *runner.Result)
gr.RunAsync(ch2)
gr.Interrupt()
// tasks will be interrupted
Eventually(ctx, ch2).Should(Receive())
Eventually(ctx, ch2).Should(Receive())
Eventually(ctx, ch2).Should(Receive())
}, SpecTimeout(5*time.Second))
It("Interrupt async group timeout reached", func(ctx SpecContext) {
gr := runner.NewGroup(runner.WithInterruptDuration(2 * time.Second))
gr.Add(runner.New("task1", func() error {
time.Sleep(6 * time.Second)
return nil
}, func() {
}))
gr.Add(runner.New("task2", func() error {
time.Sleep(6 * time.Second)
return nil
}, func() {
}))
ch2 := make(chan *runner.Result)
gr.RunAsync(ch2)
gr.Interrupt()
// group timeout will be reached after 2 extra seconds
Eventually(ctx, ch2).Should(Receive(Equal(&runner.Result{RunnerID: "_unknown_", RunnerError: runner.NewGroupTimeoutError(2 * time.Second)})))
Eventually(ctx, ch2).Should(Receive(Equal(&runner.Result{RunnerID: "_unknown_", RunnerError: runner.NewGroupTimeoutError(2 * time.Second)})))
}, SpecTimeout(5*time.Second))
})
})
+30
View File
@@ -0,0 +1,30 @@
package runner
import (
"time"
)
var (
// DefaultInterruptDuration is the default value for the `WithInterruptDuration`
// for the "regular" runners. This global value can be adjusted if needed.
DefaultInterruptDuration = 10 * time.Second
// DefaultGroupInterruptDuration is the default value for the `WithInterruptDuration`
// for the group runners. This global value can be adjusted if needed.
DefaultGroupInterruptDuration = 15 * time.Second
)
// Option defines a single option function.
type Option func(o *Options)
// Options defines the available options for this package.
type Options struct {
InterruptDuration time.Duration
}
// WithInterruptDuration provides a function to set the interrupt
// duration option.
func WithInterruptDuration(val time.Duration) Option {
return func(o *Options) {
o.InterruptDuration = val
}
}
+201
View File
@@ -0,0 +1,201 @@
package runner
import (
"context"
"sync/atomic"
"time"
)
// Runner represents the one executing a long running task, such as a server
// or a service.
// The ID of the runner is public to make identification easier, and the
// Result that it will generated will contain the same ID, so we can
// know which runner provided which result.
//
// Runners are intended to be used only once. Reusing them isn't possible.
// You'd need to create a new runner if you want to rerun the same task.
type Runner struct {
ID string
interruptDur time.Duration
fn Runable
interrupt Stopper
running atomic.Bool
interrupted atomic.Bool
interruptedCh chan time.Duration
finished chan struct{}
}
// New will create a new runner.
// The runner will be created with the provided id (the id must be unique,
// otherwise undefined behavior might occur), and will run the provided
// runable task, using the "interrupt" function to stop that task if needed.
//
// The interrupt duration, which can be set through the `WithInterruptDuration`
// option, will be used to ensure the runner doesn't block forever. If the
// option isn't supplied, the default value (10 secs) will be used.
// The interrupt duration will be used to start a timeout when the
// runner gets interrupted (either the context of the `Run` method is done
// or this runner's `Interrupt` method is called). If the timeout is reached,
// a timeout result will be returned instead of whatever result the task should
// be returning.
//
// Note that it's your responsibility to provide a proper stopper for the task.
// The runner will just call that method assuming it will be enough to
// eventually stop the task at some point.
func New(id string, fn Runable, interrupt Stopper, opts ...Option) *Runner {
options := Options{
InterruptDuration: DefaultInterruptDuration,
}
for _, o := range opts {
o(&options)
}
return &Runner{
ID: id,
interruptDur: options.InterruptDuration,
fn: fn,
interrupt: interrupt,
interruptedCh: make(chan time.Duration, 1),
finished: make(chan struct{}),
}
}
// Run will execute the task associated to this runner in a synchronous way.
// The task will be spawned in a new goroutine, and the current thread will
// wait until the task finishes.
//
// The task will finish "naturally". The stopper will be called in the
// following ways:
// - Manually calling this runner's `Interrupt` method
// - When the provided context is done
// As said, it's expected that calling the provided stopper will be enough to
// make the task to eventually complete.
//
// Once the task finishes, the result will be returned.
// When the context is done, or if the runner is interrupted, a timeout will
// start using the provided "interrupt duration". If this timeout is reached,
// a timeout result will be returned instead of the one from the task. This is
// intended to prevent blocking the main thread indefinitely. A suitable
// duration should be used depending on the task, usually 5, 10 or 30 secs
//
// Some nice things you can do:
// - Use signal.NotifyContext(...) to call the stopper and provide a clean
// shutdown procedure when an OS signal is received
// - Use context.WithDeadline(...) or context.WithTimeout(...) to run the task
// for a limited time
func (r *Runner) Run(ctx context.Context) *Result {
if !r.running.CompareAndSwap(false, true) {
// If not swapped, the task is already running.
// Running the same task multiple times is a bug, so we panic
panic("Runner with id " + r.ID + " was running twice")
}
ch := make(chan *Result)
go r.doTask(ch, true)
select {
case result := <-ch:
return result
case <-ctx.Done():
r.Interrupt()
return <-ch
}
}
// RunAsync will execute the task associated to this runner asynchronously.
// The task will be spawned in a new goroutine and this method will finish.
// The task's result will be written in the provided channel when it's
// available, so you can wait for it if needed. It's up to you to decide
// to use a blocking or non-blocking channel, but the task will always finish
// before writing in the channel.
//
// To interrupt the running task, the only option is to call the `Interrupt`
// method at some point.
func (r *Runner) RunAsync(ch chan<- *Result) {
if !r.running.CompareAndSwap(false, true) {
// If not swapped, the task is already running.
// Running the same task multiple times is a bug, so we panic
panic("Runner with id " + r.ID + " was running twice")
}
go r.doTask(ch, false)
}
// Interrupt will execute the stopper function, which should notify the task
// in order for it to finish.
// The stopper will be called immediately, although it's expected the
// consequences to take a while (task might need a while to stop)
// A timeout will start using the provided "interrupt duration". Once that
// timeout is reached, the task must provide a result with a timeout error.
// Note that, even after returning the timeout result, the task could still
// be being executed and consuming resource.
// This method will be called only once. Further calls won't do anything
func (r *Runner) Interrupt() {
if r.interrupted.CompareAndSwap(false, true) {
go func() {
select {
case <-r.Finished():
// Task finished -> runner should be delivering the result
case <-time.After(r.interruptDur):
// timeout reached -> send it through the channel so our runner
// can abort
r.interruptedCh <- r.interruptDur
close(r.interruptedCh)
}
}()
r.interrupt()
}
}
// Finished will return a receive-only channel that can be used to know when
// the task has finished but the result hasn't been made available yet. The
// channel will be closed (without sending any message) when the task has finished.
// This can be used specially with the `RunAsync` method when multiple runners
// use the same channel: results could be waiting on your side of the channel
func (r *Runner) Finished() <-chan struct{} {
return r.finished
}
// doTask will perform this runner's task and write the result in the provided
// channel. The channel will be closed if requested.
// A result will be provided when either the task finishes naturally or we
// reach the timeout after being interrupted
func (r *Runner) doTask(ch chan<- *Result, closeChan bool) {
tmpCh := make(chan *Result, 1)
// spawn the task and return the result in a temporary channel
go func(tmpCh chan *Result) {
err := r.fn()
close(r.finished)
result := &Result{
RunnerID: r.ID,
RunnerError: err,
}
tmpCh <- result
close(tmpCh)
}(tmpCh)
// wait for the result in the temporary channel or until we get the
// interrupted signal
var result *Result
select {
case d := <-r.interruptedCh:
result = &Result{
RunnerID: r.ID,
RunnerError: NewTimeoutError(r.ID, d),
}
case result = <-tmpCh:
// Just assign the received value, nothing else to do
}
// send the result
ch <- result
if closeChan {
close(ch)
}
}

Some files were not shown because too many files have changed in this diff Show More