feat: Handle version command
This commit is contained in:
parent
d2f41821c0
commit
25460d77cd
|
@ -1,20 +1,26 @@
|
|||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
net::SocketAddr,
|
||||
ops::DerefMut,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use tokio::{select, sync::Semaphore};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::{
|
||||
atem_lib::atem_socket::{AtemEvent, AtemSocketCommand, AtemSocketMessage, TrackingId},
|
||||
atem_lib::atem_socket::{
|
||||
AtemSocket, AtemSocketCommand, AtemSocketEvent, AtemSocketMessage, TrackingId,
|
||||
},
|
||||
commands::{
|
||||
command_base::{BasicWritableCommand, DeserializedCommand},
|
||||
device_profile::version::DESERIALIZE_VERSION_RAW_NAME,
|
||||
init_complete::DESERIALIZE_INIT_COMPLETE_RAW_NAME,
|
||||
parse_commands::deserialize_commands,
|
||||
time::DESERIALIZE_TIME_RAW_NAME,
|
||||
},
|
||||
enums::ProtocolVersion,
|
||||
state::AtemState,
|
||||
};
|
||||
|
||||
|
@ -27,13 +33,24 @@ pub enum AtemConnectionStatus {
|
|||
}
|
||||
|
||||
pub struct Atem {
|
||||
protocol_version: tokio::sync::RwLock<ProtocolVersion>,
|
||||
|
||||
socket: tokio::sync::RwLock<AtemSocket>,
|
||||
|
||||
waiting_semaphores: tokio::sync::RwLock<HashMap<TrackingId, Arc<Semaphore>>>,
|
||||
socket_message_tx: tokio::sync::mpsc::Sender<AtemSocketMessage>,
|
||||
}
|
||||
|
||||
impl Atem {
|
||||
pub fn new(socket_message_tx: tokio::sync::mpsc::Sender<AtemSocketMessage>) -> Self {
|
||||
pub fn new(
|
||||
socket: AtemSocket,
|
||||
socket_message_tx: tokio::sync::mpsc::Sender<AtemSocketMessage>,
|
||||
) -> Self {
|
||||
Self {
|
||||
protocol_version: tokio::sync::RwLock::new(ProtocolVersion::V7_2),
|
||||
|
||||
socket: tokio::sync::RwLock::new(socket),
|
||||
|
||||
waiting_semaphores: tokio::sync::RwLock::new(HashMap::new()),
|
||||
socket_message_tx,
|
||||
}
|
||||
|
@ -54,25 +71,30 @@ impl Atem {
|
|||
|
||||
pub async fn run(
|
||||
&self,
|
||||
mut atem_event_rx: tokio::sync::mpsc::UnboundedReceiver<AtemEvent>,
|
||||
mut atem_event_rx: tokio::sync::mpsc::UnboundedReceiver<AtemSocketEvent>,
|
||||
cancel: CancellationToken,
|
||||
) {
|
||||
let mut status = AtemConnectionStatus::default();
|
||||
let mut state = AtemState::default();
|
||||
|
||||
let mut poll_interval = tokio::time::interval(Duration::from_millis(5));
|
||||
|
||||
while !cancel.is_cancelled() {
|
||||
let tick = poll_interval.tick();
|
||||
select! {
|
||||
_ = cancel.cancelled() => {},
|
||||
_ = tick => {},
|
||||
message = atem_event_rx.recv() => match message {
|
||||
Some(event) => match event {
|
||||
AtemEvent::Connected => {
|
||||
AtemSocketEvent::Connected => {
|
||||
log::info!("Atem connected");
|
||||
}
|
||||
AtemEvent::Disconnected => todo!("Disconnected"),
|
||||
AtemEvent::ReceivedCommands(commands) => {
|
||||
AtemSocketEvent::Disconnected => todo!("Disconnected"),
|
||||
AtemSocketEvent::ReceivedCommands(payload) => {
|
||||
let commands = deserialize_commands(&payload, self.protocol_version.write().await.deref_mut());
|
||||
self.mutate_state(&mut state, &mut status, commands).await
|
||||
}
|
||||
AtemEvent::AckedCommand(tracking_id) => {
|
||||
AtemSocketEvent::AckedCommand(tracking_id) => {
|
||||
log::debug!("Received tracking Id {tracking_id}");
|
||||
if let Some(semaphore) =
|
||||
self.waiting_semaphores.read().await.get(&tracking_id)
|
||||
|
@ -89,18 +111,19 @@ impl Atem {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.socket.write().await.poll().await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_commands(&self, commands: Vec<Box<dyn BasicWritableCommand>>) {
|
||||
let protocol_version = { self.protocol_version.read().await.clone() };
|
||||
let (callback_tx, callback_rx) = tokio::sync::oneshot::channel();
|
||||
self.socket_message_tx
|
||||
.send(AtemSocketMessage::SendCommands {
|
||||
commands: commands
|
||||
.iter()
|
||||
.map(|command| {
|
||||
AtemSocketCommand::new(command, &crate::enums::ProtocolVersion::Unknown)
|
||||
})
|
||||
.map(|command| AtemSocketCommand::new(command, &protocol_version))
|
||||
.collect(),
|
||||
tracking_ids_callback: callback_tx,
|
||||
})
|
||||
|
|
|
@ -54,10 +54,10 @@ pub struct TrackingIdsCallback {
|
|||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum AtemEvent {
|
||||
pub enum AtemSocketEvent {
|
||||
Connected,
|
||||
Disconnected,
|
||||
ReceivedCommands(VecDeque<Arc<dyn DeserializedCommand>>),
|
||||
ReceivedCommands(Vec<u8>),
|
||||
AckedCommand(TrackingId),
|
||||
}
|
||||
|
||||
|
@ -111,8 +111,11 @@ pub struct AtemSocket {
|
|||
ack_timer: Option<SystemTime>,
|
||||
received_without_ack: u16,
|
||||
|
||||
atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemEvent>,
|
||||
atem_message_rx: tokio::sync::mpsc::Receiver<AtemSocketMessage>,
|
||||
atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemSocketEvent>,
|
||||
connected_callbacks: Mutex<Vec<tokio::sync::oneshot::Sender<bool>>>,
|
||||
|
||||
tick_interval: tokio::time::Interval,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Clone)]
|
||||
|
@ -147,7 +150,11 @@ enum AtemSocketReceiveError {
|
|||
}
|
||||
|
||||
impl AtemSocket {
|
||||
pub fn new(atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemEvent>) -> Self {
|
||||
pub fn new(
|
||||
atem_message_rx: tokio::sync::mpsc::Receiver<AtemSocketMessage>,
|
||||
atem_event_tx: tokio::sync::mpsc::UnboundedSender<AtemSocketEvent>,
|
||||
) -> Self {
|
||||
let tick_interval = tokio::time::interval(Duration::from_millis(5));
|
||||
Self {
|
||||
connection_state: ConnectionState::Closed,
|
||||
reconnect_timer: None,
|
||||
|
@ -169,23 +176,19 @@ impl AtemSocket {
|
|||
ack_timer: None,
|
||||
received_without_ack: 0,
|
||||
|
||||
atem_message_rx,
|
||||
atem_event_tx,
|
||||
connected_callbacks: Mutex::default(),
|
||||
|
||||
tick_interval,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
&mut self,
|
||||
mut atem_message_rx: tokio::sync::mpsc::Receiver<AtemSocketMessage>,
|
||||
cancel: tokio_util::sync::CancellationToken,
|
||||
) {
|
||||
let mut interval = tokio::time::interval(Duration::from_millis(5));
|
||||
while !cancel.is_cancelled() {
|
||||
let tick = interval.tick();
|
||||
pub async fn poll(&mut self) {
|
||||
let tick = self.tick_interval.tick();
|
||||
select! {
|
||||
_ = cancel.cancelled() => {},
|
||||
_ = tick => {},
|
||||
message = atem_message_rx.recv() => {
|
||||
message = self.atem_message_rx.recv() => {
|
||||
match message {
|
||||
Some(AtemSocketMessage::Connect {
|
||||
address,
|
||||
|
@ -234,8 +237,7 @@ impl AtemSocket {
|
|||
barrier.wait().await;
|
||||
},
|
||||
None => {
|
||||
log::info!("ATEM message channel has closed, exiting event loop.");
|
||||
cancel.cancel();
|
||||
log::info!("ATEM message channel has closed.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -243,7 +245,6 @@ impl AtemSocket {
|
|||
|
||||
self.tick().await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect(&mut self, address: SocketAddr) -> Result<(), io::Error> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
|
@ -536,23 +537,21 @@ impl AtemSocket {
|
|||
}
|
||||
|
||||
fn on_commands_received(&mut self, payload: &[u8]) {
|
||||
let commands = deserialize_commands(payload, &self.protocol_version);
|
||||
|
||||
let _ = self
|
||||
.atem_event_tx
|
||||
.send(AtemEvent::ReceivedCommands(commands));
|
||||
.send(AtemSocketEvent::ReceivedCommands(payload.to_vec()));
|
||||
}
|
||||
|
||||
fn on_command_acknowledged(&mut self, packets: Vec<AckedPacket>) {
|
||||
for ack in packets {
|
||||
let _ = self
|
||||
.atem_event_tx
|
||||
.send(AtemEvent::AckedCommand(TrackingId(ack.tracking_id)));
|
||||
.send(AtemSocketEvent::AckedCommand(TrackingId(ack.tracking_id)));
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_connect(&mut self) {
|
||||
let _ = self.atem_event_tx.send(AtemEvent::Connected);
|
||||
let _ = self.atem_event_tx.send(AtemSocketEvent::Connected);
|
||||
let mut connected_callbacks = self.connected_callbacks.lock().await;
|
||||
for callback in connected_callbacks.drain(0..) {
|
||||
let _ = callback.send(false);
|
||||
|
@ -560,7 +559,7 @@ impl AtemSocket {
|
|||
}
|
||||
|
||||
fn on_disconnect(&mut self) {
|
||||
let _ = self.atem_event_tx.send(AtemEvent::Disconnected);
|
||||
let _ = self.atem_event_tx.send(AtemSocketEvent::Disconnected);
|
||||
}
|
||||
|
||||
fn start_timers(&mut self) {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::{collections::HashMap, fmt::Debug, sync::Arc};
|
||||
use std::{collections::HashMap, fmt::Debug, process::Command, sync::Arc};
|
||||
|
||||
use crate::{enums::ProtocolVersion, state::AtemState};
|
||||
|
||||
|
|
|
@ -8,12 +8,12 @@ use crate::{
|
|||
pub const DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME: &str = "_pin";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ProductIdentifierCommand {
|
||||
pub struct ProductIdentifier {
|
||||
pub product_identifier: String,
|
||||
pub model: Model,
|
||||
}
|
||||
|
||||
impl DeserializedCommand for ProductIdentifierCommand {
|
||||
impl DeserializedCommand for ProductIdentifier {
|
||||
fn raw_name(&self) -> &'static str {
|
||||
DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME
|
||||
}
|
||||
|
@ -40,14 +40,14 @@ impl DeserializedCommand for ProductIdentifierCommand {
|
|||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ProductIdentifierCommandDeserializer {}
|
||||
pub struct ProductIdentifierDeserializer {}
|
||||
|
||||
impl CommandDeserializer for ProductIdentifierCommandDeserializer {
|
||||
impl CommandDeserializer for ProductIdentifierDeserializer {
|
||||
fn deserialize(
|
||||
&self,
|
||||
buffer: &[u8],
|
||||
version: &ProtocolVersion,
|
||||
) -> std::sync::Arc<dyn DeserializedCommand> {
|
||||
) -> Arc<dyn DeserializedCommand> {
|
||||
let null_byte_index = buffer
|
||||
.iter()
|
||||
.position(|byte| *byte == b'\0')
|
||||
|
@ -57,7 +57,7 @@ impl CommandDeserializer for ProductIdentifierCommandDeserializer {
|
|||
.expect("Malformed string");
|
||||
let model = buffer[40];
|
||||
|
||||
Arc::new(ProductIdentifierCommand {
|
||||
Arc::new(ProductIdentifier {
|
||||
product_identifier: product_identifier
|
||||
.to_str()
|
||||
.expect("Invalid rust string")
|
||||
|
|
|
@ -8,7 +8,7 @@ use crate::{
|
|||
pub const DESERIALIZE_TOPOLOGY_RAW_NAME: &str = "_top";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TopologyCommand {
|
||||
pub struct Topology {
|
||||
mix_effects: u8,
|
||||
sources: u8,
|
||||
auxilliaries: u8,
|
||||
|
@ -27,7 +27,7 @@ pub struct TopologyCommand {
|
|||
only_configurable_outputs: bool,
|
||||
}
|
||||
|
||||
impl DeserializedCommand for TopologyCommand {
|
||||
impl DeserializedCommand for Topology {
|
||||
fn raw_name(&self) -> &'static str {
|
||||
todo!()
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ impl CommandDeserializer for TopologyCommandDeserializer {
|
|||
&self,
|
||||
buffer: &[u8],
|
||||
version: &ProtocolVersion,
|
||||
) -> std::sync::Arc<dyn crate::commands::command_base::DeserializedCommand> {
|
||||
) -> Arc<dyn DeserializedCommand> {
|
||||
let v230offset = if *version > ProtocolVersion::V8_0_1 {
|
||||
1
|
||||
} else {
|
||||
|
@ -69,7 +69,7 @@ impl CommandDeserializer for TopologyCommandDeserializer {
|
|||
false
|
||||
};
|
||||
|
||||
Arc::new(TopologyCommand {
|
||||
Arc::new(Topology {
|
||||
mix_effects: buffer[0],
|
||||
sources: buffer[1],
|
||||
downstream_keyers: buffer[2],
|
||||
|
|
|
@ -1,39 +1,25 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
commands::command_base::{CommandDeserializer, DeserializedCommand},
|
||||
enums::ProtocolVersion,
|
||||
};
|
||||
use crate::{commands::command_base::DeserializedCommand, enums::ProtocolVersion};
|
||||
|
||||
pub const DESERIALIZE_VERSION_RAW_NAME: &str = "_ver";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct VersionCommand {
|
||||
pub struct Version {
|
||||
pub version: ProtocolVersion,
|
||||
}
|
||||
|
||||
impl DeserializedCommand for VersionCommand {
|
||||
impl DeserializedCommand for Version {
|
||||
fn raw_name(&self) -> &'static str {
|
||||
DESERIALIZE_VERSION_RAW_NAME
|
||||
}
|
||||
|
||||
fn apply_to_state(&self, state: &mut crate::state::AtemState) {
|
||||
state.info.api_version = self.version;
|
||||
state.info.api_version = self.version.clone();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct VersionCommandDeserializer {}
|
||||
|
||||
impl CommandDeserializer for VersionCommandDeserializer {
|
||||
fn deserialize(
|
||||
&self,
|
||||
buffer: &[u8],
|
||||
version: &ProtocolVersion,
|
||||
) -> std::sync::Arc<dyn DeserializedCommand> {
|
||||
pub fn deserialize_version(buffer: &[u8]) -> Version {
|
||||
let version = u32::from_be_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]);
|
||||
let version: ProtocolVersion = version.try_into().expect("Invalid protocol version");
|
||||
|
||||
Arc::new(VersionCommand { version })
|
||||
}
|
||||
Version { version }
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ impl CommandDeserializer for InitCompleteDeserializer {
|
|||
&self,
|
||||
_buffer: &[u8],
|
||||
version: &ProtocolVersion,
|
||||
) -> std::sync::Arc<dyn DeserializedCommand> {
|
||||
) -> Arc<dyn DeserializedCommand> {
|
||||
Arc::new(InitComplete {})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
use std::{collections::VecDeque, sync::Arc};
|
||||
|
||||
use crate::enums::ProtocolVersion;
|
||||
use crate::{
|
||||
commands::device_profile::version::{deserialize_version, DESERIALIZE_VERSION_RAW_NAME},
|
||||
enums::ProtocolVersion,
|
||||
};
|
||||
|
||||
use super::{
|
||||
command_base::{CommandDeserializer, DeserializedCommand},
|
||||
device_profile::{
|
||||
product_identifier::{
|
||||
ProductIdentifierCommandDeserializer, DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME,
|
||||
},
|
||||
version::{VersionCommandDeserializer, DESERIALIZE_VERSION_RAW_NAME},
|
||||
device_profile::product_identifier::{
|
||||
ProductIdentifierDeserializer, DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME,
|
||||
},
|
||||
init_complete::{InitCompleteDeserializer, DESERIALIZE_INIT_COMPLETE_RAW_NAME},
|
||||
mix_effects::program_input::{ProgramInputDeserializer, DESERIALIZE_PROGRAM_INPUT_RAW_NAME},
|
||||
|
@ -18,9 +18,9 @@ use super::{
|
|||
|
||||
pub fn deserialize_commands(
|
||||
payload: &[u8],
|
||||
version: &ProtocolVersion,
|
||||
version: &mut ProtocolVersion,
|
||||
) -> VecDeque<Arc<dyn DeserializedCommand>> {
|
||||
let mut parsed_commands = VecDeque::new();
|
||||
let mut parsed_commands: VecDeque<Arc<dyn DeserializedCommand>> = VecDeque::new();
|
||||
let mut head = 0;
|
||||
|
||||
while payload.len() > head + 8 {
|
||||
|
@ -35,15 +35,21 @@ pub fn deserialize_commands(
|
|||
|
||||
log::debug!("Received command {} with length {}", name, length);
|
||||
|
||||
if let Some(deserializer) = command_deserializer_from_string(name.as_str()) {
|
||||
let deserialized_command =
|
||||
deserializer.deserialize(&payload[head + 8..head + length], version);
|
||||
let command_buffer = &payload[head + 8..head + length];
|
||||
|
||||
if name == DESERIALIZE_VERSION_RAW_NAME {
|
||||
let version_command = deserialize_version(command_buffer);
|
||||
*version = version_command.version.clone();
|
||||
log::info!("Switched to protocol version {}", version);
|
||||
parsed_commands.push_back(Arc::new(version_command));
|
||||
} else if let Some(deserializer) = command_deserializer_from_string(name.as_str()) {
|
||||
let deserialized_command = deserializer.deserialize(command_buffer, version);
|
||||
log::debug!("Received {:?}", deserialized_command);
|
||||
parsed_commands.push_back(deserialized_command);
|
||||
} else {
|
||||
log::warn!("Received command {name} for which there is no deserializer.");
|
||||
// TODO: Remove!
|
||||
// todo!("Write deserializer for {name}.");
|
||||
todo!("Write deserializer for {name}.");
|
||||
}
|
||||
|
||||
head += length;
|
||||
|
@ -54,13 +60,12 @@ pub fn deserialize_commands(
|
|||
|
||||
fn command_deserializer_from_string(command_str: &str) -> Option<Box<dyn CommandDeserializer>> {
|
||||
match command_str {
|
||||
DESERIALIZE_VERSION_RAW_NAME => Some(Box::<VersionCommandDeserializer>::default()),
|
||||
DESERIALIZE_INIT_COMPLETE_RAW_NAME => Some(Box::<InitCompleteDeserializer>::default()),
|
||||
DESERIALIZE_PROGRAM_INPUT_RAW_NAME => Some(Box::<ProgramInputDeserializer>::default()),
|
||||
DESERIALIZE_TALLY_BY_SOURCE_RAW_NAME => Some(Box::<TallyBySourceDeserializer>::default()),
|
||||
DESERIALIZE_TIME_RAW_NAME => Some(Box::<TimeDeserializer>::default()),
|
||||
DESERIALIZE_PRODUCT_IDENTIFIER_RAW_NAME => {
|
||||
Some(Box::<ProductIdentifierCommandDeserializer>::default())
|
||||
Some(Box::<ProductIdentifierDeserializer>::default())
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ impl CommandDeserializer for TimeDeserializer {
|
|||
&self,
|
||||
buffer: &[u8],
|
||||
version: &ProtocolVersion,
|
||||
) -> std::sync::Arc<dyn super::command_base::DeserializedCommand> {
|
||||
) -> Arc<dyn DeserializedCommand> {
|
||||
let info = TimeInfo {
|
||||
hour: buffer[0],
|
||||
minute: buffer[1],
|
||||
|
|
|
@ -73,7 +73,7 @@ impl From<u8> for Model {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy, PartialEq, PartialOrd)]
|
||||
#[derive(Debug, Default, Clone, PartialEq, PartialOrd)]
|
||||
pub enum ProtocolVersion {
|
||||
#[default]
|
||||
Unknown = 0,
|
||||
|
|
|
@ -35,12 +35,10 @@ async fn main() {
|
|||
tokio::sync::mpsc::channel::<AtemSocketMessage>(10);
|
||||
let (atem_event_tx, atem_event_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
let cancel = CancellationToken::new();
|
||||
let cancel_task = cancel.clone();
|
||||
|
||||
let mut atem_socket = AtemSocket::new(atem_event_tx);
|
||||
let atem_socket_run = atem_socket.run(socket_message_rx, cancel_task);
|
||||
let mut atem_socket = AtemSocket::new(socket_message_rx, atem_event_tx);
|
||||
|
||||
let atem = Arc::new(Atem::new(socket_message_tx));
|
||||
let atem = Arc::new(Atem::new(atem_socket, socket_message_tx));
|
||||
let atem_thread = atem.clone();
|
||||
let atem_run = atem_thread.run(atem_event_rx, cancel);
|
||||
|
||||
|
@ -64,7 +62,6 @@ async fn main() {
|
|||
});
|
||||
|
||||
select! {
|
||||
_ = atem_socket_run => {},
|
||||
_ = atem_run => {},
|
||||
_ = switch_loop => {}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue