From 4a41d1f5d7f049605a8e550d39a9c896bfda7681 Mon Sep 17 00:00:00 2001 From: Baud Date: Fri, 1 Mar 2024 17:11:57 +0000 Subject: [PATCH] feat: Atem wrapper --- Cargo.lock | 48 +-- atem-connection-rs/Cargo.toml | 2 +- atem-connection-rs/src/atem.rs | 117 ++++++- .../src/atem_lib/atem_packet.rs | 14 +- .../src/atem_lib/atem_socket.rs | 288 ++++++++++++------ .../src/commands/command_base.rs | 2 +- .../src/commands/mix_effects.rs | 2 - .../src/commands/mix_effects/program_input.rs | 2 +- atem-connection-rs/src/lib.rs | 2 - atem-test/Cargo.toml | 1 + atem-test/src/main.rs | 66 ++-- 11 files changed, 365 insertions(+), 179 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 93553fe..4422846 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -81,8 +81,8 @@ dependencies = [ "derive-getters", "derive-new", "log", - "thiserror", "tokio", + "tokio-util", ] [[package]] @@ -95,6 +95,7 @@ dependencies = [ "env_logger", "log", "tokio", + "tokio-util", ] [[package]] @@ -271,6 +272,18 @@ dependencies = [ "once_cell", ] +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-sink" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" + [[package]] name = "gimli" version = "0.25.0" @@ -554,26 +567,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "thiserror" -version = "1.0.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417" -dependencies = [ - "thiserror-impl", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.74", -] - [[package]] name = "thread_local" version = "1.1.3" @@ -613,6 +606,19 @@ dependencies = [ "syn 2.0.48", ] +[[package]] +name = "tokio-util" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "tracing" version = "0.1.29" diff --git a/atem-connection-rs/Cargo.toml b/atem-connection-rs/Cargo.toml index 73ee78b..58dc8eb 100644 --- a/atem-connection-rs/Cargo.toml +++ b/atem-connection-rs/Cargo.toml @@ -7,5 +7,5 @@ edition = "2021" derive-getters = "0.2.0" derive-new = "0.6.0" log = "0.4.14" -thiserror = "1.0.30" tokio = { version = "1.13.0", features = ["full"] } +tokio-util = "0.7.10" diff --git a/atem-connection-rs/src/atem.rs b/atem-connection-rs/src/atem.rs index 2bc214e..f50f504 100644 --- a/atem-connection-rs/src/atem.rs +++ b/atem-connection-rs/src/atem.rs @@ -1,9 +1,112 @@ -use crate::{commands::command_base::DeserializedCommand, state::AtemState}; +use std::{collections::HashMap, net::SocketAddr, sync::Arc}; -pub struct AtemOptions { - address: Option, - port: Option, - debug_buffers: bool, - disable_multi_threaded: bool, - child_process_timeout: Option, +use tokio::sync::{mpsc::error::TryRecvError, Semaphore}; +use tokio_util::sync::CancellationToken; + +use crate::{ + atem_lib::atem_socket::{AtemEvent, AtemSocketCommand, AtemSocketMessage, TrackingId}, + commands::command_base::BasicWritableCommand, +}; + +pub struct Atem { + waiting_semaphores: tokio::sync::RwLock>>, + socket_message_tx: tokio::sync::mpsc::Sender, +} + +impl Atem { + pub fn new(socket_message_tx: tokio::sync::mpsc::Sender) -> Self { + Self { + waiting_semaphores: tokio::sync::RwLock::new(HashMap::new()), + socket_message_tx, + } + } + + pub async fn connect(&self, address: SocketAddr) { + let (callback_tx, callback_rx) = tokio::sync::oneshot::channel(); + self.socket_message_tx + .send(AtemSocketMessage::Connect { + address, + result_callback: callback_tx, + }) + .await + .unwrap(); + + callback_rx.await.unwrap().unwrap(); + } + + pub async fn run( + &self, + mut atem_event_rx: tokio::sync::mpsc::UnboundedReceiver, + cancel: CancellationToken, + ) { + while !cancel.is_cancelled() { + match atem_event_rx.try_recv() { + Ok(event) => match event { + AtemEvent::Error(_) => todo!(), + AtemEvent::Info(_) => todo!(), + AtemEvent::Debug(_) => todo!(), + AtemEvent::Connected => { + log::info!("Atem connected"); + } + AtemEvent::Disconnected => todo!(), + AtemEvent::ReceivedCommand(_) => todo!(), + AtemEvent::AckedCommand(tracking_id) => { + log::debug!("Received tracking Id {tracking_id}"); + if let Some(semaphore) = + self.waiting_semaphores.read().await.get(&tracking_id) + { + semaphore.add_permits(1); + } else { + log::warn!("Received tracking Id {tracking_id} with no-one waiting for it to be resolved.") + } + } + }, + Err(TryRecvError::Empty) => {} + Err(TryRecvError::Disconnected) => { + log::info!("ATEM event channel has closed, exiting event loop."); + cancel.cancel(); + } + } + } + } + + pub async fn send_commands(&self, commands: Vec>) { + let (callback_tx, callback_rx) = tokio::sync::oneshot::channel(); + self.socket_message_tx + .send(AtemSocketMessage::SendCommands { + commands: commands + .iter() + .map(|command| { + AtemSocketCommand::new(command, &crate::enums::ProtocolVersion::Unknown) + }) + .collect(), + tracking_ids_callback: callback_tx, + }) + .await + .unwrap(); + let callback = callback_rx.await.unwrap(); + + let semaphore = Arc::new(Semaphore::new(0)); + + for tracking_id in callback.tracking_ids.iter() { + self.waiting_semaphores + .write() + .await + .insert(tracking_id.clone(), semaphore.clone()); + } + + callback.barrier.wait().await; + + // If this fails then the semaphore has been closed which is a darn shame but at that point + // the best we can do it continue on our merry way in life and remain oblivious to + // the fire raging in other realms. + semaphore + .acquire_many(callback.tracking_ids.len() as u32) + .await + .ok(); + + for tracking_id in callback.tracking_ids.iter() { + self.waiting_semaphores.write().await.remove(tracking_id); + } + } } diff --git a/atem-connection-rs/src/atem_lib/atem_packet.rs b/atem-connection-rs/src/atem_lib/atem_packet.rs index c1a14a9..2220788 100755 --- a/atem-connection-rs/src/atem_lib/atem_packet.rs +++ b/atem-connection-rs/src/atem_lib/atem_packet.rs @@ -40,10 +40,6 @@ impl<'packet_buffer> AtemPacket<'packet_buffer> { self.length } - pub fn flags(&self) -> u8 { - self.flags - } - pub fn session_id(&self) -> u16 { self.session_id } @@ -80,7 +76,7 @@ impl<'packet_buffer> TryFrom<&'packet_buffer [u8]> for AtemPacket<'packet_buffer ))); } - let length = u16::from_be_bytes(buffer[0..2].try_into().unwrap()) & 0x07ff; + let length = u16::from_be_bytes([buffer[0], buffer[1]]) & 0x07ff; if length as usize != buffer.len() { return Err(AtemPacketErr::LengthDiffers(format!( "Length of message differs, expected {} got {}", @@ -90,20 +86,20 @@ impl<'packet_buffer> TryFrom<&'packet_buffer [u8]> for AtemPacket<'packet_buffer } let flags = buffer[0] >> 3; - let session_id = u16::from_be_bytes(buffer[2..4].try_into().unwrap()); - let remote_packet_id = u16::from_be_bytes(buffer[10..12].try_into().unwrap()); + let session_id = u16::from_be_bytes([buffer[2], buffer[3]]); + let remote_packet_id = u16::from_be_bytes([buffer[10], buffer[11]]); let body = &buffer[12..]; let retransmit_requested_from_packet_id = if flags & u8::from(PacketFlag::RetransmitRequest) > 0 { - Some(u16::from_be_bytes(buffer[6..8].try_into().unwrap())) + Some(u16::from_be_bytes([buffer[6], buffer[7]])) } else { None }; let ack_reply = if flags & u8::from(PacketFlag::AckReply) > 0 { - Some(u16::from_be_bytes(buffer[4..6].try_into().unwrap())) + Some(u16::from_be_bytes([buffer[4], buffer[5]])) } else { None }; diff --git a/atem-connection-rs/src/atem_lib/atem_socket.rs b/atem-connection-rs/src/atem_lib/atem_socket.rs index 3c4d5c6..612d262 100644 --- a/atem-connection-rs/src/atem_lib/atem_socket.rs +++ b/atem-connection-rs/src/atem_lib/atem_socket.rs @@ -1,16 +1,20 @@ use std::{ + fmt::Display, io, net::SocketAddr, sync::Arc, time::{Duration, SystemTime}, }; -use log::debug; -use tokio::net::UdpSocket; +use tokio::{net::UdpSocket, sync::Barrier, task::yield_now}; use crate::{ atem_lib::{atem_packet::AtemPacket, atem_util}, - commands::{command_base::DeserializedCommand, parse_commands::deserialize_commands}, + commands::{ + command_base::{BasicWritableCommand, DeserializedCommand}, + parse_commands::deserialize_commands, + }, + enums::ProtocolVersion, }; use super::atem_packet::PacketFlag; @@ -27,6 +31,23 @@ const MAX_PACKET_PER_ACK: u16 = 16; const MAX_PACKET_RECEIVE_SIZE: usize = 65535; const ACK_PACKET_LENGTH: u16 = 12; +pub enum AtemSocketMessage { + Connect { + address: SocketAddr, + result_callback: tokio::sync::oneshot::Sender>, + }, + Disconnect, + SendCommands { + commands: Vec, + tracking_ids_callback: tokio::sync::oneshot::Sender, + }, +} + +pub struct TrackingIdsCallback { + pub tracking_ids: Vec, + pub barrier: Arc, +} + #[derive(Clone)] pub enum AtemEvent { Error(String), @@ -35,7 +56,58 @@ pub enum AtemEvent { Connected, Disconnected, ReceivedCommand(Arc), - AckedCommand(AckedPacket), + AckedCommand(TrackingId), +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct TrackingId(u64); + +impl Display for TrackingId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[derive(Clone)] +pub struct AckedPacket { + pub packet_id: u16, + pub tracking_id: u64, +} + +pub struct AtemSocketCommand { + payload: Vec, + raw_name: String, +} + +impl AtemSocketCommand { + pub fn new(command: &Box, version: &ProtocolVersion) -> Self { + Self { + payload: command.payload(version), + raw_name: command.get_raw_name().to_string(), + } + } +} + +pub struct AtemSocket { + connection_state: ConnectionState, + reconnect_timer: Option, + retransmit_timer: Option, + + next_tracking_id: u64, + + next_send_packet_id: u16, + session_id: u16, + + socket: Option, + address: SocketAddr, + + last_received_at: SystemTime, + last_received_packed_id: u16, + in_flight: Vec, + ack_timer: Option, + received_without_ack: u16, + + atem_event_tx: tokio::sync::mpsc::UnboundedSender, } #[derive(PartialEq, Clone)] @@ -65,62 +137,91 @@ struct InFlightPacket { pub resent: u16, } -#[derive(Clone)] -pub struct AckedPacket { - pub packet_id: u16, - pub tracking_id: u64, -} - -pub struct AtemSocketCommand { - payload: Vec, - raw_name: String, - tracking_id: u64, -} - -pub struct AtemSocket { - connection_state: ConnectionState, - reconnect_timer: Option, - retransmit_timer: Option, - - next_send_packet_id: u16, - session_id: u16, - - socket: Option, - address: String, - port: u16, - - last_received_at: SystemTime, - last_received_packed_id: u16, - in_flight: Vec, - ack_timer: Option, - received_without_ack: u16, - - atem_event_tx: tokio::sync::broadcast::Sender, -} - enum AtemSocketReceiveError { Closed, } -#[derive(Debug, Error)] -enum AtemSocketWriteError { - #[error("Socket closed")] - Closed, - - #[error("Socket disconnected")] - Disconnected(#[from] io::Error), -} - impl AtemSocket { - pub async fn connect(&mut self, address: String, port: u16) -> Result<(), io::Error> { - self.address = address.clone(); - self.port = port; + pub fn new(atem_event_tx: tokio::sync::mpsc::UnboundedSender) -> Self { + Self { + connection_state: ConnectionState::Closed, + reconnect_timer: None, + retransmit_timer: None, + next_tracking_id: 0, + + next_send_packet_id: 1, + session_id: 0, + + socket: None, + address: "0.0.0.0:0".parse().unwrap(), + + last_received_at: SystemTime::now(), + last_received_packed_id: 0, + in_flight: vec![], + ack_timer: None, + received_without_ack: 0, + + atem_event_tx, + } + } + + pub async fn run( + &mut self, + mut atem_message_rx: tokio::sync::mpsc::Receiver, + cancel: tokio_util::sync::CancellationToken, + ) { + while !cancel.is_cancelled() { + if let Ok(msg) = atem_message_rx.try_recv() { + match msg { + AtemSocketMessage::Connect { + address, + result_callback, + } => { + result_callback.send(self.connect(address).await).ok(); + } + AtemSocketMessage::Disconnect => self.disconnect(), + AtemSocketMessage::SendCommands { + commands, + tracking_ids_callback, + } => { + let barrier = Arc::new(Barrier::new(2)); + tracking_ids_callback + .send(TrackingIdsCallback { + tracking_ids: self.send_commands(commands).await, + barrier: barrier.clone(), + }) + .ok(); + + // Let's play the game "Synchronisation Shenanigans"! + // So, we are sending tracking Ids to the sender of this message, the sender will then wait + // for each of these tracking Ids to be ACK'd by the ATEM. However, the sender will need to + // do ✨ some form of shenanigans ✨ in order to be ready to receive tracking Ids. So we send + // them a barrier as part of the callback so that they can tell us that they are ready for + // us to continue with ATEM communication, at which point we may immediately inform them of a + // received tracking Id matching one included in this callback. + // + // Now, if we were being 🚩 Real Proper Software Developers 🚩 we'd probably expect the receiver + // of the callback to do clever things so that if a tracking Id is received immediately, they + // then wait for something that wants that tracking Id on their side, rather than blocking this + // task so that the caller can do ✨ shenanigans ✨. However, that sounds far too clever and too + // much like 🚩 Real Actual Work 🚩 so instead we've chosen to do this and hope that whichever + // actor we're waiting on doesn't take _too_ long to do ✨ shenanigans ✨ before signalling that + // they are ready. If they do, I suggest finding whoever wrote that code and bonking them 🔨. + barrier.wait().await; + } + } + + yield_now().await; + } + + self.tick().await; + } + } + + pub async fn connect(&mut self, address: SocketAddr) -> Result<(), io::Error> { let socket = UdpSocket::bind("0.0.0.0:0").await?; - let remote_addr = format!("{}:{}", address, port) - .parse::() - .unwrap(); - socket.connect(remote_addr).await?; + socket.connect(address).await?; self.socket = Some(socket); self.start_timers(); @@ -128,7 +229,7 @@ impl AtemSocket { self.next_send_packet_id = 1; self.session_id = 0; self.in_flight = vec![]; - debug!("Reconnect"); + log::debug!("Reconnect"); self.send_packet(&atem_util::COMMAND_CONNECT_HELLO).await; self.connection_state = ConnectionState::SynSent; @@ -152,11 +253,16 @@ impl AtemSocket { } } - pub async fn send_commands(&mut self, commands: Vec) { + pub async fn send_commands(&mut self, commands: Vec) -> Vec { + let mut tracking_ids: Vec = Vec::with_capacity(commands.len()); for command in commands.into_iter() { - self.send_command(&command.payload, &command.raw_name, command.tracking_id) + let tracking_id = self.next_packet_tracking_id(); + self.send_command(&command.payload, &command.raw_name, tracking_id) .await; + tracking_ids.push(TrackingId(tracking_id)); } + + tracking_ids } pub async fn send_command(&mut self, payload: &[u8], raw_name: &str, tracking_id: u64) { @@ -192,16 +298,12 @@ impl AtemSocket { }) } - pub fn subscribe_to_events(&self) -> tokio::sync::broadcast::Receiver { - self.atem_event_tx.subscribe() - } - async fn restart_connection(&mut self) { self.disconnect(); - self.connect(self.address.clone(), self.port).await.ok(); + self.connect(self.address.clone()).await.ok(); } - pub async fn tick(&mut self) { + async fn tick(&mut self) { let messages = self.receive().await.ok(); if let Some(messages) = messages { for message in messages.iter() { @@ -220,8 +322,8 @@ impl AtemSocket { if self.last_received_at + Duration::from_millis(CONNECTION_TIMEOUT) <= SystemTime::now() { - debug!("{:?}", self.last_received_at); - debug!("Connection timed out, restarting"); + log::debug!("{:?}", self.last_received_at); + log::debug!("Connection timed out, restarting"); self.restart_connection().await; } self.start_reconnect_timer(); @@ -261,7 +363,7 @@ impl AtemSocket { return; }; - debug!("Received {:x?}", atem_packet); + log::debug!("Received {:x?}", atem_packet); self.last_received_at = SystemTime::now(); @@ -269,16 +371,17 @@ impl AtemSocket { let remote_packet_id = atem_packet.remote_packet_id(); if atem_packet.has_flag(PacketFlag::NewSessionId) { - debug!("New session"); + log::debug!("New session"); self.connection_state = ConnectionState::Established; self.last_received_packed_id = remote_packet_id; self.send_ack(remote_packet_id).await; + self.on_connect(); return; } if self.connection_state == ConnectionState::Established { if let Some(from_packet_id) = atem_packet.retransmit_request() { - debug!("Retransmit request: {:x?}", from_packet_id); + log::debug!("Retransmit request: {:x?}", from_packet_id); self.retransmit_from(from_packet_id).await; } @@ -299,7 +402,7 @@ impl AtemSocket { } if atem_packet.has_flag(PacketFlag::IsRetransmit) { - debug!("ATEM retransmitted packet {:x?}", remote_packet_id); + log::debug!("ATEM retransmitted packet {:x?}", remote_packet_id); } if let Some(ack_packet_id) = atem_packet.ack_reply() { @@ -327,11 +430,11 @@ impl AtemSocket { } async fn send_packet(&self, packet: &[u8]) { - debug!("Send {:x?}", packet); + log::debug!("Send {:x?}", packet); if let Some(socket) = &self.socket { socket.send(packet).await.ok(); } else { - debug!("Socket is not open") + log::debug!("Socket is not open") } } @@ -347,7 +450,7 @@ impl AtemSocket { } async fn send_ack(&mut self, packet_id: u16) { - debug!("Sending ack for packet {:x?}", packet_id); + log::debug!("Sending ack for packet {:x?}", packet_id); let flag: u8 = PacketFlag::AckReply.into(); let opcode = u16::from(flag) << 11; let mut buffer: [u8; ACK_PACKET_LENGTH as _] = [0; 12]; @@ -365,7 +468,7 @@ impl AtemSocket { .iter() .position(|pkt| pkt.packet_id == from_id) { - debug!( + log::debug!( "Resending from {} to {}", from_id, self.in_flight[self.in_flight.len() - 1].packet_id @@ -382,7 +485,7 @@ impl AtemSocket { } } } else { - debug!("Unable to resend: {}", from_id); + log::debug!("Unable to resend: {}", from_id); self.restart_connection().await; } } @@ -395,11 +498,11 @@ impl AtemSocket { && self .is_packet_covered_by_ack(self.next_send_packet_id, sent_packet.packet_id) { - debug!("Retransmit from timeout: {}", sent_packet.packet_id); + log::debug!("Retransmit from timeout: {}", sent_packet.packet_id); self.retransmit_from(sent_packet.packet_id).await; } else { - debug!("Packet timed out: {}", sent_packet.packet_id); + log::debug!("Packet timed out: {}", sent_packet.packet_id); self.restart_connection().await; } } @@ -412,10 +515,16 @@ impl AtemSocket { fn on_command_acknowledged(&mut self, packets: Vec) { for ack in packets { - let _ = self.atem_event_tx.send(AtemEvent::AckedCommand(ack)); + let _ = self + .atem_event_tx + .send(AtemEvent::AckedCommand(TrackingId(ack.tracking_id))); } } + fn on_connect(&mut self) { + let _ = self.atem_event_tx.send(AtemEvent::Connected); + } + fn on_disconnect(&mut self) { let _ = self.atem_event_tx.send(AtemEvent::Disconnected); } @@ -439,31 +548,10 @@ impl AtemSocket { self.retransmit_timer = Some(SystemTime::now() + Duration::from_millis(RETRANSMIT_CHECK_INTERVAL)); } -} -impl Default for AtemSocket { - fn default() -> Self { - let (atem_event_tx, _) = tokio::sync::broadcast::channel(100); + fn next_packet_tracking_id(&mut self) -> u64 { + self.next_tracking_id = self.next_tracking_id.checked_add(1).unwrap_or(1); - Self { - connection_state: ConnectionState::Closed, - reconnect_timer: None, - retransmit_timer: None, - - next_send_packet_id: 1, - session_id: 0, - - socket: None, - address: "0.0.0.0".to_string(), - port: 0, - - last_received_at: SystemTime::now(), - last_received_packed_id: 0, - in_flight: vec![], - ack_timer: None, - received_without_ack: 0, - - atem_event_tx, - } + self.next_tracking_id } } diff --git a/atem-connection-rs/src/commands/command_base.rs b/atem-connection-rs/src/commands/command_base.rs index be68974..0fb2b5c 100644 --- a/atem-connection-rs/src/commands/command_base.rs +++ b/atem-connection-rs/src/commands/command_base.rs @@ -11,7 +11,7 @@ pub trait CommandDeserializer: Send + Sync { } pub trait SerializableCommand { - fn payload(&self, version: ProtocolVersion) -> Vec; + fn payload(&self, version: &ProtocolVersion) -> Vec; } pub trait BasicWritableCommand: SerializableCommand { diff --git a/atem-connection-rs/src/commands/mix_effects.rs b/atem-connection-rs/src/commands/mix_effects.rs index cca5d48..776e8d0 100644 --- a/atem-connection-rs/src/commands/mix_effects.rs +++ b/atem-connection-rs/src/commands/mix_effects.rs @@ -1,3 +1 @@ -use super::command_base::{BasicWritableCommand, SerializableCommand}; - pub mod program_input; diff --git a/atem-connection-rs/src/commands/mix_effects/program_input.rs b/atem-connection-rs/src/commands/mix_effects/program_input.rs index b288776..044a6fd 100644 --- a/atem-connection-rs/src/commands/mix_effects/program_input.rs +++ b/atem-connection-rs/src/commands/mix_effects/program_input.rs @@ -11,7 +11,7 @@ pub struct ProgramInput { } impl SerializableCommand for ProgramInput { - fn payload(&self, _version: crate::enums::ProtocolVersion) -> Vec { + fn payload(&self, _version: &crate::enums::ProtocolVersion) -> Vec { let mut buf = vec![0; 4]; buf[..1].copy_from_slice(&self.mix_effect.to_be_bytes()); buf[2..].copy_from_slice(&self.source.to_be_bytes()); diff --git a/atem-connection-rs/src/lib.rs b/atem-connection-rs/src/lib.rs index a7c3350..b82f233 100644 --- a/atem-connection-rs/src/lib.rs +++ b/atem-connection-rs/src/lib.rs @@ -2,8 +2,6 @@ extern crate derive_new; #[macro_use] extern crate derive_getters; -#[macro_use] -extern crate thiserror; pub mod atem; pub mod atem_lib; diff --git a/atem-test/Cargo.toml b/atem-test/Cargo.toml index 64ade34..6496268 100644 --- a/atem-test/Cargo.toml +++ b/atem-test/Cargo.toml @@ -12,3 +12,4 @@ color-eyre = "0.5.11" env_logger = "0.9.0" log = "0.4.14" tokio = "1.14.0" +tokio-util = "0.7.10" diff --git a/atem-test/src/main.rs b/atem-test/src/main.rs index 1c966dd..0c1e1d2 100644 --- a/atem-test/src/main.rs +++ b/atem-test/src/main.rs @@ -1,16 +1,20 @@ -use std::{sync::Arc, time::Duration}; +use std::{ + net::{Ipv4Addr, SocketAddrV4}, + str::FromStr, + sync::Arc, + time::Duration, +}; use atem_connection_rs::{ - atem_lib::atem_socket::AtemSocket, - commands::{ - command_base::{BasicWritableCommand, SerializableCommand}, - mix_effects::program_input::ProgramInput, - }, + atem::Atem, + atem_lib::atem_socket::{AtemSocket, AtemSocketMessage}, + commands::mix_effects::program_input::ProgramInput, }; use clap::Parser; use color_eyre::Report; -use tokio::{task::yield_now, time::sleep}; +use tokio::time::sleep; +use tokio_util::sync::CancellationToken; /// ATEM Rust Library Test App #[derive(Parser, Debug)] @@ -27,44 +31,36 @@ async fn main() { setup_logging().unwrap(); - let switch_to_source_1 = ProgramInput::new(0, 1); - let switch_to_source_2 = ProgramInput::new(0, 2); + let (socket_message_tx, socket_message_rx) = + tokio::sync::mpsc::channel::(10); + let (atem_event_tx, atem_event_rx) = tokio::sync::mpsc::unbounded_channel(); + let cancel = CancellationToken::new(); + let cancel_task = cancel.clone(); - let atem = tokio::sync::RwLock::new(AtemSocket::default()); - let atem = Arc::new(atem); + let mut atem_socket = AtemSocket::new(atem_event_tx); + tokio::spawn(async move { + atem_socket.run(socket_message_rx, cancel_task).await; + }); + + let atem = Arc::new(Atem::new(socket_message_tx)); let atem_thread = atem.clone(); tokio::spawn(async move { - loop { - atem_thread.write().await.tick().await; - - yield_now().await; - } + atem_thread.run(atem_event_rx, cancel).await; }); - atem.write().await.connect(args.ip, 9910).await.ok(); - let mut tracking_id = 0; + let address = Ipv4Addr::from_str(&args.ip).unwrap(); + let socket = SocketAddrV4::new(address, 9910); + atem.connect(socket.into()).await; + loop { - tracking_id += 1; sleep(Duration::from_millis(5000)).await; - atem.write() - .await - .send_command( - &switch_to_source_1.payload(atem_connection_rs::enums::ProtocolVersion::Unknown), - switch_to_source_1.get_raw_name(), - tracking_id, - ) + log::info!("Switch to source 1"); + atem.send_commands(vec![Box::new(ProgramInput::new(0, 1))]) .await; - tracking_id += 1; sleep(Duration::from_millis(5000)).await; - atem.write() - .await - .send_command( - &switch_to_source_2.payload(atem_connection_rs::enums::ProtocolVersion::Unknown), - switch_to_source_2.get_raw_name(), - tracking_id, - ) + log::info!("Switch to source 2"); + atem.send_commands(vec![Box::new(ProgramInput::new(0, 2))]) .await; - tracking_id += 1; } }