diff --git a/trailbase-core/src/records/create_record.rs b/trailbase-core/src/records/create_record.rs index b2c21b00..8b17e654 100644 --- a/trailbase-core/src/records/create_record.rs +++ b/trailbase-core/src/records/create_record.rs @@ -145,7 +145,7 @@ pub async fn create_record_handler( params_list.push( lazy_params .consume() - .map_err(|_| RecordError::BadRequest("Parameter conversion"))?, + .map_err(|_| RecordError::BadRequest("Invalid Parameters"))?, ); } @@ -207,6 +207,7 @@ mod test { use crate::util::{id_to_b64, uuid_to_b64}; use serde_json::json; + use trailbase_sqlite::params; #[tokio::test] async fn test_simple_record_api_create() { @@ -271,16 +272,14 @@ mod test { .unwrap(); } - let value: Option = state - .conn() - .read_query_row_f( - "SELECT value FROM simple WHERE owner = ?1", - trailbase_sqlite::params!(user_x), - |row| row.get(0), - ) - .await - .unwrap(); - assert_eq!(value, Some(9)); + assert_eq!( + state + .conn() + .read_query_value::("SELECT value FROM simple WHERE owner = ?1", params!(user_x)) + .await + .unwrap(), + Some(9) + ); { // Make sure user.id == owner ACL check works diff --git a/trailbase-core/src/records/params.rs b/trailbase-core/src/records/params.rs index 2f56d8c1..04f4326b 100644 --- a/trailbase-core/src/records/params.rs +++ b/trailbase-core/src/records/params.rs @@ -397,15 +397,39 @@ pub fn simple_json_value_to_param( try_json_array_to_blob(arr)? } serde_json::Value::Null => Value::Null, - serde_json::Value::Bool(b) => Value::Integer(b as i64), + serde_json::Value::Bool(b) => { + if col_type != ColumnDataType::Integer { + return Err(ParamsError::UnexpectedType("Bool", format!("{col_type:?}"))); + } + Value::Integer(b as i64) + } serde_json::Value::String(str) => json_string_to_value(col_type, str)?, serde_json::Value::Number(number) => { if let Some(n) = number.as_i64() { - Value::Integer(n) + match col_type { + ColumnDataType::Integer => Value::Integer(n), + // NOTE: "as" is lossy conversion. Does not panic. + ColumnDataType::Real => Value::Real(n as f64), + _ => { + return Err(ParamsError::UnexpectedType("int", format!("{col_type:?}"))); + } + } } else if let Some(n) = number.as_u64() { - Value::Integer(n as i64) + match col_type { + // NOTE: "as" is lossy conversion. Does not panic. + ColumnDataType::Integer => Value::Integer(n as i64), + ColumnDataType::Real => Value::Real(n as f64), + _ => { + return Err(ParamsError::UnexpectedType("uint", format!("{col_type:?}"))); + } + } } else if let Some(n) = number.as_f64() { - Value::Real(n) + match col_type { + ColumnDataType::Real => Value::Real(n), + _ => { + return Err(ParamsError::UnexpectedType("real", format!("{col_type:?}"))); + } + } } else { warn!("Not a valid number: {number:?}"); return Err(ParamsError::NotANumber); diff --git a/trailbase-core/src/records/update_record.rs b/trailbase-core/src/records/update_record.rs index 1b8c7c0e..58eb8ffa 100644 --- a/trailbase-core/src/records/update_record.rs +++ b/trailbase-core/src/records/update_record.rs @@ -66,7 +66,7 @@ pub async fn update_record_handler( api.has_file_columns(), lazy_params .consume() - .map_err(|err| RecordError::Internal(err.into()))?, + .map_err(|_| RecordError::BadRequest("Invalid Parameters"))?, ) .await .map_err(|err| RecordError::Internal(err.into()))?; @@ -77,6 +77,7 @@ pub async fn update_record_handler( #[cfg(test)] mod test { use axum::extract::Query; + use serde_json::json; use trailbase_sqlite::params; use super::*; @@ -94,6 +95,106 @@ mod test { use crate::test::unpack_json_response; use crate::util::{b64_to_id, id_to_b64}; + #[tokio::test] + async fn test_simple_record_api_update() { + let state = test_state(None).await.unwrap(); + + state + .conn() + .execute_batch( + r#" + CREATE TABLE "update" ( + "id" INTEGER PRIMARY KEY, + "int" INTEGER NOT NULL DEFAULT (-1), + "float" REAL NOT NULL, + "text" TEXT + ) STRICT; + "#, + ) + .await + .unwrap(); + + state.schema_metadata().invalidate_all().await.unwrap(); + + add_record_api_config( + &state, + RecordApiConfig { + name: Some("update_api".to_string()), + table_name: Some("update".to_string()), + acl_world: [ + PermissionFlag::Create as i32, + PermissionFlag::Read as i32, + PermissionFlag::Update as i32, + ] + .into(), + ..Default::default() + }, + ) + .await + .unwrap(); + + let _ = create_record_handler( + State(state.clone()), + Path("update_api".to_string()), + Query(CreateRecordQuery::default()), + None, + Either::Json( + json_row_from_value(json!({ + "id": 1, + "float": 5, + })) + .unwrap() + .into(), + ), + ) + .await + .unwrap(); + + let _ = update_record_handler( + State(state.clone()), + Path(("update_api".to_string(), "1".to_string())), + None, + Either::Json( + json_row_from_value(json!({ + "int": 4, + })) + .unwrap() + .into(), + ), + ) + .await + .unwrap(); + + assert_eq!( + state + .conn() + .read_query_value::(r#"SELECT "int" FROM "update" WHERE id = 1"#, ()) + .await + .unwrap(), + Some(4) + ); + + // Test that bad input leads to bad request. + let response = update_record_handler( + State(state.clone()), + Path(("update_api".to_string(), "1".to_string())), + None, + Either::Json( + json_row_from_value(json!({ + "int": 4.1, + })) + .unwrap() + .into(), + ), + ) + .await; + + assert!(matches!( + response.err().unwrap(), + RecordError::BadRequest(_) + )) + } + #[tokio::test] async fn test_record_api_update() { let state = test_state(None).await.unwrap();