diff --git a/Cargo.lock b/Cargo.lock index e9e4d12..2273f83 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -80,6 +80,7 @@ version = "0.1.0" dependencies = [ "derive-getters", "derive-new", + "enumflags2", "itertools", "log", "tokio", @@ -256,6 +257,26 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "enumflags2" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1027f7680c853e056ebcec683615fb6fbbc07dbaa13b4d5d9442b146ded4ecef" +dependencies = [ + "enumflags2_derive", +] + +[[package]] +name = "enumflags2_derive" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67c78a4d8fdf9953a5c9d458f9efe940fd97a0cab0941c075a813ac594733827" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "env_logger" version = "0.9.0" diff --git a/atem-connection-rs/Cargo.toml b/atem-connection-rs/Cargo.toml index 0eae179..ffc5820 100644 --- a/atem-connection-rs/Cargo.toml +++ b/atem-connection-rs/Cargo.toml @@ -10,6 +10,7 @@ itertools = {version = "0.14.0"} log = "0.4.14" tokio = { version = "1.13.0", features = ["full"], optional = true } tokio-util = { version = "0.7.10", optional = true } +enumflags2 = { version = "0.7.12", default-features = false } [features] default = ["std"] diff --git a/atem-connection-rs/src/atem_lib/atem_packet.rs b/atem-connection-rs/src/atem_lib/atem_packet.rs index 832ef11..94cd59f 100755 --- a/atem-connection-rs/src/atem_lib/atem_packet.rs +++ b/atem-connection-rs/src/atem_lib/atem_packet.rs @@ -1,4 +1,5 @@ use core::{fmt::Display, str}; +use enumflags2::{bitflags, BitFlags}; use super::atem_field::{Field, FieldParsingError, RawField}; @@ -23,30 +24,26 @@ impl<'a> TryFrom<&'a [u8]> for AtemPacket<&'a [u8]> { #[derive(Debug)] pub enum AtemPacketErr { /// The packet was too short - TooShort { got: usize }, + TooShort { + got: usize, + }, /// The packet's stated and actual lengths were different - LengthDiffers { expected: u16, got: usize }, + LengthDiffers { + expected: u16, + got: usize, + }, + InvalidFlags, } -#[derive(PartialEq)] +#[bitflags] +#[repr(u8)] +#[derive(PartialEq, Copy, Clone, Debug)] pub enum PacketFlag { - AckRequest, - NewSessionId, - IsRetransmit, - RetransmitRequest, - AckReply, -} - -impl From for u8 { - fn from(flag: PacketFlag) -> Self { - match flag { - PacketFlag::AckRequest => 0x01, - PacketFlag::NewSessionId => 0x02, - PacketFlag::IsRetransmit => 0x04, - PacketFlag::RetransmitRequest => 0x08, - PacketFlag::AckReply => 0x10, - } - } + AckRequest = 0x1, + NewSessionId = 0x2, + IsRetransmit = 0x4, + RetransmitRequest = 0x8, + AckReply = 0x10, } impl> AtemPacket { @@ -64,16 +61,26 @@ impl> AtemPacket { got: len, }); } + // Check flags are valid + let _: BitFlags = p + .flags_raw() + .try_into() + .map_err(|_| AtemPacketErr::InvalidFlags)?; Ok(p) } pub fn length(&self) -> u16 { u16::from_be_bytes(self.buf.as_ref()[0..=1].try_into().unwrap()) & 0x07ff } - pub fn flags(&self) -> u8 { + fn flags_raw(&self) -> u8 { self.buf.as_ref()[0] >> 3 } + pub fn flags(&self) -> BitFlags { + // We `unwrap` here, but given we check the flags in the constructor this should never panic. + self.flags_raw().try_into().unwrap() + } + pub fn session_id(&self) -> u16 { u16::from_be_bytes(self.buf.as_ref()[2..=3].try_into().unwrap()) } @@ -91,7 +98,8 @@ impl> AtemPacket { } pub fn retransmit_request(&self) -> Option { - self.has_flag(PacketFlag::RetransmitRequest) + self.flags() + .contains(PacketFlag::RetransmitRequest) .then_some(u16::from_be_bytes([ self.buf.as_ref()[6], self.buf.as_ref()[7], @@ -99,21 +107,17 @@ impl> AtemPacket { } pub fn ack_reply(&self) -> Option { - self.has_flag(PacketFlag::AckReply) + self.flags() + .contains(PacketFlag::AckReply) .then_some(self.ack_number()) } - /// Return true if this packet has the given [`PacketFlag`] - pub fn has_flag(&self, flag: PacketFlag) -> bool { - self.flags() & u8::from(flag) > 0 - } - /// Get an iterator over the `Field`s in this packet. /// /// Returns None if this is a packet without fields. pub fn raw_fields(&self) -> RawFields { // TODO: do we only ever get newsessionid during the handshake (i.e. not in a packet with fields)? - let has_fields = !self.has_flag(PacketFlag::NewSessionId); + let has_fields = !self.flags().contains(PacketFlag::NewSessionId); RawFields::new(if has_fields { self.body() } else { &[] }) } diff --git a/atem-connection-rs/src/atem_lib/atem_socket.rs b/atem-connection-rs/src/atem_lib/atem_socket.rs index a006a77..06b94cb 100644 --- a/atem-connection-rs/src/atem_lib/atem_socket.rs +++ b/atem-connection-rs/src/atem_lib/atem_socket.rs @@ -300,7 +300,7 @@ impl AtemSocket { self.next_send_packet_id = 0; } - let opcode = u16::from(u8::from(PacketFlag::AckRequest)) << 11; + let opcode = u16::from(PacketFlag::AckRequest as u8) << 11; let mut buffer = vec![0; 20 + payload.len()]; @@ -403,7 +403,7 @@ impl AtemSocket { // TODO: naming seems rather off here let remote_packet_id = atem_packet.local_sequence_number(); - if atem_packet.has_flag(PacketFlag::NewSessionId) { + if atem_packet.flags().contains(PacketFlag::NewSessionId) { log::debug!("New session"); self.connection_state = ConnectionState::Established; self.last_received_packed_id = remote_packet_id; @@ -419,7 +419,7 @@ impl AtemSocket { self.retransmit_from(from_packet_id).await; } - if atem_packet.has_flag(PacketFlag::AckRequest) { + if atem_packet.flags().contains(PacketFlag::AckRequest) { if remote_packet_id == (self.last_received_packed_id + 1) % MAX_PACKET_ID { self.last_received_packed_id = remote_packet_id; self.send_or_queue_ack().await; @@ -432,7 +432,7 @@ impl AtemSocket { } } - if atem_packet.has_flag(PacketFlag::IsRetransmit) { + if atem_packet.flags().contains(PacketFlag::IsRetransmit) { log::debug!("ATEM retransmitted packet {:x?}", remote_packet_id); } @@ -482,7 +482,7 @@ impl AtemSocket { async fn send_ack(&mut self, packet_id: u16) { log::debug!("Sending ack for packet {:x?}", packet_id); - let flag: u8 = PacketFlag::AckReply.into(); + let flag: u8 = PacketFlag::AckReply as u8; let opcode = u16::from(flag) << 11; let mut buffer: [u8; ACK_PACKET_LENGTH as _] = [0; 12]; buffer[0..2].copy_from_slice(&u16::to_be_bytes(opcode | ACK_PACKET_LENGTH));