summaryrefslogtreecommitdiff
path: root/src/migrations.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/migrations.rs')
-rw-r--r--src/migrations.rs194
1 files changed, 194 insertions, 0 deletions
diff --git a/src/migrations.rs b/src/migrations.rs
new file mode 100644
index 0000000..62f86b3
--- /dev/null
+++ b/src/migrations.rs
@@ -0,0 +1,194 @@
+use anyhow::{Context, Result};
+use sqlx::{SqlitePool, migrate::Migrator};
+use std::path::Path;
+use tracing::{debug, error, info, warn};
+
+/// Manages database migrations automatically
+pub struct MigrationManager;
+
+impl MigrationManager {
+ /// Ensures the database exists and runs all pending migrations
+ pub async fn ensure_database_and_migrate(database_url: &str) -> Result<SqlitePool> {
+ debug!("Setting up database with URL: {}", database_url);
+
+ // Check if this is a memory database
+ if database_url == "sqlite::memory:" {
+ return Self::setup_memory_database(database_url).await;
+ }
+
+ // Extract the file path from the SQLite URL for directory creation
+ let db_path = Self::extract_db_path(database_url)?;
+ debug!("Extracted database path: {:?}", db_path);
+
+ // Ensure the database directory exists
+ Self::ensure_db_directory(&db_path)?;
+ debug!("Database directory ensured");
+
+ // Create the database pool using the original URL (now in correct format)
+ let pool = SqlitePool::connect(database_url)
+ .await
+ .context("Failed to connect to database")?;
+
+ // Run migrations
+ Self::run_migrations(&pool).await?;
+
+ Ok(pool)
+ }
+
+ /// Sets up a memory database (for testing)
+ async fn setup_memory_database(database_url: &str) -> Result<SqlitePool> {
+ // Create the database pool
+ let pool = SqlitePool::connect(database_url)
+ .await
+ .context("Failed to connect to memory database")?;
+
+ // Run migrations
+ Self::run_migrations(&pool).await?;
+
+ Ok(pool)
+ }
+
+ /// Extracts the file path from a SQLite URL for directory creation
+ fn extract_db_path(database_url: &str) -> Result<std::path::PathBuf> {
+ if database_url.starts_with("sqlite:") {
+ let path = database_url.trim_start_matches("sqlite:");
+ // Keep the path as is for absolute paths
+ Ok(Path::new(path).to_path_buf())
+ } else {
+ Err(anyhow::anyhow!(
+ "Invalid SQLite URL format: {}",
+ database_url
+ ))
+ }
+ }
+
+ /// Ensures the database directory exists
+ fn ensure_db_directory(db_path: &std::path::Path) -> Result<()> {
+ if let Some(parent) = db_path.parent() {
+ std::fs::create_dir_all(parent).context("Failed to create database directory")?;
+ }
+ Ok(())
+ }
+
+ /// Runs all pending migrations
+ async fn run_migrations(pool: &SqlitePool) -> Result<()> {
+ info!("Checking for pending migrations...");
+
+ // Get the migrations directory relative to the project root
+ let migrations_dir = std::env::current_dir()
+ .context("Failed to get current directory")?
+ .join("migrations");
+
+ if !migrations_dir.exists() {
+ warn!(
+ "Migrations directory not found at: {}",
+ migrations_dir.display()
+ );
+ return Ok(());
+ }
+
+ // Run migrations using SQLx
+ match Migrator::new(migrations_dir).await {
+ Ok(migrator) => match migrator.run(pool).await {
+ Ok(_) => {
+ info!("Database migrations completed successfully");
+ Ok(())
+ }
+ Err(e) => {
+ error!("Failed to run migrations: {}", e);
+ Err(anyhow::anyhow!("Migration failed: {}", e))
+ }
+ },
+ Err(e) => {
+ error!("Failed to create migrator: {}", e);
+ Err(anyhow::anyhow!("Failed to create migrator: {}", e))
+ }
+ }
+ }
+
+ /// Checks if migrations are up to date (for health checks)
+ pub async fn check_migrations_status(pool: &SqlitePool) -> Result<bool> {
+ let migrations_dir = std::env::current_dir()
+ .context("Failed to get current directory")?
+ .join("migrations");
+
+ if !migrations_dir.exists() {
+ return Ok(true); // No migrations to check
+ }
+
+ // Check if the _sqlx_migrations table exists
+ let table_exists = sqlx::query(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='_sqlx_migrations'",
+ )
+ .fetch_optional(pool)
+ .await
+ .context("Failed to check for migrations table")?
+ .is_some();
+
+ if !table_exists {
+ // No migrations table means no migrations have been run
+ return Ok(false);
+ }
+
+ // Count the number of applied migrations
+ let applied_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _sqlx_migrations")
+ .fetch_one(pool)
+ .await
+ .context("Failed to count applied migrations")?;
+
+ // Count the number of migration files
+ let migration_files = std::fs::read_dir(&migrations_dir)
+ .context("Failed to read migrations directory")?
+ .filter_map(|entry| {
+ entry.ok().and_then(|e| {
+ if e.path().extension()?.to_str()? == "sql" {
+ Some(e.path())
+ } else {
+ None
+ }
+ })
+ })
+ .count();
+
+ let migrations_ok = applied_count as usize == migration_files;
+
+ if !migrations_ok {
+ warn!(
+ "Migration mismatch: {} applied, {} files",
+ applied_count, migration_files
+ );
+ }
+
+ Ok(migrations_ok)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use tempfile::TempDir;
+
+ #[tokio::test]
+ async fn test_extract_db_path() {
+ let url = "sqlite:/path/to/database.db";
+ let path = MigrationManager::extract_db_path(url).unwrap();
+ assert_eq!(path.to_string_lossy(), "/path/to/database.db");
+ }
+
+ #[tokio::test]
+ async fn test_extract_db_path_invalid() {
+ let url = "invalid:///path/to/database.db";
+ let result = MigrationManager::extract_db_path(url);
+ assert!(result.is_err());
+ }
+
+ #[tokio::test]
+ async fn test_ensure_db_directory() {
+ let temp_dir = TempDir::new().unwrap();
+ let db_path = temp_dir.path().join("subdir").join("database.db");
+
+ MigrationManager::ensure_db_directory(&db_path).unwrap();
+
+ assert!(db_path.parent().unwrap().exists());
+ }
+}