Make AtemPacket operate on `&[u8] `, add field parsing

This commit is contained in:
Sam W 2025-06-10 10:19:09 +01:00
parent 9f45b7b6d9
commit 16829080db
8 changed files with 273 additions and 113 deletions

18
Cargo.lock generated
View File

@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo. # This file is automatically @generated by Cargo.
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 4
[[package]] [[package]]
name = "addr2line" name = "addr2line"
@ -80,6 +80,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"derive-getters", "derive-getters",
"derive-new", "derive-new",
"itertools",
"log", "log",
"tokio", "tokio",
"tokio-util", "tokio-util",
@ -249,6 +250,12 @@ dependencies = [
"syn 2.0.48", "syn 2.0.48",
] ]
[[package]]
name = "either"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]] [[package]]
name = "env_logger" name = "env_logger"
version = "0.9.0" version = "0.9.0"
@ -323,6 +330,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683"
[[package]]
name = "itertools"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285"
dependencies = [
"either",
]
[[package]] [[package]]
name = "lazy_static" name = "lazy_static"
version = "1.4.0" version = "1.4.0"

View File

@ -6,6 +6,12 @@ edition = "2021"
[dependencies] [dependencies]
derive-getters = "0.2.0" derive-getters = "0.2.0"
derive-new = "0.6.0" derive-new = "0.6.0"
itertools = {version = "0.14.0"}
log = "0.4.14" log = "0.4.14"
tokio = { version = "1.13.0", features = ["full"] } tokio = { version = "1.13.0", features = ["full"], optional = true }
tokio-util = "0.7.10" tokio-util = { version = "0.7.10", optional = true }
[features]
default = ["std"]
std = ["dep:tokio", "dep:tokio-util"]

View File

@ -0,0 +1,88 @@
//! Definitions and decoding of ATEM protocol fields
/// An uninterpreted ATEM protocol field - a 4-ascii character type plus variable length data
pub struct RawField<'a> {
pub r#type: &'a [u8; 4],
pub data: &'a [u8],
}
#[derive(Debug)]
pub struct _Ver {
pub major: u16,
pub minor: u16,
}
impl<'a> Field<'a> for _Ver {
const TYPE: [u8; 4] = [b'_', b'v', b'e', b'r'];
fn decode(data: &'a [u8]) -> Result<Self, FieldParsingError> {
let data = checked_len::<4>(data)?;
Ok(Self {
major: u16::from_be_bytes(data[0..=1].try_into().unwrap()),
minor: u16::from_be_bytes(data[2..=3].try_into().unwrap()),
})
}
}
#[derive(Debug)]
pub struct PrvI {
pub m_e_index: u8,
pub source_index: u16,
pub pvw_in_pgm: bool,
}
impl<'a> Field<'a> for PrvI {
const TYPE: [u8; 4] = [b'P', b'r', b'v', b'I'];
fn decode(data: &'a [u8]) -> Result<Self, FieldParsingError> {
let data = checked_len::<8>(data)?;
Ok(Self {
m_e_index: data[0],
source_index: u16::from_be_bytes(data[2..=3].try_into().unwrap()),
pvw_in_pgm: data[4] != 0,
})
}
}
#[derive(Debug)]
pub struct PrgI {
pub m_e_index: u8,
pub source_index: u16,
}
impl<'a> Field<'a> for PrgI {
const TYPE: [u8; 4] = [b'P', b'r', b'g', b'I'];
fn decode(data: &'a [u8]) -> Result<Self, FieldParsingError> {
let data = checked_len::<4>(data)?;
Ok(Self {
m_e_index: data[0],
source_index: u16::from_be_bytes(data[2..=3].try_into().unwrap()),
})
}
}
pub trait Field<'a>: Sized {
const TYPE: [u8; 4];
fn decode(data: &'a [u8]) -> Result<Self, FieldParsingError>;
fn try_from_raw(raw: RawField<'a>) -> Result<Self, FieldParsingError> {
if Self::TYPE != *raw.r#type {
Err(FieldParsingError::MismatchedFieldType)
} else {
Self::decode(&raw.data)
}
}
}
#[derive(Debug)]
pub enum FieldParsingError {
UnexpectedLength { expected: usize, got: usize },
UnknownFieldType { r#type: [u8; 4] },
MismatchedFieldType,
}
fn checked_len<'a, const LEN: usize>(data: &'a [u8]) -> Result<&'a [u8; LEN], FieldParsingError> {
data.try_into()
.map_err(|_| FieldParsingError::UnexpectedLength {
expected: LEN,
got: data.len(),
})
}

View File

@ -1,16 +1,28 @@
use core::{fmt::Display, str};
/// The "hello" packet to start communication with the ATEM
pub const COMMAND_CONNECT_HELLO: [u8; 20] = [
0x10, 0x14, 0x53, 0xab, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3a, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
];
#[derive(Debug)] #[derive(Debug)]
pub struct AtemPacket<'packet_buffer> { pub struct AtemPacket<T: AsRef<[u8]>> {
flags: u8, buf: T,
session_id: u16,
remote_packet_id: u16,
retransmit_requested_from_packet_id: Option<u16>,
ack_reply: Option<u16>,
body: Option<&'packet_buffer [u8]>,
} }
impl<'a> From<&'a [u8]> for AtemPacket<&'a [u8]> {
fn from(buf: &'a [u8]) -> Self {
AtemPacket { buf }
}
}
#[derive(Debug)]
pub enum AtemPacketErr { pub enum AtemPacketErr {
TooShort(String), /// The packet was too short
LengthDiffers(String), TooShort { got: usize },
/// The packet's stated and actual lengths were different
LengthDiffers { expected: u16, got: usize },
} }
#[derive(PartialEq)] #[derive(PartialEq)]
@ -34,82 +46,133 @@ impl From<PacketFlag> for u8 {
} }
} }
impl<'packet_buffer> AtemPacket<'packet_buffer> { impl<T: AsRef<[u8]>> AtemPacket<T> {
pub fn new_checked(buf: T) -> Result<Self, AtemPacketErr> {
let len = buf.as_ref().len();
if len < 12 {
return Err(AtemPacketErr::TooShort {
got: buf.as_ref().len(),
});
}
let p = Self { buf };
if p.length() as usize != len {
return Err(AtemPacketErr::LengthDiffers {
expected: p.length(),
got: len,
});
}
Ok(p)
}
pub fn length(&self) -> u16 {
u16::from_be_bytes(self.buf.as_ref()[0..=1].try_into().unwrap()) & 0x07ff
}
pub fn flags(&self) -> u8 {
self.buf.as_ref()[0] >> 3
}
pub fn session_id(&self) -> u16 { pub fn session_id(&self) -> u16 {
self.session_id u16::from_be_bytes(self.buf.as_ref()[2..=3].try_into().unwrap())
} }
pub fn remote_packet_id(&self) -> u16 { pub fn ack_number(&self) -> u16 {
self.remote_packet_id u16::from_be_bytes(self.buf.as_ref()[4..=5].try_into().unwrap())
} }
pub fn body(&self) -> Option<&[u8]> { pub fn remote_sequence_number(&self) -> u16 {
self.body u16::from_be_bytes(self.buf.as_ref()[9..=10].try_into().unwrap())
}
pub fn local_sequence_number(&self) -> u16 {
u16::from_be_bytes(self.buf.as_ref()[10..=11].try_into().unwrap())
} }
pub fn retransmit_request(&self) -> Option<u16> { pub fn retransmit_request(&self) -> Option<u16> {
self.retransmit_requested_from_packet_id self.has_flag(PacketFlag::RetransmitRequest)
.then_some(u16::from_be_bytes([
self.buf.as_ref()[6],
self.buf.as_ref()[7],
]))
} }
pub fn ack_reply(&self) -> Option<u16> { pub fn ack_reply(&self) -> Option<u16> {
self.ack_reply self.has_flag(PacketFlag::AckReply)
.then_some(u16::from_be_bytes([
self.buf.as_ref()[4],
self.buf.as_ref()[5],
]))
} }
/// Return true if this packet has the given [`PacketFlag`]
pub fn has_flag(&self, flag: PacketFlag) -> bool { pub fn has_flag(&self, flag: PacketFlag) -> bool {
self.flags & u8::from(flag) > 0 self.flags() & u8::from(flag) > 0
}
/// Get an iterator over the `Field`s in this packet.
///
/// Returns None if this is a packet without fields.
pub fn fields(&self) -> Fields {
// TODO: do we only ever get newsessionid during the handshake (i.e. not in a packet with fields)?
let has_fields = !self.has_flag(PacketFlag::NewSessionId);
Fields::new(if has_fields { self.body() } else { &[] })
}
pub fn body(&self) -> &[u8] {
&self.buf.as_ref()[12..]
} }
} }
impl<'packet_buffer> TryFrom<&'packet_buffer [u8]> for AtemPacket<'packet_buffer> { impl<T: AsRef<[u8]>> Display for AtemPacket<T> {
type Error = AtemPacketErr; fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
fn try_from(buffer: &'packet_buffer [u8]) -> Result<Self, Self::Error> { f,
if buffer.len() < 12 { "len: {}, sid: {}, flags: {}, ack: {}, localseq: {}, remoteseq: {}",
return Err(AtemPacketErr::TooShort(format!( self.length(),
"Invalid packet from ATEM {:x?}", self.session_id(),
buffer self.flags(),
))); self.ack_number(),
} self.remote_sequence_number(),
self.local_sequence_number(),
let length = u16::from_be_bytes([buffer[0], buffer[1]]) & 0x07ff; )
if length as usize != buffer.len() { }
return Err(AtemPacketErr::LengthDiffers(format!( }
"Length of message differs, expected {} got {}",
length, /// An ATEM protocol field - a 4-ascii character type plus variable length data
buffer.len() pub struct Field<'a> {
))); pub r#type: &'a str,
} pub data: &'a [u8],
}
let flags = buffer[0] >> 3;
let session_id = u16::from_be_bytes([buffer[2], buffer[3]]); /// Created by [`AtemPacket::fields`]
let remote_packet_id = u16::from_be_bytes([buffer[10], buffer[11]]); pub struct Fields<'a> {
data: &'a [u8],
let body = if buffer.len() > 12 { // The offset of the next field in the packet
Some(&buffer[12..]) offset: usize,
} else { }
None
}; impl<'a> Fields<'a> {
pub(crate) fn new(data: &'a [u8]) -> Self {
let retransmit_requested_from_packet_id = Self { data, offset: 0 }
if flags & u8::from(PacketFlag::RetransmitRequest) > 0 { }
Some(u16::from_be_bytes([buffer[6], buffer[7]])) }
} else {
None impl<'a> Iterator for Fields<'a> {
}; type Item = Field<'a>;
fn next(&mut self) -> Option<Self::Item> {
let ack_reply = if flags & u8::from(PacketFlag::AckReply) > 0 { let remain = self.data.len() - self.offset;
Some(u16::from_be_bytes([buffer[4], buffer[5]])) if remain == 0 {
} else { return None;
None } else if remain < 8 {
}; // TODO: is 8 indeed the minimum size for something here? (i.e. no field data)
panic!("Oh no");
Ok(AtemPacket { }
flags,
session_id, let length =
remote_packet_id, u16::from_be_bytes(self.data[self.offset..=self.offset + 1].try_into().unwrap());
body, // TODO: sanity check length
retransmit_requested_from_packet_id, let r#type = str::from_utf8(&self.data[self.offset + 4..=self.offset + 7]).unwrap();
ack_reply, let data = &self.data[self.offset + 8..self.offset + (length as usize)];
}) self.offset += (length as usize);
Some(Field { r#type, data })
} }
} }

View File

@ -7,6 +7,7 @@ use std::{
time::{Duration, SystemTime}, time::{Duration, SystemTime},
}; };
use itertools::Itertools;
use tokio::{ use tokio::{
net::UdpSocket, net::UdpSocket,
select, select,
@ -14,7 +15,7 @@ use tokio::{
}; };
use crate::{ use crate::{
atem_lib::{atem_packet::AtemPacket, atem_util}, atem_lib::atem_packet::{self, AtemPacket, COMMAND_CONNECT_HELLO},
commands::{ commands::{
command_base::{BasicWritableCommand, DeserializedCommand}, command_base::{BasicWritableCommand, DeserializedCommand},
parse_commands::deserialize_commands, parse_commands::deserialize_commands,
@ -258,7 +259,7 @@ impl AtemSocket {
self.in_flight = vec![]; self.in_flight = vec![];
log::debug!("Reconnect"); log::debug!("Reconnect");
self.send_packet(&atem_util::COMMAND_CONNECT_HELLO).await; self.send_packet(&COMMAND_CONNECT_HELLO).await;
self.connection_state = ConnectionState::SynSent; self.connection_state = ConnectionState::SynSent;
Ok(()) Ok(())
@ -386,16 +387,21 @@ impl AtemSocket {
} }
async fn recieved_packet(&mut self, packet: &[u8]) { async fn recieved_packet(&mut self, packet: &[u8]) {
let Ok(atem_packet): Result<AtemPacket, _> = packet.try_into() else { let Ok(atem_packet): Result<AtemPacket<_>, _> = packet.try_into() else {
return; return;
}; };
log::debug!("Received {:x?}", atem_packet); log::debug!("Received {}", atem_packet,);
log::debug!(
"fields: {}",
atem_packet.fields().map(|f| f.r#type).join(",")
);
self.last_received_at = SystemTime::now(); self.last_received_at = SystemTime::now();
self.session_id = atem_packet.session_id(); self.session_id = atem_packet.session_id();
let remote_packet_id = atem_packet.remote_packet_id(); // TODO: naming seems rather off here
let remote_packet_id = atem_packet.local_sequence_number();
if atem_packet.has_flag(PacketFlag::NewSessionId) { if atem_packet.has_flag(PacketFlag::NewSessionId) {
log::debug!("New session"); log::debug!("New session");
@ -418,9 +424,7 @@ impl AtemSocket {
self.last_received_packed_id = remote_packet_id; self.last_received_packed_id = remote_packet_id;
self.send_or_queue_ack().await; self.send_or_queue_ack().await;
if let Some(body) = atem_packet.body() { self.on_commands_received(atem_packet.body());
self.on_commands_received(body);
}
} else if self } else if self
.is_packet_covered_by_ack(self.last_received_packed_id, remote_packet_id) .is_packet_covered_by_ack(self.last_received_packed_id, remote_packet_id)
{ {
@ -537,9 +541,11 @@ impl AtemSocket {
} }
fn on_commands_received(&mut self, payload: &[u8]) { fn on_commands_received(&mut self, payload: &[u8]) {
let _ = self if !payload.is_empty() {
.atem_event_tx let _ = self
.send(AtemSocketEvent::ReceivedCommands(payload.to_vec())); .atem_event_tx
.send(AtemSocketEvent::ReceivedCommands(payload.to_vec()));
}
} }
fn on_command_acknowledged(&mut self, packets: Vec<AckedPacket>) { fn on_command_acknowledged(&mut self, packets: Vec<AckedPacket>) {

View File

@ -1,4 +0,0 @@
pub const COMMAND_CONNECT_HELLO: [u8; 20] = [
0x10, 0x14, 0x53, 0xab, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3a, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
];

View File

@ -1,3 +1,2 @@
mod atem_packet; pub mod atem_packet;
pub mod atem_socket; pub mod atem_socket;
pub mod atem_util;

View File

@ -1,6 +1,7 @@
use std::{collections::VecDeque, sync::Arc}; use std::{collections::VecDeque, sync::Arc};
use crate::{ use crate::{
atem_lib::atem_packet::{AtemPacket, Fields},
commands::device_profile::version::{deserialize_version, DESERIALIZE_VERSION_RAW_NAME}, commands::device_profile::version::{deserialize_version, DESERIALIZE_VERSION_RAW_NAME},
enums::ProtocolVersion, enums::ProtocolVersion,
}; };
@ -25,43 +26,28 @@ use super::{
time::{TimeDeserializer, DESERIALIZE_TIME_RAW_NAME}, time::{TimeDeserializer, DESERIALIZE_TIME_RAW_NAME},
}; };
pub fn deserialize_commands( pub fn deserialize_commands<T: AsRef<[u8]>>(
payload: &[u8], payload: T,
version: &mut ProtocolVersion, version: &mut ProtocolVersion,
) -> VecDeque<Arc<dyn DeserializedCommand>> { ) -> VecDeque<Arc<dyn DeserializedCommand>> {
let mut parsed_commands: VecDeque<Arc<dyn DeserializedCommand>> = VecDeque::new(); let mut parsed_commands: VecDeque<Arc<dyn DeserializedCommand>> = VecDeque::new();
let mut head = 0;
while payload.len() > head + 8 { for field in Fields::new(payload.as_ref()) {
let length = u16::from_be_bytes([payload[head], payload[head + 1]]) as usize; let name: &str = field.r#type.try_into().unwrap();
let Ok(name) = String::from_utf8(payload[(head + 4)..(head + 8)].to_vec()) else { log::debug!("Received command {} with length {}", name, field.data.len(),);
break;
};
if length < 8 { if field.r#type == DESERIALIZE_VERSION_RAW_NAME {
break; let version_command = deserialize_version(field.data);
}
log::debug!("Received command {} with length {}", name, length);
let command_buffer = &payload[head + 8..head + length];
if name == DESERIALIZE_VERSION_RAW_NAME {
let version_command = deserialize_version(command_buffer);
*version = version_command.version.clone(); *version = version_command.version.clone();
log::info!("Switched to protocol version {}", version); log::info!("Switched to protocol version {}", version);
parsed_commands.push_back(Arc::new(version_command)); parsed_commands.push_back(Arc::new(version_command));
} else if let Some(deserializer) = command_deserializer_from_string(name.as_str()) { } else if let Some(deserializer) = command_deserializer_from_string(name) {
let deserialized_command = deserializer.deserialize(command_buffer, version); let deserialized_command = deserializer.deserialize(field.data, version);
log::debug!("Received {:?}", deserialized_command); log::debug!("Received {:?}", deserialized_command);
parsed_commands.push_back(deserialized_command); parsed_commands.push_back(deserialized_command);
} else { } else {
log::warn!("Received command {name} for which there is no deserializer."); log::warn!("Received command {name} for which there is no deserializer.");
// TODO: Remove!
todo!("Write deserializer for {name}.");
} }
head += length;
} }
parsed_commands parsed_commands