mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-01-09 05:39:52 -06:00
108 lines
3.0 KiB
Go
108 lines
3.0 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
oidc "github.com/coreos/go-oidc/v3/oidc"
|
|
ocisoidc "github.com/owncloud/ocis/ocis-pkg/oidc"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
// newOIDCOptions initializes the available default options.
|
|
func newOIDCOptions(opts ...ocisoidc.Option) ocisoidc.Options {
|
|
opt := ocisoidc.Options{}
|
|
|
|
for _, o := range opts {
|
|
o(&opt)
|
|
}
|
|
|
|
return opt
|
|
}
|
|
|
|
// OpenIDConnect provides a middleware to check access secured by a static token.
|
|
func OpenIDConnect(opts ...ocisoidc.Option) func(http.Handler) http.Handler {
|
|
opt := newOIDCOptions(opts...)
|
|
|
|
// set defaults
|
|
if opt.Realm == "" {
|
|
opt.Realm = opt.Endpoint
|
|
}
|
|
if len(opt.SigningAlgs) < 1 {
|
|
opt.SigningAlgs = []string{"RS256", "PS256"}
|
|
}
|
|
|
|
var oidcProvider *oidc.Provider
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
header := r.Header.Get("Authorization")
|
|
|
|
if header == "" || !strings.HasPrefix(header, "Bearer ") {
|
|
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Bearer realm="%s"`, opt.Realm))
|
|
http.Error(w, ErrInvalidToken.Error(), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
token := header[7:]
|
|
|
|
tr := &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
InsecureSkipVerify: opt.Insecure, //nolint:gosec
|
|
},
|
|
}
|
|
customHTTPClient := &http.Client{
|
|
Transport: tr,
|
|
Timeout: time.Second * 10,
|
|
}
|
|
customCtx := context.WithValue(r.Context(), oauth2.HTTPClient, customHTTPClient)
|
|
|
|
// use cached provider
|
|
if oidcProvider == nil {
|
|
// Initialize a provider by specifying the issuer URL.
|
|
// provider needs to be cached as when it is created
|
|
// it will fetch the keys from the issuer using the .well-known
|
|
// endpoint
|
|
provider, err := oidc.NewProvider(customCtx, opt.Endpoint)
|
|
if err != nil {
|
|
opt.Logger.Error().Err(err).Msg("could not initialize oidc provider")
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
oidcProvider = provider
|
|
}
|
|
|
|
// The claims we want to have
|
|
var claims map[string]interface{}
|
|
|
|
// TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token)
|
|
oauth2Token := &oauth2.Token{
|
|
AccessToken: token,
|
|
}
|
|
userInfo, err := oidcProvider.UserInfo(customCtx, oauth2.StaticTokenSource(oauth2Token))
|
|
if err != nil {
|
|
opt.Logger.Error().Err(err).Msg("Failed to get userinfo")
|
|
http.Error(w, ErrInvalidToken.Error(), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// parse claims
|
|
if err := userInfo.Claims(&claims); err != nil {
|
|
opt.Logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
opt.Logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo")
|
|
// store claims in context
|
|
// uses the original context, not the one with probably reduced security
|
|
nr := r.WithContext(ocisoidc.NewContext(r.Context(), claims))
|
|
|
|
next.ServeHTTP(w, nr)
|
|
})
|
|
}
|
|
}
|