diff options
author | FxQnLr <[email protected]> | 2024-02-25 15:15:19 +0100 |
---|---|---|
committer | FxQnLr <[email protected]> | 2024-02-25 15:15:19 +0100 |
commit | 9058f191b69ecafc8fdeace227ac113412d03888 (patch) | |
tree | 88ae071fa31c9a5831722ec82878ccf8fd2b224a /src | |
parent | 2f9f18b80a9e2134f674f345e48a5f21de5efadd (diff) | |
download | webol-9058f191b69ecafc8fdeace227ac113412d03888.tar webol-9058f191b69ecafc8fdeace227ac113412d03888.tar.gz webol-9058f191b69ecafc8fdeace227ac113412d03888.zip |
Closes #16. Impl auth as extractor
Diffstat (limited to 'src')
-rw-r--r-- | src/auth.rs | 30 | ||||
-rw-r--r-- | src/error.rs | 6 | ||||
-rw-r--r-- | src/extractors.rs | 24 | ||||
-rw-r--r-- | src/main.rs | 20 | ||||
-rw-r--r-- | src/routes/device.rs | 113 | ||||
-rw-r--r-- | src/routes/start.rs | 65 |
6 files changed, 107 insertions, 151 deletions
diff --git a/src/auth.rs b/src/auth.rs deleted file mode 100644 index 22f87e7..0000000 --- a/src/auth.rs +++ /dev/null | |||
@@ -1,30 +0,0 @@ | |||
1 | use axum::http::HeaderValue; | ||
2 | use tracing::{debug, trace}; | ||
3 | use crate::config::Config; | ||
4 | use crate::error::Error; | ||
5 | |||
6 | pub fn auth(config: &Config, secret: Option<&HeaderValue>) -> Result<Response, Error> { | ||
7 | debug!("auth request with secret {:?}", secret); | ||
8 | let res = if let Some(value) = secret { | ||
9 | trace!("auth value exists"); | ||
10 | let key = &config.apikey; | ||
11 | if value.to_str()? == key.as_str() { | ||
12 | debug!("successful auth"); | ||
13 | Response::Success | ||
14 | } else { | ||
15 | debug!("unsuccessful auth (wrong secret)"); | ||
16 | Response::WrongSecret | ||
17 | } | ||
18 | } else { | ||
19 | debug!("unsuccessful auth (no secret)"); | ||
20 | Response::MissingSecret | ||
21 | }; | ||
22 | Ok(res) | ||
23 | } | ||
24 | |||
25 | #[derive(Debug)] | ||
26 | pub enum Response { | ||
27 | Success, | ||
28 | WrongSecret, | ||
29 | MissingSecret | ||
30 | } | ||
diff --git a/src/error.rs b/src/error.rs index 66a61f4..513b51b 100644 --- a/src/error.rs +++ b/src/error.rs | |||
@@ -1,8 +1,8 @@ | |||
1 | use ::ipnetwork::IpNetworkError; | ||
1 | use axum::http::header::ToStrError; | 2 | use axum::http::header::ToStrError; |
2 | use axum::http::StatusCode; | 3 | use axum::http::StatusCode; |
3 | use axum::response::{IntoResponse, Response}; | 4 | use axum::response::{IntoResponse, Response}; |
4 | use axum::Json; | 5 | use axum::Json; |
5 | use ::ipnetwork::IpNetworkError; | ||
6 | use mac_address::MacParseError; | 6 | use mac_address::MacParseError; |
7 | use serde_json::json; | 7 | use serde_json::json; |
8 | use std::io; | 8 | use std::io; |
@@ -10,9 +10,6 @@ use tracing::error; | |||
10 | 10 | ||
11 | #[derive(Debug, thiserror::Error)] | 11 | #[derive(Debug, thiserror::Error)] |
12 | pub enum Error { | 12 | pub enum Error { |
13 | #[error("generic error")] | ||
14 | Generic, | ||
15 | |||
16 | #[error("db: {source}")] | 13 | #[error("db: {source}")] |
17 | Db { | 14 | Db { |
18 | #[from] | 15 | #[from] |
@@ -54,7 +51,6 @@ impl IntoResponse for Error { | |||
54 | fn into_response(self) -> Response { | 51 | fn into_response(self) -> Response { |
55 | error!("{}", self.to_string()); | 52 | error!("{}", self.to_string()); |
56 | let (status, error_message) = match self { | 53 | let (status, error_message) = match self { |
57 | Self::Generic => (StatusCode::INTERNAL_SERVER_ERROR, ""), | ||
58 | Self::Db { source } => { | 54 | Self::Db { source } => { |
59 | error!("{source}"); | 55 | error!("{source}"); |
60 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | 56 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") |
diff --git a/src/extractors.rs b/src/extractors.rs new file mode 100644 index 0000000..4d441e9 --- /dev/null +++ b/src/extractors.rs | |||
@@ -0,0 +1,24 @@ | |||
1 | use axum::{ | ||
2 | extract::{Request, State}, | ||
3 | http::{HeaderMap, StatusCode}, | ||
4 | middleware::Next, | ||
5 | response::Response, | ||
6 | }; | ||
7 | |||
8 | use crate::AppState; | ||
9 | |||
10 | pub async fn auth( | ||
11 | State(state): State<AppState>, | ||
12 | headers: HeaderMap, | ||
13 | request: Request, | ||
14 | next: Next, | ||
15 | ) -> Result<Response, StatusCode> { | ||
16 | let secret = headers.get("authorization"); | ||
17 | match secret { | ||
18 | Some(token) if token == state.config.apikey.as_str() => { | ||
19 | let response = next.run(request).await; | ||
20 | Ok(response) | ||
21 | } | ||
22 | _ => Err(StatusCode::UNAUTHORIZED), | ||
23 | } | ||
24 | } | ||
diff --git a/src/main.rs b/src/main.rs index 7d8c1da..eae89f6 100644 --- a/src/main.rs +++ b/src/main.rs | |||
@@ -4,26 +4,23 @@ use crate::routes::device; | |||
4 | use crate::routes::start::start; | 4 | use crate::routes::start::start; |
5 | use crate::routes::status::status; | 5 | use crate::routes::status::status; |
6 | use crate::services::ping::StatusMap; | 6 | use crate::services::ping::StatusMap; |
7 | use axum::middleware::from_fn_with_state; | ||
7 | use axum::routing::{get, put}; | 8 | use axum::routing::{get, put}; |
8 | use axum::{routing::post, Router}; | 9 | use axum::{routing::post, Router}; |
9 | use dashmap::DashMap; | 10 | use dashmap::DashMap; |
10 | use services::ping::BroadcastCommand; | 11 | use services::ping::BroadcastCommand; |
11 | use sqlx::PgPool; | 12 | use sqlx::PgPool; |
12 | use tracing_subscriber::fmt::time::UtcTime; | ||
13 | use std::env; | 13 | use std::env; |
14 | use std::sync::Arc; | 14 | use std::sync::Arc; |
15 | use tokio::sync::broadcast::{channel, Sender}; | 15 | use tokio::sync::broadcast::{channel, Sender}; |
16 | use tracing::{info, level_filters::LevelFilter}; | 16 | use tracing::{info, level_filters::LevelFilter}; |
17 | use tracing_subscriber::{ | 17 | use tracing_subscriber::fmt::time::UtcTime; |
18 | fmt, | 18 | use tracing_subscriber::{fmt, prelude::*, EnvFilter}; |
19 | prelude::*, | ||
20 | EnvFilter, | ||
21 | }; | ||
22 | 19 | ||
23 | mod auth; | ||
24 | mod config; | 20 | mod config; |
25 | mod db; | 21 | mod db; |
26 | mod error; | 22 | mod error; |
23 | mod extractors; | ||
27 | mod routes; | 24 | mod routes; |
28 | mod services; | 25 | mod services; |
29 | mod wol; | 26 | mod wol; |
@@ -31,7 +28,6 @@ mod wol; | |||
31 | #[tokio::main] | 28 | #[tokio::main] |
32 | async fn main() -> color_eyre::eyre::Result<()> { | 29 | async fn main() -> color_eyre::eyre::Result<()> { |
33 | color_eyre::install()?; | 30 | color_eyre::install()?; |
34 | |||
35 | 31 | ||
36 | let time_format = | 32 | let time_format = |
37 | time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second]"); | 33 | time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second]"); |
@@ -59,12 +55,12 @@ async fn main() -> color_eyre::eyre::Result<()> { | |||
59 | 55 | ||
60 | let ping_map: StatusMap = DashMap::new(); | 56 | let ping_map: StatusMap = DashMap::new(); |
61 | 57 | ||
62 | let shared_state = Arc::new(AppState { | 58 | let shared_state = AppState { |
63 | db, | 59 | db, |
64 | config: config.clone(), | 60 | config: config.clone(), |
65 | ping_send: tx, | 61 | ping_send: tx, |
66 | ping_map, | 62 | ping_map, |
67 | }); | 63 | }; |
68 | 64 | ||
69 | let app = Router::new() | 65 | let app = Router::new() |
70 | .route("/start", post(start)) | 66 | .route("/start", post(start)) |
@@ -72,7 +68,8 @@ async fn main() -> color_eyre::eyre::Result<()> { | |||
72 | .route("/device", put(device::put)) | 68 | .route("/device", put(device::put)) |
73 | .route("/device", post(device::post)) | 69 | .route("/device", post(device::post)) |
74 | .route("/status", get(status)) | 70 | .route("/status", get(status)) |
75 | .with_state(shared_state); | 71 | .route_layer(from_fn_with_state(shared_state.clone(), extractors::auth)) |
72 | .with_state(Arc::new(shared_state)); | ||
76 | 73 | ||
77 | let addr = config.serveraddr; | 74 | let addr = config.serveraddr; |
78 | info!("start server on {}", addr); | 75 | info!("start server on {}", addr); |
@@ -82,6 +79,7 @@ async fn main() -> color_eyre::eyre::Result<()> { | |||
82 | Ok(()) | 79 | Ok(()) |
83 | } | 80 | } |
84 | 81 | ||
82 | #[derive(Clone)] | ||
85 | pub struct AppState { | 83 | pub struct AppState { |
86 | db: PgPool, | 84 | db: PgPool, |
87 | config: Config, | 85 | config: Config, |
diff --git a/src/routes/device.rs b/src/routes/device.rs index 2f0093d..d39d98e 100644 --- a/src/routes/device.rs +++ b/src/routes/device.rs | |||
@@ -1,8 +1,6 @@ | |||
1 | use crate::auth::auth; | ||
2 | use crate::db::Device; | 1 | use crate::db::Device; |
3 | use crate::error::Error; | 2 | use crate::error::Error; |
4 | use axum::extract::State; | 3 | use axum::extract::State; |
5 | use axum::http::HeaderMap; | ||
6 | use axum::Json; | 4 | use axum::Json; |
7 | use mac_address::MacAddress; | 5 | use mac_address::MacAddress; |
8 | use serde::{Deserialize, Serialize}; | 6 | use serde::{Deserialize, Serialize}; |
@@ -13,31 +11,24 @@ use tracing::{debug, info}; | |||
13 | 11 | ||
14 | pub async fn get( | 12 | pub async fn get( |
15 | State(state): State<Arc<crate::AppState>>, | 13 | State(state): State<Arc<crate::AppState>>, |
16 | headers: HeaderMap, | ||
17 | Json(payload): Json<GetDevicePayload>, | 14 | Json(payload): Json<GetDevicePayload>, |
18 | ) -> Result<Json<Value>, Error> { | 15 | ) -> Result<Json<Value>, Error> { |
19 | info!("get device {}", payload.id); | 16 | info!("get device {}", payload.id); |
20 | let secret = headers.get("authorization"); | 17 | let device = sqlx::query_as!( |
21 | let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); | 18 | Device, |
22 | if authorized { | 19 | r#" |
23 | let device = sqlx::query_as!( | 20 | SELECT id, mac, broadcast_addr, ip, times |
24 | Device, | 21 | FROM devices |
25 | r#" | 22 | WHERE id = $1; |
26 | SELECT id, mac, broadcast_addr, ip, times | 23 | "#, |
27 | FROM devices | 24 | payload.id |
28 | WHERE id = $1; | 25 | ) |
29 | "#, | 26 | .fetch_one(&state.db) |
30 | payload.id | 27 | .await?; |
31 | ) | ||
32 | .fetch_one(&state.db) | ||
33 | .await?; | ||
34 | 28 | ||
35 | debug!("got device {:?}", device); | 29 | debug!("got device {:?}", device); |
36 | 30 | ||
37 | Ok(Json(json!(device))) | 31 | Ok(Json(json!(device))) |
38 | } else { | ||
39 | Err(Error::Generic) | ||
40 | } | ||
41 | } | 32 | } |
42 | 33 | ||
43 | #[derive(Deserialize)] | 34 | #[derive(Deserialize)] |
@@ -47,7 +38,6 @@ pub struct GetDevicePayload { | |||
47 | 38 | ||
48 | pub async fn put( | 39 | pub async fn put( |
49 | State(state): State<Arc<crate::AppState>>, | 40 | State(state): State<Arc<crate::AppState>>, |
50 | headers: HeaderMap, | ||
51 | Json(payload): Json<PutDevicePayload>, | 41 | Json(payload): Json<PutDevicePayload>, |
52 | ) -> Result<Json<Value>, Error> { | 42 | ) -> Result<Json<Value>, Error> { |
53 | info!( | 43 | info!( |
@@ -55,28 +45,22 @@ pub async fn put( | |||
55 | payload.id, payload.mac, payload.broadcast_addr, payload.ip | 45 | payload.id, payload.mac, payload.broadcast_addr, payload.ip |
56 | ); | 46 | ); |
57 | 47 | ||
58 | let secret = headers.get("authorization"); | 48 | let ip = IpNetwork::from_str(&payload.ip)?; |
59 | let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); | 49 | let mac = MacAddress::from_str(&payload.mac)?; |
60 | if authorized { | 50 | sqlx::query!( |
61 | let ip = IpNetwork::from_str(&payload.ip)?; | 51 | r#" |
62 | let mac = MacAddress::from_str(&payload.mac)?; | 52 | INSERT INTO devices (id, mac, broadcast_addr, ip) |
63 | sqlx::query!( | 53 | VALUES ($1, $2, $3, $4); |
64 | r#" | 54 | "#, |
65 | INSERT INTO devices (id, mac, broadcast_addr, ip) | 55 | payload.id, |
66 | VALUES ($1, $2, $3, $4); | 56 | mac, |
67 | "#, | 57 | payload.broadcast_addr, |
68 | payload.id, | 58 | ip |
69 | mac, | 59 | ) |
70 | payload.broadcast_addr, | 60 | .execute(&state.db) |
71 | ip | 61 | .await?; |
72 | ) | ||
73 | .execute(&state.db) | ||
74 | .await?; | ||
75 | 62 | ||
76 | Ok(Json(json!(PutDeviceResponse { success: true }))) | 63 | Ok(Json(json!(PutDeviceResponse { success: true }))) |
77 | } else { | ||
78 | Err(Error::Generic) | ||
79 | } | ||
80 | } | 64 | } |
81 | 65 | ||
82 | #[derive(Deserialize)] | 66 | #[derive(Deserialize)] |
@@ -94,37 +78,30 @@ pub struct PutDeviceResponse { | |||
94 | 78 | ||
95 | pub async fn post( | 79 | pub async fn post( |
96 | State(state): State<Arc<crate::AppState>>, | 80 | State(state): State<Arc<crate::AppState>>, |
97 | headers: HeaderMap, | ||
98 | Json(payload): Json<PostDevicePayload>, | 81 | Json(payload): Json<PostDevicePayload>, |
99 | ) -> Result<Json<Value>, Error> { | 82 | ) -> Result<Json<Value>, Error> { |
100 | info!( | 83 | info!( |
101 | "edit device {} ({}, {}, {})", | 84 | "edit device {} ({}, {}, {})", |
102 | payload.id, payload.mac, payload.broadcast_addr, payload.ip | 85 | payload.id, payload.mac, payload.broadcast_addr, payload.ip |
103 | ); | 86 | ); |
104 | let secret = headers.get("authorization"); | 87 | let ip = IpNetwork::from_str(&payload.ip)?; |
105 | let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); | 88 | let mac = MacAddress::from_str(&payload.mac)?; |
106 | if authorized { | 89 | let device = sqlx::query_as!( |
107 | let ip = IpNetwork::from_str(&payload.ip)?; | 90 | Device, |
108 | let mac = MacAddress::from_str(&payload.mac)?; | 91 | r#" |
109 | let device = sqlx::query_as!( | 92 | UPDATE devices |
110 | Device, | 93 | SET mac = $1, broadcast_addr = $2, ip = $3 WHERE id = $4 |
111 | r#" | 94 | RETURNING id, mac, broadcast_addr, ip, times; |
112 | UPDATE devices | 95 | "#, |
113 | SET mac = $1, broadcast_addr = $2, ip = $3 WHERE id = $4 | 96 | mac, |
114 | RETURNING id, mac, broadcast_addr, ip, times; | 97 | payload.broadcast_addr, |
115 | "#, | 98 | ip, |
116 | mac, | 99 | payload.id |
117 | payload.broadcast_addr, | 100 | ) |
118 | ip, | 101 | .fetch_one(&state.db) |
119 | payload.id | 102 | .await?; |
120 | ) | ||
121 | .fetch_one(&state.db) | ||
122 | .await?; | ||
123 | 103 | ||
124 | Ok(Json(json!(device))) | 104 | Ok(Json(json!(device))) |
125 | } else { | ||
126 | Err(Error::Generic) | ||
127 | } | ||
128 | } | 105 | } |
129 | 106 | ||
130 | #[derive(Deserialize)] | 107 | #[derive(Deserialize)] |
diff --git a/src/routes/start.rs b/src/routes/start.rs index 4888325..d4c0802 100644 --- a/src/routes/start.rs +++ b/src/routes/start.rs | |||
@@ -1,10 +1,8 @@ | |||
1 | use crate::auth::auth; | ||
2 | use crate::db::Device; | 1 | use crate::db::Device; |
3 | use crate::error::Error; | 2 | use crate::error::Error; |
4 | use crate::services::ping::Value as PingValue; | 3 | use crate::services::ping::Value as PingValue; |
5 | use crate::wol::{create_buffer, send_packet}; | 4 | use crate::wol::{create_buffer, send_packet}; |
6 | use axum::extract::State; | 5 | use axum::extract::State; |
7 | use axum::http::HeaderMap; | ||
8 | use axum::Json; | 6 | use axum::Json; |
9 | use serde::{Deserialize, Serialize}; | 7 | use serde::{Deserialize, Serialize}; |
10 | use serde_json::{json, Value}; | 8 | use serde_json::{json, Value}; |
@@ -14,48 +12,41 @@ use uuid::Uuid; | |||
14 | 12 | ||
15 | pub async fn start( | 13 | pub async fn start( |
16 | State(state): State<Arc<crate::AppState>>, | 14 | State(state): State<Arc<crate::AppState>>, |
17 | headers: HeaderMap, | ||
18 | Json(payload): Json<Payload>, | 15 | Json(payload): Json<Payload>, |
19 | ) -> Result<Json<Value>, Error> { | 16 | ) -> Result<Json<Value>, Error> { |
20 | info!("POST request"); | 17 | info!("POST request"); |
21 | let secret = headers.get("authorization"); | 18 | let device = sqlx::query_as!( |
22 | let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); | 19 | Device, |
23 | if authorized { | 20 | r#" |
24 | let device = sqlx::query_as!( | 21 | SELECT id, mac, broadcast_addr, ip, times |
25 | Device, | 22 | FROM devices |
26 | r#" | 23 | WHERE id = $1; |
27 | SELECT id, mac, broadcast_addr, ip, times | 24 | "#, |
28 | FROM devices | 25 | payload.id |
29 | WHERE id = $1; | 26 | ) |
30 | "#, | 27 | .fetch_one(&state.db) |
31 | payload.id | 28 | .await?; |
32 | ) | ||
33 | .fetch_one(&state.db) | ||
34 | .await?; | ||
35 | 29 | ||
36 | info!("starting {}", device.id); | 30 | info!("starting {}", device.id); |
37 | 31 | ||
38 | let bind_addr = "0.0.0.0:0"; | 32 | let bind_addr = "0.0.0.0:0"; |
39 | 33 | ||
40 | let _ = send_packet( | 34 | let _ = send_packet( |
41 | bind_addr, | 35 | bind_addr, |
42 | &device.broadcast_addr, | 36 | &device.broadcast_addr, |
43 | &create_buffer(&device.mac.to_string())?, | 37 | &create_buffer(&device.mac.to_string())?, |
44 | )?; | 38 | )?; |
45 | let dev_id = device.id.clone(); | 39 | let dev_id = device.id.clone(); |
46 | let uuid = if payload.ping.is_some_and(|ping| ping) { | 40 | let uuid = if payload.ping.is_some_and(|ping| ping) { |
47 | Some(setup_ping(state, device)) | 41 | Some(setup_ping(state, device)) |
48 | } else { | ||
49 | None | ||
50 | }; | ||
51 | Ok(Json(json!(Response { | ||
52 | id: dev_id, | ||
53 | boot: true, | ||
54 | uuid | ||
55 | }))) | ||
56 | } else { | 42 | } else { |
57 | Err(Error::Generic) | 43 | None |
58 | } | 44 | }; |
45 | Ok(Json(json!(Response { | ||
46 | id: dev_id, | ||
47 | boot: true, | ||
48 | uuid | ||
49 | }))) | ||
59 | } | 50 | } |
60 | 51 | ||
61 | fn setup_ping(state: Arc<crate::AppState>, device: Device) -> String { | 52 | fn setup_ping(state: Arc<crate::AppState>, device: Device) -> String { |