Replace vendored refinery with a more proper fork stripping all the unused bits and bobs.

This commit is contained in:
Sebastian Jeltsch
2025-05-27 16:08:19 +02:00
parent 01feb22e44
commit 5152cf9000
32 changed files with 1701 additions and 68 deletions

3
.gitmodules vendored
View File

@@ -1,6 +1,3 @@
[submodule "vendor/refinery"]
path = vendor/refinery
url = https://github.com/trailbaseio/refinery.git
[submodule "vendor/sqlean/bundled/sqlean"]
path = vendor/sqlean/bundled/sqlean
url = https://github.com/trailbaseio/sqlean

28
Cargo.lock generated
View File

@@ -6821,6 +6821,7 @@ dependencies = [
"regex",
"reqwest",
"rusqlite",
"rust-embed",
"schemars",
"serde",
"serde_json",
@@ -6845,8 +6846,7 @@ dependencies = [
"trailbase-extension",
"trailbase-js",
"trailbase-qs",
"trailbase-refinery-core",
"trailbase-refinery-macros",
"trailbase-refinery",
"trailbase-schema",
"trailbase-sqlite",
"ts-rs",
@@ -6871,13 +6871,12 @@ dependencies = [
"log",
"once_cell",
"rusqlite",
"rust-embed",
"serde",
"serde_json",
"thiserror 2.0.12",
"tokio",
"trailbase-apalis",
"trailbase-refinery-core",
"trailbase-refinery-macros",
"trailbase-refinery",
"trailbase-sqlite",
]
@@ -6997,33 +6996,24 @@ dependencies = [
]
[[package]]
name = "trailbase-refinery-core"
version = "0.8.16"
name = "trailbase-refinery"
version = "0.1.0"
dependencies = [
"async-trait",
"cfg-if",
"futures",
"log",
"regex",
"rusqlite",
"siphasher 1.0.1",
"tempfile",
"thiserror 1.0.69",
"time",
"tokio",
"url",
"walkdir",
]
[[package]]
name = "trailbase-refinery-macros"
version = "0.8.15"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"regex",
"syn 2.0.101",
"trailbase-refinery-core",
]
[[package]]
name = "trailbase-schema"
version = "0.1.0"

View File

@@ -12,6 +12,7 @@ members = [
"trailbase-extension",
"trailbase-js",
"trailbase-qs",
"trailbase-refinery",
"trailbase-schema",
"trailbase-sqlite",
"vendor/sqlean",
@@ -28,10 +29,6 @@ default-members = [
"trailbase-schema",
"trailbase-sqlite",
]
exclude = [
"vendor/refinery",
"vendor/rustc_tools_util",
]
# https://doc.rust-lang.org/cargo/reference/profiles.html
[profile.release]
@@ -76,8 +73,7 @@ trailbase-sqlean = { path = "vendor/sqlean", version = "0.0.2" }
trailbase-extension = { path = "trailbase-extension", version = "0.2.0" }
trailbase-js = { path = "trailbase-js", version = "0.1.0" }
trailbase-qs = { path = "trailbase-qs", version = "0.1.0" }
trailbase-refinery-core = { path = "vendor/refinery/refinery_core", version = "0.8.16", default-features = false, features = ["rusqlite-bundled"] }
trailbase-refinery-macros = { path = "vendor/refinery/refinery_macros", version = "0.8.15" }
trailbase-refinery = { path = "trailbase-refinery", version = "0.1.0" }
trailbase-schema = { path = "trailbase-schema", version = "0.1.0" }
trailbase-sqlite = { path = "trailbase-sqlite", version = "0.2.0" }
trailbase = { path = "trailbase-core", version = "0.1.0" }

View File

@@ -28,7 +28,7 @@ async fn hello_world_handler(State(state): State<CustomState>, user: Option<User
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
env_logger::init_from_env(
env_logger::Env::new()
.default_filter_or("info,refinery_core=warn,tracing::span=warn,swc_ecma_codegen=off"),
.default_filter_or("info,trailbase_refinery=warn,tracing::span=warn,swc_ecma_codegen=off"),
);
let Server {

View File

@@ -21,13 +21,13 @@ chrono = { version = "0.4", features = ["serde"] }
futures = "0.3.30"
futures-lite = "2.3.0"
log = "0.4.21"
rusqlite.workspace = true
rusqlite = { workspace = true }
rust-embed = { workspace = true }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "2.0.0"
tokio = { workspace = true }
trailbase-refinery-core = { workspace = true }
trailbase-refinery-macros = { workspace = true }
trailbase-refinery = { workspace = true }
trailbase-sqlite = { workspace = true }
[dev-dependencies]
@@ -36,7 +36,6 @@ apalis-core = { version = "0.7.0", default-features = false, features = [ "test-
email_address = "0.2.9"
once_cell = "1.19.0"
tokio = { workspace = true }
trailbase-apalis = { path = "." }
[package.metadata.docs.rs]
# defines the configuration attribute `docsrs`

View File

@@ -63,8 +63,32 @@ impl<T, C> Clone for SqliteStorage<T, C> {
}
}
mod main {
trailbase_refinery_macros::embed_migrations!("migrations");
mod migrations {
use trailbase_refinery::{Migration, Runner};
#[derive(Clone, rust_embed::RustEmbed)]
#[folder = "migrations"]
struct Migrations;
fn load_migrations<T: rust_embed::RustEmbed>() -> Vec<Migration> {
let mut migrations = vec![];
for filename in T::iter() {
if let Some(file) = T::get(&filename) {
migrations.push(
Migration::unapplied(&filename, &String::from_utf8_lossy(&file.data)).expect("startup"),
)
}
}
return migrations;
}
pub(crate) fn migration_runner() -> Runner {
const MIGRATION_TABLE_NAME: &str = "_schema_history";
let migrations = load_migrations::<Migrations>();
let mut runner = trailbase_refinery::Runner::new(&migrations).set_abort_divergent(false);
runner.set_migration_table_name(MIGRATION_TABLE_NAME);
return runner;
}
}
impl SqliteStorage<()> {
@@ -81,7 +105,7 @@ impl SqliteStorage<()> {
)
.await?;
let runner = main::migrations::runner();
let runner = migrations::migration_runner();
let _report = conn
.call(move |conn| {

View File

@@ -23,8 +23,7 @@ fn init_logger(dev: bool) {
// SWC is very spammy in in debug builds and complaints about source maps when compiling
// typescript to javascript. Since we don't care about source maps and didn't find a better
// option to mute the errors, turn it off in debug builds.
const DEFAULT: &str =
"info,refinery_core=warn,trailbase_refinery_core=warn,tracing::span=warn,swc_ecma_codegen=off";
const DEFAULT: &str = "info,trailbase_refinery=warn,tracing::span=warn,swc_ecma_codegen=off";
env_logger::Builder::from_env(if dev {
env_logger::Env::new().default_filter_or(format!("{DEFAULT},trailbase=debug"))

View File

@@ -64,6 +64,7 @@ rand = "^0.9.0"
regex = "1.11.0"
reqwest = { version = "0.12.8", default-features = false, features = ["rustls-tls", "json"] }
rusqlite = { workspace = true }
rust-embed = { workspace = true }
serde = { version = "^1.0.203", features = ["derive"] }
serde_json = "^1.0.117"
serde_path_to_error = "0.1.16"
@@ -85,8 +86,7 @@ trailbase-assets = { workspace = true }
trailbase-extension = { workspace = true }
trailbase-js = { workspace = true, optional = true }
trailbase-qs = { workspace = true }
trailbase-refinery-core = { workspace = true }
trailbase-refinery-macros = { workspace = true }
trailbase-refinery = { workspace = true }
trailbase-schema = { workspace = true }
trailbase-sqlite = { workspace = true }
ts-rs = { version = "10", features = ["uuid-impl", "serde-json-impl"] }

View File

@@ -32,7 +32,7 @@ pub enum AdminError {
#[error("Table lookup error: {0}")]
TableLookup(#[from] crate::schema_metadata::SchemaLookupError),
#[error("DB Migration error: {0}")]
Migration(#[from] trailbase_refinery_core::Error),
Migration(#[from] trailbase_refinery::Error),
#[error("SQL -> Json error: {0}")]
Json(#[from] trailbase_sqlite::rows::JsonError),
#[error("Schema error: {0}")]

View File

@@ -54,7 +54,7 @@ mod tests {
#[tokio::test]
async fn test_user_creation_and_deletion() {
let _ = env_logger::try_init_from_env(
env_logger::Env::new().default_filter_or("info,refinery_core=warn"),
env_logger::Env::new().default_filter_or("info,trailbase_refinery=warn"),
);
let mailer = TestAsyncSmtpTransport::new();

View File

@@ -307,9 +307,10 @@ pub async fn test_state(options: Option<TestStateOptions>) -> anyhow::Result<App
use crate::config::proto::{OAuthProviderConfig, OAuthProviderId};
use crate::config::validate_config;
let _ = env_logger::try_init_from_env(env_logger::Env::new().default_filter_or(
"info,refinery_core=warn,trailbase_refinery_core=warn,log::span=warn,swc_ecma_codegen=off",
));
let _ = env_logger::try_init_from_env(
env_logger::Env::new()
.default_filter_or("info,trailbase_refinery=warn,log::span=warn,swc_ecma_codegen=off"),
);
let temp_dir = temp_dir::TempDir::new()?;
tokio::fs::create_dir_all(temp_dir.child("uploads")).await?;

View File

@@ -28,7 +28,7 @@ use crate::extract::Either;
#[tokio::test]
async fn test_auth_registration_reset_and_change_email() {
let _ = env_logger::try_init_from_env(
env_logger::Env::new().default_filter_or("info,refinery_core=warn"),
env_logger::Env::new().default_filter_or("info,trailbase_refinery=warn"),
);
let mailer = TestAsyncSmtpTransport::new();

View File

@@ -17,7 +17,7 @@ pub enum ConnectionError {
#[error("TB SQLite error: {0}")]
TbSqlite(#[from] trailbase_sqlite::Error),
#[error("Migration error: {0}")]
Migration(#[from] trailbase_refinery_core::Error),
Migration(#[from] trailbase_refinery::Error),
}
/// Initializes a new SQLite Connection with all the default extensions, migrations and settings

View File

@@ -2,14 +2,7 @@ use lazy_static::lazy_static;
use log::*;
use parking_lot::Mutex;
use std::path::PathBuf;
use trailbase_refinery_core::Migration;
mod main {
trailbase_refinery_macros::embed_migrations!("migrations/main");
}
mod logs {
trailbase_refinery_macros::embed_migrations!("migrations/logs");
}
use trailbase_refinery::Migration;
const MIGRATION_TABLE_NAME: &str = "_schema_history";
@@ -35,29 +28,41 @@ pub fn new_unique_migration_filename(suffix: &str) -> String {
return format!("U{timestamp}__{suffix}.sql");
}
pub(crate) fn new_migration_runner(migrations: &[Migration]) -> trailbase_refinery_core::Runner {
pub(crate) fn new_migration_runner(migrations: &[Migration]) -> trailbase_refinery::Runner {
// NOTE: divergent migrations are migrations with the same version but a different name. That
// said, `set_abort_divergent` is not a viable way for us to handle collisions (e.g. in tests),
// since setting it to false, will prevent the migration from failing but divergent migrations
// are quietly dropped on the floor and not applied. That's not ok.
let mut runner = trailbase_refinery_core::Runner::new(migrations).set_abort_divergent(false);
let mut runner = trailbase_refinery::Runner::new(migrations).set_abort_divergent(false);
runner.set_migration_table_name(MIGRATION_TABLE_NAME);
return runner;
}
fn load_migrations<T: rust_embed::RustEmbed>() -> Vec<Migration> {
let mut migrations = vec![];
for filename in T::iter() {
if let Some(file) = T::get(&filename) {
migrations.push(
Migration::unapplied(&filename, &String::from_utf8_lossy(&file.data)).expect("startup"),
)
}
}
return migrations;
}
pub(crate) fn apply_main_migrations(
conn: &mut rusqlite::Connection,
user_migrations_path: Option<PathBuf>,
) -> Result<bool, trailbase_refinery_core::Error> {
) -> Result<bool, trailbase_refinery::Error> {
let all_migrations = {
let mut migrations: Vec<Migration> = vec![];
let system_migrations_runner = main::migrations::runner();
migrations.extend(system_migrations_runner.get_migrations().iter().cloned());
let system_migrations_runner: Vec<Migration> = load_migrations::<MainMigrations>();
migrations.extend(system_migrations_runner);
if let Some(path) = user_migrations_path {
// NOTE: refinery has a bug where it will name-check the directory and write a warning... :/.
let user_migrations = trailbase_refinery_core::load_sql_migrations(path)?;
let user_migrations = trailbase_refinery::load_sql_migrations(path)?;
migrations.extend(user_migrations);
}
@@ -92,8 +97,10 @@ pub(crate) fn apply_main_migrations(
pub(crate) fn apply_logs_migrations(
logs_conn: &mut rusqlite::Connection,
) -> Result<(), trailbase_refinery_core::Error> {
let mut runner = logs::migrations::runner();
) -> Result<(), trailbase_refinery::Error> {
let migrations = load_migrations::<LogsMigrations>();
let mut runner = new_migration_runner(&migrations);
runner.set_migration_table_name(MIGRATION_TABLE_NAME);
let report = runner.run(logs_conn).map_err(|err| {
@@ -109,3 +116,11 @@ pub(crate) fn apply_logs_migrations(
return Ok(());
}
#[derive(Clone, rust_embed::RustEmbed)]
#[folder = "migrations/main"]
struct MainMigrations;
#[derive(Clone, rust_embed::RustEmbed)]
#[folder = "migrations/logs"]
struct LogsMigrations;

View File

@@ -13,7 +13,7 @@ pub enum TransactionError {
#[error("IO error: {0}")]
IO(#[from] std::io::Error),
#[error("Migration error: {0}")]
Migration(#[from] trailbase_refinery_core::Error),
Migration(#[from] trailbase_refinery::Error),
#[error("File error: {0}")]
File(String),
}
@@ -35,7 +35,7 @@ impl TransactionLog {
conn: &trailbase_sqlite::Connection,
migration_path: impl AsRef<Path>,
filename_suffix: &str,
) -> Result<trailbase_refinery_core::Report, TransactionError> {
) -> Result<trailbase_refinery::Report, TransactionError> {
let filename = migrations::new_unique_migration_filename(filename_suffix);
let stem = Path::new(&filename)
.file_stem()
@@ -68,7 +68,7 @@ impl TransactionLog {
)
};
let migrations = vec![trailbase_refinery_core::Migration::unapplied(&stem, &sql)?];
let migrations = vec![trailbase_refinery::Migration::unapplied(&stem, &sql)?];
let runner = migrations::new_migration_runner(&migrations).set_abort_missing(false);
let report = conn

View File

@@ -0,0 +1,26 @@
[package]
name = "trailbase-refinery"
version = "0.1.0"
authors = ["TrailBase <contact@trailabse.io>"]
description = "Fork of Refinery's refinery-core/macro crates"
license = "OSL-3.0"
repository = "https://github.com/trailbaseio/trailbase"
edition = "2024"
readme = "../README.md"
[dependencies]
async-trait = "0.1"
cfg-if = "1.0"
futures = { version = "0.3.16", optional = true, features = ["async-await"] }
log = "0.4"
regex = "1"
rusqlite = { workspace = true }
siphasher = "1.0"
thiserror = "1"
time = { version = "0.3.5", features = ["parsing", "formatting"] }
tokio = { workspace = true }
url = "2.0"
walkdir = "2.3.1"
[dev-dependencies]
tempfile = "3.1.0"

View File

@@ -0,0 +1 @@
pub mod rusqlite;

View File

@@ -0,0 +1,56 @@
use crate::Migration;
use crate::traits::sync::{Migrate, Query, Transaction};
use rusqlite::{Connection as RqlConnection, Error as RqlError};
use time::OffsetDateTime;
use time::format_description::well_known::Rfc3339;
fn query_applied_migrations(
transaction: &RqlConnection,
query: &str,
) -> Result<Vec<Migration>, RqlError> {
let mut stmt = transaction.prepare(query)?;
let mut rows = stmt.query([])?;
let mut applied = Vec::new();
while let Some(row) = rows.next()? {
let version = row.get(0)?;
let applied_on: String = row.get(2)?;
// Safe to call unwrap, as we stored it in RFC3339 format on the database
let applied_on = OffsetDateTime::parse(&applied_on, &Rfc3339).unwrap();
let checksum: String = row.get(3)?;
applied.push(Migration::applied(
version,
row.get(1)?,
applied_on,
checksum
.parse::<u64>()
.expect("checksum must be a valid u64"),
));
}
Ok(applied)
}
impl Transaction for RqlConnection {
type Error = RqlError;
fn execute<'a, T: Iterator<Item = &'a str>>(&mut self, queries: T) -> Result<usize, Self::Error> {
let transaction = self.transaction()?;
let mut count = 0;
for query in queries {
transaction.execute_batch(query)?;
count += 1;
}
transaction.commit()?;
Ok(count)
}
}
impl Query<Vec<Migration>> for RqlConnection {
fn query(&mut self, query: &str) -> Result<Vec<Migration>, Self::Error> {
let transaction = self.transaction()?;
let applied = query_applied_migrations(&transaction, query)?;
transaction.commit()?;
Ok(applied)
}
}
impl Migrate for RqlConnection {}

View File

@@ -0,0 +1,95 @@
use crate::{Migration, Report};
use std::fmt;
use std::path::PathBuf;
use thiserror::Error as TError;
/// An Error occurred during a migration cycle
#[derive(Debug)]
pub struct Error {
kind: Box<Kind>,
report: Option<Report>,
}
impl Error {
/// Instantiate a new Error
pub(crate) fn new(kind: Kind, report: Option<Report>) -> Error {
Error {
kind: Box::new(kind),
report,
}
}
/// Return the Report of the migration cycle if any
pub fn report(&self) -> Option<&Report> {
self.report.as_ref()
}
/// Return the kind of error occurred
pub fn kind(&self) -> &Kind {
&self.kind
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.kind)
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.kind.source()
}
}
/// Enum listing possible errors from Refinery.
#[derive(Debug, TError)]
pub enum Kind {
/// An Error from an invalid file name migration
#[error("migration name must be in the format V{{number}}__{{name}}")]
InvalidName,
/// An Error from an invalid version on a file name migration
#[error("migration version must be a valid integer")]
InvalidVersion,
/// An Error from a repeated version, migration version numbers must be unique
#[error("migration {0} is repeated, migration versions must be unique")]
RepeatedVersion(Migration),
/// An Error from an divergent version, the applied version is different to the filesystem one
#[error("applied migration {0} is different than filesystem one {1}")]
DivergentVersion(Migration, Migration),
/// An Error from an divergent version, the applied version is missing on the filesystem
#[error("migration {0} is missing from the filesystem")]
MissingVersion(Migration),
/// An Error from an invalid migrations path location
#[error("invalid migrations path {0}, {1}")]
InvalidMigrationPath(PathBuf, std::io::Error),
/// An Error parsing refinery Config
#[error("Error parsing config: {0}")]
ConfigError(String),
/// An Error from an underlying database connection Error
#[error("`{0}`, `{1}`")]
Connection(String, #[source] Box<dyn std::error::Error + Sync + Send>),
/// An Error from an invalid migration file (not UTF-8 etc)
#[error("invalid migration file at path {0}, {1}")]
InvalidMigrationFile(PathBuf, std::io::Error),
}
// Helper trait for adding custom messages and applied migrations to Connection error's.
pub trait WrapMigrationError<T, E> {
fn migration_err(self, msg: &str, report: Option<&[Migration]>) -> Result<T, Error>;
}
impl<T, E> WrapMigrationError<T, E> for Result<T, E>
where
E: std::error::Error + Send + Sync + 'static,
{
fn migration_err(self, msg: &str, applied_migrations: Option<&[Migration]>) -> Result<T, Error> {
match self {
Ok(report) => Ok(report),
Err(err) => Err(Error {
kind: Box::new(Kind::Connection(msg.into(), Box::new(err))),
report: applied_migrations.map(|am| Report::new(am.to_vec())),
}),
}
}
}

View File

@@ -0,0 +1,15 @@
mod drivers;
pub mod error;
mod runner;
pub mod traits;
mod util;
pub use crate::error::Error;
pub use crate::runner::{Migration, Report, Runner, Target};
pub use crate::traits::r#async::AsyncMigrate;
pub use crate::traits::sync::Migrate;
pub use crate::util::{
MigrationType, find_migration_files, load_sql_migrations, parse_migration_name,
};
pub use rusqlite;

View File

@@ -0,0 +1,452 @@
use siphasher::sip::SipHasher13;
use time::OffsetDateTime;
use log::error;
use std::cmp::Ordering;
use std::collections::VecDeque;
use std::fmt;
use std::hash::{Hash, Hasher};
use crate::traits::{DEFAULT_MIGRATION_TABLE_NAME, sync::migrate as sync_migrate};
use crate::util::parse_migration_name;
use crate::{AsyncMigrate, Error, Migrate};
use std::fmt::Formatter;
/// An enum set that represents the type of the Migration
#[derive(Clone, PartialEq)]
pub enum Type {
Versioned,
Unversioned,
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let version_type = match self {
Type::Versioned => "V",
Type::Unversioned => "U",
};
write!(f, "{}", version_type)
}
}
impl fmt::Debug for Type {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let version_type = match self {
Type::Versioned => "Versioned",
Type::Unversioned => "Unversioned",
};
write!(f, "{}", version_type)
}
}
/// An enum set that represents the target version up to which refinery should migrate, it is used
/// by [Runner]
#[derive(Clone, Copy, Debug)]
pub enum Target {
Latest,
Version(u32),
Fake,
FakeVersion(u32),
}
// an Enum set that represents the state of the migration: Applied on the database,
// or Unapplied yet to be applied on the database
#[derive(Clone, Debug)]
enum State {
Applied,
Unapplied,
}
/// Represents a schema migration to be run on the database,
/// this struct is used by the [`embed_migrations!`] macro to gather migration files
/// and shouldn't be needed by the user
///
/// [`embed_migrations!`]: macro.embed_migrations.html
#[derive(Clone, Debug)]
pub struct Migration {
state: State,
name: String,
checksum: u64,
version: i32,
prefix: Type,
sql: Option<String>,
applied_on: Option<OffsetDateTime>,
}
impl Migration {
/// Create an unapplied migration, name and version are parsed from the input_name,
/// which must be named in the format (U|V){1}__{2}.rs where {1} represents the migration version
/// and {2} the name.
pub fn unapplied(input_name: &str, sql: &str) -> Result<Migration, Error> {
let (prefix, version, name) = parse_migration_name(input_name)?;
// Previously, `std::collections::hash_map::DefaultHasher` was used
// to calculate the checksum and the implementation at that time
// was SipHasher13. However, that implementation is not guaranteed:
// > The internal algorithm is not specified, and so it and its
// > hashes should not be relied upon over releases.
// We now explicitly use SipHasher13 to both remain compatible with
// existing migrations and prevent breaking from possible future
// changes to `DefaultHasher`.
let mut hasher = SipHasher13::new();
name.hash(&mut hasher);
version.hash(&mut hasher);
sql.hash(&mut hasher);
let checksum = hasher.finish();
Ok(Migration {
state: State::Unapplied,
name,
version,
prefix,
sql: Some(sql.into()),
applied_on: None,
checksum,
})
}
// Create a migration from an applied migration on the database
pub fn applied(
version: i32,
name: String,
applied_on: OffsetDateTime,
checksum: u64,
) -> Migration {
Migration {
state: State::Applied,
name,
checksum,
version,
// applied migrations are always versioned
prefix: Type::Versioned,
sql: None,
applied_on: Some(applied_on),
}
}
// convert the Unapplied into an Applied Migration
pub fn set_applied(&mut self) {
self.applied_on = Some(OffsetDateTime::now_utc());
self.state = State::Applied;
}
// Get migration sql content
pub fn sql(&self) -> Option<&str> {
self.sql.as_deref()
}
/// Get the Migration version
pub fn version(&self) -> u32 {
self.version as u32
}
/// Get the Prefix
pub fn prefix(&self) -> &Type {
&self.prefix
}
/// Get the Migration Name
pub fn name(&self) -> &str {
&self.name
}
/// Get the timestamp from when the Migration was applied. `None` when unapplied.
/// Migrations returned from Runner::get_migrations() will always have `None`.
pub fn applied_on(&self) -> Option<&OffsetDateTime> {
self.applied_on.as_ref()
}
/// Get the Migration checksum. Checksum is formed from the name version and sql of the Migration
pub fn checksum(&self) -> u64 {
self.checksum
}
}
impl fmt::Display for Migration {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "{}{}__{}", self.prefix, self.version, self.name)
}
}
impl Eq for Migration {}
impl PartialEq for Migration {
fn eq(&self, other: &Migration) -> bool {
self.version == other.version && self.name == other.name && self.checksum() == other.checksum()
}
}
impl Ord for Migration {
fn cmp(&self, other: &Migration) -> Ordering {
self.version.cmp(&other.version)
}
}
impl PartialOrd for Migration {
fn partial_cmp(&self, other: &Migration) -> Option<Ordering> {
Some(self.cmp(other))
}
}
/// Struct that represents the report of the migration cycle,
/// a `Report` instance is returned by the [`Runner::run`] and [`Runner::run_async`] methods
/// via [`Result`]`<Report, Error>`, on case of an [`Error`] during a migration, you can access the
/// `Report` with [`Error.report`]
///
/// [`Error`]: struct.Error.html
/// [`Runner::run`]: struct.Runner.html#method.run
/// [`Runner::run_async`]: struct.Runner.html#method.run_async
/// [`Result`]: https://doc.rust-lang.org/std/result/enum.Result.html
/// [`Error.report`]: struct.Error.html#method.report
#[derive(Clone, Debug)]
pub struct Report {
applied_migrations: Vec<Migration>,
}
impl Report {
/// Instantiate a new Report
pub(crate) fn new(applied_migrations: Vec<Migration>) -> Report {
Report { applied_migrations }
}
/// Retrieves the list of applied `Migration` of the migration cycle
pub fn applied_migrations(&self) -> &Vec<Migration> {
&self.applied_migrations
}
}
/// Struct that represents the entrypoint to run the migrations,
/// an instance of this struct is returned by the [`embed_migrations!`] macro.
/// `Runner` should not need to be instantiated manually
///
/// [`embed_migrations!`]: macro.embed_migrations.html
pub struct Runner {
grouped: bool,
abort_divergent: bool,
abort_missing: bool,
migrations: Vec<Migration>,
target: Target,
migration_table_name: String,
}
impl Runner {
/// instantiate a new Runner
pub fn new(migrations: &[Migration]) -> Runner {
Runner {
grouped: false,
target: Target::Latest,
abort_divergent: true,
abort_missing: true,
migrations: migrations.to_vec(),
migration_table_name: DEFAULT_MIGRATION_TABLE_NAME.into(),
}
}
/// Get the gathered migrations.
pub fn get_migrations(&self) -> &Vec<Migration> {
&self.migrations
}
/// Set the target version up to which refinery should migrate, Latest migrates to the latest
/// version available Version migrates to a user provided version, a Version with a higher
/// version than the latest will be ignored, and Fake doesn't actually run any migration, just
/// creates and updates refinery's schema migration table by default this is set to Latest
pub fn set_target(self, target: Target) -> Runner {
Runner { target, ..self }
}
/// Set true if all migrations should be grouped and run in a single transaction.
/// by default this is set to false, each migration runs on their own transaction
///
/// # Note
///
/// set_grouped won't probably work on MySQL Databases as MySQL lacks support for transactions
/// around schema alteration operations, meaning that if a migration fails to apply you will
/// have to manually unpick the changes in order to try again (its impossible to roll back to an
/// earlier point).
pub fn set_grouped(self, grouped: bool) -> Runner {
Runner { grouped, ..self }
}
/// Set true if migration process should abort if divergent migrations are found
/// i.e. applied migrations with the same version but different name or checksum from the ones on
/// the filesystem. by default this is set to true
pub fn set_abort_divergent(self, abort_divergent: bool) -> Runner {
Runner {
abort_divergent,
..self
}
}
/// Set true if migration process should abort if missing migrations are found
/// i.e. applied migrations that are not found on the filesystem,
/// or migrations found on filesystem with a version inferior to the last one applied but not
/// applied. by default this is set to true
pub fn set_abort_missing(self, abort_missing: bool) -> Runner {
Runner {
abort_missing,
..self
}
}
/// Queries the database for the last applied migration, returns None if there aren't applied
/// Migrations
pub fn get_last_applied_migration<C>(&self, conn: &'_ mut C) -> Result<Option<Migration>, Error>
where
C: Migrate,
{
Migrate::get_last_applied_migration(conn, &self.migration_table_name)
}
/// Queries the database asynchronously for the last applied migration, returns None if there
/// aren't applied Migrations
pub async fn get_last_applied_migration_async<C>(
&self,
conn: &mut C,
) -> Result<Option<Migration>, Error>
where
C: AsyncMigrate + Send,
{
AsyncMigrate::get_last_applied_migration(conn, &self.migration_table_name).await
}
/// Queries the database for all previous applied migrations
pub fn get_applied_migrations<C>(&self, conn: &'_ mut C) -> Result<Vec<Migration>, Error>
where
C: Migrate,
{
Migrate::get_applied_migrations(conn, &self.migration_table_name)
}
/// Queries the database asynchronously for all previous applied migrations
pub async fn get_applied_migrations_async<C>(&self, conn: &mut C) -> Result<Vec<Migration>, Error>
where
C: AsyncMigrate + Send,
{
AsyncMigrate::get_applied_migrations(conn, &self.migration_table_name).await
}
/// Set the table name to use for the migrations table. The default name is
/// `refinery_schema_history`
///
/// ### Warning
/// Changing this can be disastrous for your database. You should verify that the migrations table
/// has the same name as the name you specify here, if this is changed on an existing project.
///
/// # Panics
///
/// If the provided `migration_table_name` is empty
pub fn set_migration_table_name<S: AsRef<str>>(&mut self, migration_table_name: S) -> &mut Self {
if migration_table_name.as_ref().is_empty() {
panic!("Migration table name must not be empty");
}
self.migration_table_name = migration_table_name.as_ref().to_string();
self
}
/// Creates an iterator over pending migrations, applying each before returning
/// the result from `next()`. If a migration fails, the iterator will return that
/// result and further calls to `next()` will return `None`.
pub fn run_iter<C>(
self,
connection: &mut C,
) -> impl Iterator<Item = Result<Migration, Error>> + '_
where
C: Migrate,
{
RunIterator::new(self, connection)
}
/// Runs the Migrations in the supplied database connection
pub fn run<C>(&self, connection: &mut C) -> Result<Report, Error>
where
C: Migrate,
{
Migrate::migrate(
connection,
&self.migrations,
self.abort_divergent,
self.abort_missing,
self.grouped,
self.target,
&self.migration_table_name,
)
}
/// Runs the Migrations asynchronously in the supplied database connection
pub async fn run_async<C>(&self, connection: &mut C) -> Result<Report, Error>
where
C: AsyncMigrate + Send,
{
AsyncMigrate::migrate(
connection,
&self.migrations,
self.abort_divergent,
self.abort_missing,
self.grouped,
self.target,
&self.migration_table_name,
)
.await
}
}
pub struct RunIterator<'a, C> {
connection: &'a mut C,
target: Target,
migration_table_name: String,
items: VecDeque<Migration>,
failed: bool,
}
impl<'a, C> RunIterator<'a, C>
where
C: Migrate,
{
pub(crate) fn new(runner: Runner, connection: &'a mut C) -> RunIterator<'a, C> {
RunIterator {
items: VecDeque::from(
Migrate::get_unapplied_migrations(
connection,
&runner.migrations,
runner.abort_divergent,
runner.abort_missing,
&runner.migration_table_name,
)
.unwrap(),
),
connection,
target: runner.target,
migration_table_name: runner.migration_table_name.clone(),
failed: false,
}
}
}
impl<C> Iterator for RunIterator<'_, C>
where
C: Migrate,
{
type Item = Result<Migration, Error>;
fn next(&mut self) -> Option<Self::Item> {
match self.failed {
true => None,
false => self.items.pop_front().and_then(|migration| {
sync_migrate(
self.connection,
vec![migration],
self.target,
&self.migration_table_name,
false,
)
.map(|r| r.applied_migrations.first().cloned())
.map_err(|e| {
error!("migration failed: {e:?}");
self.failed = true;
e
})
.transpose()
}),
}
}
}

View File

@@ -0,0 +1,197 @@
use crate::error::WrapMigrationError;
use crate::traits::{
ASSERT_MIGRATIONS_TABLE_QUERY, GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY,
insert_migration_query, verify_migrations,
};
use crate::{Error, Migration, Report, Target};
use async_trait::async_trait;
use std::string::ToString;
#[async_trait]
pub trait AsyncTransaction {
type Error: std::error::Error + Send + Sync + 'static;
async fn execute<'a, T: Iterator<Item = &'a str> + Send>(
&mut self,
queries: T,
) -> Result<usize, Self::Error>;
}
#[async_trait]
pub trait AsyncQuery<T>: AsyncTransaction {
async fn query(&mut self, query: &str) -> Result<T, Self::Error>;
}
async fn migrate<T: AsyncTransaction>(
transaction: &mut T,
migrations: Vec<Migration>,
target: Target,
migration_table_name: &str,
) -> Result<Report, Error> {
let mut applied_migrations = vec![];
for mut migration in migrations.into_iter() {
if let Target::Version(input_target) = target {
if input_target < migration.version() {
log::info!(
"stopping at migration: {}, due to user option",
input_target
);
break;
}
}
log::info!("applying migration: {}", migration);
migration.set_applied();
let update_query = insert_migration_query(&migration, migration_table_name);
transaction
.execute(
[
migration.sql().as_ref().expect("sql must be Some!"),
update_query.as_str(),
]
.into_iter(),
)
.await
.migration_err(
&format!("error applying migration {}", migration),
Some(&applied_migrations),
)?;
applied_migrations.push(migration);
}
Ok(Report::new(applied_migrations))
}
async fn migrate_grouped<T: AsyncTransaction>(
transaction: &mut T,
migrations: Vec<Migration>,
target: Target,
migration_table_name: &str,
) -> Result<Report, Error> {
let mut grouped_migrations = Vec::new();
let mut applied_migrations = Vec::new();
for mut migration in migrations.into_iter() {
if let Target::Version(input_target) | Target::FakeVersion(input_target) = target {
if input_target < migration.version() {
break;
}
}
migration.set_applied();
let query = insert_migration_query(&migration, migration_table_name);
let sql = migration.sql().expect("sql must be Some!").to_string();
// If Target is Fake, we only update schema migrations table
if !matches!(target, Target::Fake | Target::FakeVersion(_)) {
applied_migrations.push(migration);
grouped_migrations.push(sql);
}
grouped_migrations.push(query);
}
match target {
Target::Fake | Target::FakeVersion(_) => {
log::info!("not going to apply any migration as fake flag is enabled");
}
Target::Latest | Target::Version(_) => {
log::info!(
"going to apply batch migrations in single transaction: {:#?}",
applied_migrations.iter().map(ToString::to_string)
);
}
};
if let Target::Version(input_target) = target {
log::info!(
"stopping at migration: {}, due to user option",
input_target
);
}
let refs = grouped_migrations.iter().map(AsRef::as_ref);
transaction
.execute(refs)
.await
.migration_err("error applying migrations", None)?;
Ok(Report::new(applied_migrations))
}
#[async_trait]
pub trait AsyncMigrate: AsyncQuery<Vec<Migration>>
where
Self: Sized,
{
// Needed cause some database vendors like Mssql have a non sql standard way of checking the
// migrations table
fn assert_migrations_table_query(migration_table_name: &str) -> String {
ASSERT_MIGRATIONS_TABLE_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
}
async fn get_last_applied_migration(
&mut self,
migration_table_name: &str,
) -> Result<Option<Migration>, Error> {
let mut migrations = self
.query(
&GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name),
)
.await
.migration_err("error getting last applied migration", None)?;
Ok(migrations.pop())
}
async fn get_applied_migrations(
&mut self,
migration_table_name: &str,
) -> Result<Vec<Migration>, Error> {
let migrations = self
.query(&GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name))
.await
.migration_err("error getting applied migrations", None)?;
Ok(migrations)
}
async fn migrate(
&mut self,
migrations: &[Migration],
abort_divergent: bool,
abort_missing: bool,
grouped: bool,
target: Target,
migration_table_name: &str,
) -> Result<Report, Error> {
self
.execute([Self::assert_migrations_table_query(migration_table_name).as_str()].into_iter())
.await
.migration_err("error asserting migrations table", None)?;
let applied_migrations = self
.query(&GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name))
.await
.migration_err("error getting current schema version", None)?;
let migrations = verify_migrations(
applied_migrations,
migrations.to_vec(),
abort_divergent,
abort_missing,
)?;
if migrations.is_empty() {
log::info!("no migrations to apply");
}
if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) {
migrate_grouped(self, migrations, target, migration_table_name).await
} else {
migrate(self, migrations, target, migration_table_name).await
}
}
}

View File

@@ -0,0 +1,309 @@
use time::format_description::well_known::Rfc3339;
pub mod r#async;
pub mod sync;
use crate::runner::Type;
use crate::{Error, Migration, error::Kind};
// Verifies applied and to be applied migrations returning Error if:
// - `abort_divergent` is true and there are applied migrations with a different name and checksum
// but same version as a migration to be applied.
// - `abort_missing` is true and there are applied migrations that are missing on the file system
// - there are repeated migrations with the same version to be applied
pub(crate) fn verify_migrations(
applied: Vec<Migration>,
mut migrations: Vec<Migration>,
abort_divergent: bool,
abort_missing: bool,
) -> Result<Vec<Migration>, Error> {
migrations.sort();
for app in &applied {
// iterate applied migrations on database and assert all migrations
// applied on database exist on the file system and have the same checksum
if migrations
.binary_search_by(|m| m.version().cmp(&app.version()))
.is_err()
{
if abort_missing {
return Err(Error::new(Kind::MissingVersion(app.clone()), None));
} else {
log::warn!(target: "trailbase_refinery::traits::missing", "migration {} is missing from the filesystem", app);
}
}
}
let stale = |migration: &Migration| {
applied
.last()
.map(|latest| latest.version() >= migration.version())
.unwrap_or(false)
};
let mut to_be_applied = Vec::new();
// iterate all migration files found on file system and assert that there are not migrations
// missing: migrations which its version is inferior to the current version on the database, yet
// were not applied. select to be applied all migrations with version greater than current
for migration in migrations {
if let Some(app) = applied
.iter()
.find(|app| app.version() == migration.version())
{
if *app == migration {
continue;
}
if abort_divergent {
return Err(Error::new(
Kind::DivergentVersion(app.clone(), migration.clone()),
None,
));
}
log::warn!(
target: "trailbase_refinery::traits::divergent",
"applied migration {app} is different than filesystem one {migration} => skipping {migration}",
);
continue;
}
if to_be_applied.contains(&migration) {
return Err(Error::new(Kind::RepeatedVersion(migration), None));
}
if migration.prefix() == &Type::Versioned && stale(&migration) {
if abort_missing {
return Err(Error::new(Kind::MissingVersion(migration), None));
}
log::error!(target: "trailbase_refinery::traits::missing", "found strictly versioned, not applied migration on file system => skipping: {migration}");
continue;
}
to_be_applied.push(migration);
}
// with these two iterations we both assert that all migrations found on the database
// exist on the file system and have the same checksum, and all migrations found
// on the file system are either on the database, or greater than the current, and therefore going
// to be applied
Ok(to_be_applied)
}
pub(crate) fn insert_migration_query(migration: &Migration, migration_table_name: &str) -> String {
format!(
"INSERT INTO {} (version, name, applied_on, checksum) VALUES ({}, '{}', '{}', '{}')",
// safe to call unwrap as we just converted it to applied, and we are sure it can be formatted
// according to RFC 33339
migration_table_name,
migration.version(),
migration.name(),
migration.applied_on().unwrap().format(&Rfc3339).unwrap(),
migration.checksum()
)
}
pub(crate) const ASSERT_MIGRATIONS_TABLE_QUERY: &str =
"CREATE TABLE IF NOT EXISTS %MIGRATION_TABLE_NAME%(
version INT4 PRIMARY KEY,
name VARCHAR(255),
applied_on VARCHAR(255),
checksum VARCHAR(255));";
pub(crate) const GET_APPLIED_MIGRATIONS_QUERY: &str = "SELECT version, name, applied_on, checksum \
FROM %MIGRATION_TABLE_NAME% ORDER BY version ASC;";
pub(crate) const GET_LAST_APPLIED_MIGRATION_QUERY: &str =
"SELECT version, name, applied_on, checksum
FROM %MIGRATION_TABLE_NAME% WHERE version=(SELECT MAX(version) from %MIGRATION_TABLE_NAME%)";
pub(crate) const DEFAULT_MIGRATION_TABLE_NAME: &str = "_schema_history";
#[cfg(test)]
mod tests {
use super::{Kind, Migration, verify_migrations};
use std::include_str;
fn get_migrations() -> Vec<Migration> {
let migration1 = Migration::unapplied(
"V1__initial.sql",
"CREATE TABLE persons (id int, name varchar(255), city varchar(255));",
)
.unwrap();
let migration2 = Migration::unapplied(
"V2__add_cars_and_motos_table.sql",
include_str!("../../tests/migrations/V1-2/V2__add_cars_and_motos_table.sql"),
)
.unwrap();
let migration3 = Migration::unapplied(
"V3__add_brand_to_cars_table",
include_str!("../../tests/migrations/V3/V3__add_brand_to_cars_table.sql"),
)
.unwrap();
let migration4 = Migration::unapplied(
"V4__add_year_field_to_cars",
"ALTER TABLE cars ADD year INTEGER;",
)
.unwrap();
vec![migration1, migration2, migration3, migration4]
}
#[test]
fn verify_migrations_returns_all_migrations_if_applied_are_empty() {
let migrations = get_migrations();
let applied: Vec<Migration> = Vec::new();
let result = verify_migrations(applied, migrations.clone(), true, true).unwrap();
assert_eq!(migrations, result);
}
#[test]
fn verify_migrations_returns_unapplied() {
let migrations = get_migrations();
let applied: Vec<Migration> = vec![
migrations[0].clone(),
migrations[1].clone(),
migrations[2].clone(),
];
let remaining = vec![migrations[3].clone()];
let result = verify_migrations(applied, migrations, true, true).unwrap();
assert_eq!(remaining, result);
}
#[test]
fn verify_migrations_fails_on_divergent() {
let migrations = get_migrations();
let applied: Vec<Migration> = vec![
migrations[0].clone(),
migrations[1].clone(),
Migration::unapplied(
"V3__add_brand_to_cars_tableeee",
include_str!("../../tests/migrations/V3/V3__add_brand_to_cars_table.sql"),
)
.unwrap(),
];
let migration = migrations[2].clone();
let err = verify_migrations(applied, migrations, true, true).unwrap_err();
match err.kind() {
Kind::DivergentVersion(applied, divergent) => {
assert_eq!(&migration, divergent);
assert_eq!("add_brand_to_cars_tableeee", applied.name());
}
_ => panic!("failed test"),
}
}
#[test]
fn verify_migrations_doesnt_fail_on_divergent() {
let migrations = get_migrations();
let applied: Vec<Migration> = vec![
migrations[0].clone(),
migrations[1].clone(),
Migration::unapplied(
"V3__add_brand_to_cars_tableeee",
include_str!("../../tests/migrations/V3/V3__add_brand_to_cars_table.sql"),
)
.unwrap(),
];
let remaining = vec![migrations[3].clone()];
let result = verify_migrations(applied, migrations, false, true).unwrap();
assert_eq!(remaining, result);
}
#[test]
fn verify_migrations_fails_on_missing_on_applied() {
let migrations = get_migrations();
let applied: Vec<Migration> = vec![migrations[0].clone(), migrations[2].clone()];
let migration = migrations[1].clone();
let err = verify_migrations(applied, migrations, true, true).unwrap_err();
match err.kind() {
Kind::MissingVersion(missing) => {
assert_eq!(&migration, missing);
}
_ => panic!("failed test"),
}
}
#[test]
fn verify_migrations_fails_on_missing_on_filesystem() {
let mut migrations = get_migrations();
let applied: Vec<Migration> = vec![
migrations[0].clone(),
migrations[1].clone(),
migrations[2].clone(),
];
let migration = migrations.remove(1);
let err = verify_migrations(applied, migrations, true, true).unwrap_err();
match err.kind() {
Kind::MissingVersion(missing) => {
assert_eq!(&migration, missing);
}
_ => panic!("failed test"),
}
}
#[test]
fn verify_migrations_doesnt_fail_on_missing_on_applied() {
let migrations = get_migrations();
let applied: Vec<Migration> = vec![migrations[0].clone(), migrations[2].clone()];
let remaining = vec![migrations[3].clone()];
let result = verify_migrations(applied, migrations, true, false).unwrap();
assert_eq!(remaining, result);
}
#[test]
fn verify_migrations_doesnt_fail_on_missing_on_filesystem() {
let mut migrations = get_migrations();
let applied: Vec<Migration> = vec![
migrations[0].clone(),
migrations[1].clone(),
migrations[2].clone(),
];
migrations.remove(1);
let remaining = vec![migrations[2].clone()];
let result = verify_migrations(applied, migrations, true, false).unwrap();
assert_eq!(remaining, result);
}
#[test]
fn verify_migrations_checks_unversioned_out_of_order_doesnt_fail() {
let mut migrations = get_migrations();
migrations.push(
Migration::unapplied(
"U0__merge_out_of_order",
include_str!("../../tests/migrations_unversioned/U0__merge_out_of_order.sql"),
)
.unwrap(),
);
let applied: Vec<Migration> = vec![
migrations[0].clone(),
migrations[1].clone(),
migrations[2].clone(),
migrations[3].clone(),
];
let remaining = vec![migrations[4].clone()];
let result = verify_migrations(applied, migrations, true, true).unwrap();
assert_eq!(remaining, result);
}
#[test]
fn verify_migrations_fails_on_repeated_migration() {
let mut migrations = get_migrations();
let repeated = migrations[0].clone();
migrations.push(repeated.clone());
let err = verify_migrations(vec![], migrations, false, true).unwrap_err();
match err.kind() {
Kind::RepeatedVersion(m) => {
assert_eq!(m, &repeated);
}
_ => panic!("failed test"),
}
}
}

View File

@@ -0,0 +1,177 @@
use std::ops::Deref;
use crate::error::WrapMigrationError;
use crate::traits::{
ASSERT_MIGRATIONS_TABLE_QUERY, GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY,
insert_migration_query, verify_migrations,
};
use crate::{Error, Migration, Report, Target};
pub trait Transaction {
type Error: std::error::Error + Send + Sync + 'static;
fn execute<'a, T: Iterator<Item = &'a str>>(&mut self, queries: T) -> Result<usize, Self::Error>;
}
pub trait Query<T>: Transaction {
fn query(&mut self, query: &str) -> Result<T, Self::Error>;
}
pub fn migrate<T: Transaction>(
transaction: &mut T,
migrations: Vec<Migration>,
target: Target,
migration_table_name: &str,
grouped: bool,
) -> Result<Report, Error> {
let mut migration_batch = Vec::new();
let mut applied_migrations = Vec::new();
for mut migration in migrations.into_iter() {
if let Target::Version(input_target) | Target::FakeVersion(input_target) = target {
if input_target < migration.version() {
log::info!(
"stopping at migration: {}, due to user option",
input_target
);
break;
}
}
log::info!("applying migration: {}", migration);
migration.set_applied();
let insert_migration = insert_migration_query(&migration, migration_table_name);
let migration_sql = migration.sql().expect("sql must be Some!").to_string();
// If Target is Fake, we only update schema migrations table
if !matches!(target, Target::Fake | Target::FakeVersion(_)) {
applied_migrations.push(migration);
migration_batch.push(migration_sql);
}
migration_batch.push(insert_migration);
}
match (target, grouped) {
(Target::Fake | Target::FakeVersion(_), _) => {
log::info!("not going to apply any migration as fake flag is enabled");
}
(Target::Latest | Target::Version(_), true) => {
log::info!(
"going to apply batch migrations in single transaction: {:#?}",
applied_migrations.iter().map(ToString::to_string)
);
}
(Target::Latest | Target::Version(_), false) => {
log::info!(
"preparing to apply {} migrations: {:#?}",
applied_migrations.len(),
applied_migrations.iter().map(ToString::to_string)
);
}
};
if grouped {
transaction
.execute(migration_batch.iter().map(Deref::deref))
.migration_err("error applying migrations", None)?;
} else {
for (i, update) in migration_batch.into_iter().enumerate() {
transaction
.execute([update.as_str()].into_iter())
.migration_err("error applying update", Some(&applied_migrations[0..i / 2]))?;
}
}
Ok(Report::new(applied_migrations))
}
pub trait Migrate: Query<Vec<Migration>>
where
Self: Sized,
{
fn assert_migrations_table(&mut self, migration_table_name: &str) -> Result<usize, Error> {
// Needed cause some database vendors like Mssql have a non sql standard way of checking the
// migrations table, thou on this case it's just to be consistent with the async trait
// `AsyncMigrate`
self
.execute(
[ASSERT_MIGRATIONS_TABLE_QUERY
.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
.as_str()]
.into_iter(),
)
.migration_err("error asserting migrations table", None)
}
fn get_last_applied_migration(
&mut self,
migration_table_name: &str,
) -> Result<Option<Migration>, Error> {
let mut migrations = self
.query(
&GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name),
)
.migration_err("error getting last applied migration", None)?;
Ok(migrations.pop())
}
fn get_applied_migrations(
&mut self,
migration_table_name: &str,
) -> Result<Vec<Migration>, Error> {
let migrations = self
.query(&GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name))
.migration_err("error getting applied migrations", None)?;
Ok(migrations)
}
fn get_unapplied_migrations(
&mut self,
migrations: &[Migration],
abort_divergent: bool,
abort_missing: bool,
migration_table_name: &str,
) -> Result<Vec<Migration>, Error> {
self.assert_migrations_table(migration_table_name)?;
let applied_migrations = self.get_applied_migrations(migration_table_name)?;
let migrations = verify_migrations(
applied_migrations,
migrations.to_vec(),
abort_divergent,
abort_missing,
)?;
if migrations.is_empty() {
log::info!("no migrations to apply");
}
Ok(migrations)
}
fn migrate(
&mut self,
migrations: &[Migration],
abort_divergent: bool,
abort_missing: bool,
grouped: bool,
target: Target,
migration_table_name: &str,
) -> Result<Report, Error> {
let migrations = self.get_unapplied_migrations(
migrations,
abort_divergent,
abort_missing,
migration_table_name,
)?;
if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) {
migrate(self, migrations, target, migration_table_name, true)
} else {
migrate(self, migrations, target, migration_table_name, false)
}
}
}

View File

@@ -0,0 +1,242 @@
use crate::Migration;
use crate::error::{Error, Kind};
use crate::runner::Type;
use regex::Regex;
use std::ffi::OsStr;
use std::path::{Path, PathBuf};
use std::sync::OnceLock;
use walkdir::{DirEntry, WalkDir};
const STEM_RE: &str = r"^([U|V])(\d+(?:\.\d+)?)__(\w+)";
/// Matches the stem of a migration file.
fn file_stem_re() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| Regex::new(STEM_RE).unwrap())
}
/// Matches the stem + extension of a SQL migration file.
fn file_re_sql() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| Regex::new([STEM_RE, r"\.sql$"].concat().as_str()).unwrap())
}
/// Matches the stem + extension of any migration file.
fn file_re_all() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| Regex::new([STEM_RE, r"\.(rs|sql)$"].concat().as_str()).unwrap())
}
/// enum containing the migration types used to search for migrations
/// either just .sql files or both .sql and .rs
pub enum MigrationType {
All,
Sql,
}
impl MigrationType {
fn file_match_re(&self) -> &'static Regex {
match self {
MigrationType::All => file_re_all(),
MigrationType::Sql => file_re_sql(),
}
}
}
/// Parse a migration filename stem into a prefix, version, and name.
pub fn parse_migration_name(name: &str) -> Result<(Type, i32, String), Error> {
let captures = file_stem_re()
.captures(name)
.filter(|caps| caps.len() == 4)
.ok_or_else(|| Error::new(Kind::InvalidName, None))?;
let version: i32 = captures[2]
.parse()
.map_err(|_| Error::new(Kind::InvalidVersion, None))?;
let name: String = (&captures[3]).into();
let prefix = match &captures[1] {
"V" => Type::Versioned,
"U" => Type::Unversioned,
_ => unreachable!(),
};
Ok((prefix, version, name))
}
/// find migrations on file system recursively across directories given a location and
/// [MigrationType]
pub fn find_migration_files(
location: impl AsRef<Path>,
migration_type: MigrationType,
) -> Result<impl Iterator<Item = PathBuf>, Error> {
let location: &Path = location.as_ref();
let location = location.canonicalize().map_err(|err| {
Error::new(
Kind::InvalidMigrationPath(location.to_path_buf(), err),
None,
)
})?;
let re = migration_type.file_match_re();
let file_paths = WalkDir::new(location)
.into_iter()
.filter_map(Result::ok)
.map(DirEntry::into_path)
// filter by migration file regex
.filter(
move |entry| match entry.file_name().and_then(OsStr::to_str) {
Some(_) if entry.is_dir() => false,
Some(file_name) if re.is_match(file_name) => true,
Some(file_name) => {
log::warn!(
"File \"{}\" does not adhere to the migration naming convention. Migrations must be named in the format [U|V]{{1}}__{{2}}.sql or [U|V]{{1}}__{{2}}.rs, where {{1}} represents the migration version and {{2}} the name.",
file_name
);
false
}
None => false,
},
);
Ok(file_paths)
}
/// Loads SQL migrations from a path. This enables dynamic migration discovery, as opposed to
/// embedding. The resulting collection is ordered by version.
pub fn load_sql_migrations(location: impl AsRef<Path>) -> Result<Vec<Migration>, Error> {
let migration_files = find_migration_files(location, MigrationType::Sql)?;
let mut migrations = vec![];
for path in migration_files {
let sql = std::fs::read_to_string(path.as_path()).map_err(|e| {
let path = path.to_owned();
let kind = match e.kind() {
std::io::ErrorKind::NotFound => Kind::InvalidMigrationPath(path, e),
_ => Kind::InvalidMigrationFile(path, e),
};
Error::new(kind, None)
})?;
//safe to call unwrap as find_migration_filenames returns canonical paths
let filename = path
.file_stem()
.and_then(|file| file.to_os_string().into_string().ok())
.unwrap();
let migration = Migration::unapplied(&filename, &sql)?;
migrations.push(migration);
}
migrations.sort();
Ok(migrations)
}
#[cfg(test)]
mod tests {
use super::{MigrationType, find_migration_files, load_sql_migrations};
use std::fs;
use std::path::PathBuf;
use tempfile::TempDir;
#[test]
fn finds_mod_migrations() {
let tmp_dir = TempDir::new().unwrap();
let migrations_dir = tmp_dir.path().join("migrations");
fs::create_dir(&migrations_dir).unwrap();
let sql1 = migrations_dir.join("V1__first.rs");
fs::File::create(&sql1).unwrap();
let sql2 = migrations_dir.join("V2__second.rs");
fs::File::create(&sql2).unwrap();
let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
.unwrap()
.collect();
mods.sort();
assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
}
#[test]
fn ignores_mod_files_without_migration_regex_match() {
let tmp_dir = TempDir::new().unwrap();
let migrations_dir = tmp_dir.path().join("migrations");
fs::create_dir(&migrations_dir).unwrap();
let sql1 = migrations_dir.join("V1first.rs");
fs::File::create(sql1).unwrap();
let sql2 = migrations_dir.join("V2second.rs");
fs::File::create(sql2).unwrap();
let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap();
assert!(mods.next().is_none());
}
#[test]
fn finds_sql_migrations() {
let tmp_dir = TempDir::new().unwrap();
let migrations_dir = tmp_dir.path().join("migrations");
fs::create_dir(&migrations_dir).unwrap();
let sql1 = migrations_dir.join("V1__first.sql");
fs::File::create(&sql1).unwrap();
let sql2 = migrations_dir.join("V2__second.sql");
fs::File::create(&sql2).unwrap();
let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
.unwrap()
.collect();
mods.sort();
assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
}
#[test]
fn finds_unversioned_migrations() {
let tmp_dir = TempDir::new().unwrap();
let migrations_dir = tmp_dir.path().join("migrations");
fs::create_dir(&migrations_dir).unwrap();
let sql1 = migrations_dir.join("U1__first.sql");
fs::File::create(&sql1).unwrap();
let sql2 = migrations_dir.join("U2__second.sql");
fs::File::create(&sql2).unwrap();
let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
.unwrap()
.collect();
mods.sort();
assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
}
#[test]
fn ignores_sql_files_without_migration_regex_match() {
let tmp_dir = TempDir::new().unwrap();
let migrations_dir = tmp_dir.path().join("migrations");
fs::create_dir(&migrations_dir).unwrap();
let sql1 = migrations_dir.join("V1first.sql");
fs::File::create(sql1).unwrap();
let sql2 = migrations_dir.join("V2second.sql");
fs::File::create(sql2).unwrap();
let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap();
assert!(mods.next().is_none());
}
#[test]
fn loads_migrations_from_path() {
let tmp_dir = TempDir::new().unwrap();
let migrations_dir = tmp_dir.path().join("migrations");
fs::create_dir(&migrations_dir).unwrap();
let sql1 = migrations_dir.join("V1__first.sql");
fs::File::create(&sql1).unwrap();
let sql2 = migrations_dir.join("V2__second.sql");
fs::File::create(&sql2).unwrap();
let rs3 = migrations_dir.join("V3__third.rs");
fs::File::create(&rs3).unwrap();
let migrations = load_sql_migrations(migrations_dir).unwrap();
assert_eq!(migrations.len(), 2);
assert_eq!(&migrations[0].to_string(), "V1__first");
assert_eq!(&migrations[1].to_string(), "V2__second");
}
}

View File

@@ -0,0 +1,15 @@
use barrel::{types, Migration};
use crate::Sql;
pub fn migration() -> String {
let mut m = Migration::new();
m.create_table("persons", |t| {
t.add_column("id", types::primary());
t.add_column("name", types::varchar(255));
t.add_column("city", types::varchar(255));
});
m.make::<Sql>()
}

View File

@@ -0,0 +1,8 @@
CREATE TABLE cars (
id int,
name varchar(255)
);
CREATE TABLE motos (
id int,
name varchar(255)
);

View File

@@ -0,0 +1,2 @@
ALTER TABLE cars
ADD brand varchar(255);

View File

@@ -0,0 +1,13 @@
use barrel::{types, Migration};
use crate::Sql;
pub fn migration() -> String {
let mut m = Migration::new();
m.change_table("motos", |t| {
t.add_column("brand", types::varchar(255).nullable(true));
});
m.make::<Sql>()
}

View File

@@ -0,0 +1,5 @@
CREATE TABLE person_finance (
id int,
person_id int,
amount int
);

View File

@@ -15,7 +15,7 @@ path = "benches/benchmark.rs"
harness = false
[dependencies]
base64 = { version = "0.22.1", default-features = false }
base64 = { version = "0.22.1", default-features = false, features = ["alloc"] }
crossbeam-channel = "0.5.13"
kanal = "0.1.1"
log = { version = "^0.4.21", default-features = false }

1
vendor/refinery vendored

Submodule vendor/refinery deleted from d8c3558786