mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-01-04 11:19:39 -06:00
120 lines
3.7 KiB
Go
120 lines
3.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/coreos/go-oidc"
|
|
ocisoidc "github.com/owncloud/ocis/ocis-pkg/oidc"
|
|
"github.com/owncloud/ocis/proxy/pkg/cache"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
var (
|
|
// ErrInvalidToken is returned when the request token is invalid.
|
|
ErrInvalidToken = errors.New("invalid or missing token")
|
|
|
|
// svcCache caches requests for given services to prevent round trips to the service
|
|
svcCache = cache.NewCache(
|
|
cache.Size(256),
|
|
)
|
|
)
|
|
|
|
// OIDCProvider used to mock the oidc provider during tests
|
|
type OIDCProvider interface {
|
|
UserInfo(ctx context.Context, ts oauth2.TokenSource) (*oidc.UserInfo, error)
|
|
}
|
|
|
|
// OpenIDConnect provides a middleware to check access secured by a static token.
|
|
func OpenIDConnect(opts ...Option) func(next http.Handler) http.Handler {
|
|
opt := newOptions(opts...)
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
|
|
var oidcProvider OIDCProvider
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
header := r.Header.Get("Authorization")
|
|
path := r.URL.Path
|
|
|
|
// Ignore request to "/konnect/v1/userinfo" as this will cause endless loop when getting userinfo
|
|
// needs a better idea on how to not hardcode this
|
|
if header == "" || !strings.HasPrefix(header, "Bearer ") || path == "/konnect/v1/userinfo" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
customCtx := context.WithValue(r.Context(), oauth2.HTTPClient, opt.HTTPClient)
|
|
|
|
// check if oidc provider is initialized
|
|
if oidcProvider == nil {
|
|
// Lazily initialize a provider
|
|
|
|
// provider needs to be cached as when it is created
|
|
// it will fetch the keys from the issuer using the .well-known
|
|
// endpoint
|
|
var err error
|
|
oidcProvider, err = opt.OIDCProviderFunc()
|
|
if err != nil {
|
|
opt.Logger.Error().Err(err).Msg("could not initialize oidc provider")
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
token := strings.TrimPrefix(header, "Bearer ")
|
|
|
|
// TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token)
|
|
oauth2Token := &oauth2.Token{
|
|
AccessToken: token,
|
|
}
|
|
|
|
// The claims we want to have
|
|
var claims ocisoidc.StandardClaims
|
|
userInfo, err := oidcProvider.UserInfo(customCtx, oauth2.StaticTokenSource(oauth2Token))
|
|
if err != nil {
|
|
opt.Logger.Error().Err(err).Str("token", token).Msg("Failed to get userinfo")
|
|
http.Error(w, ErrInvalidToken.Error(), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
if err := userInfo.Claims(&claims); err != nil {
|
|
opt.Logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
//TODO: This should be read from the token instead of config
|
|
claims.Iss = opt.OIDCIss
|
|
|
|
// inject claims to the request context for the account_uuid middleware.
|
|
ctxWithClaims := ocisoidc.NewContext(r.Context(), &claims)
|
|
r = r.WithContext(ctxWithClaims)
|
|
|
|
opt.Logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo")
|
|
// store claims in context
|
|
// uses the original context, not the one with probably reduced security
|
|
nr := r.WithContext(ocisoidc.NewContext(r.Context(), &claims))
|
|
|
|
next.ServeHTTP(w, nr)
|
|
})
|
|
}
|
|
}
|
|
|
|
// AccountsCacheEntry stores a request to the accounts service on the cache.
|
|
// this type declaration should be on each respective service.
|
|
type AccountsCacheEntry struct {
|
|
Email string
|
|
UUID string
|
|
}
|
|
|
|
const (
|
|
// AccountsKey declares the svcKey for the Accounts service.
|
|
AccountsKey = "accounts"
|
|
|
|
// NodeKey declares the key that will be used to store the node address.
|
|
// It is shared between services.
|
|
NodeKey = "node"
|
|
)
|