use axum::extract::ws::Message; use axum::extract::{State, WebSocketUpgrade}; use axum::response::Html; use axum::{ body::Bytes, extract::Path, http::{header, HeaderValue, Method}, response::IntoResponse, routing::get, Router, }; use message::{ClientRequest, HyperdeckMonitorState, ServerEvent}; use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, net::{Ipv4Addr, SocketAddr}, sync::Arc, time::Duration, }; use tokio::sync::{Mutex, RwLock}; use tower::ServiceBuilder; use tower_http::timeout::TimeoutLayer; use tower_http::ServiceBuilderExt; use tower_http::{ cors::{Any, CorsLayer}, trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer}, LatencyUnit, }; use tracing::info; use uuid::Uuid; pub mod message; mod ws; #[derive(Debug, Clone)] pub struct Client { pub sender: Option>, } type Clients = Arc>>; pub async fn initialize_api( mut state_rx: tokio::sync::broadcast::Receiver, client_request_tx: tokio::sync::mpsc::UnboundedSender, ) { info!("Initializing API"); let clients: Clients = Default::default(); let state = Arc::new(RwLock::new(state_rx.recv().await.unwrap())); let state_clients = clients.clone(); let state_loop = state.clone(); tokio::spawn(async move { loop { if let Ok(hyperdeck_monitor_state) = state_rx.recv().await { let mut state = state_loop.write().await; *state = hyperdeck_monitor_state.clone(); let clients = state_clients.lock().await; let state_json = serde_json::to_string(&ServerEvent::HyperdeckMonitorState( hyperdeck_monitor_state.into(), )) .unwrap(); for (_, client) in clients.iter() { if let Some(sender) = &client.sender { let message: Message = Message::Text(state_json.clone()); let _ = sender.send(message); } } } } }); let app_state = AppState { state, client_request_tx, clients, port: 9681, }; let addr = SocketAddr::from((Ipv4Addr::UNSPECIFIED, app_state.port)); info!("Listening on {}", addr); // TODO: This could fail, need to figure out how to get a result from this let _ = axum::Server::bind(&addr) .serve(app(app_state).into_make_service()) .await; } #[derive(Clone)] struct AppState { state: Arc>, client_request_tx: tokio::sync::mpsc::UnboundedSender, clients: Clients, port: u16, } fn app(state: AppState) -> Router { let sensitive_headers: Arc<[_]> = vec![header::AUTHORIZATION, header::COOKIE].into(); let middleware = ServiceBuilder::new() // Mark the `Authorization` and `Cookie` headers as sensitive so it doesn't show in logs .sensitive_request_headers(sensitive_headers.clone()) // Add high level tracing/logging to all requests .layer( TraceLayer::new_for_http() .on_body_chunk(|chunk: &Bytes, latency: Duration, _: &tracing::Span| { tracing::trace!(size_bytes = chunk.len(), latency = ?latency, "sending body chunk") }) .make_span_with(DefaultMakeSpan::new().include_headers(true)) .on_response(DefaultOnResponse::new().include_headers(true).latency_unit(LatencyUnit::Micros)), ) .sensitive_response_headers(sensitive_headers) // Set a timeout .layer(TimeoutLayer::new(Duration::from_secs(10))) // Box the response body so it implements `Default` which is required by axum .map_response_body(axum::body::boxed) // Compress responses .compression() // Set a `Content-Type` if there isn't one already. .insert_response_header_if_not_present( header::CONTENT_TYPE, HeaderValue::from_static("application/octet-stream"), ); let cors = CorsLayer::new() .allow_methods(vec![ Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS, ]) .allow_headers(Any) .allow_origin(Any) .allow_credentials(false); Router::new() .route("/", get(get_index)) .route("/ws", get(upgrade_ws)) .layer(middleware) .layer(cors) .with_state(state) } #[derive(Debug, Serialize, Deserialize)] pub struct WebSocketUpgradeRequest {} async fn get_index() -> Html { Html(format!("Hello!")) } #[axum::debug_handler] async fn upgrade_ws(state: State, ws: WebSocketUpgrade) -> impl IntoResponse { info!("New client websocket connection"); let client_id = uuid::Uuid::new_v4(); state .clients .lock() .await .insert(client_id.clone(), Client { sender: None }); let client = state.clients.lock().await.get(&client_id).cloned().unwrap(); ws.on_upgrade(move |socket| { ws::client_connection( state.client_request_tx.clone(), socket, client_id, state.state.clone(), state.clients.clone(), client, ) }) }