From 732c487d3dab4af9fc561527591d3d56299e39f2 Mon Sep 17 00:00:00 2001 From: fx Date: Mon, 9 Oct 2023 16:06:08 +0200 Subject: added db for devices --- src/db.rs | 5 +++++ src/main.rs | 20 +++++++++++++++++--- src/routes/start.rs | 30 +++++++++++++++++++++--------- 3 files changed, 43 insertions(+), 12 deletions(-) create mode 100644 src/db.rs (limited to 'src') diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..79eca91 --- /dev/null +++ b/src/db.rs @@ -0,0 +1,5 @@ +pub struct Device { + pub id: String, + pub mac: String, + pub broadcast_addr: String +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 0fe170d..761e925 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,8 @@ +use std::sync::Arc; use axum::{Router, routing::post}; +use sqlx::SqlitePool; use time::util::local_offset; -use tracing::{info, level_filters::LevelFilter}; +use tracing::{debug, info, level_filters::LevelFilter}; use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*}; use crate::routes::start::start; @@ -8,6 +10,7 @@ mod auth; mod config; mod routes; mod wol; +mod db; #[tokio::main] async fn main() { @@ -27,13 +30,21 @@ async fn main() { ) .init(); + debug!("connecting to db"); + let db = SqlitePool::connect("sqlite:devices.sqlite").await.unwrap(); + sqlx::migrate!().run(&db).await.unwrap(); + info!("connected to db"); + let version = env!("CARGO_PKG_VERSION"); - info!("Starting webol v{}", version); + info!("starting webol v{}", version); + + let shared_state = Arc::new(AppState { db }); // build our application with a single route let app = Router::new() - .route("/start", post(start)); + .route("/start", post(start)) + .with_state(shared_state); // run it with hyper on localhost:3000 axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) @@ -42,3 +53,6 @@ async fn main() { .unwrap(); } +pub struct AppState { + db: SqlitePool +} \ No newline at end of file diff --git a/src/routes/start.rs b/src/routes/start.rs index e7d7e0e..2d505fc 100644 --- a/src/routes/start.rs +++ b/src/routes/start.rs @@ -4,28 +4,40 @@ use axum::Json; use axum::response::{IntoResponse, Response}; use serde::{Deserialize, Serialize}; use std::error::Error; +use std::sync::Arc; +use axum::extract::State; use serde_json::{json, Value}; -use tracing::error; +use tracing::{error, info}; use crate::auth::{auth, AuthError}; use crate::config::SETTINGS; use crate::wol::{create_buffer, send_packet}; +use crate::db::Device; -pub async fn start(headers: HeaderMap, Json(payload): Json) -> Result, StartError> { +pub async fn start(State(state): State>, headers: HeaderMap, Json(payload): Json) -> Result, StartError> { let secret = headers.get("authorization"); if auth(secret).map_err(StartError::Auth)? { + let device = sqlx::query_as!( + Device, + r#" + SELECT id, mac, broadcast_addr + FROM devices + WHERE id = ?1; + "#, + payload.id + ).fetch_one(&state.db).await.map_err(|err| StartError::Server(Box::new(err)))?; + + info!("starting {}", device.id); + let bind_addr = SETTINGS .get_string("bindaddr") .map_err(|err| StartError::Server(Box::new(err)))?; - let broadcast_addr = SETTINGS - .get_string("broadcastaddr") - .map_err(|err| StartError::Server(Box::new(err)))?; + let _ = send_packet( &bind_addr.parse().map_err(|err| StartError::Server(Box::new(err)))?, - &broadcast_addr.parse().map_err(|err| StartError::Server(Box::new(err)))?, - // TODO: MAC saved in DB - create_buffer(std::env::var("MAC").unwrap().as_str()).map_err(|err| StartError::Server(Box::new(err)))? + &device.broadcast_addr.parse().map_err(|err| StartError::Server(Box::new(err)))?, + create_buffer(&device.mac).map_err(|err| StartError::Server(Box::new(err)))? ).map_err(|err| StartError::Server(Box::new(err))); - Ok(Json(json!(StartResponse { id: payload.id, boot: true }))) + Ok(Json(json!(StartResponse { id: device.id, boot: true }))) } else { Err(StartError::Generic) } -- cgit v1.2.3