use rusqlite::NO_PARAMS; use rusqlite::{params, Connection}; use std::collections::BTreeMap; use std::ops::Bound::{Excluded, Unbounded}; use thiserror::Error; use log::debug; #[derive(Debug, Error)] pub enum MigrationError { #[error("sql error: {0}")] SqlError(#[from] rusqlite::Error), #[error("database version {0} too new to migrate")] VersionTooNew(MigrationVersion), } pub type MigrationResult = Result; pub trait Migration { fn up(&self, conn: &Connection) -> MigrationResult<()>; fn down(&self, conn: &Connection) -> MigrationResult<()>; } pub struct SimpleMigration { pub up_sql: String, pub down_sql: String, } impl SimpleMigration { pub fn new(up_sql: T1, down_sql: T2) -> Self { Self { up_sql: up_sql.to_string(), down_sql: down_sql.to_string(), } } pub fn new_box(up_sql: T1, down_sql: T2) -> Box { Box::new(Self::new(up_sql, down_sql)) } } impl Migration for SimpleMigration { fn up(&self, conn: &Connection) -> MigrationResult<()> { conn.execute(&self.up_sql, NO_PARAMS)?; Ok(()) } fn down(&self, conn: &Connection) -> MigrationResult<()> { conn.execute(&self.down_sql, NO_PARAMS)?; Ok(()) } } pub type MigrationVersion = u32; pub const NO_MIGRATIONS: MigrationVersion = 0; pub fn get_db_version(conn: &Connection) -> MigrationResult { let table_count: u32 = conn.query_row( "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='db_version'", NO_PARAMS, |row| row.get(0), )?; if table_count == 0 { return Ok(NO_MIGRATIONS); } let version: u32 = conn.query_row( "SELECT version FROM db_version WHERE id = 1", NO_PARAMS, |row| row.get(0), )?; Ok(version) } pub fn set_db_version(conn: &Connection, version: MigrationVersion) -> MigrationResult<()> { conn.execute( " CREATE TABLE IF NOT EXISTS db_version ( id INTEGER PRIMARY KEY, version INTEGER );", NO_PARAMS)?; conn.execute( " INSERT OR REPLACE INTO db_version (id, version) VALUES (1, ?1);", params![version])?; Ok(()) } pub struct Migrations { migrations: BTreeMap>, } impl Migrations { pub fn new() -> Self { Self { migrations: BTreeMap::new(), } } pub fn add(&mut self, version: MigrationVersion, migration: Box) { self.migrations.insert(version, migration); } pub fn apply(&self, conn: &mut Connection) -> MigrationResult<()> { let db_version = get_db_version(conn)?; if db_version != 0 && !self.migrations.contains_key(&db_version) { return Err(MigrationError::VersionTooNew(db_version)); } let mig_range = self.migrations.range( (Excluded(db_version), Unbounded)); let mut trans = conn.transaction()?; let mut last_ver: MigrationVersion = 0; for (ver, mig) in mig_range { debug!("applying migration version {}", ver); mig.up(&mut trans)?; last_ver = *ver; } set_db_version(&trans, last_ver)?; trans.commit()?; Ok(()) } }