From 3e6a72428824c5a50a873a4284b86d0a9e47a778 Mon Sep 17 00:00:00 2001 From: fx Date: Mon, 9 Oct 2023 17:26:59 +0200 Subject: db int for api --- src/db.rs | 3 ++ src/error.rs | 31 ++++++++++++++++++ src/main.rs | 35 ++++++++++++++++---- src/routes/device.rs | 92 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/routes/mod.rs | 3 +- src/routes/start.rs | 52 ++++++++--------------------- 6 files changed, 169 insertions(+), 47 deletions(-) create mode 100644 src/error.rs create mode 100644 src/routes/device.rs (limited to 'src') diff --git a/src/db.rs b/src/db.rs index 79eca91..87943ca 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,3 +1,6 @@ +use serde::Serialize; + +#[derive(Serialize)] pub struct Device { pub id: String, pub mac: String, diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..afed111 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,31 @@ +use std::error::Error; +use axum::http::StatusCode; +use axum::Json; +use axum::response::{IntoResponse, Response}; +use serde_json::json; +use tracing::error; +use crate::auth::AuthError; + +pub enum WebolError { + Auth(AuthError), + Generic, + Server(Box), +} + +impl IntoResponse for WebolError { + fn into_response(self) -> Response { + let (status, error_message) = match self { + WebolError::Auth(err) => err.get(), + WebolError::Generic => (StatusCode::INTERNAL_SERVER_ERROR, ""), + WebolError::Server(err) => { + error!("server error: {}", err.to_string()); + (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") + }, + + }; + let body = Json(json!({ + "error": error_message, + })); + (status, body).into_response() + } +} diff --git a/src/main.rs b/src/main.rs index 761e925..bb37dc2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,13 @@ +use std::env; use std::sync::Arc; use axum::{Router, routing::post}; -use sqlx::SqlitePool; +use axum::routing::{get, put}; +use sqlx::PgPool; +use sqlx::postgres::PgPoolOptions; use time::util::local_offset; use tracing::{debug, info, level_filters::LevelFilter}; use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*}; +use crate::routes::device::{get_device, post_device, put_device}; use crate::routes::start::start; mod auth; @@ -11,6 +15,7 @@ mod config; mod routes; mod wol; mod db; +mod error; #[tokio::main] async fn main() { @@ -30,20 +35,20 @@ async fn main() { ) .init(); - debug!("connecting to db"); - let db = SqlitePool::connect("sqlite:devices.sqlite").await.unwrap(); - sqlx::migrate!().run(&db).await.unwrap(); - info!("connected to db"); - let version = env!("CARGO_PKG_VERSION"); info!("starting webol v{}", version); + let db = init_db_pool().await; + let shared_state = Arc::new(AppState { db }); // build our application with a single route let app = Router::new() .route("/start", post(start)) + .route("/device", get(get_device)) + .route("/device", put(put_device)) + .route("/device", post(post_device)) .with_state(shared_state); // run it with hyper on localhost:3000 @@ -54,5 +59,21 @@ async fn main() { } pub struct AppState { - db: SqlitePool + db: PgPool +} + +async fn init_db_pool() -> PgPool { + let db_url = env::var("DATABASE_URL").unwrap(); + + debug!("attempting to connect dbPool to '{}'", db_url); + + let pool = PgPoolOptions::new() + .max_connections(5) + .connect(&db_url) + .await + .unwrap(); + + info!("dbPool successfully connected to '{}'", db_url); + + pool } \ No newline at end of file diff --git a/src/routes/device.rs b/src/routes/device.rs new file mode 100644 index 0000000..d5d7144 --- /dev/null +++ b/src/routes/device.rs @@ -0,0 +1,92 @@ +use std::sync::Arc; +use axum::extract::State; +use axum::headers::HeaderMap; +use axum::Json; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use crate::auth::auth; +use crate::db::Device; +use crate::error::WebolError; + +pub async fn get_device(State(state): State>, headers: HeaderMap, Json(payload): Json) -> Result, WebolError> { + let secret = headers.get("authorization"); + if auth(secret).map_err(WebolError::Auth)? { + let device = sqlx::query_as!( + Device, + r#" + SELECT id, mac, broadcast_addr + FROM devices + WHERE id = $1; + "#, + payload.id + ).fetch_one(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; + + Ok(Json(json!(device))) + } else { + Err(WebolError::Generic) + } +} + +#[derive(Deserialize)] +pub struct GetDevicePayload { + id: String, +} + +pub async fn put_device(State(state): State>, headers: HeaderMap, Json(payload): Json) -> Result, WebolError> { + let secret = headers.get("authorization"); + if auth(secret).map_err(WebolError::Auth)? { + sqlx::query!( + r#" + INSERT INTO devices (id, mac, broadcast_addr) + VALUES ($1, $2, $3); + "#, + payload.id, + payload.mac, + payload.broadcast_addr + ).execute(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; + + Ok(Json(json!(PutDeviceResponse { success: true }))) + } else { + Err(WebolError::Generic) + } +} + +#[derive(Deserialize)] +pub struct PutDevicePayload { + id: String, + mac: String, + broadcast_addr: String, +} + +#[derive(Serialize)] +pub struct PutDeviceResponse { + success: bool +} + +pub async fn post_device(State(state): State>, headers: HeaderMap, Json(payload): Json) -> Result, WebolError> { + let secret = headers.get("authorization"); + if auth(secret).map_err(WebolError::Auth)? { + let device = sqlx::query_as!( + Device, + r#" + UPDATE devices + SET mac = $1, broadcast_addr = $2 WHERE id = $3 + RETURNING id, mac, broadcast_addr; + "#, + payload.mac, + payload.broadcast_addr, + payload.id + ).fetch_one(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; + + Ok(Json(json!(device))) + } else { + Err(WebolError::Generic) + } +} + +#[derive(Deserialize)] +pub struct PostDevicePayload { + id: String, + mac: String, + broadcast_addr: String, +} \ No newline at end of file diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 78d4375..12fbfab 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -1 +1,2 @@ -pub mod start; \ No newline at end of file +pub mod start; +pub mod device; \ No newline at end of file diff --git a/src/routes/start.rs b/src/routes/start.rs index 2d505fc..d16ea4e 100644 --- a/src/routes/start.rs +++ b/src/routes/start.rs @@ -1,45 +1,43 @@ use axum::headers::HeaderMap; -use axum::http::StatusCode; use axum::Json; -use axum::response::{IntoResponse, Response}; use serde::{Deserialize, Serialize}; -use std::error::Error; use std::sync::Arc; use axum::extract::State; use serde_json::{json, Value}; -use tracing::{error, info}; -use crate::auth::{auth, AuthError}; +use tracing::info; +use crate::auth::auth; use crate::config::SETTINGS; use crate::wol::{create_buffer, send_packet}; use crate::db::Device; +use crate::error::WebolError; -pub async fn start(State(state): State>, headers: HeaderMap, Json(payload): Json) -> Result, StartError> { +pub async fn start(State(state): State>, headers: HeaderMap, Json(payload): Json) -> Result, WebolError> { let secret = headers.get("authorization"); - if auth(secret).map_err(StartError::Auth)? { + if auth(secret).map_err(WebolError::Auth)? { let device = sqlx::query_as!( Device, r#" SELECT id, mac, broadcast_addr FROM devices - WHERE id = ?1; + WHERE id = $1; "#, payload.id - ).fetch_one(&state.db).await.map_err(|err| StartError::Server(Box::new(err)))?; + ).fetch_one(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; info!("starting {}", device.id); let bind_addr = SETTINGS .get_string("bindaddr") - .map_err(|err| StartError::Server(Box::new(err)))?; + .map_err(|err| WebolError::Server(Box::new(err)))?; let _ = send_packet( - &bind_addr.parse().map_err(|err| StartError::Server(Box::new(err)))?, - &device.broadcast_addr.parse().map_err(|err| StartError::Server(Box::new(err)))?, - create_buffer(&device.mac).map_err(|err| StartError::Server(Box::new(err)))? - ).map_err(|err| StartError::Server(Box::new(err))); + &bind_addr.parse().map_err(|err| WebolError::Server(Box::new(err)))?, + &device.broadcast_addr.parse().map_err(|err| WebolError::Server(Box::new(err)))?, + create_buffer(&device.mac).map_err(|err| WebolError::Server(Box::new(err)))? + ).map_err(|err| WebolError::Server(Box::new(err))); Ok(Json(json!(StartResponse { id: device.id, boot: true }))) } else { - Err(StartError::Generic) + Err(WebolError::Generic) } } @@ -53,28 +51,4 @@ pub struct StartPayload { struct StartResponse { id: String, boot: bool, -} - -pub enum StartError { - Auth(AuthError), - Generic, - Server(Box), -} - -impl IntoResponse for StartError { - fn into_response(self) -> Response { - let (status, error_message) = match self { - StartError::Auth(err) => err.get(), - StartError::Generic => (StatusCode::INTERNAL_SERVER_ERROR, ""), - StartError::Server(err) => { - error!("server error: {}", err.to_string()); - (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") - }, - - }; - let body = Json(json!({ - "error": error_message, - })); - (status, body).into_response() - } } \ No newline at end of file -- cgit v1.2.3