From 2f9f18b80a9e2134f674f345e48a5f21de5efadd Mon Sep 17 00:00:00 2001 From: FxQnLr Date: Sun, 18 Feb 2024 21:16:46 +0100 Subject: Refactor stuff. Use Postgres Types --- src/routes/device.rs | 19 +++++++++---- src/routes/mod.rs | 3 -- src/routes/start.rs | 78 +++++++++++++++++++++++++++------------------------ src/routes/status.rs | 79 ++++++++++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 128 insertions(+), 51 deletions(-) delete mode 100644 src/routes/mod.rs (limited to 'src/routes') diff --git a/src/routes/device.rs b/src/routes/device.rs index 5ca574a..2f0093d 100644 --- a/src/routes/device.rs +++ b/src/routes/device.rs @@ -4,9 +4,11 @@ use crate::error::Error; use axum::extract::State; use axum::http::HeaderMap; use axum::Json; +use mac_address::MacAddress; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use std::sync::Arc; +use sqlx::types::ipnetwork::IpNetwork; +use std::{sync::Arc, str::FromStr}; use tracing::{debug, info}; pub async fn get( @@ -14,7 +16,7 @@ pub async fn get( headers: HeaderMap, Json(payload): Json, ) -> Result, Error> { - info!("add device {}", payload.id); + info!("get device {}", payload.id); let secret = headers.get("authorization"); let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); if authorized { @@ -52,18 +54,21 @@ pub async fn put( "add device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip ); + let secret = headers.get("authorization"); let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); if authorized { + let ip = IpNetwork::from_str(&payload.ip)?; + let mac = MacAddress::from_str(&payload.mac)?; sqlx::query!( r#" INSERT INTO devices (id, mac, broadcast_addr, ip) VALUES ($1, $2, $3, $4); "#, payload.id, - payload.mac, + mac, payload.broadcast_addr, - payload.ip + ip ) .execute(&state.db) .await?; @@ -99,6 +104,8 @@ pub async fn post( let secret = headers.get("authorization"); let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); if authorized { + let ip = IpNetwork::from_str(&payload.ip)?; + let mac = MacAddress::from_str(&payload.mac)?; let device = sqlx::query_as!( Device, r#" @@ -106,9 +113,9 @@ pub async fn post( SET mac = $1, broadcast_addr = $2, ip = $3 WHERE id = $4 RETURNING id, mac, broadcast_addr, ip, times; "#, - payload.mac, + mac, payload.broadcast_addr, - payload.ip, + ip, payload.id ) .fetch_one(&state.db) diff --git a/src/routes/mod.rs b/src/routes/mod.rs deleted file mode 100644 index d5ab0d6..0000000 --- a/src/routes/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod start; -pub mod device; -pub mod status; \ No newline at end of file diff --git a/src/routes/start.rs b/src/routes/start.rs index ec4f98f..4888325 100644 --- a/src/routes/start.rs +++ b/src/routes/start.rs @@ -12,7 +12,6 @@ use std::sync::Arc; use tracing::{debug, info}; use uuid::Uuid; -#[axum_macros::debug_handler] pub async fn start( State(state): State>, headers: HeaderMap, @@ -41,45 +40,11 @@ pub async fn start( let _ = send_packet( bind_addr, &device.broadcast_addr, - &create_buffer(&device.mac)?, + &create_buffer(&device.mac.to_string())?, )?; let dev_id = device.id.clone(); let uuid = if payload.ping.is_some_and(|ping| ping) { - let mut uuid: Option = None; - for (key, value) in state.ping_map.clone() { - if value.ip == device.ip { - debug!("service already exists"); - uuid = Some(key); - break; - } - } - let uuid_gen = match uuid { - Some(u) => u, - None => Uuid::new_v4().to_string(), - }; - let uuid_genc = uuid_gen.clone(); - - tokio::spawn(async move { - debug!("init ping service"); - state.ping_map.insert( - uuid_gen.clone(), - PingValue { - ip: device.ip.clone(), - online: false, - }, - ); - - crate::services::ping::spawn( - state.ping_send.clone(), - &state.config, - device, - uuid_gen.clone(), - &state.ping_map, - &state.db, - ) - .await; - }); - Some(uuid_genc) + Some(setup_ping(state, device)) } else { None }; @@ -93,6 +58,45 @@ pub async fn start( } } +fn setup_ping(state: Arc, device: Device) -> String { + let mut uuid: Option = None; + for (key, value) in state.ping_map.clone() { + if value.ip == device.ip { + debug!("service already exists"); + uuid = Some(key); + break; + } + } + let uuid_gen = match uuid { + Some(u) => u, + None => Uuid::new_v4().to_string(), + }; + let uuid_ret = uuid_gen.clone(); + + debug!("init ping service"); + state.ping_map.insert( + uuid_gen.clone(), + PingValue { + ip: device.ip, + online: false, + }, + ); + + tokio::spawn(async move { + crate::services::ping::spawn( + state.ping_send.clone(), + &state.config, + device, + uuid_gen, + &state.ping_map, + &state.db, + ) + .await; + }); + + uuid_ret +} + #[derive(Deserialize)] pub struct Payload { id: String, diff --git a/src/routes/status.rs b/src/routes/status.rs index 31ef996..0e25f7d 100644 --- a/src/routes/status.rs +++ b/src/routes/status.rs @@ -1,10 +1,79 @@ -use std::sync::Arc; +use crate::services::ping::BroadcastCommand; +use crate::AppState; +use axum::extract::ws::{Message, WebSocket}; use axum::extract::{State, WebSocketUpgrade}; use axum::response::Response; -use crate::AppState; -use crate::services::ping::status_websocket; +use sqlx::PgPool; +use std::sync::Arc; +use tracing::{debug, trace}; -#[axum_macros::debug_handler] pub async fn status(State(state): State>, ws: WebSocketUpgrade) -> Response { - ws.on_upgrade(move |socket| status_websocket(socket, state)) + ws.on_upgrade(move |socket| websocket(socket, state)) +} + +pub async fn websocket(mut socket: WebSocket, state: Arc) { + trace!("wait for ws message (uuid)"); + let msg = socket.recv().await; + let uuid = msg.unwrap().unwrap().into_text().unwrap(); + + trace!("Search for uuid: {}", uuid); + + let eta = get_eta(&state.db).await; + let _ = socket + .send(Message::Text(format!("eta_{eta}_{uuid}"))) + .await; + + let device_exists = state.ping_map.contains_key(&uuid); + if device_exists { + let _ = socket + .send(receive_ping_broadcast(state.clone(), uuid).await) + .await; + } else { + debug!("didn't find any device"); + let _ = socket.send(Message::Text(format!("notfound_{uuid}"))).await; + }; + + let _ = socket.close().await; +} + +async fn receive_ping_broadcast(state: Arc, uuid: String) -> Message { + let pm = state.ping_map.clone().into_read_only(); + let device = pm.get(&uuid).expect("fatal error"); + debug!("got device: {} (online: {})", device.ip, device.online); + if device.online { + debug!("already started"); + Message::Text(BroadcastCommand::success(uuid).to_string()) + } else { + loop { + trace!("wait for tx message"); + let message = state + .ping_send + .subscribe() + .recv() + .await + .expect("fatal error"); + trace!("got message {:?}", message); + + if message.uuid != uuid { + continue; + } + trace!("message == uuid success"); + return Message::Text(message.to_string()); + } + } +} + +async fn get_eta(db: &PgPool) -> i64 { + let query = sqlx::query!(r#"SELECT times FROM devices;"#) + .fetch_one(db) + .await + .unwrap(); + + let times = if let Some(times) = query.times { + times + } else { + vec![0] + }; + + times.iter().sum::() / i64::try_from(times.len()).unwrap() } -- cgit v1.2.3