Add more "realtime" tests especially for ACL checking

This commit is contained in:
Sebastian Jeltsch
2025-01-13 13:02:32 +01:00
parent dfedb76342
commit acdc888653

View File

@@ -222,6 +222,7 @@ impl SubscriptionManager {
};
}
#[cfg(test)]
pub fn num_record_subscriptions(&self) -> usize {
let mut count: usize = 0;
for table in self.state.record_subscriptions.read().values() {
@@ -232,10 +233,20 @@ impl SubscriptionManager {
return count;
}
#[cfg(test)]
pub fn num_table_subscriptions(&self) -> usize {
let mut count: usize = 0;
for table in self.state.table_subscriptions.read().values() {
count += table.len();
}
return count;
}
fn broker_subscriptions(
s: &ManagerState,
conn: &rusqlite::Connection,
subs: &[Subscription],
record_subscriptions: bool,
record: &[(&str, rusqlite::types::ValueRef<'_>)],
event: &Event,
) -> Vec<usize> {
@@ -250,13 +261,15 @@ impl SubscriptionManager {
if let Err(_err) =
api.check_record_level_read_access(conn, Permission::Read, record, sub.user.as_ref())
{
// This can happen if the record api configuration has changed since originally
// subscribed. In this case we just send and error and cancel the subscription.
if let Ok(ev) = Event::default().json_data(DbEvent::Error("Access denied".into())) {
let _ = sub.sender.try_send(ev);
if record_subscriptions {
// This can happen if the record api configuration has changed since originally
// subscribed. In this case we just send and error and cancel the subscription.
if let Ok(ev) = Event::default().json_data(DbEvent::Error("Access denied".into())) {
let _ = sub.sender.try_send(ev);
}
dead_subscriptions.push(idx);
sub.sender.close();
}
dead_subscriptions.push(idx);
sub.sender.close();
continue;
}
@@ -346,7 +359,7 @@ impl SubscriptionManager {
break 'record_subs;
};
let dead_subscriptions = Self::broker_subscriptions(s, conn, subs, &record, &event);
let dead_subscriptions = Self::broker_subscriptions(s, conn, subs, true, &record, &event);
if dead_subscriptions.is_empty() && action != RecordAction::Delete {
// No cleanup needed.
break 'record_subs;
@@ -396,7 +409,7 @@ impl SubscriptionManager {
break 'table_subs;
};
let dead_subscriptions = Self::broker_subscriptions(s, conn, subs, &record, &event);
let dead_subscriptions = Self::broker_subscriptions(s, conn, subs, false, &record, &event);
if dead_subscriptions.is_empty() && action != RecordAction::Delete {
// No cleanup needed.
break 'table_subs;
@@ -646,12 +659,18 @@ async fn decode_sse_json_event(event: Event) -> serde_json::Value {
#[cfg(test)]
mod tests {
use super::DbEvent;
use super::*;
use async_channel::TryRecvError;
use futures::StreamExt;
use trailbase_sqlite::params;
use super::DbEvent;
use super::*;
use crate::admin::user::*;
use crate::app_state::test_state;
use crate::auth::api::login::login_with_password;
use crate::records::{add_record_api, AccessRules, Acls, PermissionFlag};
use crate::util::uuid_to_b64;
async fn decode_db_event(event: Event) -> DbEvent {
let json = decode_sse_json_event(event).await;
@@ -737,6 +756,15 @@ mod tests {
assert_eq!(1, manager.num_record_subscriptions());
// This should do nothing since nobody is subscribed to id = 5.
let _ = conn
.query_row(
"INSERT INTO test (id, text) VALUES ($1, 'baz')",
[trailbase_sqlite::Value::Integer(5)],
)
.await
.unwrap();
conn
.execute(
"UPDATE test SET text = $1 WHERE _rowid_ = $2",
@@ -782,68 +810,78 @@ mod tests {
let manager = state.subscription_manager();
let api = state.lookup_record_api("api_name").unwrap();
let cleanup = manager
.add_table_subscription(state.clone(), api, None)
.await
.unwrap();
let receiver = &cleanup.stream;
let record_id_raw = 0;
conn
.query_row(
"INSERT INTO test (id, text) VALUES ($1, 'foo')",
params!(record_id_raw),
)
.await
.unwrap();
{
let cleanup = manager
.add_table_subscription(state.clone(), api, None)
.await
.unwrap();
let receiver = &cleanup.stream;
conn
.execute(
"UPDATE test SET text = $1 WHERE id = $2",
params!("bar", record_id_raw),
)
.await
.unwrap();
assert_eq!(1, manager.num_table_subscriptions());
let expected = serde_json::json!({
"id": record_id_raw,
"text": "foo",
});
match decode_db_event(receiver.recv().await.unwrap()).await {
DbEvent::Insert(Some(value)) => {
assert_eq!(value, expected);
}
x => {
assert!(false, "Expected update, got: {x:?}");
}
};
let record_id_raw = 0;
conn
.query_row(
"INSERT INTO test (id, text) VALUES ($1, 'foo')",
params!(record_id_raw),
)
.await
.unwrap();
let expected = serde_json::json!({
"id": record_id_raw,
"text": "bar",
});
match decode_db_event(receiver.recv().await.unwrap()).await {
DbEvent::Update(Some(value)) => {
assert_eq!(value, expected);
}
x => {
assert!(false, "Expected update, got: {x:?}");
}
};
conn
.execute(
"UPDATE test SET text = $1 WHERE id = $2",
params!("bar", record_id_raw),
)
.await
.unwrap();
conn
.execute("DELETE FROM test WHERE id = $1", params!(record_id_raw))
.await
.unwrap();
let expected = serde_json::json!({
"id": record_id_raw,
"text": "foo",
});
match decode_db_event(receiver.recv().await.unwrap()).await {
DbEvent::Insert(Some(value)) => {
assert_eq!(value, expected);
}
x => {
assert!(false, "Expected update, got: {x:?}");
}
};
match decode_db_event(receiver.recv().await.unwrap()).await {
DbEvent::Delete(Some(value)) => {
assert_eq!(value, expected);
}
x => {
assert!(false, "Expected update, got: {x:?}");
let expected = serde_json::json!({
"id": record_id_raw,
"text": "bar",
});
match decode_db_event(receiver.recv().await.unwrap()).await {
DbEvent::Update(Some(value)) => {
assert_eq!(value, expected);
}
x => {
assert!(false, "Expected update, got: {x:?}");
}
};
conn
.execute("DELETE FROM test WHERE id = $1", params!(record_id_raw))
.await
.unwrap();
match decode_db_event(receiver.recv().await.unwrap()).await {
DbEvent::Delete(Some(value)) => {
assert_eq!(value, expected);
}
x => {
assert!(false, "Expected update, got: {x:?}");
}
}
}
// Implicitly await for scheduled cleanups to go through.
conn.query("SELECT 1", ()).await.unwrap();
assert_eq!(0, manager.num_table_subscriptions());
}
#[tokio::test]
@@ -883,6 +921,247 @@ mod tests {
assert_eq!(0, manager.num_record_subscriptions());
}
#[tokio::test]
async fn subscription_acl_test() {
let state = test_state(None).await.unwrap();
let conn = state.conn().clone();
conn
.execute(
"CREATE TABLE test (
id INTEGER PRIMARY KEY,
user BLOB NOT NULL,
text TEXT
) STRICT",
(),
)
.await
.unwrap();
state.table_metadata().invalidate_all().await.unwrap();
// Register message table as record api with moderator read access.
add_record_api(
&state,
"api_name",
"test",
Acls {
authenticated: vec![PermissionFlag::Read],
..Default::default()
},
AccessRules {
read: Some("EXISTS(SELECT 1 FROM test AS m WHERE _USER_.id = _ROW_.user)".to_string()),
..Default::default()
},
)
.await
.unwrap();
let user_x_email = "user_x@bar.com";
let password = "Secret!1!!";
let sse_or = add_subscription_sse_handler(
State(state.clone()),
Path(("api_name".to_string(), "*".to_string())),
None,
)
.await;
assert!(matches!(sse_or, Err(RecordError::Forbidden)));
let user_x = create_user_for_test(&state, user_x_email, password)
.await
.unwrap()
.into_bytes();
let user_x_token = login_with_password(&state, user_x_email, password)
.await
.unwrap();
// Check that we can subscribe to table wide changes.
{
let _ = add_subscription_sse_handler(
State(state.clone()),
Path(("api_name".to_string(), "*".to_string())),
User::from_auth_token(&state, &user_x_token.auth_token),
)
.await
.unwrap();
}
let record_id_raw = 0;
let record_id = trailbase_sqlite::Value::Integer(record_id_raw);
let _rowid: i64 = conn
.query_row(
"INSERT INTO test (id, user, text) VALUES ($1, $2, 'foo') RETURNING _rowid_",
[
record_id.clone(),
trailbase_sqlite::Value::Blob(user_x.to_vec()),
],
)
.await
.unwrap()
.unwrap()
.get(0)
.unwrap();
// Assert user_x can subscribe to their record.
{
let _ = add_subscription_sse_handler(
State(state.clone()),
Path(("api_name".to_string(), record_id_raw.to_string())),
User::from_auth_token(&state, &user_x_token.auth_token),
)
.await
.unwrap();
}
// Assert user_y cannot subscribe to user_x's record.
{
let user_y_email = "user_y@bar.com";
let _user_y = create_user_for_test(&state, user_y_email, password)
.await
.unwrap()
.into_bytes();
let user_y_token = login_with_password(&state, user_y_email, password)
.await
.unwrap();
let sse_or = add_subscription_sse_handler(
State(state.clone()),
Path(("api_name".to_string(), record_id_raw.to_string())),
User::from_auth_token(&state, &user_y_token.auth_token),
)
.await;
assert!(matches!(sse_or, Err(RecordError::Forbidden)));
}
}
#[tokio::test]
async fn test_acl_selective_table_subs() {
let state = test_state(None).await.unwrap();
let conn = state.conn().clone();
conn
.execute(
"CREATE TABLE test (
id INTEGER PRIMARY KEY,
user BLOB NOT NULL,
text TEXT
) STRICT",
(),
)
.await
.unwrap();
state.table_metadata().invalidate_all().await.unwrap();
// Register message table as record api with moderator read access.
add_record_api(
&state,
"api_name",
"test",
Acls {
authenticated: vec![PermissionFlag::Read],
..Default::default()
},
AccessRules {
read: Some("EXISTS(SELECT 1 FROM test AS m WHERE _USER_.id = _ROW_.user)".to_string()),
..Default::default()
},
)
.await
.unwrap();
let manager = state.subscription_manager();
let api = state.lookup_record_api("api_name").unwrap();
let password = "Secret!1!!";
let user_x_email = "user_x@bar.com";
let user_x = create_user_for_test(&state, user_x_email, password)
.await
.unwrap();
let user_x_token = login_with_password(&state, user_x_email, password)
.await
.unwrap();
let user_y_email = "user_y@bar.com";
let _user_y = create_user_for_test(&state, user_y_email, password)
.await
.unwrap()
.into_bytes();
let user_y_token = login_with_password(&state, user_y_email, password)
.await
.unwrap();
// Assert events for table subscriptions are selective on ACLs.
{
let user_x_subscription = manager
.add_table_subscription(
state.clone(),
api.clone(),
User::from_auth_token(&state, &user_x_token.auth_token),
)
.await
.unwrap();
let user_y_subscription = manager
.add_table_subscription(
state.clone(),
api.clone(),
User::from_auth_token(&state, &user_y_token.auth_token),
)
.await
.unwrap();
assert_eq!(2, manager.num_table_subscriptions());
let record_id_raw = 1;
conn
.query_row(
"INSERT INTO test (id, user, text) VALUES ($1, $2, 'foo')",
[
trailbase_sqlite::Value::Integer(record_id_raw),
trailbase_sqlite::Value::Blob(user_x.into()),
],
)
.await
.unwrap();
let expected = serde_json::json!({
"id": record_id_raw,
"user": uuid_to_b64(&user_x),
"text": "foo",
});
match decode_db_event(user_x_subscription.stream.recv().await.unwrap()).await {
DbEvent::Insert(Some(value)) => {
assert_eq!(value, expected);
}
x => {
assert!(false, "Expected update, got: {x:?}");
}
};
// User y should *not* have received the insert event.
assert!(tokio::time::timeout(
tokio::time::Duration::from_millis(300),
user_y_subscription.stream.clone().count()
)
.await
.is_err());
assert_eq!(
user_y_subscription.stream.try_recv().err().unwrap(),
TryRecvError::Empty
);
}
// Implicitly await for scheduled cleanups to go through.
conn.query("SELECT 1", ()).await.unwrap();
assert_eq!(0, manager.num_table_subscriptions());
}
}
const NO_HOOK: Option<fn(Action, &str, &str, &PreUpdateCase)> = None;