577 lines
19 KiB
Rust
577 lines
19 KiB
Rust
use std::{
|
|
fmt::Display,
|
|
io,
|
|
net::SocketAddr,
|
|
sync::Arc,
|
|
time::{Duration, SystemTime},
|
|
};
|
|
|
|
use tokio::{
|
|
net::UdpSocket,
|
|
sync::{Barrier, Mutex},
|
|
task::yield_now,
|
|
};
|
|
|
|
use crate::{
|
|
atem_lib::{atem_packet::AtemPacket, atem_util},
|
|
commands::{
|
|
command_base::{BasicWritableCommand, DeserializedCommand},
|
|
parse_commands::deserialize_commands,
|
|
},
|
|
enums::ProtocolVersion,
|
|
};
|
|
|
|
use super::atem_packet::PacketFlag;
|
|
|
|
const IN_FLIGHT_TIMEOUT: u64 = 60;
|
|
const CONNECTION_TIMEOUT: u64 = 5000;
|
|
const CONNECTION_RETRY_INTERVAL: u64 = 1000;
|
|
const RETRANSMIT_CHECK_INTERVAL: u64 = 1000;
|
|
const MAX_PACKET_RETRIES: u16 = 10;
|
|
const MAX_PACKET_ID: u16 = 1 << 15;
|
|
const MAX_PACKET_PER_ACK: u16 = 16;
|
|
|
|
// Set to max UDP packet size, for now
|
|
const MAX_PACKET_RECEIVE_SIZE: usize = 65535;
|
|
const ACK_PACKET_LENGTH: u16 = 12;
|
|
|
|
pub enum AtemSocketMessage {
|
|
Connect {
|
|
address: SocketAddr,
|
|
result_callback: tokio::sync::oneshot::Sender<bool>,
|
|
},
|
|
Disconnect,
|
|
SendCommands {
|
|
commands: Vec<AtemSocketCommand>,
|
|
tracking_ids_callback: tokio::sync::oneshot::Sender<TrackingIdsCallback>,
|
|
},
|
|
}
|
|
|
|
pub struct TrackingIdsCallback {
|
|
pub tracking_ids: Vec<TrackingId>,
|
|
pub barrier: Arc<Barrier>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub enum AtemEvent {
|
|
Error(String),
|
|
Info(String),
|
|
Debug(String),
|
|
Connected,
|
|
Disconnected,
|
|
ReceivedCommand(Arc<dyn DeserializedCommand>),
|
|
AckedCommand(TrackingId),
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
|
pub struct TrackingId(u64);
|
|
|
|
impl Display for TrackingId {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}", self.0)
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct AckedPacket {
|
|
pub packet_id: u16,
|
|
pub tracking_id: u64,
|
|
}
|
|
|
|
pub struct AtemSocketCommand {
|
|
payload: Vec<u8>,
|
|
raw_name: String,
|
|
}
|
|
|
|
impl AtemSocketCommand {
|
|
pub fn new(command: &Box<dyn BasicWritableCommand>, version: &ProtocolVersion) -> Self {
|
|
Self {
|
|
payload: command.payload(version),
|
|
raw_name: command.get_raw_name().to_string(),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct AtemSocket {
|
|
connection_state: ConnectionState,
|
|
reconnect_timer: Option<SystemTime>,
|
|
retransmit_timer: Option<SystemTime>,
|
|
|
|
next_tracking_id: u64,
|
|
|
|
next_send_packet_id: u16,
|
|
session_id: u16,
|
|
|
|
socket: Option<UdpSocket>,
|
|
address: SocketAddr,
|
|
|
|
last_received_at: SystemTime,
|
|
last_received_packed_id: u16,
|
|
in_flight: Vec<InFlightPacket>,
|
|
ack_timer: Option<SystemTime>,
|
|
received_without_ack: u16,
|
|
|
|
atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemEvent>,
|
|
connected_callbacks: Mutex<Vec<tokio::sync::oneshot::Sender<bool>>>,
|
|
}
|
|
|
|
#[derive(PartialEq, Clone)]
|
|
enum ConnectionState {
|
|
Closed,
|
|
SynSent,
|
|
Established,
|
|
}
|
|
|
|
#[allow(clippy::from_over_into)]
|
|
impl Into<u8> for ConnectionState {
|
|
fn into(self) -> u8 {
|
|
match self {
|
|
ConnectionState::Closed => 0x00,
|
|
ConnectionState::SynSent => 0x01,
|
|
ConnectionState::Established => 0x02,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct InFlightPacket {
|
|
packet_id: u16,
|
|
tracking_id: u64,
|
|
payload: Vec<u8>,
|
|
pub last_sent: SystemTime,
|
|
pub resent: u16,
|
|
}
|
|
|
|
enum AtemSocketReceiveError {
|
|
Closed,
|
|
}
|
|
|
|
impl AtemSocket {
|
|
pub fn new(atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemEvent>) -> Self {
|
|
Self {
|
|
connection_state: ConnectionState::Closed,
|
|
reconnect_timer: None,
|
|
retransmit_timer: None,
|
|
|
|
next_tracking_id: 0,
|
|
|
|
next_send_packet_id: 1,
|
|
session_id: 0,
|
|
|
|
socket: None,
|
|
address: "0.0.0.0:0".parse().unwrap(),
|
|
|
|
last_received_at: SystemTime::now(),
|
|
last_received_packed_id: 0,
|
|
in_flight: vec![],
|
|
ack_timer: None,
|
|
received_without_ack: 0,
|
|
|
|
atem_event_tx,
|
|
connected_callbacks: Mutex::default(),
|
|
}
|
|
}
|
|
|
|
pub async fn run(
|
|
&mut self,
|
|
mut atem_message_rx: tokio::sync::mpsc::Receiver<AtemSocketMessage>,
|
|
cancel: tokio_util::sync::CancellationToken,
|
|
) {
|
|
while !cancel.is_cancelled() {
|
|
if let Ok(msg) = atem_message_rx.try_recv() {
|
|
match msg {
|
|
AtemSocketMessage::Connect {
|
|
address,
|
|
result_callback,
|
|
} => {
|
|
{
|
|
let mut connected_callbacks = self.connected_callbacks.lock().await;
|
|
connected_callbacks.push(result_callback);
|
|
}
|
|
if self.connect(address).await.is_err() {
|
|
let mut connected_callbacks = self.connected_callbacks.lock().await;
|
|
for callback in connected_callbacks.drain(0..) {
|
|
let _ = callback.send(false);
|
|
}
|
|
}
|
|
}
|
|
AtemSocketMessage::Disconnect => self.disconnect(),
|
|
AtemSocketMessage::SendCommands {
|
|
commands,
|
|
tracking_ids_callback,
|
|
} => {
|
|
let barrier = Arc::new(Barrier::new(2));
|
|
tracking_ids_callback
|
|
.send(TrackingIdsCallback {
|
|
tracking_ids: self.send_commands(commands).await,
|
|
barrier: barrier.clone(),
|
|
})
|
|
.ok();
|
|
|
|
// Let's play the game "Synchronisation Shenanigans"!
|
|
// So, we are sending tracking Ids to the sender of this message, the sender will then wait
|
|
// for each of these tracking Ids to be ACK'd by the ATEM. However, the sender will need to
|
|
// do ✨ some form of shenanigans ✨ in order to be ready to receive tracking Ids. So we send
|
|
// them a barrier as part of the callback so that they can tell us that they are ready for
|
|
// us to continue with ATEM communication, at which point we may immediately inform them of a
|
|
// received tracking Id matching one included in this callback.
|
|
//
|
|
// Now, if we were being 🚩 Real Proper Software Developers 🚩 we'd probably expect the receiver
|
|
// of the callback to do clever things so that if a tracking Id is received immediately, they
|
|
// then wait for something that wants that tracking Id on their side, rather than blocking this
|
|
// task so that the caller can do ✨ shenanigans ✨. However, that sounds far too clever and too
|
|
// much like 🚩 Real Actual Work 🚩 so instead we've chosen to do this and hope that whichever
|
|
// actor we're waiting on doesn't take _too_ long to do ✨ shenanigans ✨ before signalling that
|
|
// they are ready. If they do, I suggest finding whoever wrote that code and bonking them 🔨.
|
|
barrier.wait().await;
|
|
}
|
|
}
|
|
|
|
yield_now().await;
|
|
}
|
|
|
|
self.tick().await;
|
|
}
|
|
}
|
|
|
|
pub async fn connect(&mut self, address: SocketAddr) -> Result<(), io::Error> {
|
|
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
|
socket.connect(address).await?;
|
|
self.socket = Some(socket);
|
|
|
|
self.start_timers();
|
|
|
|
self.next_send_packet_id = 1;
|
|
self.session_id = 0;
|
|
self.in_flight = vec![];
|
|
log::debug!("Reconnect");
|
|
|
|
self.send_packet(&atem_util::COMMAND_CONNECT_HELLO).await;
|
|
self.connection_state = ConnectionState::SynSent;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn disconnect(&mut self) {
|
|
self.stop_timers();
|
|
|
|
self.retransmit_timer = None;
|
|
self.reconnect_timer = None;
|
|
self.ack_timer = None;
|
|
self.socket = None;
|
|
|
|
let prev_connection_state = self.connection_state.clone();
|
|
self.connection_state = ConnectionState::Closed;
|
|
|
|
if prev_connection_state == ConnectionState::Established {
|
|
self.on_disconnect();
|
|
}
|
|
}
|
|
|
|
pub async fn send_commands(&mut self, commands: Vec<AtemSocketCommand>) -> Vec<TrackingId> {
|
|
let mut tracking_ids: Vec<TrackingId> = Vec::with_capacity(commands.len());
|
|
for command in commands.into_iter() {
|
|
let tracking_id = self.next_packet_tracking_id();
|
|
self.send_command(&command.payload, &command.raw_name, tracking_id)
|
|
.await;
|
|
tracking_ids.push(TrackingId(tracking_id));
|
|
}
|
|
|
|
tracking_ids
|
|
}
|
|
|
|
pub async fn send_command(&mut self, payload: &[u8], raw_name: &str, tracking_id: u64) {
|
|
let packet_id = self.next_send_packet_id;
|
|
self.next_send_packet_id += 1;
|
|
if self.next_send_packet_id >= MAX_PACKET_ID {
|
|
self.next_send_packet_id = 0;
|
|
}
|
|
|
|
let opcode = u16::from(u8::from(PacketFlag::AckRequest)) << 11;
|
|
|
|
let mut buffer = vec![0; 20 + payload.len()];
|
|
|
|
// Headers
|
|
buffer[0..2].copy_from_slice(&u16::to_be_bytes(opcode | (payload.len() as u16 + 20)));
|
|
buffer[2..4].copy_from_slice(&u16::to_be_bytes(self.session_id));
|
|
buffer[10..12].copy_from_slice(&u16::to_be_bytes(packet_id));
|
|
|
|
// Command
|
|
buffer[12..14].copy_from_slice(&u16::to_be_bytes(payload.len() as u16 + 8));
|
|
buffer[16..20].copy_from_slice(raw_name.as_bytes());
|
|
|
|
// Body
|
|
buffer[20..20 + payload.len()].copy_from_slice(payload);
|
|
self.send_packet(&buffer).await;
|
|
|
|
self.in_flight.push(InFlightPacket {
|
|
packet_id,
|
|
tracking_id,
|
|
payload: buffer,
|
|
last_sent: SystemTime::now(),
|
|
resent: 0,
|
|
})
|
|
}
|
|
|
|
async fn restart_connection(&mut self) {
|
|
self.disconnect();
|
|
self.connect(self.address.clone()).await.ok();
|
|
}
|
|
|
|
async fn tick(&mut self) {
|
|
let messages = self.receive().await.ok();
|
|
if let Some(messages) = messages {
|
|
for message in messages.iter() {
|
|
self.recieved_packet(message).await;
|
|
}
|
|
}
|
|
if let Some(ack_time) = self.ack_timer {
|
|
if ack_time <= SystemTime::now() {
|
|
self.ack_timer = None;
|
|
self.received_without_ack = 0;
|
|
self.send_ack(self.last_received_packed_id).await;
|
|
}
|
|
}
|
|
if let Some(reconnect_time) = self.reconnect_timer {
|
|
if reconnect_time <= SystemTime::now() {
|
|
if self.last_received_at + Duration::from_millis(CONNECTION_TIMEOUT)
|
|
<= SystemTime::now()
|
|
{
|
|
log::debug!("{:?}", self.last_received_at);
|
|
log::debug!("Connection timed out, restarting");
|
|
self.restart_connection().await;
|
|
}
|
|
self.start_reconnect_timer();
|
|
}
|
|
}
|
|
if let Some(retransmit_time) = self.retransmit_timer {
|
|
if retransmit_time <= SystemTime::now() {
|
|
self.check_for_retransmit().await;
|
|
self.start_retransmit_timer();
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn receive(&mut self) -> Result<Vec<Vec<u8>>, AtemSocketReceiveError> {
|
|
let mut messages: Vec<Vec<u8>> = vec![];
|
|
let socket = self.socket.as_mut().ok_or(AtemSocketReceiveError::Closed)?;
|
|
|
|
let mut buf = [0; MAX_PACKET_RECEIVE_SIZE];
|
|
if let Ok((message_size, _)) = socket.try_recv_from(&mut buf) {
|
|
messages.push(buf[0..message_size].to_owned());
|
|
}
|
|
|
|
Ok(messages)
|
|
}
|
|
|
|
fn is_packet_covered_by_ack(&self, ack_id: u16, packet_id: u16) -> bool {
|
|
let tolerance: u16 = MAX_PACKET_ID / 2;
|
|
let pkt_is_shortly_before = packet_id < ack_id && packet_id + tolerance > ack_id;
|
|
let pkt_is_shortly_after = packet_id > ack_id && packet_id < ack_id + tolerance;
|
|
let pkt_is_before_wrap = packet_id > ack_id + tolerance;
|
|
packet_id == ack_id
|
|
|| ((pkt_is_shortly_before || pkt_is_before_wrap) && !pkt_is_shortly_after)
|
|
}
|
|
|
|
async fn recieved_packet(&mut self, packet: &[u8]) {
|
|
let Ok(atem_packet): Result<AtemPacket, _> = packet.try_into() else {
|
|
return;
|
|
};
|
|
|
|
log::debug!("Received {:x?}", atem_packet);
|
|
|
|
self.last_received_at = SystemTime::now();
|
|
|
|
self.session_id = atem_packet.session_id();
|
|
let remote_packet_id = atem_packet.remote_packet_id();
|
|
|
|
if atem_packet.has_flag(PacketFlag::NewSessionId) {
|
|
log::debug!("New session");
|
|
self.connection_state = ConnectionState::Established;
|
|
self.last_received_packed_id = remote_packet_id;
|
|
self.send_ack(remote_packet_id).await;
|
|
self.on_connect();
|
|
return;
|
|
}
|
|
|
|
if self.connection_state == ConnectionState::Established {
|
|
if let Some(from_packet_id) = atem_packet.retransmit_request() {
|
|
log::debug!("Retransmit request: {:x?}", from_packet_id);
|
|
|
|
self.retransmit_from(from_packet_id).await;
|
|
}
|
|
|
|
if atem_packet.has_flag(PacketFlag::AckRequest) {
|
|
if remote_packet_id == (self.last_received_packed_id + 1) % MAX_PACKET_ID {
|
|
self.last_received_packed_id = remote_packet_id;
|
|
self.send_or_queue_ack().await;
|
|
|
|
if atem_packet.length() > 12 {
|
|
self.on_commands_received(atem_packet.body());
|
|
}
|
|
} else if self
|
|
.is_packet_covered_by_ack(self.last_received_packed_id, remote_packet_id)
|
|
{
|
|
self.send_or_queue_ack().await;
|
|
}
|
|
}
|
|
|
|
if atem_packet.has_flag(PacketFlag::IsRetransmit) {
|
|
log::debug!("ATEM retransmitted packet {:x?}", remote_packet_id);
|
|
}
|
|
|
|
if let Some(ack_packet_id) = atem_packet.ack_reply() {
|
|
let mut acked_commands: Vec<AckedPacket> = vec![];
|
|
|
|
self.in_flight = self
|
|
.in_flight
|
|
.clone()
|
|
.into_iter()
|
|
.filter(|pkt| {
|
|
if self.is_packet_covered_by_ack(ack_packet_id, pkt.packet_id) {
|
|
acked_commands.push(AckedPacket {
|
|
packet_id: pkt.packet_id,
|
|
tracking_id: pkt.tracking_id,
|
|
});
|
|
false
|
|
} else {
|
|
true
|
|
}
|
|
})
|
|
.collect();
|
|
self.on_command_acknowledged(acked_commands);
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn send_packet(&self, packet: &[u8]) {
|
|
log::debug!("Send {:x?}", packet);
|
|
if let Some(socket) = &self.socket {
|
|
socket.send(packet).await.ok();
|
|
} else {
|
|
log::debug!("Socket is not open")
|
|
}
|
|
}
|
|
|
|
async fn send_or_queue_ack(&mut self) {
|
|
self.received_without_ack += 1;
|
|
if self.received_without_ack >= MAX_PACKET_PER_ACK {
|
|
self.received_without_ack = 0;
|
|
self.ack_timer = None;
|
|
self.send_ack(self.last_received_packed_id).await;
|
|
} else if self.ack_timer.is_none() {
|
|
self.ack_timer = Some(SystemTime::now() + Duration::from_millis(5));
|
|
}
|
|
}
|
|
|
|
async fn send_ack(&mut self, packet_id: u16) {
|
|
log::debug!("Sending ack for packet {:x?}", packet_id);
|
|
let flag: u8 = PacketFlag::AckReply.into();
|
|
let opcode = u16::from(flag) << 11;
|
|
let mut buffer: [u8; ACK_PACKET_LENGTH as _] = [0; 12];
|
|
buffer[0..2].copy_from_slice(&u16::to_be_bytes(opcode as u16 | ACK_PACKET_LENGTH));
|
|
buffer[2..4].copy_from_slice(&u16::to_be_bytes(self.session_id));
|
|
buffer[4..6].copy_from_slice(&u16::to_be_bytes(packet_id));
|
|
self.send_packet(&buffer).await;
|
|
}
|
|
|
|
async fn retransmit_from(&mut self, from_id: u16) {
|
|
let from_id = from_id % MAX_PACKET_ID;
|
|
|
|
if let Some(index) = self
|
|
.in_flight
|
|
.iter()
|
|
.position(|pkt| pkt.packet_id == from_id)
|
|
{
|
|
log::debug!(
|
|
"Resending from {} to {}",
|
|
from_id,
|
|
self.in_flight[self.in_flight.len() - 1].packet_id
|
|
);
|
|
for i in index..self.in_flight.len() {
|
|
let mut sent_packet = self.in_flight[i].clone();
|
|
if sent_packet.packet_id == from_id
|
|
|| !self.is_packet_covered_by_ack(from_id, sent_packet.packet_id)
|
|
{
|
|
sent_packet.last_sent = SystemTime::now();
|
|
sent_packet.resent += 1;
|
|
|
|
self.send_packet(&sent_packet.payload).await;
|
|
}
|
|
}
|
|
} else {
|
|
log::debug!("Unable to resend: {}", from_id);
|
|
self.restart_connection().await;
|
|
}
|
|
}
|
|
|
|
async fn check_for_retransmit(&mut self) {
|
|
for sent_packet in self.in_flight.clone() {
|
|
if sent_packet.last_sent + Duration::from_millis(IN_FLIGHT_TIMEOUT) < SystemTime::now()
|
|
{
|
|
if sent_packet.resent <= MAX_PACKET_RETRIES
|
|
&& self
|
|
.is_packet_covered_by_ack(self.next_send_packet_id, sent_packet.packet_id)
|
|
{
|
|
log::debug!("Retransmit from timeout: {}", sent_packet.packet_id);
|
|
|
|
self.retransmit_from(sent_packet.packet_id).await;
|
|
} else {
|
|
log::debug!("Packet timed out: {}", sent_packet.packet_id);
|
|
self.restart_connection().await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn on_commands_received(&mut self, payload: &[u8]) {
|
|
let commands = deserialize_commands(payload);
|
|
}
|
|
|
|
fn on_command_acknowledged(&mut self, packets: Vec<AckedPacket>) {
|
|
for ack in packets {
|
|
let _ = self
|
|
.atem_event_tx
|
|
.send(AtemEvent::AckedCommand(TrackingId(ack.tracking_id)));
|
|
}
|
|
}
|
|
|
|
fn on_connect(&mut self) {
|
|
let _ = self.atem_event_tx.send(AtemEvent::Connected);
|
|
let mut connected_callbacks = self.connected_callbacks.blocking_lock();
|
|
for callback in connected_callbacks.drain(0..) {
|
|
let _ = callback.send(false);
|
|
}
|
|
}
|
|
|
|
fn on_disconnect(&mut self) {
|
|
let _ = self.atem_event_tx.send(AtemEvent::Disconnected);
|
|
}
|
|
|
|
fn start_timers(&mut self) {
|
|
self.start_reconnect_timer();
|
|
self.start_retransmit_timer();
|
|
}
|
|
|
|
fn stop_timers(&mut self) {
|
|
self.reconnect_timer = None;
|
|
self.retransmit_timer = None;
|
|
}
|
|
|
|
fn start_reconnect_timer(&mut self) {
|
|
self.reconnect_timer =
|
|
Some(SystemTime::now() + Duration::from_millis(CONNECTION_RETRY_INTERVAL));
|
|
}
|
|
|
|
fn start_retransmit_timer(&mut self) {
|
|
self.retransmit_timer =
|
|
Some(SystemTime::now() + Duration::from_millis(RETRANSMIT_CHECK_INTERVAL));
|
|
}
|
|
|
|
fn next_packet_tracking_id(&mut self) -> u64 {
|
|
self.next_tracking_id = self.next_tracking_id.checked_add(1).unwrap_or(1);
|
|
|
|
self.next_tracking_id
|
|
}
|
|
}
|