feat: Handle version command

This commit is contained in:
Baud 2024-03-19 20:31:25 +00:00
parent d2f41821c0
commit 25460d77cd
11 changed files with 150 additions and 140 deletions

View File

@ -1,20 +1,26 @@
use std::{ use std::{
collections::{HashMap, VecDeque}, collections::{HashMap, VecDeque},
net::SocketAddr, net::SocketAddr,
ops::DerefMut,
sync::Arc, sync::Arc,
time::Duration,
}; };
use tokio::{select, sync::Semaphore}; use tokio::{select, sync::Semaphore};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::{ use crate::{
atem_lib::atem_socket::{AtemEvent, AtemSocketCommand, AtemSocketMessage, TrackingId}, atem_lib::atem_socket::{
AtemSocket, AtemSocketCommand, AtemSocketEvent, AtemSocketMessage, TrackingId,
},
commands::{ commands::{
command_base::{BasicWritableCommand, DeserializedCommand}, command_base::{BasicWritableCommand, DeserializedCommand},
device_profile::version::DESERIALIZE_VERSION_RAW_NAME, device_profile::version::DESERIALIZE_VERSION_RAW_NAME,
init_complete::DESERIALIZE_INIT_COMPLETE_RAW_NAME, init_complete::DESERIALIZE_INIT_COMPLETE_RAW_NAME,
parse_commands::deserialize_commands,
time::DESERIALIZE_TIME_RAW_NAME, time::DESERIALIZE_TIME_RAW_NAME,
}, },
enums::ProtocolVersion,
state::AtemState, state::AtemState,
}; };
@ -27,13 +33,24 @@ pub enum AtemConnectionStatus {
} }
pub struct Atem { pub struct Atem {
protocol_version: tokio::sync::RwLock<ProtocolVersion>,
socket: tokio::sync::RwLock<AtemSocket>,
waiting_semaphores: tokio::sync::RwLock<HashMap<TrackingId, Arc<Semaphore>>>, waiting_semaphores: tokio::sync::RwLock<HashMap<TrackingId, Arc<Semaphore>>>,
socket_message_tx: tokio::sync::mpsc::Sender<AtemSocketMessage>, socket_message_tx: tokio::sync::mpsc::Sender<AtemSocketMessage>,
} }
impl Atem { impl Atem {
pub fn new(socket_message_tx: tokio::sync::mpsc::Sender<AtemSocketMessage>) -> Self { pub fn new(
socket: AtemSocket,
socket_message_tx: tokio::sync::mpsc::Sender<AtemSocketMessage>,
) -> Self {
Self { Self {
protocol_version: tokio::sync::RwLock::new(ProtocolVersion::V7_2),
socket: tokio::sync::RwLock::new(socket),
waiting_semaphores: tokio::sync::RwLock::new(HashMap::new()), waiting_semaphores: tokio::sync::RwLock::new(HashMap::new()),
socket_message_tx, socket_message_tx,
} }
@ -54,25 +71,30 @@ impl Atem {
pub async fn run( pub async fn run(
&self, &self,
mut atem_event_rx: tokio::sync::mpsc::UnboundedReceiver<AtemEvent>, mut atem_event_rx: tokio::sync::mpsc::UnboundedReceiver<AtemSocketEvent>,
cancel: CancellationToken, cancel: CancellationToken,
) { ) {
let mut status = AtemConnectionStatus::default(); let mut status = AtemConnectionStatus::default();
let mut state = AtemState::default(); let mut state = AtemState::default();
let mut poll_interval = tokio::time::interval(Duration::from_millis(5));
while !cancel.is_cancelled() { while !cancel.is_cancelled() {
let tick = poll_interval.tick();
select! { select! {
_ = cancel.cancelled() => {}, _ = cancel.cancelled() => {},
_ = tick => {},
message = atem_event_rx.recv() => match message { message = atem_event_rx.recv() => match message {
Some(event) => match event { Some(event) => match event {
AtemEvent::Connected => { AtemSocketEvent::Connected => {
log::info!("Atem connected"); log::info!("Atem connected");
} }
AtemEvent::Disconnected => todo!("Disconnected"), AtemSocketEvent::Disconnected => todo!("Disconnected"),
AtemEvent::ReceivedCommands(commands) => { AtemSocketEvent::ReceivedCommands(payload) => {
let commands = deserialize_commands(&payload, self.protocol_version.write().await.deref_mut());
self.mutate_state(&mut state, &mut status, commands).await self.mutate_state(&mut state, &mut status, commands).await
} }
AtemEvent::AckedCommand(tracking_id) => { AtemSocketEvent::AckedCommand(tracking_id) => {
log::debug!("Received tracking Id {tracking_id}"); log::debug!("Received tracking Id {tracking_id}");
if let Some(semaphore) = if let Some(semaphore) =
self.waiting_semaphores.read().await.get(&tracking_id) self.waiting_semaphores.read().await.get(&tracking_id)
@ -89,18 +111,19 @@ impl Atem {
} }
} }
} }
self.socket.write().await.poll().await;
} }
} }
pub async fn send_commands(&self, commands: Vec<Box<dyn BasicWritableCommand>>) { pub async fn send_commands(&self, commands: Vec<Box<dyn BasicWritableCommand>>) {
let protocol_version = { self.protocol_version.read().await.clone() };
let (callback_tx, callback_rx) = tokio::sync::oneshot::channel(); let (callback_tx, callback_rx) = tokio::sync::oneshot::channel();
self.socket_message_tx self.socket_message_tx
.send(AtemSocketMessage::SendCommands { .send(AtemSocketMessage::SendCommands {
commands: commands commands: commands
.iter() .iter()
.map(|command| { .map(|command| AtemSocketCommand::new(command, &protocol_version))
AtemSocketCommand::new(command, &crate::enums::ProtocolVersion::Unknown)
})
.collect(), .collect(),
tracking_ids_callback: callback_tx, tracking_ids_callback: callback_tx,
}) })

View File

@ -54,10 +54,10 @@ pub struct TrackingIdsCallback {
} }
#[derive(Clone)] #[derive(Clone)]
pub enum AtemEvent { pub enum AtemSocketEvent {
Connected, Connected,
Disconnected, Disconnected,
ReceivedCommands(VecDeque<Arc<dyn DeserializedCommand>>), ReceivedCommands(Vec<u8>),
AckedCommand(TrackingId), AckedCommand(TrackingId),
} }
@ -111,8 +111,11 @@ pub struct AtemSocket {
ack_timer: Option<SystemTime>, ack_timer: Option<SystemTime>,
received_without_ack: u16, received_without_ack: u16,
atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemEvent>, atem_message_rx: tokio::sync::mpsc::Receiver<AtemSocketMessage>,
atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemSocketEvent>,
connected_callbacks: Mutex<Vec<tokio::sync::oneshot::Sender<bool>>>, connected_callbacks: Mutex<Vec<tokio::sync::oneshot::Sender<bool>>>,
tick_interval: tokio::time::Interval,
} }
#[derive(PartialEq, Clone)] #[derive(PartialEq, Clone)]
@ -147,7 +150,11 @@ enum AtemSocketReceiveError {
} }
impl AtemSocket { impl AtemSocket {
pub fn new(atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemEvent>) -> Self { pub fn new(
atem_message_rx: tokio::sync::mpsc::Receiver<AtemSocketMessage>,
atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemSocketEvent>,
) -> Self {
let tick_interval = tokio::time::interval(Duration::from_millis(5));
Self { Self {
connection_state: ConnectionState::Closed, connection_state: ConnectionState::Closed,
reconnect_timer: None, reconnect_timer: None,
@ -169,23 +176,19 @@ impl AtemSocket {
ack_timer: None, ack_timer: None,
received_without_ack: 0, received_without_ack: 0,
atem_message_rx,
atem_event_tx, atem_event_tx,
connected_callbacks: Mutex::default(), connected_callbacks: Mutex::default(),
tick_interval,
} }
} }
pub async fn run( pub async fn poll(&mut self) {
&mut self, let tick = self.tick_interval.tick();
mut atem_message_rx: tokio::sync::mpsc::Receiver<AtemSocketMessage>,
cancel: tokio_util::sync::CancellationToken,
) {
let mut interval = tokio::time::interval(Duration::from_millis(5));
while !cancel.is_cancelled() {
let tick = interval.tick();
select! { select! {
_ = cancel.cancelled() => {},
_ = tick => {}, _ = tick => {},
message = atem_message_rx.recv() => { message = self.atem_message_rx.recv() => {
match message { match message {
Some(AtemSocketMessage::Connect { Some(AtemSocketMessage::Connect {
address, address,
@ -234,8 +237,7 @@ impl AtemSocket {
barrier.wait().await; barrier.wait().await;
}, },
None => { None => {
log::info!("ATEM message channel has closed, exiting event loop."); log::info!("ATEM message channel has closed.");
cancel.cancel();
} }
} }
} }
@ -243,7 +245,6 @@ impl AtemSocket {
self.tick().await; self.tick().await;
} }
}
pub async fn connect(&mut self, address: SocketAddr) -> Result<(), io::Error> { pub async fn connect(&mut self, address: SocketAddr) -> Result<(), io::Error> {
let socket = UdpSocket::bind("0.0.0.0:0").await?; let socket = UdpSocket::bind("0.0.0.0:0").await?;
@ -536,23 +537,21 @@ impl AtemSocket {
} }
fn on_commands_received(&mut self, payload: &[u8]) { fn on_commands_received(&mut self, payload: &[u8]) {
let commands = deserialize_commands(payload, &self.protocol_version);
let _ = self let _ = self
.atem_event_tx .atem_event_tx
.send(AtemEvent::ReceivedCommands(commands)); .send(AtemSocketEvent::ReceivedCommands(payload.to_vec()));
} }
fn on_command_acknowledged(&mut self, packets: Vec<AckedPacket>) { fn on_command_acknowledged(&mut self, packets: Vec<AckedPacket>) {
for ack in packets { for ack in packets {
let _ = self let _ = self
.atem_event_tx .atem_event_tx
.send(AtemEvent::AckedCommand(TrackingId(ack.tracking_id))); .send(AtemSocketEvent::AckedCommand(TrackingId(ack.tracking_id)));
} }
} }
async fn on_connect(&mut self) { async fn on_connect(&mut self) {
let _ = self.atem_event_tx.send(AtemEvent::Connected); let _ = self.atem_event_tx.send(AtemSocketEvent::Connected);
let mut connected_callbacks = self.connected_callbacks.lock().await; let mut connected_callbacks = self.connected_callbacks.lock().await;
for callback in connected_callbacks.drain(0..) { for callback in connected_callbacks.drain(0..) {
let _ = callback.send(false); let _ = callback.send(false);
@ -560,7 +559,7 @@ impl AtemSocket {
} }
fn on_disconnect(&mut self) { fn on_disconnect(&mut self) {
let _ = self.atem_event_tx.send(AtemEvent::Disconnected); let _ = self.atem_event_tx.send(AtemSocketEvent::Disconnected);
} }
fn start_timers(&mut self) { fn start_timers(&mut self) {

View File

@ -1,4 +1,4 @@
use std::{collections::HashMap, fmt::Debug, sync::Arc}; use std::{collections::HashMap, fmt::Debug, process::Command, sync::Arc};
use crate::{enums::ProtocolVersion, state::AtemState}; use crate::{enums::ProtocolVersion, state::AtemState};

View File

@ -8,12 +8,12 @@ use crate::{
pub const DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME: &str = "_pin"; pub const DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME: &str = "_pin";
#[derive(Debug)] #[derive(Debug)]
pub struct ProductIdentifierCommand { pub struct ProductIdentifier {
pub product_identifier: String, pub product_identifier: String,
pub model: Model, pub model: Model,
} }
impl DeserializedCommand for ProductIdentifierCommand { impl DeserializedCommand for ProductIdentifier {
fn raw_name(&self) -> &'static str { fn raw_name(&self) -> &'static str {
DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME
} }
@ -40,14 +40,14 @@ impl DeserializedCommand for ProductIdentifierCommand {
} }
#[derive(Default)] #[derive(Default)]
pub struct ProductIdentifierCommandDeserializer {} pub struct ProductIdentifierDeserializer {}
impl CommandDeserializer for ProductIdentifierCommandDeserializer { impl CommandDeserializer for ProductIdentifierDeserializer {
fn deserialize( fn deserialize(
&self, &self,
buffer: &[u8], buffer: &[u8],
version: &ProtocolVersion, version: &ProtocolVersion,
) -> std::sync::Arc<dyn DeserializedCommand> { ) -> Arc<dyn DeserializedCommand> {
let null_byte_index = buffer let null_byte_index = buffer
.iter() .iter()
.position(|byte| *byte == b'\0') .position(|byte| *byte == b'\0')
@ -57,7 +57,7 @@ impl CommandDeserializer for ProductIdentifierCommandDeserializer {
.expect("Malformed string"); .expect("Malformed string");
let model = buffer[40]; let model = buffer[40];
Arc::new(ProductIdentifierCommand { Arc::new(ProductIdentifier {
product_identifier: product_identifier product_identifier: product_identifier
.to_str() .to_str()
.expect("Invalid rust string") .expect("Invalid rust string")

View File

@ -8,7 +8,7 @@ use crate::{
pub const DESERIALIZE_TOPOLOGY_RAW_NAME: &str = "_top"; pub const DESERIALIZE_TOPOLOGY_RAW_NAME: &str = "_top";
#[derive(Debug)] #[derive(Debug)]
pub struct TopologyCommand { pub struct Topology {
mix_effects: u8, mix_effects: u8,
sources: u8, sources: u8,
auxilliaries: u8, auxilliaries: u8,
@ -27,7 +27,7 @@ pub struct TopologyCommand {
only_configurable_outputs: bool, only_configurable_outputs: bool,
} }
impl DeserializedCommand for TopologyCommand { impl DeserializedCommand for Topology {
fn raw_name(&self) -> &'static str { fn raw_name(&self) -> &'static str {
todo!() todo!()
} }
@ -44,7 +44,7 @@ impl CommandDeserializer for TopologyCommandDeserializer {
&self, &self,
buffer: &[u8], buffer: &[u8],
version: &ProtocolVersion, version: &ProtocolVersion,
) -> std::sync::Arc<dyn crate::commands::command_base::DeserializedCommand> { ) -> Arc<dyn DeserializedCommand> {
let v230offset = if *version > ProtocolVersion::V8_0_1 { let v230offset = if *version > ProtocolVersion::V8_0_1 {
1 1
} else { } else {
@ -69,7 +69,7 @@ impl CommandDeserializer for TopologyCommandDeserializer {
false false
}; };
Arc::new(TopologyCommand { Arc::new(Topology {
mix_effects: buffer[0], mix_effects: buffer[0],
sources: buffer[1], sources: buffer[1],
downstream_keyers: buffer[2], downstream_keyers: buffer[2],

View File

@ -1,39 +1,25 @@
use std::sync::Arc; use crate::{commands::command_base::DeserializedCommand, enums::ProtocolVersion};
use crate::{
commands::command_base::{CommandDeserializer, DeserializedCommand},
enums::ProtocolVersion,
};
pub const DESERIALIZE_VERSION_RAW_NAME: &str = "_ver"; pub const DESERIALIZE_VERSION_RAW_NAME: &str = "_ver";
#[derive(Debug)] #[derive(Debug)]
pub struct VersionCommand { pub struct Version {
pub version: ProtocolVersion, pub version: ProtocolVersion,
} }
impl DeserializedCommand for VersionCommand { impl DeserializedCommand for Version {
fn raw_name(&self) -> &'static str { fn raw_name(&self) -> &'static str {
DESERIALIZE_VERSION_RAW_NAME DESERIALIZE_VERSION_RAW_NAME
} }
fn apply_to_state(&self, state: &mut crate::state::AtemState) { fn apply_to_state(&self, state: &mut crate::state::AtemState) {
state.info.api_version = self.version; state.info.api_version = self.version.clone();
} }
} }
#[derive(Default)] pub fn deserialize_version(buffer: &[u8]) -> Version {
pub struct VersionCommandDeserializer {}
impl CommandDeserializer for VersionCommandDeserializer {
fn deserialize(
&self,
buffer: &[u8],
version: &ProtocolVersion,
) -> std::sync::Arc<dyn DeserializedCommand> {
let version = u32::from_be_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]); let version = u32::from_be_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]);
let version: ProtocolVersion = version.try_into().expect("Invalid protocol version"); let version: ProtocolVersion = version.try_into().expect("Invalid protocol version");
Arc::new(VersionCommand { version }) Version { version }
}
} }

View File

@ -25,7 +25,7 @@ impl CommandDeserializer for InitCompleteDeserializer {
&self, &self,
_buffer: &[u8], _buffer: &[u8],
version: &ProtocolVersion, version: &ProtocolVersion,
) -> std::sync::Arc<dyn DeserializedCommand> { ) -> Arc<dyn DeserializedCommand> {
Arc::new(InitComplete {}) Arc::new(InitComplete {})
} }
} }

View File

@ -1,14 +1,14 @@
use std::{collections::VecDeque, sync::Arc}; use std::{collections::VecDeque, sync::Arc};
use crate::enums::ProtocolVersion; use crate::{
commands::device_profile::version::{deserialize_version, DESERIALIZE_VERSION_RAW_NAME},
enums::ProtocolVersion,
};
use super::{ use super::{
command_base::{CommandDeserializer, DeserializedCommand}, command_base::{CommandDeserializer, DeserializedCommand},
device_profile::{ device_profile::product_identifier::{
product_identifier::{ ProductIdentifierDeserializer, DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME,
ProductIdentifierCommandDeserializer, DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME,
},
version::{VersionCommandDeserializer, DESERIALIZE_VERSION_RAW_NAME},
}, },
init_complete::{InitCompleteDeserializer, DESERIALIZE_INIT_COMPLETE_RAW_NAME}, init_complete::{InitCompleteDeserializer, DESERIALIZE_INIT_COMPLETE_RAW_NAME},
mix_effects::program_input::{ProgramInputDeserializer, DESERIALIZE_PROGRAM_INPUT_RAW_NAME}, mix_effects::program_input::{ProgramInputDeserializer, DESERIALIZE_PROGRAM_INPUT_RAW_NAME},
@ -18,9 +18,9 @@ use super::{
pub fn deserialize_commands( pub fn deserialize_commands(
payload: &[u8], payload: &[u8],
version: &ProtocolVersion, version: &mut ProtocolVersion,
) -> VecDeque<Arc<dyn DeserializedCommand>> { ) -> VecDeque<Arc<dyn DeserializedCommand>> {
let mut parsed_commands = VecDeque::new(); let mut parsed_commands: VecDeque<Arc<dyn DeserializedCommand>> = VecDeque::new();
let mut head = 0; let mut head = 0;
while payload.len() > head + 8 { while payload.len() > head + 8 {
@ -35,15 +35,21 @@ pub fn deserialize_commands(
log::debug!("Received command {} with length {}", name, length); log::debug!("Received command {} with length {}", name, length);
if let Some(deserializer) = command_deserializer_from_string(name.as_str()) { let command_buffer = &payload[head + 8..head + length];
let deserialized_command =
deserializer.deserialize(&payload[head + 8..head + length], version); if name == DESERIALIZE_VERSION_RAW_NAME {
let version_command = deserialize_version(command_buffer);
*version = version_command.version.clone();
log::info!("Switched to protocol version {}", version);
parsed_commands.push_back(Arc::new(version_command));
} else if let Some(deserializer) = command_deserializer_from_string(name.as_str()) {
let deserialized_command = deserializer.deserialize(command_buffer, version);
log::debug!("Received {:?}", deserialized_command); log::debug!("Received {:?}", deserialized_command);
parsed_commands.push_back(deserialized_command); parsed_commands.push_back(deserialized_command);
} else { } else {
log::warn!("Received command {name} for which there is no deserializer."); log::warn!("Received command {name} for which there is no deserializer.");
// TODO: Remove! // TODO: Remove!
// todo!("Write deserializer for {name}."); todo!("Write deserializer for {name}.");
} }
head += length; head += length;
@ -54,13 +60,12 @@ pub fn deserialize_commands(
fn command_deserializer_from_string(command_str: &str) -> Option<Box<dyn CommandDeserializer>> { fn command_deserializer_from_string(command_str: &str) -> Option<Box<dyn CommandDeserializer>> {
match command_str { match command_str {
DESERIALIZE_VERSION_RAW_NAME => Some(Box::<VersionCommandDeserializer>::default()),
DESERIALIZE_INIT_COMPLETE_RAW_NAME => Some(Box::<InitCompleteDeserializer>::default()), DESERIALIZE_INIT_COMPLETE_RAW_NAME => Some(Box::<InitCompleteDeserializer>::default()),
DESERIALIZE_PROGRAM_INPUT_RAW_NAME => Some(Box::<ProgramInputDeserializer>::default()), DESERIALIZE_PROGRAM_INPUT_RAW_NAME => Some(Box::<ProgramInputDeserializer>::default()),
DESERIALIZE_TALLY_BY_SOURCE_RAW_NAME => Some(Box::<TallyBySourceDeserializer>::default()), DESERIALIZE_TALLY_BY_SOURCE_RAW_NAME => Some(Box::<TallyBySourceDeserializer>::default()),
DESERIALIZE_TIME_RAW_NAME => Some(Box::<TimeDeserializer>::default()), DESERIALIZE_TIME_RAW_NAME => Some(Box::<TimeDeserializer>::default()),
DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME => { DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME => {
Some(Box::<ProductIdentifierCommandDeserializer>::default()) Some(Box::<ProductIdentifierDeserializer>::default())
} }
_ => None, _ => None,
} }

View File

@ -36,7 +36,7 @@ impl CommandDeserializer for TimeDeserializer {
&self, &self,
buffer: &[u8], buffer: &[u8],
version: &ProtocolVersion, version: &ProtocolVersion,
) -> std::sync::Arc<dyn super::command_base::DeserializedCommand> { ) -> Arc<dyn DeserializedCommand> {
let info = TimeInfo { let info = TimeInfo {
hour: buffer[0], hour: buffer[0],
minute: buffer[1], minute: buffer[1],

View File

@ -73,7 +73,7 @@ impl From<u8> for Model {
} }
} }
#[derive(Debug, Default, Clone, Copy, PartialEq, PartialOrd)] #[derive(Debug, Default, Clone, PartialEq, PartialOrd)]
pub enum ProtocolVersion { pub enum ProtocolVersion {
#[default] #[default]
Unknown = 0, Unknown = 0,

View File

@ -35,12 +35,10 @@ async fn main() {
tokio::sync::mpsc::channel::<AtemSocketMessage>(10); tokio::sync::mpsc::channel::<AtemSocketMessage>(10);
let (atem_event_tx, atem_event_rx) = tokio::sync::mpsc::unbounded_channel(); let (atem_event_tx, atem_event_rx) = tokio::sync::mpsc::unbounded_channel();
let cancel = CancellationToken::new(); let cancel = CancellationToken::new();
let cancel_task = cancel.clone();
let mut atem_socket = AtemSocket::new(atem_event_tx); let mut atem_socket = AtemSocket::new(socket_message_rx, atem_event_tx);
let atem_socket_run = atem_socket.run(socket_message_rx, cancel_task);
let atem = Arc::new(Atem::new(socket_message_tx)); let atem = Arc::new(Atem::new(atem_socket, socket_message_tx));
let atem_thread = atem.clone(); let atem_thread = atem.clone();
let atem_run = atem_thread.run(atem_event_rx, cancel); let atem_run = atem_thread.run(atem_event_rx, cancel);
@ -64,7 +62,6 @@ async fn main() {
}); });
select! { select! {
_ = atem_socket_run => {},
_ = atem_run => {}, _ = atem_run => {},
_ = switch_loop => {} _ = switch_loop => {}
} }