Bring up multiple isolate workers and grab requests from a shared pool.

This commit is contained in:
Sebastian Jeltsch
2024-11-13 16:31:15 +01:00
parent 43f5820bee
commit 18dd596478
3 changed files with 271 additions and 153 deletions
+6 -2
View File
@@ -73,6 +73,8 @@ impl AppState {
let conn_clone0 = args.conn.clone();
let conn_clone1 = args.conn.clone();
RuntimeHandle::set_connection(args.conn.clone());
AppState {
state: Arc::new(InternalState {
data_dir: args.data_dir,
@@ -124,7 +126,7 @@ impl AppState {
jwt: args.jwt,
table_metadata: args.table_metadata,
runtime: RuntimeHandle::new(args.conn.clone()),
runtime: RuntimeHandle::new(),
#[cfg(test)]
cleanup: vec![],
@@ -370,6 +372,8 @@ pub async fn test_state(options: Option<TestStateOptions>) -> anyhow::Result<App
let main_conn_clone1 = main_conn.clone();
let table_metadata_clone = table_metadata.clone();
RuntimeHandle::set_connection(main_conn.clone());
return Ok(AppState {
state: Arc::new(InternalState {
data_dir: DataDir(temp_dir.path().to_path_buf()),
@@ -411,7 +415,7 @@ pub async fn test_state(options: Option<TestStateOptions>) -> anyhow::Result<App
logs_conn,
jwt: jwt::test_jwt_helper(),
table_metadata,
runtime: RuntimeHandle::new(main_conn.clone()),
runtime: RuntimeHandle::new(),
cleanup: vec![Box::new(temp_dir)],
}),
});
+264 -150
View File
@@ -6,7 +6,7 @@ use axum::Router;
use libsql::Connection;
use parking_lot::Mutex;
use rust_embed::RustEmbed;
use rustyscript::{json_args, Module, Runtime};
use rustyscript::{init_platform, json_args, Module, Runtime};
use serde::Deserialize;
use serde_json::from_value;
use std::collections::{HashMap, HashSet};
@@ -30,20 +30,31 @@ struct JsResponse {
}
enum Message {
Close,
Run(Box<dyn (FnOnce(&mut Runtime)) + Send + Sync>),
}
struct RuntimeSingleton {
struct State {
sender: crossbeam_channel::Sender<Message>,
connection: Mutex<Option<libsql::Connection>>,
}
struct RuntimeSingleton {
// Thread handle
handle: Option<std::thread::JoinHandle<()>>,
// Shared sender.
sender: crossbeam_channel::Sender<Message>,
// Isolate state.
state: Vec<State>,
}
impl Drop for RuntimeSingleton {
fn drop(&mut self) {
if let Some(handle) = self.handle.take() {
if self.sender.send(Message::Close).is_ok() {
handle.join().unwrap();
self.state.clear();
if handle.join().is_err() {
log::error!("Failed to join main rt thread");
}
}
}
@@ -51,44 +62,101 @@ impl Drop for RuntimeSingleton {
impl RuntimeSingleton {
fn new() -> Self {
let (sender, receiver) = crossbeam_channel::unbounded::<Message>();
let n_threads: usize = std::thread::available_parallelism().map_or_else(
|err| {
log::error!("Failed to get number of threads: {err}");
return 1;
},
|x| x.get(),
);
let (shared_sender, shared_receiver) = crossbeam_channel::unbounded::<Message>();
let (state, receivers): (Vec<State>, Vec<crossbeam_channel::Receiver<Message>>) = (0
..n_threads)
.map(|_index| {
let (sender, receiver) = crossbeam_channel::unbounded::<Message>();
return (
State {
sender,
connection: Mutex::new(None),
},
receiver,
);
})
.unzip();
let handle = std::thread::spawn(move || {
let mut runtime = Self::init_runtime().unwrap();
init_platform(n_threads as u32, true);
#[allow(clippy::never_loop)]
while let Ok(message) = receiver.recv() {
match message {
Message::Close => break,
Message::Run(f) => {
f(&mut runtime);
}
let threads: Vec<_> = receivers
.into_iter()
.enumerate()
.map(|(index, receiver)| {
let shared_receiver = shared_receiver.clone();
return std::thread::spawn(move || {
let mut runtime = Self::init_runtime(index).unwrap();
loop {
crossbeam_channel::select! {
recv(receiver) -> msg => {
match msg {
Ok(Message::Run(f)) => {
f(&mut runtime);
}
_ => {
log::info!("channel closed");
break;
}
}
},
recv(shared_receiver) -> msg => {
match msg {
Ok(Message::Run(f)) => {
f(&mut runtime);
}
_ => {
log::info!("shared channel closed");
break;
}
}
},
}
}
});
})
.collect();
for thread in threads {
if thread.join().is_err() {
log::error!("Failed to join worker");
}
}
});
return RuntimeSingleton {
sender,
sender: shared_sender,
handle: Some(handle),
state,
};
}
fn init_runtime() -> Result<Runtime, AnyError> {
let mut cache = import_provider::MemoryCache::default();
cache.set(
"trailbase:main",
cow_to_string(JsRuntimeAssets::get("index.js").unwrap().data),
);
let tokio_runtime = tokio::runtime::Builder::new_multi_thread()
fn init_runtime(index: usize) -> Result<Runtime, AnyError> {
let tokio_runtime = tokio::runtime::Builder::new_current_thread()
.enable_time()
.enable_io()
.thread_name("v8-runtime")
.thread_stack_size(4 * 1024 * 1024)
.build()?;
let runtime = rustyscript::Runtime::with_tokio_runtime(
let mut cache = import_provider::MemoryCache::default();
cache.set(
"trailbase:main",
cow_to_string(JsRuntimeAssets::get("index.js").unwrap().data),
);
let mut runtime = rustyscript::Runtime::with_tokio_runtime(
rustyscript::RuntimeOptions {
import_provider: Some(Box::new(cache)),
schema_whlist: HashSet::from(["trailbase".to_string()]),
@@ -97,6 +165,63 @@ impl RuntimeSingleton {
std::rc::Rc::new(tokio_runtime),
)?;
let idx = index;
runtime.register_async_function("query", move |args: Vec<serde_json::Value>| {
Box::pin(async move {
let query: String = get_arg(&args, 0)?;
let json_params: Vec<serde_json::Value> = get_arg(&args, 1)?;
let mut params: Vec<libsql::Value> = vec![];
for value in json_params {
params.push(json_value_to_param(value)?);
}
let Some(conn) = RUNTIME.state[idx].connection.lock().clone() else {
return Err(rustyscript::Error::Runtime(
"missing db connection".to_string(),
));
};
let rows = conn
.query(&query, libsql::params::Params::Positional(params))
.await
.map_err(|err| rustyscript::Error::Runtime(err.to_string()))?;
let (values, _columns) = rows_to_json_arrays(rows, usize::MAX)
.await
.map_err(|err| rustyscript::Error::Runtime(err.to_string()))?;
return serde_json::to_value(values)
.map_err(|err| rustyscript::Error::Runtime(err.to_string()));
})
})?;
let idx = index;
runtime.register_async_function("execute", move |args: Vec<serde_json::Value>| {
Box::pin(async move {
let query: String = get_arg(&args, 0)?;
let json_params: Vec<serde_json::Value> = get_arg(&args, 1)?;
let mut params: Vec<libsql::Value> = vec![];
for value in json_params {
params.push(json_value_to_param(value)?);
}
let Some(conn) = RUNTIME.state[idx].connection.lock().clone() else {
return Err(rustyscript::Error::Runtime(
"missing db connection".to_string(),
));
};
let rows_affected = conn
.execute(&query, libsql::params::Params::Positional(params))
.await
.map_err(|err| rustyscript::Error::Runtime(err.to_string()))?;
return Ok(serde_json::Value::Number(rows_affected.into()));
})
})?;
return Ok(runtime);
}
}
@@ -108,6 +233,64 @@ static RUNTIME: LazyLock<RuntimeSingleton> = LazyLock::new(RuntimeSingleton::new
pub(crate) struct RuntimeHandle;
impl RuntimeHandle {
#[cfg(not(test))]
pub(crate) fn set_connection(conn: Connection) {
for s in &RUNTIME.state {
let mut lock = s.connection.lock();
if lock.is_some() {
panic!("connection already set");
}
lock.replace(conn.clone());
}
}
#[cfg(test)]
pub(crate) fn set_connection(conn: Connection) {
for s in &RUNTIME.state {
let mut lock = s.connection.lock();
if lock.is_some() {
log::debug!("connection already set");
} else {
lock.replace(conn.clone());
}
}
}
#[cfg(test)]
pub(crate) fn override_connection(conn: Connection) {
for s in &RUNTIME.state {
let mut lock = s.connection.lock();
if lock.is_some() {
log::debug!("connection already set");
}
lock.replace(conn.clone());
}
}
pub(crate) fn new() -> Self {
return Self {};
}
async fn apply<T>(
&self,
f: impl (FnOnce(&mut rustyscript::Runtime) -> T) + Send + Sync + 'static,
) -> Result<Box<T>, AnyError>
where
T: Send + Sync + 'static,
{
let (sender, receiver) = tokio::sync::oneshot::channel::<Box<T>>();
RUNTIME.sender.send(Message::Run(Box::new(move |rt| {
if let Err(_err) = sender.send(Box::new(f(rt))) {
log::warn!("Failed to send");
}
})))?;
return Ok(receiver.await?);
}
}
pub fn json_value_to_param(value: serde_json::Value) -> Result<libsql::Value, rustyscript::Error> {
use rustyscript::Error;
return Ok(match value {
@@ -134,86 +317,6 @@ pub fn json_value_to_param(value: serde_json::Value) -> Result<libsql::Value, ru
});
}
impl RuntimeHandle {
pub(crate) fn new(conn: Connection) -> Self {
RUNTIME
.sender
.send(Message::Run(Box::new(move |runtime: &mut Runtime| {
let conn_clone = conn.clone();
runtime
.register_async_function("query", move |args: Vec<serde_json::Value>| {
let conn = conn_clone.clone();
Box::pin(async move {
let query: String = get_arg(&args, 0)?;
let json_params: Vec<serde_json::Value> = get_arg(&args, 1)?;
let mut params: Vec<libsql::Value> = vec![];
for value in json_params {
params.push(json_value_to_param(value)?);
}
let rows = conn
.query(&query, libsql::params::Params::Positional(params))
.await
.map_err(|err| rustyscript::Error::Runtime(err.to_string()))?;
let (values, _columns) = rows_to_json_arrays(rows, usize::MAX)
.await
.map_err(|err| rustyscript::Error::Runtime(err.to_string()))?;
return serde_json::to_value(values)
.map_err(|err| rustyscript::Error::Runtime(err.to_string()));
})
})
.unwrap();
runtime
.register_async_function("execute", move |args: Vec<serde_json::Value>| {
let conn = conn.clone();
Box::pin(async move {
let query: String = get_arg(&args, 0)?;
let json_params: Vec<serde_json::Value> = get_arg(&args, 1)?;
let mut params: Vec<libsql::Value> = vec![];
for value in json_params {
params.push(json_value_to_param(value)?);
}
let rows_affected = conn
.execute(&query, libsql::params::Params::Positional(params))
.await
.map_err(|err| rustyscript::Error::Runtime(err.to_string()))?;
return Ok(serde_json::Value::Number(rows_affected.into()));
})
})
.unwrap();
})))
.unwrap();
return Self {};
}
async fn apply<T>(
&self,
f: impl (FnOnce(&mut rustyscript::Runtime) -> T) + Send + Sync + 'static,
) -> Result<Box<T>, AnyError>
where
T: Send + Sync + 'static,
{
let (sender, receiver) = tokio::sync::oneshot::channel::<Box<T>>();
RUNTIME.sender.send(Message::Run(Box::new(move |rt| {
if let Err(_err) = sender.send(Box::new(f(rt))) {
log::warn!("Failed to send");
}
})))?;
return Ok(receiver.await?);
}
}
#[derive(Debug, Error)]
pub enum JsResponseError {
#[error("Precondition: {0}")]
@@ -246,7 +349,6 @@ impl IntoResponse for JsResponseError {
/// Get's called from JS to `addRoute`.
fn route_callback(
state: AppState,
router: Arc<Mutex<Option<Router<AppState>>>>,
method: String,
route: String,
@@ -280,8 +382,7 @@ fn route_callback(
})
.collect();
let js_response = state
.script_runtime()
let js_response = RuntimeHandle::new()
.apply(move |runtime| -> Result<JsResponse, rustyscript::Error> {
let response: JsResponse = runtime.call_function(
None,
@@ -351,44 +452,59 @@ where
return from_value::<T>(arg.clone()).map_err(|err| Error::Runtime(err.to_string()));
}
pub(crate) async fn install_routes(
state: AppState,
module: Module,
) -> Result<Option<Router<AppState>>, AnyError> {
return Ok(
*state
.clone()
.script_runtime()
.apply(move |runtime: &mut Runtime| {
let router = Arc::new(Mutex::new(Some(Router::<AppState>::new())));
pub(crate) async fn install_routes(module: Module) -> Result<Option<Router<AppState>>, AnyError> {
use tokio::sync::oneshot;
// First install a native callback that builds an axum router.
let router_clone = router.clone();
runtime
.register_function("route", move |args: &[serde_json::Value]| {
let method: String = get_arg(args, 0)?;
let route: String = get_arg(args, 1)?;
let receivers: Vec<_> = RUNTIME
.state
.iter()
.map(move |s| -> oneshot::Receiver<Option<Router<AppState>>> {
let module = module.clone();
let (sender, receiver) = oneshot::channel::<Option<Router<AppState>>>();
s.sender
.send(Message::Run(Box::new(move |runtime: &mut Runtime| {
let router = Arc::new(Mutex::new(Some(Router::<AppState>::new())));
route_callback(state.clone(), router_clone.clone(), method, route)
.map_err(|err| rustyscript::Error::Runtime(err.to_string()))?;
// First install a native callback that builds an axum router.
let router_clone = router.clone();
runtime
.register_function("route", move |args: &[serde_json::Value]| {
let method: String = get_arg(args, 0)?;
let route: String = get_arg(args, 1)?;
Ok(serde_json::Value::Null)
})
.unwrap();
route_callback(router_clone.clone(), method, route)
.map_err(|err| rustyscript::Error::Runtime(err.to_string()))?;
// Then execute the script/module, i.e. statements in the file scope.
runtime.load_module(&module).unwrap();
Ok(serde_json::Value::Null)
})
.unwrap();
let router: Router<AppState> = router.lock().take().unwrap();
if router.has_routes() {
return Some(router);
}
return None;
})
.await?,
);
// Then execute the script/module, i.e. statements in the file scope.
runtime.load_module(&module).unwrap();
let router: Router<AppState> = router.lock().take().unwrap();
if router.has_routes() {
sender.send(Some(router)).unwrap();
} else {
sender.send(None).unwrap();
}
})))
.unwrap();
return receiver;
})
.collect();
let mut receivers = futures::future::join_all(receivers).await;
// Note: We only return the first router assuming that js route registration is deterministic.
return Ok(receivers.swap_remove(0)?);
}
#[derive(RustEmbed, Clone)]
#[folder = "js/dist/"]
struct JsRuntimeAssets;
#[cfg(test)]
mod tests {
use super::*;
@@ -415,7 +531,7 @@ mod tests {
}
async fn test_runtime_apply() {
let handle = RuntimeHandle::new(new_mem_conn().await);
let handle = RuntimeHandle::new();
let number = handle
.apply::<i64>(|_runtime| {
return 42;
@@ -427,7 +543,7 @@ mod tests {
}
async fn test_runtime_javascript() {
let handle = RuntimeHandle::new(new_mem_conn().await);
let handle = RuntimeHandle::new();
let result = handle
.apply::<String>(|runtime| {
let context = runtime
@@ -448,7 +564,7 @@ mod tests {
return runtime
.call_function(Some(&context), "test_fun", json_args!())
.map_err(|err| {
log::error!("Failed to load call fun: {err}");
log::error!("Failed to load call test_fun: {err}");
return err;
})
.unwrap();
@@ -470,7 +586,8 @@ mod tests {
.await
.unwrap();
let handle = RuntimeHandle::new(conn);
RuntimeHandle::override_connection(conn);
let handle = RuntimeHandle::new();
let result = handle
.apply::<Vec<Vec<serde_json::Value>>>(|runtime| {
@@ -503,7 +620,7 @@ mod tests {
.await
})
.map_err(|err| {
log::error!("Failed to load call fun: {err}");
log::error!("Failed to load call test_query: {err}");
return err;
})
.unwrap();
@@ -533,7 +650,8 @@ mod tests {
.await
.unwrap();
let handle = RuntimeHandle::new(conn.clone());
RuntimeHandle::override_connection(conn.clone());
let handle = RuntimeHandle::new();
let _result = handle
.apply::<i64>(|runtime| {
@@ -566,7 +684,7 @@ mod tests {
.await
})
.map_err(|err| {
log::error!("Failed to load call fun: {err}");
log::error!("Failed to load call test_execute: {err}");
return err;
})
.unwrap();
@@ -581,7 +699,3 @@ mod tests {
assert_eq!(0, count);
}
}
#[derive(RustEmbed, Clone)]
#[folder = "js/dist/"]
struct JsRuntimeAssets;
+1 -1
View File
@@ -113,7 +113,7 @@ impl Server {
let mut js_router = Some(Router::new());
for module in modules {
let fname = module.filename().to_owned();
let router = install_routes(state.clone(), module)
let router = install_routes(module)
.await
.map_err(|err| InitError::ScriptError(err.to_string()))?;