atem-connection-rs/atem-connection-rs/src/atem_lib/atem_socket.rs

577 lines
19 KiB
Rust

use std::{
fmt::Display,
io,
net::SocketAddr,
sync::Arc,
time::{Duration, SystemTime},
};
use tokio::{
net::UdpSocket,
sync::{Barrier, Mutex},
task::yield_now,
};
use crate::{
atem_lib::{atem_packet::AtemPacket, atem_util},
commands::{
command_base::{BasicWritableCommand, DeserializedCommand},
parse_commands::deserialize_commands,
},
enums::ProtocolVersion,
};
use super::atem_packet::PacketFlag;
const IN_FLIGHT_TIMEOUT: u64 = 60;
const CONNECTION_TIMEOUT: u64 = 5000;
const CONNECTION_RETRY_INTERVAL: u64 = 1000;
const RETRANSMIT_CHECK_INTERVAL: u64 = 1000;
const MAX_PACKET_RETRIES: u16 = 10;
const MAX_PACKET_ID: u16 = 1 << 15;
const MAX_PACKET_PER_ACK: u16 = 16;
// Set to max UDP packet size, for now
const MAX_PACKET_RECEIVE_SIZE: usize = 65535;
const ACK_PACKET_LENGTH: u16 = 12;
pub enum AtemSocketMessage {
Connect {
address: SocketAddr,
result_callback: tokio::sync::oneshot::Sender<bool>,
},
Disconnect,
SendCommands {
commands: Vec<AtemSocketCommand>,
tracking_ids_callback: tokio::sync::oneshot::Sender<TrackingIdsCallback>,
},
}
pub struct TrackingIdsCallback {
pub tracking_ids: Vec<TrackingId>,
pub barrier: Arc<Barrier>,
}
#[derive(Clone)]
pub enum AtemEvent {
Error(String),
Info(String),
Debug(String),
Connected,
Disconnected,
ReceivedCommand(Arc<dyn DeserializedCommand>),
AckedCommand(TrackingId),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TrackingId(u64);
impl Display for TrackingId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Clone)]
pub struct AckedPacket {
pub packet_id: u16,
pub tracking_id: u64,
}
pub struct AtemSocketCommand {
payload: Vec<u8>,
raw_name: String,
}
impl AtemSocketCommand {
pub fn new(command: &Box<dyn BasicWritableCommand>, version: &ProtocolVersion) -> Self {
Self {
payload: command.payload(version),
raw_name: command.get_raw_name().to_string(),
}
}
}
pub struct AtemSocket {
connection_state: ConnectionState,
reconnect_timer: Option<SystemTime>,
retransmit_timer: Option<SystemTime>,
next_tracking_id: u64,
next_send_packet_id: u16,
session_id: u16,
socket: Option<UdpSocket>,
address: SocketAddr,
last_received_at: SystemTime,
last_received_packed_id: u16,
in_flight: Vec<InFlightPacket>,
ack_timer: Option<SystemTime>,
received_without_ack: u16,
atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemEvent>,
connected_callbacks: Mutex<Vec<tokio::sync::oneshot::Sender<bool>>>,
}
#[derive(PartialEq, Clone)]
enum ConnectionState {
Closed,
SynSent,
Established,
}
#[allow(clippy::from_over_into)]
impl Into<u8> for ConnectionState {
fn into(self) -> u8 {
match self {
ConnectionState::Closed => 0x00,
ConnectionState::SynSent => 0x01,
ConnectionState::Established => 0x02,
}
}
}
#[derive(Clone)]
struct InFlightPacket {
packet_id: u16,
tracking_id: u64,
payload: Vec<u8>,
pub last_sent: SystemTime,
pub resent: u16,
}
enum AtemSocketReceiveError {
Closed,
}
impl AtemSocket {
pub fn new(atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemEvent>) -> Self {
Self {
connection_state: ConnectionState::Closed,
reconnect_timer: None,
retransmit_timer: None,
next_tracking_id: 0,
next_send_packet_id: 1,
session_id: 0,
socket: None,
address: "0.0.0.0:0".parse().unwrap(),
last_received_at: SystemTime::now(),
last_received_packed_id: 0,
in_flight: vec![],
ack_timer: None,
received_without_ack: 0,
atem_event_tx,
connected_callbacks: Mutex::default(),
}
}
pub async fn run(
&mut self,
mut atem_message_rx: tokio::sync::mpsc::Receiver<AtemSocketMessage>,
cancel: tokio_util::sync::CancellationToken,
) {
while !cancel.is_cancelled() {
if let Ok(msg) = atem_message_rx.try_recv() {
match msg {
AtemSocketMessage::Connect {
address,
result_callback,
} => {
{
let mut connected_callbacks = self.connected_callbacks.lock().await;
connected_callbacks.push(result_callback);
}
if self.connect(address).await.is_err() {
let mut connected_callbacks = self.connected_callbacks.lock().await;
for callback in connected_callbacks.drain(0..) {
let _ = callback.send(false);
}
}
}
AtemSocketMessage::Disconnect => self.disconnect(),
AtemSocketMessage::SendCommands {
commands,
tracking_ids_callback,
} => {
let barrier = Arc::new(Barrier::new(2));
tracking_ids_callback
.send(TrackingIdsCallback {
tracking_ids: self.send_commands(commands).await,
barrier: barrier.clone(),
})
.ok();
// Let's play the game "Synchronisation Shenanigans"!
// So, we are sending tracking Ids to the sender of this message, the sender will then wait
// for each of these tracking Ids to be ACK'd by the ATEM. However, the sender will need to
// do ✨ some form of shenanigans ✨ in order to be ready to receive tracking Ids. So we send
// them a barrier as part of the callback so that they can tell us that they are ready for
// us to continue with ATEM communication, at which point we may immediately inform them of a
// received tracking Id matching one included in this callback.
//
// Now, if we were being 🚩 Real Proper Software Developers 🚩 we'd probably expect the receiver
// of the callback to do clever things so that if a tracking Id is received immediately, they
// then wait for something that wants that tracking Id on their side, rather than blocking this
// task so that the caller can do ✨ shenanigans ✨. However, that sounds far too clever and too
// much like 🚩 Real Actual Work 🚩 so instead we've chosen to do this and hope that whichever
// actor we're waiting on doesn't take _too_ long to do ✨ shenanigans ✨ before signalling that
// they are ready. If they do, I suggest finding whoever wrote that code and bonking them 🔨.
barrier.wait().await;
}
}
yield_now().await;
}
self.tick().await;
}
}
pub async fn connect(&mut self, address: SocketAddr) -> Result<(), io::Error> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(address).await?;
self.socket = Some(socket);
self.start_timers();
self.next_send_packet_id = 1;
self.session_id = 0;
self.in_flight = vec![];
log::debug!("Reconnect");
self.send_packet(&atem_util::COMMAND_CONNECT_HELLO).await;
self.connection_state = ConnectionState::SynSent;
Ok(())
}
pub fn disconnect(&mut self) {
self.stop_timers();
self.retransmit_timer = None;
self.reconnect_timer = None;
self.ack_timer = None;
self.socket = None;
let prev_connection_state = self.connection_state.clone();
self.connection_state = ConnectionState::Closed;
if prev_connection_state == ConnectionState::Established {
self.on_disconnect();
}
}
pub async fn send_commands(&mut self, commands: Vec<AtemSocketCommand>) -> Vec<TrackingId> {
let mut tracking_ids: Vec<TrackingId> = Vec::with_capacity(commands.len());
for command in commands.into_iter() {
let tracking_id = self.next_packet_tracking_id();
self.send_command(&command.payload, &command.raw_name, tracking_id)
.await;
tracking_ids.push(TrackingId(tracking_id));
}
tracking_ids
}
pub async fn send_command(&mut self, payload: &[u8], raw_name: &str, tracking_id: u64) {
let packet_id = self.next_send_packet_id;
self.next_send_packet_id += 1;
if self.next_send_packet_id >= MAX_PACKET_ID {
self.next_send_packet_id = 0;
}
let opcode = u16::from(u8::from(PacketFlag::AckRequest)) << 11;
let mut buffer = vec![0; 20 + payload.len()];
// Headers
buffer[0..2].copy_from_slice(&u16::to_be_bytes(opcode | (payload.len() as u16 + 20)));
buffer[2..4].copy_from_slice(&u16::to_be_bytes(self.session_id));
buffer[10..12].copy_from_slice(&u16::to_be_bytes(packet_id));
// Command
buffer[12..14].copy_from_slice(&u16::to_be_bytes(payload.len() as u16 + 8));
buffer[16..20].copy_from_slice(raw_name.as_bytes());
// Body
buffer[20..20 + payload.len()].copy_from_slice(payload);
self.send_packet(&buffer).await;
self.in_flight.push(InFlightPacket {
packet_id,
tracking_id,
payload: buffer,
last_sent: SystemTime::now(),
resent: 0,
})
}
async fn restart_connection(&mut self) {
self.disconnect();
self.connect(self.address.clone()).await.ok();
}
async fn tick(&mut self) {
let messages = self.receive().await.ok();
if let Some(messages) = messages {
for message in messages.iter() {
self.recieved_packet(message).await;
}
}
if let Some(ack_time) = self.ack_timer {
if ack_time <= SystemTime::now() {
self.ack_timer = None;
self.received_without_ack = 0;
self.send_ack(self.last_received_packed_id).await;
}
}
if let Some(reconnect_time) = self.reconnect_timer {
if reconnect_time <= SystemTime::now() {
if self.last_received_at + Duration::from_millis(CONNECTION_TIMEOUT)
<= SystemTime::now()
{
log::debug!("{:?}", self.last_received_at);
log::debug!("Connection timed out, restarting");
self.restart_connection().await;
}
self.start_reconnect_timer();
}
}
if let Some(retransmit_time) = self.retransmit_timer {
if retransmit_time <= SystemTime::now() {
self.check_for_retransmit().await;
self.start_retransmit_timer();
}
}
}
async fn receive(&mut self) -> Result<Vec<Vec<u8>>, AtemSocketReceiveError> {
let mut messages: Vec<Vec<u8>> = vec![];
let socket = self.socket.as_mut().ok_or(AtemSocketReceiveError::Closed)?;
let mut buf = [0; MAX_PACKET_RECEIVE_SIZE];
if let Ok((message_size, _)) = socket.try_recv_from(&mut buf) {
messages.push(buf[0..message_size].to_owned());
}
Ok(messages)
}
fn is_packet_covered_by_ack(&self, ack_id: u16, packet_id: u16) -> bool {
let tolerance: u16 = MAX_PACKET_ID / 2;
let pkt_is_shortly_before = packet_id < ack_id && packet_id + tolerance > ack_id;
let pkt_is_shortly_after = packet_id > ack_id && packet_id < ack_id + tolerance;
let pkt_is_before_wrap = packet_id > ack_id + tolerance;
packet_id == ack_id
|| ((pkt_is_shortly_before || pkt_is_before_wrap) && !pkt_is_shortly_after)
}
async fn recieved_packet(&mut self, packet: &[u8]) {
let Ok(atem_packet): Result<AtemPacket, _> = packet.try_into() else {
return;
};
log::debug!("Received {:x?}", atem_packet);
self.last_received_at = SystemTime::now();
self.session_id = atem_packet.session_id();
let remote_packet_id = atem_packet.remote_packet_id();
if atem_packet.has_flag(PacketFlag::NewSessionId) {
log::debug!("New session");
self.connection_state = ConnectionState::Established;
self.last_received_packed_id = remote_packet_id;
self.send_ack(remote_packet_id).await;
self.on_connect();
return;
}
if self.connection_state == ConnectionState::Established {
if let Some(from_packet_id) = atem_packet.retransmit_request() {
log::debug!("Retransmit request: {:x?}", from_packet_id);
self.retransmit_from(from_packet_id).await;
}
if atem_packet.has_flag(PacketFlag::AckRequest) {
if remote_packet_id == (self.last_received_packed_id + 1) % MAX_PACKET_ID {
self.last_received_packed_id = remote_packet_id;
self.send_or_queue_ack().await;
if atem_packet.length() > 12 {
self.on_commands_received(atem_packet.body());
}
} else if self
.is_packet_covered_by_ack(self.last_received_packed_id, remote_packet_id)
{
self.send_or_queue_ack().await;
}
}
if atem_packet.has_flag(PacketFlag::IsRetransmit) {
log::debug!("ATEM retransmitted packet {:x?}", remote_packet_id);
}
if let Some(ack_packet_id) = atem_packet.ack_reply() {
let mut acked_commands: Vec<AckedPacket> = vec![];
self.in_flight = self
.in_flight
.clone()
.into_iter()
.filter(|pkt| {
if self.is_packet_covered_by_ack(ack_packet_id, pkt.packet_id) {
acked_commands.push(AckedPacket {
packet_id: pkt.packet_id,
tracking_id: pkt.tracking_id,
});
false
} else {
true
}
})
.collect();
self.on_command_acknowledged(acked_commands);
}
}
}
async fn send_packet(&self, packet: &[u8]) {
log::debug!("Send {:x?}", packet);
if let Some(socket) = &self.socket {
socket.send(packet).await.ok();
} else {
log::debug!("Socket is not open")
}
}
async fn send_or_queue_ack(&mut self) {
self.received_without_ack += 1;
if self.received_without_ack >= MAX_PACKET_PER_ACK {
self.received_without_ack = 0;
self.ack_timer = None;
self.send_ack(self.last_received_packed_id).await;
} else if self.ack_timer.is_none() {
self.ack_timer = Some(SystemTime::now() + Duration::from_millis(5));
}
}
async fn send_ack(&mut self, packet_id: u16) {
log::debug!("Sending ack for packet {:x?}", packet_id);
let flag: u8 = PacketFlag::AckReply.into();
let opcode = u16::from(flag) << 11;
let mut buffer: [u8; ACK_PACKET_LENGTH as _] = [0; 12];
buffer[0..2].copy_from_slice(&u16::to_be_bytes(opcode as u16 | ACK_PACKET_LENGTH));
buffer[2..4].copy_from_slice(&u16::to_be_bytes(self.session_id));
buffer[4..6].copy_from_slice(&u16::to_be_bytes(packet_id));
self.send_packet(&buffer).await;
}
async fn retransmit_from(&mut self, from_id: u16) {
let from_id = from_id % MAX_PACKET_ID;
if let Some(index) = self
.in_flight
.iter()
.position(|pkt| pkt.packet_id == from_id)
{
log::debug!(
"Resending from {} to {}",
from_id,
self.in_flight[self.in_flight.len() - 1].packet_id
);
for i in index..self.in_flight.len() {
let mut sent_packet = self.in_flight[i].clone();
if sent_packet.packet_id == from_id
|| !self.is_packet_covered_by_ack(from_id, sent_packet.packet_id)
{
sent_packet.last_sent = SystemTime::now();
sent_packet.resent += 1;
self.send_packet(&sent_packet.payload).await;
}
}
} else {
log::debug!("Unable to resend: {}", from_id);
self.restart_connection().await;
}
}
async fn check_for_retransmit(&mut self) {
for sent_packet in self.in_flight.clone() {
if sent_packet.last_sent + Duration::from_millis(IN_FLIGHT_TIMEOUT) < SystemTime::now()
{
if sent_packet.resent <= MAX_PACKET_RETRIES
&& self
.is_packet_covered_by_ack(self.next_send_packet_id, sent_packet.packet_id)
{
log::debug!("Retransmit from timeout: {}", sent_packet.packet_id);
self.retransmit_from(sent_packet.packet_id).await;
} else {
log::debug!("Packet timed out: {}", sent_packet.packet_id);
self.restart_connection().await;
}
}
}
}
fn on_commands_received(&mut self, payload: &[u8]) {
let commands = deserialize_commands(payload);
}
fn on_command_acknowledged(&mut self, packets: Vec<AckedPacket>) {
for ack in packets {
let _ = self
.atem_event_tx
.send(AtemEvent::AckedCommand(TrackingId(ack.tracking_id)));
}
}
fn on_connect(&mut self) {
let _ = self.atem_event_tx.send(AtemEvent::Connected);
let mut connected_callbacks = self.connected_callbacks.blocking_lock();
for callback in connected_callbacks.drain(0..) {
let _ = callback.send(false);
}
}
fn on_disconnect(&mut self) {
let _ = self.atem_event_tx.send(AtemEvent::Disconnected);
}
fn start_timers(&mut self) {
self.start_reconnect_timer();
self.start_retransmit_timer();
}
fn stop_timers(&mut self) {
self.reconnect_timer = None;
self.retransmit_timer = None;
}
fn start_reconnect_timer(&mut self) {
self.reconnect_timer =
Some(SystemTime::now() + Duration::from_millis(CONNECTION_RETRY_INTERVAL));
}
fn start_retransmit_timer(&mut self) {
self.retransmit_timer =
Some(SystemTime::now() + Duration::from_millis(RETRANSMIT_CHECK_INTERVAL));
}
fn next_packet_tracking_id(&mut self) -> u64 {
self.next_tracking_id = self.next_tracking_id.checked_add(1).unwrap_or(1);
self.next_tracking_id
}
}