diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/auth.rs | 5 | ||||
-rw-r--r-- | src/config.rs | 5 | ||||
-rw-r--r-- | src/db.rs | 37 | ||||
-rw-r--r-- | src/error.rs | 21 | ||||
-rw-r--r-- | src/main.rs | 41 | ||||
-rw-r--r-- | src/routes/device.rs | 118 | ||||
-rw-r--r-- | src/routes/start.rs | 100 | ||||
-rw-r--r-- | src/routes/status.rs | 24 | ||||
-rw-r--r-- | src/services/ping.rs | 45 | ||||
-rw-r--r-- | src/storage.rs | 70 | ||||
-rw-r--r-- | src/wol.rs | 20 |
11 files changed, 180 insertions, 306 deletions
diff --git a/src/auth.rs b/src/auth.rs index 74008b5..c662e36 100644 --- a/src/auth.rs +++ b/src/auth.rs | |||
@@ -6,6 +6,7 @@ use axum::{ | |||
6 | response::Response, | 6 | response::Response, |
7 | }; | 7 | }; |
8 | use serde::Deserialize; | 8 | use serde::Deserialize; |
9 | use tracing::trace; | ||
9 | 10 | ||
10 | #[derive(Debug, Clone, Deserialize)] | 11 | #[derive(Debug, Clone, Deserialize)] |
11 | pub enum Methods { | 12 | pub enum Methods { |
@@ -20,15 +21,19 @@ pub async fn auth( | |||
20 | next: Next, | 21 | next: Next, |
21 | ) -> Result<Response, StatusCode> { | 22 | ) -> Result<Response, StatusCode> { |
22 | let auth = state.config.auth; | 23 | let auth = state.config.auth; |
24 | trace!(?auth.method, "auth request"); | ||
23 | match auth.method { | 25 | match auth.method { |
24 | Methods::Key => { | 26 | Methods::Key => { |
25 | if let Some(secret) = headers.get("authorization") { | 27 | if let Some(secret) = headers.get("authorization") { |
26 | if auth.secret.as_str() != secret { | 28 | if auth.secret.as_str() != secret { |
29 | trace!("auth failed, unknown secret"); | ||
27 | return Err(StatusCode::UNAUTHORIZED); | 30 | return Err(StatusCode::UNAUTHORIZED); |
28 | }; | 31 | }; |
32 | trace!("auth successfull"); | ||
29 | let response = next.run(request).await; | 33 | let response = next.run(request).await; |
30 | Ok(response) | 34 | Ok(response) |
31 | } else { | 35 | } else { |
36 | trace!("auth failed, no secret"); | ||
32 | Err(StatusCode::UNAUTHORIZED) | 37 | Err(StatusCode::UNAUTHORIZED) |
33 | } | 38 | } |
34 | } | 39 | } |
diff --git a/src/config.rs b/src/config.rs index 9636af4..bfb28be 100644 --- a/src/config.rs +++ b/src/config.rs | |||
@@ -5,7 +5,6 @@ use crate::auth; | |||
5 | 5 | ||
6 | #[derive(Debug, Clone, Deserialize)] | 6 | #[derive(Debug, Clone, Deserialize)] |
7 | pub struct Config { | 7 | pub struct Config { |
8 | pub database_url: String, | ||
9 | pub serveraddr: String, | 8 | pub serveraddr: String, |
10 | pub pingtimeout: i64, | 9 | pub pingtimeout: i64, |
11 | pub pingthreshold: i64, | 10 | pub pingthreshold: i64, |
@@ -26,9 +25,11 @@ impl Config { | |||
26 | .set_default("pingtimeout", 10)? | 25 | .set_default("pingtimeout", 10)? |
27 | .set_default("pingthreshold", 1)? | 26 | .set_default("pingthreshold", 1)? |
28 | .set_default("timeoffset", 0)? | 27 | .set_default("timeoffset", 0)? |
28 | .set_default("auth.method", "none")? | ||
29 | .set_default("auth.secret", "")? | ||
29 | .add_source(File::with_name("config.toml").required(false)) | 30 | .add_source(File::with_name("config.toml").required(false)) |
30 | .add_source(File::with_name("config.dev.toml").required(false)) | 31 | .add_source(File::with_name("config.dev.toml").required(false)) |
31 | .add_source(config::Environment::with_prefix("WEBOL").prefix_separator("_")) | 32 | .add_source(config::Environment::with_prefix("WEBOL").separator("_")) |
32 | .build()?; | 33 | .build()?; |
33 | 34 | ||
34 | config.try_deserialize() | 35 | config.try_deserialize() |
diff --git a/src/db.rs b/src/db.rs deleted file mode 100644 index a2b2009..0000000 --- a/src/db.rs +++ /dev/null | |||
@@ -1,37 +0,0 @@ | |||
1 | use serde::Serialize; | ||
2 | use sqlx::{PgPool, postgres::PgPoolOptions, types::{ipnetwork::IpNetwork, mac_address::MacAddress}}; | ||
3 | use tracing::{debug, info}; | ||
4 | use utoipa::ToSchema; | ||
5 | |||
6 | #[derive(Serialize, Debug)] | ||
7 | pub struct Device { | ||
8 | pub id: String, | ||
9 | pub mac: MacAddress, | ||
10 | pub broadcast_addr: String, | ||
11 | pub ip: IpNetwork, | ||
12 | pub times: Option<Vec<i64>> | ||
13 | } | ||
14 | |||
15 | #[derive(ToSchema)] | ||
16 | #[schema(as = Device)] | ||
17 | pub struct DeviceSchema { | ||
18 | pub id: String, | ||
19 | pub mac: String, | ||
20 | pub broadcast_addr: String, | ||
21 | pub ip: String, | ||
22 | pub times: Option<Vec<i64>> | ||
23 | } | ||
24 | |||
25 | pub async fn init_db_pool(db_url: &str) -> PgPool { | ||
26 | debug!("attempt to connect dbPool to '{}'", db_url); | ||
27 | |||
28 | let pool = PgPoolOptions::new() | ||
29 | .max_connections(5) | ||
30 | .connect(db_url) | ||
31 | .await | ||
32 | .unwrap(); | ||
33 | |||
34 | info!("dbPool successfully connected to '{}'", db_url); | ||
35 | |||
36 | pool | ||
37 | } | ||
diff --git a/src/error.rs b/src/error.rs index 006fcdb..2d70592 100644 --- a/src/error.rs +++ b/src/error.rs | |||
@@ -7,14 +7,14 @@ use mac_address::MacParseError; | |||
7 | use serde_json::json; | 7 | use serde_json::json; |
8 | use utoipa::ToSchema; | 8 | use utoipa::ToSchema; |
9 | use std::io; | 9 | use std::io; |
10 | use tracing::error; | 10 | use tracing::{error, warn}; |
11 | 11 | ||
12 | #[derive(Debug, thiserror::Error, ToSchema)] | 12 | #[derive(Debug, thiserror::Error, ToSchema)] |
13 | pub enum Error { | 13 | pub enum Error { |
14 | #[error("db: {source}")] | 14 | #[error("json: {source}")] |
15 | Db { | 15 | Json { |
16 | #[from] | 16 | #[from] |
17 | source: sqlx::Error, | 17 | source: serde_json::Error, |
18 | }, | 18 | }, |
19 | 19 | ||
20 | #[error("buffer parse: {source}")] | 20 | #[error("buffer parse: {source}")] |
@@ -50,15 +50,20 @@ pub enum Error { | |||
50 | 50 | ||
51 | impl IntoResponse for Error { | 51 | impl IntoResponse for Error { |
52 | fn into_response(self) -> Response { | 52 | fn into_response(self) -> Response { |
53 | error!("{}", self.to_string()); | 53 | // error!("{}", self.to_string()); |
54 | let (status, error_message) = match self { | 54 | let (status, error_message) = match self { |
55 | Self::Db { source } => { | 55 | Self::Json { source } => { |
56 | error!("{source}"); | 56 | error!("{source}"); |
57 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | 57 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") |
58 | } | 58 | } |
59 | Self::Io { source } => { | 59 | Self::Io { source } => { |
60 | error!("{source}"); | 60 | if source.kind() == io::ErrorKind::NotFound { |
61 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | 61 | warn!("unknown device requested"); |
62 | (StatusCode::NOT_FOUND, "Requested device not found") | ||
63 | } else { | ||
64 | error!("{source}"); | ||
65 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | ||
66 | } | ||
62 | } | 67 | } |
63 | Self::ParseHeader { source } => { | 68 | Self::ParseHeader { source } => { |
64 | error!("{source}"); | 69 | error!("{source}"); |
diff --git a/src/main.rs b/src/main.rs index 70c67cf..204c318 100644 --- a/src/main.rs +++ b/src/main.rs | |||
@@ -1,8 +1,5 @@ | |||
1 | use crate::{ | 1 | use crate::{ |
2 | config::Config, | 2 | config::Config, routes::{device, start, status}, services::ping::{BroadcastCommand, StatusMap}, storage::Device |
3 | db::init_db_pool, | ||
4 | routes::{device, start, status}, | ||
5 | services::ping::{BroadcastCommand, StatusMap}, | ||
6 | }; | 3 | }; |
7 | use axum::{ | 4 | use axum::{ |
8 | middleware::from_fn_with_state, | 5 | middleware::from_fn_with_state, |
@@ -10,11 +7,10 @@ use axum::{ | |||
10 | Router, | 7 | Router, |
11 | }; | 8 | }; |
12 | use dashmap::DashMap; | 9 | use dashmap::DashMap; |
13 | use sqlx::PgPool; | ||
14 | use std::{env, sync::Arc}; | 10 | use std::{env, sync::Arc}; |
15 | use time::UtcOffset; | 11 | use time::UtcOffset; |
16 | use tokio::sync::broadcast::{channel, Sender}; | 12 | use tokio::sync::broadcast::{channel, Sender}; |
17 | use tracing::{info, level_filters::LevelFilter}; | 13 | use tracing::{info, level_filters::LevelFilter, trace}; |
18 | use tracing_subscriber::{ | 14 | use tracing_subscriber::{ |
19 | fmt::{self, time::OffsetTime}, | 15 | fmt::{self, time::OffsetTime}, |
20 | prelude::*, | 16 | prelude::*, |
@@ -26,10 +22,10 @@ use utoipa::{ | |||
26 | }; | 22 | }; |
27 | use utoipa_swagger_ui::SwaggerUi; | 23 | use utoipa_swagger_ui::SwaggerUi; |
28 | 24 | ||
25 | mod auth; | ||
29 | mod config; | 26 | mod config; |
30 | mod db; | 27 | mod storage; |
31 | mod error; | 28 | mod error; |
32 | mod auth; | ||
33 | mod routes; | 29 | mod routes; |
34 | mod services; | 30 | mod services; |
35 | mod wol; | 31 | mod wol; |
@@ -39,20 +35,16 @@ mod wol; | |||
39 | paths( | 35 | paths( |
40 | start::post, | 36 | start::post, |
41 | start::get, | 37 | start::get, |
42 | start::start_payload, | ||
43 | device::get, | 38 | device::get, |
44 | device::get_payload, | ||
45 | device::post, | 39 | device::post, |
46 | device::put, | 40 | device::put, |
47 | ), | 41 | ), |
48 | components( | 42 | components( |
49 | schemas( | 43 | schemas( |
50 | start::PayloadOld, | 44 | start::SPayload, |
51 | start::Payload, | ||
52 | start::Response, | 45 | start::Response, |
53 | device::DevicePayload, | 46 | device::DPayload, |
54 | device::GetDevicePayload, | 47 | storage::DeviceSchema, |
55 | db::DeviceSchema, | ||
56 | ) | 48 | ) |
57 | ), | 49 | ), |
58 | modifiers(&SecurityAddon), | 50 | modifiers(&SecurityAddon), |
@@ -76,7 +68,6 @@ impl Modify for SecurityAddon { | |||
76 | } | 68 | } |
77 | 69 | ||
78 | #[tokio::main] | 70 | #[tokio::main] |
79 | #[allow(deprecated)] | ||
80 | async fn main() -> color_eyre::eyre::Result<()> { | 71 | async fn main() -> color_eyre::eyre::Result<()> { |
81 | color_eyre::install()?; | 72 | color_eyre::install()?; |
82 | 73 | ||
@@ -98,35 +89,28 @@ async fn main() -> color_eyre::eyre::Result<()> { | |||
98 | .from_env_lossy(), | 89 | .from_env_lossy(), |
99 | ) | 90 | ) |
100 | .init(); | 91 | .init(); |
92 | trace!("logging initialized"); | ||
101 | 93 | ||
102 | let version = env!("CARGO_PKG_VERSION"); | 94 | Device::setup()?; |
103 | |||
104 | info!("start webol v{}", version); | ||
105 | 95 | ||
106 | let db = init_db_pool(&config.database_url).await; | 96 | let version = env!("CARGO_PKG_VERSION"); |
107 | sqlx::migrate!().run(&db).await.unwrap(); | 97 | info!(?version, "start webol"); |
108 | 98 | ||
109 | let (tx, _) = channel(32); | 99 | let (tx, _) = channel(32); |
110 | 100 | ||
111 | let ping_map: StatusMap = DashMap::new(); | 101 | let ping_map: StatusMap = DashMap::new(); |
112 | 102 | ||
113 | let shared_state = AppState { | 103 | let shared_state = AppState { |
114 | db, | ||
115 | config: config.clone(), | 104 | config: config.clone(), |
116 | ping_send: tx, | 105 | ping_send: tx, |
117 | ping_map, | 106 | ping_map, |
118 | }; | 107 | }; |
119 | 108 | ||
120 | let app = Router::new() | 109 | let app = Router::new() |
121 | .route("/start", post(start::start_payload)) | ||
122 | .route("/start/:id", post(start::post).get(start::get)) | 110 | .route("/start/:id", post(start::post).get(start::get)) |
123 | .route( | 111 | .route("/device", post(device::post).put(device::put)) |
124 | "/device", | ||
125 | post(device::post).get(device::get_payload).put(device::put), | ||
126 | ) | ||
127 | .route("/device/:id", get(device::get)) | 112 | .route("/device/:id", get(device::get)) |
128 | .route("/status", get(status::status)) | 113 | .route("/status", get(status::status)) |
129 | // TODO: Don't load on `None` Auth | ||
130 | .route_layer(from_fn_with_state(shared_state.clone(), auth::auth)) | 114 | .route_layer(from_fn_with_state(shared_state.clone(), auth::auth)) |
131 | .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi())) | 115 | .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi())) |
132 | .with_state(Arc::new(shared_state)); | 116 | .with_state(Arc::new(shared_state)); |
@@ -141,7 +125,6 @@ async fn main() -> color_eyre::eyre::Result<()> { | |||
141 | 125 | ||
142 | #[derive(Clone)] | 126 | #[derive(Clone)] |
143 | pub struct AppState { | 127 | pub struct AppState { |
144 | db: PgPool, | ||
145 | config: Config, | 128 | config: Config, |
146 | ping_send: Sender<BroadcastCommand>, | 129 | ping_send: Sender<BroadcastCommand>, |
147 | ping_map: StatusMap, | 130 | ping_map: StatusMap, |
diff --git a/src/routes/device.rs b/src/routes/device.rs index 40b5cd8..49361f2 100644 --- a/src/routes/device.rs +++ b/src/routes/device.rs | |||
@@ -1,49 +1,17 @@ | |||
1 | use crate::db::Device; | ||
2 | use crate::error::Error; | 1 | use crate::error::Error; |
3 | use axum::extract::{Path, State}; | 2 | use crate::storage::Device; |
3 | use axum::extract::Path; | ||
4 | use axum::Json; | 4 | use axum::Json; |
5 | use ipnetwork::IpNetwork; | ||
5 | use mac_address::MacAddress; | 6 | use mac_address::MacAddress; |
6 | use serde::Deserialize; | 7 | use serde::Deserialize; |
7 | use serde_json::{json, Value}; | 8 | use serde_json::{json, Value}; |
8 | use sqlx::types::ipnetwork::IpNetwork; | 9 | use std::str::FromStr; |
9 | use std::{str::FromStr, sync::Arc}; | ||
10 | use tracing::{debug, info}; | 10 | use tracing::{debug, info}; |
11 | use utoipa::ToSchema; | 11 | use utoipa::ToSchema; |
12 | 12 | ||
13 | #[utoipa::path( | 13 | #[utoipa::path( |
14 | get, | 14 | get, |
15 | path = "/device", | ||
16 | request_body = GetDevicePayload, | ||
17 | responses( | ||
18 | (status = 200, description = "Get `Device` information", body = [Device]) | ||
19 | ), | ||
20 | security(("api_key" = [])) | ||
21 | )] | ||
22 | #[deprecated] | ||
23 | pub async fn get_payload( | ||
24 | State(state): State<Arc<crate::AppState>>, | ||
25 | Json(payload): Json<GetDevicePayload>, | ||
26 | ) -> Result<Json<Value>, Error> { | ||
27 | info!("get device {}", payload.id); | ||
28 | let device = sqlx::query_as!( | ||
29 | Device, | ||
30 | r#" | ||
31 | SELECT id, mac, broadcast_addr, ip, times | ||
32 | FROM devices | ||
33 | WHERE id = $1; | ||
34 | "#, | ||
35 | payload.id | ||
36 | ) | ||
37 | .fetch_one(&state.db) | ||
38 | .await?; | ||
39 | |||
40 | debug!("got device {:?}", device); | ||
41 | |||
42 | Ok(Json(json!(device))) | ||
43 | } | ||
44 | |||
45 | #[utoipa::path( | ||
46 | get, | ||
47 | path = "/device/{id}", | 15 | path = "/device/{id}", |
48 | responses( | 16 | responses( |
49 | (status = 200, description = "Get `Device` information", body = [Device]) | 17 | (status = 200, description = "Get `Device` information", body = [Device]) |
@@ -53,22 +21,10 @@ pub async fn get_payload( | |||
53 | ), | 21 | ), |
54 | security((), ("api_key" = [])) | 22 | security((), ("api_key" = [])) |
55 | )] | 23 | )] |
56 | pub async fn get( | 24 | pub async fn get(Path(id): Path<String>) -> Result<Json<Value>, Error> { |
57 | State(state): State<Arc<crate::AppState>>, | 25 | info!("get device from path {}", id); |
58 | Path(path): Path<String>, | 26 | |
59 | ) -> Result<Json<Value>, Error> { | 27 | let device = Device::read(&id)?; |
60 | info!("get device from path {}", path); | ||
61 | let device = sqlx::query_as!( | ||
62 | Device, | ||
63 | r#" | ||
64 | SELECT id, mac, broadcast_addr, ip, times | ||
65 | FROM devices | ||
66 | WHERE id = $1; | ||
67 | "#, | ||
68 | path | ||
69 | ) | ||
70 | .fetch_one(&state.db) | ||
71 | .await?; | ||
72 | 28 | ||
73 | debug!("got device {:?}", device); | 29 | debug!("got device {:?}", device); |
74 | 30 | ||
@@ -76,13 +32,7 @@ pub async fn get( | |||
76 | } | 32 | } |
77 | 33 | ||
78 | #[derive(Deserialize, ToSchema)] | 34 | #[derive(Deserialize, ToSchema)] |
79 | #[deprecated] | 35 | pub struct DPayload { |
80 | pub struct GetDevicePayload { | ||
81 | id: String, | ||
82 | } | ||
83 | |||
84 | #[derive(Deserialize, ToSchema)] | ||
85 | pub struct DevicePayload { | ||
86 | id: String, | 36 | id: String, |
87 | mac: String, | 37 | mac: String, |
88 | broadcast_addr: String, | 38 | broadcast_addr: String, |
@@ -92,15 +42,14 @@ pub struct DevicePayload { | |||
92 | #[utoipa::path( | 42 | #[utoipa::path( |
93 | put, | 43 | put, |
94 | path = "/device", | 44 | path = "/device", |
95 | request_body = DevicePayload, | 45 | request_body = DPayload, |
96 | responses( | 46 | responses( |
97 | (status = 200, description = "add device to storage", body = [DeviceSchema]) | 47 | (status = 200, description = "add device to storage", body = [DeviceSchema]) |
98 | ), | 48 | ), |
99 | security((), ("api_key" = [])) | 49 | security((), ("api_key" = [])) |
100 | )] | 50 | )] |
101 | pub async fn put( | 51 | pub async fn put( |
102 | State(state): State<Arc<crate::AppState>>, | 52 | Json(payload): Json<DPayload>, |
103 | Json(payload): Json<DevicePayload>, | ||
104 | ) -> Result<Json<Value>, Error> { | 53 | ) -> Result<Json<Value>, Error> { |
105 | info!( | 54 | info!( |
106 | "add device {} ({}, {}, {})", | 55 | "add device {} ({}, {}, {})", |
@@ -109,20 +58,14 @@ pub async fn put( | |||
109 | 58 | ||
110 | let ip = IpNetwork::from_str(&payload.ip)?; | 59 | let ip = IpNetwork::from_str(&payload.ip)?; |
111 | let mac = MacAddress::from_str(&payload.mac)?; | 60 | let mac = MacAddress::from_str(&payload.mac)?; |
112 | let device = sqlx::query_as!( | 61 | let device = Device { |
113 | Device, | 62 | id: payload.id, |
114 | r#" | ||
115 | INSERT INTO devices (id, mac, broadcast_addr, ip) | ||
116 | VALUES ($1, $2, $3, $4) | ||
117 | RETURNING id, mac, broadcast_addr, ip, times; | ||
118 | "#, | ||
119 | payload.id, | ||
120 | mac, | 63 | mac, |
121 | payload.broadcast_addr, | 64 | broadcast_addr: payload.broadcast_addr, |
122 | ip | 65 | ip, |
123 | ) | 66 | times: None, |
124 | .fetch_one(&state.db) | 67 | }; |
125 | .await?; | 68 | device.write()?; |
126 | 69 | ||
127 | Ok(Json(json!(device))) | 70 | Ok(Json(json!(device))) |
128 | } | 71 | } |
@@ -130,15 +73,14 @@ pub async fn put( | |||
130 | #[utoipa::path( | 73 | #[utoipa::path( |
131 | post, | 74 | post, |
132 | path = "/device", | 75 | path = "/device", |
133 | request_body = DevicePayload, | 76 | request_body = DPayload, |
134 | responses( | 77 | responses( |
135 | (status = 200, description = "update device in storage", body = [DeviceSchema]) | 78 | (status = 200, description = "update device in storage", body = [DeviceSchema]) |
136 | ), | 79 | ), |
137 | security((), ("api_key" = [])) | 80 | security((), ("api_key" = [])) |
138 | )] | 81 | )] |
139 | pub async fn post( | 82 | pub async fn post( |
140 | State(state): State<Arc<crate::AppState>>, | 83 | Json(payload): Json<DPayload>, |
141 | Json(payload): Json<DevicePayload>, | ||
142 | ) -> Result<Json<Value>, Error> { | 84 | ) -> Result<Json<Value>, Error> { |
143 | info!( | 85 | info!( |
144 | "edit device {} ({}, {}, {})", | 86 | "edit device {} ({}, {}, {})", |
@@ -146,20 +88,16 @@ pub async fn post( | |||
146 | ); | 88 | ); |
147 | let ip = IpNetwork::from_str(&payload.ip)?; | 89 | let ip = IpNetwork::from_str(&payload.ip)?; |
148 | let mac = MacAddress::from_str(&payload.mac)?; | 90 | let mac = MacAddress::from_str(&payload.mac)?; |
149 | let device = sqlx::query_as!( | 91 | let times = Device::read(&payload.id)?.times; |
150 | Device, | 92 | |
151 | r#" | 93 | let device = Device { |
152 | UPDATE devices | 94 | id: payload.id, |
153 | SET mac = $1, broadcast_addr = $2, ip = $3 WHERE id = $4 | ||
154 | RETURNING id, mac, broadcast_addr, ip, times; | ||
155 | "#, | ||
156 | mac, | 95 | mac, |
157 | payload.broadcast_addr, | 96 | broadcast_addr: payload.broadcast_addr, |
158 | ip, | 97 | ip, |
159 | payload.id | 98 | times, |
160 | ) | 99 | }; |
161 | .fetch_one(&state.db) | 100 | device.write()?; |
162 | .await?; | ||
163 | 101 | ||
164 | Ok(Json(json!(device))) | 102 | Ok(Json(json!(device))) |
165 | } | 103 | } |
diff --git a/src/routes/start.rs b/src/routes/start.rs index ff3d1be..ae2b384 100644 --- a/src/routes/start.rs +++ b/src/routes/start.rs | |||
@@ -1,7 +1,7 @@ | |||
1 | use crate::db::Device; | 1 | use crate::storage::Device; |
2 | use crate::error::Error; | 2 | use crate::error::Error; |
3 | use crate::services::ping::Value as PingValue; | 3 | use crate::services::ping::Value as PingValue; |
4 | use crate::wol::{create_buffer, send_packet}; | 4 | use crate::wol::send_packet; |
5 | use axum::extract::{Path, State}; | 5 | use axum::extract::{Path, State}; |
6 | use axum::Json; | 6 | use axum::Json; |
7 | use serde::{Deserialize, Serialize}; | 7 | use serde::{Deserialize, Serialize}; |
@@ -13,57 +13,8 @@ use uuid::Uuid; | |||
13 | 13 | ||
14 | #[utoipa::path( | 14 | #[utoipa::path( |
15 | post, | 15 | post, |
16 | path = "/start", | ||
17 | request_body = PayloadOld, | ||
18 | responses( | ||
19 | (status = 200, description = "DEP", body = [Response]) | ||
20 | ), | ||
21 | security((), ("api_key" = [])) | ||
22 | )] | ||
23 | #[deprecated] | ||
24 | pub async fn start_payload( | ||
25 | State(state): State<Arc<crate::AppState>>, | ||
26 | Json(payload): Json<PayloadOld>, | ||
27 | ) -> Result<Json<Value>, Error> { | ||
28 | info!("POST request"); | ||
29 | let device = sqlx::query_as!( | ||
30 | Device, | ||
31 | r#" | ||
32 | SELECT id, mac, broadcast_addr, ip, times | ||
33 | FROM devices | ||
34 | WHERE id = $1; | ||
35 | "#, | ||
36 | payload.id | ||
37 | ) | ||
38 | .fetch_one(&state.db) | ||
39 | .await?; | ||
40 | |||
41 | info!("starting {}", device.id); | ||
42 | |||
43 | let bind_addr = "0.0.0.0:0"; | ||
44 | |||
45 | let _ = send_packet( | ||
46 | bind_addr, | ||
47 | &device.broadcast_addr, | ||
48 | &create_buffer(&device.mac.to_string())?, | ||
49 | )?; | ||
50 | let dev_id = device.id.clone(); | ||
51 | let uuid = if payload.ping.is_some_and(|ping| ping) { | ||
52 | Some(setup_ping(state, device)) | ||
53 | } else { | ||
54 | None | ||
55 | }; | ||
56 | Ok(Json(json!(Response { | ||
57 | id: dev_id, | ||
58 | boot: true, | ||
59 | uuid | ||
60 | }))) | ||
61 | } | ||
62 | |||
63 | #[utoipa::path( | ||
64 | post, | ||
65 | path = "/start/{id}", | 16 | path = "/start/{id}", |
66 | request_body = Option<Payload>, | 17 | request_body = Option<SPayload>, |
67 | responses( | 18 | responses( |
68 | (status = 200, description = "start the device with the given id", body = [Response]) | 19 | (status = 200, description = "start the device with the given id", body = [Response]) |
69 | ), | 20 | ), |
@@ -75,9 +26,9 @@ pub async fn start_payload( | |||
75 | pub async fn post( | 26 | pub async fn post( |
76 | State(state): State<Arc<crate::AppState>>, | 27 | State(state): State<Arc<crate::AppState>>, |
77 | Path(id): Path<String>, | 28 | Path(id): Path<String>, |
78 | payload: Option<Json<Payload>>, | 29 | payload: Option<Json<SPayload>>, |
79 | ) -> Result<Json<Value>, Error> { | 30 | ) -> Result<Json<Value>, Error> { |
80 | send_wol(state, &id, payload).await | 31 | send_wol(state, &id, payload) |
81 | } | 32 | } |
82 | 33 | ||
83 | #[utoipa::path( | 34 | #[utoipa::path( |
@@ -95,26 +46,16 @@ pub async fn get( | |||
95 | State(state): State<Arc<crate::AppState>>, | 46 | State(state): State<Arc<crate::AppState>>, |
96 | Path(id): Path<String>, | 47 | Path(id): Path<String>, |
97 | ) -> Result<Json<Value>, Error> { | 48 | ) -> Result<Json<Value>, Error> { |
98 | send_wol(state, &id, None).await | 49 | send_wol(state, &id, None) |
99 | } | 50 | } |
100 | 51 | ||
101 | async fn send_wol( | 52 | fn send_wol( |
102 | state: Arc<crate::AppState>, | 53 | state: Arc<crate::AppState>, |
103 | id: &str, | 54 | id: &str, |
104 | payload: Option<Json<Payload>>, | 55 | payload: Option<Json<SPayload>>, |
105 | ) -> Result<Json<Value>, Error> { | 56 | ) -> Result<Json<Value>, Error> { |
106 | info!("Start request for {id}"); | 57 | info!("start request for {id}"); |
107 | let device = sqlx::query_as!( | 58 | let device = Device::read(id)?; |
108 | Device, | ||
109 | r#" | ||
110 | SELECT id, mac, broadcast_addr, ip, times | ||
111 | FROM devices | ||
112 | WHERE id = $1; | ||
113 | "#, | ||
114 | id | ||
115 | ) | ||
116 | .fetch_one(&state.db) | ||
117 | .await?; | ||
118 | 59 | ||
119 | info!("starting {}", device.id); | 60 | info!("starting {}", device.id); |
120 | 61 | ||
@@ -122,8 +63,8 @@ async fn send_wol( | |||
122 | 63 | ||
123 | let _ = send_packet( | 64 | let _ = send_packet( |
124 | bind_addr, | 65 | bind_addr, |
125 | &device.broadcast_addr, | 66 | &device.broadcast_addr.to_string(), |
126 | &create_buffer(&device.mac.to_string())?, | 67 | &device.mac.bytes() |
127 | )?; | 68 | )?; |
128 | let dev_id = device.id.clone(); | 69 | let dev_id = device.id.clone(); |
129 | let uuid = if let Some(pl) = payload { | 70 | let uuid = if let Some(pl) = payload { |
@@ -163,6 +104,7 @@ fn setup_ping(state: Arc<crate::AppState>, device: Device) -> String { | |||
163 | uuid_gen.clone(), | 104 | uuid_gen.clone(), |
164 | PingValue { | 105 | PingValue { |
165 | ip: device.ip, | 106 | ip: device.ip, |
107 | eta: get_eta(device.clone().times), | ||
166 | online: false, | 108 | online: false, |
167 | }, | 109 | }, |
168 | ); | 110 | ); |
@@ -174,7 +116,6 @@ fn setup_ping(state: Arc<crate::AppState>, device: Device) -> String { | |||
174 | device, | 116 | device, |
175 | uuid_gen, | 117 | uuid_gen, |
176 | &state.ping_map, | 118 | &state.ping_map, |
177 | &state.db, | ||
178 | ) | 119 | ) |
179 | .await; | 120 | .await; |
180 | }); | 121 | }); |
@@ -182,15 +123,18 @@ fn setup_ping(state: Arc<crate::AppState>, device: Device) -> String { | |||
182 | uuid_ret | 123 | uuid_ret |
183 | } | 124 | } |
184 | 125 | ||
185 | #[derive(Deserialize, ToSchema)] | 126 | fn get_eta(times: Option<Vec<i64>>) -> i64 { |
186 | #[deprecated] | 127 | let times = if let Some(times) = times { |
187 | pub struct PayloadOld { | 128 | times |
188 | id: String, | 129 | } else { |
189 | ping: Option<bool>, | 130 | vec![0] |
131 | }; | ||
132 | |||
133 | times.iter().sum::<i64>() / i64::try_from(times.len()).unwrap() | ||
190 | } | 134 | } |
191 | 135 | ||
192 | #[derive(Deserialize, ToSchema)] | 136 | #[derive(Deserialize, ToSchema)] |
193 | pub struct Payload { | 137 | pub struct SPayload { |
194 | ping: Option<bool>, | 138 | ping: Option<bool>, |
195 | } | 139 | } |
196 | 140 | ||
diff --git a/src/routes/status.rs b/src/routes/status.rs index 0e25f7d..b38202b 100644 --- a/src/routes/status.rs +++ b/src/routes/status.rs | |||
@@ -3,7 +3,6 @@ use crate::AppState; | |||
3 | use axum::extract::ws::{Message, WebSocket}; | 3 | use axum::extract::ws::{Message, WebSocket}; |
4 | use axum::extract::{State, WebSocketUpgrade}; | 4 | use axum::extract::{State, WebSocketUpgrade}; |
5 | use axum::response::Response; | 5 | use axum::response::Response; |
6 | use sqlx::PgPool; | ||
7 | use std::sync::Arc; | 6 | use std::sync::Arc; |
8 | use tracing::{debug, trace}; | 7 | use tracing::{debug, trace}; |
9 | 8 | ||
@@ -18,13 +17,13 @@ pub async fn websocket(mut socket: WebSocket, state: Arc<AppState>) { | |||
18 | 17 | ||
19 | trace!("Search for uuid: {}", uuid); | 18 | trace!("Search for uuid: {}", uuid); |
20 | 19 | ||
21 | let eta = get_eta(&state.db).await; | ||
22 | let _ = socket | ||
23 | .send(Message::Text(format!("eta_{eta}_{uuid}"))) | ||
24 | .await; | ||
25 | 20 | ||
26 | let device_exists = state.ping_map.contains_key(&uuid); | 21 | let device_exists = state.ping_map.contains_key(&uuid); |
27 | if device_exists { | 22 | if device_exists { |
23 | let eta = state.ping_map.get(&uuid).unwrap().eta; | ||
24 | let _ = socket | ||
25 | .send(Message::Text(format!("eta_{eta}_{uuid}"))) | ||
26 | .await; | ||
28 | let _ = socket | 27 | let _ = socket |
29 | .send(receive_ping_broadcast(state.clone(), uuid).await) | 28 | .send(receive_ping_broadcast(state.clone(), uuid).await) |
30 | .await; | 29 | .await; |
@@ -62,18 +61,3 @@ async fn receive_ping_broadcast(state: Arc<AppState>, uuid: String) -> Message { | |||
62 | } | 61 | } |
63 | } | 62 | } |
64 | } | 63 | } |
65 | |||
66 | async fn get_eta(db: &PgPool) -> i64 { | ||
67 | let query = sqlx::query!(r#"SELECT times FROM devices;"#) | ||
68 | .fetch_one(db) | ||
69 | .await | ||
70 | .unwrap(); | ||
71 | |||
72 | let times = if let Some(times) = query.times { | ||
73 | times | ||
74 | } else { | ||
75 | vec![0] | ||
76 | }; | ||
77 | |||
78 | times.iter().sum::<i64>() / i64::try_from(times.len()).unwrap() | ||
79 | } | ||
diff --git a/src/services/ping.rs b/src/services/ping.rs index 8cf6072..1bf022d 100644 --- a/src/services/ping.rs +++ b/src/services/ping.rs | |||
@@ -1,8 +1,7 @@ | |||
1 | use crate::config::Config; | 1 | use crate::config::Config; |
2 | use crate::db::Device; | 2 | use crate::storage::Device; |
3 | use dashmap::DashMap; | 3 | use dashmap::DashMap; |
4 | use ipnetwork::IpNetwork; | 4 | use ipnetwork::IpNetwork; |
5 | use sqlx::PgPool; | ||
6 | use std::fmt::Display; | 5 | use std::fmt::Display; |
7 | use time::{Duration, Instant}; | 6 | use time::{Duration, Instant}; |
8 | use tokio::sync::broadcast::Sender; | 7 | use tokio::sync::broadcast::Sender; |
@@ -13,6 +12,7 @@ pub type StatusMap = DashMap<String, Value>; | |||
13 | #[derive(Debug, Clone)] | 12 | #[derive(Debug, Clone)] |
14 | pub struct Value { | 13 | pub struct Value { |
15 | pub ip: IpNetwork, | 14 | pub ip: IpNetwork, |
15 | pub eta: i64, | ||
16 | pub online: bool, | 16 | pub online: bool, |
17 | } | 17 | } |
18 | 18 | ||
@@ -22,7 +22,6 @@ pub async fn spawn( | |||
22 | device: Device, | 22 | device: Device, |
23 | uuid: String, | 23 | uuid: String, |
24 | ping_map: &StatusMap, | 24 | ping_map: &StatusMap, |
25 | db: &PgPool, | ||
26 | ) { | 25 | ) { |
27 | let timer = Instant::now(); | 26 | let timer = Instant::now(); |
28 | let payload = [0; 8]; | 27 | let payload = [0; 8]; |
@@ -56,27 +55,29 @@ pub async fn spawn( | |||
56 | let _ = tx.send(msg.clone()); | 55 | let _ = tx.send(msg.clone()); |
57 | if msg.command == BroadcastCommands::Success { | 56 | if msg.command == BroadcastCommands::Success { |
58 | if timer.elapsed().whole_seconds() > config.pingthreshold { | 57 | if timer.elapsed().whole_seconds() > config.pingthreshold { |
59 | sqlx::query!( | 58 | let newtimes = if let Some(mut oldtimes) = device.times { |
60 | r#" | 59 | oldtimes.push(timer.elapsed().whole_seconds()); |
61 | UPDATE devices | 60 | oldtimes |
62 | SET times = array_append(times, $1) | 61 | } else { |
63 | WHERE id = $2; | 62 | vec![timer.elapsed().whole_seconds()] |
64 | "#, | 63 | }; |
65 | timer.elapsed().whole_seconds(), | 64 | |
66 | device.id | 65 | let updatedev = Device { |
67 | ) | 66 | id: device.id, |
68 | .execute(db) | 67 | mac: device.mac, |
69 | .await | 68 | broadcast_addr: device.broadcast_addr, |
70 | .unwrap(); | 69 | ip: device.ip, |
70 | times: Some(newtimes), | ||
71 | }; | ||
72 | updatedev.write().unwrap(); | ||
71 | } | 73 | } |
72 | 74 | ||
73 | ping_map.insert( | 75 | ping_map.alter(&uuid, |_, v| Value { |
74 | uuid.clone(), | 76 | ip: v.ip, |
75 | Value { | 77 | eta: v.eta, |
76 | ip: device.ip, | 78 | online: true, |
77 | online: true, | 79 | }); |
78 | }, | 80 | |
79 | ); | ||
80 | tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; | 81 | tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; |
81 | } | 82 | } |
82 | trace!("remove {} from ping_map", uuid); | 83 | trace!("remove {} from ping_map", uuid); |
diff --git a/src/storage.rs b/src/storage.rs new file mode 100644 index 0000000..0da245b --- /dev/null +++ b/src/storage.rs | |||
@@ -0,0 +1,70 @@ | |||
1 | use std::{ | ||
2 | fs::{create_dir_all, File}, | ||
3 | io::{Read, Write}, | ||
4 | path::Path, | ||
5 | }; | ||
6 | |||
7 | use ipnetwork::IpNetwork; | ||
8 | use mac_address::MacAddress; | ||
9 | use serde::{Deserialize, Serialize}; | ||
10 | use serde_json::json; | ||
11 | use tracing::{debug, trace, warn}; | ||
12 | use utoipa::ToSchema; | ||
13 | |||
14 | use crate::error::Error; | ||
15 | |||
16 | #[derive(Serialize, Deserialize, Clone, Debug)] | ||
17 | pub struct Device { | ||
18 | pub id: String, | ||
19 | pub mac: MacAddress, | ||
20 | pub broadcast_addr: String, | ||
21 | pub ip: IpNetwork, | ||
22 | pub times: Option<Vec<i64>>, | ||
23 | } | ||
24 | |||
25 | impl Device { | ||
26 | const STORAGE_PATH: &'static str = "devices"; | ||
27 | |||
28 | pub fn setup() -> Result<String, Error> { | ||
29 | trace!("check for storage at {}", Self::STORAGE_PATH); | ||
30 | let sp = Path::new(Self::STORAGE_PATH); | ||
31 | if !sp.exists() { | ||
32 | warn!("device storage path doesn't exist, creating it"); | ||
33 | create_dir_all(Self::STORAGE_PATH)?; | ||
34 | }; | ||
35 | |||
36 | debug!("device storage at '{}'", Self::STORAGE_PATH); | ||
37 | |||
38 | Ok(Self::STORAGE_PATH.to_string()) | ||
39 | } | ||
40 | |||
41 | pub fn read(id: &str) -> Result<Self, Error> { | ||
42 | trace!(?id, "attempt to read file"); | ||
43 | let mut file = File::open(format!("{}/{id}.json", Self::STORAGE_PATH))?; | ||
44 | let mut buf = String::new(); | ||
45 | file.read_to_string(&mut buf)?; | ||
46 | trace!(?id, ?buf, "read successfully from file"); | ||
47 | |||
48 | let dev = serde_json::from_str(&buf)?; | ||
49 | Ok(dev) | ||
50 | } | ||
51 | |||
52 | pub fn write(&self) -> Result<(), Error> { | ||
53 | trace!(?self.id, ?self, "attempt to write to file"); | ||
54 | let mut file = File::create(format!("{}/{}.json", Self::STORAGE_PATH, self.id))?; | ||
55 | file.write_all(json!(self).to_string().as_bytes())?; | ||
56 | trace!(?self.id, "wrote successfully to file"); | ||
57 | |||
58 | Ok(()) | ||
59 | } | ||
60 | } | ||
61 | |||
62 | #[derive(ToSchema)] | ||
63 | #[schema(as = Device)] | ||
64 | pub struct DeviceSchema { | ||
65 | pub id: String, | ||
66 | pub mac: String, | ||
67 | pub broadcast_addr: String, | ||
68 | pub ip: String, | ||
69 | pub times: Option<Vec<i64>>, | ||
70 | } | ||
@@ -2,26 +2,6 @@ use std::net::{ToSocketAddrs, UdpSocket}; | |||
2 | 2 | ||
3 | use crate::error::Error; | 3 | use crate::error::Error; |
4 | 4 | ||
5 | /// Creates the magic packet from a mac address | ||
6 | /// | ||
7 | /// # Panics | ||
8 | /// | ||
9 | /// Panics if `mac_addr` is an invalid mac | ||
10 | pub fn create_buffer(mac_addr: &str) -> Result<Vec<u8>, Error> { | ||
11 | let mut mac = Vec::new(); | ||
12 | let sp = mac_addr.split(':'); | ||
13 | for f in sp { | ||
14 | mac.push(u8::from_str_radix(f, 16)?); | ||
15 | } | ||
16 | let mut buf = vec![255; 6]; | ||
17 | for _ in 0..16 { | ||
18 | for i in &mac { | ||
19 | buf.push(*i); | ||
20 | } | ||
21 | } | ||
22 | Ok(buf) | ||
23 | } | ||
24 | |||
25 | /// Sends a buffer on UDP broadcast | 5 | /// Sends a buffer on UDP broadcast |
26 | pub fn send_packet<A: ToSocketAddrs>( | 6 | pub fn send_packet<A: ToSocketAddrs>( |
27 | bind_addr: A, | 7 | bind_addr: A, |