Merge pull request #213 from mayanayza/feat/rate-limiting

feat: rate limiting
This commit is contained in:
Maya
2025-11-28 16:47:14 -05:00
committed by GitHub
57 changed files with 837 additions and 449 deletions
+95 -1
View File
@@ -1200,6 +1200,20 @@ dependencies = [
"syn 2.0.110",
]
[[package]]
name = "dashmap"
version = "6.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
dependencies = [
"cfg-if",
"crossbeam-utils",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]]
name = "data-encoding"
version = "2.9.0"
@@ -1661,6 +1675,12 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "foldhash"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb"
[[package]]
name = "form_urlencoded"
version = "1.2.2"
@@ -1775,6 +1795,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
@@ -1837,6 +1863,29 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
[[package]]
name = "governor"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e23d5986fd4364c2fb7498523540618b4b8d92eec6c36a02e565f66748e2f79"
dependencies = [
"cfg-if",
"dashmap",
"futures-sink",
"futures-timer",
"futures-util",
"getrandom 0.3.4",
"hashbrown 0.16.0",
"nonzero_ext",
"parking_lot",
"portable-atomic",
"quanta",
"rand 0.9.2",
"smallvec",
"spinning_top",
"web-time",
]
[[package]]
name = "group"
version = "0.13.0"
@@ -1891,7 +1940,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash",
"foldhash 0.1.5",
]
[[package]]
@@ -1899,6 +1948,11 @@ name = "hashbrown"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash 0.2.0",
]
[[package]]
name = "hashlink"
@@ -2998,6 +3052,7 @@ dependencies = [
"fastrand",
"figment",
"futures",
"governor",
"hex",
"hostname",
"html2text",
@@ -3123,6 +3178,12 @@ dependencies = [
"memchr",
]
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]]
name = "nu-ansi-term"
version = "0.50.3"
@@ -3979,6 +4040,21 @@ dependencies = [
"psl-types",
]
[[package]]
name = "quanta"
version = "0.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
"wasi",
"web-sys",
"winapi",
]
[[package]]
name = "quinn"
version = "0.11.9"
@@ -4120,6 +4196,15 @@ version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbbbbea733ec66275512d0b9694f34102e7d5406fdbe2ad8d21b28dce92887c"
[[package]]
name = "raw-cpuid"
version = "11.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
dependencies = [
"bitflags",
]
[[package]]
name = "redox_syscall"
version = "0.5.18"
@@ -4927,6 +5012,15 @@ dependencies = [
"lock_api",
]
[[package]]
name = "spinning_top"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
dependencies = [
"lock_api",
]
[[package]]
name = "spki"
version = "0.7.3"
+2 -1
View File
@@ -158,8 +158,9 @@ serde_with = "3.15.1"
lettre = { version = "0.11.19", default-features = false, features = ["smtp-transport", "builder", "tokio1", "tokio1-rustls", "ring", "webpki-roots"] }
html2text = "0.16.4"
json_value_merge = "2.0.1"
bad_email = "0.1.1"
axum-client-ip = "1.1.3"
governor = "0.10.2"
bad_email = "0.1.1"
# === Platform-specific Dependencies ===
[target.'cfg(target_os = "linux")'.dependencies]
+23 -13
View File
@@ -3,11 +3,15 @@ use std::{net::SocketAddr, str::FromStr, sync::Arc, time::Duration};
use axum::{
Extension, Router,
http::{HeaderValue, Method},
middleware,
};
use axum_client_ip::ClientIpSource;
use clap::Parser;
use netvisor::server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::{
auth::AuthenticatedEntity, logging::request_logging_middleware,
rate_limit::rate_limit_middleware,
},
billing::types::base::{BillingPlan, BillingRate, PlanConfig},
config::{AppState, ServerCli, ServerConfig},
organizations::r#impl::base::{Organization, OrganizationBase},
@@ -43,8 +47,8 @@ async fn main() -> anyhow::Result<()> {
// Initialize tracing
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new(format!(
"netvisor={},server={}",
config.log_level, config.log_level
"netvisor={},server={},request_log={}",
config.log_level, config.log_level, config.log_level
)))
.with(tracing_subscriber::fmt::layer())
.init();
@@ -109,14 +113,10 @@ async fn main() -> anyhow::Result<()> {
}
});
let session_store = state.storage.sessions.clone();
let base_router = create_router().with_state(state.clone());
let api_router = if let Some(static_path) = &web_external_path {
// First create the API router
let router = create_router().layer(session_store).with_state(state);
// Then add static file serving with SPA fallback
router.fallback_service(
base_router.fallback_service(
ServeDir::new(static_path)
.append_index_html_on_directories(true)
.fallback(ServeFile::new(format!(
@@ -126,9 +126,11 @@ async fn main() -> anyhow::Result<()> {
)
} else {
tracing::info!("Server is not serving web assets due to no web_external_path");
create_router().layer(session_store).with_state(state)
base_router
};
let session_store = state.storage.sessions.clone();
let cors = if cfg!(debug_assertions) {
// Development: Allow localhost with credentials
CorsLayer::new()
@@ -165,13 +167,21 @@ async fn main() -> anyhow::Result<()> {
// Create main app
let app = Router::new().merge(api_router).layer(
ServiceBuilder::new()
.layer(client_ip_source.into_extension())
.layer(TraceLayer::new_for_http())
.layer(cors)
.layer(session_store)
.layer(middleware::from_fn_with_state(
state.clone(),
request_logging_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
rate_limit_middleware,
))
.layer(Extension(app_cache))
.layer(cache_headers)
.layer(client_ip_source.into_extension()),
.layer(cache_headers),
);
let listener = tokio::net::TcpListener::bind(&listen_addr).await?;
let actual_port = listener.local_addr()?.port();
+11 -3
View File
@@ -144,6 +144,7 @@ pub trait RunsDiscovery: AsRef<DaemonDiscoveryService> + Send + Sync {
let server_target = self.as_ref().config_store.get_server_url().await?;
let session = self.as_ref().get_session().await?;
let discovery_type = self.discovery_type();
let daemon_id = self.as_ref().config_store.get_id().await?;
let api_key = self
.as_ref()
@@ -165,6 +166,7 @@ pub trait RunsDiscovery: AsRef<DaemonDiscoveryService> + Send + Sync {
"{}/api/discovery/{}/update",
server_target, session.info.session_id
))
.header("X-Daemon-ID", daemon_id.to_string())
.header("Authorization", format!("Bearer {}", api_key))
.json(&payload)
.send()
@@ -411,8 +413,7 @@ pub trait DiscoversNetworkedEntities:
ip = %interface.base.ip_address,
host_name = %host.base.name,
service_count = %services.len(),
"Processed host for ip {}",
interface.base.ip_address
"Processed host",
);
Ok(Some((host, services)))
}
@@ -596,7 +597,7 @@ pub trait CreatesDiscoveredEntities:
services: Vec<Service>,
) -> Result<(Host, Vec<Service>), Error> {
let server_target = self.as_ref().config_store.get_server_url().await?;
let daemon_id = self.as_ref().config_store.get_id().await?;
tracing::info!("Creating host {}", host.base.name);
let api_key = self
@@ -610,6 +611,7 @@ pub trait CreatesDiscoveredEntities:
.as_ref()
.client
.post(format!("{}/api/hosts", server_target))
.header("X-Daemon-ID", daemon_id.to_string())
.header("Authorization", format!("Bearer {}", api_key))
.json(&HostWithServicesRequest {
host,
@@ -645,6 +647,7 @@ pub trait CreatesDiscoveredEntities:
async fn create_subnet(&self, subnet: &Subnet) -> Result<Subnet, Error> {
let server_target = self.as_ref().config_store.get_server_url().await?;
let daemon_id = self.as_ref().config_store.get_id().await?;
let api_key = self
.as_ref()
@@ -657,6 +660,7 @@ pub trait CreatesDiscoveredEntities:
.as_ref()
.client
.post(format!("{}/api/subnets", server_target))
.header("X-Daemon-ID", daemon_id.to_string())
.header("Authorization", format!("Bearer {}", api_key))
.json(&subnet)
.send()
@@ -687,6 +691,7 @@ pub trait CreatesDiscoveredEntities:
async fn create_service(&self, service: &Service) -> Result<Service, Error> {
let server_target = self.as_ref().config_store.get_server_url().await?;
let daemon_id = self.as_ref().config_store.get_id().await?;
let api_key = self
.as_ref()
@@ -699,6 +704,7 @@ pub trait CreatesDiscoveredEntities:
.as_ref()
.client
.post(format!("{}/api/services", server_target))
.header("X-Daemon-ID", daemon_id.to_string())
.header("Authorization", format!("Bearer {}", api_key))
.json(&service)
.send()
@@ -729,6 +735,7 @@ pub trait CreatesDiscoveredEntities:
async fn create_group(&self, group: &Group) -> Result<Group, Error> {
let server_target = self.as_ref().config_store.get_server_url().await?;
let daemon_id = self.as_ref().config_store.get_id().await?;
let api_key = self
.as_ref()
@@ -741,6 +748,7 @@ pub trait CreatesDiscoveredEntities:
.as_ref()
.client
.post(format!("{}/api/groups", server_target))
.header("X-Daemon-ID", daemon_id.to_string())
.header("Authorization", format!("Bearer {}", api_key))
.json(&group)
.send()
@@ -295,6 +295,7 @@ impl DiscoveryRunner<SelfReportDiscovery> {
"{}/api/daemons/{}/update-capabilities",
server_target, daemon_id
))
.header("X-Daemon-ID", daemon_id.to_string())
.header("Authorization", format!("Bearer {}", api_key))
.json(&capabilities)
.send()
+3
View File
@@ -59,6 +59,7 @@ impl DaemonRuntimeService {
server_target, daemon_id
))
.json(&daemon_id)
.header("X-Daemon-ID", daemon_id.to_string())
.header("Authorization", format!("Bearer {}", api_key))
.send()
.await?;
@@ -156,6 +157,7 @@ impl DaemonRuntimeService {
"{}/api/daemons/{}/heartbeat",
server_target, daemon_id
))
.header("X-Daemon-ID", daemon_id.to_string())
.header("Authorization", format!("Bearer {}", api_key))
.send()
.await?;
@@ -275,6 +277,7 @@ impl DaemonRuntimeService {
let response = self
.client
.post(format!("{}/api/daemons/register", server_target))
.header("X-Daemon-ID", daemon_id.to_string())
.header("Authorization", format!("Bearer {}", api_key))
.json(&registration_request)
.send()
+1 -1
View File
@@ -1,6 +1,6 @@
use crate::server::{
api_keys::r#impl::{api::ApiKeyResponse, base::ApiKey},
auth::middleware::RequireMember,
auth::middleware::permissions::RequireMember,
config::AppState,
shared::{
events::types::{TelemetryEvent, TelemetryOperation},
+1 -1
View File
@@ -7,7 +7,7 @@ use uuid::Uuid;
use crate::server::{
api_keys::r#impl::base::ApiKey,
auth::middleware::{AuthenticatedEntity, AuthenticatedUser},
auth::middleware::auth::{AuthenticatedEntity, AuthenticatedUser},
shared::{
entities::ChangeTriggersTopologyStaleness,
events::{
+1 -1
View File
@@ -9,7 +9,7 @@ use crate::server::{
base::LoginRegisterParams,
oidc::{OidcFlow, OidcPendingAuth, OidcProviderMetadata},
},
middleware::AuthenticatedUser,
middleware::auth::AuthenticatedUser,
oidc::OidcService,
},
config::AppState,
@@ -1,13 +1,10 @@
use std::fmt::Display;
use crate::server::{
billing::types::base::BillingPlan,
config::AppState,
organizations::r#impl::base::Organization,
shared::{services::traits::CrudService, storage::filter::EntityFilter, types::api::ApiError},
users::r#impl::{base::User, permissions::UserOrgPermissions},
};
use async_trait::async_trait;
use axum::{
extract::FromRequestParts,
http::request::Parts,
@@ -20,7 +17,7 @@ use serde::Serialize;
use tower_sessions::Session;
use uuid::Uuid;
pub struct AuthError(ApiError);
pub struct AuthError(pub ApiError);
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
@@ -41,6 +38,7 @@ pub enum AuthenticatedEntity {
Daemon {
network_id: Uuid,
api_key_id: Uuid,
daemon_id: Uuid,
}, // network_id
System,
Anonymous,
@@ -77,13 +75,7 @@ impl AuthenticatedEntity {
pub fn entity_id(&self) -> String {
match self {
AuthenticatedEntity::User { user_id, .. } => user_id.to_string(),
AuthenticatedEntity::Daemon {
network_id,
api_key_id,
} => format!(
"Daemon for network {} using API key {}",
network_id, api_key_id
),
AuthenticatedEntity::Daemon { daemon_id, .. } => daemon_id.to_string(),
AuthenticatedEntity::System => "System".to_string(),
AuthenticatedEntity::Anonymous => "Anonymous".to_string(),
}
@@ -136,6 +128,11 @@ where
if let Some(auth_header) = parts.headers.get(axum::http::header::AUTHORIZATION)
&& let Ok(auth_str) = auth_header.to_str()
&& let Some(api_key) = auth_str.strip_prefix("Bearer ")
&& let Some(daemon_id) = parts
.headers
.get("X-Daemon-ID")
.and_then(|h| h.to_str().ok())
.and_then(|s| Uuid::parse_str(s).ok())
{
let api_key_filter = EntityFilter::unfiltered().api_key(api_key.to_owned());
// Get API key record by key
@@ -181,6 +178,7 @@ where
return Ok(AuthenticatedEntity::Daemon {
network_id,
api_key_id,
daemon_id,
});
}
// Invalid API key
@@ -294,6 +292,7 @@ where
pub struct AuthenticatedDaemon {
pub network_id: Uuid,
pub api_key_id: Uuid,
pub daemon_id: Uuid,
}
impl From<AuthenticatedDaemon> for AuthenticatedEntity {
@@ -301,6 +300,7 @@ impl From<AuthenticatedDaemon> for AuthenticatedEntity {
AuthenticatedEntity::Daemon {
network_id: value.network_id,
api_key_id: value.api_key_id,
daemon_id: value.daemon_id,
}
}
}
@@ -318,9 +318,11 @@ where
AuthenticatedEntity::Daemon {
network_id,
api_key_id,
daemon_id,
} => Ok(AuthenticatedDaemon {
network_id,
api_key_id,
daemon_id,
}),
_ => Err(AuthError(ApiError::unauthorized(
"Daemon authentication required".to_string(),
@@ -328,315 +330,3 @@ where
}
}
}
/// Extractor that accepts either a Member+ user OR a daemon
/// Returns the network IDs the authenticated entity has access to
pub struct MemberOrDaemon {
pub network_ids: Vec<Uuid>,
pub entity: AuthenticatedEntity,
}
impl<S> FromRequestParts<S> for MemberOrDaemon
where
S: Send + Sync + AsRef<AppState>,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// Get the authenticated entity (works for both users and daemons)
let entity = AuthenticatedEntity::from_request_parts(parts, state).await?;
match entity {
AuthenticatedEntity::User { .. } => {
// For users, check they're at least Member level
let member = RequireMember::from_request_parts(parts, state).await?;
let user: AuthenticatedUser = member.into();
Ok(MemberOrDaemon {
network_ids: user.network_ids.clone(),
entity: user.into(),
})
}
AuthenticatedEntity::Daemon { network_id, .. } => {
// Daemons only have access to their single network
Ok(MemberOrDaemon {
network_ids: vec![network_id],
entity,
})
}
_ => Err(AuthError(ApiError::forbidden(
"Member or Daemon permission required",
))),
}
}
}
/// Extractor that requires the user to be at least an Owner
pub struct RequireOwner(pub AuthenticatedUser);
impl From<RequireOwner> for AuthenticatedUser {
fn from(value: RequireOwner) -> Self {
value.0
}
}
impl<S> FromRequestParts<S> for RequireOwner
where
S: Send + Sync + AsRef<AppState>,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let user = AuthenticatedUser::from_request_parts(parts, state).await?;
if user.permissions < UserOrgPermissions::Owner {
return Err(AuthError(ApiError::forbidden("Owner permission required")));
}
Ok(RequireOwner(user))
}
}
/// Extractor that requires the user to be at least an Admin
pub struct RequireAdmin(pub AuthenticatedUser);
impl From<RequireAdmin> for AuthenticatedUser {
fn from(value: RequireAdmin) -> Self {
value.0
}
}
impl From<RequireOwner> for RequireAdmin {
fn from(value: RequireOwner) -> Self {
RequireAdmin(value.0)
}
}
impl<S> FromRequestParts<S> for RequireAdmin
where
S: Send + Sync + AsRef<AppState>,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let user = AuthenticatedUser::from_request_parts(parts, state).await?;
if user.permissions < UserOrgPermissions::Admin {
return Err(AuthError(ApiError::forbidden("Admin permission required")));
}
Ok(RequireAdmin(user))
}
}
/// Extractor that requires the user to be at least a Member
pub struct RequireMember(pub AuthenticatedUser);
impl From<RequireMember> for AuthenticatedUser {
fn from(value: RequireMember) -> Self {
value.0
}
}
impl From<RequireOwner> for RequireMember {
fn from(value: RequireOwner) -> Self {
RequireMember(value.0)
}
}
impl From<RequireAdmin> for RequireMember {
fn from(value: RequireAdmin) -> Self {
RequireMember(value.0)
}
}
impl<S> FromRequestParts<S> for RequireMember
where
S: Send + Sync + AsRef<AppState>,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let user = AuthenticatedUser::from_request_parts(parts, state).await?;
if user.permissions < UserOrgPermissions::Member {
return Err(AuthError(ApiError::forbidden("Member permission required")));
}
Ok(RequireMember(user))
}
}
/// Context available for feature/quota checks
pub struct FeatureCheckContext<'a> {
pub organization: &'a Organization,
pub plan: BillingPlan,
pub app_state: &'a AppState,
}
pub enum FeatureCheckResult {
Allowed,
Denied { message: String },
}
impl FeatureCheckResult {
pub fn denied(msg: impl Into<String>) -> Self {
Self::Denied {
message: msg.into(),
}
}
pub fn is_allowed(&self) -> bool {
matches!(self, Self::Allowed)
}
}
#[async_trait]
pub trait FeatureCheck: Send + Sync + Default {
async fn check(&self, ctx: &FeatureCheckContext<'_>) -> FeatureCheckResult;
}
// ============ Extractor ============
pub struct RequireFeature<T: FeatureCheck> {
pub permissions: UserOrgPermissions,
pub plan: BillingPlan,
pub organization: Organization,
pub _phantom: std::marker::PhantomData<T>,
}
impl<S, T> FromRequestParts<S> for RequireFeature<T>
where
S: Send + Sync + AsRef<AppState>,
T: FeatureCheck + Default,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let AuthenticatedUser {
permissions,
organization_id,
..
} = AuthenticatedUser::from_request_parts(parts, state).await?;
let app_state = state.as_ref();
let organization = app_state
.services
.organization_service
.get_by_id(&organization_id)
.await
.map_err(|_| AuthError(ApiError::internal_error("Failed to load organization")))?
.ok_or_else(|| AuthError(ApiError::forbidden("Organization not found")))?;
let plan = organization.base.plan.unwrap_or_default();
let ctx = FeatureCheckContext {
organization: &organization,
plan,
app_state,
};
let checker = T::default();
match checker.check(&ctx).await {
FeatureCheckResult::Allowed => Ok(RequireFeature {
permissions,
plan,
organization,
_phantom: std::marker::PhantomData,
}),
FeatureCheckResult::Denied { message } => Err(AuthError(ApiError::forbidden(&message))),
}
}
}
// ============ Concrete Checkers ============
#[derive(Default)]
pub struct InviteUsersFeature;
#[async_trait]
impl FeatureCheck for InviteUsersFeature {
async fn check(&self, ctx: &FeatureCheckContext<'_>) -> FeatureCheckResult {
let features = ctx.plan.features();
if !features.share_views {
return FeatureCheckResult::denied(
"Your plan does not include team collaboration features",
);
}
// Check seat quota if there's a limit and user doesn't have a plan that lets them buy more seats
if let Some(max_seats) = ctx.plan.config().included_seats
&& ctx.plan.config().seat_cents.is_none()
{
let org_filter = EntityFilter::unfiltered().organization_id(&ctx.organization.id);
let current_members = ctx
.app_state
.services
.user_service
.get_all(org_filter)
.await
.unwrap_or_default()
.iter()
.filter(|u| u.base.permissions.counts_towards_seats())
.count();
let pending_invites = ctx
.app_state
.services
.organization_service
.get_org_invites(&ctx.organization.id)
.await
.unwrap_or_default()
.iter()
.filter(|i| i.permissions.counts_towards_seats())
.count();
let total_seats_used = current_members + pending_invites;
if total_seats_used >= max_seats as usize {
return FeatureCheckResult::denied(format!(
"Seat limit reached ({}/{}). Upgrade your plan for more seats, or delete any unused pending invites.",
total_seats_used, max_seats
));
}
}
FeatureCheckResult::Allowed
}
}
#[derive(Default)]
pub struct CreateNetworkFeature;
#[async_trait]
impl FeatureCheck for CreateNetworkFeature {
async fn check(&self, ctx: &FeatureCheckContext<'_>) -> FeatureCheckResult {
// Check networks quota if there's a limit and user doesn't have a plan that lets them buy more networks
if let Some(max_networks) = ctx.plan.config().included_networks
&& ctx.plan.config().network_cents.is_none()
{
let org_filter = EntityFilter::unfiltered().organization_id(&ctx.organization.id);
let current_networks = ctx
.app_state
.services
.network_service
.get_all(org_filter)
.await
.map(|o| o.len())
.unwrap_or(0);
if current_networks >= max_networks as usize {
return FeatureCheckResult::denied(format!(
"Network limit reached ({}/{}). Upgrade your plan for more networks.",
current_networks, max_networks
));
}
}
FeatureCheckResult::Allowed
}
}
@@ -0,0 +1,145 @@
use crate::server::{
auth::middleware::auth::{AuthError, AuthenticatedUser},
billing::types::base::BillingPlan,
config::AppState,
organizations::r#impl::base::Organization,
shared::{services::traits::CrudService, storage::filter::EntityFilter, types::api::ApiError},
users::r#impl::permissions::UserOrgPermissions,
};
use async_trait::async_trait;
use axum::{extract::FromRequestParts, http::request::Parts};
/// Context available for feature/quota checks
pub struct FeatureCheckContext<'a> {
pub organization: &'a Organization,
pub plan: BillingPlan,
pub app_state: &'a AppState,
}
pub enum FeatureCheckResult {
Allowed,
Denied { message: String },
}
impl FeatureCheckResult {
pub fn denied(msg: impl Into<String>) -> Self {
Self::Denied {
message: msg.into(),
}
}
pub fn is_allowed(&self) -> bool {
matches!(self, Self::Allowed)
}
}
#[async_trait]
pub trait FeatureCheck: Send + Sync + Default {
async fn check(&self, ctx: &FeatureCheckContext<'_>) -> FeatureCheckResult;
}
// ============ Extractor ============
pub struct RequireFeature<T: FeatureCheck> {
pub permissions: UserOrgPermissions,
pub plan: BillingPlan,
pub organization: Organization,
pub _phantom: std::marker::PhantomData<T>,
}
impl<S, T> FromRequestParts<S> for RequireFeature<T>
where
S: Send + Sync + AsRef<AppState>,
T: FeatureCheck + Default,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let AuthenticatedUser {
permissions,
organization_id,
..
} = AuthenticatedUser::from_request_parts(parts, state).await?;
let app_state = state.as_ref();
let organization = app_state
.services
.organization_service
.get_by_id(&organization_id)
.await
.map_err(|_| AuthError(ApiError::internal_error("Failed to load organization")))?
.ok_or_else(|| AuthError(ApiError::forbidden("Organization not found")))?;
let plan = organization.base.plan.unwrap_or_default();
let ctx = FeatureCheckContext {
organization: &organization,
plan,
app_state,
};
let checker = T::default();
match checker.check(&ctx).await {
FeatureCheckResult::Allowed => Ok(RequireFeature {
permissions,
plan,
organization,
_phantom: std::marker::PhantomData,
}),
FeatureCheckResult::Denied { message } => Err(AuthError(ApiError::forbidden(&message))),
}
}
}
// ============ Concrete Checkers ============
#[derive(Default)]
pub struct InviteUsersFeature;
#[async_trait]
impl FeatureCheck for InviteUsersFeature {
async fn check(&self, ctx: &FeatureCheckContext<'_>) -> FeatureCheckResult {
let features = ctx.plan.features();
if !features.share_views {
return FeatureCheckResult::denied("Your plan does not include inviting users");
}
// Seat check happens in the handler where we have access to the request body
FeatureCheckResult::Allowed
}
}
#[derive(Default)]
pub struct CreateNetworkFeature;
#[async_trait]
impl FeatureCheck for CreateNetworkFeature {
async fn check(&self, ctx: &FeatureCheckContext<'_>) -> FeatureCheckResult {
// Check networks quota if there's a limit and user doesn't have a plan that lets them buy more networks
if let Some(max_networks) = ctx.plan.config().included_networks
&& ctx.plan.config().network_cents.is_none()
{
let org_filter = EntityFilter::unfiltered().organization_id(&ctx.organization.id);
let current_networks = ctx
.app_state
.services
.network_service
.get_all(org_filter)
.await
.map(|o| o.len())
.unwrap_or(0);
if current_networks >= max_networks as usize {
return FeatureCheckResult::denied(format!(
"Network limit reached ({}/{}). Upgrade your plan for more networks.",
current_networks, max_networks
));
}
}
FeatureCheckResult::Allowed
}
}
@@ -0,0 +1,66 @@
use axum::{
extract::{FromRequestParts, MatchedPath, Request, State},
middleware::Next,
response::Response,
};
use axum_client_ip::ClientIp;
use std::{sync::Arc, time::Instant};
use crate::server::{auth::middleware::auth::AuthenticatedEntity, config::AppState};
pub async fn request_logging_middleware(
State(state): State<Arc<AppState>>,
ClientIp(ip): ClientIp,
request: Request,
next: Next,
) -> Response {
let start = Instant::now();
// Extract info before consuming request
let method = request.method().clone();
let uri = request.uri().clone();
let path = request
.extensions()
.get::<MatchedPath>()
.map(|p| p.as_str().to_owned())
.unwrap_or_else(|| uri.path().to_owned());
// Extract auth info
let (mut parts, body) = request.into_parts();
let entity = AuthenticatedEntity::from_request_parts(&mut parts, &state)
.await
.ok();
let (entity_type, entity_id) = match &entity {
Some(AuthenticatedEntity::User { user_id, .. }) => ("user", Some(user_id.to_string())),
Some(AuthenticatedEntity::Daemon { daemon_id, .. }) => {
("daemon", Some(daemon_id.to_string()))
}
Some(AuthenticatedEntity::System) => ("system", None),
Some(AuthenticatedEntity::Anonymous) | None => ("anonymous", None),
};
let request = Request::from_parts(parts, body);
// Process request
let response = next.run(request).await;
// Capture response info
let duration = start.elapsed();
let status = response.status().as_u16();
// Log the request
tracing::debug!(
target: "request_log",
method = %method,
path = %path,
status = status,
duration_ms = duration.as_millis() as u64,
ip = %ip,
entity_type = entity_type,
entity_id = entity_id,
"request completed"
);
response
}
@@ -0,0 +1,5 @@
pub mod auth;
pub mod features;
pub mod logging;
pub mod permissions;
pub mod rate_limit;
@@ -0,0 +1,146 @@
use crate::server::auth::middleware::auth::AuthError;
use crate::server::auth::middleware::auth::AuthenticatedEntity;
use crate::server::auth::middleware::auth::AuthenticatedUser;
use crate::server::{
config::AppState, shared::types::api::ApiError, users::r#impl::permissions::UserOrgPermissions,
};
use axum::{extract::FromRequestParts, http::request::Parts};
use uuid::Uuid;
/// Extractor that accepts either a Member+ user OR a daemon
/// Returns the network IDs the authenticated entity has access to
pub struct MemberOrDaemon {
pub network_ids: Vec<Uuid>,
pub entity: AuthenticatedEntity,
}
impl<S> FromRequestParts<S> for MemberOrDaemon
where
S: Send + Sync + AsRef<AppState>,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// Get the authenticated entity (works for both users and daemons)
let entity = AuthenticatedEntity::from_request_parts(parts, state).await?;
match entity {
AuthenticatedEntity::User { .. } => {
// For users, check they're at least Member level
let member = RequireMember::from_request_parts(parts, state).await?;
let user: AuthenticatedUser = member.into();
Ok(MemberOrDaemon {
network_ids: user.network_ids.clone(),
entity: user.into(),
})
}
AuthenticatedEntity::Daemon { network_id, .. } => {
// Daemons only have access to their single network
Ok(MemberOrDaemon {
network_ids: vec![network_id],
entity,
})
}
_ => Err(AuthError(ApiError::forbidden(
"Member or Daemon permission required",
))),
}
}
}
/// Extractor that requires the user to be at least an Owner
pub struct RequireOwner(pub AuthenticatedUser);
impl From<RequireOwner> for AuthenticatedUser {
fn from(value: RequireOwner) -> Self {
value.0
}
}
impl<S> FromRequestParts<S> for RequireOwner
where
S: Send + Sync + AsRef<AppState>,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let user = AuthenticatedUser::from_request_parts(parts, state).await?;
if user.permissions < UserOrgPermissions::Owner {
return Err(AuthError(ApiError::forbidden("Owner permission required")));
}
Ok(RequireOwner(user))
}
}
/// Extractor that requires the user to be at least an Admin
pub struct RequireAdmin(pub AuthenticatedUser);
impl From<RequireAdmin> for AuthenticatedUser {
fn from(value: RequireAdmin) -> Self {
value.0
}
}
impl From<RequireOwner> for RequireAdmin {
fn from(value: RequireOwner) -> Self {
RequireAdmin(value.0)
}
}
impl<S> FromRequestParts<S> for RequireAdmin
where
S: Send + Sync + AsRef<AppState>,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let user = AuthenticatedUser::from_request_parts(parts, state).await?;
if user.permissions < UserOrgPermissions::Admin {
return Err(AuthError(ApiError::forbidden("Admin permission required")));
}
Ok(RequireAdmin(user))
}
}
/// Extractor that requires the user to be at least a Member
pub struct RequireMember(pub AuthenticatedUser);
impl From<RequireMember> for AuthenticatedUser {
fn from(value: RequireMember) -> Self {
value.0
}
}
impl From<RequireOwner> for RequireMember {
fn from(value: RequireOwner) -> Self {
RequireMember(value.0)
}
}
impl From<RequireAdmin> for RequireMember {
fn from(value: RequireAdmin) -> Self {
RequireMember(value.0)
}
}
impl<S> FromRequestParts<S> for RequireMember
where
S: Send + Sync + AsRef<AppState>,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let user = AuthenticatedUser::from_request_parts(parts, state).await?;
if user.permissions < UserOrgPermissions::Member {
return Err(AuthError(ApiError::forbidden("Member permission required")));
}
Ok(RequireMember(user))
}
}
@@ -0,0 +1,197 @@
use crate::server::{
auth::middleware::auth::AuthenticatedEntity, config::AppState, shared::types::api::ApiError,
};
use axum::{
extract::{FromRequestParts, Request, State},
middleware::Next,
response::{IntoResponse, Response},
};
use axum_client_ip::ClientIp;
use governor::{
Quota, RateLimiter,
clock::{Clock, DefaultClock},
state::keyed::DashMapStateStore,
};
use std::{
net::IpAddr,
num::NonZeroU32,
sync::{Arc, OnceLock},
time::Duration,
};
use uuid::Uuid;
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub enum RateLimitKey {
User(Uuid),
Ip(IpAddr),
}
type KeyedRateLimiter =
Arc<RateLimiter<RateLimitKey, DashMapStateStore<RateLimitKey>, DefaultClock>>;
struct RateLimiters {
user: KeyedRateLimiter,
anonymous: KeyedRateLimiter,
}
static RATE_LIMITERS: OnceLock<RateLimiters> = OnceLock::new();
fn get_limiters() -> &'static RateLimiters {
RATE_LIMITERS.get_or_init(|| {
let limiters = RateLimiters {
// Users: 5000 requests per hour
user: Arc::new(RateLimiter::keyed(
Quota::per_minute(NonZeroU32::new(300).unwrap())
.allow_burst(NonZeroU32::new(150).unwrap()),
)),
// Anonymous: 20 requests per minute
anonymous: Arc::new(RateLimiter::keyed(
Quota::per_minute(NonZeroU32::new(20).unwrap())
.allow_burst(NonZeroU32::new(5).unwrap()),
)),
};
// Spawn cleanup task
let user_limiter = Arc::clone(&limiters.user);
let anonymous_limiter = Arc::clone(&limiters.anonymous);
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
user_limiter.retain_recent();
anonymous_limiter.retain_recent();
tracing::debug!(
"Rate limiter cleanup: user keys={}, anonymous keys={}",
user_limiter.len(),
anonymous_limiter.len()
);
}
});
limiters
})
}
#[derive(Debug, Clone)]
struct RateLimitInfo {
limit: u32,
remaining: u32,
reset_in_secs: u64,
}
impl RateLimitInfo {
fn apply_headers(&self, response: &mut Response) {
let headers = response.headers_mut();
if let Ok(v) = self.limit.to_string().parse() {
headers.insert("X-RateLimit-Limit", v);
}
if let Ok(v) = self.remaining.to_string().parse() {
headers.insert("X-RateLimit-Remaining", v);
}
if let Ok(v) = self.reset_in_secs.to_string().parse() {
headers.insert("X-RateLimit-Reset", v);
}
}
fn to_error_response(&self) -> Response {
let mut response = ApiError::too_many_requests(format!(
"Rate limit exceeded. Try again in {} seconds.",
self.reset_in_secs
))
.into_response();
self.apply_headers(&mut response);
if let Ok(v) = self.reset_in_secs.to_string().parse() {
response.headers_mut().insert("Retry-After", v);
}
response
}
}
fn check_user(user_id: Uuid) -> Result<RateLimitInfo, RateLimitInfo> {
let limiters = get_limiters();
let key = RateLimitKey::User(user_id);
match limiters.user.check_key(&key) {
Ok(_) => Ok(RateLimitInfo {
limit: 100,
remaining: 99,
reset_in_secs: 60,
}),
Err(not_until) => {
let wait_time = not_until
.wait_time_from(DefaultClock::default().now())
.as_secs();
Err(RateLimitInfo {
limit: 100,
remaining: 0,
reset_in_secs: wait_time,
})
}
}
}
fn check_anonymous(ip: IpAddr) -> Result<RateLimitInfo, RateLimitInfo> {
let limiters = get_limiters();
let key = RateLimitKey::Ip(ip);
match limiters.anonymous.check_key(&key) {
Ok(_) => Ok(RateLimitInfo {
limit: 20,
remaining: 19,
reset_in_secs: 60,
}),
Err(not_until) => {
let wait_time = not_until
.wait_time_from(DefaultClock::default().now())
.as_secs();
Err(RateLimitInfo {
limit: 20,
remaining: 0,
reset_in_secs: wait_time,
})
}
}
}
pub async fn rate_limit_middleware(
State(state): State<Arc<AppState>>,
ClientIp(ip): ClientIp,
request: Request,
next: Next,
) -> Result<Response, Response> {
let (mut parts, body) = request.into_parts();
let entity = AuthenticatedEntity::from_request_parts(&mut parts, &state)
.await
.ok();
// Daemons and System are exempt from rate limiting
if let Some(ref e) = entity
&& matches!(
e,
AuthenticatedEntity::Daemon { .. } | AuthenticatedEntity::System
)
{
let request = Request::from_parts(parts, body);
return Ok(next.run(request).await);
}
let check_result = match entity {
Some(AuthenticatedEntity::User { user_id, .. }) => check_user(user_id),
_ => check_anonymous(ip),
};
match check_result {
Ok(info) => {
let request = Request::from_parts(parts, body);
let mut response = next.run(request).await;
info.apply_headers(&mut response);
Ok(response)
}
Err(info) => Err(info.to_error_response()),
}
}
+1 -1
View File
@@ -11,7 +11,7 @@ use crate::server::{
base::{LoginRegisterParams, ProvisionUserParams},
oidc::{OidcPendingAuth, OidcProvider, OidcProviderConfig, OidcProviderMetadata},
},
middleware::AuthenticatedEntity,
middleware::auth::AuthenticatedEntity,
service::AuthService,
},
shared::{
+1 -1
View File
@@ -5,7 +5,7 @@ use crate::server::{
api::{LoginRequest, RegisterRequest},
base::{LoginRegisterParams, ProvisionUserParams},
},
middleware::{AuthenticatedEntity, AuthenticatedUser},
middleware::auth::{AuthenticatedEntity, AuthenticatedUser},
},
email::traits::EmailService,
organizations::{
+2 -1
View File
@@ -1,4 +1,5 @@
use crate::server::auth::middleware::{AuthenticatedUser, RequireOwner};
use crate::server::auth::middleware::auth::AuthenticatedUser;
use crate::server::auth::middleware::permissions::RequireOwner;
use crate::server::billing::types::api::CreateCheckoutRequest;
use crate::server::billing::types::base::BillingPlan;
use crate::server::config::AppState;
+1 -1
View File
@@ -1,4 +1,4 @@
use crate::server::auth::middleware::AuthenticatedEntity;
use crate::server::auth::middleware::auth::AuthenticatedEntity;
use crate::server::billing::types::base::BillingPlan;
use crate::server::billing::types::features::Feature;
use crate::server::networks::service::NetworkService;
+5 -5
View File
@@ -74,7 +74,7 @@ pub struct ServerCli {
public_url: Option<String>,
#[arg(long)]
pub plunk_api_key: Option<String>,
pub plunk_secret: Option<String>,
/// Configure what proxy (if any) is providing IP address for requests, ie in a reverse proxy setup, for accurate IP in auth event logging
#[arg(long)]
@@ -105,7 +105,7 @@ pub struct ServerConfig {
pub oidc_providers: Option<Vec<OidcProviderConfig>>,
// Used in SaaS deployment
pub plunk_api_key: Option<String>,
pub plunk_secret: Option<String>,
pub stripe_key: Option<String>,
pub stripe_secret: Option<String>,
pub stripe_webhook_secret: Option<String>,
@@ -142,7 +142,7 @@ impl Default for ServerConfig {
smtp_password: None,
smtp_email: None,
smtp_relay: None,
plunk_api_key: None,
plunk_secret: None,
client_ip_source: None,
oidc_providers: None,
}
@@ -196,8 +196,8 @@ impl ServerConfig {
if let Some(public_url) = cli_args.public_url {
figment = figment.merge(("public_url", public_url));
}
if let Some(plunk_api_key) = cli_args.plunk_api_key {
figment = figment.merge(("plunk_api_key", plunk_api_key));
if let Some(plunk_secret) = cli_args.plunk_secret {
figment = figment.merge(("plunk_secret", plunk_secret));
}
if let Some(client_ip_source) = cli_args.client_ip_source {
figment = figment.merge(("client_ip_source", client_ip_source));
+1 -1
View File
@@ -1,6 +1,6 @@
use crate::server::shared::events::types::TelemetryOperation;
use crate::server::{
auth::middleware::{AuthenticatedDaemon, AuthenticatedEntity},
auth::middleware::auth::{AuthenticatedDaemon, AuthenticatedEntity},
config::AppState,
daemons::r#impl::{
api::{
+1 -1
View File
@@ -1,7 +1,7 @@
use crate::{
daemon::runtime::types::InitializeDaemonRequest,
server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::auth::AuthenticatedEntity,
daemons::r#impl::{
api::{DaemonDiscoveryRequest, DaemonDiscoveryResponse, DiscoveryUpdatePayload},
base::Daemon,
+4 -1
View File
@@ -1,5 +1,8 @@
use crate::server::{
auth::middleware::{AuthenticatedDaemon, AuthenticatedUser, RequireMember},
auth::middleware::{
auth::{AuthenticatedDaemon, AuthenticatedUser},
permissions::RequireMember,
},
config::AppState,
daemons::r#impl::api::DiscoveryUpdatePayload,
discovery::r#impl::{base::Discovery, types::RunType},
+1 -1
View File
@@ -1,4 +1,4 @@
use crate::server::auth::middleware::AuthenticatedEntity;
use crate::server::auth::middleware::auth::AuthenticatedEntity;
use crate::server::daemons::r#impl::base::DaemonMode;
use crate::server::discovery::r#impl::types::RunType;
use crate::server::shared::entities::ChangeTriggersTopologyStaleness;
+1 -1
View File
@@ -1,5 +1,5 @@
use crate::server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::auth::AuthenticatedEntity,
email::traits::EmailService,
shared::events::{
bus::{EventFilter, EventSubscriber},
+1 -1
View File
@@ -4,7 +4,7 @@ use std::sync::Arc;
use uuid::Uuid;
use crate::server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::auth::AuthenticatedEntity,
groups::r#impl::base::Group,
shared::{
entities::ChangeTriggersTopologyStaleness,
+1 -1
View File
@@ -1,4 +1,4 @@
use crate::server::auth::middleware::{MemberOrDaemon, RequireMember};
use crate::server::auth::middleware::permissions::{MemberOrDaemon, RequireMember};
use crate::server::shared::handlers::traits::{
CrudHandlers, bulk_delete_handler, get_all_handler, get_by_id_handler,
};
+1 -1
View File
@@ -1,5 +1,5 @@
use crate::server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::auth::AuthenticatedEntity,
daemons::service::DaemonService,
hosts::r#impl::base::Host,
services::{r#impl::base::Service, service::ServiceService},
+1 -1
View File
@@ -4,7 +4,7 @@ use anyhow::Error;
use async_trait::async_trait;
use crate::server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::auth::AuthenticatedEntity,
hosts::service::HostService,
shared::{
entities::EntityDiscriminants,
+1 -1
View File
@@ -2,7 +2,7 @@ use serial_test::serial;
use crate::{
server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::auth::AuthenticatedEntity,
services::r#impl::bindings::Binding,
shared::{
services::traits::CrudService,
+2 -1
View File
@@ -1,5 +1,6 @@
use crate::server::auth::middleware::{
AuthenticatedUser, CreateNetworkFeature, RequireAdmin, RequireFeature,
auth::AuthenticatedUser, features::CreateNetworkFeature, features::RequireFeature,
permissions::RequireAdmin,
};
use crate::server::shared::handlers::traits::{
BulkDeleteResponse, CrudHandlers, bulk_delete_handler, delete_handler, get_by_id_handler,
+1 -1
View File
@@ -1,5 +1,5 @@
use crate::server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::auth::AuthenticatedEntity,
hosts::service::HostService,
networks::r#impl::Network,
shared::{
+38 -9
View File
@@ -1,6 +1,7 @@
use crate::server::auth::middleware::AuthenticatedEntity;
use crate::server::auth::middleware::auth::AuthenticatedEntity;
use crate::server::auth::middleware::{
AuthenticatedUser, InviteUsersFeature, RequireFeature, RequireMember,
auth::AuthenticatedUser, features::InviteUsersFeature, features::RequireFeature,
permissions::RequireMember,
};
use crate::server::config::AppState;
use crate::server::organizations::r#impl::api::CreateInviteRequest;
@@ -8,6 +9,7 @@ use crate::server::organizations::r#impl::base::Organization;
use crate::server::organizations::r#impl::invites::Invite;
use crate::server::shared::handlers::traits::{CrudHandlers, update_handler};
use crate::server::shared::services::traits::CrudService;
use crate::server::shared::storage::filter::EntityFilter;
use crate::server::shared::types::api::ApiError;
use crate::server::shared::types::api::ApiResponse;
use crate::server::shared::types::api::ApiResult;
@@ -57,14 +59,41 @@ async fn create_invite(
RequireFeature { plan, .. }: RequireFeature<InviteUsersFeature>,
Json(request): Json<CreateInviteRequest>,
) -> ApiResult<Json<ApiResponse<Invite>>> {
if let Some(s) = plan.config().included_seats
&& s == 1
&& plan.features().share_views
&& request.permissions > UserOrgPermissions::Visualizer
// Seat limit check - only applies if permissions count towards seats
if request.permissions.counts_towards_seats()
&& let Some(max_seats) = plan.config().included_seats
&& plan.config().seat_cents.is_none()
{
return Err(ApiError::forbidden(
"You can only invite users with Visualizer permissions on your plan. Please upgrade to invite Members, Admins, and Owners.",
));
let org_filter = EntityFilter::unfiltered().organization_id(&user.organization_id);
let current_members = state
.services
.user_service
.get_all(org_filter)
.await
.unwrap_or_default()
.iter()
.filter(|u| u.base.permissions.counts_towards_seats())
.count();
let pending_invites = state
.services
.organization_service
.get_org_invites(&user.organization_id)
.await
.unwrap_or_default()
.iter()
.filter(|i| i.permissions.counts_towards_seats())
.count();
let total_seats_used = current_members + pending_invites;
if total_seats_used >= max_seats as usize {
return Err(ApiError::forbidden(&format!(
"Seat limit reached ({}/{}). Upgrade your plan for more seats, or delete any unused pending invites.",
total_seats_used, max_seats
)));
}
}
if user.permissions < request.permissions {
+1 -1
View File
@@ -1,4 +1,4 @@
use crate::server::auth::middleware::AuthenticatedEntity;
use crate::server::auth::middleware::auth::AuthenticatedEntity;
use crate::server::organizations::r#impl::invites::Invite;
use crate::server::shared::entities::ChangeTriggersTopologyStaleness;
use crate::server::shared::events::bus::EventBus;
@@ -2,7 +2,7 @@ use anyhow::Error;
use async_trait::async_trait;
use crate::server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::auth::AuthenticatedEntity,
organizations::service::OrganizationService,
shared::{
events::{
+1 -1
View File
@@ -1,5 +1,5 @@
use crate::server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::auth::AuthenticatedEntity,
groups::{
r#impl::{base::Group, types::GroupType},
service::GroupService,
+1 -1
View File
@@ -2,7 +2,7 @@ use serial_test::serial;
use crate::{
server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::auth::AuthenticatedEntity,
groups::r#impl::types::GroupType,
services::r#impl::{bindings::Binding, patterns::MatchDetails},
shared::{
+1 -1
View File
@@ -1,4 +1,4 @@
use crate::server::{auth::middleware::AuthenticatedEntity, shared::entities::Entity};
use crate::server::{auth::middleware::auth::AuthenticatedEntity, shared::entities::Entity};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::{fmt::Display, net::IpAddr};
@@ -1,5 +1,6 @@
use crate::server::api_keys::r#impl::base::{ApiKey, ApiKeyBase};
use crate::server::auth::middleware::{AuthenticatedEntity, AuthenticatedUser, RequireOwner};
use crate::server::auth::middleware::auth::{AuthenticatedEntity, AuthenticatedUser};
use crate::server::auth::middleware::permissions::RequireOwner;
use crate::server::billing::types::base::BillingPlan;
use crate::server::billing::types::features::Feature;
use crate::server::config::PublicConfigResponse;
@@ -123,9 +124,9 @@ pub async fn get_public_config(
&& state.config.smtp_username.is_some()
&& state.config.smtp_email.is_some()
&& state.config.smtp_relay.is_some())
|| state.config.plunk_api_key.is_some(),
|| state.config.plunk_secret.is_some(),
public_url: state.config.public_url.clone(),
has_email_opt_in: state.config.plunk_api_key.is_some(),
has_email_opt_in: state.config.plunk_secret.is_some(),
}))
}
+2 -43
View File
@@ -1,5 +1,5 @@
use crate::server::{
auth::middleware::{AuthenticatedUser, RequireMember},
auth::middleware::{auth::AuthenticatedUser, permissions::RequireMember},
config::AppState,
shared::{
entities::{ChangeTriggersTopologyStaleness, Entity},
@@ -79,12 +79,6 @@ where
)));
}
tracing::debug!(
entity_type = T::table_name(),
user_id = %user.user_id,
"Create request received"
);
let service = T::get_service(&state);
let created = service
.create(request, user.clone().into())
@@ -110,13 +104,6 @@ where
T: CrudHandlers + 'static + ChangeTriggersTopologyStaleness<T>,
Entity: From<T>,
{
tracing::debug!(
entity_type = T::table_name(),
user_id = %user.user_id,
network_count = %user.network_ids.len(),
"Get all request received"
);
let network_filter = EntityFilter::unfiltered().network_ids(&user.network_ids);
let service = T::get_service(&state);
@@ -142,13 +129,6 @@ where
T: CrudHandlers + 'static + ChangeTriggersTopologyStaleness<T>,
Entity: From<T>,
{
tracing::debug!(
entity_type = T::table_name(),
entity_id = %id,
user_id = %user.user_id,
"Get by ID request received"
);
let service = T::get_service(&state);
let entity = service
.get_by_id(&id)
@@ -186,13 +166,6 @@ where
T: CrudHandlers + 'static + ChangeTriggersTopologyStaleness<T>,
Entity: From<T>,
{
tracing::debug!(
entity_type = T::table_name(),
entity_id = %id,
user_id = %user.user_id,
"Update request received"
);
let service = T::get_service(&state);
// Verify entity exists
@@ -248,7 +221,7 @@ where
let service = T::get_service(&state);
// Verify entity exists and log the deletion attempt
let entity = service
service
.get_by_id(&id)
.await
.map_err(|e| {
@@ -269,13 +242,6 @@ where
ApiError::not_found(format!("{} '{}' not found", T::entity_name(), id))
})?;
tracing::debug!(
entity_type = T::table_name(),
entity_id = %id,
entity_name = %entity,
"Delete request received"
);
service.delete(&id, user.into()).await.map_err(|e| {
tracing::error!(
entity_type = T::table_name(),
@@ -302,13 +268,6 @@ where
return Err(ApiError::bad_request("No IDs provided for bulk delete"));
}
tracing::debug!(
entity_type = T::table_name(),
user_id = %user.user_id,
count = ids.len(),
"Bulk delete request received"
);
let service = T::get_service(&state);
let deleted_count = service
.delete_many(&ids, user.clone().into())
@@ -108,8 +108,8 @@ impl ServiceFactory {
let email_service = config.clone().and_then(|c| {
// Prefer Plunk if API key is provided
if let Some(plunk_api_key) = c.plunk_api_key {
let provider = Box::new(PlunkEmailProvider::new(plunk_api_key));
if let Some(plunk_secret) = c.plunk_secret {
let provider = Box::new(PlunkEmailProvider::new(plunk_secret));
return Some(Arc::new(EmailService::new(provider, user_service.clone())));
}
+1 -1
View File
@@ -5,7 +5,7 @@ use std::{fmt::Display, sync::Arc};
use uuid::Uuid;
use crate::server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::auth::AuthenticatedEntity,
shared::{
entities::{ChangeTriggersTopologyStaleness, Entity},
events::{
+4
View File
@@ -66,6 +66,10 @@ impl ApiError {
pub fn bad_gateway(message: String) -> Self {
Self::new(StatusCode::BAD_GATEWAY, message.to_string())
}
pub fn too_many_requests(message: String) -> Self {
Self::new(StatusCode::TOO_MANY_REQUESTS, message.to_string())
}
}
impl axum::response::IntoResponse for ApiError {
+2 -1
View File
@@ -1,4 +1,5 @@
use crate::server::auth::middleware::{AuthenticatedEntity, MemberOrDaemon};
use crate::server::auth::middleware::auth::AuthenticatedEntity;
use crate::server::auth::middleware::permissions::MemberOrDaemon;
use crate::server::shared::handlers::traits::{
CrudHandlers, bulk_delete_handler, delete_handler, get_by_id_handler, update_handler,
};
+1 -1
View File
@@ -1,5 +1,5 @@
use crate::server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::auth::AuthenticatedEntity,
discovery::r#impl::types::DiscoveryType,
shared::{
entities::ChangeTriggersTopologyStaleness,
+1 -1
View File
@@ -1,5 +1,5 @@
use crate::server::{
auth::middleware::{AuthenticatedUser, RequireMember},
auth::middleware::{auth::AuthenticatedUser, permissions::RequireMember},
config::AppState,
shared::{
events::types::{TelemetryEvent, TelemetryOperation},
+1 -1
View File
@@ -8,7 +8,7 @@ use tokio::sync::broadcast;
use uuid::Uuid;
use crate::server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::auth::AuthenticatedEntity,
groups::{r#impl::base::Group, service::GroupService},
hosts::{r#impl::base::Host, service::HostService},
services::{r#impl::base::Service, service::ServiceService},
@@ -5,7 +5,7 @@ use async_trait::async_trait;
use uuid::Uuid;
use crate::server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::auth::AuthenticatedEntity,
shared::{
entities::{Entity, EntityDiscriminants},
events::{
+2 -1
View File
@@ -1,4 +1,5 @@
use crate::server::auth::middleware::{AuthenticatedUser, RequireAdmin, RequireMember};
use crate::server::auth::middleware::auth::AuthenticatedUser;
use crate::server::auth::middleware::permissions::{RequireAdmin, RequireMember};
use crate::server::shared::handlers::traits::{
CrudHandlers, bulk_delete_handler, delete_handler, get_by_id_handler,
};
+6 -2
View File
@@ -109,8 +109,12 @@ impl TypeMetadataProvider for UserOrgPermissions {
UserOrgPermissions::Member => {
"Create and modify hosts, services, run discovery scans, and invite Visualizers to networks they have access to"
}
UserOrgPermissions::Visualizer => "Read-only access: view network topology",
UserOrgPermissions::None => "No permissions assigned",
UserOrgPermissions::Visualizer => {
"Read-only access: view network topology. Does not count towards seat usage."
}
UserOrgPermissions::None => {
"No permissions assigned. Does not count towards seat usage."
}
}
}
+1 -1
View File
@@ -1,5 +1,5 @@
use crate::server::{
auth::middleware::AuthenticatedEntity,
auth::middleware::auth::AuthenticatedEntity,
shared::{
entities::ChangeTriggersTopologyStaleness,
events::{
+13
View File
@@ -11,6 +11,7 @@
"@tailwindcss/forms": "^0.5.10",
"@types/jquery": "^3.5.33",
"@xyflow/svelte": "^1.2.4",
"company-email-validator": "^1.2.0",
"deepmerge": "^4.3.1",
"elkjs": "^0.10.0",
"email-validator": "^2.0.4",
@@ -2001,6 +2002,18 @@
"node": ">= 6"
}
},
"node_modules/company-email-validator": {
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/company-email-validator/-/company-email-validator-1.2.0.tgz",
"integrity": "sha512-fToFL0/RaOBKAuqjpHy6ZG76k4OSJCcYCKqFmEQOZcb2pJdWC3lkcJve0VPHapFEODlx3aimuoaCyHfJThZIbg==",
"license": "Unlicense",
"dependencies": {
"email-validator": "^2.0.4"
},
"engines": {
"node": ">4.0"
}
},
"node_modules/concat-map": {
"version": "0.0.1",
"resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz",
+1
View File
@@ -40,6 +40,7 @@
"@tailwindcss/forms": "^0.5.10",
"@types/jquery": "^3.5.33",
"@xyflow/svelte": "^1.2.4",
"company-email-validator": "^1.2.0",
"deepmerge": "^4.3.1",
"elkjs": "^0.10.0",
"email-validator": "^2.0.4",
@@ -8,7 +8,7 @@
import ToggleGroup from './ToggleGroup.svelte';
import { SvelteMap } from 'svelte/reactivity';
import { currentUser } from '../auth/store';
import { isFreeEmail } from 'free-email-domains-list';
import { isCompanyEmail } from 'company-email-validator';
$effect(() => {
void $currentPlans;
@@ -20,8 +20,8 @@
// Plan filter state
type PlanFilter = 'all' | 'personal' | 'commercial';
let freeEmail = $currentUser ? isFreeEmail($currentUser.email) : false;
let planFilter = $state<PlanFilter>(freeEmail ? 'personal' : 'commercial');
let companyEmail = $currentUser ? isCompanyEmail($currentUser.email) : false;
let planFilter = $state<PlanFilter>(companyEmail ? 'commercial' : 'personal');
// Billing period filter state
type BillingPeriod = 'monthly' | 'yearly';
+9 -5
View File
@@ -7,6 +7,7 @@ import { getSubnets } from '../subnets/store';
import { getServices } from '../services/store';
import { BaseSSEManager, type SSEConfig } from '$lib/shared/utils/sse';
import { getDaemons } from '../daemons/store';
import { getDiscoveries } from './store';
// session_id to latest update
export const sessions = writable<DiscoveryUpdatePayload[]>([]);
@@ -30,7 +31,7 @@ class DiscoverySSEManager extends BaseSSEManager<DiscoveryUpdatePayload> {
protected createConfig(): SSEConfig<DiscoveryUpdatePayload> {
return {
url: '/api/discovery/stream',
onMessage: (update) => {
onMessage: async (update) => {
// Check if discovered_count increased
const lastCount = lastProcessedCount.get(update.session_id) || 0;
const currentCount = update.processed || 0;
@@ -48,10 +49,13 @@ class DiscoverySSEManager extends BaseSSEManager<DiscoveryUpdatePayload> {
if (update.phase === 'Complete') {
pushSuccess(`${update.discovery_type.type} discovery completed`);
// Final refresh on completion
getHosts();
getServices();
getSubnets();
getDaemons();
await Promise.all([
getHosts(),
getServices(),
getSubnets(),
getDaemons(),
getDiscoveries()
]);
} else if (update.phase === 'Cancelled') {
pushWarning(`Discovery cancelled`);
} else if (update.phase === 'Failed' && update.error) {
+2 -2
View File
@@ -30,10 +30,10 @@ export class SSEClient<T> {
this.config.onOpen?.();
};
this.eventSource.onmessage = (event) => {
this.eventSource.onmessage = async (event) => {
try {
const data = JSON.parse(event.data) as T;
this.config.onMessage(data);
await this.config.onMessage(data);
} catch (error) {
console.error('Failed to parse SSE message:', error);
}
+2 -2
View File
@@ -23,8 +23,8 @@
</div>
<!-- Content (sits above background) -->
<div class="flex justify-center">
<div class="relative z-10 mt-6">
<div class="flex min-h-screen items-center justify-center">
<div class="relative z-10">
<BillingPlanForm />
</div>
</div>