From 9058f191b69ecafc8fdeace227ac113412d03888 Mon Sep 17 00:00:00 2001 From: FxQnLr Date: Sun, 25 Feb 2024 15:15:19 +0100 Subject: Closes #16. Impl auth as extractor --- src/auth.rs | 30 -------------- src/error.rs | 6 +-- src/extractors.rs | 24 +++++++++++ src/main.rs | 20 ++++----- src/routes/device.rs | 113 ++++++++++++++++++++------------------------------- src/routes/start.rs | 65 +++++++++++++---------------- 6 files changed, 107 insertions(+), 151 deletions(-) delete mode 100644 src/auth.rs create mode 100644 src/extractors.rs (limited to 'src') diff --git a/src/auth.rs b/src/auth.rs deleted file mode 100644 index 22f87e7..0000000 --- a/src/auth.rs +++ /dev/null @@ -1,30 +0,0 @@ -use axum::http::HeaderValue; -use tracing::{debug, trace}; -use crate::config::Config; -use crate::error::Error; - -pub fn auth(config: &Config, secret: Option<&HeaderValue>) -> Result { - debug!("auth request with secret {:?}", secret); - let res = if let Some(value) = secret { - trace!("auth value exists"); - let key = &config.apikey; - if value.to_str()? == key.as_str() { - debug!("successful auth"); - Response::Success - } else { - debug!("unsuccessful auth (wrong secret)"); - Response::WrongSecret - } - } else { - debug!("unsuccessful auth (no secret)"); - Response::MissingSecret - }; - Ok(res) -} - -#[derive(Debug)] -pub enum Response { - Success, - WrongSecret, - MissingSecret -} diff --git a/src/error.rs b/src/error.rs index 66a61f4..513b51b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,8 +1,8 @@ +use ::ipnetwork::IpNetworkError; use axum::http::header::ToStrError; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::Json; -use ::ipnetwork::IpNetworkError; use mac_address::MacParseError; use serde_json::json; use std::io; @@ -10,9 +10,6 @@ use tracing::error; #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("generic error")] - Generic, - #[error("db: {source}")] Db { #[from] @@ -54,7 +51,6 @@ impl IntoResponse for Error { fn into_response(self) -> Response { error!("{}", self.to_string()); let (status, error_message) = match self { - Self::Generic => (StatusCode::INTERNAL_SERVER_ERROR, ""), Self::Db { source } => { error!("{source}"); (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") diff --git a/src/extractors.rs b/src/extractors.rs new file mode 100644 index 0000000..4d441e9 --- /dev/null +++ b/src/extractors.rs @@ -0,0 +1,24 @@ +use axum::{ + extract::{Request, State}, + http::{HeaderMap, StatusCode}, + middleware::Next, + response::Response, +}; + +use crate::AppState; + +pub async fn auth( + State(state): State, + headers: HeaderMap, + request: Request, + next: Next, +) -> Result { + let secret = headers.get("authorization"); + match secret { + Some(token) if token == state.config.apikey.as_str() => { + let response = next.run(request).await; + Ok(response) + } + _ => Err(StatusCode::UNAUTHORIZED), + } +} diff --git a/src/main.rs b/src/main.rs index 7d8c1da..eae89f6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,26 +4,23 @@ use crate::routes::device; use crate::routes::start::start; use crate::routes::status::status; use crate::services::ping::StatusMap; +use axum::middleware::from_fn_with_state; use axum::routing::{get, put}; use axum::{routing::post, Router}; use dashmap::DashMap; use services::ping::BroadcastCommand; use sqlx::PgPool; -use tracing_subscriber::fmt::time::UtcTime; use std::env; use std::sync::Arc; use tokio::sync::broadcast::{channel, Sender}; use tracing::{info, level_filters::LevelFilter}; -use tracing_subscriber::{ - fmt, - prelude::*, - EnvFilter, -}; +use tracing_subscriber::fmt::time::UtcTime; +use tracing_subscriber::{fmt, prelude::*, EnvFilter}; -mod auth; mod config; mod db; mod error; +mod extractors; mod routes; mod services; mod wol; @@ -31,7 +28,6 @@ mod wol; #[tokio::main] async fn main() -> color_eyre::eyre::Result<()> { color_eyre::install()?; - let time_format = time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second]"); @@ -59,12 +55,12 @@ async fn main() -> color_eyre::eyre::Result<()> { let ping_map: StatusMap = DashMap::new(); - let shared_state = Arc::new(AppState { + let shared_state = AppState { db, config: config.clone(), ping_send: tx, ping_map, - }); + }; let app = Router::new() .route("/start", post(start)) @@ -72,7 +68,8 @@ async fn main() -> color_eyre::eyre::Result<()> { .route("/device", put(device::put)) .route("/device", post(device::post)) .route("/status", get(status)) - .with_state(shared_state); + .route_layer(from_fn_with_state(shared_state.clone(), extractors::auth)) + .with_state(Arc::new(shared_state)); let addr = config.serveraddr; info!("start server on {}", addr); @@ -82,6 +79,7 @@ async fn main() -> color_eyre::eyre::Result<()> { Ok(()) } +#[derive(Clone)] pub struct AppState { db: PgPool, config: Config, diff --git a/src/routes/device.rs b/src/routes/device.rs index 2f0093d..d39d98e 100644 --- a/src/routes/device.rs +++ b/src/routes/device.rs @@ -1,8 +1,6 @@ -use crate::auth::auth; use crate::db::Device; use crate::error::Error; use axum::extract::State; -use axum::http::HeaderMap; use axum::Json; use mac_address::MacAddress; use serde::{Deserialize, Serialize}; @@ -13,31 +11,24 @@ use tracing::{debug, info}; pub async fn get( State(state): State>, - headers: HeaderMap, Json(payload): Json, ) -> Result, Error> { info!("get device {}", payload.id); - let secret = headers.get("authorization"); - let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); - if authorized { - let device = sqlx::query_as!( - Device, - r#" - SELECT id, mac, broadcast_addr, ip, times - FROM devices - WHERE id = $1; - "#, - payload.id - ) - .fetch_one(&state.db) - .await?; + let device = sqlx::query_as!( + Device, + r#" + SELECT id, mac, broadcast_addr, ip, times + FROM devices + WHERE id = $1; + "#, + payload.id + ) + .fetch_one(&state.db) + .await?; - debug!("got device {:?}", device); + debug!("got device {:?}", device); - Ok(Json(json!(device))) - } else { - Err(Error::Generic) - } + Ok(Json(json!(device))) } #[derive(Deserialize)] @@ -47,7 +38,6 @@ pub struct GetDevicePayload { pub async fn put( State(state): State>, - headers: HeaderMap, Json(payload): Json, ) -> Result, Error> { info!( @@ -55,28 +45,22 @@ pub async fn put( payload.id, payload.mac, payload.broadcast_addr, payload.ip ); - let secret = headers.get("authorization"); - let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); - if authorized { - let ip = IpNetwork::from_str(&payload.ip)?; - let mac = MacAddress::from_str(&payload.mac)?; - sqlx::query!( - r#" - INSERT INTO devices (id, mac, broadcast_addr, ip) - VALUES ($1, $2, $3, $4); - "#, - payload.id, - mac, - payload.broadcast_addr, - ip - ) - .execute(&state.db) - .await?; + let ip = IpNetwork::from_str(&payload.ip)?; + let mac = MacAddress::from_str(&payload.mac)?; + sqlx::query!( + r#" + INSERT INTO devices (id, mac, broadcast_addr, ip) + VALUES ($1, $2, $3, $4); + "#, + payload.id, + mac, + payload.broadcast_addr, + ip + ) + .execute(&state.db) + .await?; - Ok(Json(json!(PutDeviceResponse { success: true }))) - } else { - Err(Error::Generic) - } + Ok(Json(json!(PutDeviceResponse { success: true }))) } #[derive(Deserialize)] @@ -94,37 +78,30 @@ pub struct PutDeviceResponse { pub async fn post( State(state): State>, - headers: HeaderMap, Json(payload): Json, ) -> Result, Error> { info!( "edit device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip ); - let secret = headers.get("authorization"); - let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); - if authorized { - let ip = IpNetwork::from_str(&payload.ip)?; - let mac = MacAddress::from_str(&payload.mac)?; - let device = sqlx::query_as!( - Device, - r#" - UPDATE devices - SET mac = $1, broadcast_addr = $2, ip = $3 WHERE id = $4 - RETURNING id, mac, broadcast_addr, ip, times; - "#, - mac, - payload.broadcast_addr, - ip, - payload.id - ) - .fetch_one(&state.db) - .await?; + let ip = IpNetwork::from_str(&payload.ip)?; + let mac = MacAddress::from_str(&payload.mac)?; + let device = sqlx::query_as!( + Device, + r#" + UPDATE devices + SET mac = $1, broadcast_addr = $2, ip = $3 WHERE id = $4 + RETURNING id, mac, broadcast_addr, ip, times; + "#, + mac, + payload.broadcast_addr, + ip, + payload.id + ) + .fetch_one(&state.db) + .await?; - Ok(Json(json!(device))) - } else { - Err(Error::Generic) - } + Ok(Json(json!(device))) } #[derive(Deserialize)] diff --git a/src/routes/start.rs b/src/routes/start.rs index 4888325..d4c0802 100644 --- a/src/routes/start.rs +++ b/src/routes/start.rs @@ -1,10 +1,8 @@ -use crate::auth::auth; use crate::db::Device; use crate::error::Error; use crate::services::ping::Value as PingValue; use crate::wol::{create_buffer, send_packet}; use axum::extract::State; -use axum::http::HeaderMap; use axum::Json; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -14,48 +12,41 @@ use uuid::Uuid; pub async fn start( State(state): State>, - headers: HeaderMap, Json(payload): Json, ) -> Result, Error> { info!("POST request"); - let secret = headers.get("authorization"); - let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); - if authorized { - let device = sqlx::query_as!( - Device, - r#" - SELECT id, mac, broadcast_addr, ip, times - FROM devices - WHERE id = $1; - "#, - payload.id - ) - .fetch_one(&state.db) - .await?; + let device = sqlx::query_as!( + Device, + r#" + SELECT id, mac, broadcast_addr, ip, times + FROM devices + WHERE id = $1; + "#, + payload.id + ) + .fetch_one(&state.db) + .await?; - info!("starting {}", device.id); + info!("starting {}", device.id); - let bind_addr = "0.0.0.0:0"; + let bind_addr = "0.0.0.0:0"; - let _ = send_packet( - bind_addr, - &device.broadcast_addr, - &create_buffer(&device.mac.to_string())?, - )?; - let dev_id = device.id.clone(); - let uuid = if payload.ping.is_some_and(|ping| ping) { - Some(setup_ping(state, device)) - } else { - None - }; - Ok(Json(json!(Response { - id: dev_id, - boot: true, - uuid - }))) + let _ = send_packet( + bind_addr, + &device.broadcast_addr, + &create_buffer(&device.mac.to_string())?, + )?; + let dev_id = device.id.clone(); + let uuid = if payload.ping.is_some_and(|ping| ping) { + Some(setup_ping(state, device)) } else { - Err(Error::Generic) - } + None + }; + Ok(Json(json!(Response { + id: dev_id, + boot: true, + uuid + }))) } fn setup_ping(state: Arc, device: Device) -> String { -- cgit v1.2.3