From acdc888653caede6862d3e25f3a0636ca9560a0e Mon Sep 17 00:00:00 2001 From: Sebastian Jeltsch Date: Mon, 13 Jan 2025 13:02:32 +0100 Subject: [PATCH] Add more "realtime" tests especially for ACL checking --- trailbase-core/src/records/subscribe.rs | 407 ++++++++++++++++++++---- 1 file changed, 343 insertions(+), 64 deletions(-) diff --git a/trailbase-core/src/records/subscribe.rs b/trailbase-core/src/records/subscribe.rs index 9c1e507b..8f57b8cc 100644 --- a/trailbase-core/src/records/subscribe.rs +++ b/trailbase-core/src/records/subscribe.rs @@ -222,6 +222,7 @@ impl SubscriptionManager { }; } + #[cfg(test)] pub fn num_record_subscriptions(&self) -> usize { let mut count: usize = 0; for table in self.state.record_subscriptions.read().values() { @@ -232,10 +233,20 @@ impl SubscriptionManager { return count; } + #[cfg(test)] + pub fn num_table_subscriptions(&self) -> usize { + let mut count: usize = 0; + for table in self.state.table_subscriptions.read().values() { + count += table.len(); + } + return count; + } + fn broker_subscriptions( s: &ManagerState, conn: &rusqlite::Connection, subs: &[Subscription], + record_subscriptions: bool, record: &[(&str, rusqlite::types::ValueRef<'_>)], event: &Event, ) -> Vec { @@ -250,13 +261,15 @@ impl SubscriptionManager { if let Err(_err) = api.check_record_level_read_access(conn, Permission::Read, record, sub.user.as_ref()) { - // This can happen if the record api configuration has changed since originally - // subscribed. In this case we just send and error and cancel the subscription. - if let Ok(ev) = Event::default().json_data(DbEvent::Error("Access denied".into())) { - let _ = sub.sender.try_send(ev); + if record_subscriptions { + // This can happen if the record api configuration has changed since originally + // subscribed. In this case we just send and error and cancel the subscription. + if let Ok(ev) = Event::default().json_data(DbEvent::Error("Access denied".into())) { + let _ = sub.sender.try_send(ev); + } + dead_subscriptions.push(idx); + sub.sender.close(); } - dead_subscriptions.push(idx); - sub.sender.close(); continue; } @@ -346,7 +359,7 @@ impl SubscriptionManager { break 'record_subs; }; - let dead_subscriptions = Self::broker_subscriptions(s, conn, subs, &record, &event); + let dead_subscriptions = Self::broker_subscriptions(s, conn, subs, true, &record, &event); if dead_subscriptions.is_empty() && action != RecordAction::Delete { // No cleanup needed. break 'record_subs; @@ -396,7 +409,7 @@ impl SubscriptionManager { break 'table_subs; }; - let dead_subscriptions = Self::broker_subscriptions(s, conn, subs, &record, &event); + let dead_subscriptions = Self::broker_subscriptions(s, conn, subs, false, &record, &event); if dead_subscriptions.is_empty() && action != RecordAction::Delete { // No cleanup needed. break 'table_subs; @@ -646,12 +659,18 @@ async fn decode_sse_json_event(event: Event) -> serde_json::Value { #[cfg(test)] mod tests { - use super::DbEvent; - use super::*; + use async_channel::TryRecvError; + use futures::StreamExt; use trailbase_sqlite::params; + use super::DbEvent; + use super::*; + + use crate::admin::user::*; use crate::app_state::test_state; + use crate::auth::api::login::login_with_password; use crate::records::{add_record_api, AccessRules, Acls, PermissionFlag}; + use crate::util::uuid_to_b64; async fn decode_db_event(event: Event) -> DbEvent { let json = decode_sse_json_event(event).await; @@ -737,6 +756,15 @@ mod tests { assert_eq!(1, manager.num_record_subscriptions()); + // This should do nothing since nobody is subscribed to id = 5. + let _ = conn + .query_row( + "INSERT INTO test (id, text) VALUES ($1, 'baz')", + [trailbase_sqlite::Value::Integer(5)], + ) + .await + .unwrap(); + conn .execute( "UPDATE test SET text = $1 WHERE _rowid_ = $2", @@ -782,68 +810,78 @@ mod tests { let manager = state.subscription_manager(); let api = state.lookup_record_api("api_name").unwrap(); - let cleanup = manager - .add_table_subscription(state.clone(), api, None) - .await - .unwrap(); - let receiver = &cleanup.stream; - let record_id_raw = 0; - conn - .query_row( - "INSERT INTO test (id, text) VALUES ($1, 'foo')", - params!(record_id_raw), - ) - .await - .unwrap(); + { + let cleanup = manager + .add_table_subscription(state.clone(), api, None) + .await + .unwrap(); + let receiver = &cleanup.stream; - conn - .execute( - "UPDATE test SET text = $1 WHERE id = $2", - params!("bar", record_id_raw), - ) - .await - .unwrap(); + assert_eq!(1, manager.num_table_subscriptions()); - let expected = serde_json::json!({ - "id": record_id_raw, - "text": "foo", - }); - match decode_db_event(receiver.recv().await.unwrap()).await { - DbEvent::Insert(Some(value)) => { - assert_eq!(value, expected); - } - x => { - assert!(false, "Expected update, got: {x:?}"); - } - }; + let record_id_raw = 0; + conn + .query_row( + "INSERT INTO test (id, text) VALUES ($1, 'foo')", + params!(record_id_raw), + ) + .await + .unwrap(); - let expected = serde_json::json!({ - "id": record_id_raw, - "text": "bar", - }); - match decode_db_event(receiver.recv().await.unwrap()).await { - DbEvent::Update(Some(value)) => { - assert_eq!(value, expected); - } - x => { - assert!(false, "Expected update, got: {x:?}"); - } - }; + conn + .execute( + "UPDATE test SET text = $1 WHERE id = $2", + params!("bar", record_id_raw), + ) + .await + .unwrap(); - conn - .execute("DELETE FROM test WHERE id = $1", params!(record_id_raw)) - .await - .unwrap(); + let expected = serde_json::json!({ + "id": record_id_raw, + "text": "foo", + }); + match decode_db_event(receiver.recv().await.unwrap()).await { + DbEvent::Insert(Some(value)) => { + assert_eq!(value, expected); + } + x => { + assert!(false, "Expected update, got: {x:?}"); + } + }; - match decode_db_event(receiver.recv().await.unwrap()).await { - DbEvent::Delete(Some(value)) => { - assert_eq!(value, expected); - } - x => { - assert!(false, "Expected update, got: {x:?}"); + let expected = serde_json::json!({ + "id": record_id_raw, + "text": "bar", + }); + match decode_db_event(receiver.recv().await.unwrap()).await { + DbEvent::Update(Some(value)) => { + assert_eq!(value, expected); + } + x => { + assert!(false, "Expected update, got: {x:?}"); + } + }; + + conn + .execute("DELETE FROM test WHERE id = $1", params!(record_id_raw)) + .await + .unwrap(); + + match decode_db_event(receiver.recv().await.unwrap()).await { + DbEvent::Delete(Some(value)) => { + assert_eq!(value, expected); + } + x => { + assert!(false, "Expected update, got: {x:?}"); + } } } + + // Implicitly await for scheduled cleanups to go through. + conn.query("SELECT 1", ()).await.unwrap(); + + assert_eq!(0, manager.num_table_subscriptions()); } #[tokio::test] @@ -883,6 +921,247 @@ mod tests { assert_eq!(0, manager.num_record_subscriptions()); } + + #[tokio::test] + async fn subscription_acl_test() { + let state = test_state(None).await.unwrap(); + let conn = state.conn().clone(); + + conn + .execute( + "CREATE TABLE test ( + id INTEGER PRIMARY KEY, + user BLOB NOT NULL, + text TEXT + ) STRICT", + (), + ) + .await + .unwrap(); + + state.table_metadata().invalidate_all().await.unwrap(); + + // Register message table as record api with moderator read access. + add_record_api( + &state, + "api_name", + "test", + Acls { + authenticated: vec![PermissionFlag::Read], + ..Default::default() + }, + AccessRules { + read: Some("EXISTS(SELECT 1 FROM test AS m WHERE _USER_.id = _ROW_.user)".to_string()), + ..Default::default() + }, + ) + .await + .unwrap(); + + let user_x_email = "user_x@bar.com"; + let password = "Secret!1!!"; + + let sse_or = add_subscription_sse_handler( + State(state.clone()), + Path(("api_name".to_string(), "*".to_string())), + None, + ) + .await; + + assert!(matches!(sse_or, Err(RecordError::Forbidden))); + + let user_x = create_user_for_test(&state, user_x_email, password) + .await + .unwrap() + .into_bytes(); + let user_x_token = login_with_password(&state, user_x_email, password) + .await + .unwrap(); + + // Check that we can subscribe to table wide changes. + { + let _ = add_subscription_sse_handler( + State(state.clone()), + Path(("api_name".to_string(), "*".to_string())), + User::from_auth_token(&state, &user_x_token.auth_token), + ) + .await + .unwrap(); + } + + let record_id_raw = 0; + let record_id = trailbase_sqlite::Value::Integer(record_id_raw); + let _rowid: i64 = conn + .query_row( + "INSERT INTO test (id, user, text) VALUES ($1, $2, 'foo') RETURNING _rowid_", + [ + record_id.clone(), + trailbase_sqlite::Value::Blob(user_x.to_vec()), + ], + ) + .await + .unwrap() + .unwrap() + .get(0) + .unwrap(); + + // Assert user_x can subscribe to their record. + { + let _ = add_subscription_sse_handler( + State(state.clone()), + Path(("api_name".to_string(), record_id_raw.to_string())), + User::from_auth_token(&state, &user_x_token.auth_token), + ) + .await + .unwrap(); + } + + // Assert user_y cannot subscribe to user_x's record. + { + let user_y_email = "user_y@bar.com"; + let _user_y = create_user_for_test(&state, user_y_email, password) + .await + .unwrap() + .into_bytes(); + let user_y_token = login_with_password(&state, user_y_email, password) + .await + .unwrap(); + + let sse_or = add_subscription_sse_handler( + State(state.clone()), + Path(("api_name".to_string(), record_id_raw.to_string())), + User::from_auth_token(&state, &user_y_token.auth_token), + ) + .await; + + assert!(matches!(sse_or, Err(RecordError::Forbidden))); + } + } + + #[tokio::test] + async fn test_acl_selective_table_subs() { + let state = test_state(None).await.unwrap(); + let conn = state.conn().clone(); + + conn + .execute( + "CREATE TABLE test ( + id INTEGER PRIMARY KEY, + user BLOB NOT NULL, + text TEXT + ) STRICT", + (), + ) + .await + .unwrap(); + + state.table_metadata().invalidate_all().await.unwrap(); + + // Register message table as record api with moderator read access. + add_record_api( + &state, + "api_name", + "test", + Acls { + authenticated: vec![PermissionFlag::Read], + ..Default::default() + }, + AccessRules { + read: Some("EXISTS(SELECT 1 FROM test AS m WHERE _USER_.id = _ROW_.user)".to_string()), + ..Default::default() + }, + ) + .await + .unwrap(); + + let manager = state.subscription_manager(); + let api = state.lookup_record_api("api_name").unwrap(); + + let password = "Secret!1!!"; + let user_x_email = "user_x@bar.com"; + let user_x = create_user_for_test(&state, user_x_email, password) + .await + .unwrap(); + let user_x_token = login_with_password(&state, user_x_email, password) + .await + .unwrap(); + + let user_y_email = "user_y@bar.com"; + let _user_y = create_user_for_test(&state, user_y_email, password) + .await + .unwrap() + .into_bytes(); + let user_y_token = login_with_password(&state, user_y_email, password) + .await + .unwrap(); + + // Assert events for table subscriptions are selective on ACLs. + { + let user_x_subscription = manager + .add_table_subscription( + state.clone(), + api.clone(), + User::from_auth_token(&state, &user_x_token.auth_token), + ) + .await + .unwrap(); + + let user_y_subscription = manager + .add_table_subscription( + state.clone(), + api.clone(), + User::from_auth_token(&state, &user_y_token.auth_token), + ) + .await + .unwrap(); + + assert_eq!(2, manager.num_table_subscriptions()); + + let record_id_raw = 1; + conn + .query_row( + "INSERT INTO test (id, user, text) VALUES ($1, $2, 'foo')", + [ + trailbase_sqlite::Value::Integer(record_id_raw), + trailbase_sqlite::Value::Blob(user_x.into()), + ], + ) + .await + .unwrap(); + + let expected = serde_json::json!({ + "id": record_id_raw, + "user": uuid_to_b64(&user_x), + "text": "foo", + }); + + match decode_db_event(user_x_subscription.stream.recv().await.unwrap()).await { + DbEvent::Insert(Some(value)) => { + assert_eq!(value, expected); + } + x => { + assert!(false, "Expected update, got: {x:?}"); + } + }; + + // User y should *not* have received the insert event. + assert!(tokio::time::timeout( + tokio::time::Duration::from_millis(300), + user_y_subscription.stream.clone().count() + ) + .await + .is_err()); + assert_eq!( + user_y_subscription.stream.try_recv().err().unwrap(), + TryRecvError::Empty + ); + } + + // Implicitly await for scheduled cleanups to go through. + conn.query("SELECT 1", ()).await.unwrap(); + + assert_eq!(0, manager.num_table_subscriptions()); + } } const NO_HOOK: Option = None;