Files
opencloud/proxy/pkg/middleware/openidconnect.go
2020-09-18 15:18:41 +02:00

120 lines
3.7 KiB
Go

package middleware
import (
"context"
"errors"
"net/http"
"strings"
"github.com/coreos/go-oidc"
ocisoidc "github.com/owncloud/ocis/ocis-pkg/oidc"
"github.com/owncloud/ocis/proxy/pkg/cache"
"golang.org/x/oauth2"
)
var (
// ErrInvalidToken is returned when the request token is invalid.
ErrInvalidToken = errors.New("invalid or missing token")
// svcCache caches requests for given services to prevent round trips to the service
svcCache = cache.NewCache(
cache.Size(256),
)
)
// OIDCProvider used to mock the oidc provider during tests
type OIDCProvider interface {
UserInfo(ctx context.Context, ts oauth2.TokenSource) (*oidc.UserInfo, error)
}
// OpenIDConnect provides a middleware to check access secured by a static token.
func OpenIDConnect(opts ...Option) func(next http.Handler) http.Handler {
opt := newOptions(opts...)
return func(next http.Handler) http.Handler {
var oidcProvider OIDCProvider
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
header := r.Header.Get("Authorization")
path := r.URL.Path
// Ignore request to "/konnect/v1/userinfo" as this will cause endless loop when getting userinfo
// needs a better idea on how to not hardcode this
if header == "" || !strings.HasPrefix(header, "Bearer ") || path == "/konnect/v1/userinfo" {
next.ServeHTTP(w, r)
return
}
customCtx := context.WithValue(r.Context(), oauth2.HTTPClient, opt.HTTPClient)
// check if oidc provider is initialized
if oidcProvider == nil {
// Lazily initialize a provider
// provider needs to be cached as when it is created
// it will fetch the keys from the issuer using the .well-known
// endpoint
var err error
oidcProvider, err = opt.OIDCProviderFunc()
if err != nil {
opt.Logger.Error().Err(err).Msg("could not initialize oidc provider")
w.WriteHeader(http.StatusInternalServerError)
return
}
}
token := strings.TrimPrefix(header, "Bearer ")
// TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token)
oauth2Token := &oauth2.Token{
AccessToken: token,
}
// The claims we want to have
var claims ocisoidc.StandardClaims
userInfo, err := oidcProvider.UserInfo(customCtx, oauth2.StaticTokenSource(oauth2Token))
if err != nil {
opt.Logger.Error().Err(err).Str("token", token).Msg("Failed to get userinfo")
http.Error(w, ErrInvalidToken.Error(), http.StatusUnauthorized)
return
}
if err := userInfo.Claims(&claims); err != nil {
opt.Logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
w.WriteHeader(http.StatusInternalServerError)
return
}
//TODO: This should be read from the token instead of config
claims.Iss = opt.OIDCIss
// inject claims to the request context for the account_uuid middleware.
ctxWithClaims := ocisoidc.NewContext(r.Context(), &claims)
r = r.WithContext(ctxWithClaims)
opt.Logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo")
// store claims in context
// uses the original context, not the one with probably reduced security
nr := r.WithContext(ocisoidc.NewContext(r.Context(), &claims))
next.ServeHTTP(w, nr)
})
}
}
// AccountsCacheEntry stores a request to the accounts service on the cache.
// this type declaration should be on each respective service.
type AccountsCacheEntry struct {
Email string
UUID string
}
const (
// AccountsKey declares the svcKey for the Accounts service.
AccountsKey = "accounts"
// NodeKey declares the key that will be used to store the node address.
// It is shared between services.
NodeKey = "node"
)