summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/lib.rs5
-rw-r--r--src/main.rs27
-rw-r--r--src/users.rs22
3 files changed, 51 insertions, 3 deletions
diff --git a/src/lib.rs b/src/lib.rs
index d96556a..6f4a795 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -113,7 +113,6 @@ mod locations_api {
pub struct CreateLocation {
pub latitude: f64,
pub longitude: f64,
- pub user_id: i64,
}
#[derive(Deserialize)]
@@ -147,12 +146,12 @@ mod locations_api {
}
pub async fn create_location(
- AuthUser(_): AuthUser,
+ AuthUser(user): AuthUser,
State(pool): State<Arc<SqlitePool>>,
Json(payload): Json<CreateLocation>,
) -> Result<Json<Location>, String> {
let repo = LocationRepository { db: &pool };
- repo.create_location(payload.latitude, payload.longitude, payload.user_id)
+ repo.create_location(payload.latitude, payload.longitude, user.id)
.await
.map(Json)
.map_err(|e| e.to_string())
diff --git a/src/main.rs b/src/main.rs
index a387657..1cb1515 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,9 +1,11 @@
use silmataivas::app_with_state;
+use silmataivas::users::{UserRepository, UserRole};
use sqlx::SqlitePool;
use std::env;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::fs;
+use uuid::Uuid;
#[tokio::main]
async fn main() {
@@ -16,6 +18,31 @@ async fn main() {
let pool = SqlitePool::connect(&db_path)
.await
.expect("Failed to connect to DB");
+
+ // Create initial admin user if none exists
+ {
+ let repo = UserRepository { db: &pool };
+ match repo.any_admin_exists().await {
+ Ok(false) => {
+ let admin_token =
+ env::var("ADMIN_TOKEN").unwrap_or_else(|_| Uuid::new_v4().to_string());
+ match repo
+ .create_user(Some(admin_token.clone()), Some(UserRole::Admin))
+ .await
+ {
+ Ok(_) => println!("Initial admin user created. Token: {admin_token}"),
+ Err(e) => eprintln!("Failed to create initial admin user: {e}"),
+ }
+ }
+ Ok(true) => {
+ // At least one admin exists, do nothing
+ }
+ Err(e) => {
+ eprintln!("Failed to check for existing admin users: {e}");
+ }
+ }
+ }
+
let app = app_with_state(Arc::new(pool));
let addr = SocketAddr::from(([0, 0, 0, 0], 4000));
let listener = tokio::net::TcpListener::bind(addr)
diff --git a/src/users.rs b/src/users.rs
index 0cfa440..129bf15 100644
--- a/src/users.rs
+++ b/src/users.rs
@@ -77,6 +77,14 @@ impl<'a> UserRepository<'a> {
.await?;
Ok(())
}
+
+ pub async fn any_admin_exists(&self) -> Result<bool, sqlx::Error> {
+ let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users WHERE role = ?")
+ .bind(UserRole::Admin)
+ .fetch_one(self.db)
+ .await?;
+ Ok(count.0 > 0)
+ }
}
#[cfg(test)]
@@ -141,4 +149,18 @@ mod tests {
let users = repo.list_users().await.unwrap();
assert_eq!(users.len(), 2);
}
+
+ #[tokio::test]
+ async fn test_any_admin_exists() {
+ let db = setup_db().await;
+ let repo = UserRepository { db: &db };
+ // No admin yet
+ assert!(!repo.any_admin_exists().await.unwrap());
+ // Add a user (not admin)
+ repo.create_user(None, Some(UserRole::User)).await.unwrap();
+ assert!(!repo.any_admin_exists().await.unwrap());
+ // Add an admin
+ repo.create_user(None, Some(UserRole::Admin)).await.unwrap();
+ assert!(repo.any_admin_exists().await.unwrap());
+ }
}