From a419a78a0ac4b0f0c84230e803ecc97fc2e3a708 Mon Sep 17 00:00:00 2001 From: Sam Willcocks Date: Tue, 10 Jun 2025 10:19:09 +0100 Subject: [PATCH] Make AtemPacket operate on `&[u8] `, add field parsing --- .../src/atem_lib/atem_packet.rs | 170 ++++++++++++------ .../src/atem_lib/atem_socket_inner.rs | 72 +++----- atem-connection-rs/src/atem_lib/atem_util.rs | 4 - atem-connection-rs/src/atem_lib/mod.rs | 3 +- 4 files changed, 142 insertions(+), 107 deletions(-) delete mode 100644 atem-connection-rs/src/atem_lib/atem_util.rs diff --git a/atem-connection-rs/src/atem_lib/atem_packet.rs b/atem-connection-rs/src/atem_lib/atem_packet.rs index 18a9b02..f57a797 100755 --- a/atem-connection-rs/src/atem_lib/atem_packet.rs +++ b/atem-connection-rs/src/atem_lib/atem_packet.rs @@ -1,14 +1,24 @@ -pub struct AtemPacket { - length: u16, - flags: u8, - session_id: u16, - remote_packet_id: u16, - body: Vec, +use core::{fmt::Display, str}; + +// TODO: we don't need itertools once https://github.com/rust-lang/rust/issues/79524 lands +use itertools::Itertools; + +/// The "hello" packet to start communication with the ATEM +pub const COMMAND_CONNECT_HELLO: [u8; 20] = [ + 0x10, 0x14, 0x53, 0xab, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3a, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, +]; + +pub struct AtemPacket> { + buf: T, } +#[derive(Debug)] pub enum AtemPacketErr { - TooShort(String), - LengthDiffers(String), + /// The packet was too short + TooShort { got: usize }, + /// The packet's stated and actual lengths were different + LengthDiffers { expected: u16, got: usize }, } #[derive(PartialEq)] @@ -32,64 +42,122 @@ impl From for u8 { } } -impl AtemPacket { +impl> AtemPacket { + pub fn new_checked(buf: T) -> Result { + let len = buf.as_ref().len(); + if len < 12 { + return Err(AtemPacketErr::TooShort { + got: buf.as_ref().len(), + }); + } + let p = Self { buf }; + if p.length() as usize != len { + return Err(AtemPacketErr::LengthDiffers { + expected: p.length(), + got: len, + }); + } + Ok(p) + } pub fn length(&self) -> u16 { - self.length + u16::from_be_bytes(self.buf.as_ref()[0..=1].try_into().unwrap()) & 0x07ff } pub fn flags(&self) -> u8 { - self.flags + self.buf.as_ref()[0] >> 3 } pub fn session_id(&self) -> u16 { - self.session_id + u16::from_be_bytes(self.buf.as_ref()[2..=3].try_into().unwrap()) + } + + pub fn ack_number(&self) -> u16 { + u16::from_be_bytes(self.buf.as_ref()[4..=5].try_into().unwrap()) + } + + pub fn remote_sequence_number(&self) -> u16 { + u16::from_be_bytes(self.buf.as_ref()[9..=10].try_into().unwrap()) } pub fn remote_packet_id(&self) -> u16 { - self.remote_packet_id - } - - pub fn body(&self) -> Vec { - self.body.clone() + u16::from_be_bytes(self.buf.as_ref()[10..=11].try_into().unwrap()) } + /// Return true if this packet has the given [`PacketFlag`] pub fn has_flag(&self, flag: PacketFlag) -> bool { - self.flags & u8::from(flag) > 0 + self.flags() & u8::from(flag) > 0 + } + + /// Get an iterator over the `Field`s in this packet. + /// + /// Returns None if this is a packet without fields. + pub fn fields(&self) -> Option { + // TODO: do we only ever get newsessionid during the handshake (i.e. not in a packet with fields)? + if self.has_flag(PacketFlag::NewSessionId) { + None + } else { + Some(Fields { + data: self.body(), + offset: 0, + }) + } + } + + pub fn body(&self) -> &[u8] { + &self.buf.as_ref()[12..] } } -impl TryFrom<&[u8]> for AtemPacket { - type Error = AtemPacketErr; - - fn try_from(buffer: &[u8]) -> Result { - if buffer.len() < 12 { - return Err(AtemPacketErr::TooShort(format!( - "Invalid packet from ATEM {:x?}", - buffer - ))); - } - - let length = u16::from_be_bytes(buffer[0..2].try_into().unwrap()) & 0x07ff; - if length as usize != buffer.len() { - return Err(AtemPacketErr::LengthDiffers(format!( - "Length of message differs, expected {} got {}", - length, - buffer.len() - ))); - } - - let flags = buffer[0] >> 3; - let session_id = u16::from_be_bytes(buffer[2..4].try_into().unwrap()); - let remote_packet_id = u16::from_be_bytes(buffer[10..12].try_into().unwrap()); - - let body = buffer[12..].to_vec(); - - Ok(AtemPacket { - length, - flags, - session_id, - remote_packet_id, - body, - }) +#[cfg(feature = "std")] +impl> Display for AtemPacket { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!( + f, + "sid: {}, len: {}, fields: ({})", + self.session_id(), + self.length(), + self.fields() + .and_then(|f| Some( + f.map(|f| str::from_utf8(f.r#type).unwrap()) + .intersperse(", ") + .collect::() + )) + .unwrap_or("none".into()) + ) + } +} + +/// An ATEM protocol field - a 4-ascii character type plus variable length data +pub struct Field<'a> { + r#type: &'a [u8; 4], + data: &'a [u8], +} + +/// Created by [`AtemPacket::fields`] +pub struct Fields<'a> { + data: &'a [u8], + offset: usize, +} + +impl<'a> Iterator for Fields<'a> { + type Item = Field<'a>; + fn next(&mut self) -> Option { + let remain = self.data.len() - self.offset; + if remain == 0 { + return None; + } else if remain < 8 { + // TODO: is 8 indeed the minimum size for something here? (i.e. no field data) + panic!("Oh no"); + } + + let length = + u16::from_be_bytes(self.data[self.offset..=self.offset + 1].try_into().unwrap()); + // TODO: sanity check length + let r#type: &[u8; 4] = self.data[self.offset + 4..=self.offset + 7] + .try_into() + .unwrap(); + let data = &self.data[self.offset + 8..self.offset + (length as usize)]; + self.offset += (length as usize); + Some(Field { r#type, data }) } } diff --git a/atem-connection-rs/src/atem_lib/atem_socket_inner.rs b/atem-connection-rs/src/atem_lib/atem_socket_inner.rs index c9c2594..95c25ab 100644 --- a/atem-connection-rs/src/atem_lib/atem_socket_inner.rs +++ b/atem-connection-rs/src/atem_lib/atem_socket_inner.rs @@ -4,10 +4,10 @@ use std::{ time::{Duration, SystemTime}, }; -use log::debug; +use log::{debug, trace}; use tokio::net::UdpSocket; -use crate::atem_lib::atem_util; +use crate::atem_lib::atem_packet::{AtemPacket, PacketFlag, COMMAND_CONNECT_HELLO}; const IN_FLIGHT_TIMEOUT: u64 = 60; const CONNECTION_TIMEOUT: u64 = 5000; @@ -39,27 +39,6 @@ impl Into for ConnectionState { } } -#[derive(PartialEq)] -enum PacketFlag { - AckRequest, - NewSessionId, - IsRetransmit, - RetransmitRequest, - AckReply, -} - -impl From for u8 { - fn from(flag: PacketFlag) -> Self { - match flag { - PacketFlag::AckRequest => 0x01, - PacketFlag::NewSessionId => 0x02, - PacketFlag::IsRetransmit => 0x04, - PacketFlag::RetransmitRequest => 0x08, - PacketFlag::AckReply => 0x10, - } - } -} - #[derive(Clone)] struct InFlightPacket { packet_id: u16, @@ -152,7 +131,7 @@ impl AtemSocketInner { self.in_flight = vec![]; debug!("Reconnect"); - self.send_packet(&atem_util::COMMAND_CONNECT_HELLO).await; + self.send_packet(&COMMAND_CONNECT_HELLO).await; self.connection_state = ConnectionState::SynSent; Ok(()) @@ -277,30 +256,23 @@ impl AtemSocketInner { } async fn recieved_packet(&mut self, packet: &[u8]) { - debug!("RECV {:x?}", packet); + trace!("RX Raw: {:x?}", packet); - if packet.len() < 12 { - debug!("Invalid packet from ATEM {:x?}", packet); - return; - } + let checked = match AtemPacket::new_checked(packet) { + Ok(p) => p, + Err(e) => { + debug!("Invalid packet ({:?}): {:x?}", e, packet); + return; + } + }; + debug!("RX: {}", checked); self.last_received_at = SystemTime::now(); - let length = u16::from_be_bytes(packet[0..2].try_into().unwrap()) & 0x07ff; - if length as usize != packet.len() { - debug!( - "Length of message differs, expected {} got {}", - length, - packet.len() - ); - return; - } + self.session_id = checked.session_id(); + let remote_packet_id = checked.remote_packet_id(); - let flags = packet[0] >> 3; - self.session_id = u16::from_be_bytes(packet[2..4].try_into().unwrap()); - let remote_packet_id = u16::from_be_bytes(packet[10..12].try_into().unwrap()); - - if flags & u8::from(PacketFlag::NewSessionId) > 0 { + if checked.flags() & u8::from(PacketFlag::NewSessionId) > 0 { debug!("New session"); self.connection_state = ConnectionState::Established; self.last_received_packed_id = remote_packet_id; @@ -309,20 +281,20 @@ impl AtemSocketInner { } if self.connection_state == ConnectionState::Established { - if flags & u8::from(PacketFlag::RetransmitRequest) > 0 { + if checked.has_flag(PacketFlag::RetransmitRequest) { let from_packet_id = u16::from_be_bytes(packet[6..8].try_into().unwrap()); debug!("Retransmit request: {:x?}", from_packet_id); self.retransmit_from(from_packet_id).await; } - if flags & u8::from(PacketFlag::AckRequest) > 0 { + if checked.has_flag(PacketFlag::AckRequest) { if remote_packet_id == (self.last_received_packed_id + 1) % MAX_PACKET_ID { self.last_received_packed_id = remote_packet_id; self.send_or_queue_ack().await; - if length > 12 { - self.on_command_received(&packet[12..], remote_packet_id); + if checked.length() > 12 { + self.on_command_received(checked.body(), remote_packet_id); } } else if self .is_packet_covered_by_ack(self.last_received_packed_id, remote_packet_id) @@ -331,12 +303,12 @@ impl AtemSocketInner { } } - if flags & u8::from(PacketFlag::IsRetransmit) > 0 { + if checked.has_flag(PacketFlag::IsRetransmit) { debug!("ATEM retransmitted packet {:x?}", remote_packet_id); } - if flags & u8::from(PacketFlag::AckReply) > 0 { - let ack_packet_id = u16::from_be_bytes(packet[4..6].try_into().unwrap()); + if checked.has_flag(PacketFlag::AckReply) { + let ack_packet_id = checked.ack_number(); let mut acked_commands: Vec = vec![]; self.in_flight = self diff --git a/atem-connection-rs/src/atem_lib/atem_util.rs b/atem-connection-rs/src/atem_lib/atem_util.rs deleted file mode 100644 index 09b970a..0000000 --- a/atem-connection-rs/src/atem_lib/atem_util.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub const COMMAND_CONNECT_HELLO: [u8; 20] = [ - 0x10, 0x14, 0x53, 0xab, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3a, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, -]; diff --git a/atem-connection-rs/src/atem_lib/mod.rs b/atem-connection-rs/src/atem_lib/mod.rs index 6a8af6a..e0ef1be 100644 --- a/atem-connection-rs/src/atem_lib/mod.rs +++ b/atem-connection-rs/src/atem_lib/mod.rs @@ -1,4 +1,3 @@ -mod atem_packet; +pub mod atem_packet; pub mod atem_socket; mod atem_socket_inner; -pub mod atem_util;