diff --git a/Cargo.lock b/Cargo.lock index 4422846..e9e4d12 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -80,6 +80,7 @@ version = "0.1.0" dependencies = [ "derive-getters", "derive-new", + "itertools", "log", "tokio", "tokio-util", @@ -249,6 +250,12 @@ dependencies = [ "syn 2.0.48", ] +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + [[package]] name = "env_logger" version = "0.9.0" @@ -323,6 +330,15 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "lazy_static" version = "1.4.0" diff --git a/atem-connection-rs/Cargo.toml b/atem-connection-rs/Cargo.toml index 58dc8eb..0eae179 100644 --- a/atem-connection-rs/Cargo.toml +++ b/atem-connection-rs/Cargo.toml @@ -6,6 +6,12 @@ edition = "2021" [dependencies] derive-getters = "0.2.0" derive-new = "0.6.0" +itertools = {version = "0.14.0"} log = "0.4.14" -tokio = { version = "1.13.0", features = ["full"] } -tokio-util = "0.7.10" +tokio = { version = "1.13.0", features = ["full"], optional = true } +tokio-util = { version = "0.7.10", optional = true } + +[features] +default = ["std"] + +std = ["dep:tokio", "dep:tokio-util"] diff --git a/atem-connection-rs/src/atem_lib/atem_field.rs b/atem-connection-rs/src/atem_lib/atem_field.rs new file mode 100644 index 0000000..1fa07d7 --- /dev/null +++ b/atem-connection-rs/src/atem_lib/atem_field.rs @@ -0,0 +1,88 @@ +//! Definitions and decoding of ATEM protocol fields + +/// An uninterpreted ATEM protocol field - a 4-ascii character type plus variable length data +pub struct RawField<'a> { + pub r#type: &'a [u8; 4], + pub data: &'a [u8], +} + +#[derive(Debug)] +pub struct _Ver { + pub major: u16, + pub minor: u16, +} + +impl<'a> Field<'a> for _Ver { + const TYPE: [u8; 4] = [b'_', b'v', b'e', b'r']; + + fn decode(data: &'a [u8]) -> Result { + let data = checked_len::<4>(data)?; + Ok(Self { + major: u16::from_be_bytes(data[0..=1].try_into().unwrap()), + minor: u16::from_be_bytes(data[2..=3].try_into().unwrap()), + }) + } +} + +#[derive(Debug)] +pub struct PrvI { + pub m_e_index: u8, + pub source_index: u16, + pub pvw_in_pgm: bool, +} + +impl<'a> Field<'a> for PrvI { + const TYPE: [u8; 4] = [b'P', b'r', b'v', b'I']; + fn decode(data: &'a [u8]) -> Result { + let data = checked_len::<8>(data)?; + Ok(Self { + m_e_index: data[0], + source_index: u16::from_be_bytes(data[2..=3].try_into().unwrap()), + pvw_in_pgm: data[4] != 0, + }) + } +} + +#[derive(Debug)] +pub struct PrgI { + pub m_e_index: u8, + pub source_index: u16, +} + +impl<'a> Field<'a> for PrgI { + const TYPE: [u8; 4] = [b'P', b'r', b'g', b'I']; + fn decode(data: &'a [u8]) -> Result { + let data = checked_len::<4>(data)?; + Ok(Self { + m_e_index: data[0], + source_index: u16::from_be_bytes(data[2..=3].try_into().unwrap()), + }) + } +} + +pub trait Field<'a>: Sized { + const TYPE: [u8; 4]; + fn decode(data: &'a [u8]) -> Result; + fn try_from_raw(raw: RawField<'a>) -> Result { + if Self::TYPE != *raw.r#type { + Err(FieldParsingError::MismatchedFieldType) + } else { + Self::decode(&raw.data) + } + } +} + +#[derive(Debug)] +pub enum FieldParsingError { + UnexpectedLength { expected: usize, got: usize }, + UnknownFieldType { r#type: [u8; 4] }, + MismatchedFieldType, +} + +fn checked_len<'a, const LEN: usize>(data: &'a [u8]) -> Result<&'a [u8; LEN], FieldParsingError> { + data.try_into() + .map_err(|_| FieldParsingError::UnexpectedLength { + expected: LEN, + got: data.len(), + }) +} diff --git a/atem-connection-rs/src/atem_lib/atem_packet.rs b/atem-connection-rs/src/atem_lib/atem_packet.rs index deeb5b9..2d3cfa8 100755 --- a/atem-connection-rs/src/atem_lib/atem_packet.rs +++ b/atem-connection-rs/src/atem_lib/atem_packet.rs @@ -1,16 +1,28 @@ +use core::{fmt::Display, str}; + +/// The "hello" packet to start communication with the ATEM +pub const COMMAND_CONNECT_HELLO: [u8; 20] = [ + 0x10, 0x14, 0x53, 0xab, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3a, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, +]; + #[derive(Debug)] -pub struct AtemPacket<'packet_buffer> { - flags: u8, - session_id: u16, - remote_packet_id: u16, - retransmit_requested_from_packet_id: Option, - ack_reply: Option, - body: Option<&'packet_buffer [u8]>, +pub struct AtemPacket> { + buf: T, } +impl<'a> From<&'a [u8]> for AtemPacket<&'a [u8]> { + fn from(buf: &'a [u8]) -> Self { + AtemPacket { buf } + } +} + +#[derive(Debug)] pub enum AtemPacketErr { - TooShort(String), - LengthDiffers(String), + /// The packet was too short + TooShort { got: usize }, + /// The packet's stated and actual lengths were different + LengthDiffers { expected: u16, got: usize }, } #[derive(PartialEq)] @@ -34,82 +46,133 @@ impl From for u8 { } } -impl<'packet_buffer> AtemPacket<'packet_buffer> { +impl> AtemPacket { + pub fn new_checked(buf: T) -> Result { + let len = buf.as_ref().len(); + if len < 12 { + return Err(AtemPacketErr::TooShort { + got: buf.as_ref().len(), + }); + } + let p = Self { buf }; + if p.length() as usize != len { + return Err(AtemPacketErr::LengthDiffers { + expected: p.length(), + got: len, + }); + } + Ok(p) + } + pub fn length(&self) -> u16 { + u16::from_be_bytes(self.buf.as_ref()[0..=1].try_into().unwrap()) & 0x07ff + } + + pub fn flags(&self) -> u8 { + self.buf.as_ref()[0] >> 3 + } + pub fn session_id(&self) -> u16 { - self.session_id + u16::from_be_bytes(self.buf.as_ref()[2..=3].try_into().unwrap()) } - pub fn remote_packet_id(&self) -> u16 { - self.remote_packet_id + pub fn ack_number(&self) -> u16 { + u16::from_be_bytes(self.buf.as_ref()[4..=5].try_into().unwrap()) } - pub fn body(&self) -> Option<&[u8]> { - self.body + pub fn remote_sequence_number(&self) -> u16 { + u16::from_be_bytes(self.buf.as_ref()[9..=10].try_into().unwrap()) + } + + pub fn local_sequence_number(&self) -> u16 { + u16::from_be_bytes(self.buf.as_ref()[10..=11].try_into().unwrap()) } pub fn retransmit_request(&self) -> Option { - self.retransmit_requested_from_packet_id + self.has_flag(PacketFlag::RetransmitRequest) + .then_some(u16::from_be_bytes([ + self.buf.as_ref()[6], + self.buf.as_ref()[7], + ])) } pub fn ack_reply(&self) -> Option { - self.ack_reply + self.has_flag(PacketFlag::AckReply) + .then_some(u16::from_be_bytes([ + self.buf.as_ref()[4], + self.buf.as_ref()[5], + ])) } + /// Return true if this packet has the given [`PacketFlag`] pub fn has_flag(&self, flag: PacketFlag) -> bool { - self.flags & u8::from(flag) > 0 + self.flags() & u8::from(flag) > 0 + } + + /// Get an iterator over the `Field`s in this packet. + /// + /// Returns None if this is a packet without fields. + pub fn fields(&self) -> Fields { + // TODO: do we only ever get newsessionid during the handshake (i.e. not in a packet with fields)? + let has_fields = !self.has_flag(PacketFlag::NewSessionId); + Fields::new(if has_fields { self.body() } else { &[] }) + } + + pub fn body(&self) -> &[u8] { + &self.buf.as_ref()[12..] } } -impl<'packet_buffer> TryFrom<&'packet_buffer [u8]> for AtemPacket<'packet_buffer> { - type Error = AtemPacketErr; - - fn try_from(buffer: &'packet_buffer [u8]) -> Result { - if buffer.len() < 12 { - return Err(AtemPacketErr::TooShort(format!( - "Invalid packet from ATEM {:x?}", - buffer - ))); - } - - let length = u16::from_be_bytes([buffer[0], buffer[1]]) & 0x07ff; - if length as usize != buffer.len() { - return Err(AtemPacketErr::LengthDiffers(format!( - "Length of message differs, expected {} got {}", - length, - buffer.len() - ))); - } - - let flags = buffer[0] >> 3; - let session_id = u16::from_be_bytes([buffer[2], buffer[3]]); - let remote_packet_id = u16::from_be_bytes([buffer[10], buffer[11]]); - - let body = if buffer.len() > 12 { - Some(&buffer[12..]) - } else { - None - }; - - let retransmit_requested_from_packet_id = - if flags & u8::from(PacketFlag::RetransmitRequest) > 0 { - Some(u16::from_be_bytes([buffer[6], buffer[7]])) - } else { - None - }; - - let ack_reply = if flags & u8::from(PacketFlag::AckReply) > 0 { - Some(u16::from_be_bytes([buffer[4], buffer[5]])) - } else { - None - }; - - Ok(AtemPacket { - flags, - session_id, - remote_packet_id, - body, - retransmit_requested_from_packet_id, - ack_reply, - }) +impl> Display for AtemPacket { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!( + f, + "len: {}, sid: {}, flags: {}, ack: {}, localseq: {}, remoteseq: {}", + self.length(), + self.session_id(), + self.flags(), + self.ack_number(), + self.remote_sequence_number(), + self.local_sequence_number(), + ) + } +} + +/// An ATEM protocol field - a 4-ascii character type plus variable length data +pub struct Field<'a> { + pub r#type: &'a str, + pub data: &'a [u8], +} + +/// Created by [`AtemPacket::fields`] +pub struct Fields<'a> { + data: &'a [u8], + // The offset of the next field in the packet + offset: usize, +} + +impl<'a> Fields<'a> { + pub(crate) fn new(data: &'a [u8]) -> Self { + Self { data, offset: 0 } + } +} + +impl<'a> Iterator for Fields<'a> { + type Item = Field<'a>; + fn next(&mut self) -> Option { + let remain = self.data.len() - self.offset; + if remain == 0 { + return None; + } else if remain < 8 { + // TODO: is 8 indeed the minimum size for something here? (i.e. no field data) + panic!("Oh no"); + } + + let length = + u16::from_be_bytes(self.data[self.offset..=self.offset + 1].try_into().unwrap()); + // TODO: sanity check length + let r#type = str::from_utf8(&self.data[self.offset + 4..=self.offset + 7]).unwrap(); + let data = &self.data[self.offset + 8..self.offset + (length as usize)]; + self.offset += (length as usize); + Some(Field { r#type, data }) } } diff --git a/atem-connection-rs/src/atem_lib/atem_socket.rs b/atem-connection-rs/src/atem_lib/atem_socket.rs index 4945b54..e96529b 100644 --- a/atem-connection-rs/src/atem_lib/atem_socket.rs +++ b/atem-connection-rs/src/atem_lib/atem_socket.rs @@ -7,6 +7,7 @@ use std::{ time::{Duration, SystemTime}, }; +use itertools::Itertools; use tokio::{ net::UdpSocket, select, @@ -14,7 +15,7 @@ use tokio::{ }; use crate::{ - atem_lib::{atem_packet::AtemPacket, atem_util}, + atem_lib::atem_packet::{self, AtemPacket, COMMAND_CONNECT_HELLO}, commands::{ command_base::{BasicWritableCommand, DeserializedCommand}, parse_commands::deserialize_commands, @@ -258,7 +259,7 @@ impl AtemSocket { self.in_flight = vec![]; log::debug!("Reconnect"); - self.send_packet(&atem_util::COMMAND_CONNECT_HELLO).await; + self.send_packet(&COMMAND_CONNECT_HELLO).await; self.connection_state = ConnectionState::SynSent; Ok(()) @@ -386,16 +387,21 @@ impl AtemSocket { } async fn recieved_packet(&mut self, packet: &[u8]) { - let Ok(atem_packet): Result = packet.try_into() else { + let Ok(atem_packet): Result, _> = packet.try_into() else { return; }; - log::debug!("Received {:x?}", atem_packet); + log::debug!("Received {}", atem_packet,); + log::debug!( + "fields: {}", + atem_packet.fields().map(|f| f.r#type).join(",") + ); self.last_received_at = SystemTime::now(); self.session_id = atem_packet.session_id(); - let remote_packet_id = atem_packet.remote_packet_id(); + // TODO: naming seems rather off here + let remote_packet_id = atem_packet.local_sequence_number(); if atem_packet.has_flag(PacketFlag::NewSessionId) { log::debug!("New session"); @@ -418,9 +424,7 @@ impl AtemSocket { self.last_received_packed_id = remote_packet_id; self.send_or_queue_ack().await; - if let Some(body) = atem_packet.body() { - self.on_commands_received(body); - } + self.on_commands_received(atem_packet.body()); } else if self .is_packet_covered_by_ack(self.last_received_packed_id, remote_packet_id) { @@ -537,9 +541,11 @@ impl AtemSocket { } fn on_commands_received(&mut self, payload: &[u8]) { - let _ = self - .atem_event_tx - .send(AtemSocketEvent::ReceivedCommands(payload.to_vec())); + if !payload.is_empty() { + let _ = self + .atem_event_tx + .send(AtemSocketEvent::ReceivedCommands(payload.to_vec())); + } } fn on_command_acknowledged(&mut self, packets: Vec) { diff --git a/atem-connection-rs/src/atem_lib/atem_util.rs b/atem-connection-rs/src/atem_lib/atem_util.rs deleted file mode 100644 index 09b970a..0000000 --- a/atem-connection-rs/src/atem_lib/atem_util.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub const COMMAND_CONNECT_HELLO: [u8; 20] = [ - 0x10, 0x14, 0x53, 0xab, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3a, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, -]; diff --git a/atem-connection-rs/src/atem_lib/mod.rs b/atem-connection-rs/src/atem_lib/mod.rs index 9b3090b..388a4e1 100644 --- a/atem-connection-rs/src/atem_lib/mod.rs +++ b/atem-connection-rs/src/atem_lib/mod.rs @@ -1,3 +1,2 @@ -mod atem_packet; +pub mod atem_packet; pub mod atem_socket; -pub mod atem_util; diff --git a/atem-connection-rs/src/commands/parse_commands.rs b/atem-connection-rs/src/commands/parse_commands.rs index 1c51d1d..e32386e 100644 --- a/atem-connection-rs/src/commands/parse_commands.rs +++ b/atem-connection-rs/src/commands/parse_commands.rs @@ -1,6 +1,7 @@ use std::{collections::VecDeque, sync::Arc}; use crate::{ + atem_lib::atem_packet::{AtemPacket, Fields}, commands::device_profile::version::{deserialize_version, DESERIALIZE_VERSION_RAW_NAME}, enums::ProtocolVersion, }; @@ -25,43 +26,28 @@ use super::{ time::{TimeDeserializer, DESERIALIZE_TIME_RAW_NAME}, }; -pub fn deserialize_commands( - payload: &[u8], +pub fn deserialize_commands>( + payload: T, version: &mut ProtocolVersion, ) -> VecDeque> { let mut parsed_commands: VecDeque> = VecDeque::new(); - let mut head = 0; - while payload.len() > head + 8 { - let length = u16::from_be_bytes([payload[head], payload[head + 1]]) as usize; - let Ok(name) = String::from_utf8(payload[(head + 4)..(head + 8)].to_vec()) else { - break; - }; + for field in Fields::new(payload.as_ref()) { + let name: &str = field.r#type.try_into().unwrap(); + log::debug!("Received command {} with length {}", name, field.data.len(),); - if length < 8 { - break; - } - - log::debug!("Received command {} with length {}", name, length); - - let command_buffer = &payload[head + 8..head + length]; - - if name == DESERIALIZE_VERSION_RAW_NAME { - let version_command = deserialize_version(command_buffer); + if field.r#type == DESERIALIZE_VERSION_RAW_NAME { + let version_command = deserialize_version(field.data); *version = version_command.version.clone(); log::info!("Switched to protocol version {}", version); parsed_commands.push_back(Arc::new(version_command)); - } else if let Some(deserializer) = command_deserializer_from_string(name.as_str()) { - let deserialized_command = deserializer.deserialize(command_buffer, version); + } else if let Some(deserializer) = command_deserializer_from_string(name) { + let deserialized_command = deserializer.deserialize(field.data, version); log::debug!("Received {:?}", deserialized_command); parsed_commands.push_back(deserialized_command); } else { log::warn!("Received command {name} for which there is no deserializer."); - // TODO: Remove! - todo!("Write deserializer for {name}."); } - - head += length; } parsed_commands