Add chat websocket

This commit is contained in:
Lucas Schumacher 2023-11-01 00:14:34 -04:00
parent 62621e7fe6
commit 748b7968cd
3 changed files with 89 additions and 61 deletions

9
Cargo.lock generated
View File

@ -509,6 +509,7 @@ dependencies = [
"poem-openapi", "poem-openapi",
"serde", "serde",
"sqlx", "sqlx",
"thiserror",
"tokio", "tokio",
"tracing-subscriber", "tracing-subscriber",
] ]
@ -2019,18 +2020,18 @@ dependencies = [
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.49" version = "1.0.50"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4" checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl",
] ]
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "1.0.49" version = "1.0.50"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc" checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",

View File

@ -10,10 +10,8 @@ futures-util = "0.3.17"
tracing-subscriber = { version ="0.3.9", features = ["env-filter"] } tracing-subscriber = { version ="0.3.9", features = ["env-filter"] }
sqlx = { version = "0.7.2", features = ["runtime-tokio", "sqlite"] } sqlx = { version = "0.7.2", features = ["runtime-tokio", "sqlite"] }
bcrypt = "0.15.0" bcrypt = "0.15.0"
anyhow = "1.0.75"
poem-openapi = { version = "3.0.5", features = ["websocket"] } poem-openapi = { version = "3.0.5", features = ["websocket"] }
serde = "1.0.190" serde = "1.0.190"
thiserror = "1.0.50"
[profile.dev.package.sqlx-macros] anyhow = "1.0.75"
opt-level = 3

View File

@ -1,27 +1,26 @@
//use anyhow::Result;
//use bcrypt::{hash, verify, DEFAULT_COST};
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use poem::{ use poem::{
endpoint::StaticFilesEndpoint, endpoint::StaticFilesEndpoint,
//get, error::ResponseError,
handler, get, handler,
http::StatusCode,
listener::TcpListener, listener::TcpListener,
session::{CookieConfig, MemoryStorage, ServerSession, Session}, session::{CookieConfig, MemoryStorage, ServerSession, Session},
web::{ web::{
websocket::{Message, WebSocket}, websocket::{BoxWebSocketUpgraded, Message, WebSocket},
Data, Html, Path, Html,
}, },
EndpointExt, EndpointExt, Result, Route, Server,
IntoResponse,
Result,
Route,
Server,
}; };
use poem_openapi::{ use poem_openapi::{
payload::Json, types::Password, ApiResponse, Object, OpenApi, OpenApiService, Tags, payload::Json,
registry::{MetaResponse, MetaResponses, Registry},
types::Password,
ApiResponse, Object, OpenApi, OpenApiService, Tags,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{Pool, Sqlite}; use sqlx::{Pool, Sqlite};
use tokio::sync::broadcast::Sender;
#[handler] #[handler]
fn index() -> Html<&'static str> { fn index() -> Html<&'static str> {
@ -56,7 +55,7 @@ fn index() -> Html<&'static str> {
sendForm.hidden = false; sendForm.hidden = false;
msgsArea.hidden = false; msgsArea.hidden = false;
msgInput.focus(); msgInput.focus();
ws = new WebSocket("ws://127.0.0.1:2000/api/ws/" + nameInput.value); ws = new WebSocket("ws://localhost:2000/api/ws");
ws.onmessage = function(event) { ws.onmessage = function(event) {
msgsArea.value += event.data + "\r\n"; msgsArea.value += event.data + "\r\n";
} }
@ -73,37 +72,6 @@ fn index() -> Html<&'static str> {
) )
} }
#[handler]
fn ws(
Path(name): Path<String>,
ws: WebSocket,
sender: Data<&tokio::sync::broadcast::Sender<String>>,
) -> impl IntoResponse {
let sender = sender.clone();
let mut receiver = sender.subscribe();
ws.on_upgrade(move |socket| async move {
let (mut sink, mut stream) = socket.split();
tokio::spawn(async move {
while let Some(Ok(msg)) = stream.next().await {
if let Message::Text(text) = msg {
if sender.send(format!("{name}: {text}")).is_err() {
break;
}
}
}
});
tokio::spawn(async move {
while let Ok(msg) = receiver.recv().await {
if sink.send(Message::Text(msg)).await.is_err() {
break;
}
}
});
})
}
#[derive(Debug, Object, Clone, Eq, PartialEq, Serialize, Deserialize)] #[derive(Debug, Object, Clone, Eq, PartialEq, Serialize, Deserialize)]
struct User { struct User {
#[oai(validator(max_length = 64))] #[oai(validator(max_length = 64))]
@ -159,6 +127,7 @@ enum NumResponse {
#[oai(status = 403)] #[oai(status = 403)]
AuthError, AuthError,
} }
#[derive(Tags)] #[derive(Tags)]
enum ApiTags { enum ApiTags {
/// Operations about user /// Operations about user
@ -167,12 +136,68 @@ enum ApiTags {
async fn valid_code(code: &str) -> bool { async fn valid_code(code: &str) -> bool {
"changeme" == code "changeme" == code
} }
struct Api { #[derive(Debug, thiserror::Error)]
db: Pool<Sqlite>, #[error("API Error")]
signup_open: bool, //TODO Should be in the db so it can change without restart struct ApiError();
impl ResponseError for ApiError {
fn status(&self) -> StatusCode {
StatusCode::FORBIDDEN
}
} }
impl ApiResponse for ApiError {
fn meta() -> MetaResponses {
MetaResponses {
responses: vec![MetaResponse {
description: "An Error response",
status: Some(403),
content: vec![],
headers: vec![],
}],
}
}
fn register(_registry: &mut Registry) {}
}
#[OpenApi] #[OpenApi]
impl Api { impl Api {
#[oai(path = "/ws", method = "get")]
async fn ws(
&self,
websock: WebSocket,
session: &Session,
) -> Result<BoxWebSocketUpgraded, ApiError> {
let name = session.get::<String>("user").ok_or(ApiError {})?;
let sender = self.channel.clone();
let mut receiver = sender.subscribe();
let x = websock
.on_upgrade(move |socket| async move {
let (mut sink, mut stream) = socket.split();
tokio::spawn(async move {
while let Some(Ok(msg)) = stream.next().await {
if let Message::Text(text) = msg {
if sender.send(format!("{name}: {text}")).is_err() {
break;
}
}
}
});
tokio::spawn(async move {
while let Ok(msg) = receiver.recv().await {
if sink.send(Message::Text(msg)).await.is_err() {
break;
}
}
});
})
.boxed();
Result::Ok(x)
}
#[oai(path = "/user", method = "post", tag = "ApiTags::User")] #[oai(path = "/user", method = "post", tag = "ApiTags::User")]
async fn create_user(&self, user_form: Json<NewUser>, session: &Session) -> NewUserResponse { async fn create_user(&self, user_form: Json<NewUser>, session: &Session) -> NewUserResponse {
let has_referral = match &user_form.referral { let has_referral = match &user_form.referral {
@ -271,6 +296,12 @@ impl Api {
} }
} }
struct Api {
db: Pool<Sqlite>,
channel: Sender<String>,
signup_open: bool, //TODO Should be in the db so it can change without restart
}
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
if std::env::var_os("RUST_LOG").is_none() { if std::env::var_os("RUST_LOG").is_none() {
@ -280,11 +311,14 @@ async fn main() -> anyhow::Result<()> {
let session = ServerSession::new(CookieConfig::default(), MemoryStorage::new()); let session = ServerSession::new(CookieConfig::default(), MemoryStorage::new());
let channel = tokio::sync::broadcast::channel::<String>(128).0;
//let dbpool = SqlitePool::connect("sqlite:chat.db").await?; //let dbpool = SqlitePool::connect("sqlite:chat.db").await?;
let dbpool = Pool::<Sqlite>::connect("sqlite:chat.db").await?; let dbpool = Pool::<Sqlite>::connect("sqlite:chat.db").await?;
let api = OpenApiService::new( let api = OpenApiService::new(
Api { Api {
db: dbpool, db: dbpool,
channel,
signup_open: false, signup_open: false,
}, },
"Chat", "Chat",
@ -293,12 +327,7 @@ async fn main() -> anyhow::Result<()> {
let static_endpoint = StaticFilesEndpoint::new("./client/dist").index_file("index.html"); let static_endpoint = StaticFilesEndpoint::new("./client/dist").index_file("index.html");
let app = Route::new() let app = Route::new()
//.at("/api/", get(index)) .at("/api/test", get(index))
//.at(
// "/api/ws/:name",
// get(ws.data(tokio::sync::broadcast::channel::<String>(32).0)),
//)
//.data(dbpool)
.nest("/", static_endpoint) .nest("/", static_endpoint)
.nest("/api", api) .nest("/api", api)
.with(session); .with(session);