Files
PrivateCaptcha/pkg/api/server_enterprise.go

185 lines
7.2 KiB
Go

//go:build enterprise
package api
import (
"context"
"fmt"
"log/slog"
"net/http"
"time"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/db"
dbgen "github.com/PrivateCaptcha/PrivateCaptcha/pkg/db/generated"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/monitoring"
"github.com/jackc/pgx/v5/pgtype"
"github.com/justinas/alice"
)
const (
maxAPIPostBodySize = 128 * 1024
maxPostPropertiesBodySize = 1024 * 1024
maxDeletePropertiesBodySize = 128 * 1024
maxUpdatePropertiesBodySize = 1024 * 1024
)
func (s *Server) setupEnterprise(rg *common.RouteGenerator, publicChain alice.Chain, apiRateLimiter func(next http.Handler) http.Handler) {
arg := func(s string) string {
return fmt.Sprintf("{%s}", s)
}
// "portal" API
portalAPIChain := publicChain.Append(s.Metrics.HandlerIDFunc(rg.LastPath), apiRateLimiter, monitoring.Traced, common.TimeoutHandler(5*time.Second), s.Auth.APIKey(headerAPIKey, dbgen.ApiKeyScopePortal))
// tasks
rg.Handle(rg.Get(common.AsyncTaskEndpoint, arg(common.ParamID)), portalAPIChain, http.HandlerFunc(s.getAsyncTask))
// orgs
rg.Handle(rg.Get(common.OrganizationsEndpoint), portalAPIChain, http.HandlerFunc(s.getUserOrgs))
rg.Handle(rg.Post(common.OrgEndpoint), portalAPIChain, http.MaxBytesHandler(http.HandlerFunc(s.postNewOrg), maxAPIPostBodySize))
rg.Handle(rg.Put(common.OrgEndpoint), portalAPIChain, http.MaxBytesHandler(http.HandlerFunc(s.updateOrg), maxAPIPostBodySize))
rg.Handle(rg.Delete(common.OrgEndpoint), portalAPIChain, http.MaxBytesHandler(http.HandlerFunc(s.deleteOrg), maxAPIPostBodySize))
// properties
rg.Handle(rg.Get(common.OrgEndpoint, arg(common.ParamOrg), common.PropertiesEndpoint), portalAPIChain, http.HandlerFunc(s.getOrgProperties))
rg.Handle(rg.Post(common.OrgEndpoint, arg(common.ParamOrg), common.PropertiesEndpoint), portalAPIChain, http.MaxBytesHandler(http.HandlerFunc(s.postNewProperties), maxPostPropertiesBodySize))
rg.Handle(rg.Delete(common.PropertiesEndpoint), portalAPIChain, http.MaxBytesHandler(http.HandlerFunc(s.deleteProperties), maxDeletePropertiesBodySize))
rg.Handle(rg.Put(common.PropertiesEndpoint), portalAPIChain, http.MaxBytesHandler(http.HandlerFunc(s.updateProperties), maxUpdatePropertiesBodySize))
rg.Handle(rg.Get(common.OrgEndpoint, arg(common.ParamOrg), common.PropertyEndpoint, arg(common.ParamProperty)), portalAPIChain, http.HandlerFunc(s.getOrgProperty))
}
func (s *Server) RegisterTaskHandlers(ctx context.Context) {
if ok := s.AsyncTasks.Register(createPropertiesHandlerID, s.handleCreateProperties); !ok {
slog.ErrorContext(ctx, "Failed to register async task handler", "handler", createPropertiesHandlerID)
}
if ok := s.AsyncTasks.Register(deletePropertiesHandlerID, s.handleDeleteProperties); !ok {
slog.ErrorContext(ctx, "Failed to register async task handler", "handler", deletePropertiesHandlerID)
}
if ok := s.AsyncTasks.Register(updatePropertiesHandlerID, s.handleUpdateProperties); !ok {
slog.ErrorContext(ctx, "Failed to register async task handler", "handler", updatePropertiesHandlerID)
}
}
func (s *Server) requestUser(ctx context.Context, readOnly bool) (*dbgen.User, *dbgen.APIKey, error) {
portalOwnerSource := &apiKeyOwnerSource{Store: s.BusinessDB, scope: dbgen.ApiKeyScopePortal}
id, _, err := portalOwnerSource.OwnerID(ctx, time.Now().UTC())
if err != nil {
return nil, nil, err
}
if portalOwnerSource.cachedKey != nil && (portalOwnerSource.cachedKey.Readonly && !readOnly) {
slog.WarnContext(ctx, "API key read-write attribute does not match", "expected", readOnly,
"actual", portalOwnerSource.cachedKey.Readonly)
return nil, nil, errAPIKeyReadOnly
}
user, err := s.BusinessDB.Impl().RetrieveUser(ctx, id)
if err != nil {
return nil, nil, err
}
if !user.SubscriptionID.Valid {
return nil, nil, db.ErrNoActiveSubscription
}
return user, portalOwnerSource.cachedKey, nil
}
func (s *Server) requestOrg(user *dbgen.User, r *http.Request, onlyOwner bool, allowedOrgID *pgtype.Int4) (*dbgen.Organization, error) {
ctx := r.Context()
orgID, value, err := common.IntPathArg(r, common.ParamOrg, s.IDHasher)
if err != nil {
slog.ErrorContext(ctx, "Failed to parse org path parameter", "value", value, common.ErrAttr(err))
return nil, db.ErrInvalidInput
}
if (allowedOrgID != nil) && allowedOrgID.Valid && (allowedOrgID.Int32 != orgID) {
slog.WarnContext(ctx, "Requested organization is not allowed for this requester", "allowedOrgID", allowedOrgID.Int32, "requestedOrgID", orgID)
return nil, db.ErrPermissions
}
org, err := s.BusinessDB.Impl().RetrieveUserOrganization(ctx, user, orgID)
if err != nil {
return nil, err
}
if onlyOwner {
if !org.UserID.Valid || (org.UserID.Int32 != user.ID) {
return nil, db.ErrPermissions
}
}
return org, nil
}
func (s *Server) requestProperty(org *dbgen.Organization, r *http.Request) (*dbgen.Property, error) {
ctx := r.Context()
propertyID, value, err := common.IntPathArg(r, common.ParamProperty, s.IDHasher)
if err != nil {
slog.ErrorContext(ctx, "Failed to parse property path parameter", "value", value, common.ErrAttr(err))
return nil, db.ErrInvalidInput
}
property, err := s.BusinessDB.Impl().RetrieveOrgProperty(ctx, org, propertyID)
if err != nil {
return nil, err
}
return property, nil
}
func (s *Server) sendHTTPErrorResponse(err error, w http.ResponseWriter) {
switch err {
case db.ErrRecordNotFound:
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
case db.ErrInvalidInput:
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
case db.ErrNoActiveSubscription:
http.Error(w, http.StatusText(http.StatusPaymentRequired), http.StatusPaymentRequired)
case db.ErrMaintenance:
http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
case errAPIKeyScope, errInvalidAPIKey, errAPIKeyNotSet, errAPIKeyReadOnly, db.ErrPermissions:
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
case db.ErrSoftDeleted:
http.Error(w, http.StatusText(http.StatusConflict), http.StatusConflict)
default:
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}
func (s *Server) sendAPISuccessResponse(ctx context.Context, data interface{}, w http.ResponseWriter) {
s.sendAPISuccessResponseEx(ctx, &APIResponse{Data: data}, w, common.NoCacheHeaders)
}
func (s *Server) sendAPISuccessResponseEx(ctx context.Context, response *APIResponse, w http.ResponseWriter, headers ...map[string][]string) {
response.Meta = ResponseMetadata{
Code: common.StatusOK,
Description: common.StatusOK.String(),
}
if tid, ok := ctx.Value(common.TraceIDContextKey).(string); ok {
response.Meta.RequestID = tid
}
common.SendJSONResponse(ctx, w, response, headers...)
}
func (s *Server) sendAPIErrorResponse(ctx context.Context, code common.StatusCode, r *http.Request, w http.ResponseWriter) {
response := &APIResponse{
Meta: ResponseMetadata{
Code: code,
Description: code.String(),
},
}
if tid, ok := ctx.Value(common.TraceIDContextKey).(string); ok {
response.Meta.RequestID = tid
}
common.SendJSONResponse(ctx, w, response, common.NoCacheHeaders)
slog.WarnContext(ctx, "Returned API error response", "code", int(code))
s.Metrics.ObserveApiError(r.URL.Path, r.Method, int(code))
}