From 3bc7cf8ed36016ca3da9438a98f4fe8b8e6f9e61 Mon Sep 17 00:00:00 2001 From: FxQnLr Date: Thu, 15 Feb 2024 17:17:30 +0100 Subject: Closes #10 & #12. Added `thiserror` crate and changed to `IntoSocketAddr` for easier usage and error handling --- src/auth.rs | 16 +++++++++---- src/error.rs | 65 ++++++++++++++++++++++++++++++++-------------------- src/routes/device.rs | 12 +++++----- src/routes/start.rs | 9 ++++---- src/wol.rs | 18 +++++++++------ 5 files changed, 72 insertions(+), 48 deletions(-) (limited to 'src') diff --git a/src/auth.rs b/src/auth.rs index feca652..eb4d1bf 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -9,7 +9,7 @@ pub fn auth(config: &Config, secret: Option<&HeaderValue>) -> Result) -> Result (StatusCode::UNAUTHORIZED, "Wrong credentials"), Self::MissingSecret => (StatusCode::BAD_REQUEST, "Missing credentials"), - Self::HeaderToStr(err) => { - error!("server error: {}", err.to_string()); + Self::HeaderToStr { source } => { + error!("auth: {}", source); (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") }, } diff --git a/src/error.rs b/src/error.rs index 56d6c52..4f1bedd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,44 +1,59 @@ -use std::io; +use crate::auth::Error as AuthError; use axum::http::StatusCode; -use axum::Json; use axum::response::{IntoResponse, Response}; +use axum::Json; use serde_json::json; +use std::io; use tracing::error; -use crate::auth::Error as AuthError; -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum Error { + #[error("generic error")] Generic, - Auth(AuthError), - DB(sqlx::Error), - IpParse(::Err), - BufferParse(std::num::ParseIntError), - Broadcast(io::Error), + + #[error("auth: {source}")] + Auth { + #[from] + source: AuthError, + }, + + #[error("db: {source}")] + Db { + #[from] + source: sqlx::Error, + }, + + #[error("buffer parse: {source}")] + ParseInt { + #[from] + source: std::num::ParseIntError, + }, + + #[error("io: {source}")] + Io { + #[from] + source: io::Error, + }, } impl IntoResponse for Error { fn into_response(self) -> Response { + error!("{}", self.to_string()); let (status, error_message) = match self { - Self::Auth(err) => { - err.get() - }, + Self::Auth { source } => source.get(), Self::Generic => (StatusCode::INTERNAL_SERVER_ERROR, ""), - Self::IpParse(err) => { - error!("server error: {}", err.to_string()); - (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") - }, - Self::DB(err) => { - error!("server error: {}", err.to_string()); + Self::Db { source } => { + error!("{source}"); (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") - }, - Self::Broadcast(err) => { - error!("server error: {}", err.to_string()); + } + Self::Io { source } => { + error!("{source}"); (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") - }, - Self::BufferParse(err) => { - error!("server error: {}", err.to_string()); + } + Self::ParseInt { source } => { + error!("{source}"); (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") - }, + } }; let body = Json(json!({ "error": error_message, diff --git a/src/routes/device.rs b/src/routes/device.rs index c85df1b..aa52cf7 100644 --- a/src/routes/device.rs +++ b/src/routes/device.rs @@ -12,7 +12,7 @@ use crate::error::Error; pub async fn get(State(state): State>, headers: HeaderMap, Json(payload): Json) -> Result, Error> { info!("add device {}", payload.id); let secret = headers.get("authorization"); - if auth(&state.config, secret).map_err(Error::Auth)? { + if auth(&state.config, secret)? { let device = sqlx::query_as!( Device, r#" @@ -21,7 +21,7 @@ pub async fn get(State(state): State>, headers: HeaderMap, WHERE id = $1; "#, payload.id - ).fetch_one(&state.db).await.map_err(Error::DB)?; + ).fetch_one(&state.db).await?; debug!("got device {:?}", device); @@ -39,7 +39,7 @@ pub struct GetDevicePayload { pub async fn put(State(state): State>, headers: HeaderMap, Json(payload): Json) -> Result, Error> { info!("add device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); let secret = headers.get("authorization"); - if auth(&state.config, secret).map_err(Error::Auth)? { + if auth(&state.config, secret)? { sqlx::query!( r#" INSERT INTO devices (id, mac, broadcast_addr, ip) @@ -49,7 +49,7 @@ pub async fn put(State(state): State>, headers: HeaderMap, payload.mac, payload.broadcast_addr, payload.ip - ).execute(&state.db).await.map_err(Error::DB)?; + ).execute(&state.db).await?; Ok(Json(json!(PutDeviceResponse { success: true }))) } else { @@ -73,7 +73,7 @@ pub struct PutDeviceResponse { pub async fn post(State(state): State>, headers: HeaderMap, Json(payload): Json) -> Result, Error> { info!("edit device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); let secret = headers.get("authorization"); - if auth(&state.config, secret).map_err(Error::Auth)? { + if auth(&state.config, secret)? { let device = sqlx::query_as!( Device, r#" @@ -85,7 +85,7 @@ pub async fn post(State(state): State>, headers: HeaderMap, payload.broadcast_addr, payload.ip, payload.id - ).fetch_one(&state.db).await.map_err(Error::DB)?; + ).fetch_one(&state.db).await?; Ok(Json(json!(device))) } else { diff --git a/src/routes/start.rs b/src/routes/start.rs index ce95bf3..66b7cb4 100644 --- a/src/routes/start.rs +++ b/src/routes/start.rs @@ -20,7 +20,7 @@ pub async fn start( ) -> Result, Error> { info!("POST request"); let secret = headers.get("authorization"); - let authorized = auth(&state.config, secret).map_err(Error::Auth)?; + let authorized = auth(&state.config, secret)?; if authorized { let device = sqlx::query_as!( Device, @@ -32,16 +32,15 @@ pub async fn start( payload.id ) .fetch_one(&state.db) - .await - .map_err(Error::DB)?; + .await?; info!("starting {}", device.id); let bind_addr = "0.0.0.0:0"; let _ = send_packet( - &bind_addr.parse().map_err(Error::IpParse)?, - &device.broadcast_addr.parse().map_err(Error::IpParse)?, + bind_addr, + &device.broadcast_addr, &create_buffer(&device.mac)?, )?; let dev_id = device.id.clone(); diff --git a/src/wol.rs b/src/wol.rs index 83c0ee6..31cf350 100644 --- a/src/wol.rs +++ b/src/wol.rs @@ -1,4 +1,4 @@ -use std::net::{SocketAddr, UdpSocket}; +use std::net::{ToSocketAddrs, UdpSocket}; use crate::error::Error; @@ -11,8 +11,8 @@ pub fn create_buffer(mac_addr: &str) -> Result, Error> { let mut mac = Vec::new(); let sp = mac_addr.split(':'); for f in sp { - mac.push(u8::from_str_radix(f, 16).map_err(Error::BufferParse)?); - }; + mac.push(u8::from_str_radix(f, 16)?); + } let mut buf = vec![255; 6]; for _ in 0..16 { for i in &mac { @@ -23,8 +23,12 @@ pub fn create_buffer(mac_addr: &str) -> Result, Error> { } /// Sends a buffer on UDP broadcast -pub fn send_packet(bind_addr: &SocketAddr, broadcast_addr: &SocketAddr, buffer: &[u8]) -> Result { - let socket = UdpSocket::bind(bind_addr).map_err(Error::Broadcast)?; - socket.set_broadcast(true).map_err(Error::Broadcast)?; - socket.send_to(buffer, broadcast_addr).map_err(Error::Broadcast) +pub fn send_packet( + bind_addr: A, + broadcast_addr: A, + buffer: &[u8], +) -> Result { + let socket = UdpSocket::bind(bind_addr)?; + socket.set_broadcast(true)?; + Ok(socket.send_to(buffer, broadcast_addr)?) } -- cgit v1.2.3