use anyhow::{Context, Result}; use sqlx::{SqlitePool, migrate::Migrator}; use std::path::Path; use tracing::{debug, error, info, warn}; /// Manages database migrations automatically pub struct MigrationManager; impl MigrationManager { /// Ensures the database exists and runs all pending migrations pub async fn ensure_database_and_migrate(database_url: &str) -> Result { debug!("Setting up database with URL: {}", database_url); // Check if this is a memory database if database_url == "sqlite::memory:" { return Self::setup_memory_database(database_url).await; } // Extract the file path from the SQLite URL for directory creation let db_path = Self::extract_db_path(database_url)?; debug!("Extracted database path: {:?}", db_path); // Ensure the database directory exists Self::ensure_db_directory(&db_path)?; debug!("Database directory ensured"); // Create the database pool using the original URL (now in correct format) let pool = SqlitePool::connect(database_url) .await .context("Failed to connect to database")?; // Run migrations Self::run_migrations(&pool).await?; Ok(pool) } /// Sets up a memory database (for testing) async fn setup_memory_database(database_url: &str) -> Result { // Create the database pool let pool = SqlitePool::connect(database_url) .await .context("Failed to connect to memory database")?; // Run migrations Self::run_migrations(&pool).await?; Ok(pool) } /// Extracts the file path from a SQLite URL for directory creation fn extract_db_path(database_url: &str) -> Result { if database_url.starts_with("sqlite:") { let path = database_url.trim_start_matches("sqlite:"); // Keep the path as is for absolute paths Ok(Path::new(path).to_path_buf()) } else { Err(anyhow::anyhow!( "Invalid SQLite URL format: {}", database_url )) } } /// Ensures the database directory exists fn ensure_db_directory(db_path: &std::path::Path) -> Result<()> { if let Some(parent) = db_path.parent() { std::fs::create_dir_all(parent).context("Failed to create database directory")?; } Ok(()) } /// Runs all pending migrations async fn run_migrations(pool: &SqlitePool) -> Result<()> { info!("Checking for pending migrations..."); // Get the migrations directory relative to the project root let migrations_dir = std::env::current_dir() .context("Failed to get current directory")? .join("migrations"); if !migrations_dir.exists() { warn!( "Migrations directory not found at: {}", migrations_dir.display() ); return Ok(()); } // Run migrations using SQLx match Migrator::new(migrations_dir).await { Ok(migrator) => match migrator.run(pool).await { Ok(_) => { info!("Database migrations completed successfully"); Ok(()) } Err(e) => { error!("Failed to run migrations: {}", e); Err(anyhow::anyhow!("Migration failed: {}", e)) } }, Err(e) => { error!("Failed to create migrator: {}", e); Err(anyhow::anyhow!("Failed to create migrator: {}", e)) } } } /// Checks if migrations are up to date (for health checks) pub async fn check_migrations_status(pool: &SqlitePool) -> Result { let migrations_dir = std::env::current_dir() .context("Failed to get current directory")? .join("migrations"); if !migrations_dir.exists() { return Ok(true); // No migrations to check } // Check if the _sqlx_migrations table exists let table_exists = sqlx::query( "SELECT name FROM sqlite_master WHERE type='table' AND name='_sqlx_migrations'", ) .fetch_optional(pool) .await .context("Failed to check for migrations table")? .is_some(); if !table_exists { // No migrations table means no migrations have been run return Ok(false); } // Count the number of applied migrations let applied_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _sqlx_migrations") .fetch_one(pool) .await .context("Failed to count applied migrations")?; // Count the number of migration files let migration_files = std::fs::read_dir(&migrations_dir) .context("Failed to read migrations directory")? .filter_map(|entry| { entry.ok().and_then(|e| { if e.path().extension()?.to_str()? == "sql" { Some(e.path()) } else { None } }) }) .count(); let migrations_ok = applied_count as usize == migration_files; if !migrations_ok { warn!( "Migration mismatch: {} applied, {} files", applied_count, migration_files ); } Ok(migrations_ok) } } #[cfg(test)] mod tests { use super::*; use tempfile::TempDir; #[tokio::test] async fn test_extract_db_path() { let url = "sqlite:/path/to/database.db"; let path = MigrationManager::extract_db_path(url).unwrap(); assert_eq!(path.to_string_lossy(), "/path/to/database.db"); } #[tokio::test] async fn test_extract_db_path_invalid() { let url = "invalid:///path/to/database.db"; let result = MigrationManager::extract_db_path(url); assert!(result.is_err()); } #[tokio::test] async fn test_ensure_db_directory() { let temp_dir = TempDir::new().unwrap(); let db_path = temp_dir.path().join("subdir").join("database.db"); MigrationManager::ensure_db_directory(&db_path).unwrap(); assert!(db_path.parent().unwrap().exists()); } }