diff options
Diffstat (limited to 'src/migrations.rs')
| -rw-r--r-- | src/migrations.rs | 194 |
1 files changed, 194 insertions, 0 deletions
diff --git a/src/migrations.rs b/src/migrations.rs new file mode 100644 index 0000000..62f86b3 --- /dev/null +++ b/src/migrations.rs @@ -0,0 +1,194 @@ +use anyhow::{Context, Result}; +use sqlx::{SqlitePool, migrate::Migrator}; +use std::path::Path; +use tracing::{debug, error, info, warn}; + +/// Manages database migrations automatically +pub struct MigrationManager; + +impl MigrationManager { + /// Ensures the database exists and runs all pending migrations + pub async fn ensure_database_and_migrate(database_url: &str) -> Result<SqlitePool> { + debug!("Setting up database with URL: {}", database_url); + + // Check if this is a memory database + if database_url == "sqlite::memory:" { + return Self::setup_memory_database(database_url).await; + } + + // Extract the file path from the SQLite URL for directory creation + let db_path = Self::extract_db_path(database_url)?; + debug!("Extracted database path: {:?}", db_path); + + // Ensure the database directory exists + Self::ensure_db_directory(&db_path)?; + debug!("Database directory ensured"); + + // Create the database pool using the original URL (now in correct format) + let pool = SqlitePool::connect(database_url) + .await + .context("Failed to connect to database")?; + + // Run migrations + Self::run_migrations(&pool).await?; + + Ok(pool) + } + + /// Sets up a memory database (for testing) + async fn setup_memory_database(database_url: &str) -> Result<SqlitePool> { + // Create the database pool + let pool = SqlitePool::connect(database_url) + .await + .context("Failed to connect to memory database")?; + + // Run migrations + Self::run_migrations(&pool).await?; + + Ok(pool) + } + + /// Extracts the file path from a SQLite URL for directory creation + fn extract_db_path(database_url: &str) -> Result<std::path::PathBuf> { + if database_url.starts_with("sqlite:") { + let path = database_url.trim_start_matches("sqlite:"); + // Keep the path as is for absolute paths + Ok(Path::new(path).to_path_buf()) + } else { + Err(anyhow::anyhow!( + "Invalid SQLite URL format: {}", + database_url + )) + } + } + + /// Ensures the database directory exists + fn ensure_db_directory(db_path: &std::path::Path) -> Result<()> { + if let Some(parent) = db_path.parent() { + std::fs::create_dir_all(parent).context("Failed to create database directory")?; + } + Ok(()) + } + + /// Runs all pending migrations + async fn run_migrations(pool: &SqlitePool) -> Result<()> { + info!("Checking for pending migrations..."); + + // Get the migrations directory relative to the project root + let migrations_dir = std::env::current_dir() + .context("Failed to get current directory")? + .join("migrations"); + + if !migrations_dir.exists() { + warn!( + "Migrations directory not found at: {}", + migrations_dir.display() + ); + return Ok(()); + } + + // Run migrations using SQLx + match Migrator::new(migrations_dir).await { + Ok(migrator) => match migrator.run(pool).await { + Ok(_) => { + info!("Database migrations completed successfully"); + Ok(()) + } + Err(e) => { + error!("Failed to run migrations: {}", e); + Err(anyhow::anyhow!("Migration failed: {}", e)) + } + }, + Err(e) => { + error!("Failed to create migrator: {}", e); + Err(anyhow::anyhow!("Failed to create migrator: {}", e)) + } + } + } + + /// Checks if migrations are up to date (for health checks) + pub async fn check_migrations_status(pool: &SqlitePool) -> Result<bool> { + let migrations_dir = std::env::current_dir() + .context("Failed to get current directory")? + .join("migrations"); + + if !migrations_dir.exists() { + return Ok(true); // No migrations to check + } + + // Check if the _sqlx_migrations table exists + let table_exists = sqlx::query( + "SELECT name FROM sqlite_master WHERE type='table' AND name='_sqlx_migrations'", + ) + .fetch_optional(pool) + .await + .context("Failed to check for migrations table")? + .is_some(); + + if !table_exists { + // No migrations table means no migrations have been run + return Ok(false); + } + + // Count the number of applied migrations + let applied_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _sqlx_migrations") + .fetch_one(pool) + .await + .context("Failed to count applied migrations")?; + + // Count the number of migration files + let migration_files = std::fs::read_dir(&migrations_dir) + .context("Failed to read migrations directory")? + .filter_map(|entry| { + entry.ok().and_then(|e| { + if e.path().extension()?.to_str()? == "sql" { + Some(e.path()) + } else { + None + } + }) + }) + .count(); + + let migrations_ok = applied_count as usize == migration_files; + + if !migrations_ok { + warn!( + "Migration mismatch: {} applied, {} files", + applied_count, migration_files + ); + } + + Ok(migrations_ok) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[tokio::test] + async fn test_extract_db_path() { + let url = "sqlite:/path/to/database.db"; + let path = MigrationManager::extract_db_path(url).unwrap(); + assert_eq!(path.to_string_lossy(), "/path/to/database.db"); + } + + #[tokio::test] + async fn test_extract_db_path_invalid() { + let url = "invalid:///path/to/database.db"; + let result = MigrationManager::extract_db_path(url); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_ensure_db_directory() { + let temp_dir = TempDir::new().unwrap(); + let db_path = temp_dir.path().join("subdir").join("database.db"); + + MigrationManager::ensure_db_directory(&db_path).unwrap(); + + assert!(db_path.parent().unwrap().exists()); + } +} |
