diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/config.rs | 35 | ||||
| -rw-r--r-- | src/health.rs | 59 | ||||
| -rw-r--r-- | src/main.rs | 24 | ||||
| -rw-r--r-- | src/migrations.rs | 194 |
4 files changed, 293 insertions, 19 deletions
diff --git a/src/config.rs b/src/config.rs index 186ef82..49bd995 100644 --- a/src/config.rs +++ b/src/config.rs @@ -57,7 +57,8 @@ impl Config { let _ = std::fs::create_dir_all(parent); } - format!("sqlite://{}", db_path.display()) + // Use the correct SQLx format: sqlite:path + format!("sqlite:{}", db_path.display()) } #[allow(dead_code)] @@ -79,9 +80,11 @@ impl Config { mod tests { use super::*; use std::env; - use std::sync::Once; + use std::sync::OnceLock; + use std::sync::{Mutex, Once}; static INIT: Once = Once::new(); + static TEST_MUTEX: OnceLock<Mutex<()>> = OnceLock::new(); fn setup() { INIT.call_once(|| { @@ -95,6 +98,9 @@ mod tests { } }); + // Get the test mutex to ensure sequential execution + let _guard = TEST_MUTEX.get_or_init(|| Mutex::new(())).lock().unwrap(); + // Clear environment variables before each test unsafe { env::remove_var("DATABASE_URL"); @@ -110,7 +116,7 @@ mod tests { setup(); let home = env::var("HOME").unwrap_or_else(|_| "/tmp".to_string()); - let expected_path = format!("sqlite://{home}/.local/share/silmataivas/silmataivas.db"); + let expected_path = format!("sqlite:{home}/.local/share/silmataivas/silmataivas.db"); let config = Config::from_env(); assert_eq!(config.database_url, expected_path); @@ -124,10 +130,15 @@ mod tests { env::set_var("XDG_DATA_HOME", "/custom/data/path"); } - let expected_path = "sqlite:///custom/data/path/silmataivas/silmataivas.db"; + let expected_path = "sqlite:/custom/data/path/silmataivas/silmataivas.db"; let config = Config::from_env(); assert_eq!(config.database_url, expected_path); + + // Clean up after this test + unsafe { + env::remove_var("XDG_DATA_HOME"); + } } #[test] @@ -135,11 +146,16 @@ mod tests { setup(); unsafe { - env::set_var("DATABASE_URL", "sqlite:///explicit/path.db"); + env::set_var("DATABASE_URL", "sqlite:/explicit/path.db"); } let config = Config::from_env(); - assert_eq!(config.database_url, "sqlite:///explicit/path.db"); + assert_eq!(config.database_url, "sqlite:/explicit/path.db"); + + // Clean up after this test + unsafe { + env::remove_var("DATABASE_URL"); + } } #[test] @@ -156,5 +172,12 @@ mod tests { assert_eq!(config.server_port, 8080); assert_eq!(config.server_host, "127.0.0.1"); assert_eq!(config.log_level, "debug"); + + // Clean up after this test + unsafe { + env::remove_var("PORT"); + env::remove_var("HOST"); + env::remove_var("LOG_LEVEL"); + } } } diff --git a/src/health.rs b/src/health.rs index c3ae80a..546d5d4 100644 --- a/src/health.rs +++ b/src/health.rs @@ -1,16 +1,67 @@ -use axum::{Json, response::IntoResponse}; +use axum::Json; use serde_json::json; +use sqlx::SqlitePool; +use std::sync::Arc; +use tracing::warn; + +use crate::migrations::MigrationManager; #[utoipa::path( get, path = "/health", responses( - (status = 200, description = "Health check", body = inline(HealthResponse)) + (status = 200, description = "Health check response", body = serde_json::Value) ), tag = "health" )] -pub async fn health_handler() -> impl IntoResponse { - Json(json!({"status": "ok"})) +pub async fn health_handler( + axum::extract::State(pool): axum::extract::State<Arc<SqlitePool>>, +) -> Json<serde_json::Value> { + let mut status = "ok"; + let details; + + // Check database connectivity + match sqlx::query("SELECT 1").execute(&*pool).await { + Ok(_) => { + // Check migration status + match MigrationManager::check_migrations_status(&pool).await { + Ok(migrations_ok) => { + if migrations_ok { + details = json!({ + "database": "connected", + "migrations": "up_to_date" + }); + } else { + warn!("Health check: migrations are not up to date"); + details = json!({ + "database": "connected", + "migrations": "pending" + }); + } + } + Err(e) => { + warn!("Health check: failed to check migration status: {}", e); + details = json!({ + "database": "connected", + "migrations": "unknown", + "migration_error": e.to_string() + }); + } + } + } + Err(e) => { + status = "error"; + details = json!({ + "database": "disconnected", + "error": e.to_string() + }); + } + } + + Json(json!({ + "status": status, + "details": details + })) } #[derive(utoipa::ToSchema, serde::Serialize)] diff --git a/src/main.rs b/src/main.rs index 9f97e37..64da786 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,7 @@ mod auth; mod config; mod health; mod locations; +mod migrations; mod notifications; mod users; mod weather_api_data; @@ -952,10 +953,11 @@ async fn main() -> anyhow::Result<()> { match cli.command.unwrap_or(Commands::Server) { Commands::Server => { - // Connect to database - let pool = SqlitePool::connect(&config.database_url) - .await - .expect("Failed to connect to DB"); + // Ensure database exists and run migrations + let pool = + migrations::MigrationManager::ensure_database_and_migrate(&config.database_url) + .await + .expect("Failed to setup database and run migrations"); // Create initial admin user if none exists { @@ -1017,7 +1019,9 @@ async fn main() -> anyhow::Result<()> { } } Commands::CreateUser { uuid } => { - let pool = SqlitePool::connect(&config.database_url).await?; + let pool = + migrations::MigrationManager::ensure_database_and_migrate(&config.database_url) + .await?; let repo = crate::users::UserRepository { db: &pool }; let user_id = uuid.unwrap_or_else(|| Uuid::new_v4().to_string()); let user = repo @@ -1075,9 +1079,10 @@ async fn main() -> anyhow::Result<()> { info!("User {} created", user_id); } Commands::CheckWeather => { - let pool = SqlitePool::connect(&config.database_url) - .await - .expect("Failed to connect to DB"); + let pool = + migrations::MigrationManager::ensure_database_and_migrate(&config.database_url) + .await + .expect("Failed to setup database and run migrations"); let poller = crate::weather_poller::WeatherPoller::new(Arc::new(pool)); info!("Manually triggering weather check..."); @@ -1115,6 +1120,7 @@ mod tests { .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = to_bytes(response.into_body(), 1024).await.unwrap(); - assert_eq!(&body[..], b"{\"status\":\"ok\"}"); + let body_str = std::str::from_utf8(&body).unwrap(); + assert!(body_str.contains("\"status\":\"ok\"")); } } diff --git a/src/migrations.rs b/src/migrations.rs new file mode 100644 index 0000000..62f86b3 --- /dev/null +++ b/src/migrations.rs @@ -0,0 +1,194 @@ +use anyhow::{Context, Result}; +use sqlx::{SqlitePool, migrate::Migrator}; +use std::path::Path; +use tracing::{debug, error, info, warn}; + +/// Manages database migrations automatically +pub struct MigrationManager; + +impl MigrationManager { + /// Ensures the database exists and runs all pending migrations + pub async fn ensure_database_and_migrate(database_url: &str) -> Result<SqlitePool> { + debug!("Setting up database with URL: {}", database_url); + + // Check if this is a memory database + if database_url == "sqlite::memory:" { + return Self::setup_memory_database(database_url).await; + } + + // Extract the file path from the SQLite URL for directory creation + let db_path = Self::extract_db_path(database_url)?; + debug!("Extracted database path: {:?}", db_path); + + // Ensure the database directory exists + Self::ensure_db_directory(&db_path)?; + debug!("Database directory ensured"); + + // Create the database pool using the original URL (now in correct format) + let pool = SqlitePool::connect(database_url) + .await + .context("Failed to connect to database")?; + + // Run migrations + Self::run_migrations(&pool).await?; + + Ok(pool) + } + + /// Sets up a memory database (for testing) + async fn setup_memory_database(database_url: &str) -> Result<SqlitePool> { + // Create the database pool + let pool = SqlitePool::connect(database_url) + .await + .context("Failed to connect to memory database")?; + + // Run migrations + Self::run_migrations(&pool).await?; + + Ok(pool) + } + + /// Extracts the file path from a SQLite URL for directory creation + fn extract_db_path(database_url: &str) -> Result<std::path::PathBuf> { + if database_url.starts_with("sqlite:") { + let path = database_url.trim_start_matches("sqlite:"); + // Keep the path as is for absolute paths + Ok(Path::new(path).to_path_buf()) + } else { + Err(anyhow::anyhow!( + "Invalid SQLite URL format: {}", + database_url + )) + } + } + + /// Ensures the database directory exists + fn ensure_db_directory(db_path: &std::path::Path) -> Result<()> { + if let Some(parent) = db_path.parent() { + std::fs::create_dir_all(parent).context("Failed to create database directory")?; + } + Ok(()) + } + + /// Runs all pending migrations + async fn run_migrations(pool: &SqlitePool) -> Result<()> { + info!("Checking for pending migrations..."); + + // Get the migrations directory relative to the project root + let migrations_dir = std::env::current_dir() + .context("Failed to get current directory")? + .join("migrations"); + + if !migrations_dir.exists() { + warn!( + "Migrations directory not found at: {}", + migrations_dir.display() + ); + return Ok(()); + } + + // Run migrations using SQLx + match Migrator::new(migrations_dir).await { + Ok(migrator) => match migrator.run(pool).await { + Ok(_) => { + info!("Database migrations completed successfully"); + Ok(()) + } + Err(e) => { + error!("Failed to run migrations: {}", e); + Err(anyhow::anyhow!("Migration failed: {}", e)) + } + }, + Err(e) => { + error!("Failed to create migrator: {}", e); + Err(anyhow::anyhow!("Failed to create migrator: {}", e)) + } + } + } + + /// Checks if migrations are up to date (for health checks) + pub async fn check_migrations_status(pool: &SqlitePool) -> Result<bool> { + let migrations_dir = std::env::current_dir() + .context("Failed to get current directory")? + .join("migrations"); + + if !migrations_dir.exists() { + return Ok(true); // No migrations to check + } + + // Check if the _sqlx_migrations table exists + let table_exists = sqlx::query( + "SELECT name FROM sqlite_master WHERE type='table' AND name='_sqlx_migrations'", + ) + .fetch_optional(pool) + .await + .context("Failed to check for migrations table")? + .is_some(); + + if !table_exists { + // No migrations table means no migrations have been run + return Ok(false); + } + + // Count the number of applied migrations + let applied_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _sqlx_migrations") + .fetch_one(pool) + .await + .context("Failed to count applied migrations")?; + + // Count the number of migration files + let migration_files = std::fs::read_dir(&migrations_dir) + .context("Failed to read migrations directory")? + .filter_map(|entry| { + entry.ok().and_then(|e| { + if e.path().extension()?.to_str()? == "sql" { + Some(e.path()) + } else { + None + } + }) + }) + .count(); + + let migrations_ok = applied_count as usize == migration_files; + + if !migrations_ok { + warn!( + "Migration mismatch: {} applied, {} files", + applied_count, migration_files + ); + } + + Ok(migrations_ok) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[tokio::test] + async fn test_extract_db_path() { + let url = "sqlite:/path/to/database.db"; + let path = MigrationManager::extract_db_path(url).unwrap(); + assert_eq!(path.to_string_lossy(), "/path/to/database.db"); + } + + #[tokio::test] + async fn test_extract_db_path_invalid() { + let url = "invalid:///path/to/database.db"; + let result = MigrationManager::extract_db_path(url); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_ensure_db_directory() { + let temp_dir = TempDir::new().unwrap(); + let db_path = temp_dir.path().join("subdir").join("database.db"); + + MigrationManager::ensure_db_directory(&db_path).unwrap(); + + assert!(db_path.parent().unwrap().exists()); + } +} |
