From 228879a2c390e08c227b7b1b2a5f44e1c34bf4e3 Mon Sep 17 00:00:00 2001 From: Dawid Rycerz Date: Wed, 30 Jul 2025 19:26:07 +0300 Subject: feat: add automatic migrations --- Cargo.lock | 1 + Cargo.toml | 5 +- lefthook.yml | 2 +- src/config.rs | 35 ++++++++-- src/health.rs | 59 +++++++++++++++-- src/main.rs | 24 ++++--- src/migrations.rs | 194 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 298 insertions(+), 22 deletions(-) create mode 100644 src/migrations.rs diff --git a/Cargo.lock b/Cargo.lock index f007fec..b86050c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2318,6 +2318,7 @@ dependencies = [ "serde_derive", "serde_json", "sqlx", + "tempfile", "tera", "tokio", "tokio-task-scheduler", diff --git a/Cargo.toml b/Cargo.toml index b15542f..e822019 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ lettre = "0.11.17" reqwest = { version = "0.12.22", features = ["json"] } serde = "1.0.219" serde_json = "1.0.140" -sqlx = { version = "0.8.6", features = ["sqlite", "runtime-tokio-native-tls"] } +sqlx = { version = "0.8.6", features = ["sqlite", "runtime-tokio-native-tls", "migrate", "macros"] } tera = "1.20.0" tokio = "1.46.1" tokio-task-scheduler = "1.0.0" @@ -35,9 +35,10 @@ reqwest = "0.12.22" serde = "1.0.219" serde_derive = "1.0.219" serde_json = "1.0.140" -sqlx = { version = "0.8.6", features = ["sqlite", "runtime-tokio-native-tls"] } +sqlx = { version = "0.8.6", features = ["sqlite", "runtime-tokio-native-tls", "macros"] } tera = "1.20.0" tokio = "1.46.1" tokio-task-scheduler = "1.0.0" tower-http = "0.6.6" uuid = "1.17.0" +tempfile = "3.10" diff --git a/lefthook.yml b/lefthook.yml index cabea2b..a99846b 100644 --- a/lefthook.yml +++ b/lefthook.yml @@ -16,4 +16,4 @@ pre-commit: clippy: run: cargo clippy --all-targets --all-features -- -D warnings test: - run: cargo test --all \ No newline at end of file + run: cargo test --all -- --test-threads=1 \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index 186ef82..49bd995 100644 --- a/src/config.rs +++ b/src/config.rs @@ -57,7 +57,8 @@ impl Config { let _ = std::fs::create_dir_all(parent); } - format!("sqlite://{}", db_path.display()) + // Use the correct SQLx format: sqlite:path + format!("sqlite:{}", db_path.display()) } #[allow(dead_code)] @@ -79,9 +80,11 @@ impl Config { mod tests { use super::*; use std::env; - use std::sync::Once; + use std::sync::OnceLock; + use std::sync::{Mutex, Once}; static INIT: Once = Once::new(); + static TEST_MUTEX: OnceLock> = OnceLock::new(); fn setup() { INIT.call_once(|| { @@ -95,6 +98,9 @@ mod tests { } }); + // Get the test mutex to ensure sequential execution + let _guard = TEST_MUTEX.get_or_init(|| Mutex::new(())).lock().unwrap(); + // Clear environment variables before each test unsafe { env::remove_var("DATABASE_URL"); @@ -110,7 +116,7 @@ mod tests { setup(); let home = env::var("HOME").unwrap_or_else(|_| "/tmp".to_string()); - let expected_path = format!("sqlite://{home}/.local/share/silmataivas/silmataivas.db"); + let expected_path = format!("sqlite:{home}/.local/share/silmataivas/silmataivas.db"); let config = Config::from_env(); assert_eq!(config.database_url, expected_path); @@ -124,10 +130,15 @@ mod tests { env::set_var("XDG_DATA_HOME", "/custom/data/path"); } - let expected_path = "sqlite:///custom/data/path/silmataivas/silmataivas.db"; + let expected_path = "sqlite:/custom/data/path/silmataivas/silmataivas.db"; let config = Config::from_env(); assert_eq!(config.database_url, expected_path); + + // Clean up after this test + unsafe { + env::remove_var("XDG_DATA_HOME"); + } } #[test] @@ -135,11 +146,16 @@ mod tests { setup(); unsafe { - env::set_var("DATABASE_URL", "sqlite:///explicit/path.db"); + env::set_var("DATABASE_URL", "sqlite:/explicit/path.db"); } let config = Config::from_env(); - assert_eq!(config.database_url, "sqlite:///explicit/path.db"); + assert_eq!(config.database_url, "sqlite:/explicit/path.db"); + + // Clean up after this test + unsafe { + env::remove_var("DATABASE_URL"); + } } #[test] @@ -156,5 +172,12 @@ mod tests { assert_eq!(config.server_port, 8080); assert_eq!(config.server_host, "127.0.0.1"); assert_eq!(config.log_level, "debug"); + + // Clean up after this test + unsafe { + env::remove_var("PORT"); + env::remove_var("HOST"); + env::remove_var("LOG_LEVEL"); + } } } diff --git a/src/health.rs b/src/health.rs index c3ae80a..546d5d4 100644 --- a/src/health.rs +++ b/src/health.rs @@ -1,16 +1,67 @@ -use axum::{Json, response::IntoResponse}; +use axum::Json; use serde_json::json; +use sqlx::SqlitePool; +use std::sync::Arc; +use tracing::warn; + +use crate::migrations::MigrationManager; #[utoipa::path( get, path = "/health", responses( - (status = 200, description = "Health check", body = inline(HealthResponse)) + (status = 200, description = "Health check response", body = serde_json::Value) ), tag = "health" )] -pub async fn health_handler() -> impl IntoResponse { - Json(json!({"status": "ok"})) +pub async fn health_handler( + axum::extract::State(pool): axum::extract::State>, +) -> Json { + let mut status = "ok"; + let details; + + // Check database connectivity + match sqlx::query("SELECT 1").execute(&*pool).await { + Ok(_) => { + // Check migration status + match MigrationManager::check_migrations_status(&pool).await { + Ok(migrations_ok) => { + if migrations_ok { + details = json!({ + "database": "connected", + "migrations": "up_to_date" + }); + } else { + warn!("Health check: migrations are not up to date"); + details = json!({ + "database": "connected", + "migrations": "pending" + }); + } + } + Err(e) => { + warn!("Health check: failed to check migration status: {}", e); + details = json!({ + "database": "connected", + "migrations": "unknown", + "migration_error": e.to_string() + }); + } + } + } + Err(e) => { + status = "error"; + details = json!({ + "database": "disconnected", + "error": e.to_string() + }); + } + } + + Json(json!({ + "status": status, + "details": details + })) } #[derive(utoipa::ToSchema, serde::Serialize)] diff --git a/src/main.rs b/src/main.rs index 9f97e37..64da786 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,7 @@ mod auth; mod config; mod health; mod locations; +mod migrations; mod notifications; mod users; mod weather_api_data; @@ -952,10 +953,11 @@ async fn main() -> anyhow::Result<()> { match cli.command.unwrap_or(Commands::Server) { Commands::Server => { - // Connect to database - let pool = SqlitePool::connect(&config.database_url) - .await - .expect("Failed to connect to DB"); + // Ensure database exists and run migrations + let pool = + migrations::MigrationManager::ensure_database_and_migrate(&config.database_url) + .await + .expect("Failed to setup database and run migrations"); // Create initial admin user if none exists { @@ -1017,7 +1019,9 @@ async fn main() -> anyhow::Result<()> { } } Commands::CreateUser { uuid } => { - let pool = SqlitePool::connect(&config.database_url).await?; + let pool = + migrations::MigrationManager::ensure_database_and_migrate(&config.database_url) + .await?; let repo = crate::users::UserRepository { db: &pool }; let user_id = uuid.unwrap_or_else(|| Uuid::new_v4().to_string()); let user = repo @@ -1075,9 +1079,10 @@ async fn main() -> anyhow::Result<()> { info!("User {} created", user_id); } Commands::CheckWeather => { - let pool = SqlitePool::connect(&config.database_url) - .await - .expect("Failed to connect to DB"); + let pool = + migrations::MigrationManager::ensure_database_and_migrate(&config.database_url) + .await + .expect("Failed to setup database and run migrations"); let poller = crate::weather_poller::WeatherPoller::new(Arc::new(pool)); info!("Manually triggering weather check..."); @@ -1115,6 +1120,7 @@ mod tests { .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = to_bytes(response.into_body(), 1024).await.unwrap(); - assert_eq!(&body[..], b"{\"status\":\"ok\"}"); + let body_str = std::str::from_utf8(&body).unwrap(); + assert!(body_str.contains("\"status\":\"ok\"")); } } diff --git a/src/migrations.rs b/src/migrations.rs new file mode 100644 index 0000000..62f86b3 --- /dev/null +++ b/src/migrations.rs @@ -0,0 +1,194 @@ +use anyhow::{Context, Result}; +use sqlx::{SqlitePool, migrate::Migrator}; +use std::path::Path; +use tracing::{debug, error, info, warn}; + +/// Manages database migrations automatically +pub struct MigrationManager; + +impl MigrationManager { + /// Ensures the database exists and runs all pending migrations + pub async fn ensure_database_and_migrate(database_url: &str) -> Result { + debug!("Setting up database with URL: {}", database_url); + + // Check if this is a memory database + if database_url == "sqlite::memory:" { + return Self::setup_memory_database(database_url).await; + } + + // Extract the file path from the SQLite URL for directory creation + let db_path = Self::extract_db_path(database_url)?; + debug!("Extracted database path: {:?}", db_path); + + // Ensure the database directory exists + Self::ensure_db_directory(&db_path)?; + debug!("Database directory ensured"); + + // Create the database pool using the original URL (now in correct format) + let pool = SqlitePool::connect(database_url) + .await + .context("Failed to connect to database")?; + + // Run migrations + Self::run_migrations(&pool).await?; + + Ok(pool) + } + + /// Sets up a memory database (for testing) + async fn setup_memory_database(database_url: &str) -> Result { + // Create the database pool + let pool = SqlitePool::connect(database_url) + .await + .context("Failed to connect to memory database")?; + + // Run migrations + Self::run_migrations(&pool).await?; + + Ok(pool) + } + + /// Extracts the file path from a SQLite URL for directory creation + fn extract_db_path(database_url: &str) -> Result { + if database_url.starts_with("sqlite:") { + let path = database_url.trim_start_matches("sqlite:"); + // Keep the path as is for absolute paths + Ok(Path::new(path).to_path_buf()) + } else { + Err(anyhow::anyhow!( + "Invalid SQLite URL format: {}", + database_url + )) + } + } + + /// Ensures the database directory exists + fn ensure_db_directory(db_path: &std::path::Path) -> Result<()> { + if let Some(parent) = db_path.parent() { + std::fs::create_dir_all(parent).context("Failed to create database directory")?; + } + Ok(()) + } + + /// Runs all pending migrations + async fn run_migrations(pool: &SqlitePool) -> Result<()> { + info!("Checking for pending migrations..."); + + // Get the migrations directory relative to the project root + let migrations_dir = std::env::current_dir() + .context("Failed to get current directory")? + .join("migrations"); + + if !migrations_dir.exists() { + warn!( + "Migrations directory not found at: {}", + migrations_dir.display() + ); + return Ok(()); + } + + // Run migrations using SQLx + match Migrator::new(migrations_dir).await { + Ok(migrator) => match migrator.run(pool).await { + Ok(_) => { + info!("Database migrations completed successfully"); + Ok(()) + } + Err(e) => { + error!("Failed to run migrations: {}", e); + Err(anyhow::anyhow!("Migration failed: {}", e)) + } + }, + Err(e) => { + error!("Failed to create migrator: {}", e); + Err(anyhow::anyhow!("Failed to create migrator: {}", e)) + } + } + } + + /// Checks if migrations are up to date (for health checks) + pub async fn check_migrations_status(pool: &SqlitePool) -> Result { + let migrations_dir = std::env::current_dir() + .context("Failed to get current directory")? + .join("migrations"); + + if !migrations_dir.exists() { + return Ok(true); // No migrations to check + } + + // Check if the _sqlx_migrations table exists + let table_exists = sqlx::query( + "SELECT name FROM sqlite_master WHERE type='table' AND name='_sqlx_migrations'", + ) + .fetch_optional(pool) + .await + .context("Failed to check for migrations table")? + .is_some(); + + if !table_exists { + // No migrations table means no migrations have been run + return Ok(false); + } + + // Count the number of applied migrations + let applied_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _sqlx_migrations") + .fetch_one(pool) + .await + .context("Failed to count applied migrations")?; + + // Count the number of migration files + let migration_files = std::fs::read_dir(&migrations_dir) + .context("Failed to read migrations directory")? + .filter_map(|entry| { + entry.ok().and_then(|e| { + if e.path().extension()?.to_str()? == "sql" { + Some(e.path()) + } else { + None + } + }) + }) + .count(); + + let migrations_ok = applied_count as usize == migration_files; + + if !migrations_ok { + warn!( + "Migration mismatch: {} applied, {} files", + applied_count, migration_files + ); + } + + Ok(migrations_ok) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[tokio::test] + async fn test_extract_db_path() { + let url = "sqlite:/path/to/database.db"; + let path = MigrationManager::extract_db_path(url).unwrap(); + assert_eq!(path.to_string_lossy(), "/path/to/database.db"); + } + + #[tokio::test] + async fn test_extract_db_path_invalid() { + let url = "invalid:///path/to/database.db"; + let result = MigrationManager::extract_db_path(url); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_ensure_db_directory() { + let temp_dir = TempDir::new().unwrap(); + let db_path = temp_dir.path().join("subdir").join("database.db"); + + MigrationManager::ensure_db_directory(&db_path).unwrap(); + + assert!(db_path.parent().unwrap().exists()); + } +} -- cgit v1.2.3