1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
|
use anyhow::{Context, Result};
use sqlx::{SqlitePool, migrate::Migrator};
use std::path::Path;
use tracing::{debug, error, info, warn};
/// Manages database migrations automatically
pub struct MigrationManager;
impl MigrationManager {
/// Ensures the database exists and runs all pending migrations
pub async fn ensure_database_and_migrate(database_url: &str) -> Result<SqlitePool> {
debug!("Setting up database with URL: {}", database_url);
// Check if this is a memory database
if database_url == "sqlite::memory:" {
return Self::setup_memory_database(database_url).await;
}
// Extract the file path from the SQLite URL for directory creation
let db_path = Self::extract_db_path(database_url)?;
debug!("Extracted database path: {:?}", db_path);
// Ensure the database directory exists
Self::ensure_db_directory(&db_path)?;
debug!("Database directory ensured");
// Create the database pool using the original URL (now in correct format)
let pool = SqlitePool::connect(database_url)
.await
.context("Failed to connect to database")?;
// Run migrations
Self::run_migrations(&pool).await?;
Ok(pool)
}
/// Sets up a memory database (for testing)
async fn setup_memory_database(database_url: &str) -> Result<SqlitePool> {
// Create the database pool
let pool = SqlitePool::connect(database_url)
.await
.context("Failed to connect to memory database")?;
// Run migrations
Self::run_migrations(&pool).await?;
Ok(pool)
}
/// Extracts the file path from a SQLite URL for directory creation
fn extract_db_path(database_url: &str) -> Result<std::path::PathBuf> {
if database_url.starts_with("sqlite:") {
let path = database_url.trim_start_matches("sqlite:");
// Keep the path as is for absolute paths
Ok(Path::new(path).to_path_buf())
} else {
Err(anyhow::anyhow!(
"Invalid SQLite URL format: {}",
database_url
))
}
}
/// Ensures the database directory exists
fn ensure_db_directory(db_path: &std::path::Path) -> Result<()> {
if let Some(parent) = db_path.parent() {
std::fs::create_dir_all(parent).context("Failed to create database directory")?;
}
Ok(())
}
/// Runs all pending migrations
async fn run_migrations(pool: &SqlitePool) -> Result<()> {
info!("Checking for pending migrations...");
// Get the migrations directory relative to the project root
let migrations_dir = std::env::current_dir()
.context("Failed to get current directory")?
.join("migrations");
if !migrations_dir.exists() {
warn!(
"Migrations directory not found at: {}",
migrations_dir.display()
);
return Ok(());
}
// Run migrations using SQLx
match Migrator::new(migrations_dir).await {
Ok(migrator) => match migrator.run(pool).await {
Ok(_) => {
info!("Database migrations completed successfully");
Ok(())
}
Err(e) => {
error!("Failed to run migrations: {}", e);
Err(anyhow::anyhow!("Migration failed: {}", e))
}
},
Err(e) => {
error!("Failed to create migrator: {}", e);
Err(anyhow::anyhow!("Failed to create migrator: {}", e))
}
}
}
/// Checks if migrations are up to date (for health checks)
pub async fn check_migrations_status(pool: &SqlitePool) -> Result<bool> {
let migrations_dir = std::env::current_dir()
.context("Failed to get current directory")?
.join("migrations");
if !migrations_dir.exists() {
return Ok(true); // No migrations to check
}
// Check if the _sqlx_migrations table exists
let table_exists = sqlx::query(
"SELECT name FROM sqlite_master WHERE type='table' AND name='_sqlx_migrations'",
)
.fetch_optional(pool)
.await
.context("Failed to check for migrations table")?
.is_some();
if !table_exists {
// No migrations table means no migrations have been run
return Ok(false);
}
// Count the number of applied migrations
let applied_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _sqlx_migrations")
.fetch_one(pool)
.await
.context("Failed to count applied migrations")?;
// Count the number of migration files
let migration_files = std::fs::read_dir(&migrations_dir)
.context("Failed to read migrations directory")?
.filter_map(|entry| {
entry.ok().and_then(|e| {
if e.path().extension()?.to_str()? == "sql" {
Some(e.path())
} else {
None
}
})
})
.count();
let migrations_ok = applied_count as usize == migration_files;
if !migrations_ok {
warn!(
"Migration mismatch: {} applied, {} files",
applied_count, migration_files
);
}
Ok(migrations_ok)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_extract_db_path() {
let url = "sqlite:/path/to/database.db";
let path = MigrationManager::extract_db_path(url).unwrap();
assert_eq!(path.to_string_lossy(), "/path/to/database.db");
}
#[tokio::test]
async fn test_extract_db_path_invalid() {
let url = "invalid:///path/to/database.db";
let result = MigrationManager::extract_db_path(url);
assert!(result.is_err());
}
#[tokio::test]
async fn test_ensure_db_directory() {
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("subdir").join("database.db");
MigrationManager::ensure_db_directory(&db_path).unwrap();
assert!(db_path.parent().unwrap().exists());
}
}
|