From 2f9f18b80a9e2134f674f345e48a5f21de5efadd Mon Sep 17 00:00:00 2001 From: FxQnLr Date: Sun, 18 Feb 2024 21:16:46 +0100 Subject: Refactor stuff. Use Postgres Types --- ...f32262fd643b452aacca373ee527c978e816115de6.json | 8 +- ...f6b4f7ee6dae2130c2109fb6f1e47e0990ec395744.json | 4 +- ...b8cf06394a302d288e91f5eedde65db6630021f737.json | 4 +- Cargo.lock | 50 ++++++- Cargo.toml | 4 +- migrations/20231009123228_devices.sql | 4 +- src/db.rs | 6 +- src/error.rs | 22 +++ src/main.rs | 54 ++++---- src/routes.rs | 3 + src/routes/device.rs | 19 ++- src/routes/mod.rs | 3 - src/routes/start.rs | 78 ++++++----- src/routes/status.rs | 79 ++++++++++- src/services.rs | 1 + src/services/mod.rs | 1 - src/services/ping.rs | 154 +++++++++------------ 17 files changed, 317 insertions(+), 177 deletions(-) create mode 100644 src/routes.rs delete mode 100644 src/routes/mod.rs create mode 100644 src/services.rs delete mode 100644 src/services/mod.rs diff --git a/.sqlx/query-1dc5f44967ffdee882f4cef32262fd643b452aacca373ee527c978e816115de6.json b/.sqlx/query-1dc5f44967ffdee882f4cef32262fd643b452aacca373ee527c978e816115de6.json index 33d524d..dd85eaa 100644 --- a/.sqlx/query-1dc5f44967ffdee882f4cef32262fd643b452aacca373ee527c978e816115de6.json +++ b/.sqlx/query-1dc5f44967ffdee882f4cef32262fd643b452aacca373ee527c978e816115de6.json @@ -11,7 +11,7 @@ { "ordinal": 1, "name": "mac", - "type_info": "Varchar" + "type_info": "Macaddr" }, { "ordinal": 2, @@ -21,7 +21,7 @@ { "ordinal": 3, "name": "ip", - "type_info": "Varchar" + "type_info": "Inet" }, { "ordinal": 4, @@ -31,9 +31,9 @@ ], "parameters": { "Left": [ + "Macaddr", "Varchar", - "Varchar", - "Varchar", + "Inet", "Text" ] }, diff --git a/.sqlx/query-62c84231c7e9c85dc91d71f6b4f7ee6dae2130c2109fb6f1e47e0990ec395744.json b/.sqlx/query-62c84231c7e9c85dc91d71f6b4f7ee6dae2130c2109fb6f1e47e0990ec395744.json index 5ec47e3..905bb51 100644 --- a/.sqlx/query-62c84231c7e9c85dc91d71f6b4f7ee6dae2130c2109fb6f1e47e0990ec395744.json +++ b/.sqlx/query-62c84231c7e9c85dc91d71f6b4f7ee6dae2130c2109fb6f1e47e0990ec395744.json @@ -11,7 +11,7 @@ { "ordinal": 1, "name": "mac", - "type_info": "Varchar" + "type_info": "Macaddr" }, { "ordinal": 2, @@ -21,7 +21,7 @@ { "ordinal": 3, "name": "ip", - "type_info": "Varchar" + "type_info": "Inet" }, { "ordinal": 4, diff --git a/.sqlx/query-adead45e1a6b02d5eabd68b8cf06394a302d288e91f5eedde65db6630021f737.json b/.sqlx/query-adead45e1a6b02d5eabd68b8cf06394a302d288e91f5eedde65db6630021f737.json index bc4bdd3..d25b12e 100644 --- a/.sqlx/query-adead45e1a6b02d5eabd68b8cf06394a302d288e91f5eedde65db6630021f737.json +++ b/.sqlx/query-adead45e1a6b02d5eabd68b8cf06394a302d288e91f5eedde65db6630021f737.json @@ -6,9 +6,9 @@ "parameters": { "Left": [ "Varchar", + "Macaddr", "Varchar", - "Varchar", - "Varchar" + "Inet" ] }, "nullable": [] diff --git a/Cargo.lock b/Cargo.lock index 835335b..5d10375 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -71,7 +71,7 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edcdbedc2236483ab103a53415653d6b4442ea6141baf1ffa85df29635e88436" dependencies = [ - "nix", + "nix 0.27.1", "rand", ] @@ -826,6 +826,15 @@ dependencies = [ "hashbrown 0.14.3", ] +[[package]] +name = "ipnetwork" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e" +dependencies = [ + "serde", +] + [[package]] name = "itertools" version = "0.12.1" @@ -912,6 +921,17 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +[[package]] +name = "mac_address" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4863ee94f19ed315bf3bc00299338d857d4b5bc856af375cc97d237382ad3856" +dependencies = [ + "nix 0.23.2", + "serde", + "winapi", +] + [[package]] name = "matchers" version = "0.1.0" @@ -943,6 +963,15 @@ version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +[[package]] +name = "memoffset" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" +dependencies = [ + "autocfg", +] + [[package]] name = "mime" version = "0.3.17" @@ -975,6 +1004,19 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "nix" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f3790c00a0150112de0f4cd161e3d7fc4b2d8a5542ffc35f099a2562aecb35c" +dependencies = [ + "bitflags 1.3.2", + "cc", + "cfg-if", + "libc", + "memoffset", +] + [[package]] name = "nix" version = "0.27.1" @@ -1709,7 +1751,9 @@ dependencies = [ "hashlink", "hex", "indexmap", + "ipnetwork", "log", + "mac_address", "memchr", "once_cell", "paste", @@ -1829,8 +1873,10 @@ dependencies = [ "hkdf", "hmac", "home", + "ipnetwork", "itoa", "log", + "mac_address", "md-5", "memchr", "once_cell", @@ -2360,6 +2406,8 @@ dependencies = [ "color-eyre", "config", "dashmap", + "ipnetwork", + "mac_address", "serde", "serde_json", "sqlx", diff --git a/Cargo.toml b/Cargo.toml index f4633c9..c320da1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,10 +14,12 @@ time = { version = "0.3", features = ["macros"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" config = "0.14" -sqlx = { version = "0.7", features = ["postgres", "runtime-tokio"]} +sqlx = { version = "0.7", features = ["postgres", "runtime-tokio", "ipnetwork", "mac_address"]} surge-ping = "0.8" axum-macros = "0.4" uuid = { version = "1.6", features = ["v4", "fast-rng"] } dashmap = "5.5" color-eyre = "0.6" thiserror = "1.0" +ipnetwork = "0.20.0" +mac_address = { version = "1.1.5", features = ["serde"] } diff --git a/migrations/20231009123228_devices.sql b/migrations/20231009123228_devices.sql index d36946c..6983ada 100644 --- a/migrations/20231009123228_devices.sql +++ b/migrations/20231009123228_devices.sql @@ -2,8 +2,8 @@ CREATE TABLE IF NOT EXISTS "devices" ( "id" VARCHAR(255) PRIMARY KEY NOT NULL, - "mac" VARCHAR(17) NOT NULL, + "mac" MACADDR NOT NULL, "broadcast_addr" VARCHAR(39) NOT NULL, - "ip" VARCHAR(39) NOT NULL, + "ip" INET NOT NULL, "times" BIGINT[] ) diff --git a/src/db.rs b/src/db.rs index 489a000..47e907d 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,13 +1,13 @@ use serde::Serialize; -use sqlx::{PgPool, postgres::PgPoolOptions}; +use sqlx::{PgPool, postgres::PgPoolOptions, types::{ipnetwork::IpNetwork, mac_address::MacAddress}}; use tracing::{debug, info}; #[derive(Serialize, Debug)] pub struct Device { pub id: String, - pub mac: String, + pub mac: MacAddress, pub broadcast_addr: String, - pub ip: String, + pub ip: IpNetwork, pub times: Option> } diff --git a/src/error.rs b/src/error.rs index 63b214e..66a61f4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,6 +2,8 @@ use axum::http::header::ToStrError; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::Json; +use ::ipnetwork::IpNetworkError; +use mac_address::MacParseError; use serde_json::json; use std::io; use tracing::error; @@ -29,6 +31,18 @@ pub enum Error { source: ToStrError, }, + #[error("string parse: {source}")] + IpParse { + #[from] + source: IpNetworkError, + }, + + #[error("mac parse: {source}")] + MacParse { + #[from] + source: MacParseError, + }, + #[error("io: {source}")] Io { #[from] @@ -57,6 +71,14 @@ impl IntoResponse for Error { error!("{source}"); (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") } + Self::MacParse { source } => { + error!("{source}"); + (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") + } + Self::IpParse { source } => { + error!("{source}"); + (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") + } }; let body = Json(json!({ "error": error_message, diff --git a/src/main.rs b/src/main.rs index 4ef129b..7d8c1da 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,42 +1,44 @@ -use std::env; -use std::sync::Arc; -use axum::{Router, routing::post}; -use axum::routing::{get, put}; -use dashmap::DashMap; -use sqlx::PgPool; -use time::util::local_offset; -use tokio::sync::broadcast::{channel, Sender}; -use tracing::{info, level_filters::LevelFilter}; -use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*}; use crate::config::Config; use crate::db::init_db_pool; use crate::routes::device; use crate::routes::start::start; use crate::routes::status::status; -use crate::services::ping::{BroadcastCommands, StatusMap}; +use crate::services::ping::StatusMap; +use axum::routing::{get, put}; +use axum::{routing::post, Router}; +use dashmap::DashMap; +use services::ping::BroadcastCommand; +use sqlx::PgPool; +use tracing_subscriber::fmt::time::UtcTime; +use std::env; +use std::sync::Arc; +use tokio::sync::broadcast::{channel, Sender}; +use tracing::{info, level_filters::LevelFilter}; +use tracing_subscriber::{ + fmt, + prelude::*, + EnvFilter, +}; mod auth; mod config; -mod routes; -mod wol; mod db; mod error; +mod routes; mod services; +mod wol; #[tokio::main] async fn main() -> color_eyre::eyre::Result<()> { - color_eyre::install()?; + - unsafe { local_offset::set_soundness(local_offset::Soundness::Unsound); } let time_format = time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second]"); - let loc = LocalTime::new(time_format); + let loc = UtcTime::new(time_format); tracing_subscriber::registry() - .with(fmt::layer() - .with_timer(loc) - ) + .with(fmt::layer().with_timer(loc)) .with( EnvFilter::builder() .with_default_directive(LevelFilter::INFO.into()) @@ -56,8 +58,13 @@ async fn main() -> color_eyre::eyre::Result<()> { let (tx, _) = channel(32); let ping_map: StatusMap = DashMap::new(); - - let shared_state = Arc::new(AppState { db, config: config.clone(), ping_send: tx, ping_map }); + + let shared_state = Arc::new(AppState { + db, + config: config.clone(), + ping_send: tx, + ping_map, + }); let app = Router::new() .route("/start", post(start)) @@ -69,8 +76,7 @@ async fn main() -> color_eyre::eyre::Result<()> { let addr = config.serveraddr; info!("start server on {}", addr); - let listener = tokio::net::TcpListener::bind(addr) - .await?; + let listener = tokio::net::TcpListener::bind(addr).await?; axum::serve(listener, app).await?; Ok(()) @@ -79,6 +85,6 @@ async fn main() -> color_eyre::eyre::Result<()> { pub struct AppState { db: PgPool, config: Config, - ping_send: Sender, + ping_send: Sender, ping_map: StatusMap, } diff --git a/src/routes.rs b/src/routes.rs new file mode 100644 index 0000000..d5ab0d6 --- /dev/null +++ b/src/routes.rs @@ -0,0 +1,3 @@ +pub mod start; +pub mod device; +pub mod status; \ No newline at end of file diff --git a/src/routes/device.rs b/src/routes/device.rs index 5ca574a..2f0093d 100644 --- a/src/routes/device.rs +++ b/src/routes/device.rs @@ -4,9 +4,11 @@ use crate::error::Error; use axum::extract::State; use axum::http::HeaderMap; use axum::Json; +use mac_address::MacAddress; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use std::sync::Arc; +use sqlx::types::ipnetwork::IpNetwork; +use std::{sync::Arc, str::FromStr}; use tracing::{debug, info}; pub async fn get( @@ -14,7 +16,7 @@ pub async fn get( headers: HeaderMap, Json(payload): Json, ) -> Result, Error> { - info!("add device {}", payload.id); + info!("get device {}", payload.id); let secret = headers.get("authorization"); let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); if authorized { @@ -52,18 +54,21 @@ pub async fn put( "add device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip ); + let secret = headers.get("authorization"); let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); if authorized { + let ip = IpNetwork::from_str(&payload.ip)?; + let mac = MacAddress::from_str(&payload.mac)?; sqlx::query!( r#" INSERT INTO devices (id, mac, broadcast_addr, ip) VALUES ($1, $2, $3, $4); "#, payload.id, - payload.mac, + mac, payload.broadcast_addr, - payload.ip + ip ) .execute(&state.db) .await?; @@ -99,6 +104,8 @@ pub async fn post( let secret = headers.get("authorization"); let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); if authorized { + let ip = IpNetwork::from_str(&payload.ip)?; + let mac = MacAddress::from_str(&payload.mac)?; let device = sqlx::query_as!( Device, r#" @@ -106,9 +113,9 @@ pub async fn post( SET mac = $1, broadcast_addr = $2, ip = $3 WHERE id = $4 RETURNING id, mac, broadcast_addr, ip, times; "#, - payload.mac, + mac, payload.broadcast_addr, - payload.ip, + ip, payload.id ) .fetch_one(&state.db) diff --git a/src/routes/mod.rs b/src/routes/mod.rs deleted file mode 100644 index d5ab0d6..0000000 --- a/src/routes/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod start; -pub mod device; -pub mod status; \ No newline at end of file diff --git a/src/routes/start.rs b/src/routes/start.rs index ec4f98f..4888325 100644 --- a/src/routes/start.rs +++ b/src/routes/start.rs @@ -12,7 +12,6 @@ use std::sync::Arc; use tracing::{debug, info}; use uuid::Uuid; -#[axum_macros::debug_handler] pub async fn start( State(state): State>, headers: HeaderMap, @@ -41,45 +40,11 @@ pub async fn start( let _ = send_packet( bind_addr, &device.broadcast_addr, - &create_buffer(&device.mac)?, + &create_buffer(&device.mac.to_string())?, )?; let dev_id = device.id.clone(); let uuid = if payload.ping.is_some_and(|ping| ping) { - let mut uuid: Option = None; - for (key, value) in state.ping_map.clone() { - if value.ip == device.ip { - debug!("service already exists"); - uuid = Some(key); - break; - } - } - let uuid_gen = match uuid { - Some(u) => u, - None => Uuid::new_v4().to_string(), - }; - let uuid_genc = uuid_gen.clone(); - - tokio::spawn(async move { - debug!("init ping service"); - state.ping_map.insert( - uuid_gen.clone(), - PingValue { - ip: device.ip.clone(), - online: false, - }, - ); - - crate::services::ping::spawn( - state.ping_send.clone(), - &state.config, - device, - uuid_gen.clone(), - &state.ping_map, - &state.db, - ) - .await; - }); - Some(uuid_genc) + Some(setup_ping(state, device)) } else { None }; @@ -93,6 +58,45 @@ pub async fn start( } } +fn setup_ping(state: Arc, device: Device) -> String { + let mut uuid: Option = None; + for (key, value) in state.ping_map.clone() { + if value.ip == device.ip { + debug!("service already exists"); + uuid = Some(key); + break; + } + } + let uuid_gen = match uuid { + Some(u) => u, + None => Uuid::new_v4().to_string(), + }; + let uuid_ret = uuid_gen.clone(); + + debug!("init ping service"); + state.ping_map.insert( + uuid_gen.clone(), + PingValue { + ip: device.ip, + online: false, + }, + ); + + tokio::spawn(async move { + crate::services::ping::spawn( + state.ping_send.clone(), + &state.config, + device, + uuid_gen, + &state.ping_map, + &state.db, + ) + .await; + }); + + uuid_ret +} + #[derive(Deserialize)] pub struct Payload { id: String, diff --git a/src/routes/status.rs b/src/routes/status.rs index 31ef996..0e25f7d 100644 --- a/src/routes/status.rs +++ b/src/routes/status.rs @@ -1,10 +1,79 @@ -use std::sync::Arc; +use crate::services::ping::BroadcastCommand; +use crate::AppState; +use axum::extract::ws::{Message, WebSocket}; use axum::extract::{State, WebSocketUpgrade}; use axum::response::Response; -use crate::AppState; -use crate::services::ping::status_websocket; +use sqlx::PgPool; +use std::sync::Arc; +use tracing::{debug, trace}; -#[axum_macros::debug_handler] pub async fn status(State(state): State>, ws: WebSocketUpgrade) -> Response { - ws.on_upgrade(move |socket| status_websocket(socket, state)) + ws.on_upgrade(move |socket| websocket(socket, state)) +} + +pub async fn websocket(mut socket: WebSocket, state: Arc) { + trace!("wait for ws message (uuid)"); + let msg = socket.recv().await; + let uuid = msg.unwrap().unwrap().into_text().unwrap(); + + trace!("Search for uuid: {}", uuid); + + let eta = get_eta(&state.db).await; + let _ = socket + .send(Message::Text(format!("eta_{eta}_{uuid}"))) + .await; + + let device_exists = state.ping_map.contains_key(&uuid); + if device_exists { + let _ = socket + .send(receive_ping_broadcast(state.clone(), uuid).await) + .await; + } else { + debug!("didn't find any device"); + let _ = socket.send(Message::Text(format!("notfound_{uuid}"))).await; + }; + + let _ = socket.close().await; +} + +async fn receive_ping_broadcast(state: Arc, uuid: String) -> Message { + let pm = state.ping_map.clone().into_read_only(); + let device = pm.get(&uuid).expect("fatal error"); + debug!("got device: {} (online: {})", device.ip, device.online); + if device.online { + debug!("already started"); + Message::Text(BroadcastCommand::success(uuid).to_string()) + } else { + loop { + trace!("wait for tx message"); + let message = state + .ping_send + .subscribe() + .recv() + .await + .expect("fatal error"); + trace!("got message {:?}", message); + + if message.uuid != uuid { + continue; + } + trace!("message == uuid success"); + return Message::Text(message.to_string()); + } + } +} + +async fn get_eta(db: &PgPool) -> i64 { + let query = sqlx::query!(r#"SELECT times FROM devices;"#) + .fetch_one(db) + .await + .unwrap(); + + let times = if let Some(times) = query.times { + times + } else { + vec![0] + }; + + times.iter().sum::() / i64::try_from(times.len()).unwrap() } diff --git a/src/services.rs b/src/services.rs new file mode 100644 index 0000000..a766209 --- /dev/null +++ b/src/services.rs @@ -0,0 +1 @@ +pub mod ping; diff --git a/src/services/mod.rs b/src/services/mod.rs deleted file mode 100644 index a766209..0000000 --- a/src/services/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod ping; diff --git a/src/services/ping.rs b/src/services/ping.rs index 9b164c8..9191f86 100644 --- a/src/services/ping.rs +++ b/src/services/ping.rs @@ -1,59 +1,58 @@ -use std::str::FromStr; -use std::net::IpAddr; -use std::sync::Arc; - -use axum::extract::ws::WebSocket; -use axum::extract::ws::Message; +use crate::config::Config; +use crate::db::Device; use dashmap::DashMap; +use ipnetwork::IpNetwork; use sqlx::PgPool; +use std::fmt::Display; use time::{Duration, Instant}; use tokio::sync::broadcast::Sender; use tracing::{debug, error, trace}; -use crate::AppState; -use crate::config::Config; -use crate::db::Device; pub type StatusMap = DashMap; #[derive(Debug, Clone)] pub struct Value { - pub ip: String, - pub online: bool + pub ip: IpNetwork, + pub online: bool, } -pub async fn spawn(tx: Sender, config: &Config, device: Device, uuid: String, ping_map: &StatusMap, db: &PgPool) { +pub async fn spawn( + tx: Sender, + config: &Config, + device: Device, + uuid: String, + ping_map: &StatusMap, + db: &PgPool, +) { let timer = Instant::now(); let payload = [0; 8]; - let ping_ip = IpAddr::from_str(&device.ip).expect("bad ip"); - - let mut msg: Option = None; + let mut msg: Option = None; while msg.is_none() { - let ping = surge_ping::ping( - ping_ip, - &payload - ).await; + let ping = surge_ping::ping(device.ip.ip(), &payload).await; if let Err(ping) = ping { let ping_timeout = matches!(ping, surge_ping::SurgeError::Timeout { .. }); if !ping_timeout { error!("{}", ping.to_string()); - msg = Some(BroadcastCommands::Error(uuid.clone())); + msg = Some(BroadcastCommand::error(uuid.clone())); } if timer.elapsed() >= Duration::minutes(config.pingtimeout) { - msg = Some(BroadcastCommands::Timeout(uuid.clone())); + msg = Some(BroadcastCommand::timeout(uuid.clone())); } } else { - let (_, duration) = ping.map_err(|err| error!("{}", err.to_string())).expect("fatal error"); + let (_, duration) = ping + .map_err(|err| error!("{}", err.to_string())) + .expect("fatal error"); debug!("ping took {:?}", duration); - msg = Some(BroadcastCommands::Success(uuid.clone())); + msg = Some(BroadcastCommand::success(uuid.clone())); }; } let msg = msg.expect("fatal error"); let _ = tx.send(msg.clone()); - if let BroadcastCommands::Success(..) = msg { + if let BroadcastCommands::Success = msg.command { sqlx::query!( r#" UPDATE devices @@ -62,8 +61,17 @@ pub async fn spawn(tx: Sender, config: &Config, device: Devic "#, timer.elapsed().whole_seconds(), device.id - ).execute(db).await.unwrap(); - ping_map.insert(uuid.clone(), Value { ip: device.ip.clone(), online: true }); + ) + .execute(db) + .await + .unwrap(); + ping_map.insert( + uuid.clone(), + Value { + ip: device.ip, + online: true, + }, + ); tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; } trace!("remove {} from ping_map", uuid); @@ -72,74 +80,48 @@ pub async fn spawn(tx: Sender, config: &Config, device: Devic #[derive(Clone, Debug, PartialEq)] pub enum BroadcastCommands { - Success(String), - Timeout(String), - Error(String), + Success, + Timeout, + Error, } -pub async fn status_websocket(mut socket: WebSocket, state: Arc) { - trace!("wait for ws message (uuid)"); - let msg = socket.recv().await; - let uuid = msg.unwrap().unwrap().into_text().unwrap(); - - trace!("Search for uuid: {}", uuid); - - let eta = get_eta(&state.db).await; - let _ = socket.send(Message::Text(format!("eta_{eta}_{uuid}"))).await; +#[derive(Clone, Debug, PartialEq)] +pub struct BroadcastCommand { + pub uuid: String, + pub command: BroadcastCommands, +} - let device_exists = state.ping_map.contains_key(&uuid); - if device_exists { - let _ = socket.send(process_device(state.clone(), uuid).await).await; - } else { - debug!("didn't find any device"); - let _ = socket.send(Message::Text(format!("notfound_{uuid}"))).await; - }; +impl Display for BroadcastCommand { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let prefix = match self.command { + BroadcastCommands::Success => "start", + BroadcastCommands::Timeout => "timeout", + BroadcastCommands::Error => "error", + }; - let _ = socket.close().await; + f.write_str(format!("{prefix}_{}", self.uuid).as_str()) + } } -async fn get_eta(db: &PgPool) -> i64 { - let query = sqlx::query!( - r#"SELECT times FROM devices;"# - ).fetch_one(db).await.unwrap(); - - let times = match query.times { - None => { vec![0] }, - Some(t) => t, - }; - times.iter().sum::() / i64::try_from(times.len()).unwrap() +impl BroadcastCommand { + pub fn success(uuid: String) -> Self { + Self { + uuid, + command: BroadcastCommands::Success, + } + } -} + pub fn timeout(uuid: String) -> Self { + Self { + uuid, + command: BroadcastCommands::Timeout, + } + } -async fn process_device(state: Arc, uuid: String) -> Message { - let pm = state.ping_map.clone().into_read_only(); - let device = pm.get(&uuid).expect("fatal error"); - debug!("got device: {} (online: {})", device.ip, device.online); - if device.online { - debug!("already started"); - Message::Text(format!("start_{uuid}")) - } else { - loop { - trace!("wait for tx message"); - let message = state.ping_send.subscribe().recv().await.expect("fatal error"); - trace!("got message {:?}", message); - return match message { - BroadcastCommands::Success(msg_uuid) => { - if msg_uuid != uuid { continue; } - trace!("message == uuid success"); - Message::Text(format!("start_{uuid}")) - }, - BroadcastCommands::Timeout(msg_uuid) => { - if msg_uuid != uuid { continue; } - trace!("message == uuid timeout"); - Message::Text(format!("timeout_{uuid}")) - }, - BroadcastCommands::Error(msg_uuid) => { - if msg_uuid != uuid { continue; } - trace!("message == uuid error"); - Message::Text(format!("error_{uuid}")) - } - } + pub fn error(uuid: String) -> Self { + Self { + uuid, + command: BroadcastCommands::Error, } } } -- cgit v1.2.3