feat: Atem wrapper

This commit is contained in:
Baud 2024-03-01 17:11:57 +00:00
parent 5db8843ce7
commit 4a41d1f5d7
11 changed files with 365 additions and 179 deletions

48
Cargo.lock generated
View File

@ -81,8 +81,8 @@ dependencies = [
"derive-getters", "derive-getters",
"derive-new", "derive-new",
"log", "log",
"thiserror",
"tokio", "tokio",
"tokio-util",
] ]
[[package]] [[package]]
@ -95,6 +95,7 @@ dependencies = [
"env_logger", "env_logger",
"log", "log",
"tokio", "tokio",
"tokio-util",
] ]
[[package]] [[package]]
@ -271,6 +272,18 @@ dependencies = [
"once_cell", "once_cell",
] ]
[[package]]
name = "futures-core"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d"
[[package]]
name = "futures-sink"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5"
[[package]] [[package]]
name = "gimli" name = "gimli"
version = "0.25.0" version = "0.25.0"
@ -554,26 +567,6 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "thiserror"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.74",
]
[[package]] [[package]]
name = "thread_local" name = "thread_local"
version = "1.1.3" version = "1.1.3"
@ -613,6 +606,19 @@ dependencies = [
"syn 2.0.48", "syn 2.0.48",
] ]
[[package]]
name = "tokio-util"
version = "0.7.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"pin-project-lite",
"tokio",
]
[[package]] [[package]]
name = "tracing" name = "tracing"
version = "0.1.29" version = "0.1.29"

View File

@ -7,5 +7,5 @@ edition = "2021"
derive-getters = "0.2.0" derive-getters = "0.2.0"
derive-new = "0.6.0" derive-new = "0.6.0"
log = "0.4.14" log = "0.4.14"
thiserror = "1.0.30"
tokio = { version = "1.13.0", features = ["full"] } tokio = { version = "1.13.0", features = ["full"] }
tokio-util = "0.7.10"

View File

@ -1,9 +1,112 @@
use crate::{commands::command_base::DeserializedCommand, state::AtemState}; use std::{collections::HashMap, net::SocketAddr, sync::Arc};
pub struct AtemOptions { use tokio::sync::{mpsc::error::TryRecvError, Semaphore};
address: Option<String>, use tokio_util::sync::CancellationToken;
port: Option<u16>,
debug_buffers: bool, use crate::{
disable_multi_threaded: bool, atem_lib::atem_socket::{AtemEvent, AtemSocketCommand, AtemSocketMessage, TrackingId},
child_process_timeout: Option<u64>, commands::command_base::BasicWritableCommand,
};
pub struct Atem {
waiting_semaphores: tokio::sync::RwLock<HashMap<TrackingId, Arc<Semaphore>>>,
socket_message_tx: tokio::sync::mpsc::Sender<AtemSocketMessage>,
}
impl Atem {
pub fn new(socket_message_tx: tokio::sync::mpsc::Sender<AtemSocketMessage>) -> Self {
Self {
waiting_semaphores: tokio::sync::RwLock::new(HashMap::new()),
socket_message_tx,
}
}
pub async fn connect(&self, address: SocketAddr) {
let (callback_tx, callback_rx) = tokio::sync::oneshot::channel();
self.socket_message_tx
.send(AtemSocketMessage::Connect {
address,
result_callback: callback_tx,
})
.await
.unwrap();
callback_rx.await.unwrap().unwrap();
}
pub async fn run(
&self,
mut atem_event_rx: tokio::sync::mpsc::UnboundedReceiver<AtemEvent>,
cancel: CancellationToken,
) {
while !cancel.is_cancelled() {
match atem_event_rx.try_recv() {
Ok(event) => match event {
AtemEvent::Error(_) => todo!(),
AtemEvent::Info(_) => todo!(),
AtemEvent::Debug(_) => todo!(),
AtemEvent::Connected => {
log::info!("Atem connected");
}
AtemEvent::Disconnected => todo!(),
AtemEvent::ReceivedCommand(_) => todo!(),
AtemEvent::AckedCommand(tracking_id) => {
log::debug!("Received tracking Id {tracking_id}");
if let Some(semaphore) =
self.waiting_semaphores.read().await.get(&tracking_id)
{
semaphore.add_permits(1);
} else {
log::warn!("Received tracking Id {tracking_id} with no-one waiting for it to be resolved.")
}
}
},
Err(TryRecvError::Empty) => {}
Err(TryRecvError::Disconnected) => {
log::info!("ATEM event channel has closed, exiting event loop.");
cancel.cancel();
}
}
}
}
pub async fn send_commands(&self, commands: Vec<Box<dyn BasicWritableCommand>>) {
let (callback_tx, callback_rx) = tokio::sync::oneshot::channel();
self.socket_message_tx
.send(AtemSocketMessage::SendCommands {
commands: commands
.iter()
.map(|command| {
AtemSocketCommand::new(command, &crate::enums::ProtocolVersion::Unknown)
})
.collect(),
tracking_ids_callback: callback_tx,
})
.await
.unwrap();
let callback = callback_rx.await.unwrap();
let semaphore = Arc::new(Semaphore::new(0));
for tracking_id in callback.tracking_ids.iter() {
self.waiting_semaphores
.write()
.await
.insert(tracking_id.clone(), semaphore.clone());
}
callback.barrier.wait().await;
// If this fails then the semaphore has been closed which is a darn shame but at that point
// the best we can do it continue on our merry way in life and remain oblivious to
// the fire raging in other realms.
semaphore
.acquire_many(callback.tracking_ids.len() as u32)
.await
.ok();
for tracking_id in callback.tracking_ids.iter() {
self.waiting_semaphores.write().await.remove(tracking_id);
}
}
} }

View File

@ -40,10 +40,6 @@ impl<'packet_buffer> AtemPacket<'packet_buffer> {
self.length self.length
} }
pub fn flags(&self) -> u8 {
self.flags
}
pub fn session_id(&self) -> u16 { pub fn session_id(&self) -> u16 {
self.session_id self.session_id
} }
@ -80,7 +76,7 @@ impl<'packet_buffer> TryFrom<&'packet_buffer [u8]> for AtemPacket<'packet_buffer
))); )));
} }
let length = u16::from_be_bytes(buffer[0..2].try_into().unwrap()) & 0x07ff; let length = u16::from_be_bytes([buffer[0], buffer[1]]) & 0x07ff;
if length as usize != buffer.len() { if length as usize != buffer.len() {
return Err(AtemPacketErr::LengthDiffers(format!( return Err(AtemPacketErr::LengthDiffers(format!(
"Length of message differs, expected {} got {}", "Length of message differs, expected {} got {}",
@ -90,20 +86,20 @@ impl<'packet_buffer> TryFrom<&'packet_buffer [u8]> for AtemPacket<'packet_buffer
} }
let flags = buffer[0] >> 3; let flags = buffer[0] >> 3;
let session_id = u16::from_be_bytes(buffer[2..4].try_into().unwrap()); let session_id = u16::from_be_bytes([buffer[2], buffer[3]]);
let remote_packet_id = u16::from_be_bytes(buffer[10..12].try_into().unwrap()); let remote_packet_id = u16::from_be_bytes([buffer[10], buffer[11]]);
let body = &buffer[12..]; let body = &buffer[12..];
let retransmit_requested_from_packet_id = let retransmit_requested_from_packet_id =
if flags & u8::from(PacketFlag::RetransmitRequest) > 0 { if flags & u8::from(PacketFlag::RetransmitRequest) > 0 {
Some(u16::from_be_bytes(buffer[6..8].try_into().unwrap())) Some(u16::from_be_bytes([buffer[6], buffer[7]]))
} else { } else {
None None
}; };
let ack_reply = if flags & u8::from(PacketFlag::AckReply) > 0 { let ack_reply = if flags & u8::from(PacketFlag::AckReply) > 0 {
Some(u16::from_be_bytes(buffer[4..6].try_into().unwrap())) Some(u16::from_be_bytes([buffer[4], buffer[5]]))
} else { } else {
None None
}; };

View File

@ -1,16 +1,20 @@
use std::{ use std::{
fmt::Display,
io, io,
net::SocketAddr, net::SocketAddr,
sync::Arc, sync::Arc,
time::{Duration, SystemTime}, time::{Duration, SystemTime},
}; };
use log::debug; use tokio::{net::UdpSocket, sync::Barrier, task::yield_now};
use tokio::net::UdpSocket;
use crate::{ use crate::{
atem_lib::{atem_packet::AtemPacket, atem_util}, atem_lib::{atem_packet::AtemPacket, atem_util},
commands::{command_base::DeserializedCommand, parse_commands::deserialize_commands}, commands::{
command_base::{BasicWritableCommand, DeserializedCommand},
parse_commands::deserialize_commands,
},
enums::ProtocolVersion,
}; };
use super::atem_packet::PacketFlag; use super::atem_packet::PacketFlag;
@ -27,6 +31,23 @@ const MAX_PACKET_PER_ACK: u16 = 16;
const MAX_PACKET_RECEIVE_SIZE: usize = 65535; const MAX_PACKET_RECEIVE_SIZE: usize = 65535;
const ACK_PACKET_LENGTH: u16 = 12; const ACK_PACKET_LENGTH: u16 = 12;
pub enum AtemSocketMessage {
Connect {
address: SocketAddr,
result_callback: tokio::sync::oneshot::Sender<Result<(), io::Error>>,
},
Disconnect,
SendCommands {
commands: Vec<AtemSocketCommand>,
tracking_ids_callback: tokio::sync::oneshot::Sender<TrackingIdsCallback>,
},
}
pub struct TrackingIdsCallback {
pub tracking_ids: Vec<TrackingId>,
pub barrier: Arc<Barrier>,
}
#[derive(Clone)] #[derive(Clone)]
pub enum AtemEvent { pub enum AtemEvent {
Error(String), Error(String),
@ -35,7 +56,58 @@ pub enum AtemEvent {
Connected, Connected,
Disconnected, Disconnected,
ReceivedCommand(Arc<dyn DeserializedCommand>), ReceivedCommand(Arc<dyn DeserializedCommand>),
AckedCommand(AckedPacket), AckedCommand(TrackingId),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TrackingId(u64);
impl Display for TrackingId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Clone)]
pub struct AckedPacket {
pub packet_id: u16,
pub tracking_id: u64,
}
pub struct AtemSocketCommand {
payload: Vec<u8>,
raw_name: String,
}
impl AtemSocketCommand {
pub fn new(command: &Box<dyn BasicWritableCommand>, version: &ProtocolVersion) -> Self {
Self {
payload: command.payload(version),
raw_name: command.get_raw_name().to_string(),
}
}
}
pub struct AtemSocket {
connection_state: ConnectionState,
reconnect_timer: Option<SystemTime>,
retransmit_timer: Option<SystemTime>,
next_tracking_id: u64,
next_send_packet_id: u16,
session_id: u16,
socket: Option<UdpSocket>,
address: SocketAddr,
last_received_at: SystemTime,
last_received_packed_id: u16,
in_flight: Vec<InFlightPacket>,
ack_timer: Option<SystemTime>,
received_without_ack: u16,
atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemEvent>,
} }
#[derive(PartialEq, Clone)] #[derive(PartialEq, Clone)]
@ -65,62 +137,91 @@ struct InFlightPacket {
pub resent: u16, pub resent: u16,
} }
#[derive(Clone)]
pub struct AckedPacket {
pub packet_id: u16,
pub tracking_id: u64,
}
pub struct AtemSocketCommand {
payload: Vec<u8>,
raw_name: String,
tracking_id: u64,
}
pub struct AtemSocket {
connection_state: ConnectionState,
reconnect_timer: Option<SystemTime>,
retransmit_timer: Option<SystemTime>,
next_send_packet_id: u16,
session_id: u16,
socket: Option<UdpSocket>,
address: String,
port: u16,
last_received_at: SystemTime,
last_received_packed_id: u16,
in_flight: Vec<InFlightPacket>,
ack_timer: Option<SystemTime>,
received_without_ack: u16,
atem_event_tx: tokio::sync::broadcast::Sender<AtemEvent>,
}
enum AtemSocketReceiveError { enum AtemSocketReceiveError {
Closed, Closed,
} }
#[derive(Debug, Error)] impl AtemSocket {
enum AtemSocketWriteError { pub fn new(atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemEvent>) -> Self {
#[error("Socket closed")] Self {
Closed, connection_state: ConnectionState::Closed,
reconnect_timer: None,
retransmit_timer: None,
#[error("Socket disconnected")] next_tracking_id: 0,
Disconnected(#[from] io::Error),
next_send_packet_id: 1,
session_id: 0,
socket: None,
address: "0.0.0.0:0".parse().unwrap(),
last_received_at: SystemTime::now(),
last_received_packed_id: 0,
in_flight: vec![],
ack_timer: None,
received_without_ack: 0,
atem_event_tx,
}
} }
impl AtemSocket { pub async fn run(
pub async fn connect(&mut self, address: String, port: u16) -> Result<(), io::Error> { &mut self,
self.address = address.clone(); mut atem_message_rx: tokio::sync::mpsc::Receiver<AtemSocketMessage>,
self.port = port; cancel: tokio_util::sync::CancellationToken,
) {
while !cancel.is_cancelled() {
if let Ok(msg) = atem_message_rx.try_recv() {
match msg {
AtemSocketMessage::Connect {
address,
result_callback,
} => {
result_callback.send(self.connect(address).await).ok();
}
AtemSocketMessage::Disconnect => self.disconnect(),
AtemSocketMessage::SendCommands {
commands,
tracking_ids_callback,
} => {
let barrier = Arc::new(Barrier::new(2));
tracking_ids_callback
.send(TrackingIdsCallback {
tracking_ids: self.send_commands(commands).await,
barrier: barrier.clone(),
})
.ok();
// Let's play the game "Synchronisation Shenanigans"!
// So, we are sending tracking Ids to the sender of this message, the sender will then wait
// for each of these tracking Ids to be ACK'd by the ATEM. However, the sender will need to
// do ✨ some form of shenanigans ✨ in order to be ready to receive tracking Ids. So we send
// them a barrier as part of the callback so that they can tell us that they are ready for
// us to continue with ATEM communication, at which point we may immediately inform them of a
// received tracking Id matching one included in this callback.
//
// Now, if we were being 🚩 Real Proper Software Developers 🚩 we'd probably expect the receiver
// of the callback to do clever things so that if a tracking Id is received immediately, they
// then wait for something that wants that tracking Id on their side, rather than blocking this
// task so that the caller can do ✨ shenanigans ✨. However, that sounds far too clever and too
// much like 🚩 Real Actual Work 🚩 so instead we've chosen to do this and hope that whichever
// actor we're waiting on doesn't take _too_ long to do ✨ shenanigans ✨ before signalling that
// they are ready. If they do, I suggest finding whoever wrote that code and bonking them 🔨.
barrier.wait().await;
}
}
yield_now().await;
}
self.tick().await;
}
}
pub async fn connect(&mut self, address: SocketAddr) -> Result<(), io::Error> {
let socket = UdpSocket::bind("0.0.0.0:0").await?; let socket = UdpSocket::bind("0.0.0.0:0").await?;
let remote_addr = format!("{}:{}", address, port) socket.connect(address).await?;
.parse::<SocketAddr>()
.unwrap();
socket.connect(remote_addr).await?;
self.socket = Some(socket); self.socket = Some(socket);
self.start_timers(); self.start_timers();
@ -128,7 +229,7 @@ impl AtemSocket {
self.next_send_packet_id = 1; self.next_send_packet_id = 1;
self.session_id = 0; self.session_id = 0;
self.in_flight = vec![]; self.in_flight = vec![];
debug!("Reconnect"); log::debug!("Reconnect");
self.send_packet(&atem_util::COMMAND_CONNECT_HELLO).await; self.send_packet(&atem_util::COMMAND_CONNECT_HELLO).await;
self.connection_state = ConnectionState::SynSent; self.connection_state = ConnectionState::SynSent;
@ -152,11 +253,16 @@ impl AtemSocket {
} }
} }
pub async fn send_commands(&mut self, commands: Vec<AtemSocketCommand>) { pub async fn send_commands(&mut self, commands: Vec<AtemSocketCommand>) -> Vec<TrackingId> {
let mut tracking_ids: Vec<TrackingId> = Vec::with_capacity(commands.len());
for command in commands.into_iter() { for command in commands.into_iter() {
self.send_command(&command.payload, &command.raw_name, command.tracking_id) let tracking_id = self.next_packet_tracking_id();
self.send_command(&command.payload, &command.raw_name, tracking_id)
.await; .await;
tracking_ids.push(TrackingId(tracking_id));
} }
tracking_ids
} }
pub async fn send_command(&mut self, payload: &[u8], raw_name: &str, tracking_id: u64) { pub async fn send_command(&mut self, payload: &[u8], raw_name: &str, tracking_id: u64) {
@ -192,16 +298,12 @@ impl AtemSocket {
}) })
} }
pub fn subscribe_to_events(&self) -> tokio::sync::broadcast::Receiver<AtemEvent> {
self.atem_event_tx.subscribe()
}
async fn restart_connection(&mut self) { async fn restart_connection(&mut self) {
self.disconnect(); self.disconnect();
self.connect(self.address.clone(), self.port).await.ok(); self.connect(self.address.clone()).await.ok();
} }
pub async fn tick(&mut self) { async fn tick(&mut self) {
let messages = self.receive().await.ok(); let messages = self.receive().await.ok();
if let Some(messages) = messages { if let Some(messages) = messages {
for message in messages.iter() { for message in messages.iter() {
@ -220,8 +322,8 @@ impl AtemSocket {
if self.last_received_at + Duration::from_millis(CONNECTION_TIMEOUT) if self.last_received_at + Duration::from_millis(CONNECTION_TIMEOUT)
<= SystemTime::now() <= SystemTime::now()
{ {
debug!("{:?}", self.last_received_at); log::debug!("{:?}", self.last_received_at);
debug!("Connection timed out, restarting"); log::debug!("Connection timed out, restarting");
self.restart_connection().await; self.restart_connection().await;
} }
self.start_reconnect_timer(); self.start_reconnect_timer();
@ -261,7 +363,7 @@ impl AtemSocket {
return; return;
}; };
debug!("Received {:x?}", atem_packet); log::debug!("Received {:x?}", atem_packet);
self.last_received_at = SystemTime::now(); self.last_received_at = SystemTime::now();
@ -269,16 +371,17 @@ impl AtemSocket {
let remote_packet_id = atem_packet.remote_packet_id(); let remote_packet_id = atem_packet.remote_packet_id();
if atem_packet.has_flag(PacketFlag::NewSessionId) { if atem_packet.has_flag(PacketFlag::NewSessionId) {
debug!("New session"); log::debug!("New session");
self.connection_state = ConnectionState::Established; self.connection_state = ConnectionState::Established;
self.last_received_packed_id = remote_packet_id; self.last_received_packed_id = remote_packet_id;
self.send_ack(remote_packet_id).await; self.send_ack(remote_packet_id).await;
self.on_connect();
return; return;
} }
if self.connection_state == ConnectionState::Established { if self.connection_state == ConnectionState::Established {
if let Some(from_packet_id) = atem_packet.retransmit_request() { if let Some(from_packet_id) = atem_packet.retransmit_request() {
debug!("Retransmit request: {:x?}", from_packet_id); log::debug!("Retransmit request: {:x?}", from_packet_id);
self.retransmit_from(from_packet_id).await; self.retransmit_from(from_packet_id).await;
} }
@ -299,7 +402,7 @@ impl AtemSocket {
} }
if atem_packet.has_flag(PacketFlag::IsRetransmit) { if atem_packet.has_flag(PacketFlag::IsRetransmit) {
debug!("ATEM retransmitted packet {:x?}", remote_packet_id); log::debug!("ATEM retransmitted packet {:x?}", remote_packet_id);
} }
if let Some(ack_packet_id) = atem_packet.ack_reply() { if let Some(ack_packet_id) = atem_packet.ack_reply() {
@ -327,11 +430,11 @@ impl AtemSocket {
} }
async fn send_packet(&self, packet: &[u8]) { async fn send_packet(&self, packet: &[u8]) {
debug!("Send {:x?}", packet); log::debug!("Send {:x?}", packet);
if let Some(socket) = &self.socket { if let Some(socket) = &self.socket {
socket.send(packet).await.ok(); socket.send(packet).await.ok();
} else { } else {
debug!("Socket is not open") log::debug!("Socket is not open")
} }
} }
@ -347,7 +450,7 @@ impl AtemSocket {
} }
async fn send_ack(&mut self, packet_id: u16) { async fn send_ack(&mut self, packet_id: u16) {
debug!("Sending ack for packet {:x?}", packet_id); log::debug!("Sending ack for packet {:x?}", packet_id);
let flag: u8 = PacketFlag::AckReply.into(); let flag: u8 = PacketFlag::AckReply.into();
let opcode = u16::from(flag) << 11; let opcode = u16::from(flag) << 11;
let mut buffer: [u8; ACK_PACKET_LENGTH as _] = [0; 12]; let mut buffer: [u8; ACK_PACKET_LENGTH as _] = [0; 12];
@ -365,7 +468,7 @@ impl AtemSocket {
.iter() .iter()
.position(|pkt| pkt.packet_id == from_id) .position(|pkt| pkt.packet_id == from_id)
{ {
debug!( log::debug!(
"Resending from {} to {}", "Resending from {} to {}",
from_id, from_id,
self.in_flight[self.in_flight.len() - 1].packet_id self.in_flight[self.in_flight.len() - 1].packet_id
@ -382,7 +485,7 @@ impl AtemSocket {
} }
} }
} else { } else {
debug!("Unable to resend: {}", from_id); log::debug!("Unable to resend: {}", from_id);
self.restart_connection().await; self.restart_connection().await;
} }
} }
@ -395,11 +498,11 @@ impl AtemSocket {
&& self && self
.is_packet_covered_by_ack(self.next_send_packet_id, sent_packet.packet_id) .is_packet_covered_by_ack(self.next_send_packet_id, sent_packet.packet_id)
{ {
debug!("Retransmit from timeout: {}", sent_packet.packet_id); log::debug!("Retransmit from timeout: {}", sent_packet.packet_id);
self.retransmit_from(sent_packet.packet_id).await; self.retransmit_from(sent_packet.packet_id).await;
} else { } else {
debug!("Packet timed out: {}", sent_packet.packet_id); log::debug!("Packet timed out: {}", sent_packet.packet_id);
self.restart_connection().await; self.restart_connection().await;
} }
} }
@ -412,10 +515,16 @@ impl AtemSocket {
fn on_command_acknowledged(&mut self, packets: Vec<AckedPacket>) { fn on_command_acknowledged(&mut self, packets: Vec<AckedPacket>) {
for ack in packets { for ack in packets {
let _ = self.atem_event_tx.send(AtemEvent::AckedCommand(ack)); let _ = self
.atem_event_tx
.send(AtemEvent::AckedCommand(TrackingId(ack.tracking_id)));
} }
} }
fn on_connect(&mut self) {
let _ = self.atem_event_tx.send(AtemEvent::Connected);
}
fn on_disconnect(&mut self) { fn on_disconnect(&mut self) {
let _ = self.atem_event_tx.send(AtemEvent::Disconnected); let _ = self.atem_event_tx.send(AtemEvent::Disconnected);
} }
@ -439,31 +548,10 @@ impl AtemSocket {
self.retransmit_timer = self.retransmit_timer =
Some(SystemTime::now() + Duration::from_millis(RETRANSMIT_CHECK_INTERVAL)); Some(SystemTime::now() + Duration::from_millis(RETRANSMIT_CHECK_INTERVAL));
} }
}
impl Default for AtemSocket { fn next_packet_tracking_id(&mut self) -> u64 {
fn default() -> Self { self.next_tracking_id = self.next_tracking_id.checked_add(1).unwrap_or(1);
let (atem_event_tx, _) = tokio::sync::broadcast::channel(100);
Self { self.next_tracking_id
connection_state: ConnectionState::Closed,
reconnect_timer: None,
retransmit_timer: None,
next_send_packet_id: 1,
session_id: 0,
socket: None,
address: "0.0.0.0".to_string(),
port: 0,
last_received_at: SystemTime::now(),
last_received_packed_id: 0,
in_flight: vec![],
ack_timer: None,
received_without_ack: 0,
atem_event_tx,
}
} }
} }

View File

@ -11,7 +11,7 @@ pub trait CommandDeserializer: Send + Sync {
} }
pub trait SerializableCommand { pub trait SerializableCommand {
fn payload(&self, version: ProtocolVersion) -> Vec<u8>; fn payload(&self, version: &ProtocolVersion) -> Vec<u8>;
} }
pub trait BasicWritableCommand: SerializableCommand { pub trait BasicWritableCommand: SerializableCommand {

View File

@ -1,3 +1 @@
use super::command_base::{BasicWritableCommand, SerializableCommand};
pub mod program_input; pub mod program_input;

View File

@ -11,7 +11,7 @@ pub struct ProgramInput {
} }
impl SerializableCommand for ProgramInput { impl SerializableCommand for ProgramInput {
fn payload(&self, _version: crate::enums::ProtocolVersion) -> Vec<u8> { fn payload(&self, _version: &crate::enums::ProtocolVersion) -> Vec<u8> {
let mut buf = vec![0; 4]; let mut buf = vec![0; 4];
buf[..1].copy_from_slice(&self.mix_effect.to_be_bytes()); buf[..1].copy_from_slice(&self.mix_effect.to_be_bytes());
buf[2..].copy_from_slice(&self.source.to_be_bytes()); buf[2..].copy_from_slice(&self.source.to_be_bytes());

View File

@ -2,8 +2,6 @@
extern crate derive_new; extern crate derive_new;
#[macro_use] #[macro_use]
extern crate derive_getters; extern crate derive_getters;
#[macro_use]
extern crate thiserror;
pub mod atem; pub mod atem;
pub mod atem_lib; pub mod atem_lib;

View File

@ -12,3 +12,4 @@ color-eyre = "0.5.11"
env_logger = "0.9.0" env_logger = "0.9.0"
log = "0.4.14" log = "0.4.14"
tokio = "1.14.0" tokio = "1.14.0"
tokio-util = "0.7.10"

View File

@ -1,16 +1,20 @@
use std::{sync::Arc, time::Duration}; use std::{
net::{Ipv4Addr, SocketAddrV4},
str::FromStr,
sync::Arc,
time::Duration,
};
use atem_connection_rs::{ use atem_connection_rs::{
atem_lib::atem_socket::AtemSocket, atem::Atem,
commands::{ atem_lib::atem_socket::{AtemSocket, AtemSocketMessage},
command_base::{BasicWritableCommand, SerializableCommand}, commands::mix_effects::program_input::ProgramInput,
mix_effects::program_input::ProgramInput,
},
}; };
use clap::Parser; use clap::Parser;
use color_eyre::Report; use color_eyre::Report;
use tokio::{task::yield_now, time::sleep}; use tokio::time::sleep;
use tokio_util::sync::CancellationToken;
/// ATEM Rust Library Test App /// ATEM Rust Library Test App
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -27,44 +31,36 @@ async fn main() {
setup_logging().unwrap(); setup_logging().unwrap();
let switch_to_source_1 = ProgramInput::new(0, 1); let (socket_message_tx, socket_message_rx) =
let switch_to_source_2 = ProgramInput::new(0, 2); tokio::sync::mpsc::channel::<AtemSocketMessage>(10);
let (atem_event_tx, atem_event_rx) = tokio::sync::mpsc::unbounded_channel();
let cancel = CancellationToken::new();
let cancel_task = cancel.clone();
let atem = tokio::sync::RwLock::new(AtemSocket::default()); let mut atem_socket = AtemSocket::new(atem_event_tx);
let atem = Arc::new(atem); tokio::spawn(async move {
atem_socket.run(socket_message_rx, cancel_task).await;
});
let atem = Arc::new(Atem::new(socket_message_tx));
let atem_thread = atem.clone(); let atem_thread = atem.clone();
tokio::spawn(async move { tokio::spawn(async move {
loop { atem_thread.run(atem_event_rx, cancel).await;
atem_thread.write().await.tick().await;
yield_now().await;
}
}); });
atem.write().await.connect(args.ip, 9910).await.ok();
let mut tracking_id = 0; let address = Ipv4Addr::from_str(&args.ip).unwrap();
let socket = SocketAddrV4::new(address, 9910);
atem.connect(socket.into()).await;
loop { loop {
tracking_id += 1;
sleep(Duration::from_millis(5000)).await; sleep(Duration::from_millis(5000)).await;
atem.write() log::info!("Switch to source 1");
.await atem.send_commands(vec![Box::new(ProgramInput::new(0, 1))])
.send_command(
&switch_to_source_1.payload(atem_connection_rs::enums::ProtocolVersion::Unknown),
switch_to_source_1.get_raw_name(),
tracking_id,
)
.await; .await;
tracking_id += 1;
sleep(Duration::from_millis(5000)).await; sleep(Duration::from_millis(5000)).await;
atem.write() log::info!("Switch to source 2");
.await atem.send_commands(vec![Box::new(ProgramInput::new(0, 2))])
.send_command(
&switch_to_source_2.payload(atem_connection_rs::enums::ProtocolVersion::Unknown),
switch_to_source_2.get_raw_name(),
tracking_id,
)
.await; .await;
tracking_id += 1;
} }
} }