Add discord voice basics

Joins a voice channel when a call is running, leaves when ended
Autoformat
This commit is contained in:
Sam W 2022-07-17 13:47:27 +01:00
parent 6042cc7a82
commit 26c481deba
4 changed files with 333 additions and 573 deletions

698
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -10,13 +10,13 @@ rand = "0.8.5"
rsip = "0.4.0" rsip = "0.4.0"
rtp = "0.6.5" rtp = "0.6.5"
sdp-rs = "0.2.1" sdp-rs = "0.2.1"
songbird = { version = "0.2.2", features = ["builtin-queue"]} songbird = { git = "https://github.com/serenity-rs/songbird", branch = "next", default_features = false, features = ["driver", "twilight-rustls", "zlib-stock"] }
tokio = { version = "1.19.2", features = ["full"] } tokio = { version = "1.19.2", features = ["full"] }
tokio-stream = "0.1.9" tokio-stream = "0.1.9"
tokio-util = { version = "0.7.3", features = ["net", "codec"] } tokio-util = { version = "0.7.3", features = ["net", "codec"] }
tracing = "0.1.35" tracing = "0.1.35"
tracing-subscriber = "0.3.14" tracing-subscriber = "0.3.14"
twilight-gateway = "0.11.1" twilight-gateway = "0.11.1"
twilight-http = "0.11.1" twilight-http = { version = "0.11.1"}
twilight-model = "0.11.3" twilight-model = "0.11.3"
webrtc-util = "0.5.4" webrtc-util = "0.5.4"

View File

@ -2,6 +2,7 @@ use crate::StdErr;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand::{prelude::ThreadRng, thread_rng}; use rand::{prelude::ThreadRng, thread_rng};
use songbird::Songbird;
use std::cell::RefCell; use std::cell::RefCell;
use std::sync::Arc; use std::sync::Arc;
use tokio::select; use tokio::select;
@ -21,18 +22,37 @@ use twilight_model::{
thread_local!(static RNG: RefCell<ThreadRng> = RefCell::new(thread_rng())); thread_local!(static RNG: RefCell<ThreadRng> = RefCell::new(thread_rng()));
const GREETINGS: &'static [&'static str] = const GREETINGS: &'static [&'static str] = &[
&["You rung?", "Hello there", "Right back atcha", "Yes?"]; "You rung?",
"Hello there",
"Right back atcha",
"Yes?",
"Wow, rude",
"Go stick your head in a pig",
];
#[derive(Debug)]
pub struct Call {
pub guild_id: u64,
pub channel_id: u64,
pub done: mpsc::Receiver<()>,
}
struct State {
songbird: Songbird,
client: Client,
}
pub async fn run_discord( pub async fn run_discord(
token: &str, token: &str,
mut calls: mpsc::Receiver<Call>,
mut shutdown: broadcast::Receiver<()>, mut shutdown: broadcast::Receiver<()>,
_done: mpsc::Sender<()>, _done: mpsc::Sender<()>,
) -> StdErr<()> { ) -> StdErr<()> {
event!(Level::INFO, "Starting..."); event!(Level::INFO, "Starting...");
let (cluster, mut events) = Cluster::builder( let (cluster, mut events) = Cluster::builder(
token.to_owned(), token.to_owned(),
Intents::GUILD_MESSAGES | Intents::MESSAGE_CONTENT, Intents::GUILD_MESSAGES | Intents::MESSAGE_CONTENT | Intents::GUILD_VOICE_STATES,
) )
.presence(UpdatePresencePayload::new( .presence(UpdatePresencePayload::new(
vec![MinimalActivity { vec![MinimalActivity {
@ -47,21 +67,32 @@ pub async fn run_discord(
)?) )?)
.build() .build()
.await?; .await?;
let cluster = Arc::new(cluster); let cluster = Arc::new(cluster);
let cluster_spawn = Arc::clone(&cluster); let client = Client::new(token.to_owned());
let me = client.current_user().exec().await?.model().await?.id;
let sb = Songbird::twilight(cluster.clone(), me);
let state = Arc::new(State {
songbird: sb,
client,
});
let cluster_spawn = cluster.clone();
tokio::spawn(async move { cluster_spawn.up().await }); tokio::spawn(async move { cluster_spawn.up().await });
let client = Arc::new(Client::new(token.to_owned()));
loop { loop {
select!( select!(
Some((shard_id, ev)) = events.next() => { Some((shard_id, ev)) = events.next() => {
let client = client.clone(); state.songbird.process(&ev).await;
let state_ = state.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = handle_event(ev, shard_id, client).await { if let Err(e) = handle_event(ev, shard_id, state_, me).await {
event!(Level::ERROR, err=?e, "Error handling discord event"); event!(Level::ERROR, err=?e, "Error handling discord event");
} }
}); });
}, },
Some(call) = calls.recv() => {
let state_ = state.clone();
tokio::spawn(handle_call(call, state_));
},
_ = shutdown.recv() => { _ = shutdown.recv() => {
event!(Level::INFO, "Shutting down..."); event!(Level::INFO, "Shutting down...");
cluster.down(); cluster.down();
@ -74,18 +105,41 @@ pub async fn run_discord(
Ok(()) Ok(())
} }
#[instrument(skip(call, state))]
async fn handle_call(mut call: Call, state: Arc<State>) {
let (_handle, success) = state.songbird.join(call.guild_id, call.channel_id).await;
match success {
Ok(()) => event!(Level::INFO, %call.guild_id, %call.channel_id, "Joined channel"),
Err(err) => {
event!(Level::ERROR, %call.guild_id, %call.channel_id, %err, "Error joining channel")
}
}
let _ = call.done.recv().await;
match state.songbird.leave(call.guild_id).await {
Ok(()) => event!(Level::INFO, %call.guild_id, %call.channel_id, "Left channel"),
Err(err) => {
event!(Level::INFO, %call.guild_id, %call.channel_id, %err, "Error leaving channel")
}
}
}
fn mentions(msg: &Message, us: Id<UserMarker>) -> bool { fn mentions(msg: &Message, us: Id<UserMarker>) -> bool {
msg.mentions.iter().filter(|m| m.id == us).next().is_some() msg.mentions.iter().filter(|m| m.id == us).next().is_some()
} }
#[instrument(skip(client))] #[instrument(skip(state))]
async fn handle_event(ev: Event, shard_id: u64, client: Arc<Client>) -> StdErr<()> { async fn handle_event(
let me = client.current_user().exec().await?.model().await?; ev: Event,
shard_id: u64,
state: Arc<State>,
me: Id<UserMarker>,
) -> StdErr<()> {
match ev { match ev {
Event::MessageCreate(msg) => { Event::MessageCreate(msg) => {
if mentions(&msg.0, me.id) { if mentions(&msg.0, me) {
let greet = RNG.with(|rng| GREETINGS.choose(&mut *rng.borrow_mut()).unwrap()); let greet = RNG.with(|rng| GREETINGS.choose(&mut *rng.borrow_mut()).unwrap());
client state
.client
.create_message(msg.channel_id) .create_message(msg.channel_id)
.content(greet)? .content(greet)?
.exec() .exec()

View File

@ -1,18 +1,18 @@
use futures::{sink::Sink, SinkExt}; use futures::{sink::Sink, SinkExt};
use tokio::sync::{broadcast, mpsc};
use tokio::signal;
use rsip::common::method::Method as SipMethod; use rsip::common::method::Method as SipMethod;
use rsip::headers::header::Header as SipHeader; use rsip::headers::header::Header as SipHeader;
use rsip::message::{request::Request, response::Response, SipMessage}; use rsip::message::{request::Request, response::Response, SipMessage};
use sdp_rs::lines::media::{MediaType, ProtoType}; use sdp_rs::lines::media::{MediaType, ProtoType};
use sdp_rs::lines::{attribute::Rtpmap, Attribute, Media}; use sdp_rs::lines::{attribute::Rtpmap, Attribute, Media};
use sdp_rs::{MediaDescription, SessionDescription}; use sdp_rs::{MediaDescription, SessionDescription};
use std::net::SocketAddr;
use std::str;
use std::env; use std::env;
use std::net::IpAddr; use std::net::IpAddr;
use std::net::SocketAddr;
use std::str;
use std::time::Duration; use std::time::Duration;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio::signal;
use tokio::sync::{broadcast, mpsc};
use tokio::time::sleep; use tokio::time::sleep;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tokio_util::{ use tokio_util::{
@ -20,17 +20,15 @@ use tokio_util::{
udp::UdpFramed, udp::UdpFramed,
}; };
use tracing::{event, instrument, Level}; use tracing::{event, instrument, Level};
use songbird::driver::Driver;
mod codecs; mod codecs;
use codecs::{SipCodec, RtpCodec}; use codecs::{RtpCodec, SipCodec};
mod discord; mod discord;
const SIP_PORT: u16 = 5060; const SIP_PORT: u16 = 5060;
const BIND_ADDR: &str = "0.0.0.0"; // for now const BIND_ADDR: &str = "0.0.0.0"; // for now
type StdErr<T> = Result<T, Box<dyn std::error::Error>>; type StdErr<T> = Result<T, Box<dyn std::error::Error>>;
struct Server {} struct Server {}
@ -42,8 +40,12 @@ struct CurrentCall {
} }
impl Server { impl Server {
#[instrument(skip(shutdown, _done))] #[instrument(skip(call_tx, shutdown, _done))]
async fn run_sip(mut shutdown: broadcast::Receiver<()>, _done: mpsc::Sender<()>) -> StdErr<()> { async fn run_sip(
call_tx: mpsc::Sender<discord::Call>,
mut shutdown: broadcast::Receiver<()>,
_done: mpsc::Sender<()>,
) -> StdErr<()> {
event!(Level::INFO, "Starting..."); event!(Level::INFO, "Starting...");
let socket = UdpSocket::bind(format!("{}:{}", BIND_ADDR, SIP_PORT)).await?; let socket = UdpSocket::bind(format!("{}:{}", BIND_ADDR, SIP_PORT)).await?;
let mut framed = UdpFramed::new(socket, SipCodec {}); let mut framed = UdpFramed::new(socket, SipCodec {});
@ -89,9 +91,10 @@ impl Server {
// remote's address. This means that `handle_call` // remote's address. This means that `handle_call`
// doesn't need to be aware of the remote at all. // doesn't need to be aware of the remote at all.
let res_tx = response_tx.clone(); let res_tx = response_tx.clone();
let call_tx = call_tx.clone();
tokio::spawn(async move { tokio::spawn(async move {
let res_tx = PollSender::new(res_tx).with::<_, _, _, PollSendError<_>>(|res| {futures::future::ready(Ok((res, remote)))}); let res_tx = PollSender::new(res_tx).with::<_, _, _, PollSendError<_>>(|res| {futures::future::ready(Ok((res, remote)))});
Self::handle_call(&req, res_tx, req_rx).await.unwrap(); Self::handle_call(&req, res_tx, req_rx, call_tx).await.unwrap();
}); });
} }
Some(call) => { Some(call) => {
@ -158,12 +161,16 @@ impl Server {
if r.encoding_name == "opus" { if r.encoding_name == "opus" {
let prefix = format!("{} ", r.payload_type); let prefix = format!("{} ", r.payload_type);
// Find the matching fmtp if there is one // Find the matching fmtp if there is one
let fmtp = md.attributes.iter().filter_map(|a| match a{ let fmtp = md
.attributes
.iter()
.filter_map(|a| match a {
Attribute::Other(fmt, Some(params)) if fmt == "fmtp" => { Attribute::Other(fmt, Some(params)) if fmt == "fmtp" => {
params.strip_prefix(&prefix) params.strip_prefix(&prefix)
}, }
_ => None _ => None,
}).next(); })
.next();
Some((r.clone(), fmtp)) Some((r.clone(), fmtp))
} else { } else {
None None
@ -178,13 +185,14 @@ impl Server {
// Handle a call // Handle a call
#[instrument( #[instrument(
level = "info", level = "info",
skip(invite, responses, requests) skip(invite, responses, requests, call_tx)
fields() fields()
)] )]
async fn handle_call<T: Sink<Response> + std::marker::Unpin>( async fn handle_call<T: Sink<Response> + std::marker::Unpin>(
invite: &Request, invite: &Request,
mut responses: T, mut responses: T,
mut requests: mpsc::Receiver<Request>, mut requests: mpsc::Receiver<Request>,
call_tx: mpsc::Sender<discord::Call>,
) -> StdErr<()> { ) -> StdErr<()> {
let mut base_res = Response::default(); let mut base_res = Response::default();
// Copy headers from the invite // Copy headers from the invite
@ -230,7 +238,7 @@ impl Server {
let socket = UdpSocket::bind(format!("{}:0", BIND_ADDR)).await?; let socket = UdpSocket::bind(format!("{}:0", BIND_ADDR)).await?;
let rtp_port = socket.local_addr()?.port(); let rtp_port = socket.local_addr()?.port();
event!(Level::INFO, rtp_port, "Bound RTP port"); event!(Level::INFO, rtp_port, "Bound RTP port");
let mut rtp_framed = UdpFramed::new(socket, RtpCodec{}); let mut rtp_framed = UdpFramed::new(socket, RtpCodec {});
let mut res = base_res.clone(); let mut res = base_res.clone();
res.status_code = 180.into(); res.status_code = 180.into();
responses.send(res).await; responses.send(res).await;
@ -240,8 +248,11 @@ impl Server {
res.status_code = 200.into(); res.status_code = 200.into();
// TODO: fix this lmao // TODO: fix this lmao
let ip: IpAddr = "10.23.2.134".parse()?; let ip: IpAddr = "10.23.2.134".parse()?;
res.headers.push(SipHeader::Contact(format!("sip:{}:{}", ip, SIP_PORT).into())); res.headers.push(SipHeader::Contact(
res.headers.push(SipHeader::ContentType("application/sdp".into())); format!("sip:{}:{}", ip, SIP_PORT).into(),
));
res.headers
.push(SipHeader::ContentType("application/sdp".into()));
let md = MediaDescription { let md = MediaDescription {
media: Media { media: Media {
media: MediaType::Audio, media: MediaType::Audio,
@ -250,19 +261,26 @@ impl Server {
proto: ProtoType::RtpAvp, proto: ProtoType::RtpAvp,
fmt: "101".into(), fmt: "101".into(),
}, },
connections: vec!(sdp_rs::lines::Connection{ connections: vec![sdp_rs::lines::Connection {
nettype: "IN".into(), nettype: "IN".into(),
addrtype: "IP4".into(), addrtype: "IP4".into(),
connection_address: ip.into(), connection_address: ip.into(),
}), }],
bandwidths: vec!(sdp_rs::lines::Bandwidth{bwtype:"TIAS".into(), bandwidth: 64000}), bandwidths: vec![sdp_rs::lines::Bandwidth {
attributes: vec!(sdp_rs::lines::Attribute::Other("rtpmap".into(), Some("101 opus/48000/2".into()))), bwtype: "TIAS".into(),
bandwidth: 64000,
}],
attributes: vec![sdp_rs::lines::Attribute::Other(
"rtpmap".into(),
Some("101 opus/48000/2".into()),
)],
info: None, info: None,
key: None, key: None,
}.into(); }
.into();
let sd: String = SessionDescription { let sd: String = SessionDescription {
version: sdp_rs::lines::Version::V0, version: sdp_rs::lines::Version::V0,
origin: sdp_rs::lines::Origin{ origin: sdp_rs::lines::Origin {
username: "-".into(), username: "-".into(),
sess_id: "foobar".into(), sess_id: "foobar".into(),
sess_version: "foobar".into(), sess_version: "foobar".into(),
@ -273,26 +291,37 @@ impl Server {
session_name: "discosip".to_owned().into(), session_name: "discosip".to_owned().into(),
session_info: None, session_info: None,
uri: None, uri: None,
emails: vec!(), emails: vec![],
phones: vec!(), phones: vec![],
connection: None, connection: None,
bandwidths: vec!(sdp_rs::lines::Bandwidth{bwtype:"AS".into(), bandwidth: 84}), bandwidths: vec![sdp_rs::lines::Bandwidth {
times: vec!(sdp_rs::Time{ bwtype: "AS".into(),
active: sdp_rs::lines::Active{ bandwidth: 84,
start: 0, }],
stop: 0, times: vec![sdp_rs::Time {
}, active: sdp_rs::lines::Active { start: 0, stop: 0 },
repeat: vec!(), repeat: vec![],
zone: None, zone: None,
}).try_into()?, }]
.try_into()?,
key: None, key: None,
attributes: vec!(), attributes: vec![],
media_descriptions: vec!(md), media_descriptions: vec![md],
}.to_string(); }
.to_string();
res.body = sd.as_bytes().into(); res.body = sd.as_bytes().into();
res.headers.push(SipHeader::ContentLength((res.body.len() as u32).into())); res.headers
.push(SipHeader::ContentLength((res.body.len() as u32).into()));
responses.send(res).await; responses.send(res).await;
// Discord
let (disc_done, disc_done_rx) = mpsc::channel(1);
let disc_call = discord::Call {
guild_id: env::var("DISCORD_GUILD")?.parse()?,
channel_id: env::var("DISCORD_CHANNEL")?.parse()?,
done: disc_done_rx,
};
call_tx.send(disc_call).await?;
loop { loop {
tokio::select! { tokio::select! {
@ -308,12 +337,12 @@ impl Server {
} }
}, },
Some(rtp_frame) = rtp_framed.next() => { Some(_) = rtp_framed.next() => {
event!(Level::INFO, "Got RTP Packet!"); //event!(Level::INFO, "Got RTP Packet!");
}, },
} }
} }
disc_done.send(()).await?;
event!(Level::INFO, "Call handler loop done."); event!(Level::INFO, "Call handler loop done.");
Ok(()) Ok(())
@ -327,12 +356,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let (shutdown_tx, shutdown_rx_1) = broadcast::channel(1); let (shutdown_tx, shutdown_rx_1) = broadcast::channel(1);
let (done_tx, mut done_rx) = mpsc::channel(1); let (done_tx, mut done_rx) = mpsc::channel(1);
let done_2 = done_tx.clone(); let done_2 = done_tx.clone();
tokio::spawn( async move { let (call_tx, call_rx) = mpsc::channel(1);
discord::run_discord(&discord_token, shutdown_rx_1, done_tx.clone()).await; tokio::spawn(async move {
discord::run_discord(&discord_token, call_rx, shutdown_rx_1, done_tx.clone()).await;
}); });
let sd_2 = shutdown_tx.subscribe(); let sd_2 = shutdown_tx.subscribe();
tokio::spawn(async move { tokio::spawn(async move {
Server::run_sip(sd_2, done_2).await; Server::run_sip(call_tx, sd_2, done_2).await;
}); });
signal::ctrl_c().await?; signal::ctrl_c().await?;