aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/db.rs5
-rw-r--r--src/main.rs20
-rw-r--r--src/routes/start.rs30
3 files changed, 43 insertions, 12 deletions
diff --git a/src/db.rs b/src/db.rs
new file mode 100644
index 0000000..79eca91
--- /dev/null
+++ b/src/db.rs
@@ -0,0 +1,5 @@
1pub struct Device {
2 pub id: String,
3 pub mac: String,
4 pub broadcast_addr: String
5} \ No newline at end of file
diff --git a/src/main.rs b/src/main.rs
index 0fe170d..761e925 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,6 +1,8 @@
1use std::sync::Arc;
1use axum::{Router, routing::post}; 2use axum::{Router, routing::post};
3use sqlx::SqlitePool;
2use time::util::local_offset; 4use time::util::local_offset;
3use tracing::{info, level_filters::LevelFilter}; 5use tracing::{debug, info, level_filters::LevelFilter};
4use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*}; 6use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*};
5use crate::routes::start::start; 7use crate::routes::start::start;
6 8
@@ -8,6 +10,7 @@ mod auth;
8mod config; 10mod config;
9mod routes; 11mod routes;
10mod wol; 12mod wol;
13mod db;
11 14
12#[tokio::main] 15#[tokio::main]
13async fn main() { 16async fn main() {
@@ -27,13 +30,21 @@ async fn main() {
27 ) 30 )
28 .init(); 31 .init();
29 32
33 debug!("connecting to db");
34 let db = SqlitePool::connect("sqlite:devices.sqlite").await.unwrap();
35 sqlx::migrate!().run(&db).await.unwrap();
36 info!("connected to db");
37
30 let version = env!("CARGO_PKG_VERSION"); 38 let version = env!("CARGO_PKG_VERSION");
31 39
32 info!("Starting webol v{}", version); 40 info!("starting webol v{}", version);
41
42 let shared_state = Arc::new(AppState { db });
33 43
34 // build our application with a single route 44 // build our application with a single route
35 let app = Router::new() 45 let app = Router::new()
36 .route("/start", post(start)); 46 .route("/start", post(start))
47 .with_state(shared_state);
37 48
38 // run it with hyper on localhost:3000 49 // run it with hyper on localhost:3000
39 axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) 50 axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
@@ -42,3 +53,6 @@ async fn main() {
42 .unwrap(); 53 .unwrap();
43} 54}
44 55
56pub struct AppState {
57 db: SqlitePool
58} \ No newline at end of file
diff --git a/src/routes/start.rs b/src/routes/start.rs
index e7d7e0e..2d505fc 100644
--- a/src/routes/start.rs
+++ b/src/routes/start.rs
@@ -4,28 +4,40 @@ use axum::Json;
4use axum::response::{IntoResponse, Response}; 4use axum::response::{IntoResponse, Response};
5use serde::{Deserialize, Serialize}; 5use serde::{Deserialize, Serialize};
6use std::error::Error; 6use std::error::Error;
7use std::sync::Arc;
8use axum::extract::State;
7use serde_json::{json, Value}; 9use serde_json::{json, Value};
8use tracing::error; 10use tracing::{error, info};
9use crate::auth::{auth, AuthError}; 11use crate::auth::{auth, AuthError};
10use crate::config::SETTINGS; 12use crate::config::SETTINGS;
11use crate::wol::{create_buffer, send_packet}; 13use crate::wol::{create_buffer, send_packet};
14use crate::db::Device;
12 15
13pub async fn start(headers: HeaderMap, Json(payload): Json<StartPayload>) -> Result<Json<Value>, StartError> { 16pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<StartPayload>) -> Result<Json<Value>, StartError> {
14 let secret = headers.get("authorization"); 17 let secret = headers.get("authorization");
15 if auth(secret).map_err(StartError::Auth)? { 18 if auth(secret).map_err(StartError::Auth)? {
19 let device = sqlx::query_as!(
20 Device,
21 r#"
22 SELECT id, mac, broadcast_addr
23 FROM devices
24 WHERE id = ?1;
25 "#,
26 payload.id
27 ).fetch_one(&state.db).await.map_err(|err| StartError::Server(Box::new(err)))?;
28
29 info!("starting {}", device.id);
30
16 let bind_addr = SETTINGS 31 let bind_addr = SETTINGS
17 .get_string("bindaddr") 32 .get_string("bindaddr")
18 .map_err(|err| StartError::Server(Box::new(err)))?; 33 .map_err(|err| StartError::Server(Box::new(err)))?;
19 let broadcast_addr = SETTINGS 34
20 .get_string("broadcastaddr")
21 .map_err(|err| StartError::Server(Box::new(err)))?;
22 let _ = send_packet( 35 let _ = send_packet(
23 &bind_addr.parse().map_err(|err| StartError::Server(Box::new(err)))?, 36 &bind_addr.parse().map_err(|err| StartError::Server(Box::new(err)))?,
24 &broadcast_addr.parse().map_err(|err| StartError::Server(Box::new(err)))?, 37 &device.broadcast_addr.parse().map_err(|err| StartError::Server(Box::new(err)))?,
25 // TODO: MAC saved in DB 38 create_buffer(&device.mac).map_err(|err| StartError::Server(Box::new(err)))?
26 create_buffer(std::env::var("MAC").unwrap().as_str()).map_err(|err| StartError::Server(Box::new(err)))?
27 ).map_err(|err| StartError::Server(Box::new(err))); 39 ).map_err(|err| StartError::Server(Box::new(err)));
28 Ok(Json(json!(StartResponse { id: payload.id, boot: true }))) 40 Ok(Json(json!(StartResponse { id: device.id, boot: true })))
29 } else { 41 } else {
30 Err(StartError::Generic) 42 Err(StartError::Generic)
31 } 43 }