Files
mantrae/server/internal/api/service/router_ops.go
2025-08-04 16:38:54 +02:00

641 lines
16 KiB
Go

package service
import (
"context"
"log/slog"
"github.com/google/uuid"
mantraev1 "github.com/mizuchilabs/mantrae/proto/gen/mantrae/v1"
"github.com/mizuchilabs/mantrae/server/internal/config"
"github.com/mizuchilabs/mantrae/server/internal/store/db"
"github.com/mizuchilabs/mantrae/server/internal/store/schema"
)
type RouterOps interface {
Get(
ctx context.Context,
req *mantraev1.GetRouterRequest,
) (*mantraev1.GetRouterResponse, error)
Create(
ctx context.Context,
req *mantraev1.CreateRouterRequest,
) (*mantraev1.CreateRouterResponse, error)
Update(
ctx context.Context,
req *mantraev1.UpdateRouterRequest,
) (*mantraev1.UpdateRouterResponse, error)
Delete(
ctx context.Context,
req *mantraev1.DeleteRouterRequest,
) (*mantraev1.DeleteRouterResponse, error)
List(
ctx context.Context,
req *mantraev1.ListRoutersRequest,
) (*mantraev1.ListRoutersResponse, error)
}
type HTTPRouterOps struct {
app *config.App
}
type TCPRouterOps struct {
app *config.App
}
type UDPRouterOps struct {
app *config.App
}
func NewHTTPRouterOps(app *config.App) *HTTPRouterOps {
return &HTTPRouterOps{app: app}
}
func NewTCPRouterOps(app *config.App) *TCPRouterOps {
return &TCPRouterOps{app: app}
}
func NewUDPRouterOps(app *config.App) *UDPRouterOps {
return &UDPRouterOps{app: app}
}
// HTTP Router Operations -----------------------------------------------------
func (s *HTTPRouterOps) Get(
ctx context.Context,
req *mantraev1.GetRouterRequest,
) (*mantraev1.GetRouterResponse, error) {
result, err := s.app.Conn.GetQuery().GetHttpRouter(ctx, req.Id)
if err != nil {
return nil, err
}
dnsProvider, err := s.app.Conn.GetQuery().GetDnsProvidersByHttpRouter(ctx, result.ID)
if err != nil {
return nil, err
}
router := result.ToProto()
for _, p := range dnsProvider {
router.DnsProviders = append(router.DnsProviders, p.ToProto())
}
return &mantraev1.GetRouterResponse{
Router: router,
}, nil
}
func (s *HTTPRouterOps) Create(
ctx context.Context,
req *mantraev1.CreateRouterRequest,
) (*mantraev1.CreateRouterResponse, error) {
params := db.CreateHttpRouterParams{
ID: uuid.New().String(),
ProfileID: req.ProfileId,
Name: req.Name,
AgentID: req.AgentId,
}
var err error
params.Config, err = db.UnmarshalStruct[schema.HTTPRouter](req.Config)
if err != nil {
return nil, err
}
params.Config.Service = params.Name
result, err := s.app.Conn.GetQuery().CreateHttpRouter(ctx, params)
if err != nil {
return nil, err
}
router := result.ToProto()
dnsProviders, err := s.app.Conn.GetQuery().GetDnsProvidersByHttpRouter(ctx, result.ID)
if err != nil {
return nil, err
}
router.DnsProviders = make([]*mantraev1.DnsProvider, 0, len(dnsProviders))
for _, p := range dnsProviders {
router.DnsProviders = append(router.DnsProviders, p.ToProto())
}
s.app.Event.Broadcast(&mantraev1.EventStreamResponse{
Action: mantraev1.EventAction_EVENT_ACTION_CREATED,
Data: &mantraev1.EventStreamResponse_Router{
Router: router,
},
})
return &mantraev1.CreateRouterResponse{
Router: router,
}, nil
}
func (s *HTTPRouterOps) Update(
ctx context.Context,
req *mantraev1.UpdateRouterRequest,
) (*mantraev1.UpdateRouterResponse, error) {
params := db.UpdateHttpRouterParams{
ID: req.Id,
Name: req.Name,
Enabled: req.Enabled,
}
var err error
params.Config, err = db.UnmarshalStruct[schema.HTTPRouter](req.Config)
if err != nil {
return nil, err
}
params.Config.Service = params.Name
// Update DNS Providers
existing, err := s.app.Conn.GetQuery().GetDnsProvidersByHttpRouter(ctx, params.ID)
if err != nil {
return nil, err
}
existingMap := make(map[string]bool)
for _, provider := range existing {
existingMap[provider.ID] = true
}
desiredMap := make(map[string]bool)
var desiredIDs []string
for _, protoProvider := range req.DnsProviders {
desiredMap[protoProvider.Id] = true
desiredIDs = append(desiredIDs, protoProvider.Id)
}
// Identify inserts
for _, id := range desiredIDs {
if !existingMap[id] {
if err = s.app.Conn.GetQuery().
CreateHttpRouterDNSProvider(ctx, db.CreateHttpRouterDNSProviderParams{
HttpRouterID: params.ID,
DnsProviderID: id,
}); err != nil {
return nil, err
}
}
}
// Identify deletes
for id := range existingMap {
if !desiredMap[id] {
if err = s.app.Conn.GetQuery().
DeleteHttpRouterDNSProvider(ctx, db.DeleteHttpRouterDNSProviderParams{
HttpRouterID: params.ID,
DnsProviderID: id,
}); err != nil {
return nil, err
}
}
}
result, err := s.app.Conn.GetQuery().UpdateHttpRouter(ctx, params)
if err != nil {
return nil, err
}
router := result.ToProto()
dnsProviders, err := s.app.Conn.GetQuery().GetDnsProvidersByHttpRouter(ctx, result.ID)
if err != nil {
return nil, err
}
router.DnsProviders = make([]*mantraev1.DnsProvider, 0, len(dnsProviders))
for _, p := range dnsProviders {
router.DnsProviders = append(router.DnsProviders, p.ToProto())
}
s.app.Event.Broadcast(&mantraev1.EventStreamResponse{
Action: mantraev1.EventAction_EVENT_ACTION_UPDATED,
Data: &mantraev1.EventStreamResponse_Router{
Router: router,
},
})
return &mantraev1.UpdateRouterResponse{
Router: router,
}, nil
}
func (s *HTTPRouterOps) Delete(
ctx context.Context,
req *mantraev1.DeleteRouterRequest,
) (*mantraev1.DeleteRouterResponse, error) {
router, err := s.app.Conn.GetQuery().GetHttpRouter(ctx, req.Id)
if err != nil {
return nil, err
}
if router.Config.Service != "" {
service, err := s.app.Conn.GetQuery().
GetHttpServiceByName(ctx, db.GetHttpServiceByNameParams{
ProfileID: router.ProfileID,
Name: router.Config.Service,
})
if err != nil {
slog.Error("failed to get http service", "err", err)
}
if err := s.app.Conn.GetQuery().DeleteHttpService(ctx, service.ID); err != nil {
slog.Error("failed to delete http service", "err", err)
}
}
if err := s.app.Conn.GetQuery().DeleteHttpRouter(ctx, req.Id); err != nil {
return nil, err
}
s.app.Event.Broadcast(&mantraev1.EventStreamResponse{
Action: mantraev1.EventAction_EVENT_ACTION_DELETED,
Data: &mantraev1.EventStreamResponse_Router{
Router: router.ToProto(),
},
})
return &mantraev1.DeleteRouterResponse{}, nil
}
func (s *HTTPRouterOps) List(
ctx context.Context,
req *mantraev1.ListRoutersRequest,
) (*mantraev1.ListRoutersResponse, error) {
result, err := s.app.Conn.GetQuery().
ListHttpRouters(ctx, db.ListHttpRoutersParams{
ProfileID: req.ProfileId,
AgentID: req.AgentId,
Limit: req.Limit,
Offset: req.Offset,
})
if err != nil {
return nil, err
}
totalCount, err := s.app.Conn.GetQuery().CountHttpRouters(ctx, db.CountHttpRoutersParams{
ProfileID: req.ProfileId,
AgentID: req.AgentId,
})
if err != nil {
return nil, err
}
routers := make([]*mantraev1.Router, 0, len(result))
for _, r := range result {
dnsProvider, err := s.app.Conn.GetQuery().GetDnsProvidersByHttpRouter(ctx, r.ID)
if err != nil {
return nil, err
}
router := r.ToProto()
for _, p := range dnsProvider {
router.DnsProviders = append(router.DnsProviders, p.ToProto())
}
routers = append(routers, router)
}
return &mantraev1.ListRoutersResponse{
Routers: routers,
TotalCount: totalCount,
}, nil
}
// TCP Router Operations ------------------------------------------------------
func (s *TCPRouterOps) Get(
ctx context.Context,
req *mantraev1.GetRouterRequest,
) (*mantraev1.GetRouterResponse, error) {
result, err := s.app.Conn.GetQuery().GetTcpRouter(ctx, req.Id)
if err != nil {
return nil, err
}
dnsProvider, err := s.app.Conn.GetQuery().GetDnsProvidersByHttpRouter(ctx, result.ID)
if err != nil {
return nil, err
}
router := result.ToProto()
for _, p := range dnsProvider {
router.DnsProviders = append(router.DnsProviders, p.ToProto())
}
return &mantraev1.GetRouterResponse{
Router: result.ToProto(),
}, nil
}
func (s *TCPRouterOps) Create(
ctx context.Context,
req *mantraev1.CreateRouterRequest,
) (*mantraev1.CreateRouterResponse, error) {
params := db.CreateTcpRouterParams{
ID: uuid.New().String(),
ProfileID: req.ProfileId,
Name: req.Name,
AgentID: req.AgentId,
}
var err error
params.Config, err = db.UnmarshalStruct[schema.TCPRouter](req.Config)
if err != nil {
return nil, err
}
params.Config.Service = params.Name
result, err := s.app.Conn.GetQuery().CreateTcpRouter(ctx, params)
if err != nil {
return nil, err
}
s.app.Event.Broadcast(&mantraev1.EventStreamResponse{
Action: mantraev1.EventAction_EVENT_ACTION_CREATED,
Data: &mantraev1.EventStreamResponse_Router{
Router: result.ToProto(),
},
})
return &mantraev1.CreateRouterResponse{
Router: result.ToProto(),
}, nil
}
func (s *TCPRouterOps) Update(
ctx context.Context,
req *mantraev1.UpdateRouterRequest,
) (*mantraev1.UpdateRouterResponse, error) {
params := db.UpdateTcpRouterParams{
ID: req.Id,
Name: req.Name,
Enabled: req.Enabled,
}
var err error
params.Config, err = db.UnmarshalStruct[schema.TCPRouter](req.Config)
if err != nil {
return nil, err
}
params.Config.Service = params.Name
// Update DNS Providers
existing, err := s.app.Conn.GetQuery().GetDnsProvidersByTcpRouter(ctx, params.ID)
if err != nil {
return nil, err
}
existingMap := make(map[string]bool)
for _, provider := range existing {
existingMap[provider.ID] = true
}
desiredMap := make(map[string]bool)
var desiredIDs []string
for _, protoProvider := range req.DnsProviders {
desiredMap[protoProvider.Id] = true
desiredIDs = append(desiredIDs, protoProvider.Id)
}
// Identify inserts
for _, id := range desiredIDs {
if !existingMap[id] {
if err = s.app.Conn.GetQuery().
CreateTcpRouterDNSProvider(ctx, db.CreateTcpRouterDNSProviderParams{
TcpRouterID: params.ID,
DnsProviderID: id,
}); err != nil {
return nil, err
}
}
}
// Identify deletes
for id := range existingMap {
if !desiredMap[id] {
if err = s.app.Conn.GetQuery().
DeleteTcpRouterDNSProvider(ctx, db.DeleteTcpRouterDNSProviderParams{
TcpRouterID: params.ID,
DnsProviderID: id,
}); err != nil {
return nil, err
}
}
}
result, err := s.app.Conn.GetQuery().UpdateTcpRouter(ctx, params)
if err != nil {
return nil, err
}
s.app.Event.Broadcast(&mantraev1.EventStreamResponse{
Action: mantraev1.EventAction_EVENT_ACTION_UPDATED,
Data: &mantraev1.EventStreamResponse_Router{
Router: result.ToProto(),
},
})
return &mantraev1.UpdateRouterResponse{
Router: result.ToProto(),
}, nil
}
func (s *TCPRouterOps) Delete(
ctx context.Context,
req *mantraev1.DeleteRouterRequest,
) (*mantraev1.DeleteRouterResponse, error) {
router, err := s.app.Conn.GetQuery().GetTcpRouter(ctx, req.Id)
if err != nil {
return nil, err
}
if router.Config.Service != "" {
service, err := s.app.Conn.GetQuery().GetTcpServiceByName(ctx, db.GetTcpServiceByNameParams{
ProfileID: router.ProfileID,
Name: router.Config.Service,
})
if err != nil {
slog.Error("failed to get tcp service", "err", err)
}
if err := s.app.Conn.GetQuery().DeleteTcpService(ctx, service.ID); err != nil {
slog.Error("failed to delete tcp service", "err", err)
}
}
if err := s.app.Conn.GetQuery().DeleteTcpRouter(ctx, req.Id); err != nil {
return nil, err
}
s.app.Event.Broadcast(&mantraev1.EventStreamResponse{
Action: mantraev1.EventAction_EVENT_ACTION_DELETED,
Data: &mantraev1.EventStreamResponse_Router{
Router: router.ToProto(),
},
})
return &mantraev1.DeleteRouterResponse{}, nil
}
func (s *TCPRouterOps) List(
ctx context.Context,
req *mantraev1.ListRoutersRequest,
) (*mantraev1.ListRoutersResponse, error) {
result, err := s.app.Conn.GetQuery().
ListTcpRouters(ctx, db.ListTcpRoutersParams{
ProfileID: req.ProfileId,
AgentID: req.AgentId,
Limit: req.Limit,
Offset: req.Offset,
})
if err != nil {
return nil, err
}
totalCount, err := s.app.Conn.GetQuery().CountTcpRouters(ctx, db.CountTcpRoutersParams{
ProfileID: req.ProfileId,
AgentID: req.AgentId,
})
if err != nil {
return nil, err
}
routers := make([]*mantraev1.Router, 0, len(result))
for _, r := range result {
dnsProvider, err := s.app.Conn.GetQuery().GetDnsProvidersByTcpRouter(ctx, r.ID)
if err != nil {
return nil, err
}
router := r.ToProto()
for _, p := range dnsProvider {
router.DnsProviders = append(router.DnsProviders, p.ToProto())
}
routers = append(routers, router)
}
return &mantraev1.ListRoutersResponse{
Routers: routers,
TotalCount: totalCount,
}, nil
}
// UDP Router Operations ------------------------------------------------------
func (s *UDPRouterOps) Get(
ctx context.Context,
req *mantraev1.GetRouterRequest,
) (*mantraev1.GetRouterResponse, error) {
result, err := s.app.Conn.GetQuery().GetUdpRouter(ctx, req.Id)
if err != nil {
return nil, err
}
return &mantraev1.GetRouterResponse{
Router: result.ToProto(),
}, nil
}
func (s *UDPRouterOps) Create(
ctx context.Context,
req *mantraev1.CreateRouterRequest,
) (*mantraev1.CreateRouterResponse, error) {
params := db.CreateUdpRouterParams{
ID: uuid.New().String(),
ProfileID: req.ProfileId,
Name: req.Name,
AgentID: req.AgentId,
}
var err error
params.Config, err = db.UnmarshalStruct[schema.UDPRouter](req.Config)
if err != nil {
return nil, err
}
params.Config.Service = params.Name
result, err := s.app.Conn.GetQuery().CreateUdpRouter(ctx, params)
if err != nil {
return nil, err
}
s.app.Event.Broadcast(&mantraev1.EventStreamResponse{
Action: mantraev1.EventAction_EVENT_ACTION_CREATED,
Data: &mantraev1.EventStreamResponse_Router{
Router: result.ToProto(),
},
})
return &mantraev1.CreateRouterResponse{
Router: result.ToProto(),
}, nil
}
func (s *UDPRouterOps) Update(
ctx context.Context,
req *mantraev1.UpdateRouterRequest,
) (*mantraev1.UpdateRouterResponse, error) {
params := db.UpdateUdpRouterParams{
ID: req.Id,
Name: req.Name,
Enabled: req.Enabled,
}
var err error
params.Config, err = db.UnmarshalStruct[schema.UDPRouter](req.Config)
if err != nil {
return nil, err
}
params.Config.Service = params.Name
result, err := s.app.Conn.GetQuery().UpdateUdpRouter(ctx, params)
if err != nil {
return nil, err
}
s.app.Event.Broadcast(&mantraev1.EventStreamResponse{
Action: mantraev1.EventAction_EVENT_ACTION_UPDATED,
Data: &mantraev1.EventStreamResponse_Router{
Router: result.ToProto(),
},
})
return &mantraev1.UpdateRouterResponse{
Router: result.ToProto(),
}, nil
}
func (s *UDPRouterOps) Delete(
ctx context.Context,
req *mantraev1.DeleteRouterRequest,
) (*mantraev1.DeleteRouterResponse, error) {
router, err := s.app.Conn.GetQuery().GetUdpRouter(ctx, req.Id)
if err != nil {
return nil, err
}
if router.Config.Service != "" {
service, err := s.app.Conn.GetQuery().GetUdpServiceByName(ctx, db.GetUdpServiceByNameParams{
ProfileID: router.ProfileID,
Name: router.Config.Service,
})
if err != nil {
slog.Error("failed to get udp service", "err", err)
}
if err := s.app.Conn.GetQuery().DeleteUdpService(ctx, service.ID); err != nil {
slog.Error("failed to delete udp service", "err", err)
}
}
if err := s.app.Conn.GetQuery().DeleteUdpRouter(ctx, req.Id); err != nil {
return nil, err
}
s.app.Event.Broadcast(&mantraev1.EventStreamResponse{
Action: mantraev1.EventAction_EVENT_ACTION_DELETED,
Data: &mantraev1.EventStreamResponse_Router{
Router: router.ToProto(),
},
})
return &mantraev1.DeleteRouterResponse{}, nil
}
func (s *UDPRouterOps) List(
ctx context.Context,
req *mantraev1.ListRoutersRequest,
) (*mantraev1.ListRoutersResponse, error) {
result, err := s.app.Conn.GetQuery().
ListUdpRouters(ctx, db.ListUdpRoutersParams{
ProfileID: req.ProfileId,
AgentID: req.AgentId,
Limit: req.Limit,
Offset: req.Offset,
})
if err != nil {
return nil, err
}
totalCount, err := s.app.Conn.GetQuery().CountUdpRouters(ctx, db.CountUdpRoutersParams{
ProfileID: req.ProfileId,
AgentID: req.AgentId,
})
if err != nil {
return nil, err
}
routers := make([]*mantraev1.Router, 0, len(result))
for _, r := range result {
routers = append(routers, r.ToProto())
}
return &mantraev1.ListRoutersResponse{
Routers: routers,
TotalCount: totalCount,
}, nil
}