use rusqlite::NO_PARAMS; use rusqlite::{params, Connection}; use std::collections::BTreeMap; use std::ops::Bound::{Excluded, Unbounded}; use thiserror::Error; use tracing::{debug, trace, info}; #[derive(Debug, Error)] pub enum MigrationError { #[error("sql error: {0}")] SqlError(#[from] rusqlite::Error), #[error("migration version {0} up failed, sql error: {1}")] MigrationUpFailed(MigrationVersion, rusqlite::Error), #[error("migration version {0} down failed, sql error: {1}")] MigrationDownFailed(MigrationVersion, rusqlite::Error), #[error("database version {0} too new to migrate")] VersionTooNew(MigrationVersion), } pub type MigrationResult = Result; pub type MigrationVersion = u32; pub trait Migration { fn version(&self) -> MigrationVersion; fn up(&self, conn: &Connection) -> MigrationResult<()>; fn down(&self, conn: &Connection) -> MigrationResult<()>; } pub struct SimpleMigration { pub version: MigrationVersion, pub up_sql: String, pub down_sql: String, } impl SimpleMigration { pub fn new( version: MigrationVersion, up_sql: T1, down_sql: T2, ) -> Self { Self { version, up_sql: up_sql.to_string(), down_sql: down_sql.to_string(), } } pub fn new_box( version: MigrationVersion, up_sql: T1, down_sql: T2, ) -> Box { Box::new(Self::new(version, up_sql, down_sql)) } } impl Migration for SimpleMigration { fn version(&self) -> MigrationVersion { self.version } fn up(&self, conn: &Connection) -> MigrationResult<()> { conn.execute_batch(&self.up_sql) .map_err(|sql_err| MigrationError::MigrationUpFailed(self.version, sql_err))?; Ok(()) } fn down(&self, conn: &Connection) -> MigrationResult<()> { conn.execute_batch(&self.down_sql) .map_err(|sql_err| MigrationError::MigrationDownFailed(self.version, sql_err))?; Ok(()) } } pub const NO_MIGRATIONS: MigrationVersion = 0; pub fn get_db_version(conn: &Connection) -> MigrationResult { let table_count: u32 = conn.query_row( "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='db_version'", NO_PARAMS, |row| row.get(0), )?; if table_count == 0 { return Ok(NO_MIGRATIONS); } let version: u32 = conn.query_row( "SELECT v.version \ FROM db_version AS v \ WHERE id = 1", NO_PARAMS, |row| row.get(0), )?; Ok(version) } pub fn set_db_version(conn: &Connection, version: MigrationVersion) -> MigrationResult<()> { conn.execute( " CREATE TABLE IF NOT EXISTS db_version ( id INTEGER PRIMARY KEY, version INTEGER );", NO_PARAMS, )?; conn.execute( " INSERT OR REPLACE INTO db_version (id, version) VALUES (1, ?1);", params![version], )?; Ok(()) } pub struct Migrations { migrations: BTreeMap>, } impl Migrations { pub fn new() -> Self { Self { migrations: BTreeMap::new(), } } pub fn add(&mut self, migration: Box) { assert!( migration.version() != NO_MIGRATIONS, "migration has bad vesion" ); self.migrations.insert(migration.version(), migration); } pub fn apply(&self, conn: &mut Connection) -> MigrationResult<()> { let db_version = get_db_version(conn)?; trace!(db_version, "read db_version"); if db_version != 0 && !self.migrations.contains_key(&db_version) { return Err(MigrationError::VersionTooNew(db_version)); } let mig_range = self.migrations.range((Excluded(db_version), Unbounded)); let trans = conn.transaction()?; let mut last_ver: MigrationVersion = NO_MIGRATIONS; for (ver, mig) in mig_range { debug!(version = ver, "applying migration version"); mig.up(&trans)?; last_ver = *ver; } if last_ver != NO_MIGRATIONS { info!(old_version = db_version, new_version = last_ver, "applied database migrations"); set_db_version(&trans, last_ver)?; } trans.commit()?; Ok(()) } }