diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/auth.rs | 37 | ||||
-rw-r--r-- | src/config.rs | 10 | ||||
-rw-r--r-- | src/extractors.rs | 24 | ||||
-rw-r--r-- | src/main.rs | 24 | ||||
-rw-r--r-- | src/routes/device.rs | 49 | ||||
-rw-r--r-- | src/routes/start.rs | 106 |
6 files changed, 179 insertions, 71 deletions
diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..74008b5 --- /dev/null +++ b/src/auth.rs | |||
@@ -0,0 +1,37 @@ | |||
1 | use crate::AppState; | ||
2 | use axum::{ | ||
3 | extract::{Request, State}, | ||
4 | http::{HeaderMap, StatusCode}, | ||
5 | middleware::Next, | ||
6 | response::Response, | ||
7 | }; | ||
8 | use serde::Deserialize; | ||
9 | |||
10 | #[derive(Debug, Clone, Deserialize)] | ||
11 | pub enum Methods { | ||
12 | Key, | ||
13 | None, | ||
14 | } | ||
15 | |||
16 | pub async fn auth( | ||
17 | State(state): State<AppState>, | ||
18 | headers: HeaderMap, | ||
19 | request: Request, | ||
20 | next: Next, | ||
21 | ) -> Result<Response, StatusCode> { | ||
22 | let auth = state.config.auth; | ||
23 | match auth.method { | ||
24 | Methods::Key => { | ||
25 | if let Some(secret) = headers.get("authorization") { | ||
26 | if auth.secret.as_str() != secret { | ||
27 | return Err(StatusCode::UNAUTHORIZED); | ||
28 | }; | ||
29 | let response = next.run(request).await; | ||
30 | Ok(response) | ||
31 | } else { | ||
32 | Err(StatusCode::UNAUTHORIZED) | ||
33 | } | ||
34 | } | ||
35 | Methods::None => Ok(next.run(request).await), | ||
36 | } | ||
37 | } | ||
diff --git a/src/config.rs b/src/config.rs index 9605361..9636af4 100644 --- a/src/config.rs +++ b/src/config.rs | |||
@@ -1,14 +1,22 @@ | |||
1 | use config::File; | 1 | use config::File; |
2 | use serde::Deserialize; | 2 | use serde::Deserialize; |
3 | 3 | ||
4 | use crate::auth; | ||
5 | |||
4 | #[derive(Debug, Clone, Deserialize)] | 6 | #[derive(Debug, Clone, Deserialize)] |
5 | pub struct Config { | 7 | pub struct Config { |
6 | pub database_url: String, | 8 | pub database_url: String, |
7 | pub apikey: String, | ||
8 | pub serveraddr: String, | 9 | pub serveraddr: String, |
9 | pub pingtimeout: i64, | 10 | pub pingtimeout: i64, |
10 | pub pingthreshold: i64, | 11 | pub pingthreshold: i64, |
11 | pub timeoffset: i8, | 12 | pub timeoffset: i8, |
13 | pub auth: Auth, | ||
14 | } | ||
15 | |||
16 | #[derive(Debug, Clone, Deserialize)] | ||
17 | pub struct Auth { | ||
18 | pub method: auth::Methods, | ||
19 | pub secret: String, | ||
12 | } | 20 | } |
13 | 21 | ||
14 | impl Config { | 22 | impl Config { |
diff --git a/src/extractors.rs b/src/extractors.rs deleted file mode 100644 index 4d441e9..0000000 --- a/src/extractors.rs +++ /dev/null | |||
@@ -1,24 +0,0 @@ | |||
1 | use axum::{ | ||
2 | extract::{Request, State}, | ||
3 | http::{HeaderMap, StatusCode}, | ||
4 | middleware::Next, | ||
5 | response::Response, | ||
6 | }; | ||
7 | |||
8 | use crate::AppState; | ||
9 | |||
10 | pub async fn auth( | ||
11 | State(state): State<AppState>, | ||
12 | headers: HeaderMap, | ||
13 | request: Request, | ||
14 | next: Next, | ||
15 | ) -> Result<Response, StatusCode> { | ||
16 | let secret = headers.get("authorization"); | ||
17 | match secret { | ||
18 | Some(token) if token == state.config.apikey.as_str() => { | ||
19 | let response = next.run(request).await; | ||
20 | Ok(response) | ||
21 | } | ||
22 | _ => Err(StatusCode::UNAUTHORIZED), | ||
23 | } | ||
24 | } | ||
diff --git a/src/main.rs b/src/main.rs index 00fc6ce..70c67cf 100644 --- a/src/main.rs +++ b/src/main.rs | |||
@@ -11,8 +11,8 @@ use axum::{ | |||
11 | }; | 11 | }; |
12 | use dashmap::DashMap; | 12 | use dashmap::DashMap; |
13 | use sqlx::PgPool; | 13 | use sqlx::PgPool; |
14 | use time::UtcOffset; | ||
15 | use std::{env, sync::Arc}; | 14 | use std::{env, sync::Arc}; |
15 | use time::UtcOffset; | ||
16 | use tokio::sync::broadcast::{channel, Sender}; | 16 | use tokio::sync::broadcast::{channel, Sender}; |
17 | use tracing::{info, level_filters::LevelFilter}; | 17 | use tracing::{info, level_filters::LevelFilter}; |
18 | use tracing_subscriber::{ | 18 | use tracing_subscriber::{ |
@@ -29,7 +29,7 @@ use utoipa_swagger_ui::SwaggerUi; | |||
29 | mod config; | 29 | mod config; |
30 | mod db; | 30 | mod db; |
31 | mod error; | 31 | mod error; |
32 | mod extractors; | 32 | mod auth; |
33 | mod routes; | 33 | mod routes; |
34 | mod services; | 34 | mod services; |
35 | mod wol; | 35 | mod wol; |
@@ -37,19 +37,21 @@ mod wol; | |||
37 | #[derive(OpenApi)] | 37 | #[derive(OpenApi)] |
38 | #[openapi( | 38 | #[openapi( |
39 | paths( | 39 | paths( |
40 | start::start, | 40 | start::post, |
41 | start::get, | ||
42 | start::start_payload, | ||
41 | device::get, | 43 | device::get, |
42 | device::get_path, | 44 | device::get_payload, |
43 | device::post, | 45 | device::post, |
44 | device::put, | 46 | device::put, |
45 | ), | 47 | ), |
46 | components( | 48 | components( |
47 | schemas( | 49 | schemas( |
50 | start::PayloadOld, | ||
48 | start::Payload, | 51 | start::Payload, |
49 | start::Response, | 52 | start::Response, |
50 | device::PutDevicePayload, | 53 | device::DevicePayload, |
51 | device::GetDevicePayload, | 54 | device::GetDevicePayload, |
52 | device::PostDevicePayload, | ||
53 | db::DeviceSchema, | 55 | db::DeviceSchema, |
54 | ) | 56 | ) |
55 | ), | 57 | ), |
@@ -116,14 +118,16 @@ async fn main() -> color_eyre::eyre::Result<()> { | |||
116 | }; | 118 | }; |
117 | 119 | ||
118 | let app = Router::new() | 120 | let app = Router::new() |
119 | .route("/start", post(start::start)) | 121 | .route("/start", post(start::start_payload)) |
122 | .route("/start/:id", post(start::post).get(start::get)) | ||
120 | .route( | 123 | .route( |
121 | "/device", | 124 | "/device", |
122 | post(device::post).get(device::get).put(device::put), | 125 | post(device::post).get(device::get_payload).put(device::put), |
123 | ) | 126 | ) |
124 | .route("/device/:id", get(device::get_path)) | 127 | .route("/device/:id", get(device::get)) |
125 | .route("/status", get(status::status)) | 128 | .route("/status", get(status::status)) |
126 | .route_layer(from_fn_with_state(shared_state.clone(), extractors::auth)) | 129 | // TODO: Don't load on `None` Auth |
130 | .route_layer(from_fn_with_state(shared_state.clone(), auth::auth)) | ||
127 | .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi())) | 131 | .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi())) |
128 | .with_state(Arc::new(shared_state)); | 132 | .with_state(Arc::new(shared_state)); |
129 | 133 | ||
diff --git a/src/routes/device.rs b/src/routes/device.rs index d01d9f0..40b5cd8 100644 --- a/src/routes/device.rs +++ b/src/routes/device.rs | |||
@@ -20,7 +20,7 @@ use utoipa::ToSchema; | |||
20 | security(("api_key" = [])) | 20 | security(("api_key" = [])) |
21 | )] | 21 | )] |
22 | #[deprecated] | 22 | #[deprecated] |
23 | pub async fn get( | 23 | pub async fn get_payload( |
24 | State(state): State<Arc<crate::AppState>>, | 24 | State(state): State<Arc<crate::AppState>>, |
25 | Json(payload): Json<GetDevicePayload>, | 25 | Json(payload): Json<GetDevicePayload>, |
26 | ) -> Result<Json<Value>, Error> { | 26 | ) -> Result<Json<Value>, Error> { |
@@ -49,11 +49,11 @@ pub async fn get( | |||
49 | (status = 200, description = "Get `Device` information", body = [Device]) | 49 | (status = 200, description = "Get `Device` information", body = [Device]) |
50 | ), | 50 | ), |
51 | params( | 51 | params( |
52 | ("id" = String, Path, description = "Device id") | 52 | ("id" = String, Path, description = "device id") |
53 | ), | 53 | ), |
54 | security(("api_key" = [])) | 54 | security((), ("api_key" = [])) |
55 | )] | 55 | )] |
56 | pub async fn get_path( | 56 | pub async fn get( |
57 | State(state): State<Arc<crate::AppState>>, | 57 | State(state): State<Arc<crate::AppState>>, |
58 | Path(path): Path<String>, | 58 | Path(path): Path<String>, |
59 | ) -> Result<Json<Value>, Error> { | 59 | ) -> Result<Json<Value>, Error> { |
@@ -76,22 +76,31 @@ pub async fn get_path( | |||
76 | } | 76 | } |
77 | 77 | ||
78 | #[derive(Deserialize, ToSchema)] | 78 | #[derive(Deserialize, ToSchema)] |
79 | #[deprecated] | ||
79 | pub struct GetDevicePayload { | 80 | pub struct GetDevicePayload { |
80 | id: String, | 81 | id: String, |
81 | } | 82 | } |
82 | 83 | ||
84 | #[derive(Deserialize, ToSchema)] | ||
85 | pub struct DevicePayload { | ||
86 | id: String, | ||
87 | mac: String, | ||
88 | broadcast_addr: String, | ||
89 | ip: String, | ||
90 | } | ||
91 | |||
83 | #[utoipa::path( | 92 | #[utoipa::path( |
84 | put, | 93 | put, |
85 | path = "/device", | 94 | path = "/device", |
86 | request_body = PutDevicePayload, | 95 | request_body = DevicePayload, |
87 | responses( | 96 | responses( |
88 | (status = 200, description = "List matching todos by query", body = [DeviceSchema]) | 97 | (status = 200, description = "add device to storage", body = [DeviceSchema]) |
89 | ), | 98 | ), |
90 | security(("api_key" = [])) | 99 | security((), ("api_key" = [])) |
91 | )] | 100 | )] |
92 | pub async fn put( | 101 | pub async fn put( |
93 | State(state): State<Arc<crate::AppState>>, | 102 | State(state): State<Arc<crate::AppState>>, |
94 | Json(payload): Json<PutDevicePayload>, | 103 | Json(payload): Json<DevicePayload>, |
95 | ) -> Result<Json<Value>, Error> { | 104 | ) -> Result<Json<Value>, Error> { |
96 | info!( | 105 | info!( |
97 | "add device {} ({}, {}, {})", | 106 | "add device {} ({}, {}, {})", |
@@ -118,26 +127,18 @@ pub async fn put( | |||
118 | Ok(Json(json!(device))) | 127 | Ok(Json(json!(device))) |
119 | } | 128 | } |
120 | 129 | ||
121 | #[derive(Deserialize, ToSchema)] | ||
122 | pub struct PutDevicePayload { | ||
123 | id: String, | ||
124 | mac: String, | ||
125 | broadcast_addr: String, | ||
126 | ip: String, | ||
127 | } | ||
128 | |||
129 | #[utoipa::path( | 130 | #[utoipa::path( |
130 | post, | 131 | post, |
131 | path = "/device", | 132 | path = "/device", |
132 | request_body = PostDevicePayload, | 133 | request_body = DevicePayload, |
133 | responses( | 134 | responses( |
134 | (status = 200, description = "List matching todos by query", body = [DeviceSchema]) | 135 | (status = 200, description = "update device in storage", body = [DeviceSchema]) |
135 | ), | 136 | ), |
136 | security(("api_key" = [])) | 137 | security((), ("api_key" = [])) |
137 | )] | 138 | )] |
138 | pub async fn post( | 139 | pub async fn post( |
139 | State(state): State<Arc<crate::AppState>>, | 140 | State(state): State<Arc<crate::AppState>>, |
140 | Json(payload): Json<PostDevicePayload>, | 141 | Json(payload): Json<DevicePayload>, |
141 | ) -> Result<Json<Value>, Error> { | 142 | ) -> Result<Json<Value>, Error> { |
142 | info!( | 143 | info!( |
143 | "edit device {} ({}, {}, {})", | 144 | "edit device {} ({}, {}, {})", |
@@ -162,11 +163,3 @@ pub async fn post( | |||
162 | 163 | ||
163 | Ok(Json(json!(device))) | 164 | Ok(Json(json!(device))) |
164 | } | 165 | } |
165 | |||
166 | #[derive(Deserialize, ToSchema)] | ||
167 | pub struct PostDevicePayload { | ||
168 | id: String, | ||
169 | mac: String, | ||
170 | broadcast_addr: String, | ||
171 | ip: String, | ||
172 | } | ||
diff --git a/src/routes/start.rs b/src/routes/start.rs index ef6e8f2..ff3d1be 100644 --- a/src/routes/start.rs +++ b/src/routes/start.rs | |||
@@ -2,27 +2,28 @@ use crate::db::Device; | |||
2 | use crate::error::Error; | 2 | use crate::error::Error; |
3 | use crate::services::ping::Value as PingValue; | 3 | use crate::services::ping::Value as PingValue; |
4 | use crate::wol::{create_buffer, send_packet}; | 4 | use crate::wol::{create_buffer, send_packet}; |
5 | use axum::extract::State; | 5 | use axum::extract::{Path, State}; |
6 | use axum::Json; | 6 | use axum::Json; |
7 | use serde::{Deserialize, Serialize}; | 7 | use serde::{Deserialize, Serialize}; |
8 | use serde_json::{json, Value}; | 8 | use serde_json::{json, Value}; |
9 | use utoipa::ToSchema; | ||
10 | use std::sync::Arc; | 9 | use std::sync::Arc; |
11 | use tracing::{debug, info}; | 10 | use tracing::{debug, info}; |
11 | use utoipa::ToSchema; | ||
12 | use uuid::Uuid; | 12 | use uuid::Uuid; |
13 | 13 | ||
14 | #[utoipa::path( | 14 | #[utoipa::path( |
15 | post, | 15 | post, |
16 | path = "/start", | 16 | path = "/start", |
17 | request_body = Payload, | 17 | request_body = PayloadOld, |
18 | responses( | 18 | responses( |
19 | (status = 200, description = "List matching todos by query", body = [Response]) | 19 | (status = 200, description = "DEP", body = [Response]) |
20 | ), | 20 | ), |
21 | security(("api_key" = [])) | 21 | security((), ("api_key" = [])) |
22 | )] | 22 | )] |
23 | pub async fn start( | 23 | #[deprecated] |
24 | pub async fn start_payload( | ||
24 | State(state): State<Arc<crate::AppState>>, | 25 | State(state): State<Arc<crate::AppState>>, |
25 | Json(payload): Json<Payload>, | 26 | Json(payload): Json<PayloadOld>, |
26 | ) -> Result<Json<Value>, Error> { | 27 | ) -> Result<Json<Value>, Error> { |
27 | info!("POST request"); | 28 | info!("POST request"); |
28 | let device = sqlx::query_as!( | 29 | let device = sqlx::query_as!( |
@@ -59,6 +60,89 @@ pub async fn start( | |||
59 | }))) | 60 | }))) |
60 | } | 61 | } |
61 | 62 | ||
63 | #[utoipa::path( | ||
64 | post, | ||
65 | path = "/start/{id}", | ||
66 | request_body = Option<Payload>, | ||
67 | responses( | ||
68 | (status = 200, description = "start the device with the given id", body = [Response]) | ||
69 | ), | ||
70 | params( | ||
71 | ("id" = String, Path, description = "device id") | ||
72 | ), | ||
73 | security((), ("api_key" = [])) | ||
74 | )] | ||
75 | pub async fn post( | ||
76 | State(state): State<Arc<crate::AppState>>, | ||
77 | Path(id): Path<String>, | ||
78 | payload: Option<Json<Payload>>, | ||
79 | ) -> Result<Json<Value>, Error> { | ||
80 | send_wol(state, &id, payload).await | ||
81 | } | ||
82 | |||
83 | #[utoipa::path( | ||
84 | get, | ||
85 | path = "/start/{id}", | ||
86 | responses( | ||
87 | (status = 200, description = "start the device with the given id", body = [Response]) | ||
88 | ), | ||
89 | params( | ||
90 | ("id" = String, Path, description = "device id") | ||
91 | ), | ||
92 | security((), ("api_key" = [])) | ||
93 | )] | ||
94 | pub async fn get( | ||
95 | State(state): State<Arc<crate::AppState>>, | ||
96 | Path(id): Path<String>, | ||
97 | ) -> Result<Json<Value>, Error> { | ||
98 | send_wol(state, &id, None).await | ||
99 | } | ||
100 | |||
101 | async fn send_wol( | ||
102 | state: Arc<crate::AppState>, | ||
103 | id: &str, | ||
104 | payload: Option<Json<Payload>>, | ||
105 | ) -> Result<Json<Value>, Error> { | ||
106 | info!("Start request for {id}"); | ||
107 | let device = sqlx::query_as!( | ||
108 | Device, | ||
109 | r#" | ||
110 | SELECT id, mac, broadcast_addr, ip, times | ||
111 | FROM devices | ||
112 | WHERE id = $1; | ||
113 | "#, | ||
114 | id | ||
115 | ) | ||
116 | .fetch_one(&state.db) | ||
117 | .await?; | ||
118 | |||
119 | info!("starting {}", device.id); | ||
120 | |||
121 | let bind_addr = "0.0.0.0:0"; | ||
122 | |||
123 | let _ = send_packet( | ||
124 | bind_addr, | ||
125 | &device.broadcast_addr, | ||
126 | &create_buffer(&device.mac.to_string())?, | ||
127 | )?; | ||
128 | let dev_id = device.id.clone(); | ||
129 | let uuid = if let Some(pl) = payload { | ||
130 | if pl.ping.is_some_and(|ping| ping) { | ||
131 | Some(setup_ping(state, device)) | ||
132 | } else { | ||
133 | None | ||
134 | } | ||
135 | } else { | ||
136 | None | ||
137 | }; | ||
138 | |||
139 | Ok(Json(json!(Response { | ||
140 | id: dev_id, | ||
141 | boot: true, | ||
142 | uuid | ||
143 | }))) | ||
144 | } | ||
145 | |||
62 | fn setup_ping(state: Arc<crate::AppState>, device: Device) -> String { | 146 | fn setup_ping(state: Arc<crate::AppState>, device: Device) -> String { |
63 | let mut uuid: Option<String> = None; | 147 | let mut uuid: Option<String> = None; |
64 | for (key, value) in state.ping_map.clone() { | 148 | for (key, value) in state.ping_map.clone() { |
@@ -99,11 +183,17 @@ fn setup_ping(state: Arc<crate::AppState>, device: Device) -> String { | |||
99 | } | 183 | } |
100 | 184 | ||
101 | #[derive(Deserialize, ToSchema)] | 185 | #[derive(Deserialize, ToSchema)] |
102 | pub struct Payload { | 186 | #[deprecated] |
187 | pub struct PayloadOld { | ||
103 | id: String, | 188 | id: String, |
104 | ping: Option<bool>, | 189 | ping: Option<bool>, |
105 | } | 190 | } |
106 | 191 | ||
192 | #[derive(Deserialize, ToSchema)] | ||
193 | pub struct Payload { | ||
194 | ping: Option<bool>, | ||
195 | } | ||
196 | |||
107 | #[derive(Serialize, ToSchema)] | 197 | #[derive(Serialize, ToSchema)] |
108 | pub struct Response { | 198 | pub struct Response { |
109 | id: String, | 199 | id: String, |