diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/auth.rs | 24 | ||||
-rw-r--r-- | src/config.rs | 32 | ||||
-rw-r--r-- | src/db.rs | 16 | ||||
-rw-r--r-- | src/error.rs | 8 | ||||
-rw-r--r-- | src/main.rs | 41 | ||||
-rw-r--r-- | src/routes/device.rs | 28 | ||||
-rw-r--r-- | src/routes/start.rs | 76 | ||||
-rw-r--r-- | src/routes/status.rs | 2 | ||||
-rw-r--r-- | src/services/ping.rs | 84 | ||||
-rw-r--r-- | src/wol.rs | 14 |
10 files changed, 172 insertions, 153 deletions
diff --git a/src/auth.rs b/src/auth.rs index e4b1c2f..feca652 100644 --- a/src/auth.rs +++ b/src/auth.rs | |||
@@ -1,18 +1,15 @@ | |||
1 | use axum::headers::HeaderValue; | 1 | use axum::http::{StatusCode, HeaderValue}; |
2 | use axum::http::StatusCode; | ||
3 | use axum::http::header::ToStrError; | 2 | use axum::http::header::ToStrError; |
4 | use tracing::{debug, error, trace}; | 3 | use tracing::{debug, error, trace}; |
5 | use crate::auth::AuthError::{MissingSecret, WrongSecret}; | 4 | use crate::auth::Error::{MissingSecret, WrongSecret}; |
6 | use crate::config::SETTINGS; | 5 | use crate::config::Config; |
7 | 6 | ||
8 | pub fn auth(secret: Option<&HeaderValue>) -> Result<bool, AuthError> { | 7 | pub fn auth(config: &Config, secret: Option<&HeaderValue>) -> Result<bool, Error> { |
9 | debug!("auth request with secret {:?}", secret); | 8 | debug!("auth request with secret {:?}", secret); |
10 | if let Some(value) = secret { | 9 | if let Some(value) = secret { |
11 | trace!("value exists"); | 10 | trace!("value exists"); |
12 | let key = SETTINGS | 11 | let key = &config.apikey; |
13 | .get_string("apikey") | 12 | if value.to_str().map_err(Error::HeaderToStr)? == key.as_str() { |
14 | .map_err(AuthError::Config)?; | ||
15 | if value.to_str().map_err(AuthError::HeaderToStr)? == key.as_str() { | ||
16 | debug!("successful auth"); | 13 | debug!("successful auth"); |
17 | Ok(true) | 14 | Ok(true) |
18 | } else { | 15 | } else { |
@@ -26,22 +23,17 @@ pub fn auth(secret: Option<&HeaderValue>) -> Result<bool, AuthError> { | |||
26 | } | 23 | } |
27 | 24 | ||
28 | #[derive(Debug)] | 25 | #[derive(Debug)] |
29 | pub enum AuthError { | 26 | pub enum Error { |
30 | WrongSecret, | 27 | WrongSecret, |
31 | MissingSecret, | 28 | MissingSecret, |
32 | Config(config::ConfigError), | ||
33 | HeaderToStr(ToStrError) | 29 | HeaderToStr(ToStrError) |
34 | } | 30 | } |
35 | 31 | ||
36 | impl AuthError { | 32 | impl Error { |
37 | pub fn get(self) -> (StatusCode, &'static str) { | 33 | pub fn get(self) -> (StatusCode, &'static str) { |
38 | match self { | 34 | match self { |
39 | Self::WrongSecret => (StatusCode::UNAUTHORIZED, "Wrong credentials"), | 35 | Self::WrongSecret => (StatusCode::UNAUTHORIZED, "Wrong credentials"), |
40 | Self::MissingSecret => (StatusCode::BAD_REQUEST, "Missing credentials"), | 36 | Self::MissingSecret => (StatusCode::BAD_REQUEST, "Missing credentials"), |
41 | Self::Config(err) => { | ||
42 | error!("server error: {}", err.to_string()); | ||
43 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | ||
44 | }, | ||
45 | Self::HeaderToStr(err) => { | 37 | Self::HeaderToStr(err) => { |
46 | error!("server error: {}", err.to_string()); | 38 | error!("server error: {}", err.to_string()); |
47 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | 39 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") |
diff --git a/src/config.rs b/src/config.rs index 4c79810..4319ffc 100644 --- a/src/config.rs +++ b/src/config.rs | |||
@@ -1,11 +1,25 @@ | |||
1 | use config::Config; | 1 | use config::File; |
2 | use once_cell::sync::Lazy; | 2 | use serde::Deserialize; |
3 | 3 | ||
4 | pub static SETTINGS: Lazy<Config> = Lazy::new(setup); | 4 | #[derive(Debug, Clone, Deserialize)] |
5 | pub struct Config { | ||
6 | pub database_url: String, | ||
7 | pub apikey: String, | ||
8 | pub serveraddr: String, | ||
9 | pub pingtimeout: i64, | ||
10 | } | ||
11 | |||
12 | impl Config { | ||
13 | pub fn load() -> Result<Self, config::ConfigError> { | ||
14 | let config = config::Config::builder() | ||
15 | .set_default("serveraddr", "0.0.0.0:7229")? | ||
16 | .set_default("pingtimeout", 10)? | ||
17 | .add_source(File::with_name("config.toml").required(false)) | ||
18 | .add_source(File::with_name("config.dev.toml").required(false)) | ||
19 | .add_source(config::Environment::with_prefix("WEBOL").prefix_separator("_")) | ||
20 | .build()?; | ||
21 | |||
22 | config.try_deserialize() | ||
23 | } | ||
24 | } | ||
5 | 25 | ||
6 | fn setup() -> Config { | ||
7 | Config::builder() | ||
8 | .add_source(config::Environment::with_prefix("WEBOL").separator("_")) | ||
9 | .build() | ||
10 | .unwrap() | ||
11 | } \ No newline at end of file | ||
@@ -1,13 +1,7 @@ | |||
1 | #[cfg(debug_assertions)] | ||
2 | use std::env; | ||
3 | |||
4 | use serde::Serialize; | 1 | use serde::Serialize; |
5 | use sqlx::{PgPool, postgres::PgPoolOptions}; | 2 | use sqlx::{PgPool, postgres::PgPoolOptions}; |
6 | use tracing::{debug, info}; | 3 | use tracing::{debug, info}; |
7 | 4 | ||
8 | #[cfg(not(debug_assertions))] | ||
9 | use crate::config::SETTINGS; | ||
10 | |||
11 | #[derive(Serialize, Debug)] | 5 | #[derive(Serialize, Debug)] |
12 | pub struct Device { | 6 | pub struct Device { |
13 | pub id: String, | 7 | pub id: String, |
@@ -17,18 +11,12 @@ pub struct Device { | |||
17 | pub times: Option<Vec<i64>> | 11 | pub times: Option<Vec<i64>> |
18 | } | 12 | } |
19 | 13 | ||
20 | pub async fn init_db_pool() -> PgPool { | 14 | pub async fn init_db_pool(db_url: &str) -> PgPool { |
21 | #[cfg(not(debug_assertions))] | ||
22 | let db_url = SETTINGS.get_string("database.url").unwrap(); | ||
23 | |||
24 | #[cfg(debug_assertions)] | ||
25 | let db_url = env::var("DATABASE_URL").unwrap(); | ||
26 | |||
27 | debug!("attempt to connect dbPool to '{}'", db_url); | 15 | debug!("attempt to connect dbPool to '{}'", db_url); |
28 | 16 | ||
29 | let pool = PgPoolOptions::new() | 17 | let pool = PgPoolOptions::new() |
30 | .max_connections(5) | 18 | .max_connections(5) |
31 | .connect(&db_url) | 19 | .connect(db_url) |
32 | .await | 20 | .await |
33 | .unwrap(); | 21 | .unwrap(); |
34 | 22 | ||
diff --git a/src/error.rs b/src/error.rs index 5b82534..56d6c52 100644 --- a/src/error.rs +++ b/src/error.rs | |||
@@ -4,10 +4,10 @@ use axum::Json; | |||
4 | use axum::response::{IntoResponse, Response}; | 4 | use axum::response::{IntoResponse, Response}; |
5 | use serde_json::json; | 5 | use serde_json::json; |
6 | use tracing::error; | 6 | use tracing::error; |
7 | use crate::auth::AuthError; | 7 | use crate::auth::Error as AuthError; |
8 | 8 | ||
9 | #[derive(Debug)] | 9 | #[derive(Debug)] |
10 | pub enum WebolError { | 10 | pub enum Error { |
11 | Generic, | 11 | Generic, |
12 | Auth(AuthError), | 12 | Auth(AuthError), |
13 | DB(sqlx::Error), | 13 | DB(sqlx::Error), |
@@ -16,7 +16,7 @@ pub enum WebolError { | |||
16 | Broadcast(io::Error), | 16 | Broadcast(io::Error), |
17 | } | 17 | } |
18 | 18 | ||
19 | impl IntoResponse for WebolError { | 19 | impl IntoResponse for Error { |
20 | fn into_response(self) -> Response { | 20 | fn into_response(self) -> Response { |
21 | let (status, error_message) = match self { | 21 | let (status, error_message) = match self { |
22 | Self::Auth(err) => { | 22 | Self::Auth(err) => { |
@@ -45,4 +45,4 @@ impl IntoResponse for WebolError { | |||
45 | })); | 45 | })); |
46 | (status, body).into_response() | 46 | (status, body).into_response() |
47 | } | 47 | } |
48 | } \ No newline at end of file | 48 | } |
diff --git a/src/main.rs b/src/main.rs index e96b736..4ef129b 100644 --- a/src/main.rs +++ b/src/main.rs | |||
@@ -8,12 +8,12 @@ use time::util::local_offset; | |||
8 | use tokio::sync::broadcast::{channel, Sender}; | 8 | use tokio::sync::broadcast::{channel, Sender}; |
9 | use tracing::{info, level_filters::LevelFilter}; | 9 | use tracing::{info, level_filters::LevelFilter}; |
10 | use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*}; | 10 | use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*}; |
11 | use crate::config::SETTINGS; | 11 | use crate::config::Config; |
12 | use crate::db::init_db_pool; | 12 | use crate::db::init_db_pool; |
13 | use crate::routes::device::{get_device, post_device, put_device}; | 13 | use crate::routes::device; |
14 | use crate::routes::start::start; | 14 | use crate::routes::start::start; |
15 | use crate::routes::status::status; | 15 | use crate::routes::status::status; |
16 | use crate::services::ping::{BroadcastCommands, PingMap}; | 16 | use crate::services::ping::{BroadcastCommands, StatusMap}; |
17 | 17 | ||
18 | mod auth; | 18 | mod auth; |
19 | mod config; | 19 | mod config; |
@@ -24,7 +24,10 @@ mod error; | |||
24 | mod services; | 24 | mod services; |
25 | 25 | ||
26 | #[tokio::main] | 26 | #[tokio::main] |
27 | async fn main() { | 27 | async fn main() -> color_eyre::eyre::Result<()> { |
28 | |||
29 | color_eyre::install()?; | ||
30 | |||
28 | unsafe { local_offset::set_soundness(local_offset::Soundness::Unsound); } | 31 | unsafe { local_offset::set_soundness(local_offset::Soundness::Unsound); } |
29 | let time_format = | 32 | let time_format = |
30 | time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second]"); | 33 | time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second]"); |
@@ -43,35 +46,39 @@ async fn main() { | |||
43 | 46 | ||
44 | let version = env!("CARGO_PKG_VERSION"); | 47 | let version = env!("CARGO_PKG_VERSION"); |
45 | 48 | ||
49 | let config = Config::load()?; | ||
50 | |||
46 | info!("start webol v{}", version); | 51 | info!("start webol v{}", version); |
47 | 52 | ||
48 | let db = init_db_pool().await; | 53 | let db = init_db_pool(&config.database_url).await; |
49 | sqlx::migrate!().run(&db).await.unwrap(); | 54 | sqlx::migrate!().run(&db).await.unwrap(); |
50 | 55 | ||
51 | let (tx, _) = channel(32); | 56 | let (tx, _) = channel(32); |
52 | 57 | ||
53 | let ping_map: PingMap = DashMap::new(); | 58 | let ping_map: StatusMap = DashMap::new(); |
54 | 59 | ||
55 | let shared_state = Arc::new(AppState { db, ping_send: tx, ping_map }); | 60 | let shared_state = Arc::new(AppState { db, config: config.clone(), ping_send: tx, ping_map }); |
56 | 61 | ||
57 | let app = Router::new() | 62 | let app = Router::new() |
58 | .route("/start", post(start)) | 63 | .route("/start", post(start)) |
59 | .route("/device", get(get_device)) | 64 | .route("/device", get(device::get)) |
60 | .route("/device", put(put_device)) | 65 | .route("/device", put(device::put)) |
61 | .route("/device", post(post_device)) | 66 | .route("/device", post(device::post)) |
62 | .route("/status", get(status)) | 67 | .route("/status", get(status)) |
63 | .with_state(shared_state); | 68 | .with_state(shared_state); |
64 | 69 | ||
65 | let addr = SETTINGS.get_string("serveraddr").unwrap_or("0.0.0.0:7229".to_string()); | 70 | let addr = config.serveraddr; |
66 | info!("start server on {}", addr); | 71 | info!("start server on {}", addr); |
67 | axum::Server::bind(&addr.parse().unwrap()) | 72 | let listener = tokio::net::TcpListener::bind(addr) |
68 | .serve(app.into_make_service()) | 73 | .await?; |
69 | .await | 74 | axum::serve(listener, app).await?; |
70 | .unwrap(); | 75 | |
76 | Ok(()) | ||
71 | } | 77 | } |
72 | 78 | ||
73 | pub struct AppState { | 79 | pub struct AppState { |
74 | db: PgPool, | 80 | db: PgPool, |
81 | config: Config, | ||
75 | ping_send: Sender<BroadcastCommands>, | 82 | ping_send: Sender<BroadcastCommands>, |
76 | ping_map: PingMap, | 83 | ping_map: StatusMap, |
77 | } \ No newline at end of file | 84 | } |
diff --git a/src/routes/device.rs b/src/routes/device.rs index 678d117..c85df1b 100644 --- a/src/routes/device.rs +++ b/src/routes/device.rs | |||
@@ -1,18 +1,18 @@ | |||
1 | use std::sync::Arc; | 1 | use std::sync::Arc; |
2 | use axum::extract::State; | 2 | use axum::extract::State; |
3 | use axum::headers::HeaderMap; | ||
4 | use axum::Json; | 3 | use axum::Json; |
4 | use axum::http::HeaderMap; | ||
5 | use serde::{Deserialize, Serialize}; | 5 | use serde::{Deserialize, Serialize}; |
6 | use serde_json::{json, Value}; | 6 | use serde_json::{json, Value}; |
7 | use tracing::{debug, info}; | 7 | use tracing::{debug, info}; |
8 | use crate::auth::auth; | 8 | use crate::auth::auth; |
9 | use crate::db::Device; | 9 | use crate::db::Device; |
10 | use crate::error::WebolError; | 10 | use crate::error::Error; |
11 | 11 | ||
12 | pub async fn get_device(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<GetDevicePayload>) -> Result<Json<Value>, WebolError> { | 12 | pub async fn get(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<GetDevicePayload>) -> Result<Json<Value>, Error> { |
13 | info!("add device {}", payload.id); | 13 | info!("add device {}", payload.id); |
14 | let secret = headers.get("authorization"); | 14 | let secret = headers.get("authorization"); |
15 | if auth(secret).map_err(WebolError::Auth)? { | 15 | if auth(&state.config, secret).map_err(Error::Auth)? { |
16 | let device = sqlx::query_as!( | 16 | let device = sqlx::query_as!( |
17 | Device, | 17 | Device, |
18 | r#" | 18 | r#" |
@@ -21,13 +21,13 @@ pub async fn get_device(State(state): State<Arc<crate::AppState>>, headers: Head | |||
21 | WHERE id = $1; | 21 | WHERE id = $1; |
22 | "#, | 22 | "#, |
23 | payload.id | 23 | payload.id |
24 | ).fetch_one(&state.db).await.map_err(WebolError::DB)?; | 24 | ).fetch_one(&state.db).await.map_err(Error::DB)?; |
25 | 25 | ||
26 | debug!("got device {:?}", device); | 26 | debug!("got device {:?}", device); |
27 | 27 | ||
28 | Ok(Json(json!(device))) | 28 | Ok(Json(json!(device))) |
29 | } else { | 29 | } else { |
30 | Err(WebolError::Generic) | 30 | Err(Error::Generic) |
31 | } | 31 | } |
32 | } | 32 | } |
33 | 33 | ||
@@ -36,10 +36,10 @@ pub struct GetDevicePayload { | |||
36 | id: String, | 36 | id: String, |
37 | } | 37 | } |
38 | 38 | ||
39 | pub async fn put_device(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PutDevicePayload>) -> Result<Json<Value>, WebolError> { | 39 | pub async fn put(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PutDevicePayload>) -> Result<Json<Value>, Error> { |
40 | info!("add device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); | 40 | info!("add device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); |
41 | let secret = headers.get("authorization"); | 41 | let secret = headers.get("authorization"); |
42 | if auth(secret).map_err(WebolError::Auth)? { | 42 | if auth(&state.config, secret).map_err(Error::Auth)? { |
43 | sqlx::query!( | 43 | sqlx::query!( |
44 | r#" | 44 | r#" |
45 | INSERT INTO devices (id, mac, broadcast_addr, ip) | 45 | INSERT INTO devices (id, mac, broadcast_addr, ip) |
@@ -49,11 +49,11 @@ pub async fn put_device(State(state): State<Arc<crate::AppState>>, headers: Head | |||
49 | payload.mac, | 49 | payload.mac, |
50 | payload.broadcast_addr, | 50 | payload.broadcast_addr, |
51 | payload.ip | 51 | payload.ip |
52 | ).execute(&state.db).await.map_err(WebolError::DB)?; | 52 | ).execute(&state.db).await.map_err(Error::DB)?; |
53 | 53 | ||
54 | Ok(Json(json!(PutDeviceResponse { success: true }))) | 54 | Ok(Json(json!(PutDeviceResponse { success: true }))) |
55 | } else { | 55 | } else { |
56 | Err(WebolError::Generic) | 56 | Err(Error::Generic) |
57 | } | 57 | } |
58 | } | 58 | } |
59 | 59 | ||
@@ -70,10 +70,10 @@ pub struct PutDeviceResponse { | |||
70 | success: bool | 70 | success: bool |
71 | } | 71 | } |
72 | 72 | ||
73 | pub async fn post_device(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PostDevicePayload>) -> Result<Json<Value>, WebolError> { | 73 | pub async fn post(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PostDevicePayload>) -> Result<Json<Value>, Error> { |
74 | info!("edit device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); | 74 | info!("edit device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); |
75 | let secret = headers.get("authorization"); | 75 | let secret = headers.get("authorization"); |
76 | if auth(secret).map_err(WebolError::Auth)? { | 76 | if auth(&state.config, secret).map_err(Error::Auth)? { |
77 | let device = sqlx::query_as!( | 77 | let device = sqlx::query_as!( |
78 | Device, | 78 | Device, |
79 | r#" | 79 | r#" |
@@ -85,11 +85,11 @@ pub async fn post_device(State(state): State<Arc<crate::AppState>>, headers: Hea | |||
85 | payload.broadcast_addr, | 85 | payload.broadcast_addr, |
86 | payload.ip, | 86 | payload.ip, |
87 | payload.id | 87 | payload.id |
88 | ).fetch_one(&state.db).await.map_err(WebolError::DB)?; | 88 | ).fetch_one(&state.db).await.map_err(Error::DB)?; |
89 | 89 | ||
90 | Ok(Json(json!(device))) | 90 | Ok(Json(json!(device))) |
91 | } else { | 91 | } else { |
92 | Err(WebolError::Generic) | 92 | Err(Error::Generic) |
93 | } | 93 | } |
94 | } | 94 | } |
95 | 95 | ||
diff --git a/src/routes/start.rs b/src/routes/start.rs index 1555db3..ce95bf3 100644 --- a/src/routes/start.rs +++ b/src/routes/start.rs | |||
@@ -1,23 +1,26 @@ | |||
1 | use axum::headers::HeaderMap; | 1 | use crate::auth::auth; |
2 | use crate::db::Device; | ||
3 | use crate::error::Error; | ||
4 | use crate::services::ping::Value as PingValue; | ||
5 | use crate::wol::{create_buffer, send_packet}; | ||
6 | use axum::extract::State; | ||
7 | use axum::http::HeaderMap; | ||
2 | use axum::Json; | 8 | use axum::Json; |
3 | use serde::{Deserialize, Serialize}; | 9 | use serde::{Deserialize, Serialize}; |
4 | use std::sync::Arc; | ||
5 | use axum::extract::State; | ||
6 | use serde_json::{json, Value}; | 10 | use serde_json::{json, Value}; |
11 | use std::sync::Arc; | ||
7 | use tracing::{debug, info}; | 12 | use tracing::{debug, info}; |
8 | use uuid::Uuid; | 13 | use uuid::Uuid; |
9 | use crate::auth::auth; | ||
10 | use crate::config::SETTINGS; | ||
11 | use crate::wol::{create_buffer, send_packet}; | ||
12 | use crate::db::Device; | ||
13 | use crate::error::WebolError; | ||
14 | use crate::services::ping::PingValue; | ||
15 | 14 | ||
16 | #[axum_macros::debug_handler] | 15 | #[axum_macros::debug_handler] |
17 | pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<StartPayload>) -> Result<Json<Value>, WebolError> { | 16 | pub async fn start( |
17 | State(state): State<Arc<crate::AppState>>, | ||
18 | headers: HeaderMap, | ||
19 | Json(payload): Json<Payload>, | ||
20 | ) -> Result<Json<Value>, Error> { | ||
18 | info!("POST request"); | 21 | info!("POST request"); |
19 | let secret = headers.get("authorization"); | 22 | let secret = headers.get("authorization"); |
20 | let authorized = auth(secret).map_err(WebolError::Auth)?; | 23 | let authorized = auth(&state.config, secret).map_err(Error::Auth)?; |
21 | if authorized { | 24 | if authorized { |
22 | let device = sqlx::query_as!( | 25 | let device = sqlx::query_as!( |
23 | Device, | 26 | Device, |
@@ -27,18 +30,19 @@ pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap | |||
27 | WHERE id = $1; | 30 | WHERE id = $1; |
28 | "#, | 31 | "#, |
29 | payload.id | 32 | payload.id |
30 | ).fetch_one(&state.db).await.map_err(WebolError::DB)?; | 33 | ) |
34 | .fetch_one(&state.db) | ||
35 | .await | ||
36 | .map_err(Error::DB)?; | ||
31 | 37 | ||
32 | info!("starting {}", device.id); | 38 | info!("starting {}", device.id); |
33 | 39 | ||
34 | let bind_addr = SETTINGS | 40 | let bind_addr = "0.0.0.0:0"; |
35 | .get_string("bindaddr") | ||
36 | .unwrap_or("0.0.0.0:1111".to_string()); | ||
37 | 41 | ||
38 | let _ = send_packet( | 42 | let _ = send_packet( |
39 | &bind_addr.parse().map_err(WebolError::IpParse)?, | 43 | &bind_addr.parse().map_err(Error::IpParse)?, |
40 | &device.broadcast_addr.parse().map_err(WebolError::IpParse)?, | 44 | &device.broadcast_addr.parse().map_err(Error::IpParse)?, |
41 | create_buffer(&device.mac)? | 45 | &create_buffer(&device.mac)?, |
42 | )?; | 46 | )?; |
43 | let dev_id = device.id.clone(); | 47 | let dev_id = device.id.clone(); |
44 | let uuid = if payload.ping.is_some_and(|ping| ping) { | 48 | let uuid = if payload.ping.is_some_and(|ping| ping) { |
@@ -49,7 +53,7 @@ pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap | |||
49 | uuid = Some(key); | 53 | uuid = Some(key); |
50 | break; | 54 | break; |
51 | } | 55 | } |
52 | }; | 56 | } |
53 | let uuid_gen = match uuid { | 57 | let uuid_gen = match uuid { |
54 | Some(u) => u, | 58 | Some(u) => u, |
55 | None => Uuid::new_v4().to_string(), | 59 | None => Uuid::new_v4().to_string(), |
@@ -58,26 +62,46 @@ pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap | |||
58 | 62 | ||
59 | tokio::spawn(async move { | 63 | tokio::spawn(async move { |
60 | debug!("init ping service"); | 64 | debug!("init ping service"); |
61 | state.ping_map.insert(uuid_gen.clone(), PingValue { ip: device.ip.clone(), online: false }); | 65 | state.ping_map.insert( |
66 | uuid_gen.clone(), | ||
67 | PingValue { | ||
68 | ip: device.ip.clone(), | ||
69 | online: false, | ||
70 | }, | ||
71 | ); | ||
62 | 72 | ||
63 | crate::services::ping::spawn(state.ping_send.clone(), device, uuid_gen.clone(), &state.ping_map, &state.db).await | 73 | crate::services::ping::spawn( |
74 | state.ping_send.clone(), | ||
75 | &state.config, | ||
76 | device, | ||
77 | uuid_gen.clone(), | ||
78 | &state.ping_map, | ||
79 | &state.db, | ||
80 | ) | ||
81 | .await; | ||
64 | }); | 82 | }); |
65 | Some(uuid_genc) | 83 | Some(uuid_genc) |
66 | } else { None }; | 84 | } else { |
67 | Ok(Json(json!(StartResponse { id: dev_id, boot: true, uuid }))) | 85 | None |
86 | }; | ||
87 | Ok(Json(json!(Response { | ||
88 | id: dev_id, | ||
89 | boot: true, | ||
90 | uuid | ||
91 | }))) | ||
68 | } else { | 92 | } else { |
69 | Err(WebolError::Generic) | 93 | Err(Error::Generic) |
70 | } | 94 | } |
71 | } | 95 | } |
72 | 96 | ||
73 | #[derive(Deserialize)] | 97 | #[derive(Deserialize)] |
74 | pub struct StartPayload { | 98 | pub struct Payload { |
75 | id: String, | 99 | id: String, |
76 | ping: Option<bool>, | 100 | ping: Option<bool>, |
77 | } | 101 | } |
78 | 102 | ||
79 | #[derive(Serialize)] | 103 | #[derive(Serialize)] |
80 | struct StartResponse { | 104 | struct Response { |
81 | id: String, | 105 | id: String, |
82 | boot: bool, | 106 | boot: bool, |
83 | uuid: Option<String>, | 107 | uuid: Option<String>, |
diff --git a/src/routes/status.rs b/src/routes/status.rs index 45f3e51..31ef996 100644 --- a/src/routes/status.rs +++ b/src/routes/status.rs | |||
@@ -7,4 +7,4 @@ use crate::services::ping::status_websocket; | |||
7 | #[axum_macros::debug_handler] | 7 | #[axum_macros::debug_handler] |
8 | pub async fn status(State(state): State<Arc<AppState>>, ws: WebSocketUpgrade) -> Response { | 8 | pub async fn status(State(state): State<Arc<AppState>>, ws: WebSocketUpgrade) -> Response { |
9 | ws.on_upgrade(move |socket| status_websocket(socket, state)) | 9 | ws.on_upgrade(move |socket| status_websocket(socket, state)) |
10 | } \ No newline at end of file | 10 | } |
diff --git a/src/services/ping.rs b/src/services/ping.rs index c3bdced..9b164c8 100644 --- a/src/services/ping.rs +++ b/src/services/ping.rs | |||
@@ -2,26 +2,26 @@ use std::str::FromStr; | |||
2 | use std::net::IpAddr; | 2 | use std::net::IpAddr; |
3 | use std::sync::Arc; | 3 | use std::sync::Arc; |
4 | 4 | ||
5 | use axum::extract::{ws::WebSocket}; | 5 | use axum::extract::ws::WebSocket; |
6 | use axum::extract::ws::Message; | 6 | use axum::extract::ws::Message; |
7 | use dashmap::DashMap; | 7 | use dashmap::DashMap; |
8 | use sqlx::PgPool; | 8 | use sqlx::PgPool; |
9 | use time::{Duration, Instant}; | 9 | use time::{Duration, Instant}; |
10 | use tokio::sync::broadcast::{Sender}; | 10 | use tokio::sync::broadcast::Sender; |
11 | use tracing::{debug, error, trace}; | 11 | use tracing::{debug, error, trace}; |
12 | use crate::AppState; | 12 | use crate::AppState; |
13 | use crate::config::SETTINGS; | 13 | use crate::config::Config; |
14 | use crate::db::Device; | 14 | use crate::db::Device; |
15 | 15 | ||
16 | pub type PingMap = DashMap<String, PingValue>; | 16 | pub type StatusMap = DashMap<String, Value>; |
17 | 17 | ||
18 | #[derive(Debug, Clone)] | 18 | #[derive(Debug, Clone)] |
19 | pub struct PingValue { | 19 | pub struct Value { |
20 | pub ip: String, | 20 | pub ip: String, |
21 | pub online: bool | 21 | pub online: bool |
22 | } | 22 | } |
23 | 23 | ||
24 | pub async fn spawn(tx: Sender<BroadcastCommands>, device: Device, uuid: String, ping_map: &PingMap, db: &PgPool) { | 24 | pub async fn spawn(tx: Sender<BroadcastCommands>, config: &Config, device: Device, uuid: String, ping_map: &StatusMap, db: &PgPool) { |
25 | let timer = Instant::now(); | 25 | let timer = Instant::now(); |
26 | let payload = [0; 8]; | 26 | let payload = [0; 8]; |
27 | 27 | ||
@@ -40,7 +40,7 @@ pub async fn spawn(tx: Sender<BroadcastCommands>, device: Device, uuid: String, | |||
40 | error!("{}", ping.to_string()); | 40 | error!("{}", ping.to_string()); |
41 | msg = Some(BroadcastCommands::Error(uuid.clone())); | 41 | msg = Some(BroadcastCommands::Error(uuid.clone())); |
42 | } | 42 | } |
43 | if timer.elapsed() >= Duration::minutes(SETTINGS.get_int("pingtimeout").unwrap_or(10)) { | 43 | if timer.elapsed() >= Duration::minutes(config.pingtimeout) { |
44 | msg = Some(BroadcastCommands::Timeout(uuid.clone())); | 44 | msg = Some(BroadcastCommands::Timeout(uuid.clone())); |
45 | } | 45 | } |
46 | } else { | 46 | } else { |
@@ -63,7 +63,7 @@ pub async fn spawn(tx: Sender<BroadcastCommands>, device: Device, uuid: String, | |||
63 | timer.elapsed().whole_seconds(), | 63 | timer.elapsed().whole_seconds(), |
64 | device.id | 64 | device.id |
65 | ).execute(db).await.unwrap(); | 65 | ).execute(db).await.unwrap(); |
66 | ping_map.insert(uuid.clone(), PingValue { ip: device.ip.clone(), online: true }); | 66 | ping_map.insert(uuid.clone(), Value { ip: device.ip.clone(), online: true }); |
67 | tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; | 67 | tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; |
68 | } | 68 | } |
69 | trace!("remove {} from ping_map", uuid); | 69 | trace!("remove {} from ping_map", uuid); |
@@ -85,17 +85,14 @@ pub async fn status_websocket(mut socket: WebSocket, state: Arc<AppState>) { | |||
85 | trace!("Search for uuid: {}", uuid); | 85 | trace!("Search for uuid: {}", uuid); |
86 | 86 | ||
87 | let eta = get_eta(&state.db).await; | 87 | let eta = get_eta(&state.db).await; |
88 | let _ = socket.send(Message::Text(format!("eta_{}_{}", eta, uuid))).await; | 88 | let _ = socket.send(Message::Text(format!("eta_{eta}_{uuid}"))).await; |
89 | 89 | ||
90 | let device_exists = state.ping_map.contains_key(&uuid); | 90 | let device_exists = state.ping_map.contains_key(&uuid); |
91 | match device_exists { | 91 | if device_exists { |
92 | true => { | 92 | let _ = socket.send(process_device(state.clone(), uuid).await).await; |
93 | let _ = socket.send(process_device(state.clone(), uuid).await).await; | 93 | } else { |
94 | }, | 94 | debug!("didn't find any device"); |
95 | false => { | 95 | let _ = socket.send(Message::Text(format!("notfound_{uuid}"))).await; |
96 | debug!("didn't find any device"); | ||
97 | let _ = socket.send(Message::Text(format!("notfound_{}", uuid))).await; | ||
98 | }, | ||
99 | }; | 96 | }; |
100 | 97 | ||
101 | let _ = socket.close().await; | 98 | let _ = socket.close().await; |
@@ -110,7 +107,7 @@ async fn get_eta(db: &PgPool) -> i64 { | |||
110 | None => { vec![0] }, | 107 | None => { vec![0] }, |
111 | Some(t) => t, | 108 | Some(t) => t, |
112 | }; | 109 | }; |
113 | times.iter().sum::<i64>() / times.len() as i64 | 110 | times.iter().sum::<i64>() / i64::try_from(times.len()).unwrap() |
114 | 111 | ||
115 | } | 112 | } |
116 | 113 | ||
@@ -118,34 +115,31 @@ async fn process_device(state: Arc<AppState>, uuid: String) -> Message { | |||
118 | let pm = state.ping_map.clone().into_read_only(); | 115 | let pm = state.ping_map.clone().into_read_only(); |
119 | let device = pm.get(&uuid).expect("fatal error"); | 116 | let device = pm.get(&uuid).expect("fatal error"); |
120 | debug!("got device: {} (online: {})", device.ip, device.online); | 117 | debug!("got device: {} (online: {})", device.ip, device.online); |
121 | match device.online { | 118 | if device.online { |
122 | true => { | 119 | debug!("already started"); |
123 | debug!("already started"); | 120 | Message::Text(format!("start_{uuid}")) |
124 | Message::Text(format!("start_{}", uuid)) | 121 | } else { |
125 | }, | 122 | loop { |
126 | false => { | 123 | trace!("wait for tx message"); |
127 | loop{ | 124 | let message = state.ping_send.subscribe().recv().await.expect("fatal error"); |
128 | trace!("wait for tx message"); | 125 | trace!("got message {:?}", message); |
129 | let message = state.ping_send.subscribe().recv().await.expect("fatal error"); | 126 | return match message { |
130 | trace!("got message {:?}", message); | 127 | BroadcastCommands::Success(msg_uuid) => { |
131 | return match message { | 128 | if msg_uuid != uuid { continue; } |
132 | BroadcastCommands::Success(msg_uuid) => { | 129 | trace!("message == uuid success"); |
133 | if msg_uuid != uuid { continue; } | 130 | Message::Text(format!("start_{uuid}")) |
134 | trace!("message == uuid success"); | 131 | }, |
135 | Message::Text(format!("start_{}", uuid)) | 132 | BroadcastCommands::Timeout(msg_uuid) => { |
136 | }, | 133 | if msg_uuid != uuid { continue; } |
137 | BroadcastCommands::Timeout(msg_uuid) => { | 134 | trace!("message == uuid timeout"); |
138 | if msg_uuid != uuid { continue; } | 135 | Message::Text(format!("timeout_{uuid}")) |
139 | trace!("message == uuid timeout"); | 136 | }, |
140 | Message::Text(format!("timeout_{}", uuid)) | 137 | BroadcastCommands::Error(msg_uuid) => { |
141 | }, | 138 | if msg_uuid != uuid { continue; } |
142 | BroadcastCommands::Error(msg_uuid) => { | 139 | trace!("message == uuid error"); |
143 | if msg_uuid != uuid { continue; } | 140 | Message::Text(format!("error_{uuid}")) |
144 | trace!("message == uuid error"); | ||
145 | Message::Text(format!("error_{}", uuid)) | ||
146 | } | ||
147 | } | 141 | } |
148 | } | 142 | } |
149 | } | 143 | } |
150 | } | 144 | } |
151 | } \ No newline at end of file | 145 | } |
@@ -1,17 +1,17 @@ | |||
1 | use std::net::{SocketAddr, UdpSocket}; | 1 | use std::net::{SocketAddr, UdpSocket}; |
2 | 2 | ||
3 | use crate::error::WebolError; | 3 | use crate::error::Error; |
4 | 4 | ||
5 | /// Creates the magic packet from a mac address | 5 | /// Creates the magic packet from a mac address |
6 | /// | 6 | /// |
7 | /// # Panics | 7 | /// # Panics |
8 | /// | 8 | /// |
9 | /// Panics if `mac_addr` is an invalid mac | 9 | /// Panics if `mac_addr` is an invalid mac |
10 | pub fn create_buffer(mac_addr: &str) -> Result<Vec<u8>, WebolError> { | 10 | pub fn create_buffer(mac_addr: &str) -> Result<Vec<u8>, Error> { |
11 | let mut mac = Vec::new(); | 11 | let mut mac = Vec::new(); |
12 | let sp = mac_addr.split(':'); | 12 | let sp = mac_addr.split(':'); |
13 | for f in sp { | 13 | for f in sp { |
14 | mac.push(u8::from_str_radix(f, 16).map_err(WebolError::BufferParse)?) | 14 | mac.push(u8::from_str_radix(f, 16).map_err(Error::BufferParse)?); |
15 | }; | 15 | }; |
16 | let mut buf = vec![255; 6]; | 16 | let mut buf = vec![255; 6]; |
17 | for _ in 0..16 { | 17 | for _ in 0..16 { |
@@ -23,8 +23,8 @@ pub fn create_buffer(mac_addr: &str) -> Result<Vec<u8>, WebolError> { | |||
23 | } | 23 | } |
24 | 24 | ||
25 | /// Sends a buffer on UDP broadcast | 25 | /// Sends a buffer on UDP broadcast |
26 | pub fn send_packet(bind_addr: &SocketAddr, broadcast_addr: &SocketAddr, buffer: Vec<u8>) -> Result<usize, WebolError> { | 26 | pub fn send_packet(bind_addr: &SocketAddr, broadcast_addr: &SocketAddr, buffer: &[u8]) -> Result<usize, Error> { |
27 | let socket = UdpSocket::bind(bind_addr).map_err(WebolError::Broadcast)?; | 27 | let socket = UdpSocket::bind(bind_addr).map_err(Error::Broadcast)?; |
28 | socket.set_broadcast(true).map_err(WebolError::Broadcast)?; | 28 | socket.set_broadcast(true).map_err(Error::Broadcast)?; |
29 | socket.send_to(&buffer, broadcast_addr).map_err(WebolError::Broadcast) | 29 | socket.send_to(buffer, broadcast_addr).map_err(Error::Broadcast) |
30 | } | 30 | } |