Make AtemPacket operate on `&[u8] `, add field parsing

This commit is contained in:
Sam W 2025-06-10 10:19:09 +01:00
parent 9f45b7b6d9
commit 16829080db
8 changed files with 273 additions and 113 deletions

18
Cargo.lock generated
View File

@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
version = 4
[[package]]
name = "addr2line"
@ -80,6 +80,7 @@ version = "0.1.0"
dependencies = [
"derive-getters",
"derive-new",
"itertools",
"log",
"tokio",
"tokio-util",
@ -249,6 +250,12 @@ dependencies = [
"syn 2.0.48",
]
[[package]]
name = "either"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]]
name = "env_logger"
version = "0.9.0"
@ -323,6 +330,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683"
[[package]]
name = "itertools"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285"
dependencies = [
"either",
]
[[package]]
name = "lazy_static"
version = "1.4.0"

View File

@ -6,6 +6,12 @@ edition = "2021"
[dependencies]
derive-getters = "0.2.0"
derive-new = "0.6.0"
itertools = {version = "0.14.0"}
log = "0.4.14"
tokio = { version = "1.13.0", features = ["full"] }
tokio-util = "0.7.10"
tokio = { version = "1.13.0", features = ["full"], optional = true }
tokio-util = { version = "0.7.10", optional = true }
[features]
default = ["std"]
std = ["dep:tokio", "dep:tokio-util"]

View File

@ -0,0 +1,88 @@
//! Definitions and decoding of ATEM protocol fields
/// An uninterpreted ATEM protocol field - a 4-ascii character type plus variable length data
pub struct RawField<'a> {
pub r#type: &'a [u8; 4],
pub data: &'a [u8],
}
#[derive(Debug)]
pub struct _Ver {
pub major: u16,
pub minor: u16,
}
impl<'a> Field<'a> for _Ver {
const TYPE: [u8; 4] = [b'_', b'v', b'e', b'r'];
fn decode(data: &'a [u8]) -> Result<Self, FieldParsingError> {
let data = checked_len::<4>(data)?;
Ok(Self {
major: u16::from_be_bytes(data[0..=1].try_into().unwrap()),
minor: u16::from_be_bytes(data[2..=3].try_into().unwrap()),
})
}
}
#[derive(Debug)]
pub struct PrvI {
pub m_e_index: u8,
pub source_index: u16,
pub pvw_in_pgm: bool,
}
impl<'a> Field<'a> for PrvI {
const TYPE: [u8; 4] = [b'P', b'r', b'v', b'I'];
fn decode(data: &'a [u8]) -> Result<Self, FieldParsingError> {
let data = checked_len::<8>(data)?;
Ok(Self {
m_e_index: data[0],
source_index: u16::from_be_bytes(data[2..=3].try_into().unwrap()),
pvw_in_pgm: data[4] != 0,
})
}
}
#[derive(Debug)]
pub struct PrgI {
pub m_e_index: u8,
pub source_index: u16,
}
impl<'a> Field<'a> for PrgI {
const TYPE: [u8; 4] = [b'P', b'r', b'g', b'I'];
fn decode(data: &'a [u8]) -> Result<Self, FieldParsingError> {
let data = checked_len::<4>(data)?;
Ok(Self {
m_e_index: data[0],
source_index: u16::from_be_bytes(data[2..=3].try_into().unwrap()),
})
}
}
pub trait Field<'a>: Sized {
const TYPE: [u8; 4];
fn decode(data: &'a [u8]) -> Result<Self, FieldParsingError>;
fn try_from_raw(raw: RawField<'a>) -> Result<Self, FieldParsingError> {
if Self::TYPE != *raw.r#type {
Err(FieldParsingError::MismatchedFieldType)
} else {
Self::decode(&raw.data)
}
}
}
#[derive(Debug)]
pub enum FieldParsingError {
UnexpectedLength { expected: usize, got: usize },
UnknownFieldType { r#type: [u8; 4] },
MismatchedFieldType,
}
fn checked_len<'a, const LEN: usize>(data: &'a [u8]) -> Result<&'a [u8; LEN], FieldParsingError> {
data.try_into()
.map_err(|_| FieldParsingError::UnexpectedLength {
expected: LEN,
got: data.len(),
})
}

View File

@ -1,16 +1,28 @@
use core::{fmt::Display, str};
/// The "hello" packet to start communication with the ATEM
pub const COMMAND_CONNECT_HELLO: [u8; 20] = [
0x10, 0x14, 0x53, 0xab, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3a, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
];
#[derive(Debug)]
pub struct AtemPacket<'packet_buffer> {
flags: u8,
session_id: u16,
remote_packet_id: u16,
retransmit_requested_from_packet_id: Option<u16>,
ack_reply: Option<u16>,
body: Option<&'packet_buffer [u8]>,
pub struct AtemPacket<T: AsRef<[u8]>> {
buf: T,
}
impl<'a> From<&'a [u8]> for AtemPacket<&'a [u8]> {
fn from(buf: &'a [u8]) -> Self {
AtemPacket { buf }
}
}
#[derive(Debug)]
pub enum AtemPacketErr {
TooShort(String),
LengthDiffers(String),
/// The packet was too short
TooShort { got: usize },
/// The packet's stated and actual lengths were different
LengthDiffers { expected: u16, got: usize },
}
#[derive(PartialEq)]
@ -34,82 +46,133 @@ impl From<PacketFlag> for u8 {
}
}
impl<'packet_buffer> AtemPacket<'packet_buffer> {
impl<T: AsRef<[u8]>> AtemPacket<T> {
pub fn new_checked(buf: T) -> Result<Self, AtemPacketErr> {
let len = buf.as_ref().len();
if len < 12 {
return Err(AtemPacketErr::TooShort {
got: buf.as_ref().len(),
});
}
let p = Self { buf };
if p.length() as usize != len {
return Err(AtemPacketErr::LengthDiffers {
expected: p.length(),
got: len,
});
}
Ok(p)
}
pub fn length(&self) -> u16 {
u16::from_be_bytes(self.buf.as_ref()[0..=1].try_into().unwrap()) & 0x07ff
}
pub fn flags(&self) -> u8 {
self.buf.as_ref()[0] >> 3
}
pub fn session_id(&self) -> u16 {
self.session_id
u16::from_be_bytes(self.buf.as_ref()[2..=3].try_into().unwrap())
}
pub fn remote_packet_id(&self) -> u16 {
self.remote_packet_id
pub fn ack_number(&self) -> u16 {
u16::from_be_bytes(self.buf.as_ref()[4..=5].try_into().unwrap())
}
pub fn body(&self) -> Option<&[u8]> {
self.body
pub fn remote_sequence_number(&self) -> u16 {
u16::from_be_bytes(self.buf.as_ref()[9..=10].try_into().unwrap())
}
pub fn local_sequence_number(&self) -> u16 {
u16::from_be_bytes(self.buf.as_ref()[10..=11].try_into().unwrap())
}
pub fn retransmit_request(&self) -> Option<u16> {
self.retransmit_requested_from_packet_id
self.has_flag(PacketFlag::RetransmitRequest)
.then_some(u16::from_be_bytes([
self.buf.as_ref()[6],
self.buf.as_ref()[7],
]))
}
pub fn ack_reply(&self) -> Option<u16> {
self.ack_reply
self.has_flag(PacketFlag::AckReply)
.then_some(u16::from_be_bytes([
self.buf.as_ref()[4],
self.buf.as_ref()[5],
]))
}
/// Return true if this packet has the given [`PacketFlag`]
pub fn has_flag(&self, flag: PacketFlag) -> bool {
self.flags & u8::from(flag) > 0
self.flags() & u8::from(flag) > 0
}
/// Get an iterator over the `Field`s in this packet.
///
/// Returns None if this is a packet without fields.
pub fn fields(&self) -> Fields {
// TODO: do we only ever get newsessionid during the handshake (i.e. not in a packet with fields)?
let has_fields = !self.has_flag(PacketFlag::NewSessionId);
Fields::new(if has_fields { self.body() } else { &[] })
}
pub fn body(&self) -> &[u8] {
&self.buf.as_ref()[12..]
}
}
impl<'packet_buffer> TryFrom<&'packet_buffer [u8]> for AtemPacket<'packet_buffer> {
type Error = AtemPacketErr;
fn try_from(buffer: &'packet_buffer [u8]) -> Result<Self, Self::Error> {
if buffer.len() < 12 {
return Err(AtemPacketErr::TooShort(format!(
"Invalid packet from ATEM {:x?}",
buffer
)));
}
let length = u16::from_be_bytes([buffer[0], buffer[1]]) & 0x07ff;
if length as usize != buffer.len() {
return Err(AtemPacketErr::LengthDiffers(format!(
"Length of message differs, expected {} got {}",
length,
buffer.len()
)));
}
let flags = buffer[0] >> 3;
let session_id = u16::from_be_bytes([buffer[2], buffer[3]]);
let remote_packet_id = u16::from_be_bytes([buffer[10], buffer[11]]);
let body = if buffer.len() > 12 {
Some(&buffer[12..])
} else {
None
};
let retransmit_requested_from_packet_id =
if flags & u8::from(PacketFlag::RetransmitRequest) > 0 {
Some(u16::from_be_bytes([buffer[6], buffer[7]]))
} else {
None
};
let ack_reply = if flags & u8::from(PacketFlag::AckReply) > 0 {
Some(u16::from_be_bytes([buffer[4], buffer[5]]))
} else {
None
};
Ok(AtemPacket {
flags,
session_id,
remote_packet_id,
body,
retransmit_requested_from_packet_id,
ack_reply,
})
impl<T: AsRef<[u8]>> Display for AtemPacket<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"len: {}, sid: {}, flags: {}, ack: {}, localseq: {}, remoteseq: {}",
self.length(),
self.session_id(),
self.flags(),
self.ack_number(),
self.remote_sequence_number(),
self.local_sequence_number(),
)
}
}
/// An ATEM protocol field - a 4-ascii character type plus variable length data
pub struct Field<'a> {
pub r#type: &'a str,
pub data: &'a [u8],
}
/// Created by [`AtemPacket::fields`]
pub struct Fields<'a> {
data: &'a [u8],
// The offset of the next field in the packet
offset: usize,
}
impl<'a> Fields<'a> {
pub(crate) fn new(data: &'a [u8]) -> Self {
Self { data, offset: 0 }
}
}
impl<'a> Iterator for Fields<'a> {
type Item = Field<'a>;
fn next(&mut self) -> Option<Self::Item> {
let remain = self.data.len() - self.offset;
if remain == 0 {
return None;
} else if remain < 8 {
// TODO: is 8 indeed the minimum size for something here? (i.e. no field data)
panic!("Oh no");
}
let length =
u16::from_be_bytes(self.data[self.offset..=self.offset + 1].try_into().unwrap());
// TODO: sanity check length
let r#type = str::from_utf8(&self.data[self.offset + 4..=self.offset + 7]).unwrap();
let data = &self.data[self.offset + 8..self.offset + (length as usize)];
self.offset += (length as usize);
Some(Field { r#type, data })
}
}

View File

@ -7,6 +7,7 @@ use std::{
time::{Duration, SystemTime},
};
use itertools::Itertools;
use tokio::{
net::UdpSocket,
select,
@ -14,7 +15,7 @@ use tokio::{
};
use crate::{
atem_lib::{atem_packet::AtemPacket, atem_util},
atem_lib::atem_packet::{self, AtemPacket, COMMAND_CONNECT_HELLO},
commands::{
command_base::{BasicWritableCommand, DeserializedCommand},
parse_commands::deserialize_commands,
@ -258,7 +259,7 @@ impl AtemSocket {
self.in_flight = vec![];
log::debug!("Reconnect");
self.send_packet(&atem_util::COMMAND_CONNECT_HELLO).await;
self.send_packet(&COMMAND_CONNECT_HELLO).await;
self.connection_state = ConnectionState::SynSent;
Ok(())
@ -386,16 +387,21 @@ impl AtemSocket {
}
async fn recieved_packet(&mut self, packet: &[u8]) {
let Ok(atem_packet): Result<AtemPacket, _> = packet.try_into() else {
let Ok(atem_packet): Result<AtemPacket<_>, _> = packet.try_into() else {
return;
};
log::debug!("Received {:x?}", atem_packet);
log::debug!("Received {}", atem_packet,);
log::debug!(
"fields: {}",
atem_packet.fields().map(|f| f.r#type).join(",")
);
self.last_received_at = SystemTime::now();
self.session_id = atem_packet.session_id();
let remote_packet_id = atem_packet.remote_packet_id();
// TODO: naming seems rather off here
let remote_packet_id = atem_packet.local_sequence_number();
if atem_packet.has_flag(PacketFlag::NewSessionId) {
log::debug!("New session");
@ -418,9 +424,7 @@ impl AtemSocket {
self.last_received_packed_id = remote_packet_id;
self.send_or_queue_ack().await;
if let Some(body) = atem_packet.body() {
self.on_commands_received(body);
}
self.on_commands_received(atem_packet.body());
} else if self
.is_packet_covered_by_ack(self.last_received_packed_id, remote_packet_id)
{
@ -537,10 +541,12 @@ impl AtemSocket {
}
fn on_commands_received(&mut self, payload: &[u8]) {
if !payload.is_empty() {
let _ = self
.atem_event_tx
.send(AtemSocketEvent::ReceivedCommands(payload.to_vec()));
}
}
fn on_command_acknowledged(&mut self, packets: Vec<AckedPacket>) {
for ack in packets {

View File

@ -1,4 +0,0 @@
pub const COMMAND_CONNECT_HELLO: [u8; 20] = [
0x10, 0x14, 0x53, 0xab, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3a, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
];

View File

@ -1,3 +1,2 @@
mod atem_packet;
pub mod atem_packet;
pub mod atem_socket;
pub mod atem_util;

View File

@ -1,6 +1,7 @@
use std::{collections::VecDeque, sync::Arc};
use crate::{
atem_lib::atem_packet::{AtemPacket, Fields},
commands::device_profile::version::{deserialize_version, DESERIALIZE_VERSION_RAW_NAME},
enums::ProtocolVersion,
};
@ -25,43 +26,28 @@ use super::{
time::{TimeDeserializer, DESERIALIZE_TIME_RAW_NAME},
};
pub fn deserialize_commands(
payload: &[u8],
pub fn deserialize_commands<T: AsRef<[u8]>>(
payload: T,
version: &mut ProtocolVersion,
) -> VecDeque<Arc<dyn DeserializedCommand>> {
let mut parsed_commands: VecDeque<Arc<dyn DeserializedCommand>> = VecDeque::new();
let mut head = 0;
while payload.len() > head + 8 {
let length = u16::from_be_bytes([payload[head], payload[head + 1]]) as usize;
let Ok(name) = String::from_utf8(payload[(head + 4)..(head + 8)].to_vec()) else {
break;
};
for field in Fields::new(payload.as_ref()) {
let name: &str = field.r#type.try_into().unwrap();
log::debug!("Received command {} with length {}", name, field.data.len(),);
if length < 8 {
break;
}
log::debug!("Received command {} with length {}", name, length);
let command_buffer = &payload[head + 8..head + length];
if name == DESERIALIZE_VERSION_RAW_NAME {
let version_command = deserialize_version(command_buffer);
if field.r#type == DESERIALIZE_VERSION_RAW_NAME {
let version_command = deserialize_version(field.data);
*version = version_command.version.clone();
log::info!("Switched to protocol version {}", version);
parsed_commands.push_back(Arc::new(version_command));
} else if let Some(deserializer) = command_deserializer_from_string(name.as_str()) {
let deserialized_command = deserializer.deserialize(command_buffer, version);
} else if let Some(deserializer) = command_deserializer_from_string(name) {
let deserialized_command = deserializer.deserialize(field.data, version);
log::debug!("Received {:?}", deserialized_command);
parsed_commands.push_back(deserialized_command);
} else {
log::warn!("Received command {name} for which there is no deserializer.");
// TODO: Remove!
todo!("Write deserializer for {name}.");
}
head += length;
}
parsed_commands