feat: Handle version command

This commit is contained in:
Baud 2024-03-19 20:31:25 +00:00
parent d2f41821c0
commit 25460d77cd
11 changed files with 150 additions and 140 deletions

View File

@ -1,20 +1,26 @@
use std::{
collections::{HashMap, VecDeque},
net::SocketAddr,
ops::DerefMut,
sync::Arc,
time::Duration,
};
use tokio::{select, sync::Semaphore};
use tokio_util::sync::CancellationToken;
use crate::{
atem_lib::atem_socket::{AtemEvent, AtemSocketCommand, AtemSocketMessage, TrackingId},
atem_lib::atem_socket::{
AtemSocket, AtemSocketCommand, AtemSocketEvent, AtemSocketMessage, TrackingId,
},
commands::{
command_base::{BasicWritableCommand, DeserializedCommand},
device_profile::version::DESERIALIZE_VERSION_RAW_NAME,
init_complete::DESERIALIZE_INIT_COMPLETE_RAW_NAME,
parse_commands::deserialize_commands,
time::DESERIALIZE_TIME_RAW_NAME,
},
enums::ProtocolVersion,
state::AtemState,
};
@ -27,13 +33,24 @@ pub enum AtemConnectionStatus {
}
pub struct Atem {
protocol_version: tokio::sync::RwLock<ProtocolVersion>,
socket: tokio::sync::RwLock<AtemSocket>,
waiting_semaphores: tokio::sync::RwLock<HashMap<TrackingId, Arc<Semaphore>>>,
socket_message_tx: tokio::sync::mpsc::Sender<AtemSocketMessage>,
}
impl Atem {
pub fn new(socket_message_tx: tokio::sync::mpsc::Sender<AtemSocketMessage>) -> Self {
pub fn new(
socket: AtemSocket,
socket_message_tx: tokio::sync::mpsc::Sender<AtemSocketMessage>,
) -> Self {
Self {
protocol_version: tokio::sync::RwLock::new(ProtocolVersion::V7_2),
socket: tokio::sync::RwLock::new(socket),
waiting_semaphores: tokio::sync::RwLock::new(HashMap::new()),
socket_message_tx,
}
@ -54,25 +71,30 @@ impl Atem {
pub async fn run(
&self,
mut atem_event_rx: tokio::sync::mpsc::UnboundedReceiver<AtemEvent>,
mut atem_event_rx: tokio::sync::mpsc::UnboundedReceiver<AtemSocketEvent>,
cancel: CancellationToken,
) {
let mut status = AtemConnectionStatus::default();
let mut state = AtemState::default();
let mut poll_interval = tokio::time::interval(Duration::from_millis(5));
while !cancel.is_cancelled() {
let tick = poll_interval.tick();
select! {
_ = cancel.cancelled() => {},
_ = tick => {},
message = atem_event_rx.recv() => match message {
Some(event) => match event {
AtemEvent::Connected => {
AtemSocketEvent::Connected => {
log::info!("Atem connected");
}
AtemEvent::Disconnected => todo!("Disconnected"),
AtemEvent::ReceivedCommands(commands) => {
AtemSocketEvent::Disconnected => todo!("Disconnected"),
AtemSocketEvent::ReceivedCommands(payload) => {
let commands = deserialize_commands(&payload, self.protocol_version.write().await.deref_mut());
self.mutate_state(&mut state, &mut status, commands).await
}
AtemEvent::AckedCommand(tracking_id) => {
AtemSocketEvent::AckedCommand(tracking_id) => {
log::debug!("Received tracking Id {tracking_id}");
if let Some(semaphore) =
self.waiting_semaphores.read().await.get(&tracking_id)
@ -89,18 +111,19 @@ impl Atem {
}
}
}
self.socket.write().await.poll().await;
}
}
pub async fn send_commands(&self, commands: Vec<Box<dyn BasicWritableCommand>>) {
let protocol_version = { self.protocol_version.read().await.clone() };
let (callback_tx, callback_rx) = tokio::sync::oneshot::channel();
self.socket_message_tx
.send(AtemSocketMessage::SendCommands {
commands: commands
.iter()
.map(|command| {
AtemSocketCommand::new(command, &crate::enums::ProtocolVersion::Unknown)
})
.map(|command| AtemSocketCommand::new(command, &protocol_version))
.collect(),
tracking_ids_callback: callback_tx,
})

View File

@ -54,10 +54,10 @@ pub struct TrackingIdsCallback {
}
#[derive(Clone)]
pub enum AtemEvent {
pub enum AtemSocketEvent {
Connected,
Disconnected,
ReceivedCommands(VecDeque<Arc<dyn DeserializedCommand>>),
ReceivedCommands(Vec<u8>),
AckedCommand(TrackingId),
}
@ -111,8 +111,11 @@ pub struct AtemSocket {
ack_timer: Option<SystemTime>,
received_without_ack: u16,
atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemEvent>,
atem_message_rx: tokio::sync::mpsc::Receiver<AtemSocketMessage>,
atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemSocketEvent>,
connected_callbacks: Mutex<Vec<tokio::sync::oneshot::Sender<bool>>>,
tick_interval: tokio::time::Interval,
}
#[derive(PartialEq, Clone)]
@ -147,7 +150,11 @@ enum AtemSocketReceiveError {
}
impl AtemSocket {
pub fn new(atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemEvent>) -> Self {
pub fn new(
atem_message_rx: tokio::sync::mpsc::Receiver<AtemSocketMessage>,
atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemSocketEvent>,
) -> Self {
let tick_interval = tokio::time::interval(Duration::from_millis(5));
Self {
connection_state: ConnectionState::Closed,
reconnect_timer: None,
@ -169,23 +176,19 @@ impl AtemSocket {
ack_timer: None,
received_without_ack: 0,
atem_message_rx,
atem_event_tx,
connected_callbacks: Mutex::default(),
tick_interval,
}
}
pub async fn run(
&mut self,
mut atem_message_rx: tokio::sync::mpsc::Receiver<AtemSocketMessage>,
cancel: tokio_util::sync::CancellationToken,
) {
let mut interval = tokio::time::interval(Duration::from_millis(5));
while !cancel.is_cancelled() {
let tick = interval.tick();
pub async fn poll(&mut self) {
let tick = self.tick_interval.tick();
select! {
_ = cancel.cancelled() => {},
_ = tick => {},
message = atem_message_rx.recv() => {
message = self.atem_message_rx.recv() => {
match message {
Some(AtemSocketMessage::Connect {
address,
@ -234,8 +237,7 @@ impl AtemSocket {
barrier.wait().await;
},
None => {
log::info!("ATEM message channel has closed, exiting event loop.");
cancel.cancel();
log::info!("ATEM message channel has closed.");
}
}
}
@ -243,7 +245,6 @@ impl AtemSocket {
self.tick().await;
}
}
pub async fn connect(&mut self, address: SocketAddr) -> Result<(), io::Error> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
@ -536,23 +537,21 @@ impl AtemSocket {
}
fn on_commands_received(&mut self, payload: &[u8]) {
let commands = deserialize_commands(payload, &self.protocol_version);
let _ = self
.atem_event_tx
.send(AtemEvent::ReceivedCommands(commands));
.send(AtemSocketEvent::ReceivedCommands(payload.to_vec()));
}
fn on_command_acknowledged(&mut self, packets: Vec<AckedPacket>) {
for ack in packets {
let _ = self
.atem_event_tx
.send(AtemEvent::AckedCommand(TrackingId(ack.tracking_id)));
.send(AtemSocketEvent::AckedCommand(TrackingId(ack.tracking_id)));
}
}
async fn on_connect(&mut self) {
let _ = self.atem_event_tx.send(AtemEvent::Connected);
let _ = self.atem_event_tx.send(AtemSocketEvent::Connected);
let mut connected_callbacks = self.connected_callbacks.lock().await;
for callback in connected_callbacks.drain(0..) {
let _ = callback.send(false);
@ -560,7 +559,7 @@ impl AtemSocket {
}
fn on_disconnect(&mut self) {
let _ = self.atem_event_tx.send(AtemEvent::Disconnected);
let _ = self.atem_event_tx.send(AtemSocketEvent::Disconnected);
}
fn start_timers(&mut self) {

View File

@ -1,4 +1,4 @@
use std::{collections::HashMap, fmt::Debug, sync::Arc};
use std::{collections::HashMap, fmt::Debug, process::Command, sync::Arc};
use crate::{enums::ProtocolVersion, state::AtemState};

View File

@ -8,12 +8,12 @@ use crate::{
pub const DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME: &str = "_pin";
#[derive(Debug)]
pub struct ProductIdentifierCommand {
pub struct ProductIdentifier {
pub product_identifier: String,
pub model: Model,
}
impl DeserializedCommand for ProductIdentifierCommand {
impl DeserializedCommand for ProductIdentifier {
fn raw_name(&self) -> &'static str {
DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME
}
@ -40,14 +40,14 @@ impl DeserializedCommand for ProductIdentifierCommand {
}
#[derive(Default)]
pub struct ProductIdentifierCommandDeserializer {}
pub struct ProductIdentifierDeserializer {}
impl CommandDeserializer for ProductIdentifierCommandDeserializer {
impl CommandDeserializer for ProductIdentifierDeserializer {
fn deserialize(
&self,
buffer: &[u8],
version: &ProtocolVersion,
) -> std::sync::Arc<dyn DeserializedCommand> {
) -> Arc<dyn DeserializedCommand> {
let null_byte_index = buffer
.iter()
.position(|byte| *byte == b'\0')
@ -57,7 +57,7 @@ impl CommandDeserializer for ProductIdentifierCommandDeserializer {
.expect("Malformed string");
let model = buffer[40];
Arc::new(ProductIdentifierCommand {
Arc::new(ProductIdentifier {
product_identifier: product_identifier
.to_str()
.expect("Invalid rust string")

View File

@ -8,7 +8,7 @@ use crate::{
pub const DESERIALIZE_TOPOLOGY_RAW_NAME: &str = "_top";
#[derive(Debug)]
pub struct TopologyCommand {
pub struct Topology {
mix_effects: u8,
sources: u8,
auxilliaries: u8,
@ -27,7 +27,7 @@ pub struct TopologyCommand {
only_configurable_outputs: bool,
}
impl DeserializedCommand for TopologyCommand {
impl DeserializedCommand for Topology {
fn raw_name(&self) -> &'static str {
todo!()
}
@ -44,7 +44,7 @@ impl CommandDeserializer for TopologyCommandDeserializer {
&self,
buffer: &[u8],
version: &ProtocolVersion,
) -> std::sync::Arc<dyn crate::commands::command_base::DeserializedCommand> {
) -> Arc<dyn DeserializedCommand> {
let v230offset = if *version > ProtocolVersion::V8_0_1 {
1
} else {
@ -69,7 +69,7 @@ impl CommandDeserializer for TopologyCommandDeserializer {
false
};
Arc::new(TopologyCommand {
Arc::new(Topology {
mix_effects: buffer[0],
sources: buffer[1],
downstream_keyers: buffer[2],

View File

@ -1,39 +1,25 @@
use std::sync::Arc;
use crate::{
commands::command_base::{CommandDeserializer, DeserializedCommand},
enums::ProtocolVersion,
};
use crate::{commands::command_base::DeserializedCommand, enums::ProtocolVersion};
pub const DESERIALIZE_VERSION_RAW_NAME: &str = "_ver";
#[derive(Debug)]
pub struct VersionCommand {
pub struct Version {
pub version: ProtocolVersion,
}
impl DeserializedCommand for VersionCommand {
impl DeserializedCommand for Version {
fn raw_name(&self) -> &'static str {
DESERIALIZE_VERSION_RAW_NAME
}
fn apply_to_state(&self, state: &mut crate::state::AtemState) {
state.info.api_version = self.version;
state.info.api_version = self.version.clone();
}
}
#[derive(Default)]
pub struct VersionCommandDeserializer {}
impl CommandDeserializer for VersionCommandDeserializer {
fn deserialize(
&self,
buffer: &[u8],
version: &ProtocolVersion,
) -> std::sync::Arc<dyn DeserializedCommand> {
pub fn deserialize_version(buffer: &[u8]) -> Version {
let version = u32::from_be_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]);
let version: ProtocolVersion = version.try_into().expect("Invalid protocol version");
Arc::new(VersionCommand { version })
}
Version { version }
}

View File

@ -25,7 +25,7 @@ impl CommandDeserializer for InitCompleteDeserializer {
&self,
_buffer: &[u8],
version: &ProtocolVersion,
) -> std::sync::Arc<dyn DeserializedCommand> {
) -> Arc<dyn DeserializedCommand> {
Arc::new(InitComplete {})
}
}

View File

@ -1,14 +1,14 @@
use std::{collections::VecDeque, sync::Arc};
use crate::enums::ProtocolVersion;
use crate::{
commands::device_profile::version::{deserialize_version, DESERIALIZE_VERSION_RAW_NAME},
enums::ProtocolVersion,
};
use super::{
command_base::{CommandDeserializer, DeserializedCommand},
device_profile::{
product_identifier::{
ProductIdentifierCommandDeserializer, DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME,
},
version::{VersionCommandDeserializer, DESERIALIZE_VERSION_RAW_NAME},
device_profile::product_identifier::{
ProductIdentifierDeserializer, DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME,
},
init_complete::{InitCompleteDeserializer, DESERIALIZE_INIT_COMPLETE_RAW_NAME},
mix_effects::program_input::{ProgramInputDeserializer, DESERIALIZE_PROGRAM_INPUT_RAW_NAME},
@ -18,9 +18,9 @@ use super::{
pub fn deserialize_commands(
payload: &[u8],
version: &ProtocolVersion,
version: &mut ProtocolVersion,
) -> VecDeque<Arc<dyn DeserializedCommand>> {
let mut parsed_commands = VecDeque::new();
let mut parsed_commands: VecDeque<Arc<dyn DeserializedCommand>> = VecDeque::new();
let mut head = 0;
while payload.len() > head + 8 {
@ -35,15 +35,21 @@ pub fn deserialize_commands(
log::debug!("Received command {} with length {}", name, length);
if let Some(deserializer) = command_deserializer_from_string(name.as_str()) {
let deserialized_command =
deserializer.deserialize(&payload[head + 8..head + length], version);
let command_buffer = &payload[head + 8..head + length];
if name == DESERIALIZE_VERSION_RAW_NAME {
let version_command = deserialize_version(command_buffer);
*version = version_command.version.clone();
log::info!("Switched to protocol version {}", version);
parsed_commands.push_back(Arc::new(version_command));
} else if let Some(deserializer) = command_deserializer_from_string(name.as_str()) {
let deserialized_command = deserializer.deserialize(command_buffer, version);
log::debug!("Received {:?}", deserialized_command);
parsed_commands.push_back(deserialized_command);
} else {
log::warn!("Received command {name} for which there is no deserializer.");
// TODO: Remove!
// todo!("Write deserializer for {name}.");
todo!("Write deserializer for {name}.");
}
head += length;
@ -54,13 +60,12 @@ pub fn deserialize_commands(
fn command_deserializer_from_string(command_str: &str) -> Option<Box<dyn CommandDeserializer>> {
match command_str {
DESERIALIZE_VERSION_RAW_NAME => Some(Box::<VersionCommandDeserializer>::default()),
DESERIALIZE_INIT_COMPLETE_RAW_NAME => Some(Box::<InitCompleteDeserializer>::default()),
DESERIALIZE_PROGRAM_INPUT_RAW_NAME => Some(Box::<ProgramInputDeserializer>::default()),
DESERIALIZE_TALLY_BY_SOURCE_RAW_NAME => Some(Box::<TallyBySourceDeserializer>::default()),
DESERIALIZE_TIME_RAW_NAME => Some(Box::<TimeDeserializer>::default()),
DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME => {
Some(Box::<ProductIdentifierCommandDeserializer>::default())
Some(Box::<ProductIdentifierDeserializer>::default())
}
_ => None,
}

View File

@ -36,7 +36,7 @@ impl CommandDeserializer for TimeDeserializer {
&self,
buffer: &[u8],
version: &ProtocolVersion,
) -> std::sync::Arc<dyn super::command_base::DeserializedCommand> {
) -> Arc<dyn DeserializedCommand> {
let info = TimeInfo {
hour: buffer[0],
minute: buffer[1],

View File

@ -73,7 +73,7 @@ impl From<u8> for Model {
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, PartialOrd)]
#[derive(Debug, Default, Clone, PartialEq, PartialOrd)]
pub enum ProtocolVersion {
#[default]
Unknown = 0,

View File

@ -35,12 +35,10 @@ async fn main() {
tokio::sync::mpsc::channel::<AtemSocketMessage>(10);
let (atem_event_tx, atem_event_rx) = tokio::sync::mpsc::unbounded_channel();
let cancel = CancellationToken::new();
let cancel_task = cancel.clone();
let mut atem_socket = AtemSocket::new(atem_event_tx);
let atem_socket_run = atem_socket.run(socket_message_rx, cancel_task);
let mut atem_socket = AtemSocket::new(socket_message_rx, atem_event_tx);
let atem = Arc::new(Atem::new(socket_message_tx));
let atem = Arc::new(Atem::new(atem_socket, socket_message_tx));
let atem_thread = atem.clone();
let atem_run = atem_thread.run(atem_event_rx, cancel);
@ -64,7 +62,6 @@ async fn main() {
});
select! {
_ = atem_socket_run => {},
_ = atem_run => {},
_ = switch_loop => {}
}