diff --git a/.gitmodules b/.gitmodules index 2bce7349..af6197d1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,3 @@ -[submodule "vendor/refinery"] - path = vendor/refinery - url = https://github.com/trailbaseio/refinery.git [submodule "vendor/sqlean/bundled/sqlean"] path = vendor/sqlean/bundled/sqlean url = https://github.com/trailbaseio/sqlean diff --git a/Cargo.lock b/Cargo.lock index 213db7a4..2601d948 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6821,6 +6821,7 @@ dependencies = [ "regex", "reqwest", "rusqlite", + "rust-embed", "schemars", "serde", "serde_json", @@ -6845,8 +6846,7 @@ dependencies = [ "trailbase-extension", "trailbase-js", "trailbase-qs", - "trailbase-refinery-core", - "trailbase-refinery-macros", + "trailbase-refinery", "trailbase-schema", "trailbase-sqlite", "ts-rs", @@ -6871,13 +6871,12 @@ dependencies = [ "log", "once_cell", "rusqlite", + "rust-embed", "serde", "serde_json", "thiserror 2.0.12", "tokio", - "trailbase-apalis", - "trailbase-refinery-core", - "trailbase-refinery-macros", + "trailbase-refinery", "trailbase-sqlite", ] @@ -6997,33 +6996,24 @@ dependencies = [ ] [[package]] -name = "trailbase-refinery-core" -version = "0.8.16" +name = "trailbase-refinery" +version = "0.1.0" dependencies = [ "async-trait", "cfg-if", + "futures", "log", "regex", "rusqlite", "siphasher 1.0.1", + "tempfile", "thiserror 1.0.69", "time", + "tokio", "url", "walkdir", ] -[[package]] -name = "trailbase-refinery-macros" -version = "0.8.15" -dependencies = [ - "heck 0.5.0", - "proc-macro2", - "quote", - "regex", - "syn 2.0.101", - "trailbase-refinery-core", -] - [[package]] name = "trailbase-schema" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 708ed568..eec97409 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ members = [ "trailbase-extension", "trailbase-js", "trailbase-qs", + "trailbase-refinery", "trailbase-schema", "trailbase-sqlite", "vendor/sqlean", @@ -28,10 +29,6 @@ default-members = [ "trailbase-schema", "trailbase-sqlite", ] -exclude = [ - "vendor/refinery", - "vendor/rustc_tools_util", -] # https://doc.rust-lang.org/cargo/reference/profiles.html [profile.release] @@ -76,8 +73,7 @@ trailbase-sqlean = { path = "vendor/sqlean", version = "0.0.2" } trailbase-extension = { path = "trailbase-extension", version = "0.2.0" } trailbase-js = { path = "trailbase-js", version = "0.1.0" } trailbase-qs = { path = "trailbase-qs", version = "0.1.0" } -trailbase-refinery-core = { path = "vendor/refinery/refinery_core", version = "0.8.16", default-features = false, features = ["rusqlite-bundled"] } -trailbase-refinery-macros = { path = "vendor/refinery/refinery_macros", version = "0.8.15" } +trailbase-refinery = { path = "trailbase-refinery", version = "0.1.0" } trailbase-schema = { path = "trailbase-schema", version = "0.1.0" } trailbase-sqlite = { path = "trailbase-sqlite", version = "0.2.0" } trailbase = { path = "trailbase-core", version = "0.1.0" } diff --git a/examples/custom-binary/src/main.rs b/examples/custom-binary/src/main.rs index d52151da..c4580090 100644 --- a/examples/custom-binary/src/main.rs +++ b/examples/custom-binary/src/main.rs @@ -28,7 +28,7 @@ async fn hello_world_handler(State(state): State, user: Option Result<(), Box> { env_logger::init_from_env( env_logger::Env::new() - .default_filter_or("info,refinery_core=warn,tracing::span=warn,swc_ecma_codegen=off"), + .default_filter_or("info,trailbase_refinery=warn,tracing::span=warn,swc_ecma_codegen=off"), ); let Server { diff --git a/trailbase-apalis/Cargo.toml b/trailbase-apalis/Cargo.toml index ad923fe2..e28bb929 100644 --- a/trailbase-apalis/Cargo.toml +++ b/trailbase-apalis/Cargo.toml @@ -21,13 +21,13 @@ chrono = { version = "0.4", features = ["serde"] } futures = "0.3.30" futures-lite = "2.3.0" log = "0.4.21" -rusqlite.workspace = true +rusqlite = { workspace = true } +rust-embed = { workspace = true } serde = { version = "1", features = ["derive"] } serde_json = "1" thiserror = "2.0.0" tokio = { workspace = true } -trailbase-refinery-core = { workspace = true } -trailbase-refinery-macros = { workspace = true } +trailbase-refinery = { workspace = true } trailbase-sqlite = { workspace = true } [dev-dependencies] @@ -36,7 +36,6 @@ apalis-core = { version = "0.7.0", default-features = false, features = [ "test- email_address = "0.2.9" once_cell = "1.19.0" tokio = { workspace = true } -trailbase-apalis = { path = "." } [package.metadata.docs.rs] # defines the configuration attribute `docsrs` diff --git a/trailbase-apalis/src/sqlite.rs b/trailbase-apalis/src/sqlite.rs index e7141715..08537a91 100644 --- a/trailbase-apalis/src/sqlite.rs +++ b/trailbase-apalis/src/sqlite.rs @@ -63,8 +63,32 @@ impl Clone for SqliteStorage { } } -mod main { - trailbase_refinery_macros::embed_migrations!("migrations"); +mod migrations { + use trailbase_refinery::{Migration, Runner}; + + #[derive(Clone, rust_embed::RustEmbed)] + #[folder = "migrations"] + struct Migrations; + + fn load_migrations() -> Vec { + let mut migrations = vec![]; + for filename in T::iter() { + if let Some(file) = T::get(&filename) { + migrations.push( + Migration::unapplied(&filename, &String::from_utf8_lossy(&file.data)).expect("startup"), + ) + } + } + return migrations; + } + + pub(crate) fn migration_runner() -> Runner { + const MIGRATION_TABLE_NAME: &str = "_schema_history"; + let migrations = load_migrations::(); + let mut runner = trailbase_refinery::Runner::new(&migrations).set_abort_divergent(false); + runner.set_migration_table_name(MIGRATION_TABLE_NAME); + return runner; + } } impl SqliteStorage<()> { @@ -81,7 +105,7 @@ impl SqliteStorage<()> { ) .await?; - let runner = main::migrations::runner(); + let runner = migrations::migration_runner(); let _report = conn .call(move |conn| { diff --git a/trailbase-cli/src/bin/trail.rs b/trailbase-cli/src/bin/trail.rs index 87eec6d4..aa01f597 100644 --- a/trailbase-cli/src/bin/trail.rs +++ b/trailbase-cli/src/bin/trail.rs @@ -23,8 +23,7 @@ fn init_logger(dev: bool) { // SWC is very spammy in in debug builds and complaints about source maps when compiling // typescript to javascript. Since we don't care about source maps and didn't find a better // option to mute the errors, turn it off in debug builds. - const DEFAULT: &str = - "info,refinery_core=warn,trailbase_refinery_core=warn,tracing::span=warn,swc_ecma_codegen=off"; + const DEFAULT: &str = "info,trailbase_refinery=warn,tracing::span=warn,swc_ecma_codegen=off"; env_logger::Builder::from_env(if dev { env_logger::Env::new().default_filter_or(format!("{DEFAULT},trailbase=debug")) diff --git a/trailbase-core/Cargo.toml b/trailbase-core/Cargo.toml index b44bbe3d..9ba523c9 100644 --- a/trailbase-core/Cargo.toml +++ b/trailbase-core/Cargo.toml @@ -64,6 +64,7 @@ rand = "^0.9.0" regex = "1.11.0" reqwest = { version = "0.12.8", default-features = false, features = ["rustls-tls", "json"] } rusqlite = { workspace = true } +rust-embed = { workspace = true } serde = { version = "^1.0.203", features = ["derive"] } serde_json = "^1.0.117" serde_path_to_error = "0.1.16" @@ -85,8 +86,7 @@ trailbase-assets = { workspace = true } trailbase-extension = { workspace = true } trailbase-js = { workspace = true, optional = true } trailbase-qs = { workspace = true } -trailbase-refinery-core = { workspace = true } -trailbase-refinery-macros = { workspace = true } +trailbase-refinery = { workspace = true } trailbase-schema = { workspace = true } trailbase-sqlite = { workspace = true } ts-rs = { version = "10", features = ["uuid-impl", "serde-json-impl"] } diff --git a/trailbase-core/src/admin/error.rs b/trailbase-core/src/admin/error.rs index ab623eef..baabc28a 100644 --- a/trailbase-core/src/admin/error.rs +++ b/trailbase-core/src/admin/error.rs @@ -32,7 +32,7 @@ pub enum AdminError { #[error("Table lookup error: {0}")] TableLookup(#[from] crate::schema_metadata::SchemaLookupError), #[error("DB Migration error: {0}")] - Migration(#[from] trailbase_refinery_core::Error), + Migration(#[from] trailbase_refinery::Error), #[error("SQL -> Json error: {0}")] Json(#[from] trailbase_sqlite::rows::JsonError), #[error("Schema error: {0}")] diff --git a/trailbase-core/src/admin/user/mod.rs b/trailbase-core/src/admin/user/mod.rs index 18c4d03f..ad909c7c 100644 --- a/trailbase-core/src/admin/user/mod.rs +++ b/trailbase-core/src/admin/user/mod.rs @@ -54,7 +54,7 @@ mod tests { #[tokio::test] async fn test_user_creation_and_deletion() { let _ = env_logger::try_init_from_env( - env_logger::Env::new().default_filter_or("info,refinery_core=warn"), + env_logger::Env::new().default_filter_or("info,trailbase_refinery=warn"), ); let mailer = TestAsyncSmtpTransport::new(); diff --git a/trailbase-core/src/app_state.rs b/trailbase-core/src/app_state.rs index bc8cd6e7..3710b7da 100644 --- a/trailbase-core/src/app_state.rs +++ b/trailbase-core/src/app_state.rs @@ -307,9 +307,10 @@ pub async fn test_state(options: Option) -> anyhow::Result String { return format!("U{timestamp}__{suffix}.sql"); } -pub(crate) fn new_migration_runner(migrations: &[Migration]) -> trailbase_refinery_core::Runner { +pub(crate) fn new_migration_runner(migrations: &[Migration]) -> trailbase_refinery::Runner { // NOTE: divergent migrations are migrations with the same version but a different name. That // said, `set_abort_divergent` is not a viable way for us to handle collisions (e.g. in tests), // since setting it to false, will prevent the migration from failing but divergent migrations // are quietly dropped on the floor and not applied. That's not ok. - let mut runner = trailbase_refinery_core::Runner::new(migrations).set_abort_divergent(false); + let mut runner = trailbase_refinery::Runner::new(migrations).set_abort_divergent(false); runner.set_migration_table_name(MIGRATION_TABLE_NAME); return runner; } +fn load_migrations() -> Vec { + let mut migrations = vec![]; + for filename in T::iter() { + if let Some(file) = T::get(&filename) { + migrations.push( + Migration::unapplied(&filename, &String::from_utf8_lossy(&file.data)).expect("startup"), + ) + } + } + return migrations; +} + pub(crate) fn apply_main_migrations( conn: &mut rusqlite::Connection, user_migrations_path: Option, -) -> Result { +) -> Result { let all_migrations = { let mut migrations: Vec = vec![]; - let system_migrations_runner = main::migrations::runner(); - migrations.extend(system_migrations_runner.get_migrations().iter().cloned()); + let system_migrations_runner: Vec = load_migrations::(); + migrations.extend(system_migrations_runner); if let Some(path) = user_migrations_path { // NOTE: refinery has a bug where it will name-check the directory and write a warning... :/. - let user_migrations = trailbase_refinery_core::load_sql_migrations(path)?; + let user_migrations = trailbase_refinery::load_sql_migrations(path)?; migrations.extend(user_migrations); } @@ -92,8 +97,10 @@ pub(crate) fn apply_main_migrations( pub(crate) fn apply_logs_migrations( logs_conn: &mut rusqlite::Connection, -) -> Result<(), trailbase_refinery_core::Error> { - let mut runner = logs::migrations::runner(); +) -> Result<(), trailbase_refinery::Error> { + let migrations = load_migrations::(); + + let mut runner = new_migration_runner(&migrations); runner.set_migration_table_name(MIGRATION_TABLE_NAME); let report = runner.run(logs_conn).map_err(|err| { @@ -109,3 +116,11 @@ pub(crate) fn apply_logs_migrations( return Ok(()); } + +#[derive(Clone, rust_embed::RustEmbed)] +#[folder = "migrations/main"] +struct MainMigrations; + +#[derive(Clone, rust_embed::RustEmbed)] +#[folder = "migrations/logs"] +struct LogsMigrations; diff --git a/trailbase-core/src/transaction.rs b/trailbase-core/src/transaction.rs index 8a5f9300..cfa8b8d3 100644 --- a/trailbase-core/src/transaction.rs +++ b/trailbase-core/src/transaction.rs @@ -13,7 +13,7 @@ pub enum TransactionError { #[error("IO error: {0}")] IO(#[from] std::io::Error), #[error("Migration error: {0}")] - Migration(#[from] trailbase_refinery_core::Error), + Migration(#[from] trailbase_refinery::Error), #[error("File error: {0}")] File(String), } @@ -35,7 +35,7 @@ impl TransactionLog { conn: &trailbase_sqlite::Connection, migration_path: impl AsRef, filename_suffix: &str, - ) -> Result { + ) -> Result { let filename = migrations::new_unique_migration_filename(filename_suffix); let stem = Path::new(&filename) .file_stem() @@ -68,7 +68,7 @@ impl TransactionLog { ) }; - let migrations = vec![trailbase_refinery_core::Migration::unapplied(&stem, &sql)?]; + let migrations = vec![trailbase_refinery::Migration::unapplied(&stem, &sql)?]; let runner = migrations::new_migration_runner(&migrations).set_abort_missing(false); let report = conn diff --git a/trailbase-refinery/Cargo.toml b/trailbase-refinery/Cargo.toml new file mode 100644 index 00000000..1533dc3d --- /dev/null +++ b/trailbase-refinery/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "trailbase-refinery" +version = "0.1.0" +authors = ["TrailBase "] +description = "Fork of Refinery's refinery-core/macro crates" +license = "OSL-3.0" +repository = "https://github.com/trailbaseio/trailbase" +edition = "2024" +readme = "../README.md" + +[dependencies] +async-trait = "0.1" +cfg-if = "1.0" +futures = { version = "0.3.16", optional = true, features = ["async-await"] } +log = "0.4" +regex = "1" +rusqlite = { workspace = true } +siphasher = "1.0" +thiserror = "1" +time = { version = "0.3.5", features = ["parsing", "formatting"] } +tokio = { workspace = true } +url = "2.0" +walkdir = "2.3.1" + +[dev-dependencies] +tempfile = "3.1.0" diff --git a/trailbase-refinery/src/drivers/mod.rs b/trailbase-refinery/src/drivers/mod.rs new file mode 100644 index 00000000..86b1d615 --- /dev/null +++ b/trailbase-refinery/src/drivers/mod.rs @@ -0,0 +1 @@ +pub mod rusqlite; diff --git a/trailbase-refinery/src/drivers/rusqlite.rs b/trailbase-refinery/src/drivers/rusqlite.rs new file mode 100644 index 00000000..038eb71e --- /dev/null +++ b/trailbase-refinery/src/drivers/rusqlite.rs @@ -0,0 +1,56 @@ +use crate::Migration; +use crate::traits::sync::{Migrate, Query, Transaction}; +use rusqlite::{Connection as RqlConnection, Error as RqlError}; +use time::OffsetDateTime; +use time::format_description::well_known::Rfc3339; + +fn query_applied_migrations( + transaction: &RqlConnection, + query: &str, +) -> Result, RqlError> { + let mut stmt = transaction.prepare(query)?; + let mut rows = stmt.query([])?; + let mut applied = Vec::new(); + while let Some(row) = rows.next()? { + let version = row.get(0)?; + let applied_on: String = row.get(2)?; + // Safe to call unwrap, as we stored it in RFC3339 format on the database + let applied_on = OffsetDateTime::parse(&applied_on, &Rfc3339).unwrap(); + + let checksum: String = row.get(3)?; + applied.push(Migration::applied( + version, + row.get(1)?, + applied_on, + checksum + .parse::() + .expect("checksum must be a valid u64"), + )); + } + Ok(applied) +} + +impl Transaction for RqlConnection { + type Error = RqlError; + fn execute<'a, T: Iterator>(&mut self, queries: T) -> Result { + let transaction = self.transaction()?; + let mut count = 0; + for query in queries { + transaction.execute_batch(query)?; + count += 1; + } + transaction.commit()?; + Ok(count) + } +} + +impl Query> for RqlConnection { + fn query(&mut self, query: &str) -> Result, Self::Error> { + let transaction = self.transaction()?; + let applied = query_applied_migrations(&transaction, query)?; + transaction.commit()?; + Ok(applied) + } +} + +impl Migrate for RqlConnection {} diff --git a/trailbase-refinery/src/error.rs b/trailbase-refinery/src/error.rs new file mode 100644 index 00000000..acca39c6 --- /dev/null +++ b/trailbase-refinery/src/error.rs @@ -0,0 +1,95 @@ +use crate::{Migration, Report}; +use std::fmt; +use std::path::PathBuf; +use thiserror::Error as TError; + +/// An Error occurred during a migration cycle +#[derive(Debug)] +pub struct Error { + kind: Box, + report: Option, +} + +impl Error { + /// Instantiate a new Error + pub(crate) fn new(kind: Kind, report: Option) -> Error { + Error { + kind: Box::new(kind), + report, + } + } + + /// Return the Report of the migration cycle if any + pub fn report(&self) -> Option<&Report> { + self.report.as_ref() + } + + /// Return the kind of error occurred + pub fn kind(&self) -> &Kind { + &self.kind + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.kind) + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.kind.source() + } +} + +/// Enum listing possible errors from Refinery. +#[derive(Debug, TError)] +pub enum Kind { + /// An Error from an invalid file name migration + #[error("migration name must be in the format V{{number}}__{{name}}")] + InvalidName, + /// An Error from an invalid version on a file name migration + #[error("migration version must be a valid integer")] + InvalidVersion, + /// An Error from a repeated version, migration version numbers must be unique + #[error("migration {0} is repeated, migration versions must be unique")] + RepeatedVersion(Migration), + /// An Error from an divergent version, the applied version is different to the filesystem one + #[error("applied migration {0} is different than filesystem one {1}")] + DivergentVersion(Migration, Migration), + /// An Error from an divergent version, the applied version is missing on the filesystem + #[error("migration {0} is missing from the filesystem")] + MissingVersion(Migration), + /// An Error from an invalid migrations path location + #[error("invalid migrations path {0}, {1}")] + InvalidMigrationPath(PathBuf, std::io::Error), + /// An Error parsing refinery Config + #[error("Error parsing config: {0}")] + ConfigError(String), + /// An Error from an underlying database connection Error + #[error("`{0}`, `{1}`")] + Connection(String, #[source] Box), + /// An Error from an invalid migration file (not UTF-8 etc) + #[error("invalid migration file at path {0}, {1}")] + InvalidMigrationFile(PathBuf, std::io::Error), +} + +// Helper trait for adding custom messages and applied migrations to Connection error's. +pub trait WrapMigrationError { + fn migration_err(self, msg: &str, report: Option<&[Migration]>) -> Result; +} + +impl WrapMigrationError for Result +where + E: std::error::Error + Send + Sync + 'static, +{ + fn migration_err(self, msg: &str, applied_migrations: Option<&[Migration]>) -> Result { + match self { + Ok(report) => Ok(report), + Err(err) => Err(Error { + kind: Box::new(Kind::Connection(msg.into(), Box::new(err))), + report: applied_migrations.map(|am| Report::new(am.to_vec())), + }), + } + } +} diff --git a/trailbase-refinery/src/lib.rs b/trailbase-refinery/src/lib.rs new file mode 100644 index 00000000..275cf1cd --- /dev/null +++ b/trailbase-refinery/src/lib.rs @@ -0,0 +1,15 @@ +mod drivers; +pub mod error; +mod runner; +pub mod traits; +mod util; + +pub use crate::error::Error; +pub use crate::runner::{Migration, Report, Runner, Target}; +pub use crate::traits::r#async::AsyncMigrate; +pub use crate::traits::sync::Migrate; +pub use crate::util::{ + MigrationType, find_migration_files, load_sql_migrations, parse_migration_name, +}; + +pub use rusqlite; diff --git a/trailbase-refinery/src/runner.rs b/trailbase-refinery/src/runner.rs new file mode 100644 index 00000000..04f7ff17 --- /dev/null +++ b/trailbase-refinery/src/runner.rs @@ -0,0 +1,452 @@ +use siphasher::sip::SipHasher13; +use time::OffsetDateTime; + +use log::error; +use std::cmp::Ordering; +use std::collections::VecDeque; +use std::fmt; +use std::hash::{Hash, Hasher}; + +use crate::traits::{DEFAULT_MIGRATION_TABLE_NAME, sync::migrate as sync_migrate}; +use crate::util::parse_migration_name; +use crate::{AsyncMigrate, Error, Migrate}; +use std::fmt::Formatter; + +/// An enum set that represents the type of the Migration +#[derive(Clone, PartialEq)] +pub enum Type { + Versioned, + Unversioned, +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let version_type = match self { + Type::Versioned => "V", + Type::Unversioned => "U", + }; + write!(f, "{}", version_type) + } +} + +impl fmt::Debug for Type { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let version_type = match self { + Type::Versioned => "Versioned", + Type::Unversioned => "Unversioned", + }; + write!(f, "{}", version_type) + } +} + +/// An enum set that represents the target version up to which refinery should migrate, it is used +/// by [Runner] +#[derive(Clone, Copy, Debug)] +pub enum Target { + Latest, + Version(u32), + Fake, + FakeVersion(u32), +} + +// an Enum set that represents the state of the migration: Applied on the database, +// or Unapplied yet to be applied on the database +#[derive(Clone, Debug)] +enum State { + Applied, + Unapplied, +} + +/// Represents a schema migration to be run on the database, +/// this struct is used by the [`embed_migrations!`] macro to gather migration files +/// and shouldn't be needed by the user +/// +/// [`embed_migrations!`]: macro.embed_migrations.html +#[derive(Clone, Debug)] +pub struct Migration { + state: State, + name: String, + checksum: u64, + version: i32, + prefix: Type, + sql: Option, + applied_on: Option, +} + +impl Migration { + /// Create an unapplied migration, name and version are parsed from the input_name, + /// which must be named in the format (U|V){1}__{2}.rs where {1} represents the migration version + /// and {2} the name. + pub fn unapplied(input_name: &str, sql: &str) -> Result { + let (prefix, version, name) = parse_migration_name(input_name)?; + + // Previously, `std::collections::hash_map::DefaultHasher` was used + // to calculate the checksum and the implementation at that time + // was SipHasher13. However, that implementation is not guaranteed: + // > The internal algorithm is not specified, and so it and its + // > hashes should not be relied upon over releases. + // We now explicitly use SipHasher13 to both remain compatible with + // existing migrations and prevent breaking from possible future + // changes to `DefaultHasher`. + let mut hasher = SipHasher13::new(); + name.hash(&mut hasher); + version.hash(&mut hasher); + sql.hash(&mut hasher); + let checksum = hasher.finish(); + + Ok(Migration { + state: State::Unapplied, + name, + version, + prefix, + sql: Some(sql.into()), + applied_on: None, + checksum, + }) + } + + // Create a migration from an applied migration on the database + pub fn applied( + version: i32, + name: String, + applied_on: OffsetDateTime, + checksum: u64, + ) -> Migration { + Migration { + state: State::Applied, + name, + checksum, + version, + // applied migrations are always versioned + prefix: Type::Versioned, + sql: None, + applied_on: Some(applied_on), + } + } + + // convert the Unapplied into an Applied Migration + pub fn set_applied(&mut self) { + self.applied_on = Some(OffsetDateTime::now_utc()); + self.state = State::Applied; + } + + // Get migration sql content + pub fn sql(&self) -> Option<&str> { + self.sql.as_deref() + } + + /// Get the Migration version + pub fn version(&self) -> u32 { + self.version as u32 + } + + /// Get the Prefix + pub fn prefix(&self) -> &Type { + &self.prefix + } + + /// Get the Migration Name + pub fn name(&self) -> &str { + &self.name + } + + /// Get the timestamp from when the Migration was applied. `None` when unapplied. + /// Migrations returned from Runner::get_migrations() will always have `None`. + pub fn applied_on(&self) -> Option<&OffsetDateTime> { + self.applied_on.as_ref() + } + + /// Get the Migration checksum. Checksum is formed from the name version and sql of the Migration + pub fn checksum(&self) -> u64 { + self.checksum + } +} + +impl fmt::Display for Migration { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{}{}__{}", self.prefix, self.version, self.name) + } +} + +impl Eq for Migration {} + +impl PartialEq for Migration { + fn eq(&self, other: &Migration) -> bool { + self.version == other.version && self.name == other.name && self.checksum() == other.checksum() + } +} + +impl Ord for Migration { + fn cmp(&self, other: &Migration) -> Ordering { + self.version.cmp(&other.version) + } +} + +impl PartialOrd for Migration { + fn partial_cmp(&self, other: &Migration) -> Option { + Some(self.cmp(other)) + } +} + +/// Struct that represents the report of the migration cycle, +/// a `Report` instance is returned by the [`Runner::run`] and [`Runner::run_async`] methods +/// via [`Result`]``, on case of an [`Error`] during a migration, you can access the +/// `Report` with [`Error.report`] +/// +/// [`Error`]: struct.Error.html +/// [`Runner::run`]: struct.Runner.html#method.run +/// [`Runner::run_async`]: struct.Runner.html#method.run_async +/// [`Result`]: https://doc.rust-lang.org/std/result/enum.Result.html +/// [`Error.report`]: struct.Error.html#method.report +#[derive(Clone, Debug)] +pub struct Report { + applied_migrations: Vec, +} + +impl Report { + /// Instantiate a new Report + pub(crate) fn new(applied_migrations: Vec) -> Report { + Report { applied_migrations } + } + + /// Retrieves the list of applied `Migration` of the migration cycle + pub fn applied_migrations(&self) -> &Vec { + &self.applied_migrations + } +} + +/// Struct that represents the entrypoint to run the migrations, +/// an instance of this struct is returned by the [`embed_migrations!`] macro. +/// `Runner` should not need to be instantiated manually +/// +/// [`embed_migrations!`]: macro.embed_migrations.html +pub struct Runner { + grouped: bool, + abort_divergent: bool, + abort_missing: bool, + migrations: Vec, + target: Target, + migration_table_name: String, +} + +impl Runner { + /// instantiate a new Runner + pub fn new(migrations: &[Migration]) -> Runner { + Runner { + grouped: false, + target: Target::Latest, + abort_divergent: true, + abort_missing: true, + migrations: migrations.to_vec(), + migration_table_name: DEFAULT_MIGRATION_TABLE_NAME.into(), + } + } + + /// Get the gathered migrations. + pub fn get_migrations(&self) -> &Vec { + &self.migrations + } + + /// Set the target version up to which refinery should migrate, Latest migrates to the latest + /// version available Version migrates to a user provided version, a Version with a higher + /// version than the latest will be ignored, and Fake doesn't actually run any migration, just + /// creates and updates refinery's schema migration table by default this is set to Latest + pub fn set_target(self, target: Target) -> Runner { + Runner { target, ..self } + } + + /// Set true if all migrations should be grouped and run in a single transaction. + /// by default this is set to false, each migration runs on their own transaction + /// + /// # Note + /// + /// set_grouped won't probably work on MySQL Databases as MySQL lacks support for transactions + /// around schema alteration operations, meaning that if a migration fails to apply you will + /// have to manually unpick the changes in order to try again (it’s impossible to roll back to an + /// earlier point). + pub fn set_grouped(self, grouped: bool) -> Runner { + Runner { grouped, ..self } + } + + /// Set true if migration process should abort if divergent migrations are found + /// i.e. applied migrations with the same version but different name or checksum from the ones on + /// the filesystem. by default this is set to true + pub fn set_abort_divergent(self, abort_divergent: bool) -> Runner { + Runner { + abort_divergent, + ..self + } + } + + /// Set true if migration process should abort if missing migrations are found + /// i.e. applied migrations that are not found on the filesystem, + /// or migrations found on filesystem with a version inferior to the last one applied but not + /// applied. by default this is set to true + pub fn set_abort_missing(self, abort_missing: bool) -> Runner { + Runner { + abort_missing, + ..self + } + } + + /// Queries the database for the last applied migration, returns None if there aren't applied + /// Migrations + pub fn get_last_applied_migration(&self, conn: &'_ mut C) -> Result, Error> + where + C: Migrate, + { + Migrate::get_last_applied_migration(conn, &self.migration_table_name) + } + + /// Queries the database asynchronously for the last applied migration, returns None if there + /// aren't applied Migrations + pub async fn get_last_applied_migration_async( + &self, + conn: &mut C, + ) -> Result, Error> + where + C: AsyncMigrate + Send, + { + AsyncMigrate::get_last_applied_migration(conn, &self.migration_table_name).await + } + + /// Queries the database for all previous applied migrations + pub fn get_applied_migrations(&self, conn: &'_ mut C) -> Result, Error> + where + C: Migrate, + { + Migrate::get_applied_migrations(conn, &self.migration_table_name) + } + + /// Queries the database asynchronously for all previous applied migrations + pub async fn get_applied_migrations_async(&self, conn: &mut C) -> Result, Error> + where + C: AsyncMigrate + Send, + { + AsyncMigrate::get_applied_migrations(conn, &self.migration_table_name).await + } + + /// Set the table name to use for the migrations table. The default name is + /// `refinery_schema_history` + /// + /// ### Warning + /// Changing this can be disastrous for your database. You should verify that the migrations table + /// has the same name as the name you specify here, if this is changed on an existing project. + /// + /// # Panics + /// + /// If the provided `migration_table_name` is empty + pub fn set_migration_table_name>(&mut self, migration_table_name: S) -> &mut Self { + if migration_table_name.as_ref().is_empty() { + panic!("Migration table name must not be empty"); + } + + self.migration_table_name = migration_table_name.as_ref().to_string(); + self + } + + /// Creates an iterator over pending migrations, applying each before returning + /// the result from `next()`. If a migration fails, the iterator will return that + /// result and further calls to `next()` will return `None`. + pub fn run_iter( + self, + connection: &mut C, + ) -> impl Iterator> + '_ + where + C: Migrate, + { + RunIterator::new(self, connection) + } + + /// Runs the Migrations in the supplied database connection + pub fn run(&self, connection: &mut C) -> Result + where + C: Migrate, + { + Migrate::migrate( + connection, + &self.migrations, + self.abort_divergent, + self.abort_missing, + self.grouped, + self.target, + &self.migration_table_name, + ) + } + + /// Runs the Migrations asynchronously in the supplied database connection + pub async fn run_async(&self, connection: &mut C) -> Result + where + C: AsyncMigrate + Send, + { + AsyncMigrate::migrate( + connection, + &self.migrations, + self.abort_divergent, + self.abort_missing, + self.grouped, + self.target, + &self.migration_table_name, + ) + .await + } +} + +pub struct RunIterator<'a, C> { + connection: &'a mut C, + target: Target, + migration_table_name: String, + items: VecDeque, + failed: bool, +} +impl<'a, C> RunIterator<'a, C> +where + C: Migrate, +{ + pub(crate) fn new(runner: Runner, connection: &'a mut C) -> RunIterator<'a, C> { + RunIterator { + items: VecDeque::from( + Migrate::get_unapplied_migrations( + connection, + &runner.migrations, + runner.abort_divergent, + runner.abort_missing, + &runner.migration_table_name, + ) + .unwrap(), + ), + connection, + target: runner.target, + migration_table_name: runner.migration_table_name.clone(), + failed: false, + } + } +} +impl Iterator for RunIterator<'_, C> +where + C: Migrate, +{ + type Item = Result; + + fn next(&mut self) -> Option { + match self.failed { + true => None, + false => self.items.pop_front().and_then(|migration| { + sync_migrate( + self.connection, + vec![migration], + self.target, + &self.migration_table_name, + false, + ) + .map(|r| r.applied_migrations.first().cloned()) + .map_err(|e| { + error!("migration failed: {e:?}"); + self.failed = true; + e + }) + .transpose() + }), + } + } +} diff --git a/trailbase-refinery/src/traits/async.rs b/trailbase-refinery/src/traits/async.rs new file mode 100644 index 00000000..b10aef48 --- /dev/null +++ b/trailbase-refinery/src/traits/async.rs @@ -0,0 +1,197 @@ +use crate::error::WrapMigrationError; +use crate::traits::{ + ASSERT_MIGRATIONS_TABLE_QUERY, GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY, + insert_migration_query, verify_migrations, +}; +use crate::{Error, Migration, Report, Target}; + +use async_trait::async_trait; +use std::string::ToString; + +#[async_trait] +pub trait AsyncTransaction { + type Error: std::error::Error + Send + Sync + 'static; + + async fn execute<'a, T: Iterator + Send>( + &mut self, + queries: T, + ) -> Result; +} + +#[async_trait] +pub trait AsyncQuery: AsyncTransaction { + async fn query(&mut self, query: &str) -> Result; +} + +async fn migrate( + transaction: &mut T, + migrations: Vec, + target: Target, + migration_table_name: &str, +) -> Result { + let mut applied_migrations = vec![]; + + for mut migration in migrations.into_iter() { + if let Target::Version(input_target) = target { + if input_target < migration.version() { + log::info!( + "stopping at migration: {}, due to user option", + input_target + ); + break; + } + } + + log::info!("applying migration: {}", migration); + migration.set_applied(); + let update_query = insert_migration_query(&migration, migration_table_name); + transaction + .execute( + [ + migration.sql().as_ref().expect("sql must be Some!"), + update_query.as_str(), + ] + .into_iter(), + ) + .await + .migration_err( + &format!("error applying migration {}", migration), + Some(&applied_migrations), + )?; + applied_migrations.push(migration); + } + Ok(Report::new(applied_migrations)) +} + +async fn migrate_grouped( + transaction: &mut T, + migrations: Vec, + target: Target, + migration_table_name: &str, +) -> Result { + let mut grouped_migrations = Vec::new(); + let mut applied_migrations = Vec::new(); + + for mut migration in migrations.into_iter() { + if let Target::Version(input_target) | Target::FakeVersion(input_target) = target { + if input_target < migration.version() { + break; + } + } + + migration.set_applied(); + let query = insert_migration_query(&migration, migration_table_name); + + let sql = migration.sql().expect("sql must be Some!").to_string(); + + // If Target is Fake, we only update schema migrations table + if !matches!(target, Target::Fake | Target::FakeVersion(_)) { + applied_migrations.push(migration); + grouped_migrations.push(sql); + } + grouped_migrations.push(query); + } + + match target { + Target::Fake | Target::FakeVersion(_) => { + log::info!("not going to apply any migration as fake flag is enabled"); + } + Target::Latest | Target::Version(_) => { + log::info!( + "going to apply batch migrations in single transaction: {:#?}", + applied_migrations.iter().map(ToString::to_string) + ); + } + }; + + if let Target::Version(input_target) = target { + log::info!( + "stopping at migration: {}, due to user option", + input_target + ); + } + + let refs = grouped_migrations.iter().map(AsRef::as_ref); + + transaction + .execute(refs) + .await + .migration_err("error applying migrations", None)?; + + Ok(Report::new(applied_migrations)) +} + +#[async_trait] +pub trait AsyncMigrate: AsyncQuery> +where + Self: Sized, +{ + // Needed cause some database vendors like Mssql have a non sql standard way of checking the + // migrations table + fn assert_migrations_table_query(migration_table_name: &str) -> String { + ASSERT_MIGRATIONS_TABLE_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name) + } + + async fn get_last_applied_migration( + &mut self, + migration_table_name: &str, + ) -> Result, Error> { + let mut migrations = self + .query( + &GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name), + ) + .await + .migration_err("error getting last applied migration", None)?; + + Ok(migrations.pop()) + } + + async fn get_applied_migrations( + &mut self, + migration_table_name: &str, + ) -> Result, Error> { + let migrations = self + .query(&GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)) + .await + .migration_err("error getting applied migrations", None)?; + + Ok(migrations) + } + + async fn migrate( + &mut self, + migrations: &[Migration], + abort_divergent: bool, + abort_missing: bool, + grouped: bool, + target: Target, + migration_table_name: &str, + ) -> Result { + self + .execute([Self::assert_migrations_table_query(migration_table_name).as_str()].into_iter()) + .await + .migration_err("error asserting migrations table", None)?; + + let applied_migrations = self + .query(&GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)) + .await + .migration_err("error getting current schema version", None)?; + + let migrations = verify_migrations( + applied_migrations, + migrations.to_vec(), + abort_divergent, + abort_missing, + )?; + + if migrations.is_empty() { + log::info!("no migrations to apply"); + } + + if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) { + migrate_grouped(self, migrations, target, migration_table_name).await + } else { + migrate(self, migrations, target, migration_table_name).await + } + } +} diff --git a/trailbase-refinery/src/traits/mod.rs b/trailbase-refinery/src/traits/mod.rs new file mode 100644 index 00000000..6e5f0758 --- /dev/null +++ b/trailbase-refinery/src/traits/mod.rs @@ -0,0 +1,309 @@ +use time::format_description::well_known::Rfc3339; + +pub mod r#async; +pub mod sync; + +use crate::runner::Type; +use crate::{Error, Migration, error::Kind}; + +// Verifies applied and to be applied migrations returning Error if: +// - `abort_divergent` is true and there are applied migrations with a different name and checksum +// but same version as a migration to be applied. +// - `abort_missing` is true and there are applied migrations that are missing on the file system +// - there are repeated migrations with the same version to be applied +pub(crate) fn verify_migrations( + applied: Vec, + mut migrations: Vec, + abort_divergent: bool, + abort_missing: bool, +) -> Result, Error> { + migrations.sort(); + + for app in &applied { + // iterate applied migrations on database and assert all migrations + // applied on database exist on the file system and have the same checksum + if migrations + .binary_search_by(|m| m.version().cmp(&app.version())) + .is_err() + { + if abort_missing { + return Err(Error::new(Kind::MissingVersion(app.clone()), None)); + } else { + log::warn!(target: "trailbase_refinery::traits::missing", "migration {} is missing from the filesystem", app); + } + } + } + + let stale = |migration: &Migration| { + applied + .last() + .map(|latest| latest.version() >= migration.version()) + .unwrap_or(false) + }; + + let mut to_be_applied = Vec::new(); + // iterate all migration files found on file system and assert that there are not migrations + // missing: migrations which its version is inferior to the current version on the database, yet + // were not applied. select to be applied all migrations with version greater than current + for migration in migrations { + if let Some(app) = applied + .iter() + .find(|app| app.version() == migration.version()) + { + if *app == migration { + continue; + } + + if abort_divergent { + return Err(Error::new( + Kind::DivergentVersion(app.clone(), migration.clone()), + None, + )); + } + + log::warn!( + target: "trailbase_refinery::traits::divergent", + "applied migration {app} is different than filesystem one {migration} => skipping {migration}", + ); + continue; + } + + if to_be_applied.contains(&migration) { + return Err(Error::new(Kind::RepeatedVersion(migration), None)); + } + + if migration.prefix() == &Type::Versioned && stale(&migration) { + if abort_missing { + return Err(Error::new(Kind::MissingVersion(migration), None)); + } + + log::error!(target: "trailbase_refinery::traits::missing", "found strictly versioned, not applied migration on file system => skipping: {migration}"); + continue; + } + + to_be_applied.push(migration); + } + + // with these two iterations we both assert that all migrations found on the database + // exist on the file system and have the same checksum, and all migrations found + // on the file system are either on the database, or greater than the current, and therefore going + // to be applied + Ok(to_be_applied) +} + +pub(crate) fn insert_migration_query(migration: &Migration, migration_table_name: &str) -> String { + format!( + "INSERT INTO {} (version, name, applied_on, checksum) VALUES ({}, '{}', '{}', '{}')", + // safe to call unwrap as we just converted it to applied, and we are sure it can be formatted + // according to RFC 33339 + migration_table_name, + migration.version(), + migration.name(), + migration.applied_on().unwrap().format(&Rfc3339).unwrap(), + migration.checksum() + ) +} + +pub(crate) const ASSERT_MIGRATIONS_TABLE_QUERY: &str = + "CREATE TABLE IF NOT EXISTS %MIGRATION_TABLE_NAME%( + version INT4 PRIMARY KEY, + name VARCHAR(255), + applied_on VARCHAR(255), + checksum VARCHAR(255));"; + +pub(crate) const GET_APPLIED_MIGRATIONS_QUERY: &str = "SELECT version, name, applied_on, checksum \ + FROM %MIGRATION_TABLE_NAME% ORDER BY version ASC;"; + +pub(crate) const GET_LAST_APPLIED_MIGRATION_QUERY: &str = + "SELECT version, name, applied_on, checksum + FROM %MIGRATION_TABLE_NAME% WHERE version=(SELECT MAX(version) from %MIGRATION_TABLE_NAME%)"; + +pub(crate) const DEFAULT_MIGRATION_TABLE_NAME: &str = "_schema_history"; + +#[cfg(test)] +mod tests { + use super::{Kind, Migration, verify_migrations}; + use std::include_str; + + fn get_migrations() -> Vec { + let migration1 = Migration::unapplied( + "V1__initial.sql", + "CREATE TABLE persons (id int, name varchar(255), city varchar(255));", + ) + .unwrap(); + + let migration2 = Migration::unapplied( + "V2__add_cars_and_motos_table.sql", + include_str!("../../tests/migrations/V1-2/V2__add_cars_and_motos_table.sql"), + ) + .unwrap(); + + let migration3 = Migration::unapplied( + "V3__add_brand_to_cars_table", + include_str!("../../tests/migrations/V3/V3__add_brand_to_cars_table.sql"), + ) + .unwrap(); + + let migration4 = Migration::unapplied( + "V4__add_year_field_to_cars", + "ALTER TABLE cars ADD year INTEGER;", + ) + .unwrap(); + + vec![migration1, migration2, migration3, migration4] + } + + #[test] + fn verify_migrations_returns_all_migrations_if_applied_are_empty() { + let migrations = get_migrations(); + let applied: Vec = Vec::new(); + let result = verify_migrations(applied, migrations.clone(), true, true).unwrap(); + assert_eq!(migrations, result); + } + + #[test] + fn verify_migrations_returns_unapplied() { + let migrations = get_migrations(); + let applied: Vec = vec![ + migrations[0].clone(), + migrations[1].clone(), + migrations[2].clone(), + ]; + let remaining = vec![migrations[3].clone()]; + let result = verify_migrations(applied, migrations, true, true).unwrap(); + assert_eq!(remaining, result); + } + + #[test] + fn verify_migrations_fails_on_divergent() { + let migrations = get_migrations(); + let applied: Vec = vec![ + migrations[0].clone(), + migrations[1].clone(), + Migration::unapplied( + "V3__add_brand_to_cars_tableeee", + include_str!("../../tests/migrations/V3/V3__add_brand_to_cars_table.sql"), + ) + .unwrap(), + ]; + + let migration = migrations[2].clone(); + let err = verify_migrations(applied, migrations, true, true).unwrap_err(); + match err.kind() { + Kind::DivergentVersion(applied, divergent) => { + assert_eq!(&migration, divergent); + assert_eq!("add_brand_to_cars_tableeee", applied.name()); + } + _ => panic!("failed test"), + } + } + + #[test] + fn verify_migrations_doesnt_fail_on_divergent() { + let migrations = get_migrations(); + let applied: Vec = vec![ + migrations[0].clone(), + migrations[1].clone(), + Migration::unapplied( + "V3__add_brand_to_cars_tableeee", + include_str!("../../tests/migrations/V3/V3__add_brand_to_cars_table.sql"), + ) + .unwrap(), + ]; + let remaining = vec![migrations[3].clone()]; + let result = verify_migrations(applied, migrations, false, true).unwrap(); + assert_eq!(remaining, result); + } + + #[test] + fn verify_migrations_fails_on_missing_on_applied() { + let migrations = get_migrations(); + let applied: Vec = vec![migrations[0].clone(), migrations[2].clone()]; + let migration = migrations[1].clone(); + let err = verify_migrations(applied, migrations, true, true).unwrap_err(); + match err.kind() { + Kind::MissingVersion(missing) => { + assert_eq!(&migration, missing); + } + _ => panic!("failed test"), + } + } + + #[test] + fn verify_migrations_fails_on_missing_on_filesystem() { + let mut migrations = get_migrations(); + let applied: Vec = vec![ + migrations[0].clone(), + migrations[1].clone(), + migrations[2].clone(), + ]; + let migration = migrations.remove(1); + let err = verify_migrations(applied, migrations, true, true).unwrap_err(); + match err.kind() { + Kind::MissingVersion(missing) => { + assert_eq!(&migration, missing); + } + _ => panic!("failed test"), + } + } + + #[test] + fn verify_migrations_doesnt_fail_on_missing_on_applied() { + let migrations = get_migrations(); + let applied: Vec = vec![migrations[0].clone(), migrations[2].clone()]; + let remaining = vec![migrations[3].clone()]; + let result = verify_migrations(applied, migrations, true, false).unwrap(); + assert_eq!(remaining, result); + } + + #[test] + fn verify_migrations_doesnt_fail_on_missing_on_filesystem() { + let mut migrations = get_migrations(); + let applied: Vec = vec![ + migrations[0].clone(), + migrations[1].clone(), + migrations[2].clone(), + ]; + migrations.remove(1); + let remaining = vec![migrations[2].clone()]; + let result = verify_migrations(applied, migrations, true, false).unwrap(); + assert_eq!(remaining, result); + } + + #[test] + fn verify_migrations_checks_unversioned_out_of_order_doesnt_fail() { + let mut migrations = get_migrations(); + migrations.push( + Migration::unapplied( + "U0__merge_out_of_order", + include_str!("../../tests/migrations_unversioned/U0__merge_out_of_order.sql"), + ) + .unwrap(), + ); + let applied: Vec = vec![ + migrations[0].clone(), + migrations[1].clone(), + migrations[2].clone(), + migrations[3].clone(), + ]; + + let remaining = vec![migrations[4].clone()]; + let result = verify_migrations(applied, migrations, true, true).unwrap(); + assert_eq!(remaining, result); + } + + #[test] + fn verify_migrations_fails_on_repeated_migration() { + let mut migrations = get_migrations(); + let repeated = migrations[0].clone(); + migrations.push(repeated.clone()); + + let err = verify_migrations(vec![], migrations, false, true).unwrap_err(); + match err.kind() { + Kind::RepeatedVersion(m) => { + assert_eq!(m, &repeated); + } + _ => panic!("failed test"), + } + } +} diff --git a/trailbase-refinery/src/traits/sync.rs b/trailbase-refinery/src/traits/sync.rs new file mode 100644 index 00000000..9af5f61c --- /dev/null +++ b/trailbase-refinery/src/traits/sync.rs @@ -0,0 +1,177 @@ +use std::ops::Deref; + +use crate::error::WrapMigrationError; +use crate::traits::{ + ASSERT_MIGRATIONS_TABLE_QUERY, GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY, + insert_migration_query, verify_migrations, +}; +use crate::{Error, Migration, Report, Target}; + +pub trait Transaction { + type Error: std::error::Error + Send + Sync + 'static; + + fn execute<'a, T: Iterator>(&mut self, queries: T) -> Result; +} + +pub trait Query: Transaction { + fn query(&mut self, query: &str) -> Result; +} + +pub fn migrate( + transaction: &mut T, + migrations: Vec, + target: Target, + migration_table_name: &str, + grouped: bool, +) -> Result { + let mut migration_batch = Vec::new(); + let mut applied_migrations = Vec::new(); + + for mut migration in migrations.into_iter() { + if let Target::Version(input_target) | Target::FakeVersion(input_target) = target { + if input_target < migration.version() { + log::info!( + "stopping at migration: {}, due to user option", + input_target + ); + break; + } + } + + log::info!("applying migration: {}", migration); + migration.set_applied(); + let insert_migration = insert_migration_query(&migration, migration_table_name); + let migration_sql = migration.sql().expect("sql must be Some!").to_string(); + + // If Target is Fake, we only update schema migrations table + if !matches!(target, Target::Fake | Target::FakeVersion(_)) { + applied_migrations.push(migration); + migration_batch.push(migration_sql); + } + migration_batch.push(insert_migration); + } + + match (target, grouped) { + (Target::Fake | Target::FakeVersion(_), _) => { + log::info!("not going to apply any migration as fake flag is enabled"); + } + (Target::Latest | Target::Version(_), true) => { + log::info!( + "going to apply batch migrations in single transaction: {:#?}", + applied_migrations.iter().map(ToString::to_string) + ); + } + (Target::Latest | Target::Version(_), false) => { + log::info!( + "preparing to apply {} migrations: {:#?}", + applied_migrations.len(), + applied_migrations.iter().map(ToString::to_string) + ); + } + }; + + if grouped { + transaction + .execute(migration_batch.iter().map(Deref::deref)) + .migration_err("error applying migrations", None)?; + } else { + for (i, update) in migration_batch.into_iter().enumerate() { + transaction + .execute([update.as_str()].into_iter()) + .migration_err("error applying update", Some(&applied_migrations[0..i / 2]))?; + } + } + + Ok(Report::new(applied_migrations)) +} + +pub trait Migrate: Query> +where + Self: Sized, +{ + fn assert_migrations_table(&mut self, migration_table_name: &str) -> Result { + // Needed cause some database vendors like Mssql have a non sql standard way of checking the + // migrations table, thou on this case it's just to be consistent with the async trait + // `AsyncMigrate` + self + .execute( + [ASSERT_MIGRATIONS_TABLE_QUERY + .replace("%MIGRATION_TABLE_NAME%", migration_table_name) + .as_str()] + .into_iter(), + ) + .migration_err("error asserting migrations table", None) + } + + fn get_last_applied_migration( + &mut self, + migration_table_name: &str, + ) -> Result, Error> { + let mut migrations = self + .query( + &GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name), + ) + .migration_err("error getting last applied migration", None)?; + + Ok(migrations.pop()) + } + + fn get_applied_migrations( + &mut self, + migration_table_name: &str, + ) -> Result, Error> { + let migrations = self + .query(&GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)) + .migration_err("error getting applied migrations", None)?; + + Ok(migrations) + } + + fn get_unapplied_migrations( + &mut self, + migrations: &[Migration], + abort_divergent: bool, + abort_missing: bool, + migration_table_name: &str, + ) -> Result, Error> { + self.assert_migrations_table(migration_table_name)?; + + let applied_migrations = self.get_applied_migrations(migration_table_name)?; + + let migrations = verify_migrations( + applied_migrations, + migrations.to_vec(), + abort_divergent, + abort_missing, + )?; + + if migrations.is_empty() { + log::info!("no migrations to apply"); + } + + Ok(migrations) + } + + fn migrate( + &mut self, + migrations: &[Migration], + abort_divergent: bool, + abort_missing: bool, + grouped: bool, + target: Target, + migration_table_name: &str, + ) -> Result { + let migrations = self.get_unapplied_migrations( + migrations, + abort_divergent, + abort_missing, + migration_table_name, + )?; + + if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) { + migrate(self, migrations, target, migration_table_name, true) + } else { + migrate(self, migrations, target, migration_table_name, false) + } + } +} diff --git a/trailbase-refinery/src/util.rs b/trailbase-refinery/src/util.rs new file mode 100644 index 00000000..2566d432 --- /dev/null +++ b/trailbase-refinery/src/util.rs @@ -0,0 +1,242 @@ +use crate::Migration; +use crate::error::{Error, Kind}; +use crate::runner::Type; +use regex::Regex; +use std::ffi::OsStr; +use std::path::{Path, PathBuf}; +use std::sync::OnceLock; +use walkdir::{DirEntry, WalkDir}; + +const STEM_RE: &str = r"^([U|V])(\d+(?:\.\d+)?)__(\w+)"; + +/// Matches the stem of a migration file. +fn file_stem_re() -> &'static Regex { + static RE: OnceLock = OnceLock::new(); + RE.get_or_init(|| Regex::new(STEM_RE).unwrap()) +} + +/// Matches the stem + extension of a SQL migration file. +fn file_re_sql() -> &'static Regex { + static RE: OnceLock = OnceLock::new(); + RE.get_or_init(|| Regex::new([STEM_RE, r"\.sql$"].concat().as_str()).unwrap()) +} + +/// Matches the stem + extension of any migration file. +fn file_re_all() -> &'static Regex { + static RE: OnceLock = OnceLock::new(); + RE.get_or_init(|| Regex::new([STEM_RE, r"\.(rs|sql)$"].concat().as_str()).unwrap()) +} + +/// enum containing the migration types used to search for migrations +/// either just .sql files or both .sql and .rs +pub enum MigrationType { + All, + Sql, +} + +impl MigrationType { + fn file_match_re(&self) -> &'static Regex { + match self { + MigrationType::All => file_re_all(), + MigrationType::Sql => file_re_sql(), + } + } +} + +/// Parse a migration filename stem into a prefix, version, and name. +pub fn parse_migration_name(name: &str) -> Result<(Type, i32, String), Error> { + let captures = file_stem_re() + .captures(name) + .filter(|caps| caps.len() == 4) + .ok_or_else(|| Error::new(Kind::InvalidName, None))?; + let version: i32 = captures[2] + .parse() + .map_err(|_| Error::new(Kind::InvalidVersion, None))?; + + let name: String = (&captures[3]).into(); + let prefix = match &captures[1] { + "V" => Type::Versioned, + "U" => Type::Unversioned, + _ => unreachable!(), + }; + + Ok((prefix, version, name)) +} + +/// find migrations on file system recursively across directories given a location and +/// [MigrationType] +pub fn find_migration_files( + location: impl AsRef, + migration_type: MigrationType, +) -> Result, Error> { + let location: &Path = location.as_ref(); + let location = location.canonicalize().map_err(|err| { + Error::new( + Kind::InvalidMigrationPath(location.to_path_buf(), err), + None, + ) + })?; + + let re = migration_type.file_match_re(); + let file_paths = WalkDir::new(location) + .into_iter() + .filter_map(Result::ok) + .map(DirEntry::into_path) + // filter by migration file regex + .filter( + move |entry| match entry.file_name().and_then(OsStr::to_str) { + Some(_) if entry.is_dir() => false, + Some(file_name) if re.is_match(file_name) => true, + Some(file_name) => { + log::warn!( + "File \"{}\" does not adhere to the migration naming convention. Migrations must be named in the format [U|V]{{1}}__{{2}}.sql or [U|V]{{1}}__{{2}}.rs, where {{1}} represents the migration version and {{2}} the name.", + file_name + ); + false + } + None => false, + }, + ); + + Ok(file_paths) +} + +/// Loads SQL migrations from a path. This enables dynamic migration discovery, as opposed to +/// embedding. The resulting collection is ordered by version. +pub fn load_sql_migrations(location: impl AsRef) -> Result, Error> { + let migration_files = find_migration_files(location, MigrationType::Sql)?; + + let mut migrations = vec![]; + + for path in migration_files { + let sql = std::fs::read_to_string(path.as_path()).map_err(|e| { + let path = path.to_owned(); + let kind = match e.kind() { + std::io::ErrorKind::NotFound => Kind::InvalidMigrationPath(path, e), + _ => Kind::InvalidMigrationFile(path, e), + }; + + Error::new(kind, None) + })?; + + //safe to call unwrap as find_migration_filenames returns canonical paths + let filename = path + .file_stem() + .and_then(|file| file.to_os_string().into_string().ok()) + .unwrap(); + + let migration = Migration::unapplied(&filename, &sql)?; + migrations.push(migration); + } + + migrations.sort(); + Ok(migrations) +} + +#[cfg(test)] +mod tests { + use super::{MigrationType, find_migration_files, load_sql_migrations}; + use std::fs; + use std::path::PathBuf; + use tempfile::TempDir; + + #[test] + fn finds_mod_migrations() { + let tmp_dir = TempDir::new().unwrap(); + let migrations_dir = tmp_dir.path().join("migrations"); + fs::create_dir(&migrations_dir).unwrap(); + let sql1 = migrations_dir.join("V1__first.rs"); + fs::File::create(&sql1).unwrap(); + let sql2 = migrations_dir.join("V2__second.rs"); + fs::File::create(&sql2).unwrap(); + + let mut mods: Vec = find_migration_files(migrations_dir, MigrationType::All) + .unwrap() + .collect(); + mods.sort(); + assert_eq!(sql1.canonicalize().unwrap(), mods[0]); + assert_eq!(sql2.canonicalize().unwrap(), mods[1]); + } + + #[test] + fn ignores_mod_files_without_migration_regex_match() { + let tmp_dir = TempDir::new().unwrap(); + let migrations_dir = tmp_dir.path().join("migrations"); + fs::create_dir(&migrations_dir).unwrap(); + let sql1 = migrations_dir.join("V1first.rs"); + fs::File::create(sql1).unwrap(); + let sql2 = migrations_dir.join("V2second.rs"); + fs::File::create(sql2).unwrap(); + + let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap(); + assert!(mods.next().is_none()); + } + + #[test] + fn finds_sql_migrations() { + let tmp_dir = TempDir::new().unwrap(); + let migrations_dir = tmp_dir.path().join("migrations"); + fs::create_dir(&migrations_dir).unwrap(); + let sql1 = migrations_dir.join("V1__first.sql"); + fs::File::create(&sql1).unwrap(); + let sql2 = migrations_dir.join("V2__second.sql"); + fs::File::create(&sql2).unwrap(); + + let mut mods: Vec = find_migration_files(migrations_dir, MigrationType::All) + .unwrap() + .collect(); + mods.sort(); + assert_eq!(sql1.canonicalize().unwrap(), mods[0]); + assert_eq!(sql2.canonicalize().unwrap(), mods[1]); + } + + #[test] + fn finds_unversioned_migrations() { + let tmp_dir = TempDir::new().unwrap(); + let migrations_dir = tmp_dir.path().join("migrations"); + fs::create_dir(&migrations_dir).unwrap(); + let sql1 = migrations_dir.join("U1__first.sql"); + fs::File::create(&sql1).unwrap(); + let sql2 = migrations_dir.join("U2__second.sql"); + fs::File::create(&sql2).unwrap(); + + let mut mods: Vec = find_migration_files(migrations_dir, MigrationType::All) + .unwrap() + .collect(); + mods.sort(); + assert_eq!(sql1.canonicalize().unwrap(), mods[0]); + assert_eq!(sql2.canonicalize().unwrap(), mods[1]); + } + + #[test] + fn ignores_sql_files_without_migration_regex_match() { + let tmp_dir = TempDir::new().unwrap(); + let migrations_dir = tmp_dir.path().join("migrations"); + fs::create_dir(&migrations_dir).unwrap(); + let sql1 = migrations_dir.join("V1first.sql"); + fs::File::create(sql1).unwrap(); + let sql2 = migrations_dir.join("V2second.sql"); + fs::File::create(sql2).unwrap(); + + let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap(); + assert!(mods.next().is_none()); + } + + #[test] + fn loads_migrations_from_path() { + let tmp_dir = TempDir::new().unwrap(); + let migrations_dir = tmp_dir.path().join("migrations"); + fs::create_dir(&migrations_dir).unwrap(); + let sql1 = migrations_dir.join("V1__first.sql"); + fs::File::create(&sql1).unwrap(); + let sql2 = migrations_dir.join("V2__second.sql"); + fs::File::create(&sql2).unwrap(); + let rs3 = migrations_dir.join("V3__third.rs"); + fs::File::create(&rs3).unwrap(); + + let migrations = load_sql_migrations(migrations_dir).unwrap(); + assert_eq!(migrations.len(), 2); + assert_eq!(&migrations[0].to_string(), "V1__first"); + assert_eq!(&migrations[1].to_string(), "V2__second"); + } +} diff --git a/trailbase-refinery/tests/migrations/V1-2/V1__initial.rs b/trailbase-refinery/tests/migrations/V1-2/V1__initial.rs new file mode 100644 index 00000000..3a84e506 --- /dev/null +++ b/trailbase-refinery/tests/migrations/V1-2/V1__initial.rs @@ -0,0 +1,15 @@ +use barrel::{types, Migration}; + +use crate::Sql; + +pub fn migration() -> String { + let mut m = Migration::new(); + + m.create_table("persons", |t| { + t.add_column("id", types::primary()); + t.add_column("name", types::varchar(255)); + t.add_column("city", types::varchar(255)); + }); + + m.make::() +} diff --git a/trailbase-refinery/tests/migrations/V1-2/V2__add_cars_and_motos_table.sql b/trailbase-refinery/tests/migrations/V1-2/V2__add_cars_and_motos_table.sql new file mode 100644 index 00000000..13390796 --- /dev/null +++ b/trailbase-refinery/tests/migrations/V1-2/V2__add_cars_and_motos_table.sql @@ -0,0 +1,8 @@ +CREATE TABLE cars ( + id int, + name varchar(255) +); +CREATE TABLE motos ( + id int, + name varchar(255) +); diff --git a/trailbase-refinery/tests/migrations/V3/V3__add_brand_to_cars_table.sql b/trailbase-refinery/tests/migrations/V3/V3__add_brand_to_cars_table.sql new file mode 100644 index 00000000..c689a6aa --- /dev/null +++ b/trailbase-refinery/tests/migrations/V3/V3__add_brand_to_cars_table.sql @@ -0,0 +1,2 @@ +ALTER TABLE cars +ADD brand varchar(255); diff --git a/trailbase-refinery/tests/migrations/V4__add_year_to_motos_table.rs b/trailbase-refinery/tests/migrations/V4__add_year_to_motos_table.rs new file mode 100644 index 00000000..57623df9 --- /dev/null +++ b/trailbase-refinery/tests/migrations/V4__add_year_to_motos_table.rs @@ -0,0 +1,13 @@ +use barrel::{types, Migration}; + +use crate::Sql; + +pub fn migration() -> String { + let mut m = Migration::new(); + + m.change_table("motos", |t| { + t.add_column("brand", types::varchar(255).nullable(true)); + }); + + m.make::() +} diff --git a/trailbase-refinery/tests/migrations_unversioned/U0__merge_out_of_order.sql b/trailbase-refinery/tests/migrations_unversioned/U0__merge_out_of_order.sql new file mode 100644 index 00000000..15321d97 --- /dev/null +++ b/trailbase-refinery/tests/migrations_unversioned/U0__merge_out_of_order.sql @@ -0,0 +1,5 @@ +CREATE TABLE person_finance ( + id int, + person_id int, + amount int +); diff --git a/trailbase-sqlite/Cargo.toml b/trailbase-sqlite/Cargo.toml index b910ecc0..e1ad744c 100644 --- a/trailbase-sqlite/Cargo.toml +++ b/trailbase-sqlite/Cargo.toml @@ -15,7 +15,7 @@ path = "benches/benchmark.rs" harness = false [dependencies] -base64 = { version = "0.22.1", default-features = false } +base64 = { version = "0.22.1", default-features = false, features = ["alloc"] } crossbeam-channel = "0.5.13" kanal = "0.1.1" log = { version = "^0.4.21", default-features = false } diff --git a/vendor/refinery b/vendor/refinery deleted file mode 160000 index d8c35587..00000000 --- a/vendor/refinery +++ /dev/null @@ -1 +0,0 @@ -Subproject commit d8c35587867d8d24d28297c530a7417fa921a4b0