From e3a0d7973dd71bc45b93f1f8d51c88cb8b3bbf9b Mon Sep 17 00:00:00 2001 From: Baud Date: Tue, 27 Feb 2024 10:15:44 +0000 Subject: [PATCH] fix: Don't spawn threads --- Cargo.lock | 214 +++++--- .../src/atem_lib/atem_socket.rs | 516 ++++++++++++++++-- .../src/atem_lib/atem_socket_inner.rs | 500 ----------------- atem-connection-rs/src/atem_lib/mod.rs | 1 - .../src/commands/command_base.rs | 2 +- .../src/commands/mix_effects/program_input.rs | 2 +- atem-connection-rs/src/lib.rs | 1 - atem-test/src/main.rs | 45 +- 8 files changed, 643 insertions(+), 638 deletions(-) delete mode 100644 atem-connection-rs/src/atem_lib/atem_socket_inner.rs diff --git a/Cargo.lock b/Cargo.lock index cfc8410..93553fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,7 +61,7 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e28923312444cdd728e4738b3f9c9cac739500909bb3d3c94b43551b16517648" dependencies = [ - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -71,7 +71,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7" dependencies = [ "anstyle", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -103,7 +103,7 @@ version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" dependencies = [ - "hermit-abi", + "hermit-abi 0.1.19", "libc", "winapi", ] @@ -137,9 +137,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bytes" -version = "1.1.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cc" @@ -292,6 +292,12 @@ dependencies = [ "libc", ] +[[package]] +name = "hermit-abi" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "379dada1584ad501b383485dd706b8afb7a70fcbc7f4da7d780638a5a6124a60" + [[package]] name = "humantime" version = "2.1.0" @@ -304,15 +310,6 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" -[[package]] -name = "instant" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" -dependencies = [ - "cfg-if", -] - [[package]] name = "lazy_static" version = "1.4.0" @@ -321,15 +318,15 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.107" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbe5e23404da5b4f555ef85ebed98fb4083e55a00c317800bc2a50ede9f3d219" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "lock_api" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712a4d093c9976e24e7dbca41db895dabcbac38eb5f4045393d17a95bdfb1109" +checksum = "88943dd7ef4a2e5a4bfa2753aaab3013e34ce2533d1996fb18ef591e315e2b3b" dependencies = [ "scopeguard", ] @@ -361,42 +358,22 @@ dependencies = [ [[package]] name = "mio" -version = "0.7.14" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8067b404fe97c70829f082dec8bcf4f71225d7eaea1d8645349cb76fa06205cc" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" dependencies = [ "libc", - "log", - "miow", - "ntapi", - "winapi", -] - -[[package]] -name = "miow" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9f1c5b025cda876f66ef43a113f91ebc9f4ccef34843000e0adf6ebbab84e21" -dependencies = [ - "winapi", -] - -[[package]] -name = "ntapi" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44" -dependencies = [ - "winapi", + "wasi", + "windows-sys 0.48.0", ] [[package]] name = "num_cpus" -version = "1.13.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.8", "libc", ] @@ -423,34 +400,32 @@ checksum = "2386b4ebe91c2f7f51082d4cefa145d030e33a1842a96b12e4885cc3c01f7a55" [[package]] name = "parking_lot" -version = "0.11.2" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ - "instant", "lock_api", "parking_lot_core", ] [[package]] name = "parking_lot_core" -version = "0.8.5" +version = "0.9.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d76e8e1493bcac0d2766c42737f34458f1c8c50c0d23bcb24ea953affb273216" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" dependencies = [ "cfg-if", - "instant", "libc", "redox_syscall", "smallvec", - "winapi", + "windows-targets 0.48.5", ] [[package]] name = "pin-project-lite" -version = "0.2.7" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" [[package]] name = "proc-macro2" @@ -472,9 +447,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.2.10" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8383f39639269cde97d255a32bdb68c047337295414940c68bdd30c2e13203ff" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" dependencies = [ "bitflags", ] @@ -504,9 +479,9 @@ checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" [[package]] name = "scopeguard" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "sharded-slab" @@ -519,18 +494,28 @@ dependencies = [ [[package]] name = "signal-hook-registry" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" dependencies = [ "libc", ] [[package]] name = "smallvec" -version = "1.7.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ecab6c735a6bb4139c0caafd0cc3635748bbb3acf4550e8138122099251f309" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" + +[[package]] +name = "socket2" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] [[package]] name = "strsim" @@ -600,33 +585,32 @@ dependencies = [ [[package]] name = "tokio" -version = "1.14.0" +version = "1.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70e992e41e0d2fb9f755b37446f20900f64446ef54874f40a60c78f021ac6144" +checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" dependencies = [ - "autocfg", + "backtrace", "bytes", "libc", - "memchr", "mio", "num_cpus", - "once_cell", "parking_lot", "pin-project-lite", "signal-hook-registry", + "socket2", "tokio-macros", - "winapi", + "windows-sys 0.48.0", ] [[package]] name = "tokio-macros" -version = "1.6.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9efc1aba077437943f7515666aa2b882dfabfbfdf89c819ea75a8d6e9eaba5e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 1.0.74", + "syn 2.0.48", ] [[package]] @@ -700,6 +684,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + [[package]] name = "winapi" version = "0.3.9" @@ -731,13 +721,37 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + [[package]] name = "windows-sys" version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets", + "windows-targets 0.52.0", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", ] [[package]] @@ -746,51 +760,93 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + [[package]] name = "windows_aarch64_msvc" version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + [[package]] name = "windows_i686_gnu" version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + [[package]] name = "windows_i686_msvc" version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + [[package]] name = "windows_x86_64_gnu" version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + [[package]] name = "windows_x86_64_gnullvm" version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + [[package]] name = "windows_x86_64_msvc" version = "0.52.0" diff --git a/atem-connection-rs/src/atem_lib/atem_socket.rs b/atem-connection-rs/src/atem_lib/atem_socket.rs index e1b5d4e..32de3ab 100644 --- a/atem-connection-rs/src/atem_lib/atem_socket.rs +++ b/atem-connection-rs/src/atem_lib/atem_socket.rs @@ -1,68 +1,506 @@ -use std::{io, sync::Arc, thread::yield_now}; +use std::{ + io, + net::SocketAddr, + sync::Arc, + time::{Duration, SystemTime}, +}; -use tokio::{sync::RwLock, task::JoinHandle}; +use log::debug; +use tokio::net::UdpSocket; -use super::atem_socket_inner::AtemSocketInner; +use crate::{ + atem_lib::{atem_packet::AtemPacket, atem_util}, + commands::{command_base::DeserializedCommand, parse_commands::deserialize_commands}, +}; + +const IN_FLIGHT_TIMEOUT: u64 = 60; +const CONNECTION_TIMEOUT: u64 = 5000; +const CONNECTION_RETRY_INTERVAL: u64 = 1000; +const RETRANSMIT_CHECK_INTERVAL: u64 = 1000; +const MAX_PACKET_RETRIES: u16 = 10; +const MAX_PACKET_ID: u16 = 1 << 15; +const MAX_PACKET_PER_ACK: u16 = 16; + +// Set to max UDP packet size, for now +const MAX_PACKET_RECEIVE_SIZE: usize = 65535; +const ACK_PACKET_LENGTH: u16 = 12; + +#[derive(Clone)] +pub enum AtemEvent { + Error(String), + Info(String), + Debug(String), + Connected, + Disconnected, + ReceivedCommand(Arc), + AckedCommand(AckedPacket), +} + +#[derive(PartialEq, Clone)] +enum ConnectionState { + Closed, + SynSent, + Established, +} + +#[allow(clippy::from_over_into)] +impl Into for ConnectionState { + fn into(self) -> u8 { + match self { + ConnectionState::Closed => 0x00, + ConnectionState::SynSent => 0x01, + ConnectionState::Established => 0x02, + } + } +} + +#[derive(PartialEq)] +enum PacketFlag { + AckRequest, + NewSessionId, + IsRetransmit, + RetransmitRequest, + AckReply, +} + +impl From for u8 { + fn from(flag: PacketFlag) -> Self { + match flag { + PacketFlag::AckRequest => 0x01, + PacketFlag::NewSessionId => 0x02, + PacketFlag::IsRetransmit => 0x04, + PacketFlag::RetransmitRequest => 0x08, + PacketFlag::AckReply => 0x10, + } + } +} + +#[derive(Clone)] +struct InFlightPacket { + packet_id: u16, + tracking_id: u64, + payload: Vec, + pub last_sent: SystemTime, + pub resent: u16, +} + +#[derive(Clone)] +pub struct AckedPacket { + pub packet_id: u16, + pub tracking_id: u64, +} + +pub struct AtemSocketCommand { + payload: Vec, + raw_name: String, + tracking_id: u64, +} pub struct AtemSocket { - socket: Arc>, + connection_state: ConnectionState, + reconnect_timer: Option, + retransmit_timer: Option, - inner_socket_handle: JoinHandle<()>, + next_send_packet_id: u16, + session_id: u16, + + socket: Option, + address: String, + port: u16, + + last_received_at: SystemTime, + last_received_packed_id: u16, + in_flight: Vec, + ack_timer: Option, + received_without_ack: u16, + + atem_event_tx: tokio::sync::broadcast::Sender, +} + +enum AtemSocketReceiveError { + Closed, } #[derive(Debug, Error)] -pub enum AtemSocketConnectionError { - #[error("Socket connection error")] - IoError(#[from] io::Error), +enum AtemSocketWriteError { + #[error("Socket closed")] + Closed, + + #[error("Socket disconnected")] + Disconnected(#[from] io::Error), } impl AtemSocket { - pub fn new() -> Self { - let socket = AtemSocketInner::new(); - let socket = Arc::new(RwLock::new(socket)); + pub async fn connect(&mut self, address: String, port: u16) -> Result<(), io::Error> { + self.address = address.clone(); + self.port = port; - let socket_clone = Arc::clone(&socket); - let handle = tokio::spawn(async move { - loop { - socket_clone.write().await.tick().await; + let socket = UdpSocket::bind("0.0.0.0:0").await?; + let remote_addr = format!("{}:{}", address, port) + .parse::() + .unwrap(); + socket.connect(remote_addr).await?; + self.socket = Some(socket); - yield_now(); - } - }); + self.start_timers(); - AtemSocket { - socket, + self.next_send_packet_id = 1; + self.session_id = 0; + self.in_flight = vec![]; + debug!("Reconnect"); - inner_socket_handle: handle, - } - } - - pub async fn connect( - &mut self, - address: String, - port: u16, - ) -> Result<(), AtemSocketConnectionError> { - self.socket.write().await.connect(address, port).await?; + self.send_packet(&atem_util::COMMAND_CONNECT_HELLO).await; + self.connection_state = ConnectionState::SynSent; Ok(()) } - pub async fn disconnect(self) { - self.inner_socket_handle.abort(); - self.socket.write().await.disconnect() + pub fn disconnect(&mut self) { + self.stop_timers(); + + self.retransmit_timer = None; + self.reconnect_timer = None; + self.ack_timer = None; + self.socket = None; + + let prev_connection_state = self.connection_state.clone(); + self.connection_state = ConnectionState::Closed; + + if prev_connection_state == ConnectionState::Established { + self.on_disconnect(); + } + } + + pub async fn send_commands(&mut self, commands: Vec) { + for command in commands.into_iter() { + self.send_command(&command.payload, &command.raw_name, command.tracking_id) + .await; + } } pub async fn send_command(&mut self, payload: &[u8], raw_name: &str, tracking_id: u64) { - self.socket - .write() - .await - .send_command(payload, raw_name, tracking_id) - .await; + let packet_id = self.next_send_packet_id; + self.next_send_packet_id += 1; + if self.next_send_packet_id >= MAX_PACKET_ID { + self.next_send_packet_id = 0; + } + + let opcode = u16::from(u8::from(PacketFlag::AckRequest)) << 11; + + let mut buffer = vec![0; 20 + payload.len()]; + + // Headers + buffer[0..2].copy_from_slice(&u16::to_be_bytes(opcode | (payload.len() as u16 + 20))); + buffer[2..4].copy_from_slice(&u16::to_be_bytes(self.session_id)); + buffer[10..12].copy_from_slice(&u16::to_be_bytes(packet_id)); + + // Command + buffer[12..14].copy_from_slice(&u16::to_be_bytes(payload.len() as u16 + 8)); + buffer[16..20].copy_from_slice(raw_name.as_bytes()); + + // Body + buffer[20..20 + payload.len()].copy_from_slice(payload); + self.send_packet(&buffer).await; + + self.in_flight.push(InFlightPacket { + packet_id, + tracking_id, + payload: buffer, + last_sent: SystemTime::now(), + resent: 0, + }) + } + + pub fn subscribe_to_events(&self) -> tokio::sync::broadcast::Receiver { + self.atem_event_tx.subscribe() + } + + async fn restart_connection(&mut self) { + self.disconnect(); + self.connect(self.address.clone(), self.port).await.ok(); + } + + pub async fn tick(&mut self) { + let messages = self.receive().await.ok(); + if let Some(messages) = messages { + for message in messages.iter() { + self.recieved_packet(message).await; + } + } + if let Some(ack_time) = self.ack_timer { + if ack_time <= SystemTime::now() { + self.ack_timer = None; + self.received_without_ack = 0; + self.send_ack(self.last_received_packed_id).await; + } + } + if let Some(reconnect_time) = self.reconnect_timer { + if reconnect_time <= SystemTime::now() { + if self.last_received_at + Duration::from_millis(CONNECTION_TIMEOUT) + <= SystemTime::now() + { + debug!("{:?}", self.last_received_at); + debug!("Connection timed out, restarting"); + self.restart_connection().await; + } + self.start_reconnect_timer(); + } + } + if let Some(retransmit_time) = self.retransmit_timer { + if retransmit_time <= SystemTime::now() { + self.check_for_retransmit().await; + self.start_retransmit_timer(); + } + } + } + + async fn receive(&mut self) -> Result>, AtemSocketReceiveError> { + let mut messages: Vec> = vec![]; + let socket = self.socket.as_mut().ok_or(AtemSocketReceiveError::Closed)?; + + let mut buf = [0; MAX_PACKET_RECEIVE_SIZE]; + if let Ok((message_size, _)) = socket.try_recv_from(&mut buf) { + messages.push(buf[0..message_size].to_owned()); + } + + Ok(messages) + } + + fn is_packet_covered_by_ack(&self, ack_id: u16, packet_id: u16) -> bool { + let tolerance: u16 = MAX_PACKET_ID / 2; + let pkt_is_shortly_before = packet_id < ack_id && packet_id + tolerance > ack_id; + let pkt_is_shortly_after = packet_id > ack_id && packet_id < ack_id + tolerance; + let pkt_is_before_wrap = packet_id > ack_id + tolerance; + packet_id == ack_id + || ((pkt_is_shortly_before || pkt_is_before_wrap) && !pkt_is_shortly_after) + } + + async fn recieved_packet(&mut self, packet: &[u8]) { + debug!("Received {:x?}", packet); + + let Ok(atem_packet): Result = packet.try_into() else { + return; + }; + + if packet.len() < 12 { + debug!("Invalid packet from ATEM {:x?}", packet); + return; + } + + self.last_received_at = SystemTime::now(); + let length = u16::from_be_bytes(packet[0..2].try_into().unwrap()) & 0x07ff; + + if length as usize != packet.len() { + debug!( + "Length of message differs, expected {} got {}", + length, + packet.len() + ); + return; + } + + let flags = packet[0] >> 3; + self.session_id = u16::from_be_bytes(packet[2..4].try_into().unwrap()); + let remote_packet_id = u16::from_be_bytes(packet[10..12].try_into().unwrap()); + + if flags & u8::from(PacketFlag::NewSessionId) > 0 { + debug!("New session"); + self.connection_state = ConnectionState::Established; + self.last_received_packed_id = remote_packet_id; + self.send_ack(remote_packet_id).await; + return; + } + + if self.connection_state == ConnectionState::Established { + if flags & u8::from(PacketFlag::RetransmitRequest) > 0 { + let from_packet_id = u16::from_be_bytes(packet[6..8].try_into().unwrap()); + debug!("Retransmit request: {:x?}", from_packet_id); + + self.retransmit_from(from_packet_id).await; + } + + if flags & u8::from(PacketFlag::AckRequest) > 0 { + if remote_packet_id == (self.last_received_packed_id + 1) % MAX_PACKET_ID { + self.last_received_packed_id = remote_packet_id; + self.send_or_queue_ack().await; + + if length > 12 { + self.on_commands_received(&packet[12..]); + } + } else if self + .is_packet_covered_by_ack(self.last_received_packed_id, remote_packet_id) + { + self.send_or_queue_ack().await; + } + } + + if flags & u8::from(PacketFlag::IsRetransmit) > 0 { + debug!("ATEM retransmitted packet {:x?}", remote_packet_id); + } + + if flags & u8::from(PacketFlag::AckReply) > 0 { + let ack_packet_id = u16::from_be_bytes(packet[4..6].try_into().unwrap()); + let mut acked_commands: Vec = vec![]; + + self.in_flight = self + .in_flight + .clone() + .into_iter() + .filter(|pkt| { + if self.is_packet_covered_by_ack(ack_packet_id, pkt.packet_id) { + acked_commands.push(AckedPacket { + packet_id: pkt.packet_id, + tracking_id: pkt.tracking_id, + }); + false + } else { + true + } + }) + .collect(); + self.on_command_acknowledged(acked_commands); + } + } + } + + async fn send_packet(&self, packet: &[u8]) { + debug!("Send {:x?}", packet); + if let Some(socket) = &self.socket { + socket.send(packet).await.ok(); + } else { + debug!("Socket is not open") + } + } + + async fn send_or_queue_ack(&mut self) { + self.received_without_ack += 1; + if self.received_without_ack >= MAX_PACKET_PER_ACK { + self.received_without_ack = 0; + self.ack_timer = None; + self.send_ack(self.last_received_packed_id).await; + } else if self.ack_timer.is_none() { + self.ack_timer = Some(SystemTime::now() + Duration::from_millis(5)); + } + } + + async fn send_ack(&mut self, packet_id: u16) { + debug!("Sending ack for packet {:x?}", packet_id); + let flag: u8 = PacketFlag::AckReply.into(); + let opcode = u16::from(flag) << 11; + let mut buffer: [u8; ACK_PACKET_LENGTH as _] = [0; 12]; + buffer[0..2].copy_from_slice(&u16::to_be_bytes(opcode as u16 | ACK_PACKET_LENGTH)); + buffer[2..4].copy_from_slice(&u16::to_be_bytes(self.session_id)); + buffer[4..6].copy_from_slice(&u16::to_be_bytes(packet_id)); + self.send_packet(&buffer).await; + } + + async fn retransmit_from(&mut self, from_id: u16) { + let from_id = from_id % MAX_PACKET_ID; + + if let Some(index) = self + .in_flight + .iter() + .position(|pkt| pkt.packet_id == from_id) + { + debug!( + "Resending from {} to {}", + from_id, + self.in_flight[self.in_flight.len() - 1].packet_id + ); + for i in index..self.in_flight.len() { + let mut sent_packet = self.in_flight[i].clone(); + if sent_packet.packet_id == from_id + || !self.is_packet_covered_by_ack(from_id, sent_packet.packet_id) + { + sent_packet.last_sent = SystemTime::now(); + sent_packet.resent += 1; + + self.send_packet(&sent_packet.payload).await; + } + } + } else { + debug!("Unable to resend: {}", from_id); + self.restart_connection().await; + } + } + + async fn check_for_retransmit(&mut self) { + for sent_packet in self.in_flight.clone() { + if sent_packet.last_sent + Duration::from_millis(IN_FLIGHT_TIMEOUT) < SystemTime::now() + { + if sent_packet.resent <= MAX_PACKET_RETRIES + && self + .is_packet_covered_by_ack(self.next_send_packet_id, sent_packet.packet_id) + { + debug!("Retransmit from timeout: {}", sent_packet.packet_id); + + self.retransmit_from(sent_packet.packet_id).await; + } else { + debug!("Packet timed out: {}", sent_packet.packet_id); + self.restart_connection().await; + } + } + } + } + + fn on_commands_received(&mut self, payload: &[u8]) { + let commands = deserialize_commands(payload); + } + + fn on_command_acknowledged(&mut self, packets: Vec) { + for ack in packets { + let _ = self.atem_event_tx.send(AtemEvent::AckedCommand(ack)); + } + } + + fn on_disconnect(&mut self) { + let _ = self.atem_event_tx.send(AtemEvent::Disconnected); + } + + fn start_timers(&mut self) { + self.start_reconnect_timer(); + self.start_retransmit_timer(); + } + + fn stop_timers(&mut self) { + self.reconnect_timer = None; + self.retransmit_timer = None; + } + + fn start_reconnect_timer(&mut self) { + self.reconnect_timer = + Some(SystemTime::now() + Duration::from_millis(CONNECTION_RETRY_INTERVAL)); + } + + fn start_retransmit_timer(&mut self) { + self.retransmit_timer = + Some(SystemTime::now() + Duration::from_millis(RETRANSMIT_CHECK_INTERVAL)); } } impl Default for AtemSocket { fn default() -> Self { - Self::new() + let (atem_event_tx, _) = tokio::sync::broadcast::channel(100); + + Self { + connection_state: ConnectionState::Closed, + reconnect_timer: None, + retransmit_timer: None, + + next_send_packet_id: 1, + session_id: 0, + + socket: None, + address: "0.0.0.0".to_string(), + port: 0, + + last_received_at: SystemTime::now(), + last_received_packed_id: 0, + in_flight: vec![], + ack_timer: None, + received_without_ack: 0, + + atem_event_tx, + } } } diff --git a/atem-connection-rs/src/atem_lib/atem_socket_inner.rs b/atem-connection-rs/src/atem_lib/atem_socket_inner.rs deleted file mode 100644 index 701db76..0000000 --- a/atem-connection-rs/src/atem_lib/atem_socket_inner.rs +++ /dev/null @@ -1,500 +0,0 @@ -use std::{ - io, - net::SocketAddr, - sync::Arc, - time::{Duration, SystemTime}, -}; - -use log::debug; -use tokio::net::UdpSocket; - -use crate::{ - atem_lib::atem_util, - commands::{command_base::DeserializedCommand, parse_commands::deserialize_commands}, -}; - -const IN_FLIGHT_TIMEOUT: u64 = 60; -const CONNECTION_TIMEOUT: u64 = 5000; -const CONNECTION_RETRY_INTERVAL: u64 = 1000; -const RETRANSMIT_CHECK_INTERVAL: u64 = 1000; -const MAX_PACKET_RETRIES: u16 = 10; -const MAX_PACKET_ID: u16 = 1 << 15; -const MAX_PACKET_PER_ACK: u16 = 16; - -// Set to max UDP packet size, for now -const MAX_PACKET_RECEIVE_SIZE: usize = 65535; -const ACK_PACKET_LENGTH: u16 = 12; - -#[derive(Clone)] -pub enum AtemEvent { - Error(String), - Info(String), - Debug(String), - Connected, - Disconnected, - ReceivedCommand(Arc), - AckedCommand(AckedPacket), -} - -#[derive(PartialEq, Clone)] -enum ConnectionState { - Closed, - SynSent, - Established, -} - -#[allow(clippy::from_over_into)] -impl Into for ConnectionState { - fn into(self) -> u8 { - match self { - ConnectionState::Closed => 0x00, - ConnectionState::SynSent => 0x01, - ConnectionState::Established => 0x02, - } - } -} - -#[derive(PartialEq)] -enum PacketFlag { - AckRequest, - NewSessionId, - IsRetransmit, - RetransmitRequest, - AckReply, -} - -impl From for u8 { - fn from(flag: PacketFlag) -> Self { - match flag { - PacketFlag::AckRequest => 0x01, - PacketFlag::NewSessionId => 0x02, - PacketFlag::IsRetransmit => 0x04, - PacketFlag::RetransmitRequest => 0x08, - PacketFlag::AckReply => 0x10, - } - } -} - -#[derive(Clone)] -struct InFlightPacket { - packet_id: u16, - tracking_id: u64, - payload: Vec, - pub last_sent: SystemTime, - pub resent: u16, -} - -#[derive(Clone)] -struct AckedPacket { - packet_id: u16, - tracking_id: u64, -} - -pub struct AtemSocketCommand { - payload: Vec, - raw_name: String, - tracking_id: u64, -} - -pub struct AtemSocketInner { - connection_state: ConnectionState, - reconnect_timer: Option, - retransmit_timer: Option, - - next_send_packet_id: u16, - session_id: u16, - - socket: Option, - address: String, - port: u16, - - last_received_at: SystemTime, - last_received_packed_id: u16, - in_flight: Vec, - ack_timer: Option, - received_without_ack: u16, - - atem_event_tx: tokio::sync::broadcast::Sender, -} - -enum AtemSocketReceiveError { - Closed, -} - -#[derive(Debug, Error)] -enum AtemSocketWriteError { - #[error("Socket closed")] - Closed, - - #[error("Socket disconnected")] - Disconnected(#[from] io::Error), -} - -impl AtemSocketInner { - pub fn new() -> Self { - let (atem_event_tx, _) = tokio::sync::broadcast::channel(100); - - AtemSocketInner { - connection_state: ConnectionState::Closed, - reconnect_timer: None, - retransmit_timer: None, - - next_send_packet_id: 1, - session_id: 0, - - socket: None, - address: "0.0.0.0".to_string(), - port: 0, - - last_received_at: SystemTime::now(), - last_received_packed_id: 0, - in_flight: vec![], - ack_timer: None, - received_without_ack: 0, - - atem_event_tx, - } - } - - pub async fn connect(&mut self, address: String, port: u16) -> Result<(), io::Error> { - self.address = address.clone(); - self.port = port; - - let socket = UdpSocket::bind("0.0.0.0:0").await?; - let remote_addr = format!("{}:{}", address, port) - .parse::() - .unwrap(); - socket.connect(remote_addr).await?; - self.socket = Some(socket); - - self.start_timers(); - - self.next_send_packet_id = 1; - self.session_id = 0; - self.in_flight = vec![]; - debug!("Reconnect"); - - self.send_packet(&atem_util::COMMAND_CONNECT_HELLO).await; - self.connection_state = ConnectionState::SynSent; - - Ok(()) - } - - pub fn disconnect(&mut self) { - self.stop_timers(); - - self.retransmit_timer = None; - self.reconnect_timer = None; - self.ack_timer = None; - self.socket = None; - - let prev_connection_state = self.connection_state.clone(); - self.connection_state = ConnectionState::Closed; - - if prev_connection_state == ConnectionState::Established { - self.on_disconnect(); - } - } - - pub async fn send_commands(&mut self, commands: Vec) { - for command in commands.into_iter() { - self.send_command(&command.payload, &command.raw_name, command.tracking_id) - .await; - } - } - - pub async fn send_command(&mut self, payload: &[u8], raw_name: &str, tracking_id: u64) { - let packet_id = self.next_send_packet_id; - self.next_send_packet_id += 1; - if self.next_send_packet_id >= MAX_PACKET_ID { - self.next_send_packet_id = 0; - } - - let opcode = u16::from(u8::from(PacketFlag::AckRequest)) << 11; - - let mut buffer = vec![0; 20 + payload.len()]; - - // Headers - buffer[0..2].copy_from_slice(&u16::to_be_bytes(opcode | (payload.len() as u16 + 20))); - buffer[2..4].copy_from_slice(&u16::to_be_bytes(self.session_id)); - buffer[10..12].copy_from_slice(&u16::to_be_bytes(packet_id)); - - // Command - buffer[12..14].copy_from_slice(&u16::to_be_bytes(payload.len() as u16 + 8)); - buffer[16..20].copy_from_slice(raw_name.as_bytes()); - - // Body - buffer[20..20 + payload.len()].copy_from_slice(payload); - self.send_packet(&buffer).await; - - self.in_flight.push(InFlightPacket { - packet_id, - tracking_id, - payload: buffer, - last_sent: SystemTime::now(), - resent: 0, - }) - } - - pub fn subscribe_to_events(&self) -> tokio::sync::broadcast::Receiver { - self.atem_event_tx.subscribe() - } - - async fn restart_connection(&mut self) { - self.disconnect(); - self.connect(self.address.clone(), self.port).await.ok(); - } - - pub async fn tick(&mut self) { - let messages = self.receive().await.ok(); - if let Some(messages) = messages { - for message in messages.iter() { - self.recieved_packet(message).await; - } - } - if let Some(ack_time) = self.ack_timer { - if ack_time <= SystemTime::now() { - self.ack_timer = None; - self.received_without_ack = 0; - self.send_ack(self.last_received_packed_id).await; - } - } - if let Some(reconnect_time) = self.reconnect_timer { - if reconnect_time <= SystemTime::now() { - if self.last_received_at + Duration::from_millis(CONNECTION_TIMEOUT) - <= SystemTime::now() - { - debug!("{:?}", self.last_received_at); - debug!("Connection timed out, restarting"); - self.restart_connection().await; - } - self.start_reconnect_timer(); - } - } - if let Some(retransmit_time) = self.retransmit_timer { - if retransmit_time <= SystemTime::now() { - self.check_for_retransmit().await; - self.start_retransmit_timer(); - } - } - } - - async fn receive(&mut self) -> Result>, AtemSocketReceiveError> { - let mut messages: Vec> = vec![]; - let socket = self.socket.as_mut().ok_or(AtemSocketReceiveError::Closed)?; - - let mut buf = [0; MAX_PACKET_RECEIVE_SIZE]; - if let Ok((message_size, _)) = socket.try_recv_from(&mut buf) { - messages.push(buf[0..message_size].to_owned()); - } - - Ok(messages) - } - - fn is_packet_covered_by_ack(&self, ack_id: u16, packet_id: u16) -> bool { - let tolerance: u16 = MAX_PACKET_ID / 2; - let pkt_is_shortly_before = packet_id < ack_id && packet_id + tolerance > ack_id; - let pkt_is_shortly_after = packet_id > ack_id && packet_id < ack_id + tolerance; - let pkt_is_before_wrap = packet_id > ack_id + tolerance; - packet_id == ack_id - || ((pkt_is_shortly_before || pkt_is_before_wrap) && !pkt_is_shortly_after) - } - - async fn recieved_packet(&mut self, packet: &[u8]) { - debug!("Received {:x?}", packet); - - if packet.len() < 12 { - debug!("Invalid packet from ATEM {:x?}", packet); - return; - } - - self.last_received_at = SystemTime::now(); - let length = u16::from_be_bytes(packet[0..2].try_into().unwrap()) & 0x07ff; - - if length as usize != packet.len() { - debug!( - "Length of message differs, expected {} got {}", - length, - packet.len() - ); - return; - } - - let flags = packet[0] >> 3; - self.session_id = u16::from_be_bytes(packet[2..4].try_into().unwrap()); - let remote_packet_id = u16::from_be_bytes(packet[10..12].try_into().unwrap()); - - if flags & u8::from(PacketFlag::NewSessionId) > 0 { - debug!("New session"); - self.connection_state = ConnectionState::Established; - self.last_received_packed_id = remote_packet_id; - self.send_ack(remote_packet_id).await; - return; - } - - if self.connection_state == ConnectionState::Established { - if flags & u8::from(PacketFlag::RetransmitRequest) > 0 { - let from_packet_id = u16::from_be_bytes(packet[6..8].try_into().unwrap()); - debug!("Retransmit request: {:x?}", from_packet_id); - - self.retransmit_from(from_packet_id).await; - } - - if flags & u8::from(PacketFlag::AckRequest) > 0 { - if remote_packet_id == (self.last_received_packed_id + 1) % MAX_PACKET_ID { - self.last_received_packed_id = remote_packet_id; - self.send_or_queue_ack().await; - - if length > 12 { - self.on_commands_received(&packet[12..]); - } - } else if self - .is_packet_covered_by_ack(self.last_received_packed_id, remote_packet_id) - { - self.send_or_queue_ack().await; - } - } - - if flags & u8::from(PacketFlag::IsRetransmit) > 0 { - debug!("ATEM retransmitted packet {:x?}", remote_packet_id); - } - - if flags & u8::from(PacketFlag::AckReply) > 0 { - let ack_packet_id = u16::from_be_bytes(packet[4..6].try_into().unwrap()); - let mut acked_commands: Vec = vec![]; - - self.in_flight = self - .in_flight - .clone() - .into_iter() - .filter(|pkt| { - if self.is_packet_covered_by_ack(ack_packet_id, pkt.packet_id) { - acked_commands.push(AckedPacket { - packet_id: pkt.packet_id, - tracking_id: pkt.tracking_id, - }); - false - } else { - true - } - }) - .collect(); - self.on_command_acknowledged(acked_commands); - } - } - } - - async fn send_packet(&self, packet: &[u8]) { - debug!("Send {:x?}", packet); - if let Some(socket) = &self.socket { - socket.send(packet).await.ok(); - } else { - debug!("Socket is not open") - } - } - - async fn send_or_queue_ack(&mut self) { - self.received_without_ack += 1; - if self.received_without_ack >= MAX_PACKET_PER_ACK { - self.received_without_ack = 0; - self.ack_timer = None; - self.send_ack(self.last_received_packed_id).await; - } else if self.ack_timer.is_none() { - self.ack_timer = Some(SystemTime::now() + Duration::from_millis(5)); - } - } - - async fn send_ack(&mut self, packet_id: u16) { - debug!("Sending ack for packet {:x?}", packet_id); - let flag: u8 = PacketFlag::AckReply.into(); - let opcode = u16::from(flag) << 11; - let mut buffer: [u8; ACK_PACKET_LENGTH as _] = [0; 12]; - buffer[0..2].copy_from_slice(&u16::to_be_bytes(opcode as u16 | ACK_PACKET_LENGTH)); - buffer[2..4].copy_from_slice(&u16::to_be_bytes(self.session_id)); - buffer[4..6].copy_from_slice(&u16::to_be_bytes(packet_id)); - self.send_packet(&buffer).await; - } - - async fn retransmit_from(&mut self, from_id: u16) { - let from_id = from_id % MAX_PACKET_ID; - - if let Some(index) = self - .in_flight - .iter() - .position(|pkt| pkt.packet_id == from_id) - { - debug!( - "Resending from {} to {}", - from_id, - self.in_flight[self.in_flight.len() - 1].packet_id - ); - for i in index..self.in_flight.len() { - let mut sent_packet = self.in_flight[i].clone(); - if sent_packet.packet_id == from_id - || !self.is_packet_covered_by_ack(from_id, sent_packet.packet_id) - { - sent_packet.last_sent = SystemTime::now(); - sent_packet.resent += 1; - - self.send_packet(&sent_packet.payload).await; - } - } - } else { - debug!("Unable to resend: {}", from_id); - self.restart_connection().await; - } - } - - async fn check_for_retransmit(&mut self) { - for sent_packet in self.in_flight.clone() { - if sent_packet.last_sent + Duration::from_millis(IN_FLIGHT_TIMEOUT) < SystemTime::now() - { - if sent_packet.resent <= MAX_PACKET_RETRIES - && self - .is_packet_covered_by_ack(self.next_send_packet_id, sent_packet.packet_id) - { - debug!("Retransmit from timeout: {}", sent_packet.packet_id); - - self.retransmit_from(sent_packet.packet_id).await; - } else { - debug!("Packet timed out: {}", sent_packet.packet_id); - self.restart_connection().await; - } - } - } - } - - fn on_commands_received(&mut self, payload: &[u8]) { - let commands = deserialize_commands(payload); - } - - fn on_command_acknowledged(&mut self, packets: Vec) { - for ack in packets { - let _ = self.atem_event_tx.send(AtemEvent::AckedCommand(ack)); - } - } - - fn on_disconnect(&mut self) { - let _ = self.atem_event_tx.send(AtemEvent::Disconnected); - } - - fn start_timers(&mut self) { - self.start_reconnect_timer(); - self.start_retransmit_timer(); - } - - fn stop_timers(&mut self) { - self.reconnect_timer = None; - self.retransmit_timer = None; - } - - fn start_reconnect_timer(&mut self) { - self.reconnect_timer = - Some(SystemTime::now() + Duration::from_millis(CONNECTION_RETRY_INTERVAL)); - } - - fn start_retransmit_timer(&mut self) { - self.retransmit_timer = - Some(SystemTime::now() + Duration::from_millis(RETRANSMIT_CHECK_INTERVAL)); - } -} diff --git a/atem-connection-rs/src/atem_lib/mod.rs b/atem-connection-rs/src/atem_lib/mod.rs index 6a8af6a..9b3090b 100644 --- a/atem-connection-rs/src/atem_lib/mod.rs +++ b/atem-connection-rs/src/atem_lib/mod.rs @@ -1,4 +1,3 @@ mod atem_packet; pub mod atem_socket; -mod atem_socket_inner; pub mod atem_util; diff --git a/atem-connection-rs/src/commands/command_base.rs b/atem-connection-rs/src/commands/command_base.rs index b44bf84..be68974 100644 --- a/atem-connection-rs/src/commands/command_base.rs +++ b/atem-connection-rs/src/commands/command_base.rs @@ -15,7 +15,7 @@ pub trait SerializableCommand { } pub trait BasicWritableCommand: SerializableCommand { - fn get_raw_name() -> &'static str; + fn get_raw_name(&self) -> &'static str; fn get_minimum_version(&self) -> ProtocolVersion; } diff --git a/atem-connection-rs/src/commands/mix_effects/program_input.rs b/atem-connection-rs/src/commands/mix_effects/program_input.rs index 618c6bc..b288776 100644 --- a/atem-connection-rs/src/commands/mix_effects/program_input.rs +++ b/atem-connection-rs/src/commands/mix_effects/program_input.rs @@ -21,7 +21,7 @@ impl SerializableCommand for ProgramInput { } impl BasicWritableCommand for ProgramInput { - fn get_raw_name() -> &'static str { + fn get_raw_name(&self) -> &'static str { "CPgI" } diff --git a/atem-connection-rs/src/lib.rs b/atem-connection-rs/src/lib.rs index f0a8e48..ce14b38 100644 --- a/atem-connection-rs/src/lib.rs +++ b/atem-connection-rs/src/lib.rs @@ -2,7 +2,6 @@ extern crate derive_new; #[macro_use] extern crate derive_getters; -extern crate tokio; #[macro_use] extern crate thiserror; diff --git a/atem-test/src/main.rs b/atem-test/src/main.rs index 28b1d98..1c966dd 100644 --- a/atem-test/src/main.rs +++ b/atem-test/src/main.rs @@ -1,4 +1,4 @@ -use std::time::Duration; +use std::{sync::Arc, time::Duration}; use atem_connection_rs::{ atem_lib::atem_socket::AtemSocket, @@ -10,7 +10,7 @@ use atem_connection_rs::{ use clap::Parser; use color_eyre::Report; -use tokio::time::sleep; +use tokio::{task::yield_now, time::sleep}; /// ATEM Rust Library Test App #[derive(Parser, Debug)] @@ -30,27 +30,40 @@ async fn main() { let switch_to_source_1 = ProgramInput::new(0, 1); let switch_to_source_2 = ProgramInput::new(0, 2); - let mut atem = AtemSocket::new(); - atem.connect(args.ip, 9910).await.ok(); + let atem = tokio::sync::RwLock::new(AtemSocket::default()); + let atem = Arc::new(atem); + let atem_thread = atem.clone(); + tokio::spawn(async move { + loop { + atem_thread.write().await.tick().await; + + yield_now().await; + } + }); + atem.write().await.connect(args.ip, 9910).await.ok(); let mut tracking_id = 0; loop { tracking_id += 1; sleep(Duration::from_millis(5000)).await; - atem.send_command( - &switch_to_source_1.payload(atem_connection_rs::enums::ProtocolVersion::Unknown), - switch_to_source_1.get_raw_name(), - tracking_id, - ) - .await; + atem.write() + .await + .send_command( + &switch_to_source_1.payload(atem_connection_rs::enums::ProtocolVersion::Unknown), + switch_to_source_1.get_raw_name(), + tracking_id, + ) + .await; tracking_id += 1; sleep(Duration::from_millis(5000)).await; - atem.send_command( - &switch_to_source_2.payload(atem_connection_rs::enums::ProtocolVersion::Unknown), - switch_to_source_2.get_raw_name(), - tracking_id, - ) - .await; + atem.write() + .await + .send_command( + &switch_to_source_2.payload(atem_connection_rs::enums::ProtocolVersion::Unknown), + switch_to_source_2.get_raw_name(), + tracking_id, + ) + .await; tracking_id += 1; } }