diff --git a/crates/core/src/extract/content_type.rs b/crates/core/src/extract/content_type.rs new file mode 100644 index 00000000..6dc2b351 --- /dev/null +++ b/crates/core/src/extract/content_type.rs @@ -0,0 +1,89 @@ +use axum::extract::{FromRequest, Request}; +use axum::http::header::{ACCEPT, CONTENT_TYPE}; +use axum::http::{HeaderMap, StatusCode}; +use axum::response::{IntoResponse, Response}; +use thiserror::Error; + +/// Supported request content types. +/// +/// We error for unsupported content types and fall back to application/json for unspecified +/// content types. +pub enum RequestContentType { + // Unknown, + Json, + Multipart, + Form, +} + +#[derive(Debug, Error)] +pub enum ContentTypeRejection { + #[error("Unsupported Content-Type: {0}")] + UnsupportedContentType(String), +} + +impl IntoResponse for ContentTypeRejection { + fn into_response(self) -> Response { + return (StatusCode::BAD_REQUEST, self.to_string()).into_response(); + } +} + +impl RequestContentType { + #[inline] + pub fn from_headers(headers: &HeaderMap) -> Result { + return match headers.get(CONTENT_TYPE).map(|h| h.as_bytes()) { + Some(content_type) if content_type.starts_with(b"application/json") => { + Ok(RequestContentType::Json) + } + Some(content_type) if content_type.starts_with(b"application/x-www-form-urlencoded") => { + Ok(RequestContentType::Form) + } + Some(content_type) if content_type.starts_with(b"multipart/form-data") => { + Ok(RequestContentType::Multipart) + } + Some(content_type) => Err(ContentTypeRejection::UnsupportedContentType( + String::from_utf8_lossy(content_type).into(), + )), + // QUESTION: Not convinced this is a sensible default for "None" but convenient for testing + // with curl. + None => Ok(RequestContentType::Json), + }; + } +} + +pub enum ResponseContentType { + Json, +} + +#[allow(unused)] +impl ResponseContentType { + pub fn from_headers(headers: &HeaderMap) -> Result { + // We mimic the requests's content type. However, we won't reply in forms. + if let Ok(_request_content_type) = RequestContentType::from_headers(headers) { + return Ok(ResponseContentType::Json); + } + + for value in headers.get_all(ACCEPT) { + if value == "application/json" { + return Ok(ResponseContentType::Json); + } + } + + return Err(ContentTypeRejection::UnsupportedContentType( + headers + .get(CONTENT_TYPE) + .and_then(|c| c.to_str().map(|c| c.to_string()).ok()) + .unwrap_or_default(), + )); + } +} + +impl FromRequest for ResponseContentType +where + S: Send + Sync, +{ + type Rejection = ContentTypeRejection; + + async fn from_request(req: Request, _state: &S) -> Result { + return Self::from_headers(req.headers()); + } +} diff --git a/crates/core/src/extract/either.rs b/crates/core/src/extract/either.rs index 831e6115..2c82e52f 100644 --- a/crates/core/src/extract/either.rs +++ b/crates/core/src/extract/either.rs @@ -1,21 +1,21 @@ use axum::Json; use axum::extract::{Form, FromRequest, Request, rejection::*}; use axum::http::StatusCode; -use axum::http::header::CONTENT_TYPE; use axum::response::{IntoResponse, Response}; use serde::Serialize; use serde::de::DeserializeOwned; use thiserror::Error; use trailbase_schema::FileUploadInput; +use crate::extract::content_type::{ContentTypeRejection, RequestContentType}; use crate::extract::multipart::{Rejection as MultipartRejection, parse_multipart}; #[derive(Debug, Error)] pub enum EitherRejection { // #[error("Missing Content-Type")] // MissingContentType, - #[error("Unsupported Content-Type found")] - UnsupportedContentType, + #[error("Unsupported Content-Type: {0}")] + UnsupportedContentType(String), #[error("Form error: {0}")] Form(#[from] FormRejection), #[error("Json error: {0}")] @@ -30,15 +30,21 @@ impl IntoResponse for EitherRejection { } } -// NOTE: For serde_json::Value as T, the different formats will produce very different results, -// e.g. json has a notion of types, whereas Multipart and Form don't. They're s practically a: -// Map> +pub enum ResponseContentType { + Json, +} + +/// Deserialization helper to support requests in multiple formats. +/// +/// Eventually, we'd like to support Avro as well. In which case, we might have to delay +/// de-serialization to pass a schema or we'll only be able to support generic: +/// `Map` types, which may still provide some compression benefits +/// :shrug:. #[derive(Debug)] pub enum Either { Json(T), Multipart(T, Vec), Form(T), - // Proto(DynamicMessage), } impl FromRequest for Either @@ -49,30 +55,23 @@ where type Rejection = EitherRejection; async fn from_request(req: Request, state: &S) -> Result { - return match req.headers().get(CONTENT_TYPE) { - Some(x) if x.as_ref().starts_with(b"application/json") => { - let Json(value): Json = Json::from_request(req, state).await?; - Ok(Either::Json(value)) + return match RequestContentType::from_headers(req.headers()) { + Ok(RequestContentType::Json) => { + Ok(Either::Json(Json::::from_request(req, state).await?.0)) } - Some(x) if x.as_ref().starts_with(b"application/x-www-form-urlencoded") => { + Ok(RequestContentType::Form) => { let Form(value): Form = Form::from_request(req, state).await?; Ok(Either::Form(value)) } - Some(x) if x.as_ref().starts_with(b"multipart/form-data") => { + Ok(RequestContentType::Multipart) => { let (value, files) = parse_multipart(req).await?; Ok(Either::Multipart(value, files)) } - // Some(x) if x == "application/x-protobuf" => { - // return Ok(Either::Proto(DynamicMessage::decode::from_request(req, - // state).await.unwrap())); } - Some(_) => Err(EitherRejection::UnsupportedContentType), - None => { - // TODO: Not convinced this is a sensible default for "None" but convenient for testing with - // curl. - let Json(value): Json = Json::from_request(req, state).await?; - Ok(Either::Json(value)) - // Err(EitherRejection::MissingContentType), - } + Err(err) => match err { + ContentTypeRejection::UnsupportedContentType(v) => { + Err(EitherRejection::UnsupportedContentType(v)) + } + }, }; } } @@ -157,7 +156,7 @@ mod test { "#}; let request = axum::http::Request::builder() - .header("content-type", "application/json; boundary=fieldB") + .header("ContenT-tYpe", "application/json") .header("content-length", body.len()) .body(axum::body::Body::from(body)) .unwrap(); diff --git a/crates/core/src/extract/mod.rs b/crates/core/src/extract/mod.rs index 93fc330d..6410df9f 100644 --- a/crates/core/src/extract/mod.rs +++ b/crates/core/src/extract/mod.rs @@ -1,3 +1,4 @@ +mod content_type; mod either; pub mod ip; mod multipart;