diff --git a/atem-connection-rs/src/atem.rs b/atem-connection-rs/src/atem.rs index a3bf592..bbfd9a0 100644 --- a/atem-connection-rs/src/atem.rs +++ b/atem-connection-rs/src/atem.rs @@ -1,20 +1,26 @@ use std::{ collections::{HashMap, VecDeque}, net::SocketAddr, + ops::DerefMut, sync::Arc, + time::Duration, }; use tokio::{select, sync::Semaphore}; use tokio_util::sync::CancellationToken; use crate::{ - atem_lib::atem_socket::{AtemEvent, AtemSocketCommand, AtemSocketMessage, TrackingId}, + atem_lib::atem_socket::{ + AtemSocket, AtemSocketCommand, AtemSocketEvent, AtemSocketMessage, TrackingId, + }, commands::{ command_base::{BasicWritableCommand, DeserializedCommand}, device_profile::version::DESERIALIZE_VERSION_RAW_NAME, init_complete::DESERIALIZE_INIT_COMPLETE_RAW_NAME, + parse_commands::deserialize_commands, time::DESERIALIZE_TIME_RAW_NAME, }, + enums::ProtocolVersion, state::AtemState, }; @@ -27,13 +33,24 @@ pub enum AtemConnectionStatus { } pub struct Atem { + protocol_version: tokio::sync::RwLock, + + socket: tokio::sync::RwLock, + waiting_semaphores: tokio::sync::RwLock>>, socket_message_tx: tokio::sync::mpsc::Sender, } impl Atem { - pub fn new(socket_message_tx: tokio::sync::mpsc::Sender) -> Self { + pub fn new( + socket: AtemSocket, + socket_message_tx: tokio::sync::mpsc::Sender, + ) -> Self { Self { + protocol_version: tokio::sync::RwLock::new(ProtocolVersion::V7_2), + + socket: tokio::sync::RwLock::new(socket), + waiting_semaphores: tokio::sync::RwLock::new(HashMap::new()), socket_message_tx, } @@ -54,25 +71,30 @@ impl Atem { pub async fn run( &self, - mut atem_event_rx: tokio::sync::mpsc::UnboundedReceiver, + mut atem_event_rx: tokio::sync::mpsc::UnboundedReceiver, cancel: CancellationToken, ) { let mut status = AtemConnectionStatus::default(); let mut state = AtemState::default(); + let mut poll_interval = tokio::time::interval(Duration::from_millis(5)); + while !cancel.is_cancelled() { + let tick = poll_interval.tick(); select! { _ = cancel.cancelled() => {}, + _ = tick => {}, message = atem_event_rx.recv() => match message { Some(event) => match event { - AtemEvent::Connected => { + AtemSocketEvent::Connected => { log::info!("Atem connected"); } - AtemEvent::Disconnected => todo!("Disconnected"), - AtemEvent::ReceivedCommands(commands) => { + AtemSocketEvent::Disconnected => todo!("Disconnected"), + AtemSocketEvent::ReceivedCommands(payload) => { + let commands = deserialize_commands(&payload, self.protocol_version.write().await.deref_mut()); self.mutate_state(&mut state, &mut status, commands).await } - AtemEvent::AckedCommand(tracking_id) => { + AtemSocketEvent::AckedCommand(tracking_id) => { log::debug!("Received tracking Id {tracking_id}"); if let Some(semaphore) = self.waiting_semaphores.read().await.get(&tracking_id) @@ -89,18 +111,19 @@ impl Atem { } } } + + self.socket.write().await.poll().await; } } pub async fn send_commands(&self, commands: Vec>) { + let protocol_version = { self.protocol_version.read().await.clone() }; let (callback_tx, callback_rx) = tokio::sync::oneshot::channel(); self.socket_message_tx .send(AtemSocketMessage::SendCommands { commands: commands .iter() - .map(|command| { - AtemSocketCommand::new(command, &crate::enums::ProtocolVersion::Unknown) - }) + .map(|command| AtemSocketCommand::new(command, &protocol_version)) .collect(), tracking_ids_callback: callback_tx, }) diff --git a/atem-connection-rs/src/atem_lib/atem_socket.rs b/atem-connection-rs/src/atem_lib/atem_socket.rs index 4873c58..4945b54 100644 --- a/atem-connection-rs/src/atem_lib/atem_socket.rs +++ b/atem-connection-rs/src/atem_lib/atem_socket.rs @@ -54,10 +54,10 @@ pub struct TrackingIdsCallback { } #[derive(Clone)] -pub enum AtemEvent { +pub enum AtemSocketEvent { Connected, Disconnected, - ReceivedCommands(VecDeque>), + ReceivedCommands(Vec), AckedCommand(TrackingId), } @@ -111,8 +111,11 @@ pub struct AtemSocket { ack_timer: Option, received_without_ack: u16, - atem_event_tx: tokio::sync::mpsc::UnboundedSender, + atem_message_rx: tokio::sync::mpsc::Receiver, + atem_event_tx: tokio::sync::mpsc::UnboundedSender, connected_callbacks: Mutex>>, + + tick_interval: tokio::time::Interval, } #[derive(PartialEq, Clone)] @@ -147,7 +150,11 @@ enum AtemSocketReceiveError { } impl AtemSocket { - pub fn new(atem_event_tx: tokio::sync::mpsc::UnboundedSender) -> Self { + pub fn new( + atem_message_rx: tokio::sync::mpsc::Receiver, + atem_event_tx: tokio::sync::mpsc::UnboundedSender, + ) -> Self { + let tick_interval = tokio::time::interval(Duration::from_millis(5)); Self { connection_state: ConnectionState::Closed, reconnect_timer: None, @@ -169,80 +176,74 @@ impl AtemSocket { ack_timer: None, received_without_ack: 0, + atem_message_rx, atem_event_tx, connected_callbacks: Mutex::default(), + + tick_interval, } } - pub async fn run( - &mut self, - mut atem_message_rx: tokio::sync::mpsc::Receiver, - cancel: tokio_util::sync::CancellationToken, - ) { - let mut interval = tokio::time::interval(Duration::from_millis(5)); - while !cancel.is_cancelled() { - let tick = interval.tick(); - select! { - _ = cancel.cancelled() => {}, - _ = tick => {}, - message = atem_message_rx.recv() => { - match message { - Some(AtemSocketMessage::Connect { - address, - result_callback, - }) => { - { - let mut connected_callbacks = self.connected_callbacks.lock().await; - connected_callbacks.push(result_callback); - } - if self.connect(address).await.is_err() { - log::debug!("Connect failed"); - let mut connected_callbacks = self.connected_callbacks.lock().await; - for callback in connected_callbacks.drain(0..) { - let _ = callback.send(false); - } - } + pub async fn poll(&mut self) { + let tick = self.tick_interval.tick(); + select! { + _ = tick => {}, + message = self.atem_message_rx.recv() => { + match message { + Some(AtemSocketMessage::Connect { + address, + result_callback, + }) => { + { + let mut connected_callbacks = self.connected_callbacks.lock().await; + connected_callbacks.push(result_callback); } - Some(AtemSocketMessage::Disconnect) => self.disconnect(), - Some(AtemSocketMessage::SendCommands { - commands, - tracking_ids_callback, - }) => { - let barrier = Arc::new(Barrier::new(2)); - tracking_ids_callback - .send(TrackingIdsCallback { - tracking_ids: self.send_commands(commands).await, - barrier: barrier.clone(), - }) - .ok(); - - // Let's play the game "Synchronisation Shenanigans"! - // So, we are sending tracking Ids to the sender of this message, the sender will then wait - // for each of these tracking Ids to be ACK'd by the ATEM. However, the sender will need to - // do ✨ some form of shenanigans ✨ in order to be ready to receive tracking Ids. So we send - // them a barrier as part of the callback so that they can tell us that they are ready for - // us to continue with ATEM communication, at which point we may immediately inform them of a - // received tracking Id matching one included in this callback. - // - // Now, if we were being 🚩 Real Proper Software Developers 🚩 we'd probably expect the receiver - // of the callback to do clever things so that if a tracking Id is received immediately, they - // then wait for something that wants that tracking Id on their side, rather than blocking this - // task so that the caller can do ✨ shenanigans ✨. However, that sounds far too clever and too - // much like 🚩 Real Actual Work 🚩 so instead we've chosen to do this and hope that whichever - // actor we're waiting on doesn't take _too_ long to do ✨ shenanigans ✨ before signalling that - // they are ready. If they do, I suggest finding whoever wrote that code and bonking them 🔨. - barrier.wait().await; - }, - None => { - log::info!("ATEM message channel has closed, exiting event loop."); - cancel.cancel(); + if self.connect(address).await.is_err() { + log::debug!("Connect failed"); + let mut connected_callbacks = self.connected_callbacks.lock().await; + for callback in connected_callbacks.drain(0..) { + let _ = callback.send(false); + } } } - } - }; + Some(AtemSocketMessage::Disconnect) => self.disconnect(), + Some(AtemSocketMessage::SendCommands { + commands, + tracking_ids_callback, + }) => { + let barrier = Arc::new(Barrier::new(2)); + tracking_ids_callback + .send(TrackingIdsCallback { + tracking_ids: self.send_commands(commands).await, + barrier: barrier.clone(), + }) + .ok(); - self.tick().await; - } + // Let's play the game "Synchronisation Shenanigans"! + // So, we are sending tracking Ids to the sender of this message, the sender will then wait + // for each of these tracking Ids to be ACK'd by the ATEM. However, the sender will need to + // do ✨ some form of shenanigans ✨ in order to be ready to receive tracking Ids. So we send + // them a barrier as part of the callback so that they can tell us that they are ready for + // us to continue with ATEM communication, at which point we may immediately inform them of a + // received tracking Id matching one included in this callback. + // + // Now, if we were being 🚩 Real Proper Software Developers 🚩 we'd probably expect the receiver + // of the callback to do clever things so that if a tracking Id is received immediately, they + // then wait for something that wants that tracking Id on their side, rather than blocking this + // task so that the caller can do ✨ shenanigans ✨. However, that sounds far too clever and too + // much like 🚩 Real Actual Work 🚩 so instead we've chosen to do this and hope that whichever + // actor we're waiting on doesn't take _too_ long to do ✨ shenanigans ✨ before signalling that + // they are ready. If they do, I suggest finding whoever wrote that code and bonking them 🔨. + barrier.wait().await; + }, + None => { + log::info!("ATEM message channel has closed."); + } + } + } + }; + + self.tick().await; } pub async fn connect(&mut self, address: SocketAddr) -> Result<(), io::Error> { @@ -536,23 +537,21 @@ impl AtemSocket { } fn on_commands_received(&mut self, payload: &[u8]) { - let commands = deserialize_commands(payload, &self.protocol_version); - let _ = self .atem_event_tx - .send(AtemEvent::ReceivedCommands(commands)); + .send(AtemSocketEvent::ReceivedCommands(payload.to_vec())); } fn on_command_acknowledged(&mut self, packets: Vec) { for ack in packets { let _ = self .atem_event_tx - .send(AtemEvent::AckedCommand(TrackingId(ack.tracking_id))); + .send(AtemSocketEvent::AckedCommand(TrackingId(ack.tracking_id))); } } async fn on_connect(&mut self) { - let _ = self.atem_event_tx.send(AtemEvent::Connected); + let _ = self.atem_event_tx.send(AtemSocketEvent::Connected); let mut connected_callbacks = self.connected_callbacks.lock().await; for callback in connected_callbacks.drain(0..) { let _ = callback.send(false); @@ -560,7 +559,7 @@ impl AtemSocket { } fn on_disconnect(&mut self) { - let _ = self.atem_event_tx.send(AtemEvent::Disconnected); + let _ = self.atem_event_tx.send(AtemSocketEvent::Disconnected); } fn start_timers(&mut self) { diff --git a/atem-connection-rs/src/commands/command_base.rs b/atem-connection-rs/src/commands/command_base.rs index 921d7f9..b708af6 100644 --- a/atem-connection-rs/src/commands/command_base.rs +++ b/atem-connection-rs/src/commands/command_base.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, fmt::Debug, sync::Arc}; +use std::{collections::HashMap, fmt::Debug, process::Command, sync::Arc}; use crate::{enums::ProtocolVersion, state::AtemState}; diff --git a/atem-connection-rs/src/commands/device_profile/product_identifier.rs b/atem-connection-rs/src/commands/device_profile/product_identifier.rs index 8a59c89..6d79764 100644 --- a/atem-connection-rs/src/commands/device_profile/product_identifier.rs +++ b/atem-connection-rs/src/commands/device_profile/product_identifier.rs @@ -8,12 +8,12 @@ use crate::{ pub const DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME: &str = "_pin"; #[derive(Debug)] -pub struct ProductIdentifierCommand { +pub struct ProductIdentifier { pub product_identifier: String, pub model: Model, } -impl DeserializedCommand for ProductIdentifierCommand { +impl DeserializedCommand for ProductIdentifier { fn raw_name(&self) -> &'static str { DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME } @@ -40,14 +40,14 @@ impl DeserializedCommand for ProductIdentifierCommand { } #[derive(Default)] -pub struct ProductIdentifierCommandDeserializer {} +pub struct ProductIdentifierDeserializer {} -impl CommandDeserializer for ProductIdentifierCommandDeserializer { +impl CommandDeserializer for ProductIdentifierDeserializer { fn deserialize( &self, buffer: &[u8], version: &ProtocolVersion, - ) -> std::sync::Arc { + ) -> Arc { let null_byte_index = buffer .iter() .position(|byte| *byte == b'\0') @@ -57,7 +57,7 @@ impl CommandDeserializer for ProductIdentifierCommandDeserializer { .expect("Malformed string"); let model = buffer[40]; - Arc::new(ProductIdentifierCommand { + Arc::new(ProductIdentifier { product_identifier: product_identifier .to_str() .expect("Invalid rust string") diff --git a/atem-connection-rs/src/commands/device_profile/topology.rs b/atem-connection-rs/src/commands/device_profile/topology.rs index d3b7fbe..68fa16e 100644 --- a/atem-connection-rs/src/commands/device_profile/topology.rs +++ b/atem-connection-rs/src/commands/device_profile/topology.rs @@ -8,7 +8,7 @@ use crate::{ pub const DESERIALIZE_TOPOLOGY_RAW_NAME: &str = "_top"; #[derive(Debug)] -pub struct TopologyCommand { +pub struct Topology { mix_effects: u8, sources: u8, auxilliaries: u8, @@ -27,7 +27,7 @@ pub struct TopologyCommand { only_configurable_outputs: bool, } -impl DeserializedCommand for TopologyCommand { +impl DeserializedCommand for Topology { fn raw_name(&self) -> &'static str { todo!() } @@ -44,7 +44,7 @@ impl CommandDeserializer for TopologyCommandDeserializer { &self, buffer: &[u8], version: &ProtocolVersion, - ) -> std::sync::Arc { + ) -> Arc { let v230offset = if *version > ProtocolVersion::V8_0_1 { 1 } else { @@ -69,7 +69,7 @@ impl CommandDeserializer for TopologyCommandDeserializer { false }; - Arc::new(TopologyCommand { + Arc::new(Topology { mix_effects: buffer[0], sources: buffer[1], downstream_keyers: buffer[2], diff --git a/atem-connection-rs/src/commands/device_profile/version.rs b/atem-connection-rs/src/commands/device_profile/version.rs index 8154b9d..ee2338e 100644 --- a/atem-connection-rs/src/commands/device_profile/version.rs +++ b/atem-connection-rs/src/commands/device_profile/version.rs @@ -1,39 +1,25 @@ -use std::sync::Arc; - -use crate::{ - commands::command_base::{CommandDeserializer, DeserializedCommand}, - enums::ProtocolVersion, -}; +use crate::{commands::command_base::DeserializedCommand, enums::ProtocolVersion}; pub const DESERIALIZE_VERSION_RAW_NAME: &str = "_ver"; #[derive(Debug)] -pub struct VersionCommand { +pub struct Version { pub version: ProtocolVersion, } -impl DeserializedCommand for VersionCommand { +impl DeserializedCommand for Version { fn raw_name(&self) -> &'static str { DESERIALIZE_VERSION_RAW_NAME } fn apply_to_state(&self, state: &mut crate::state::AtemState) { - state.info.api_version = self.version; + state.info.api_version = self.version.clone(); } } -#[derive(Default)] -pub struct VersionCommandDeserializer {} +pub fn deserialize_version(buffer: &[u8]) -> Version { + let version = u32::from_be_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]); + let version: ProtocolVersion = version.try_into().expect("Invalid protocol version"); -impl CommandDeserializer for VersionCommandDeserializer { - fn deserialize( - &self, - buffer: &[u8], - version: &ProtocolVersion, - ) -> std::sync::Arc { - let version = u32::from_be_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]); - let version: ProtocolVersion = version.try_into().expect("Invalid protocol version"); - - Arc::new(VersionCommand { version }) - } + Version { version } } diff --git a/atem-connection-rs/src/commands/init_complete.rs b/atem-connection-rs/src/commands/init_complete.rs index bef42ee..e23fc7e 100644 --- a/atem-connection-rs/src/commands/init_complete.rs +++ b/atem-connection-rs/src/commands/init_complete.rs @@ -25,7 +25,7 @@ impl CommandDeserializer for InitCompleteDeserializer { &self, _buffer: &[u8], version: &ProtocolVersion, - ) -> std::sync::Arc { + ) -> Arc { Arc::new(InitComplete {}) } } diff --git a/atem-connection-rs/src/commands/parse_commands.rs b/atem-connection-rs/src/commands/parse_commands.rs index 0cf7ee8..cb7d854 100644 --- a/atem-connection-rs/src/commands/parse_commands.rs +++ b/atem-connection-rs/src/commands/parse_commands.rs @@ -1,14 +1,14 @@ use std::{collections::VecDeque, sync::Arc}; -use crate::enums::ProtocolVersion; +use crate::{ + commands::device_profile::version::{deserialize_version, DESERIALIZE_VERSION_RAW_NAME}, + enums::ProtocolVersion, +}; use super::{ command_base::{CommandDeserializer, DeserializedCommand}, - device_profile::{ - product_identifier::{ - ProductIdentifierCommandDeserializer, DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME, - }, - version::{VersionCommandDeserializer, DESERIALIZE_VERSION_RAW_NAME}, + device_profile::product_identifier::{ + ProductIdentifierDeserializer, DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME, }, init_complete::{InitCompleteDeserializer, DESERIALIZE_INIT_COMPLETE_RAW_NAME}, mix_effects::program_input::{ProgramInputDeserializer, DESERIALIZE_PROGRAM_INPUT_RAW_NAME}, @@ -18,9 +18,9 @@ use super::{ pub fn deserialize_commands( payload: &[u8], - version: &ProtocolVersion, + version: &mut ProtocolVersion, ) -> VecDeque> { - let mut parsed_commands = VecDeque::new(); + let mut parsed_commands: VecDeque> = VecDeque::new(); let mut head = 0; while payload.len() > head + 8 { @@ -35,15 +35,21 @@ pub fn deserialize_commands( log::debug!("Received command {} with length {}", name, length); - if let Some(deserializer) = command_deserializer_from_string(name.as_str()) { - let deserialized_command = - deserializer.deserialize(&payload[head + 8..head + length], version); + let command_buffer = &payload[head + 8..head + length]; + + if name == DESERIALIZE_VERSION_RAW_NAME { + let version_command = deserialize_version(command_buffer); + *version = version_command.version.clone(); + log::info!("Switched to protocol version {}", version); + parsed_commands.push_back(Arc::new(version_command)); + } else if let Some(deserializer) = command_deserializer_from_string(name.as_str()) { + let deserialized_command = deserializer.deserialize(command_buffer, version); log::debug!("Received {:?}", deserialized_command); parsed_commands.push_back(deserialized_command); } else { log::warn!("Received command {name} for which there is no deserializer."); // TODO: Remove! - // todo!("Write deserializer for {name}."); + todo!("Write deserializer for {name}."); } head += length; @@ -54,13 +60,12 @@ pub fn deserialize_commands( fn command_deserializer_from_string(command_str: &str) -> Option> { match command_str { - DESERIALIZE_VERSION_RAW_NAME => Some(Box::::default()), DESERIALIZE_INIT_COMPLETE_RAW_NAME => Some(Box::::default()), DESERIALIZE_PROGRAM_INPUT_RAW_NAME => Some(Box::::default()), DESERIALIZE_TALLY_BY_SOURCE_RAW_NAME => Some(Box::::default()), DESERIALIZE_TIME_RAW_NAME => Some(Box::::default()), DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME => { - Some(Box::::default()) + Some(Box::::default()) } _ => None, } diff --git a/atem-connection-rs/src/commands/time.rs b/atem-connection-rs/src/commands/time.rs index ae34049..fc0b8f6 100644 --- a/atem-connection-rs/src/commands/time.rs +++ b/atem-connection-rs/src/commands/time.rs @@ -36,7 +36,7 @@ impl CommandDeserializer for TimeDeserializer { &self, buffer: &[u8], version: &ProtocolVersion, - ) -> std::sync::Arc { + ) -> Arc { let info = TimeInfo { hour: buffer[0], minute: buffer[1], diff --git a/atem-connection-rs/src/enums/mod.rs b/atem-connection-rs/src/enums/mod.rs index 21fcf2f..390321f 100644 --- a/atem-connection-rs/src/enums/mod.rs +++ b/atem-connection-rs/src/enums/mod.rs @@ -73,7 +73,7 @@ impl From for Model { } } -#[derive(Debug, Default, Clone, Copy, PartialEq, PartialOrd)] +#[derive(Debug, Default, Clone, PartialEq, PartialOrd)] pub enum ProtocolVersion { #[default] Unknown = 0, diff --git a/atem-test/src/main.rs b/atem-test/src/main.rs index 6233de9..3fe8795 100644 --- a/atem-test/src/main.rs +++ b/atem-test/src/main.rs @@ -35,12 +35,10 @@ async fn main() { tokio::sync::mpsc::channel::(10); let (atem_event_tx, atem_event_rx) = tokio::sync::mpsc::unbounded_channel(); let cancel = CancellationToken::new(); - let cancel_task = cancel.clone(); - let mut atem_socket = AtemSocket::new(atem_event_tx); - let atem_socket_run = atem_socket.run(socket_message_rx, cancel_task); + let mut atem_socket = AtemSocket::new(socket_message_rx, atem_event_tx); - let atem = Arc::new(Atem::new(socket_message_tx)); + let atem = Arc::new(Atem::new(atem_socket, socket_message_tx)); let atem_thread = atem.clone(); let atem_run = atem_thread.run(atem_event_rx, cancel); @@ -64,7 +62,6 @@ async fn main() { }); select! { - _ = atem_socket_run => {}, _ = atem_run => {}, _ = switch_loop => {} }