mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-05-23 21:38:43 -05:00
Add 'proxy/' from commit '201b9a652685cdfb72ba81c7e7b00ba1c60a0e35'
git-subtree-dir: proxy git-subtree-mainline:571d96e856git-subtree-split:201b9a6526
This commit is contained in:
Vendored
+101
@@ -0,0 +1,101 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Entry represents an entry on the cache. You can type assert on V.
|
||||
type Entry struct {
|
||||
V interface{}
|
||||
Valid bool
|
||||
}
|
||||
|
||||
// Cache is a barebones cache implementation.
|
||||
type Cache struct {
|
||||
entries map[string]map[string]Entry
|
||||
size int
|
||||
m sync.Mutex
|
||||
}
|
||||
|
||||
// NewCache returns a new instance of Cache.
|
||||
func NewCache(o ...Option) Cache {
|
||||
opts := newOptions(o...)
|
||||
|
||||
return Cache{
|
||||
size: opts.size,
|
||||
entries: map[string]map[string]Entry{},
|
||||
}
|
||||
}
|
||||
|
||||
// Get gets an entry on a service `svcKey` by a give `key`.
|
||||
func (c *Cache) Get(svcKey, key string) (*Entry, error) {
|
||||
var value Entry
|
||||
ok := true
|
||||
|
||||
c.m.Lock()
|
||||
defer c.m.Unlock()
|
||||
|
||||
if value, ok = c.entries[svcKey][key]; !ok {
|
||||
return nil, fmt.Errorf("invalid service key: `%v`", key)
|
||||
}
|
||||
|
||||
return &value, nil
|
||||
}
|
||||
|
||||
// Set sets a key / value. It lets a service add entries on a request basis.
|
||||
func (c *Cache) Set(svcKey, key string, val interface{}) error {
|
||||
c.m.Lock()
|
||||
defer c.m.Unlock()
|
||||
|
||||
if !c.fits() {
|
||||
return fmt.Errorf("cache is full")
|
||||
}
|
||||
|
||||
if _, ok := c.entries[svcKey]; !ok {
|
||||
c.entries[svcKey] = map[string]Entry{}
|
||||
}
|
||||
|
||||
if _, ok := c.entries[svcKey][key]; ok {
|
||||
return fmt.Errorf("key `%v` already exists", key)
|
||||
}
|
||||
|
||||
c.entries[svcKey][key] = Entry{
|
||||
V: val,
|
||||
Valid: true,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Invalidate invalidates a cache Entry by key.
|
||||
func (c *Cache) Invalidate(svcKey, key string) error {
|
||||
r, err := c.Get(svcKey, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Valid = false
|
||||
c.entries[svcKey][key] = *r
|
||||
return nil
|
||||
}
|
||||
|
||||
// Evict frees memory from the cache by removing invalid keys. It is a noop.
|
||||
func (c *Cache) Evict() {
|
||||
for _, v := range c.entries {
|
||||
for k, svcEntry := range v {
|
||||
if !svcEntry.Valid {
|
||||
delete(v, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Length returns the amount of entries per service key.
|
||||
func (c *Cache) Length(k string) int {
|
||||
return len(c.entries[k])
|
||||
}
|
||||
|
||||
func (c *Cache) fits() bool {
|
||||
return c.size >= len(c.entries)
|
||||
}
|
||||
Vendored
+99
@@ -0,0 +1,99 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Prevents from invalid import cycle.
|
||||
type AccountsCacheEntry struct {
|
||||
Email string
|
||||
UUID string
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
c := NewCache(
|
||||
Size(256),
|
||||
)
|
||||
|
||||
err := c.Set("accounts", "hello@foo.bar", AccountsCacheEntry{
|
||||
Email: "hello@foo.bar",
|
||||
UUID: "9c31b040-59e2-4a2b-926b-334d9e3fbd05",
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if c.Length("accounts") != 1 {
|
||||
t.Errorf("expected length 1 got `%v`", len(c.entries))
|
||||
}
|
||||
|
||||
item, err := c.Get("accounts", "hello@foo.bar")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if cachedEntry, ok := item.V.(AccountsCacheEntry); !ok {
|
||||
t.Errorf("invalid cached value type")
|
||||
} else {
|
||||
if cachedEntry.Email != "hello@foo.bar" {
|
||||
t.Errorf("invalid value. Expected `hello@foo.bar` got: `%v`", cachedEntry.Email)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvict(t *testing.T) {
|
||||
c := NewCache(
|
||||
Size(256),
|
||||
)
|
||||
|
||||
if err := c.Set("accounts", "hello@foo.bar", AccountsCacheEntry{
|
||||
Email: "hello@foo.bar",
|
||||
UUID: "9c31b040-59e2-4a2b-926b-334d9e3fbd05",
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err := c.Invalidate("accounts", "hello@foo.bar"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
v, err := c.Get("accounts", "hello@foo.bar")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if v.Valid {
|
||||
t.Errorf("cache key unexpected valid state")
|
||||
}
|
||||
|
||||
c.Evict()
|
||||
|
||||
if c.Length("accounts") != 0 {
|
||||
t.Errorf("expected length 0 got `%v`", len(c.entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
svcCache := NewCache(
|
||||
Size(256),
|
||||
)
|
||||
|
||||
err := svcCache.Set("accounts", "node", "0.0.0.0:1234")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
raw, err := svcCache.Get("accounts", "node")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
v, ok := raw.V.(string)
|
||||
if !ok {
|
||||
t.Errorf("invalid type on service node key")
|
||||
}
|
||||
|
||||
if v != "0.0.0.0:1234" {
|
||||
t.Errorf("expected `0.0.0.0:1234` got `%v`", v)
|
||||
}
|
||||
}
|
||||
Vendored
+36
@@ -0,0 +1,36 @@
|
||||
package cache
|
||||
|
||||
import "time"
|
||||
|
||||
// Options are all the possible options.
|
||||
type Options struct {
|
||||
size int
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// Option mutates option
|
||||
type Option func(*Options)
|
||||
|
||||
// Size configures the size of the cache in items.
|
||||
func Size(s int) Option {
|
||||
return func(o *Options) {
|
||||
o.size = s
|
||||
}
|
||||
}
|
||||
|
||||
// TTL rebuilds the cache after the configured duration.
|
||||
func TTL(ttl time.Duration) Option {
|
||||
return func(o *Options) {
|
||||
o.ttl = ttl
|
||||
}
|
||||
}
|
||||
|
||||
func newOptions(opts ...Option) Options {
|
||||
o := Options{}
|
||||
|
||||
for _, v := range opts {
|
||||
v(&o)
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/micro/cli/v2"
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
"github.com/owncloud/ocis-proxy/pkg/flagset"
|
||||
)
|
||||
|
||||
// Health is the entrypoint for the health command.
|
||||
func Health(cfg *config.Config) *cli.Command {
|
||||
return &cli.Command{
|
||||
Name: "health",
|
||||
Usage: "Check health status",
|
||||
Flags: flagset.HealthWithConfig(cfg),
|
||||
Action: func(c *cli.Context) error {
|
||||
logger := NewLogger(cfg)
|
||||
|
||||
resp, err := http.Get(
|
||||
fmt.Sprintf(
|
||||
"http://%s/healthz",
|
||||
cfg.Debug.Addr,
|
||||
),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logger.Fatal().
|
||||
Err(err).
|
||||
Msg("Failed to request health check")
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
logger.Fatal().
|
||||
Int("code", resp.StatusCode).
|
||||
Msg("Health seems to be in bad state")
|
||||
}
|
||||
|
||||
logger.Debug().
|
||||
Int("code", resp.StatusCode).
|
||||
Msg("Health got a good state")
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/micro/cli/v2"
|
||||
"github.com/owncloud/ocis-pkg/v2/log"
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
"github.com/owncloud/ocis-proxy/pkg/flagset"
|
||||
"github.com/owncloud/ocis-proxy/pkg/version"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// Execute is the entry point for the ocis-proxy command.
|
||||
func Execute() error {
|
||||
cfg := config.New()
|
||||
|
||||
app := &cli.App{
|
||||
Name: "ocis-proxy",
|
||||
Version: version.String,
|
||||
Usage: "proxy for Reva/oCIS",
|
||||
Compiled: version.Compiled(),
|
||||
|
||||
Authors: []*cli.Author{
|
||||
{
|
||||
Name: "ownCloud GmbH",
|
||||
Email: "support@owncloud.com",
|
||||
},
|
||||
},
|
||||
|
||||
Flags: flagset.RootWithConfig(cfg),
|
||||
|
||||
Before: func(c *cli.Context) error {
|
||||
return ParseConfig(c, cfg)
|
||||
},
|
||||
|
||||
Commands: []*cli.Command{
|
||||
Server(cfg),
|
||||
Health(cfg),
|
||||
},
|
||||
}
|
||||
|
||||
cli.HelpFlag = &cli.BoolFlag{
|
||||
Name: "help,h",
|
||||
Usage: "Show the help",
|
||||
}
|
||||
|
||||
cli.VersionFlag = &cli.BoolFlag{
|
||||
Name: "version,v",
|
||||
Usage: "Print the version",
|
||||
}
|
||||
|
||||
return app.Run(os.Args)
|
||||
}
|
||||
|
||||
// NewLogger initializes a service-specific logger instance.
|
||||
func NewLogger(cfg *config.Config) log.Logger {
|
||||
return log.NewLogger(
|
||||
log.Name("proxy"),
|
||||
log.Level(cfg.Log.Level),
|
||||
log.Pretty(cfg.Log.Pretty),
|
||||
log.Color(cfg.Log.Color),
|
||||
)
|
||||
}
|
||||
|
||||
// ParseConfig loads proxy configuration from Viper known paths.
|
||||
func ParseConfig(c *cli.Context, cfg *config.Config) error {
|
||||
logger := NewLogger(cfg)
|
||||
|
||||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
viper.SetEnvPrefix("PROXY")
|
||||
viper.AutomaticEnv()
|
||||
|
||||
if c.IsSet("config-file") {
|
||||
viper.SetConfigFile(c.String("config-file"))
|
||||
} else {
|
||||
viper.SetConfigName("proxy")
|
||||
|
||||
viper.AddConfigPath("/etc/ocis")
|
||||
viper.AddConfigPath("$HOME/.ocis")
|
||||
viper.AddConfigPath("./config")
|
||||
}
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
switch err.(type) {
|
||||
case viper.ConfigFileNotFoundError:
|
||||
logger.Info().
|
||||
Msg("Continue without config")
|
||||
case viper.UnsupportedConfigError:
|
||||
logger.Fatal().
|
||||
Err(err).
|
||||
Msg("Unsupported config type")
|
||||
default:
|
||||
logger.Fatal().
|
||||
Err(err).
|
||||
Msg("Failed to read config")
|
||||
}
|
||||
}
|
||||
|
||||
if err := viper.Unmarshal(&cfg); err != nil {
|
||||
logger.Fatal().
|
||||
Err(err).
|
||||
Msg("Failed to parse config")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,319 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"contrib.go.opencensus.io/exporter/jaeger"
|
||||
"contrib.go.opencensus.io/exporter/ocagent"
|
||||
"contrib.go.opencensus.io/exporter/zipkin"
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/justinas/alice"
|
||||
"github.com/micro/cli/v2"
|
||||
mclient "github.com/micro/go-micro/v2/client"
|
||||
"github.com/micro/go-micro/v2/client/grpc"
|
||||
"github.com/oklog/run"
|
||||
openzipkin "github.com/openzipkin/zipkin-go"
|
||||
zipkinhttp "github.com/openzipkin/zipkin-go/reporter/http"
|
||||
acc "github.com/owncloud/ocis-accounts/pkg/proto/v0"
|
||||
"github.com/owncloud/ocis-pkg/v2/log"
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
"github.com/owncloud/ocis-proxy/pkg/cs3"
|
||||
"github.com/owncloud/ocis-proxy/pkg/flagset"
|
||||
"github.com/owncloud/ocis-proxy/pkg/metrics"
|
||||
"github.com/owncloud/ocis-proxy/pkg/middleware"
|
||||
"github.com/owncloud/ocis-proxy/pkg/proxy"
|
||||
"github.com/owncloud/ocis-proxy/pkg/server/debug"
|
||||
proxyHTTP "github.com/owncloud/ocis-proxy/pkg/server/http"
|
||||
settings "github.com/owncloud/ocis-settings/pkg/proto/v0"
|
||||
storepb "github.com/owncloud/ocis-store/pkg/proto/v0"
|
||||
"go.opencensus.io/stats/view"
|
||||
"go.opencensus.io/trace"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Server is the entrypoint for the server command.
|
||||
func Server(cfg *config.Config) *cli.Command {
|
||||
return &cli.Command{
|
||||
Name: "server",
|
||||
Usage: "Start integrated server",
|
||||
Flags: flagset.ServerWithConfig(cfg),
|
||||
Before: func(ctx *cli.Context) error {
|
||||
l := NewLogger(cfg)
|
||||
l.Debug().Str("tracing", strconv.FormatBool(cfg.Tracing.Enabled)).Msg("init: before")
|
||||
if cfg.HTTP.Root != "/" {
|
||||
cfg.HTTP.Root = strings.TrimSuffix(cfg.HTTP.Root, "/")
|
||||
}
|
||||
cfg.PreSignedURL.AllowedHTTPMethods = ctx.StringSlice("presignedurl-allow-method")
|
||||
|
||||
// When running on single binary mode the before hook from the root command won't get called. We manually
|
||||
// call this before hook from ocis command, so the configuration can be loaded.
|
||||
return ParseConfig(ctx, cfg)
|
||||
},
|
||||
Action: func(c *cli.Context) error {
|
||||
logger := NewLogger(cfg)
|
||||
httpNamespace := c.String("http-namespace")
|
||||
|
||||
if cfg.Tracing.Enabled {
|
||||
switch t := cfg.Tracing.Type; t {
|
||||
case "agent":
|
||||
exporter, err := ocagent.NewExporter(
|
||||
ocagent.WithReconnectionPeriod(5*time.Second),
|
||||
ocagent.WithAddress(cfg.Tracing.Endpoint),
|
||||
ocagent.WithServiceName(cfg.Tracing.Service),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logger.Error().
|
||||
Err(err).
|
||||
Str("endpoint", cfg.Tracing.Endpoint).
|
||||
Str("collector", cfg.Tracing.Collector).
|
||||
Msg("Failed to create agent tracing")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
trace.RegisterExporter(exporter)
|
||||
view.RegisterExporter(exporter)
|
||||
|
||||
case "jaeger":
|
||||
exporter, err := jaeger.NewExporter(
|
||||
jaeger.Options{
|
||||
AgentEndpoint: cfg.Tracing.Endpoint,
|
||||
CollectorEndpoint: cfg.Tracing.Collector,
|
||||
ServiceName: cfg.Tracing.Service,
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logger.Error().
|
||||
Err(err).
|
||||
Str("endpoint", cfg.Tracing.Endpoint).
|
||||
Str("collector", cfg.Tracing.Collector).
|
||||
Msg("Failed to create jaeger tracing")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
trace.RegisterExporter(exporter)
|
||||
|
||||
case "zipkin":
|
||||
endpoint, err := openzipkin.NewEndpoint(
|
||||
cfg.Tracing.Service,
|
||||
cfg.Tracing.Endpoint,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logger.Error().
|
||||
Err(err).
|
||||
Str("endpoint", cfg.Tracing.Endpoint).
|
||||
Str("collector", cfg.Tracing.Collector).
|
||||
Msg("Failed to create zipkin tracing")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
exporter := zipkin.NewExporter(
|
||||
zipkinhttp.NewReporter(
|
||||
cfg.Tracing.Collector,
|
||||
),
|
||||
endpoint,
|
||||
)
|
||||
|
||||
trace.RegisterExporter(exporter)
|
||||
|
||||
default:
|
||||
logger.Warn().
|
||||
Str("type", t).
|
||||
Msg("Unknown tracing backend")
|
||||
}
|
||||
|
||||
trace.ApplyConfig(
|
||||
trace.Config{
|
||||
DefaultSampler: trace.AlwaysSample(),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
logger.Debug().
|
||||
Msg("Tracing is not enabled")
|
||||
}
|
||||
|
||||
var (
|
||||
gr = run.Group{}
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
metrics = metrics.New()
|
||||
)
|
||||
|
||||
defer cancel()
|
||||
|
||||
rp := proxy.NewMultiHostReverseProxy(
|
||||
proxy.Logger(logger),
|
||||
proxy.Config(cfg),
|
||||
)
|
||||
|
||||
{
|
||||
server, err := proxyHTTP.Server(
|
||||
proxyHTTP.Handler(rp),
|
||||
proxyHTTP.Logger(logger),
|
||||
proxyHTTP.Namespace(httpNamespace),
|
||||
proxyHTTP.Context(ctx),
|
||||
proxyHTTP.Config(cfg),
|
||||
proxyHTTP.Metrics(metrics),
|
||||
proxyHTTP.Flags(flagset.RootWithConfig(config.New())),
|
||||
proxyHTTP.Flags(flagset.ServerWithConfig(config.New())),
|
||||
proxyHTTP.Middlewares(loadMiddlewares(ctx, logger, cfg)),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logger.Error().
|
||||
Err(err).
|
||||
Str("server", "http").
|
||||
Msg("Failed to initialize server")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
gr.Add(func() error {
|
||||
return server.Run()
|
||||
}, func(_ error) {
|
||||
logger.Info().
|
||||
Str("server", "http").
|
||||
Msg("Shutting down server")
|
||||
|
||||
cancel()
|
||||
})
|
||||
}
|
||||
|
||||
{
|
||||
server, err := debug.Server(
|
||||
debug.Logger(logger),
|
||||
debug.Context(ctx),
|
||||
debug.Config(cfg),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logger.Error().
|
||||
Err(err).
|
||||
Str("server", "debug").
|
||||
Msg("Failed to initialize server")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
gr.Add(func() error {
|
||||
return server.ListenAndServe()
|
||||
}, func(_ error) {
|
||||
ctx, timeout := context.WithTimeout(ctx, 5*time.Second)
|
||||
|
||||
defer timeout()
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
logger.Error().
|
||||
Err(err).
|
||||
Str("server", "debug").
|
||||
Msg("Failed to shutdown server")
|
||||
} else {
|
||||
logger.Info().
|
||||
Str("server", "debug").
|
||||
Msg("Shutting down server")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
{
|
||||
stop := make(chan os.Signal, 1)
|
||||
|
||||
gr.Add(func() error {
|
||||
signal.Notify(stop, os.Interrupt)
|
||||
|
||||
<-stop
|
||||
|
||||
return nil
|
||||
}, func(err error) {
|
||||
close(stop)
|
||||
cancel()
|
||||
})
|
||||
}
|
||||
|
||||
return gr.Run()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func loadMiddlewares(ctx context.Context, l log.Logger, cfg *config.Config) alice.Chain {
|
||||
|
||||
psMW := middleware.PresignedURL(
|
||||
middleware.Logger(l),
|
||||
middleware.Store(storepb.NewStoreService("com.owncloud.api.store", grpc.NewClient())),
|
||||
middleware.PreSignedURLConfig(cfg.PreSignedURL),
|
||||
)
|
||||
|
||||
// TODO this won't work with a registry other than mdns. Look into Micro's client initialization.
|
||||
// https://github.com/owncloud/ocis-proxy/issues/38
|
||||
accounts := acc.NewAccountsService("com.owncloud.api.accounts", mclient.DefaultClient)
|
||||
roles := settings.NewRoleService("com.owncloud.api.settings", mclient.DefaultClient)
|
||||
|
||||
uuidMW := middleware.AccountUUID(
|
||||
middleware.Logger(l),
|
||||
middleware.TokenManagerConfig(cfg.TokenManager),
|
||||
middleware.AccountsClient(accounts),
|
||||
middleware.SettingsRoleService(roles),
|
||||
)
|
||||
|
||||
// the connection will be established in a non blocking fashion
|
||||
sc, err := cs3.GetGatewayServiceClient(cfg.Reva.Address)
|
||||
if err != nil {
|
||||
l.Error().Err(err).
|
||||
Str("gateway", cfg.Reva.Address).
|
||||
Msg("Failed to create reva gateway service client")
|
||||
}
|
||||
|
||||
chMW := middleware.CreateHome(
|
||||
middleware.Logger(l),
|
||||
middleware.RevaGatewayClient(sc),
|
||||
middleware.AccountsClient(accounts),
|
||||
middleware.TokenManagerConfig(cfg.TokenManager),
|
||||
)
|
||||
|
||||
if cfg.OIDC.Issuer != "" {
|
||||
l.Info().Msg("Loading OIDC-Middleware")
|
||||
l.Debug().Interface("oidc_config", cfg.OIDC).Msg("OIDC-Config")
|
||||
|
||||
var oidcHTTPClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: cfg.OIDC.Insecure,
|
||||
},
|
||||
DisableKeepAlives: true,
|
||||
},
|
||||
Timeout: time.Second * 10,
|
||||
}
|
||||
|
||||
customCtx := context.WithValue(ctx, oauth2.HTTPClient, oidcHTTPClient)
|
||||
|
||||
// Initialize a provider by specifying the issuer URL.
|
||||
// it will fetch the keys from the issuer using the .well-known
|
||||
// endpoint
|
||||
provider := func() (middleware.OIDCProvider, error) {
|
||||
return oidc.NewProvider(customCtx, cfg.OIDC.Issuer)
|
||||
}
|
||||
|
||||
oidcMW := middleware.OpenIDConnect(
|
||||
middleware.Logger(l),
|
||||
middleware.HTTPClient(oidcHTTPClient),
|
||||
middleware.OIDCProviderFunc(provider),
|
||||
middleware.OIDCIss(cfg.OIDC.Issuer),
|
||||
)
|
||||
|
||||
return alice.New(middleware.RedirectToHTTPS, oidcMW, psMW, uuidMW, chMW)
|
||||
}
|
||||
|
||||
return alice.New(middleware.RedirectToHTTPS, psMW, uuidMW, chMW)
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package config
|
||||
|
||||
// Log defines the available logging configuration.
|
||||
type Log struct {
|
||||
Level string
|
||||
Pretty bool
|
||||
Color bool
|
||||
}
|
||||
|
||||
// Debug defines the available debug configuration.
|
||||
type Debug struct {
|
||||
Addr string
|
||||
Token string
|
||||
Pprof bool
|
||||
Zpages bool
|
||||
}
|
||||
|
||||
// HTTP defines the available http configuration.
|
||||
type HTTP struct {
|
||||
Addr string
|
||||
Namespace string
|
||||
Root string
|
||||
TLSCert string
|
||||
TLSKey string
|
||||
TLS bool
|
||||
}
|
||||
|
||||
// Tracing defines the available tracing configuration.
|
||||
type Tracing struct {
|
||||
Enabled bool
|
||||
Type string
|
||||
Endpoint string
|
||||
Collector string
|
||||
Service string
|
||||
}
|
||||
|
||||
// Asset defines the available asset configuration.
|
||||
type Asset struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
// Policy enables us to use multiple directors.
|
||||
type Policy struct {
|
||||
Name string
|
||||
Routes []Route
|
||||
}
|
||||
|
||||
// Route define forwarding routes
|
||||
type Route struct {
|
||||
Type RouteType
|
||||
Endpoint string
|
||||
Backend string
|
||||
ApacheVHost bool `mapstructure:"apache-vhost"`
|
||||
}
|
||||
|
||||
// RouteType defines the type of a route
|
||||
type RouteType string
|
||||
|
||||
const (
|
||||
// PrefixRoute are routes matched by a prefix
|
||||
PrefixRoute RouteType = "prefix"
|
||||
// QueryRoute are routes machted by a prefix and query parameters
|
||||
QueryRoute RouteType = "query"
|
||||
// RegexRoute are routes matched by a pattern
|
||||
RegexRoute RouteType = "regex"
|
||||
// DefaultRouteType is the PrefixRoute
|
||||
DefaultRouteType RouteType = PrefixRoute
|
||||
)
|
||||
|
||||
var (
|
||||
// RouteTypes is an array of the available route types
|
||||
RouteTypes []RouteType = []RouteType{QueryRoute, RegexRoute, PrefixRoute}
|
||||
)
|
||||
|
||||
// Reva defines all available REVA configuration.
|
||||
type Reva struct {
|
||||
Address string
|
||||
}
|
||||
|
||||
// Config combines all available configuration parts.
|
||||
type Config struct {
|
||||
File string
|
||||
Log Log
|
||||
Debug Debug
|
||||
HTTP HTTP
|
||||
Tracing Tracing
|
||||
Asset Asset
|
||||
Policies []Policy
|
||||
OIDC OIDC
|
||||
TokenManager TokenManager
|
||||
PolicySelector *PolicySelector `mapstructure:"policy_selector"`
|
||||
Reva Reva
|
||||
PreSignedURL PreSignedURL
|
||||
}
|
||||
|
||||
// OIDC is the config for the OpenID-Connect middleware. If set the proxy will try to authenticate every request
|
||||
// with the configured oidc-provider
|
||||
type OIDC struct {
|
||||
Issuer string
|
||||
Insecure bool
|
||||
}
|
||||
|
||||
// PolicySelector is the toplevel-configuration for different selectors
|
||||
type PolicySelector struct {
|
||||
Static *StaticSelectorConf
|
||||
Migration *MigrationSelectorConf
|
||||
}
|
||||
|
||||
// StaticSelectorConf is the config for the static-policy-selector
|
||||
type StaticSelectorConf struct {
|
||||
Policy string
|
||||
}
|
||||
|
||||
// TokenManager is the config for using the reva token manager
|
||||
type TokenManager struct {
|
||||
JWTSecret string
|
||||
}
|
||||
|
||||
// PreSignedURL is the config for the presigned url middleware
|
||||
type PreSignedURL struct {
|
||||
AllowedHTTPMethods []string
|
||||
}
|
||||
|
||||
// MigrationSelectorConf is the config for the migration-selector
|
||||
type MigrationSelectorConf struct {
|
||||
AccFoundPolicy string `mapstructure:"acc_found_policy"`
|
||||
AccNotFoundPolicy string `mapstructure:"acc_not_found_policy"`
|
||||
UnauthenticatedPolicy string `mapstructure:"unauthenticated_policy"`
|
||||
}
|
||||
|
||||
// New initializes a new configuration
|
||||
func New() *Config {
|
||||
return &Config{}
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/owncloud/ocis-pkg/v2/log"
|
||||
)
|
||||
|
||||
func publicKey(priv interface{}) interface{} {
|
||||
switch k := priv.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
case *ecdsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func pemBlockForKey(priv interface{}, l log.Logger) *pem.Block {
|
||||
switch k := priv.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)}
|
||||
case *ecdsa.PrivateKey:
|
||||
b, err := x509.MarshalECPrivateKey(k)
|
||||
if err != nil {
|
||||
l.Fatal().Err(err).Msg("Unable to marshal ECDSA private key")
|
||||
}
|
||||
return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// GenCert generates TLS-Certificates
|
||||
func GenCert(l log.Logger) error {
|
||||
var priv interface{}
|
||||
var err error
|
||||
|
||||
priv, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
|
||||
if err != nil {
|
||||
l.Fatal().Err(err).Msg("Failed to generate private key")
|
||||
}
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(24 * time.Hour * 365)
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
l.Fatal().Err(err).Msg("Failed to generate serial number")
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Acme Corp"},
|
||||
CommonName: "OCIS",
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
hosts := []string{"127.0.0.1", "localhost"}
|
||||
for _, h := range hosts {
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
template.IPAddresses = append(template.IPAddresses, ip)
|
||||
} else {
|
||||
template.DNSNames = append(template.DNSNames, h)
|
||||
}
|
||||
}
|
||||
|
||||
//template.IsCA = true
|
||||
//template.KeyUsage |= x509.KeyUsageCertSign
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
|
||||
if err != nil {
|
||||
l.Fatal().Err(err).Msg("Failed to create certificate")
|
||||
}
|
||||
|
||||
certOut, err := os.Create("server.crt")
|
||||
if err != nil {
|
||||
l.Fatal().Err(err).Msg("Failed to open server.crt for writing")
|
||||
}
|
||||
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
if err != nil {
|
||||
l.Fatal().Err(err).Msg("Failed to encode certificate")
|
||||
}
|
||||
err = certOut.Close()
|
||||
if err != nil {
|
||||
l.Fatal().Err(err).Msg("Failed to write cert")
|
||||
}
|
||||
l.Info().Msg("Written server.crt")
|
||||
|
||||
keyOut, err := os.OpenFile("server.key", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
l.Fatal().Err(err).Msg("Failed to open server.key for writing")
|
||||
}
|
||||
err = pem.Encode(keyOut, pemBlockForKey(priv, l))
|
||||
if err != nil {
|
||||
l.Fatal().Err(err).Msg("Failed to encode key")
|
||||
}
|
||||
err = keyOut.Close()
|
||||
if err != nil {
|
||||
l.Fatal().Err(err).Msg("Failed to write key")
|
||||
}
|
||||
l.Info().Msg("Written server.key")
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package cs3
|
||||
|
||||
import (
|
||||
gateway "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func newConn(endpoint string) (*grpc.ClientConn, error) {
|
||||
conn, err := grpc.Dial(endpoint, grpc.WithInsecure())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// GetGatewayServiceClient returns a new cs3 gateway client
|
||||
func GetGatewayServiceClient(endpoint string) (gateway.GatewayAPIClient, error) {
|
||||
conn, err := newConn(endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gateway.NewGatewayAPIClient(conn), nil
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
package flagset
|
||||
|
||||
import (
|
||||
"github.com/micro/cli/v2"
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
)
|
||||
|
||||
// RootWithConfig applies cfg to the root flagset
|
||||
func RootWithConfig(cfg *config.Config) []cli.Flag {
|
||||
return []cli.Flag{
|
||||
&cli.StringFlag{
|
||||
Name: "config-file",
|
||||
Value: "",
|
||||
Usage: "Path to config file",
|
||||
EnvVars: []string{"PROXY_CONFIG_FILE"},
|
||||
Destination: &cfg.File,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "log-level",
|
||||
Value: "info",
|
||||
Usage: "Set logging level",
|
||||
EnvVars: []string{"PROXY_LOG_LEVEL"},
|
||||
Destination: &cfg.Log.Level,
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "log-pretty",
|
||||
Value: true,
|
||||
Usage: "Enable pretty logging",
|
||||
EnvVars: []string{"PROXY_LOG_PRETTY"},
|
||||
Destination: &cfg.Log.Pretty,
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "log-color",
|
||||
Value: true,
|
||||
Usage: "Enable colored logging",
|
||||
EnvVars: []string{"PROXY_LOG_COLOR"},
|
||||
Destination: &cfg.Log.Color,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// HealthWithConfig applies cfg to the root flagset
|
||||
func HealthWithConfig(cfg *config.Config) []cli.Flag {
|
||||
return []cli.Flag{
|
||||
&cli.StringFlag{
|
||||
Name: "debug-addr",
|
||||
Value: "0.0.0.0:9109",
|
||||
Usage: "Address to debug endpoint",
|
||||
EnvVars: []string{"PROXY_DEBUG_ADDR"},
|
||||
Destination: &cfg.Debug.Addr,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ServerWithConfig applies cfg to the root flagset
|
||||
func ServerWithConfig(cfg *config.Config) []cli.Flag {
|
||||
return []cli.Flag{
|
||||
&cli.BoolFlag{
|
||||
Name: "tracing-enabled",
|
||||
Usage: "Enable sending traces",
|
||||
EnvVars: []string{"PROXY_TRACING_ENABLED"},
|
||||
Destination: &cfg.Tracing.Enabled,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "tracing-type",
|
||||
Value: "jaeger",
|
||||
Usage: "Tracing backend type",
|
||||
EnvVars: []string{"PROXY_TRACING_TYPE"},
|
||||
Destination: &cfg.Tracing.Type,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "tracing-endpoint",
|
||||
Value: "",
|
||||
Usage: "Endpoint for the agent",
|
||||
EnvVars: []string{"PROXY_TRACING_ENDPOINT"},
|
||||
Destination: &cfg.Tracing.Endpoint,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "tracing-collector",
|
||||
Value: "http://localhost:14268/api/traces",
|
||||
Usage: "Endpoint for the collector",
|
||||
EnvVars: []string{"PROXY_TRACING_COLLECTOR"},
|
||||
Destination: &cfg.Tracing.Collector,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "tracing-service",
|
||||
Value: "proxy",
|
||||
Usage: "Service name for tracing",
|
||||
EnvVars: []string{"PROXY_TRACING_SERVICE"},
|
||||
Destination: &cfg.Tracing.Service,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "debug-addr",
|
||||
Value: "0.0.0.0:9205",
|
||||
Usage: "Address to bind debug server",
|
||||
EnvVars: []string{"PROXY_DEBUG_ADDR"},
|
||||
Destination: &cfg.Debug.Addr,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "debug-token",
|
||||
Value: "",
|
||||
Usage: "Token to grant metrics access",
|
||||
EnvVars: []string{"PROXY_DEBUG_TOKEN"},
|
||||
Destination: &cfg.Debug.Token,
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "debug-pprof",
|
||||
Usage: "Enable pprof debugging",
|
||||
EnvVars: []string{"PROXY_DEBUG_PPROF"},
|
||||
Destination: &cfg.Debug.Pprof,
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "debug-zpages",
|
||||
Usage: "Enable zpages debugging",
|
||||
EnvVars: []string{"PROXY_DEBUG_ZPAGES"},
|
||||
Destination: &cfg.Debug.Zpages,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "http-addr",
|
||||
Value: "0.0.0.0:9200",
|
||||
Usage: "Address to bind http server",
|
||||
EnvVars: []string{"PROXY_HTTP_ADDR"},
|
||||
Destination: &cfg.HTTP.Addr,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "http-root",
|
||||
Value: "/",
|
||||
Usage: "Root path of http server",
|
||||
EnvVars: []string{"PROXY_HTTP_ROOT"},
|
||||
Destination: &cfg.HTTP.Root,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "asset-path",
|
||||
Value: "",
|
||||
Usage: "Path to custom assets",
|
||||
EnvVars: []string{"PROXY_ASSET_PATH"},
|
||||
Destination: &cfg.Asset.Path,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "http-namespace",
|
||||
Value: "com.owncloud",
|
||||
Usage: "Set the base namespace for the http namespace",
|
||||
EnvVars: []string{"PROXY_HTTP_NAMESPACE"},
|
||||
Destination: &cfg.HTTP.Namespace,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "transport-tls-cert",
|
||||
Value: "",
|
||||
Usage: "Certificate file for transport encryption",
|
||||
EnvVars: []string{"PROXY_TRANSPORT_TLS_CERT"},
|
||||
Destination: &cfg.HTTP.TLSCert,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "transport-tls-key",
|
||||
Value: "",
|
||||
Usage: "Secret file for transport encryption",
|
||||
EnvVars: []string{"PROXY_TRANSPORT_TLS_KEY"},
|
||||
Destination: &cfg.HTTP.TLSKey,
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "tls",
|
||||
Usage: "Use TLS (disable only if proxy is behind a TLS-terminating reverse-proxy).",
|
||||
EnvVars: []string{"PROXY_TLS"},
|
||||
Value: true,
|
||||
Destination: &cfg.HTTP.TLS,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "jwt-secret",
|
||||
Value: "Pive-Fumkiu4",
|
||||
Usage: "Used to create JWT to talk to reva, should equal reva's jwt-secret",
|
||||
EnvVars: []string{"PROXY_JWT_SECRET"},
|
||||
Destination: &cfg.TokenManager.JWTSecret,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "reva-gateway-addr",
|
||||
Value: "127.0.0.1:9142",
|
||||
Usage: "REVA Gateway Endpoint",
|
||||
EnvVars: []string{"PROXY_REVA_GATEWAY_ADDR"},
|
||||
Destination: &cfg.Reva.Address,
|
||||
},
|
||||
|
||||
// OIDC
|
||||
|
||||
&cli.StringFlag{
|
||||
Name: "oidc-issuer",
|
||||
Value: "https://localhost:9200",
|
||||
Usage: "OIDC issuer",
|
||||
EnvVars: []string{"PROXY_OIDC_ISSUER"},
|
||||
Destination: &cfg.OIDC.Issuer,
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "oidc-insecure",
|
||||
Value: true,
|
||||
Usage: "OIDC allow insecure communication",
|
||||
EnvVars: []string{"PROXY_OIDC_INSECURE"},
|
||||
Destination: &cfg.OIDC.Insecure,
|
||||
},
|
||||
&cli.StringSliceFlag{
|
||||
Name: "presignedurl-allow-method",
|
||||
Value: cli.NewStringSlice("GET"),
|
||||
Usage: "--presignedurl-allow-method GET [--presignedurl-allow-method POST]",
|
||||
EnvVars: []string{"PRESIGNEDURL_ALLOWED_METHODS"},
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
var (
|
||||
// Namespace defines the namespace for the defines metrics.
|
||||
Namespace = "ocis"
|
||||
|
||||
// Subsystem defines the subsystem for the defines metrics.
|
||||
Subsystem = "proxy"
|
||||
)
|
||||
|
||||
// Metrics defines the available metrics of this service.
|
||||
type Metrics struct {
|
||||
Counter *prometheus.CounterVec
|
||||
Latency *prometheus.SummaryVec
|
||||
Duration *prometheus.HistogramVec
|
||||
}
|
||||
|
||||
// New initializes the available metrics.
|
||||
func New() *Metrics {
|
||||
m := &Metrics{
|
||||
Counter: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: Namespace,
|
||||
Subsystem: Subsystem,
|
||||
Name: "proxy_total",
|
||||
Help: "How many proxy requests processed",
|
||||
}, []string{}),
|
||||
Latency: prometheus.NewSummaryVec(prometheus.SummaryOpts{
|
||||
Namespace: Namespace,
|
||||
Subsystem: Subsystem,
|
||||
Name: "proxy_latency_microseconds",
|
||||
Help: "proxy request latencies in microseconds",
|
||||
}, []string{}),
|
||||
Duration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: Namespace,
|
||||
Subsystem: Subsystem,
|
||||
Name: "proxy_duration_seconds",
|
||||
Help: "proxy method request time in seconds",
|
||||
}, []string{}),
|
||||
}
|
||||
|
||||
prometheus.Register(
|
||||
m.Counter,
|
||||
)
|
||||
|
||||
prometheus.Register(
|
||||
m.Latency,
|
||||
)
|
||||
|
||||
prometheus.Register(
|
||||
m.Duration,
|
||||
)
|
||||
|
||||
return m
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
revauser "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
|
||||
types "github.com/cs3org/go-cs3apis/cs3/types/v1beta1"
|
||||
"github.com/cs3org/reva/pkg/token/manager/jwt"
|
||||
acc "github.com/owncloud/ocis-accounts/pkg/proto/v0"
|
||||
"github.com/owncloud/ocis-pkg/v2/log"
|
||||
"github.com/owncloud/ocis-pkg/v2/oidc"
|
||||
settings "github.com/owncloud/ocis-settings/pkg/proto/v0"
|
||||
)
|
||||
|
||||
func getAccount(l log.Logger, ac acc.AccountsService, query string) (account *acc.Account, status int) {
|
||||
resp, err := ac.ListAccounts(context.Background(), &acc.ListAccountsRequest{
|
||||
Query: query,
|
||||
PageSize: 2,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
l.Error().Err(err).Str("query", query).Msgf("Error fetching from accounts-service")
|
||||
status = http.StatusInternalServerError
|
||||
return
|
||||
}
|
||||
|
||||
if len(resp.Accounts) <= 0 {
|
||||
l.Error().Str("query", query).Msgf("Account not found")
|
||||
status = http.StatusNotFound
|
||||
return
|
||||
}
|
||||
|
||||
if len(resp.Accounts) > 1 {
|
||||
l.Error().Str("query", query).Msgf("More than one account found. Not logging user in.")
|
||||
status = http.StatusForbidden
|
||||
return
|
||||
}
|
||||
|
||||
account = resp.Accounts[0]
|
||||
return
|
||||
}
|
||||
|
||||
func createAccount(l log.Logger, claims *oidc.StandardClaims, ac acc.AccountsService) (*acc.Account, int) {
|
||||
// TODO check if fields are missing.
|
||||
req := &acc.CreateAccountRequest{
|
||||
Account: &acc.Account{
|
||||
DisplayName: claims.DisplayName,
|
||||
PreferredName: claims.PreferredUsername,
|
||||
OnPremisesSamAccountName: claims.PreferredUsername,
|
||||
Mail: claims.Email,
|
||||
CreationType: "LocalAccount",
|
||||
AccountEnabled: true,
|
||||
// TODO assign uidnumber and gidnumber? better do that in ocis-accounts as it can keep track of the next numbers
|
||||
},
|
||||
}
|
||||
created, err := ac.CreateAccount(context.Background(), req)
|
||||
if err != nil {
|
||||
l.Error().Err(err).Interface("account", req.Account).Msg("could not create account")
|
||||
return nil, http.StatusInternalServerError
|
||||
}
|
||||
|
||||
return created, 0
|
||||
}
|
||||
|
||||
// AccountUUID provides a middleware which mints a jwt and adds it to the proxied request based
|
||||
// on the oidc-claims
|
||||
func AccountUUID(opts ...Option) func(next http.Handler) http.Handler {
|
||||
opt := newOptions(opts...)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
// TODO: handle error
|
||||
tokenManager, err := jwt.New(map[string]interface{}{
|
||||
"secret": opt.TokenManagerConfig.JWTSecret,
|
||||
"expires": int64(60),
|
||||
})
|
||||
if err != nil {
|
||||
opt.Logger.Fatal().Err(err).Msgf("Could not initialize token-manager")
|
||||
}
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
l := opt.Logger
|
||||
claims := oidc.FromContext(r.Context())
|
||||
if claims == nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
var account *acc.Account
|
||||
var status int
|
||||
if claims.Email != "" {
|
||||
account, status = getAccount(l, opt.AccountsClient, fmt.Sprintf("mail eq '%s'", strings.ReplaceAll(claims.Email, "'", "''")))
|
||||
} else if claims.PreferredUsername != "" {
|
||||
account, status = getAccount(l, opt.AccountsClient, fmt.Sprintf("preferred_name eq '%s'", strings.ReplaceAll(claims.PreferredUsername, "'", "''")))
|
||||
} else if claims.OcisID != "" {
|
||||
account, status = getAccount(l, opt.AccountsClient, fmt.Sprintf("id eq '%s'", strings.ReplaceAll(claims.OcisID, "'", "''")))
|
||||
} else {
|
||||
// TODO allow lookup by custom claim, eg an id ... or sub
|
||||
l.Error().Err(err).Msgf("Could not lookup account, no mail or preferred_username claim set")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
if status != 0 || account == nil {
|
||||
if status == http.StatusNotFound {
|
||||
account, status = createAccount(l, claims, opt.AccountsClient)
|
||||
if status != 0 {
|
||||
w.WriteHeader(status)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
w.WriteHeader(status)
|
||||
return
|
||||
}
|
||||
}
|
||||
if !account.AccountEnabled {
|
||||
l.Debug().Interface("account", account).Msg("account is disabled")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
groups := make([]string, len(account.MemberOf))
|
||||
for i := range account.MemberOf {
|
||||
// reva needs the unix group name
|
||||
groups[i] = account.MemberOf[i].OnPremisesSamAccountName
|
||||
}
|
||||
|
||||
// fetch active roles from ocis-settings
|
||||
assignmentResponse, err := opt.SettingsRoleService.ListRoleAssignments(r.Context(), &settings.ListRoleAssignmentsRequest{AccountUuid: account.Id})
|
||||
roleIDs := make([]string, 0)
|
||||
if err != nil {
|
||||
l.Err(err).Str("accountID", account.Id).Msg("failed to fetch role assignments")
|
||||
} else {
|
||||
for _, assignment := range assignmentResponse.Assignments {
|
||||
roleIDs = append(roleIDs, assignment.RoleId)
|
||||
}
|
||||
}
|
||||
|
||||
l.Debug().Interface("claims", claims).Interface("account", account).Msgf("Associated claims with uuid")
|
||||
user := &revauser.User{
|
||||
Id: &revauser.UserId{
|
||||
OpaqueId: account.Id,
|
||||
Idp: claims.Iss,
|
||||
},
|
||||
Username: account.OnPremisesSamAccountName,
|
||||
DisplayName: account.DisplayName,
|
||||
Mail: account.Mail,
|
||||
MailVerified: account.ExternalUserState == "" || account.ExternalUserState == "Accepted",
|
||||
Groups: groups,
|
||||
Opaque: &types.Opaque{
|
||||
Map: map[string]*types.OpaqueEntry{},
|
||||
},
|
||||
}
|
||||
|
||||
user.Opaque.Map["uid"] = &types.OpaqueEntry{
|
||||
Decoder: "plain",
|
||||
Value: []byte(strconv.FormatInt(account.UidNumber, 10)),
|
||||
}
|
||||
user.Opaque.Map["gid"] = &types.OpaqueEntry{
|
||||
Decoder: "plain",
|
||||
Value: []byte(strconv.FormatInt(account.GidNumber, 10)),
|
||||
}
|
||||
|
||||
// encode roleIDs as json string
|
||||
roleIDsJSON, jsonErr := json.Marshal(roleIDs)
|
||||
if jsonErr != nil {
|
||||
l.Err(jsonErr).Str("accountID", account.Id).Msg("failed to marshal roleIDs into json")
|
||||
} else {
|
||||
user.Opaque.Map["roles"] = &types.OpaqueEntry{
|
||||
Decoder: "json",
|
||||
Value: roleIDsJSON,
|
||||
}
|
||||
}
|
||||
|
||||
token, err := tokenManager.MintToken(r.Context(), user)
|
||||
|
||||
if err != nil {
|
||||
l.Error().Err(err).Msgf("Could not mint token")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
r.Header.Set("x-access-token", token)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/micro/go-micro/v2/client"
|
||||
"github.com/owncloud/ocis-accounts/pkg/proto/v0"
|
||||
"github.com/owncloud/ocis-pkg/v2/log"
|
||||
"github.com/owncloud/ocis-pkg/v2/oidc"
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
settings "github.com/owncloud/ocis-settings/pkg/proto/v0"
|
||||
)
|
||||
|
||||
// TODO testing the getAccount method should inject a cache
|
||||
func TestGetAccountSuccess(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "success")
|
||||
if _, status := getAccount(log.NewLogger(), mockAccountUUIDMiddlewareAccSvc(false, true), "mail eq 'success'"); status != 0 {
|
||||
t.Errorf("expected an account")
|
||||
}
|
||||
}
|
||||
func TestGetAccountInternalError(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "failure")
|
||||
if _, status := getAccount(log.NewLogger(), mockAccountUUIDMiddlewareAccSvc(true, false), "mail eq 'failure'"); status != http.StatusInternalServerError {
|
||||
t.Errorf("expected an internal server error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountUUIDMiddleware(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "success")
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
m := AccountUUID(
|
||||
Logger(log.NewLogger()),
|
||||
TokenManagerConfig(config.TokenManager{JWTSecret: "secret"}),
|
||||
AccountsClient(mockAccountUUIDMiddlewareAccSvc(false, true)),
|
||||
SettingsRoleService(mockAccountUUIDMiddlewareRolesSvc(false)),
|
||||
)(next)
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "http://www.example.com", nil)
|
||||
w := httptest.NewRecorder()
|
||||
ctx := oidc.NewContext(r.Context(), &oidc.StandardClaims{Email: "success"})
|
||||
r = r.WithContext(ctx)
|
||||
m.ServeHTTP(w, r)
|
||||
|
||||
if r.Header.Get("x-access-token") == "" {
|
||||
t.Errorf("expected a token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountUUIDMiddlewareWithDisabledAccount(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "failure")
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
m := AccountUUID(
|
||||
Logger(log.NewLogger()),
|
||||
TokenManagerConfig(config.TokenManager{JWTSecret: "secret"}),
|
||||
AccountsClient(mockAccountUUIDMiddlewareAccSvc(false, false)),
|
||||
SettingsRoleService(mockAccountUUIDMiddlewareRolesSvc(false)),
|
||||
)(next)
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "http://www.example.com", nil)
|
||||
w := httptest.NewRecorder()
|
||||
ctx := oidc.NewContext(r.Context(), &oidc.StandardClaims{Email: "failure"})
|
||||
r = r.WithContext(ctx)
|
||||
m.ServeHTTP(w, r)
|
||||
|
||||
rsp := w.Result()
|
||||
defer rsp.Body.Close()
|
||||
|
||||
if rsp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected a disabled account to be unauthorized, got: %d", rsp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func mockAccountUUIDMiddlewareAccSvc(retErr, accEnabled bool) proto.AccountsService {
|
||||
return &proto.MockAccountsService{
|
||||
ListFunc: func(ctx context.Context, in *proto.ListAccountsRequest, opts ...client.CallOption) (out *proto.ListAccountsResponse, err error) {
|
||||
if retErr {
|
||||
return nil, fmt.Errorf("error returned by mockAccountsService LIST")
|
||||
}
|
||||
return &proto.ListAccountsResponse{
|
||||
Accounts: []*proto.Account{
|
||||
{
|
||||
Id: "yay",
|
||||
AccountEnabled: accEnabled,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func mockAccountUUIDMiddlewareRolesSvc(returnError bool) settings.RoleService {
|
||||
return &settings.MockRoleService{
|
||||
ListRoleAssignmentsFunc: func(ctx context.Context, req *settings.ListRoleAssignmentsRequest, opts ...client.CallOption) (res *settings.ListRoleAssignmentsResponse, err error) {
|
||||
if returnError {
|
||||
return nil, fmt.Errorf("error returned by mockRoleService.ListRoleAssignments")
|
||||
}
|
||||
return &settings.ListRoleAssignmentsResponse{
|
||||
Assignments: []*settings.UserRoleAssignment{},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
rpc "github.com/cs3org/go-cs3apis/cs3/rpc/v1beta1"
|
||||
provider "github.com/cs3org/go-cs3apis/cs3/storage/provider/v1beta1"
|
||||
"github.com/cs3org/reva/pkg/rgrpc/status"
|
||||
tokenpkg "github.com/cs3org/reva/pkg/token"
|
||||
"github.com/cs3org/reva/pkg/token/manager/jwt"
|
||||
"github.com/micro/go-micro/v2/errors"
|
||||
"github.com/owncloud/ocis-accounts/pkg/proto/v0"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
// CreateHome provides a middleware which sends a CreateHome request to the reva gateway
|
||||
func CreateHome(opts ...Option) func(next http.Handler) http.Handler {
|
||||
opt := newOptions(opts...)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
accounts := opt.AccountsClient
|
||||
|
||||
tokenManager, err := jwt.New(map[string]interface{}{
|
||||
"secret": opt.TokenManagerConfig.JWTSecret,
|
||||
})
|
||||
if err != nil {
|
||||
opt.Logger.Error().Err(err).Msg("error creating a token manager")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
token := r.Header.Get("x-access-token")
|
||||
if token == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := tokenManager.DismantleToken(r.Context(), token)
|
||||
if err != nil {
|
||||
opt.Logger.Err(err).Msg("error getting user from access token")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_, err = accounts.GetAccount(r.Context(), &proto.GetAccountRequest{
|
||||
Id: user.Id.OpaqueId,
|
||||
})
|
||||
if err != nil {
|
||||
e := errors.Parse(err.Error())
|
||||
if e.Code == http.StatusNotFound {
|
||||
opt.Logger.Debug().Msgf("account with id %s not found", user.Id.OpaqueId)
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
opt.Logger.Err(err).Msgf("error getting user with id %s from accounts service", user.Id.OpaqueId)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// we need to pass the token to authenticate the CreateHome request.
|
||||
//ctx := tokenpkg.ContextSetToken(r.Context(), token)
|
||||
ctx := metadata.AppendToOutgoingContext(r.Context(), tokenpkg.TokenHeader, token)
|
||||
|
||||
createHomeReq := &provider.CreateHomeRequest{}
|
||||
createHomeRes, err := opt.RevaGatewayClient.CreateHome(ctx, createHomeReq)
|
||||
|
||||
if err != nil {
|
||||
opt.Logger.Err(err).Msg("error calling CreateHome")
|
||||
} else if createHomeRes.Status.Code != rpc.Code_CODE_OK {
|
||||
err := status.NewErrorFromCode(createHomeRes.Status.Code, "gateway")
|
||||
opt.Logger.Err(err).Msg("error when calling Createhome")
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// RedirectToHTTPS redirects insecure requests to https
|
||||
func RedirectToHTTPS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
|
||||
proto := req.Header.Get("x-forwarded-proto")
|
||||
if proto == "http" || proto == "HTTP" {
|
||||
http.Redirect(res, req, fmt.Sprintf("https://%s%s", req.Host, req.URL), http.StatusPermanentRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(res, req)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
package middleware
|
||||
|
||||
import "net/http"
|
||||
|
||||
// M undocummented
|
||||
type M func(next http.Handler) http.Handler
|
||||
@@ -0,0 +1,119 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
ocisoidc "github.com/owncloud/ocis-pkg/v2/oidc"
|
||||
"github.com/owncloud/ocis-proxy/pkg/cache"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidToken is returned when the request token is invalid.
|
||||
ErrInvalidToken = errors.New("invalid or missing token")
|
||||
|
||||
// svcCache caches requests for given services to prevent round trips to the service
|
||||
svcCache = cache.NewCache(
|
||||
cache.Size(256),
|
||||
)
|
||||
)
|
||||
|
||||
// OIDCProvider used to mock the oidc provider during tests
|
||||
type OIDCProvider interface {
|
||||
UserInfo(ctx context.Context, ts oauth2.TokenSource) (*oidc.UserInfo, error)
|
||||
}
|
||||
|
||||
// OpenIDConnect provides a middleware to check access secured by a static token.
|
||||
func OpenIDConnect(opts ...Option) func(next http.Handler) http.Handler {
|
||||
opt := newOptions(opts...)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
|
||||
var oidcProvider OIDCProvider
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
header := r.Header.Get("Authorization")
|
||||
path := r.URL.Path
|
||||
|
||||
// Ignore request to "/konnect/v1/userinfo" as this will cause endless loop when getting userinfo
|
||||
// needs a better idea on how to not hardcode this
|
||||
if header == "" || !strings.HasPrefix(header, "Bearer ") || path == "/konnect/v1/userinfo" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
customCtx := context.WithValue(r.Context(), oauth2.HTTPClient, opt.HTTPClient)
|
||||
|
||||
// check if oidc provider is initialized
|
||||
if oidcProvider == nil {
|
||||
// Lazily initialize a provider
|
||||
|
||||
// provider needs to be cached as when it is created
|
||||
// it will fetch the keys from the issuer using the .well-known
|
||||
// endpoint
|
||||
var err error
|
||||
oidcProvider, err = opt.OIDCProviderFunc()
|
||||
if err != nil {
|
||||
opt.Logger.Error().Err(err).Msg("could not initialize oidc provider")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(header, "Bearer ")
|
||||
|
||||
// TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token)
|
||||
oauth2Token := &oauth2.Token{
|
||||
AccessToken: token,
|
||||
}
|
||||
|
||||
// The claims we want to have
|
||||
var claims ocisoidc.StandardClaims
|
||||
userInfo, err := oidcProvider.UserInfo(customCtx, oauth2.StaticTokenSource(oauth2Token))
|
||||
if err != nil {
|
||||
opt.Logger.Error().Err(err).Str("token", token).Msg("Failed to get userinfo")
|
||||
http.Error(w, ErrInvalidToken.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err := userInfo.Claims(&claims); err != nil {
|
||||
opt.Logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
//TODO: This should be read from the token instead of config
|
||||
claims.Iss = opt.OIDCIss
|
||||
|
||||
// inject claims to the request context for the account_uuid middleware.
|
||||
ctxWithClaims := ocisoidc.NewContext(r.Context(), &claims)
|
||||
r = r.WithContext(ctxWithClaims)
|
||||
|
||||
opt.Logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo")
|
||||
// store claims in context
|
||||
// uses the original context, not the one with probably reduced security
|
||||
nr := r.WithContext(ocisoidc.NewContext(r.Context(), &claims))
|
||||
|
||||
next.ServeHTTP(w, nr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// AccountsCacheEntry stores a request to the accounts service on the cache.
|
||||
// this type declaration should be on each respective service.
|
||||
type AccountsCacheEntry struct {
|
||||
Email string
|
||||
UUID string
|
||||
}
|
||||
|
||||
const (
|
||||
// AccountsKey declares the svcKey for the Accounts service.
|
||||
AccountsKey = "accounts"
|
||||
|
||||
// NodeKey declares the key that will be used to store the node address.
|
||||
// It is shared between services.
|
||||
NodeKey = "node"
|
||||
)
|
||||
@@ -0,0 +1,67 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/owncloud/ocis-pkg/v2/log"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestOpenIDConnectMiddleware(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "success")
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
m := OpenIDConnect(
|
||||
Logger(log.NewLogger()),
|
||||
OIDCProviderFunc(func() (OIDCProvider, error) {
|
||||
return mockOP(false), nil
|
||||
}),
|
||||
)(next)
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "https://idp.example.com", nil)
|
||||
r.Header.Set("Authorization", "Bearer sometoken")
|
||||
w := httptest.NewRecorder()
|
||||
m.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected an internal server error")
|
||||
}
|
||||
}
|
||||
|
||||
type mockOIDCProvider struct {
|
||||
UserInfoFunc func(ctx context.Context, ts oauth2.TokenSource) (*oidc.UserInfo, error)
|
||||
}
|
||||
|
||||
// UserInfo will panic if the function has been called, but not mocked
|
||||
func (m mockOIDCProvider) UserInfo(ctx context.Context, ts oauth2.TokenSource) (*oidc.UserInfo, error) {
|
||||
if m.UserInfoFunc != nil {
|
||||
return m.UserInfoFunc(ctx, ts)
|
||||
}
|
||||
|
||||
panic("UserInfo was called in test but not mocked")
|
||||
}
|
||||
|
||||
func mockOP(retErr bool) OIDCProvider {
|
||||
if retErr {
|
||||
return &mockOIDCProvider{
|
||||
UserInfoFunc: func(ctx context.Context, ts oauth2.TokenSource) (*oidc.UserInfo, error) {
|
||||
return nil, fmt.Errorf("error returned by mockOIDCProvider UserInfo")
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
return &mockOIDCProvider{
|
||||
UserInfoFunc: func(ctx context.Context, ts oauth2.TokenSource) (*oidc.UserInfo, error) {
|
||||
ui := &oidc.UserInfo{
|
||||
// claims: private ...
|
||||
}
|
||||
return ui, nil
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
settings "github.com/owncloud/ocis-settings/pkg/proto/v0"
|
||||
"net/http"
|
||||
|
||||
gateway "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1"
|
||||
acc "github.com/owncloud/ocis-accounts/pkg/proto/v0"
|
||||
"github.com/owncloud/ocis-pkg/v2/log"
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
storepb "github.com/owncloud/ocis-store/pkg/proto/v0"
|
||||
)
|
||||
|
||||
// Option defines a single option function.
|
||||
type Option func(o *Options)
|
||||
|
||||
// Options defines the available options for this package.
|
||||
type Options struct {
|
||||
// Logger to use for logging, must be set
|
||||
Logger log.Logger
|
||||
// TokenManagerConfig for communicating with the reva token manager
|
||||
TokenManagerConfig config.TokenManager
|
||||
// HTTPClient to use for communication with the oidc provider
|
||||
HTTPClient *http.Client
|
||||
// AccountsClient for resolving accounts
|
||||
AccountsClient acc.AccountsService
|
||||
// SettingsRoleService for the roles API in settings
|
||||
SettingsRoleService settings.RoleService
|
||||
// OIDCProviderFunc to lazily initialize a provider, must be set for the oidcProvider middleware
|
||||
OIDCProviderFunc func() (OIDCProvider, error)
|
||||
// OIDCIss is the oidc-issuer
|
||||
OIDCIss string
|
||||
// RevaGatewayClient to send requests to the reva gateway
|
||||
RevaGatewayClient gateway.GatewayAPIClient
|
||||
// Store for persisting data
|
||||
Store storepb.StoreService
|
||||
// PreSignedURLConfig to configure the middleware
|
||||
PreSignedURLConfig config.PreSignedURL
|
||||
}
|
||||
|
||||
// newOptions initializes the available default options.
|
||||
func newOptions(opts ...Option) Options {
|
||||
opt := Options{}
|
||||
|
||||
for _, o := range opts {
|
||||
o(&opt)
|
||||
}
|
||||
|
||||
return opt
|
||||
}
|
||||
|
||||
// Logger provides a function to set the logger option.
|
||||
func Logger(l log.Logger) Option {
|
||||
return func(o *Options) {
|
||||
o.Logger = l
|
||||
}
|
||||
}
|
||||
|
||||
// TokenManagerConfig provides a function to set the token manger config option.
|
||||
func TokenManagerConfig(cfg config.TokenManager) Option {
|
||||
return func(o *Options) {
|
||||
o.TokenManagerConfig = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPClient provides a function to set the http client config option.
|
||||
func HTTPClient(c *http.Client) Option {
|
||||
return func(o *Options) {
|
||||
o.HTTPClient = c
|
||||
}
|
||||
}
|
||||
|
||||
// AccountsClient provides a function to set the accounts client config option.
|
||||
func AccountsClient(ac acc.AccountsService) Option {
|
||||
return func(o *Options) {
|
||||
o.AccountsClient = ac
|
||||
}
|
||||
}
|
||||
|
||||
// SettingsRoleService provides a function to set the role service option.
|
||||
func SettingsRoleService(rc settings.RoleService) Option {
|
||||
return func(o *Options) {
|
||||
o.SettingsRoleService = rc
|
||||
}
|
||||
}
|
||||
|
||||
// OIDCProviderFunc provides a function to set the the oidc provider function option.
|
||||
func OIDCProviderFunc(f func() (OIDCProvider, error)) Option {
|
||||
return func(o *Options) {
|
||||
o.OIDCProviderFunc = f
|
||||
}
|
||||
}
|
||||
|
||||
// OIDCIss sets the oidc issuer url
|
||||
func OIDCIss(iss string) Option {
|
||||
return func(o *Options) {
|
||||
o.OIDCIss = iss
|
||||
}
|
||||
}
|
||||
|
||||
// RevaGatewayClient provides a function to set the the reva gateway service client option.
|
||||
func RevaGatewayClient(gc gateway.GatewayAPIClient) Option {
|
||||
return func(o *Options) {
|
||||
o.RevaGatewayClient = gc
|
||||
}
|
||||
}
|
||||
|
||||
// Store provides a function to set the store option.
|
||||
func Store(sc storepb.StoreService) Option {
|
||||
return func(o *Options) {
|
||||
o.Store = sc
|
||||
}
|
||||
}
|
||||
|
||||
// PreSignedURLConfig provides a function to set the PreSignedURL config
|
||||
func PreSignedURLConfig(cfg config.PreSignedURL) Option {
|
||||
return func(o *Options) {
|
||||
o.PreSignedURLConfig = cfg
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha512"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/owncloud/ocis-pkg/v2/log"
|
||||
ocisoidc "github.com/owncloud/ocis-pkg/v2/oidc"
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
storepb "github.com/owncloud/ocis-store/pkg/proto/v0"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
const (
|
||||
iterations = 10000
|
||||
keyLen = 32
|
||||
)
|
||||
|
||||
// PresignedURL provides a middleware to check access secured by a presigned URL.
|
||||
func PresignedURL(opts ...Option) func(next http.Handler) http.Handler {
|
||||
opt := newOptions(opts...)
|
||||
l := opt.Logger
|
||||
cfg := opt.PreSignedURLConfig
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if isSignedRequest(r) {
|
||||
if signedRequestIsValid(l, r, opt.Store, cfg) {
|
||||
// use openid claims to let the account_uuid middleware do a lookup by username
|
||||
claims := ocisoidc.StandardClaims{
|
||||
OcisID: r.URL.Query().Get("OC-Credential"),
|
||||
}
|
||||
|
||||
// inject claims to the request context for the account_uuid middleware
|
||||
ctxWithClaims := ocisoidc.NewContext(r.Context(), &claims)
|
||||
r = r.WithContext(ctxWithClaims)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
} else {
|
||||
http.Error(w, "Invalid url signature", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func isSignedRequest(r *http.Request) bool {
|
||||
return r.URL.Query().Get("OC-Signature") != ""
|
||||
}
|
||||
|
||||
func signedRequestIsValid(l log.Logger, r *http.Request, s storepb.StoreService, cfg config.PreSignedURL) bool {
|
||||
// TODO OC-Algorithm - defined the used algo (e.g. sha256 or sha512 - we should agree on one default algo and make this parameter optional)
|
||||
// TODO OC-Verb - defines for which http verb the request is valid - defaults to GET OPTIONAL
|
||||
|
||||
return allRequiredParametersArePresent(r) &&
|
||||
requestMethodMatches(r) &&
|
||||
requestMethodIsAllowed(r.Method, cfg.AllowedHTTPMethods) &&
|
||||
!urlIsExpired(r, time.Now) &&
|
||||
signatureIsValid(l, r, s)
|
||||
}
|
||||
|
||||
func allRequiredParametersArePresent(r *http.Request) bool {
|
||||
// OC-Credential - defines the user scope (shall we use the owncloud user id here - this might leak internal data ....) REQUIRED
|
||||
// OC-Date - defined the date the url was signed (ISO 8601 UTC) REQUIRED
|
||||
// OC-Expires - defines the expiry interval in seconds (between 1 and 604800 = 7 days) REQUIRED
|
||||
// OC-Signature - the computed signature - server will verify the request upon this REQUIRED
|
||||
return r.URL.Query().Get("OC-Signature") != "" &&
|
||||
r.URL.Query().Get("OC-Credential") != "" &&
|
||||
r.URL.Query().Get("OC-Date") != "" &&
|
||||
r.URL.Query().Get("OC-Expires") != "" &&
|
||||
r.URL.Query().Get("OC-Verb") != ""
|
||||
}
|
||||
|
||||
func requestMethodMatches(r *http.Request) bool {
|
||||
return strings.EqualFold(r.Method, r.URL.Query().Get("OC-Verb"))
|
||||
}
|
||||
|
||||
func requestMethodIsAllowed(m string, allowedMethods []string) bool {
|
||||
for _, allowed := range allowedMethods {
|
||||
if strings.EqualFold(m, allowed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func urlIsExpired(r *http.Request, now func() time.Time) bool {
|
||||
t, err := time.Parse(time.RFC3339, r.URL.Query().Get("OC-Date"))
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
expires, err := time.ParseDuration(r.URL.Query().Get("OC-Expires") + "s")
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
t.Add(expires)
|
||||
return t.After(now())
|
||||
}
|
||||
|
||||
func signatureIsValid(l log.Logger, r *http.Request, s storepb.StoreService) bool {
|
||||
signingKey, err := getSigningKey(r.Context(), s, r.URL.Query().Get("OC-Credential"))
|
||||
if err != nil {
|
||||
l.Error().Err(err).Msg("could not retrieve signing key")
|
||||
return false
|
||||
}
|
||||
if len(signingKey) == 0 {
|
||||
l.Error().Err(err).Msg("signing key empty")
|
||||
return false
|
||||
}
|
||||
|
||||
q := r.URL.Query()
|
||||
signature := q.Get("OC-Signature")
|
||||
q.Del("OC-Signature")
|
||||
r.URL.RawQuery = q.Encode()
|
||||
url := r.URL.String()
|
||||
if !r.URL.IsAbs() {
|
||||
url = "https://" + r.Host + url // TODO where do we get the scheme from
|
||||
}
|
||||
return createSignature(url, signingKey) == signature
|
||||
}
|
||||
|
||||
func createSignature(url string, signingKey []byte) string {
|
||||
// the oc10 signature check: $hash = \hash_pbkdf2("sha512", $url, $signingKey, 10000, 64, false);
|
||||
// - sets the length of the output string to 64
|
||||
// - sets raw output to false -> if raw_output is FALSE length corresponds to twice the byte-length of the derived key (as every byte of the key is returned as two hexits).
|
||||
// TODO change to length 128 in oc10?
|
||||
// fo golangs pbkdf2.Key we need to use 32 because it will be encoded into 64 hexits later
|
||||
hash := pbkdf2.Key([]byte(url), signingKey, iterations, keyLen, sha512.New)
|
||||
return hex.EncodeToString(hash)
|
||||
}
|
||||
|
||||
func getSigningKey(ctx context.Context, s storepb.StoreService, credential string) ([]byte, error) {
|
||||
res, err := s.Read(ctx, &storepb.ReadRequest{
|
||||
Options: &storepb.ReadOptions{
|
||||
Database: "proxy",
|
||||
Table: "signing-keys",
|
||||
},
|
||||
Key: credential,
|
||||
})
|
||||
if err != nil || len(res.Records) < 1 {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
return res.Records[0].Value, nil
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIsSignedRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
url string
|
||||
expected bool
|
||||
}{
|
||||
{"https://example.com/example.jpg", false},
|
||||
{"https://example.com/example.jpg?OC-Signature=something", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
r := httptest.NewRequest("", tt.url, nil)
|
||||
result := isSignedRequest(r)
|
||||
if result != tt.expected {
|
||||
t.Errorf("with %s expected %t got %t", tt.url, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllRequiredParametersPresent(t *testing.T) {
|
||||
baseURL := "https://example.com/example.jpg?"
|
||||
tests := []struct {
|
||||
params string
|
||||
expected bool
|
||||
}{
|
||||
{"OC-Signature=something&OC-Credential=something&OC-Date=something&OC-Expires=something&OC-Verb=something", true},
|
||||
{"OC-Credential=something&OC-Date=something&OC-Expires=something&OC-Verb=something", false},
|
||||
{"OC-Signature=something&OC-Date=something&OC-Expires=something&OC-Verb=something", false},
|
||||
{"OC-Signature=something&OC-Credential=something&OC-Expires=something&OC-Verb=something", false},
|
||||
{"OC-Signature=something&OC-Credential=something&OC-Date=something&OC-Verb=something", false},
|
||||
{"OC-Signature=something&OC-Credential=something&OC-Date=something&OC-Expires=something", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
r := httptest.NewRequest("", baseURL+tt.params, nil)
|
||||
result := allRequiredParametersArePresent(r)
|
||||
if result != tt.expected {
|
||||
t.Errorf("with %s expected %t got %t", tt.params, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestMethodMatches(t *testing.T) {
|
||||
tests := []struct {
|
||||
method string
|
||||
url string
|
||||
expected bool
|
||||
}{
|
||||
{"GET", "https://example.com/example.jpg?OC-Verb=GET", true},
|
||||
{"GET", "https://example.com/example.jpg?OC-Verb=get", true},
|
||||
{"POST", "https://example.com/example.jpg?OC-Verb=GET", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
r := httptest.NewRequest(tt.method, tt.url, nil)
|
||||
result := requestMethodMatches(r)
|
||||
if result != tt.expected {
|
||||
t.Errorf("with method %s and url %s expected %t got %t", tt.method, tt.url, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestMethodIsAllowed(t *testing.T) {
|
||||
tests := []struct {
|
||||
method string
|
||||
allowed []string
|
||||
expected bool
|
||||
}{
|
||||
{"GET", []string{}, false},
|
||||
{"GET", []string{"POST"}, false},
|
||||
{"GET", []string{"GET"}, true},
|
||||
{"GET", []string{"get"}, true},
|
||||
{"GET", []string{"POST", "GET"}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := requestMethodIsAllowed(tt.method, tt.allowed)
|
||||
if result != tt.expected {
|
||||
t.Errorf("with method %s and allowed methods %s expected %t got %t", tt.method, tt.allowed, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUrlIsExpired(t *testing.T) {
|
||||
nowFunc := func() time.Time {
|
||||
t, _ := time.Parse(time.RFC3339, "2020-08-19T15:12:43.478Z")
|
||||
return t
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
url string
|
||||
expected bool
|
||||
}{
|
||||
{"http://example.com/example.jpg?OC-Date=2020-08-19T15:02:43.478Z&OC-Expires=1200", false},
|
||||
{"http://example.com/example.jpg?OC-Date=invalid&OC-Expires=1200", true},
|
||||
{"http://example.com/example.jpg?OC-Date=2020-08-19T15:02:43.478Z&OC-Expires=invalid", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
r := httptest.NewRequest("", tt.url, nil)
|
||||
result := urlIsExpired(r, nowFunc)
|
||||
if result != tt.expected {
|
||||
t.Errorf("with %s expected %t got %t", tt.url, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSignature(t *testing.T) {
|
||||
expected := "27d2ebea381384af3179235114801dcd00f91e46f99fca72575301cf3948101d"
|
||||
s := createSignature("something", []byte("somerandomkey"))
|
||||
|
||||
if s != expected {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"github.com/owncloud/ocis-pkg/v2/log"
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
)
|
||||
|
||||
// Option defines a single option function.
|
||||
type Option func(o *Options)
|
||||
|
||||
// Options defines the available options for this package.
|
||||
type Options struct {
|
||||
Logger log.Logger
|
||||
Config *config.Config
|
||||
}
|
||||
|
||||
// newOptions initializes the available default options.
|
||||
func newOptions(opts ...Option) Options {
|
||||
opt := Options{}
|
||||
|
||||
for _, o := range opts {
|
||||
o(&opt)
|
||||
}
|
||||
|
||||
return opt
|
||||
}
|
||||
|
||||
// Logger provides a function to set the logger option.
|
||||
func Logger(val log.Logger) Option {
|
||||
return func(o *Options) {
|
||||
o.Logger = val
|
||||
}
|
||||
}
|
||||
|
||||
// Config provides a function to set the config option.
|
||||
func Config(val *config.Config) Option {
|
||||
return func(o *Options) {
|
||||
o.Config = val
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/micro/go-micro/v2/client/grpc"
|
||||
accounts "github.com/owncloud/ocis-accounts/pkg/proto/v0"
|
||||
"github.com/owncloud/ocis-pkg/v2/oidc"
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrMultipleSelectors in case there is more then one selector configured.
|
||||
ErrMultipleSelectors = fmt.Errorf("only one type of policy-selector (static or migration) can be configured")
|
||||
// ErrSelectorConfigIncomplete if policy_selector conf is missing
|
||||
ErrSelectorConfigIncomplete = fmt.Errorf("missing either \"static\" or \"migration\" configuration in policy_selector config ")
|
||||
// ErrUnexpectedConfigError unexpected config error
|
||||
ErrUnexpectedConfigError = fmt.Errorf("could not initialize policy-selector for given config")
|
||||
)
|
||||
|
||||
// Selector is a function which selects a proxy-policy based on the request.
|
||||
//
|
||||
// A policy is a random name which identifies a set of proxy-routes:
|
||||
//{
|
||||
// "policies": [
|
||||
// {
|
||||
// "name": "us-east-1",
|
||||
// "routes": [
|
||||
// {
|
||||
// "endpoint": "/",
|
||||
// "backend": "https://backend.us.example.com:8080/app"
|
||||
// }
|
||||
// ]
|
||||
// },
|
||||
// {
|
||||
// "name": "eu-ams-1",
|
||||
// "routes": [
|
||||
// {
|
||||
// "endpoint": "/",
|
||||
// "backend": "https://backend.eu.example.com:8080/app"
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// ]
|
||||
//}
|
||||
type Selector func(ctx context.Context, r *http.Request) (string, error)
|
||||
|
||||
// LoadSelector constructs a specific policy-selector from a given configuration
|
||||
func LoadSelector(cfg *config.PolicySelector) (Selector, error) {
|
||||
if cfg.Migration != nil && cfg.Static != nil {
|
||||
return nil, ErrMultipleSelectors
|
||||
}
|
||||
|
||||
if cfg.Migration == nil && cfg.Static == nil {
|
||||
return nil, ErrSelectorConfigIncomplete
|
||||
}
|
||||
|
||||
if cfg.Static != nil {
|
||||
return NewStaticSelector(cfg.Static), nil
|
||||
}
|
||||
|
||||
if cfg.Migration != nil {
|
||||
return NewMigrationSelector(
|
||||
cfg.Migration,
|
||||
accounts.NewAccountsService("com.owncloud.accounts", grpc.NewClient())), nil
|
||||
}
|
||||
|
||||
return nil, ErrUnexpectedConfigError
|
||||
}
|
||||
|
||||
// NewStaticSelector returns a selector which uses a pre-configured policy.
|
||||
//
|
||||
// Configuration:
|
||||
//
|
||||
// "policy_selector": {
|
||||
// "static": {"policy" : "reva"}
|
||||
// },
|
||||
func NewStaticSelector(cfg *config.StaticSelectorConf) Selector {
|
||||
return func(ctx context.Context, r *http.Request) (s string, err error) {
|
||||
return cfg.Policy, nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewMigrationSelector selects the policy based on the existence of the oidc "preferred_username" claim in the accounts-service.
|
||||
// The policy for each case is configurable:
|
||||
// "policy_selector": {
|
||||
// "migration": {
|
||||
// "acc_found_policy" : "reva",
|
||||
// "acc_not_found_policy": "oc10",
|
||||
// "unauthenticated_policy": "oc10"
|
||||
// }
|
||||
// },
|
||||
//
|
||||
// This selector can be used in migration-scenarios where some users have already migrated from ownCloud10 to OCIS and
|
||||
// thus have an entry in ocis-accounts. All users without accounts entry are routed to the legacy ownCloud10 instance.
|
||||
func NewMigrationSelector(cfg *config.MigrationSelectorConf, ss accounts.AccountsService) Selector {
|
||||
var acc = ss
|
||||
return func(ctx context.Context, r *http.Request) (s string, err error) {
|
||||
var userID string
|
||||
if claims := oidc.FromContext(r.Context()); claims != nil {
|
||||
userID = claims.PreferredUsername
|
||||
if _, err := acc.GetAccount(ctx, &accounts.GetAccountRequest{Id: userID}); err != nil {
|
||||
return cfg.AccNotFoundPolicy, nil
|
||||
}
|
||||
|
||||
return cfg.AccFoundPolicy, nil
|
||||
}
|
||||
|
||||
return cfg.UnauthenticatedPolicy, nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/micro/go-micro/v2/client"
|
||||
"github.com/owncloud/ocis-accounts/pkg/proto/v0"
|
||||
"github.com/owncloud/ocis-pkg/v2/oidc"
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
)
|
||||
|
||||
func TestStaticSelector(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
req := httptest.NewRequest("GET", "https://example.org/foo", nil)
|
||||
sel := NewStaticSelector(&config.StaticSelectorConf{Policy: "reva"})
|
||||
|
||||
want := "reva"
|
||||
got, err := sel(ctx, req)
|
||||
if got != want {
|
||||
t.Errorf("Expected policy %v got %v", want, got)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error %v", err)
|
||||
}
|
||||
|
||||
sel = NewStaticSelector(&config.StaticSelectorConf{Policy: "foo"})
|
||||
|
||||
want = "foo"
|
||||
got, err = sel(ctx, req)
|
||||
if got != want {
|
||||
t.Errorf("Expected policy %v got %v", want, got)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
AccSvcShouldReturnError bool
|
||||
Claims *oidc.StandardClaims
|
||||
Expected string
|
||||
}
|
||||
|
||||
func TestMigrationSelector(t *testing.T) {
|
||||
cfg := config.MigrationSelectorConf{
|
||||
AccFoundPolicy: "found",
|
||||
AccNotFoundPolicy: "not_found",
|
||||
UnauthenticatedPolicy: "unauth",
|
||||
}
|
||||
var tests = []testCase{
|
||||
{true, &oidc.StandardClaims{PreferredUsername: "Hans"}, "not_found"},
|
||||
{false, &oidc.StandardClaims{PreferredUsername: "Hans"}, "found"},
|
||||
{false, nil, "unauth"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
//t.Run(fmt.Sprintf("#%v", k), func(t *testing.T) {
|
||||
// t.Parallel()
|
||||
tc := tc
|
||||
sut := NewMigrationSelector(&cfg, mockAccSvc(tc.AccSvcShouldReturnError))
|
||||
r := httptest.NewRequest("GET", "https://example.com", nil)
|
||||
ctx := oidc.NewContext(r.Context(), tc.Claims)
|
||||
nr := r.WithContext(ctx)
|
||||
|
||||
got, err := sut(ctx, nr)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if got != tc.Expected {
|
||||
t.Errorf("Expected Policy %v got %v", tc.Expected, got)
|
||||
}
|
||||
//})
|
||||
}
|
||||
}
|
||||
|
||||
func mockAccSvc(retErr bool) proto.AccountsService {
|
||||
if retErr {
|
||||
return &proto.MockAccountsService{
|
||||
GetFunc: func(ctx context.Context, in *proto.GetAccountRequest, opts ...client.CallOption) (record *proto.Account, err error) {
|
||||
return nil, fmt.Errorf("error returned by mockAccountsService GET")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &proto.MockAccountsService{
|
||||
GetFunc: func(ctx context.Context, in *proto.GetAccountRequest, opts ...client.CallOption) (record *proto.Account, err error) {
|
||||
return &proto.Account{}, nil
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,385 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/owncloud/ocis-proxy/pkg/proxy/policy"
|
||||
"go.opencensus.io/plugin/ochttp/propagation/tracecontext"
|
||||
"go.opencensus.io/trace"
|
||||
|
||||
"github.com/owncloud/ocis-pkg/v2/log"
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
)
|
||||
|
||||
// MultiHostReverseProxy extends httputil to support multiple hosts with diffent policies
|
||||
type MultiHostReverseProxy struct {
|
||||
httputil.ReverseProxy
|
||||
Directors map[string]map[config.RouteType]map[string]func(req *http.Request)
|
||||
PolicySelector policy.Selector
|
||||
logger log.Logger
|
||||
propagator tracecontext.HTTPFormat
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewMultiHostReverseProxy undocummented
|
||||
func NewMultiHostReverseProxy(opts ...Option) *MultiHostReverseProxy {
|
||||
options := newOptions(opts...)
|
||||
|
||||
rp := &MultiHostReverseProxy{
|
||||
Directors: make(map[string]map[config.RouteType]map[string]func(req *http.Request)),
|
||||
logger: options.Logger,
|
||||
config: options.Config,
|
||||
}
|
||||
rp.Director = rp.directorSelectionDirector
|
||||
|
||||
if options.Config.Policies == nil {
|
||||
rp.logger.Info().Str("source", "runtime").Msg("Policies")
|
||||
options.Config.Policies = defaultPolicies()
|
||||
} else {
|
||||
rp.logger.Info().Str("source", "file").Msg("Policies")
|
||||
}
|
||||
|
||||
if options.Config.PolicySelector == nil {
|
||||
firstPolicy := options.Config.Policies[0].Name
|
||||
rp.logger.Warn().Msgf("policy-selector not configured. Will always use first policy: '%v'", firstPolicy)
|
||||
options.Config.PolicySelector = &config.PolicySelector{
|
||||
Static: &config.StaticSelectorConf{
|
||||
Policy: firstPolicy,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
rp.logger.Debug().
|
||||
Interface("selector_config", options.Config.PolicySelector).
|
||||
Msg("loading policy-selector")
|
||||
|
||||
policySelector, err := policy.LoadSelector(options.Config.PolicySelector)
|
||||
if err != nil {
|
||||
rp.logger.Fatal().Err(err).Msg("Could not load policy-selector")
|
||||
}
|
||||
|
||||
rp.PolicySelector = policySelector
|
||||
|
||||
for _, pol := range options.Config.Policies {
|
||||
for _, route := range pol.Routes {
|
||||
rp.logger.Debug().Str("fwd: ", route.Endpoint)
|
||||
uri, err := url.Parse(route.Backend)
|
||||
if err != nil {
|
||||
rp.logger.
|
||||
Fatal().
|
||||
Err(err).
|
||||
Msgf("malformed url: %v", route.Backend)
|
||||
}
|
||||
|
||||
rp.logger.
|
||||
Debug().
|
||||
Interface("route", route).
|
||||
Msg("adding route")
|
||||
|
||||
rp.AddHost(pol.Name, uri, route)
|
||||
}
|
||||
}
|
||||
|
||||
return rp
|
||||
}
|
||||
|
||||
func (p *MultiHostReverseProxy) directorSelectionDirector(r *http.Request) {
|
||||
pol, err := p.PolicySelector(r.Context(), r)
|
||||
if err != nil {
|
||||
p.logger.Error().Msgf("Error while selecting pol %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := p.Directors[pol]; !ok {
|
||||
p.logger.
|
||||
Error().
|
||||
Msgf("policy %v is not configured", pol)
|
||||
return
|
||||
}
|
||||
|
||||
// find matching director
|
||||
for _, rt := range config.RouteTypes {
|
||||
var handler func(string, url.URL) bool
|
||||
switch rt {
|
||||
case config.QueryRoute:
|
||||
handler = p.queryRouteMatcher
|
||||
case config.RegexRoute:
|
||||
handler = p.regexRouteMatcher
|
||||
case config.PrefixRoute:
|
||||
fallthrough
|
||||
default:
|
||||
handler = p.prefixRouteMatcher
|
||||
}
|
||||
for endpoint := range p.Directors[pol][rt] {
|
||||
if handler(endpoint, *r.URL) {
|
||||
p.logger.
|
||||
Debug().
|
||||
Str("policy", pol).
|
||||
Str("prefix", endpoint).
|
||||
Str("path", r.URL.Path).
|
||||
Str("routeType", string(rt)).
|
||||
Msg("director found")
|
||||
p.Directors[pol][rt][endpoint](r)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// override default director with root. If any
|
||||
if p.Directors[pol][config.PrefixRoute]["/"] != nil {
|
||||
p.Directors[pol][config.PrefixRoute]["/"](r)
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.
|
||||
Warn().
|
||||
Str("policy", pol).
|
||||
Str("path", r.URL.Path).
|
||||
Msg("no director found")
|
||||
}
|
||||
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
aslash := strings.HasSuffix(a, "/")
|
||||
bslash := strings.HasPrefix(b, "/")
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a + b[1:]
|
||||
case !aslash && !bslash:
|
||||
return a + "/" + b
|
||||
}
|
||||
return a + b
|
||||
}
|
||||
|
||||
// AddHost undocumented
|
||||
func (p *MultiHostReverseProxy) AddHost(policy string, target *url.URL, rt config.Route) {
|
||||
targetQuery := target.RawQuery
|
||||
if p.Directors[policy] == nil {
|
||||
p.Directors[policy] = make(map[config.RouteType]map[string]func(req *http.Request))
|
||||
}
|
||||
routeType := config.DefaultRouteType
|
||||
if rt.Type != "" {
|
||||
routeType = rt.Type
|
||||
}
|
||||
if p.Directors[policy][routeType] == nil {
|
||||
p.Directors[policy][routeType] = make(map[string]func(req *http.Request))
|
||||
}
|
||||
p.Directors[policy][routeType][rt.Endpoint] = func(req *http.Request) {
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
// Apache deployments host addresses need to match on req.Host and req.URL.Host
|
||||
// see https://stackoverflow.com/questions/34745654/golang-reverseproxy-with-apache2-sni-hostname-error
|
||||
if rt.ApacheVHost {
|
||||
req.Host = target.Host
|
||||
}
|
||||
|
||||
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
|
||||
if targetQuery == "" || req.URL.RawQuery == "" {
|
||||
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
||||
} else {
|
||||
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
|
||||
}
|
||||
if _, ok := req.Header["User-Agent"]; !ok {
|
||||
// explicitly disable User-Agent so it's not set to default value
|
||||
req.Header.Set("User-Agent", "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *MultiHostReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.Background()
|
||||
var span *trace.Span
|
||||
|
||||
// Start root span.
|
||||
if p.config.Tracing.Enabled {
|
||||
ctx, span = trace.StartSpan(context.Background(), r.URL.String())
|
||||
defer span.End()
|
||||
p.propagator.SpanContextToRequest(span.SpanContext(), r)
|
||||
}
|
||||
|
||||
// Call upstream ServeHTTP
|
||||
p.ReverseProxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
|
||||
func (p MultiHostReverseProxy) queryRouteMatcher(endpoint string, target url.URL) bool {
|
||||
u, _ := url.Parse(endpoint)
|
||||
if strings.HasPrefix(target.Path, u.Path) && endpoint != "/" {
|
||||
query := u.Query()
|
||||
if len(query) != 0 {
|
||||
rQuery := target.Query()
|
||||
match := true
|
||||
for k := range query {
|
||||
v := query.Get(k)
|
||||
rv := rQuery.Get(k)
|
||||
if rv != v {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return match
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *MultiHostReverseProxy) regexRouteMatcher(endpoint string, target url.URL) bool {
|
||||
matched, err := regexp.MatchString(endpoint, target.String())
|
||||
if err != nil {
|
||||
p.logger.Warn().Err(err).Msgf("regex with pattern %s failed", endpoint)
|
||||
}
|
||||
return matched
|
||||
}
|
||||
|
||||
func (p *MultiHostReverseProxy) prefixRouteMatcher(endpoint string, target url.URL) bool {
|
||||
return strings.HasPrefix(target.Path, endpoint) && endpoint != "/"
|
||||
}
|
||||
|
||||
func defaultPolicies() []config.Policy {
|
||||
return []config.Policy{
|
||||
{
|
||||
Name: "reva",
|
||||
Routes: []config.Route{
|
||||
{
|
||||
Endpoint: "/",
|
||||
Backend: "http://localhost:9100",
|
||||
},
|
||||
{
|
||||
Endpoint: "/.well-known/",
|
||||
Backend: "http://localhost:9130",
|
||||
},
|
||||
{
|
||||
Endpoint: "/konnect/",
|
||||
Backend: "http://localhost:9130",
|
||||
},
|
||||
{
|
||||
Endpoint: "/signin/",
|
||||
Backend: "http://localhost:9130",
|
||||
},
|
||||
{
|
||||
Type: config.RegexRoute,
|
||||
Endpoint: "/ocs/v[12].php/cloud/user", // we have `user` and `users` in ocis-ocs
|
||||
Backend: "http://localhost:9110",
|
||||
},
|
||||
{
|
||||
Endpoint: "/ocs/",
|
||||
Backend: "http://localhost:9140",
|
||||
},
|
||||
{
|
||||
Type: config.QueryRoute,
|
||||
Endpoint: "/remote.php/?preview=1",
|
||||
Backend: "http://localhost:9115",
|
||||
},
|
||||
{
|
||||
Endpoint: "/remote.php/",
|
||||
Backend: "http://localhost:9140",
|
||||
},
|
||||
{
|
||||
Endpoint: "/dav/",
|
||||
Backend: "http://localhost:9140",
|
||||
},
|
||||
{
|
||||
Endpoint: "/webdav/",
|
||||
Backend: "http://localhost:9140",
|
||||
},
|
||||
{
|
||||
Endpoint: "/status.php",
|
||||
Backend: "http://localhost:9140",
|
||||
},
|
||||
{
|
||||
Endpoint: "/index.php/",
|
||||
Backend: "http://localhost:9140",
|
||||
},
|
||||
{
|
||||
Endpoint: "/data",
|
||||
Backend: "http://localhost:9140",
|
||||
},
|
||||
// if we were using the go micro api gateway we could look up the endpoint in the registry dynamically
|
||||
{
|
||||
Endpoint: "/api/v0/accounts",
|
||||
Backend: "http://localhost:9181",
|
||||
},
|
||||
// TODO the lookup needs a better mechanism
|
||||
{
|
||||
Endpoint: "/accounts.js",
|
||||
Backend: "http://localhost:9181",
|
||||
},
|
||||
{
|
||||
Endpoint: "/api/v0/settings",
|
||||
Backend: "http://localhost:9190",
|
||||
},
|
||||
{
|
||||
Endpoint: "/settings.js",
|
||||
Backend: "http://localhost:9190",
|
||||
},
|
||||
{
|
||||
Endpoint: "/api/v0/greet",
|
||||
Backend: "http://localhost:9105",
|
||||
},
|
||||
{
|
||||
Endpoint: "/hello.js",
|
||||
Backend: "http://localhost:9105",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "oc10",
|
||||
Routes: []config.Route{
|
||||
{
|
||||
Endpoint: "/",
|
||||
Backend: "http://localhost:9100",
|
||||
},
|
||||
{
|
||||
Endpoint: "/.well-known/",
|
||||
Backend: "http://localhost:9130",
|
||||
},
|
||||
{
|
||||
Endpoint: "/konnect/",
|
||||
Backend: "http://localhost:9130",
|
||||
},
|
||||
{
|
||||
Endpoint: "/signin/",
|
||||
Backend: "http://localhost:9130",
|
||||
},
|
||||
{
|
||||
Endpoint: "/ocs/",
|
||||
Backend: "https://demo.owncloud.com",
|
||||
ApacheVHost: true,
|
||||
},
|
||||
{
|
||||
Endpoint: "/remote.php/",
|
||||
Backend: "https://demo.owncloud.com",
|
||||
ApacheVHost: true,
|
||||
},
|
||||
{
|
||||
Endpoint: "/dav/",
|
||||
Backend: "https://demo.owncloud.com",
|
||||
ApacheVHost: true,
|
||||
},
|
||||
{
|
||||
Endpoint: "/webdav/",
|
||||
Backend: "https://demo.owncloud.com",
|
||||
ApacheVHost: true,
|
||||
},
|
||||
{
|
||||
Endpoint: "/status.php",
|
||||
Backend: "https://demo.owncloud.com",
|
||||
ApacheVHost: true,
|
||||
},
|
||||
{
|
||||
Endpoint: "/index.php/",
|
||||
Backend: "https://demo.owncloud.com",
|
||||
ApacheVHost: true,
|
||||
},
|
||||
{
|
||||
Endpoint: "/data",
|
||||
Backend: "https://demo.owncloud.com",
|
||||
ApacheVHost: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
)
|
||||
|
||||
func TestProxyIntegration(t *testing.T) {
|
||||
var tests = []testCase{
|
||||
// Simple prefix route
|
||||
test("simple_prefix", withPolicy("reva", withRoutes{{
|
||||
Type: config.PrefixRoute,
|
||||
Endpoint: "/api",
|
||||
Backend: "http://api.example.com"},
|
||||
})).withRequest("GET", "https://example.com/api", nil).
|
||||
expectProxyTo("http://api.example.com/api"),
|
||||
|
||||
// Complex prefix route, different method
|
||||
test("complex_prefix_post", withPolicy("reva", withRoutes{{
|
||||
Type: config.PrefixRoute,
|
||||
Endpoint: "/api",
|
||||
Backend: "http://api.example.com/service1/"},
|
||||
})).withRequest("POST", "https://example.com/api", nil).
|
||||
expectProxyTo("http://api.example.com/service1/api"),
|
||||
|
||||
// Query route
|
||||
test("query_route", withPolicy("reva", withRoutes{{
|
||||
Type: config.QueryRoute,
|
||||
Endpoint: "/api?format=json",
|
||||
Backend: "http://backend/"},
|
||||
})).withRequest("GET", "https://example.com/api?format=json", nil).
|
||||
expectProxyTo("http://backend/api?format=json"),
|
||||
|
||||
// Regex route
|
||||
test("regex_route", withPolicy("reva", withRoutes{{
|
||||
Type: config.RegexRoute,
|
||||
Endpoint: `\/user\/(\d+)`,
|
||||
Backend: "http://backend/"},
|
||||
})).withRequest("POST", "https://example.com/user/1234", nil).
|
||||
expectProxyTo("http://backend/user/1234"),
|
||||
|
||||
// Multiple prefix routes 1
|
||||
test("multiple_prefix", withPolicy("reva", withRoutes{
|
||||
{
|
||||
Type: config.PrefixRoute,
|
||||
Endpoint: "/api",
|
||||
Backend: "http://api.example.com",
|
||||
},
|
||||
{
|
||||
Type: config.PrefixRoute,
|
||||
Endpoint: "/payment",
|
||||
Backend: "http://payment.example.com",
|
||||
},
|
||||
})).withRequest("GET", "https://example.com/payment", nil).
|
||||
expectProxyTo("http://payment.example.com/payment"),
|
||||
|
||||
// Multiple prefix routes 2
|
||||
test("multiple_prefix", withPolicy("reva", withRoutes{
|
||||
{
|
||||
Type: config.PrefixRoute,
|
||||
Endpoint: "/api",
|
||||
Backend: "http://api.example.com",
|
||||
},
|
||||
{
|
||||
Type: config.PrefixRoute,
|
||||
Endpoint: "/payment",
|
||||
Backend: "http://payment.example.com",
|
||||
},
|
||||
})).withRequest("GET", "https://example.com/api", nil).
|
||||
expectProxyTo("http://api.example.com/api"),
|
||||
|
||||
// Mixed route types
|
||||
test("mixed_types", withPolicy("reva", withRoutes{
|
||||
{
|
||||
Type: config.PrefixRoute,
|
||||
Endpoint: "/api",
|
||||
Backend: "http://api.example.com",
|
||||
},
|
||||
{
|
||||
Type: config.RegexRoute,
|
||||
Endpoint: `\/user\/(\d+)`,
|
||||
Backend: "http://users.example.com",
|
||||
ApacheVHost: false,
|
||||
},
|
||||
})).withRequest("GET", "https://example.com/api", nil).
|
||||
expectProxyTo("http://api.example.com/api"),
|
||||
|
||||
// Mixed route types
|
||||
test("mixed_types", withPolicy("reva", withRoutes{
|
||||
{
|
||||
Type: config.PrefixRoute,
|
||||
Endpoint: "/api",
|
||||
Backend: "http://api.example.com",
|
||||
},
|
||||
{
|
||||
Type: config.RegexRoute,
|
||||
Endpoint: `\/user\/(\d+)`,
|
||||
Backend: "http://users.example.com",
|
||||
ApacheVHost: false,
|
||||
},
|
||||
})).withRequest("GET", "https://example.com/user/1234", nil).
|
||||
expectProxyTo("http://users.example.com/user/1234"),
|
||||
}
|
||||
|
||||
for k := range tests {
|
||||
t.Run(tests[k].id, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := tests[k]
|
||||
rp := newTestProxy(testConfig(tc.conf), func(req *http.Request) *http.Response {
|
||||
if got, want := req.URL.String(), tc.expect.String(); got != want {
|
||||
t.Errorf("Proxied url should be %v got %v", want, got)
|
||||
}
|
||||
|
||||
if got, want := req.Method, tc.input.Method; got != want {
|
||||
t.Errorf("Proxied request method should be %v got %v", want, got)
|
||||
}
|
||||
|
||||
if got, want := req.Proto, tc.input.Proto; got != want {
|
||||
t.Errorf("Proxied request proto should be %v got %v", want, got)
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(bytes.NewBufferString(`OK`)),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
rp.ServeHTTP(rr, tc.input)
|
||||
|
||||
if rr.Result().StatusCode != 200 {
|
||||
t.Errorf("Expected status 200 from proxy-response got %v", rr.Result().StatusCode)
|
||||
}
|
||||
|
||||
resultBody, err := ioutil.ReadAll(rr.Result().Body)
|
||||
if err != nil {
|
||||
t.Fatal("Error reading result body")
|
||||
}
|
||||
|
||||
bodyString := string(resultBody)
|
||||
if bodyString != `OK` {
|
||||
t.Errorf("Result body of proxied response should be OK, got %v", bodyString)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestProxy(cfg *config.Config, fn RoundTripFunc) *MultiHostReverseProxy {
|
||||
rp := NewMultiHostReverseProxy(Config(cfg))
|
||||
rp.Transport = fn
|
||||
return rp
|
||||
}
|
||||
|
||||
type RoundTripFunc func(req *http.Request) *http.Response
|
||||
|
||||
// RoundTrip .
|
||||
func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req), nil
|
||||
}
|
||||
|
||||
type withRoutes []config.Route
|
||||
|
||||
type testCase struct {
|
||||
id string
|
||||
input *http.Request
|
||||
expect *url.URL
|
||||
conf []config.Policy
|
||||
}
|
||||
|
||||
func test(id string, policies ...config.Policy) *testCase {
|
||||
tc := &testCase{
|
||||
id: id,
|
||||
}
|
||||
for k := range policies {
|
||||
tc.conf = append(tc.conf, policies[k])
|
||||
}
|
||||
|
||||
return tc
|
||||
}
|
||||
|
||||
func withPolicy(name string, r withRoutes) config.Policy {
|
||||
return config.Policy{Name: name, Routes: r}
|
||||
}
|
||||
|
||||
func (tc *testCase) withRequest(method string, target string, body io.Reader) *testCase {
|
||||
tc.input = httptest.NewRequest(method, target, body)
|
||||
return tc
|
||||
}
|
||||
|
||||
func (tc *testCase) expectProxyTo(strURL string) testCase {
|
||||
pu, err := url.Parse(strURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Error parsing %v", strURL)
|
||||
}
|
||||
|
||||
tc.expect = pu
|
||||
return *tc
|
||||
}
|
||||
|
||||
func testConfig(policy []config.Policy) *config.Config {
|
||||
return &config.Config{
|
||||
File: "",
|
||||
Log: config.Log{},
|
||||
Debug: config.Debug{},
|
||||
HTTP: config.HTTP{},
|
||||
Tracing: config.Tracing{},
|
||||
Asset: config.Asset{},
|
||||
Policies: policy,
|
||||
OIDC: config.OIDC{},
|
||||
PolicySelector: nil,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
)
|
||||
|
||||
func TestPrefixRouteMatcher(t *testing.T) {
|
||||
cfg := config.New()
|
||||
p := NewMultiHostReverseProxy(Config(cfg))
|
||||
|
||||
endpoint := "/foobar"
|
||||
u, _ := url.Parse("/foobar/baz/some/url")
|
||||
|
||||
matched := p.prefixRouteMatcher(endpoint, *u)
|
||||
if !matched {
|
||||
t.Errorf("Endpoint %s and URL %s should match", endpoint, u.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryRouteMatcher(t *testing.T) {
|
||||
cfg := config.New()
|
||||
p := NewMultiHostReverseProxy(Config(cfg))
|
||||
|
||||
endpoint := "/foobar?parameter=true"
|
||||
u, _ := url.Parse("/foobar/baz/some/url?parameter=true")
|
||||
|
||||
matched := p.queryRouteMatcher(endpoint, *u)
|
||||
if !matched {
|
||||
t.Errorf("Endpoint %s and URL %s should match", endpoint, u.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryRouteMatcherWithoutParameters(t *testing.T) {
|
||||
cfg := config.New()
|
||||
p := NewMultiHostReverseProxy(Config(cfg))
|
||||
|
||||
endpoint := "/foobar"
|
||||
u, _ := url.Parse("/foobar/baz/some/url?parameter=true")
|
||||
|
||||
matched := p.queryRouteMatcher(endpoint, *u)
|
||||
if matched {
|
||||
t.Errorf("Endpoint %s and URL %s should not match", endpoint, u.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryRouteMatcherWithDifferingParameters(t *testing.T) {
|
||||
cfg := config.New()
|
||||
p := NewMultiHostReverseProxy(Config(cfg))
|
||||
|
||||
endpoint := "/foobar?parameter=false"
|
||||
u, _ := url.Parse("/foobar/baz/some/url?parameter=true")
|
||||
|
||||
matched := p.queryRouteMatcher(endpoint, *u)
|
||||
if matched {
|
||||
t.Errorf("Endpoint %s and URL %s should not match", endpoint, u.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryRouteMatcherWithMultipleDifferingParameters(t *testing.T) {
|
||||
cfg := config.New()
|
||||
p := NewMultiHostReverseProxy(Config(cfg))
|
||||
|
||||
endpoint := "/foobar?parameter=false&other=true"
|
||||
u, _ := url.Parse("/foobar/baz/some/url?parameter=true")
|
||||
|
||||
matched := p.queryRouteMatcher(endpoint, *u)
|
||||
if matched {
|
||||
t.Errorf("Endpoint %s and URL %s should not match", endpoint, u.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryRouteMatcherWithMultipleParameters(t *testing.T) {
|
||||
cfg := config.New()
|
||||
p := NewMultiHostReverseProxy(Config(cfg))
|
||||
|
||||
endpoint := "/foobar?parameter=false&other=true"
|
||||
u, _ := url.Parse("/foobar/baz/some/url?parameter=false&other=true")
|
||||
|
||||
matched := p.queryRouteMatcher(endpoint, *u)
|
||||
if !matched {
|
||||
t.Errorf("Endpoint %s and URL %s should match", endpoint, u.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexRouteMatcher(t *testing.T) {
|
||||
cfg := config.New()
|
||||
p := NewMultiHostReverseProxy(Config(cfg))
|
||||
|
||||
endpoint := ".*some\\/url.*parameter=true"
|
||||
u, _ := url.Parse("/foobar/baz/some/url?parameter=true")
|
||||
|
||||
matched := p.regexRouteMatcher(endpoint, *u)
|
||||
if !matched {
|
||||
t.Errorf("Endpoint %s and URL %s should match", endpoint, u.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexRouteMatcherWithInvalidPattern(t *testing.T) {
|
||||
cfg := config.New()
|
||||
p := NewMultiHostReverseProxy(Config(cfg))
|
||||
|
||||
endpoint := "([\\])\\w+"
|
||||
u, _ := url.Parse("/foobar/baz/some/url?parameter=true")
|
||||
|
||||
matched := p.regexRouteMatcher(endpoint, *u)
|
||||
if matched {
|
||||
t.Errorf("Endpoint %s and URL %s should not match", endpoint, u.String())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/owncloud/ocis-pkg/v2/log"
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
)
|
||||
|
||||
// Option defines a single option function.
|
||||
type Option func(o *Options)
|
||||
|
||||
// Options defines the available options for this package.
|
||||
type Options struct {
|
||||
Logger log.Logger
|
||||
Context context.Context
|
||||
Config *config.Config
|
||||
}
|
||||
|
||||
// newOptions initializes the available default options.
|
||||
func newOptions(opts ...Option) Options {
|
||||
opt := Options{}
|
||||
|
||||
for _, o := range opts {
|
||||
o(&opt)
|
||||
}
|
||||
|
||||
return opt
|
||||
}
|
||||
|
||||
// Logger provides a function to set the logger option.
|
||||
func Logger(val log.Logger) Option {
|
||||
return func(o *Options) {
|
||||
o.Logger = val
|
||||
}
|
||||
}
|
||||
|
||||
// Context provides a function to set the context option.
|
||||
func Context(val context.Context) Option {
|
||||
return func(o *Options) {
|
||||
o.Context = val
|
||||
}
|
||||
}
|
||||
|
||||
// Config provides a function to set the config option.
|
||||
func Config(val *config.Config) Option {
|
||||
return func(o *Options) {
|
||||
o.Config = val
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/owncloud/ocis-pkg/v2/service/debug"
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
"github.com/owncloud/ocis-proxy/pkg/version"
|
||||
)
|
||||
|
||||
// Server initializes the debug service and server.
|
||||
func Server(opts ...Option) (*http.Server, error) {
|
||||
options := newOptions(opts...)
|
||||
|
||||
return debug.NewService(
|
||||
debug.Logger(options.Logger),
|
||||
debug.Name("proxy"),
|
||||
debug.Version(version.String),
|
||||
debug.Address(options.Config.Debug.Addr),
|
||||
debug.Token(options.Config.Debug.Token),
|
||||
debug.Pprof(options.Config.Debug.Pprof),
|
||||
debug.Zpages(options.Config.Debug.Zpages),
|
||||
debug.Health(health(options.Config)),
|
||||
debug.Ready(ready(options.Config)),
|
||||
), nil
|
||||
}
|
||||
|
||||
// health implements the health check.
|
||||
func health(cfg *config.Config) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
// TODO(tboerger): check if services are up and running
|
||||
|
||||
io.WriteString(w, http.StatusText(http.StatusOK))
|
||||
}
|
||||
}
|
||||
|
||||
// ready implements the ready check.
|
||||
func ready(cfg *config.Config) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
// TODO(tboerger): check if services are up and running
|
||||
|
||||
io.WriteString(w, http.StatusText(http.StatusOK))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
"github.com/micro/cli/v2"
|
||||
"github.com/owncloud/ocis-pkg/v2/log"
|
||||
"github.com/owncloud/ocis-proxy/pkg/config"
|
||||
"github.com/owncloud/ocis-proxy/pkg/metrics"
|
||||
)
|
||||
|
||||
// Option defines a single option function.
|
||||
type Option func(o *Options)
|
||||
|
||||
// Options defines the available options for this package.
|
||||
type Options struct {
|
||||
Logger log.Logger
|
||||
Context context.Context
|
||||
Config *config.Config
|
||||
Handler http.Handler
|
||||
Metrics *metrics.Metrics
|
||||
Flags []cli.Flag
|
||||
Namespace string
|
||||
Middlewares alice.Chain
|
||||
}
|
||||
|
||||
// newOptions initializes the available default options.
|
||||
func newOptions(opts ...Option) Options {
|
||||
opt := Options{}
|
||||
|
||||
for _, o := range opts {
|
||||
o(&opt)
|
||||
}
|
||||
|
||||
return opt
|
||||
}
|
||||
|
||||
// Logger provides a function to set the logger option.
|
||||
func Logger(val log.Logger) Option {
|
||||
return func(o *Options) {
|
||||
o.Logger = val
|
||||
}
|
||||
}
|
||||
|
||||
// Context provides a function to set the context option.
|
||||
func Context(val context.Context) Option {
|
||||
return func(o *Options) {
|
||||
o.Context = val
|
||||
}
|
||||
}
|
||||
|
||||
// Config provides a function to set the config option.
|
||||
func Config(val *config.Config) Option {
|
||||
return func(o *Options) {
|
||||
o.Config = val
|
||||
}
|
||||
}
|
||||
|
||||
// Metrics provides a function to set the metrics option.
|
||||
func Metrics(val *metrics.Metrics) Option {
|
||||
return func(o *Options) {
|
||||
o.Metrics = val
|
||||
}
|
||||
}
|
||||
|
||||
// Flags provides a function to set the flags option.
|
||||
func Flags(val []cli.Flag) Option {
|
||||
return func(o *Options) {
|
||||
o.Flags = append(o.Flags, val...)
|
||||
}
|
||||
}
|
||||
|
||||
// Namespace provides a function to set the namespace option.
|
||||
func Namespace(val string) Option {
|
||||
return func(o *Options) {
|
||||
o.Namespace = val
|
||||
}
|
||||
}
|
||||
|
||||
// Handler provides a function to set the Handler option.
|
||||
func Handler(h http.Handler) Option {
|
||||
return func(o *Options) {
|
||||
o.Handler = h
|
||||
}
|
||||
}
|
||||
|
||||
// Middlewares provides a function to register middlewares
|
||||
func Middlewares(val alice.Chain) Option {
|
||||
return func(o *Options) {
|
||||
o.Middlewares = val
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"os"
|
||||
|
||||
svc "github.com/owncloud/ocis-pkg/v2/service/http"
|
||||
"github.com/owncloud/ocis-proxy/pkg/crypto"
|
||||
"github.com/owncloud/ocis-proxy/pkg/version"
|
||||
)
|
||||
|
||||
// Server initializes the http service and server.
|
||||
func Server(opts ...Option) (svc.Service, error) {
|
||||
options := newOptions(opts...)
|
||||
l := options.Logger
|
||||
httpCfg := options.Config.HTTP
|
||||
|
||||
var cer tls.Certificate
|
||||
var certErr error
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
if options.Config.HTTP.TLS {
|
||||
if httpCfg.TLSCert == "" || httpCfg.TLSKey == "" {
|
||||
l.Warn().Msgf("No tls certificate provided, using a generated one")
|
||||
_, certErr := os.Stat("./server.crt")
|
||||
_, keyErr := os.Stat("./server.key")
|
||||
|
||||
if os.IsNotExist(certErr) || os.IsNotExist(keyErr) {
|
||||
// GenCert has side effects as it writes 2 files to the binary running location
|
||||
if err := crypto.GenCert(l); err != nil {
|
||||
l.Fatal().Err(err).Msgf("Could not generate test-certificate")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
httpCfg.TLSCert = "server.crt"
|
||||
httpCfg.TLSKey = "server.key"
|
||||
}
|
||||
|
||||
cer, certErr = tls.LoadX509KeyPair(httpCfg.TLSCert, httpCfg.TLSKey)
|
||||
if certErr != nil {
|
||||
options.Logger.Fatal().Err(certErr).Msg("Could not setup TLS")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
tlsConfig = &tls.Config{Certificates: []tls.Certificate{cer}}
|
||||
}
|
||||
chain := options.Middlewares.Then(options.Handler)
|
||||
|
||||
service := svc.NewService(
|
||||
svc.Name("web.proxy"),
|
||||
svc.TLSConfig(tlsConfig),
|
||||
svc.Logger(options.Logger),
|
||||
svc.Namespace(options.Namespace),
|
||||
svc.Version(version.String),
|
||||
svc.Address(options.Config.HTTP.Addr),
|
||||
svc.Context(options.Context),
|
||||
svc.Flags(options.Flags...),
|
||||
svc.Handler(chain),
|
||||
)
|
||||
|
||||
if err := service.Init(); err != nil {
|
||||
l.Fatal().Err(err).Msgf("Error initializing")
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// String gets defined by the build system.
|
||||
String = "0.0.0"
|
||||
|
||||
// Date indicates the build date.
|
||||
Date = "00000000"
|
||||
)
|
||||
|
||||
// Compiled returns the compile time of this service.
|
||||
func Compiled() time.Time {
|
||||
t, _ := time.Parse("20060102", Date)
|
||||
return t
|
||||
}
|
||||
Reference in New Issue
Block a user