diff --git a/atem-connection-rs/src/atem_lib/atem_packet.rs b/atem-connection-rs/src/atem_lib/atem_packet.rs index d597c0b..c81be4e 100755 --- a/atem-connection-rs/src/atem_lib/atem_packet.rs +++ b/atem-connection-rs/src/atem_lib/atem_packet.rs @@ -1,3 +1,4 @@ +//! This module contains [`AtemPacket`] which is a zero*ish* copy abstraction over a [`AsRef`]`<[u8]>`. use core::{fmt::Display, str}; use enumflags2::{BitFlags, bitflags}; @@ -9,20 +10,40 @@ pub const COMMAND_CONNECT_HELLO: [u8; 20] = [ 0x00, 0x00, 0x00, 0x00, ]; +/// Maximum atem packet length. This is determined by the maximum value that can be stored +/// in the length field. +pub const MAX_LEN: usize = field::LEN_MASK as usize; + +mod field { + use core::ops::{Range, RangeInclusive}; + + pub(crate) const FLAGS_LEN_H: usize = 0; + pub(crate) const FLAGS_MASK: u8 = 0b1111_1000; + pub(crate) const LEN_L: usize = 1; + pub(crate) const LEN: RangeInclusive = FLAGS_LEN_H..=LEN_L; + pub(crate) const LEN_MASK: u16 = 0x7ff; + + pub(crate) const SESSION_ID: Range = 2..4; + pub(crate) const ACK_NUMBER: Range = 4..6; + pub(crate) const REMOTE_SEQ_NUM: Range = 8..10; + pub(crate) const LOCAL_SEQ_NUM: Range = 10..12; +} + +/// An ATEM protocol packet #[derive(Debug)] pub struct AtemPacket> { buf: T, } impl<'a> TryFrom<&'a [u8]> for AtemPacket<&'a [u8]> { - type Error = AtemPacketErr; + type Error = AtemPacketParseError; fn try_from(buf: &'a [u8]) -> Result { AtemPacket::new_checked(buf) } } #[derive(Debug)] -pub enum AtemPacketErr { +pub enum AtemPacketParseError { /// The packet was too short TooShort { got: usize, @@ -35,28 +56,116 @@ pub enum AtemPacketErr { InvalidFlags, } +/// ATEM packet flags #[bitflags] #[repr(u8)] #[derive(PartialEq, Copy, Clone, Debug)] pub enum PacketFlag { + /// AKA "reliable". This packet should be ACKed by the other end AckRequest = 0x1, + /// AKA "syn" - used in the connection handshake NewSessionId = 0x2, + /// This packet is a retransmit of a previous one IsRetransmit = 0x4, + /// Request retransmission of a previous sequence ID RetransmitRequest = 0x8, + /// This packet is an ACK AckReply = 0x10, } +/// An error while constructing a packet +#[derive(Debug)] +pub enum AtemPacketInitError { + /// The provided buf was too small + BufTooSmall { need: usize, got: usize }, + /// The provided payload was too long + DataTooLong { got: usize, max: usize }, +} + +impl<'a> AtemPacket<&'a mut [u8]> { + /// Initialise an [`AtemPacket`] into the given `buf`, with optional payload copied from `data`. + /// Returns the constructed packet, and the remaining spare space in `buf`, if there was any. + pub fn init<'b>( + buf: &'a mut [u8], + flags: BitFlags, + session_id: u16, + local_seq_num: u16, + data: Option<&'b [u8]>, + ) -> Result<(Self, &'a mut [u8]), AtemPacketInitError> { + let len = 12 + data.map_or(0, |d| d.len()); + if len > buf.len() { + return Err(AtemPacketInitError::BufTooSmall { + need: len, + got: buf.len(), + }); + } + if len > MAX_LEN { + return Err(AtemPacketInitError::DataTooLong { + got: len - 12, + max: MAX_LEN - 12, + }); + } + let (pkt, rem) = buf.split_at_mut(len); + let mut p = AtemPacket { buf: pkt }; + p.set_flags(flags) + .set_session_id(session_id) + .set_local_seq_num(local_seq_num) + .set_len(len.try_into().unwrap()); + if let Some(d) = data { + p.set_data(d); + } + Ok((p, rem)) + } +} + +impl + AsMut<[u8]>> AtemPacket { + pub fn set_flags(&mut self, flags: BitFlags) -> &mut Self { + let prev = self.buf.as_ref()[field::FLAGS_LEN_H]; + self.buf.as_mut()[field::FLAGS_LEN_H] = (prev & !field::FLAGS_MASK) | (flags.bits() << 3); + self + } + + pub fn set_len(&mut self, value: u16) -> &mut Self { + let v = value.to_be_bytes(); + self.buf.as_mut()[field::FLAGS_LEN_H] &= v[0] | field::FLAGS_MASK; + self.buf.as_mut()[field::LEN_L] = v[1]; + self + } + + pub fn set_session_id(&mut self, value: u16) -> &mut Self { + self.buf.as_mut()[field::SESSION_ID].copy_from_slice(&value.to_be_bytes()); + self + } + pub fn set_local_seq_num(&mut self, value: u16) -> &mut Self { + self.buf.as_mut()[field::LOCAL_SEQ_NUM].copy_from_slice(&value.to_be_bytes()); + self + } + pub fn set_remote_seq_num(&mut self, value: u16) -> &mut Self { + self.buf.as_mut()[field::REMOTE_SEQ_NUM].copy_from_slice(&value.to_be_bytes()); + self + } + pub fn set_data<'a>(&mut self, value: &'a [u8]) -> &mut Self { + self.buf.as_mut()[12..].copy_from_slice(value); + self + } + + pub fn set_ack_num(&mut self, value: u16) -> &mut Self { + self.buf.as_mut()[field::ACK_NUMBER].copy_from_slice(&value.to_be_bytes()); + self + } +} + impl> AtemPacket { - pub fn new_checked(buf: T) -> Result { + pub fn new_checked(buf: T) -> Result { let len = buf.as_ref().len(); if len < 12 { - return Err(AtemPacketErr::TooShort { + return Err(AtemPacketParseError::TooShort { got: buf.as_ref().len(), }); } let p = Self { buf }; if p.length() as usize != len { - return Err(AtemPacketErr::LengthDiffers { + return Err(AtemPacketParseError::LengthDiffers { expected: p.length(), got: len, }); @@ -65,15 +174,20 @@ impl> AtemPacket { let _: BitFlags = p .flags_raw() .try_into() - .map_err(|_| AtemPacketErr::InvalidFlags)?; + .map_err(|_| AtemPacketParseError::InvalidFlags)?; Ok(p) } + + /// Consumes self, returning the inner packet buffer + pub fn inner(self) -> T { + self.buf + } pub fn length(&self) -> u16 { - u16::from_be_bytes(self.buf.as_ref()[0..=1].try_into().unwrap()) & 0x07ff + u16::from_be_bytes(self.buf.as_ref()[field::LEN].try_into().unwrap()) & field::LEN_MASK } fn flags_raw(&self) -> u8 { - self.buf.as_ref()[0] >> 3 + self.buf.as_ref()[field::FLAGS_LEN_H] >> 3 } pub fn flags(&self) -> BitFlags { @@ -82,19 +196,19 @@ impl> AtemPacket { } pub fn session_id(&self) -> u16 { - u16::from_be_bytes(self.buf.as_ref()[2..=3].try_into().unwrap()) + u16::from_be_bytes(self.buf.as_ref()[field::SESSION_ID].try_into().unwrap()) } pub fn ack_number(&self) -> u16 { - u16::from_be_bytes(self.buf.as_ref()[4..=5].try_into().unwrap()) + u16::from_be_bytes(self.buf.as_ref()[field::ACK_NUMBER].try_into().unwrap()) } pub fn remote_sequence_number(&self) -> u16 { - u16::from_be_bytes(self.buf.as_ref()[9..=10].try_into().unwrap()) + u16::from_be_bytes(self.buf.as_ref()[field::REMOTE_SEQ_NUM].try_into().unwrap()) } pub fn local_sequence_number(&self) -> u16 { - u16::from_be_bytes(self.buf.as_ref()[10..=11].try_into().unwrap()) + u16::from_be_bytes(self.buf.as_ref()[field::LOCAL_SEQ_NUM].try_into().unwrap()) } pub fn retransmit_request(&self) -> Option { @@ -183,3 +297,21 @@ impl<'a> Iterator for RawFields<'a> { Some(RawField { r#type, data }) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_init() { + let mut buf = [0u8; 12]; + let (mut p, _) = + AtemPacket::init(&mut buf, PacketFlag::AckReply.into(), 1234, 5678, None).unwrap(); + p.set_remote_seq_num(9012); + assert!(p.flags().contains(PacketFlag::AckReply)); + assert_eq!(p.session_id(), 1234); + assert_eq!(p.local_sequence_number(), 5678); + assert_eq!(p.remote_sequence_number(), 9012); + assert_eq!(p.length(), 12); + } +} diff --git a/atem-connection-rs/src/atem_lib/atem_socket.rs b/atem-connection-rs/src/atem_lib/atem_socket.rs index f6dcd10..2739436 100644 --- a/atem-connection-rs/src/atem_lib/atem_socket.rs +++ b/atem-connection-rs/src/atem_lib/atem_socket.rs @@ -303,22 +303,18 @@ impl AtemSocket { self.next_send_packet_id = 0; } - let opcode = u16::from(PacketFlag::AckRequest as u8) << 11; - let mut buffer = vec![0; 20 + payload.len()]; - // Headers - buffer[0..2].copy_from_slice(&u16::to_be_bytes(opcode | (payload.len() as u16 + 20))); - buffer[2..4].copy_from_slice(&u16::to_be_bytes(self.session_id)); - buffer[10..12].copy_from_slice(&u16::to_be_bytes(packet_id)); + let (p, _) = AtemPacket::init( + &mut buffer, + PacketFlag::AckRequest.into(), + self.session_id, + packet_id, + Some([raw_name.as_bytes(), payload].concat().as_slice()), + ) + .unwrap(); - // Command - buffer[12..14].copy_from_slice(&u16::to_be_bytes(payload.len() as u16 + 8)); - buffer[16..20].copy_from_slice(raw_name.as_bytes()); - - // Body - buffer[20..20 + payload.len()].copy_from_slice(payload); - self.send_packet(&buffer).await; + self.send_packet(&p.inner()).await; self.in_flight.push(InFlightPacket { packet_id, @@ -401,7 +397,7 @@ impl AtemSocket { self.last_received_at = Instant::now(); self.session_id = atem_packet.session_id(); - // TODO: naming seems rather off here + // TODO: bit of a naming clash here let remote_packet_id = atem_packet.local_sequence_number(); if atem_packet.flags().contains(PacketFlag::NewSessionId) { @@ -483,13 +479,17 @@ impl AtemSocket { async fn send_ack(&mut self, packet_id: u16) { log::debug!("Sending ack for packet {:x?}", packet_id); - let flag: u8 = PacketFlag::AckReply as u8; - let opcode = u16::from(flag) << 11; let mut buffer: [u8; ACK_PACKET_LENGTH as _] = [0; 12]; - buffer[0..2].copy_from_slice(&u16::to_be_bytes(opcode | ACK_PACKET_LENGTH)); - buffer[2..4].copy_from_slice(&u16::to_be_bytes(self.session_id)); - buffer[4..6].copy_from_slice(&u16::to_be_bytes(packet_id)); - self.send_packet(&buffer).await; + let (mut p, _) = AtemPacket::init( + &mut buffer, + PacketFlag::AckReply.into(), + self.session_id, + 0, + None, + ) + .unwrap(); + p.set_ack_num(packet_id); + self.send_packet(p.inner()).await; } async fn retransmit_from(&mut self, from_id: u16) {