diff --git a/Cargo.lock b/Cargo.lock index e23cfa7..b740ffa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -509,6 +509,7 @@ dependencies = [ "poem-openapi", "serde", "sqlx", + "thiserror", "tokio", "tracing-subscriber", ] @@ -2019,18 +2020,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.49" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4" +checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.49" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc" +checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 23fd865..c0d7d85 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,10 +10,8 @@ futures-util = "0.3.17" tracing-subscriber = { version ="0.3.9", features = ["env-filter"] } sqlx = { version = "0.7.2", features = ["runtime-tokio", "sqlite"] } bcrypt = "0.15.0" -anyhow = "1.0.75" poem-openapi = { version = "3.0.5", features = ["websocket"] } serde = "1.0.190" - -[profile.dev.package.sqlx-macros] -opt-level = 3 +thiserror = "1.0.50" +anyhow = "1.0.75" diff --git a/src/main.rs b/src/main.rs index b7e7b67..1f1c804 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,27 +1,26 @@ -//use anyhow::Result; -//use bcrypt::{hash, verify, DEFAULT_COST}; use futures_util::{SinkExt, StreamExt}; use poem::{ endpoint::StaticFilesEndpoint, - //get, - handler, + error::ResponseError, + get, handler, + http::StatusCode, listener::TcpListener, session::{CookieConfig, MemoryStorage, ServerSession, Session}, web::{ - websocket::{Message, WebSocket}, - Data, Html, Path, + websocket::{BoxWebSocketUpgraded, Message, WebSocket}, + Html, }, - EndpointExt, - IntoResponse, - Result, - Route, - Server, + EndpointExt, Result, Route, Server, }; use poem_openapi::{ - payload::Json, types::Password, ApiResponse, Object, OpenApi, OpenApiService, Tags, + payload::Json, + registry::{MetaResponse, MetaResponses, Registry}, + types::Password, + ApiResponse, Object, OpenApi, OpenApiService, Tags, }; use serde::{Deserialize, Serialize}; use sqlx::{Pool, Sqlite}; +use tokio::sync::broadcast::Sender; #[handler] fn index() -> Html<&'static str> { @@ -56,7 +55,7 @@ fn index() -> Html<&'static str> { sendForm.hidden = false; msgsArea.hidden = false; msgInput.focus(); - ws = new WebSocket("ws://127.0.0.1:2000/api/ws/" + nameInput.value); + ws = new WebSocket("ws://localhost:2000/api/ws"); ws.onmessage = function(event) { msgsArea.value += event.data + "\r\n"; } @@ -73,37 +72,6 @@ fn index() -> Html<&'static str> { ) } -#[handler] -fn ws( - Path(name): Path, - ws: WebSocket, - sender: Data<&tokio::sync::broadcast::Sender>, -) -> impl IntoResponse { - let sender = sender.clone(); - let mut receiver = sender.subscribe(); - ws.on_upgrade(move |socket| async move { - let (mut sink, mut stream) = socket.split(); - - tokio::spawn(async move { - while let Some(Ok(msg)) = stream.next().await { - if let Message::Text(text) = msg { - if sender.send(format!("{name}: {text}")).is_err() { - break; - } - } - } - }); - - tokio::spawn(async move { - while let Ok(msg) = receiver.recv().await { - if sink.send(Message::Text(msg)).await.is_err() { - break; - } - } - }); - }) -} - #[derive(Debug, Object, Clone, Eq, PartialEq, Serialize, Deserialize)] struct User { #[oai(validator(max_length = 64))] @@ -159,6 +127,7 @@ enum NumResponse { #[oai(status = 403)] AuthError, } + #[derive(Tags)] enum ApiTags { /// Operations about user @@ -167,12 +136,68 @@ enum ApiTags { async fn valid_code(code: &str) -> bool { "changeme" == code } -struct Api { - db: Pool, - signup_open: bool, //TODO Should be in the db so it can change without restart +#[derive(Debug, thiserror::Error)] +#[error("API Error")] +struct ApiError(); + +impl ResponseError for ApiError { + fn status(&self) -> StatusCode { + StatusCode::FORBIDDEN + } } +impl ApiResponse for ApiError { + fn meta() -> MetaResponses { + MetaResponses { + responses: vec![MetaResponse { + description: "An Error response", + status: Some(403), + content: vec![], + headers: vec![], + }], + } + } + fn register(_registry: &mut Registry) {} +} + #[OpenApi] impl Api { + #[oai(path = "/ws", method = "get")] + async fn ws( + &self, + websock: WebSocket, + session: &Session, + ) -> Result { + let name = session.get::("user").ok_or(ApiError {})?; + + let sender = self.channel.clone(); + let mut receiver = sender.subscribe(); + let x = websock + .on_upgrade(move |socket| async move { + let (mut sink, mut stream) = socket.split(); + + tokio::spawn(async move { + while let Some(Ok(msg)) = stream.next().await { + if let Message::Text(text) = msg { + if sender.send(format!("{name}: {text}")).is_err() { + break; + } + } + } + }); + + tokio::spawn(async move { + while let Ok(msg) = receiver.recv().await { + if sink.send(Message::Text(msg)).await.is_err() { + break; + } + } + }); + }) + .boxed(); + + Result::Ok(x) + } + #[oai(path = "/user", method = "post", tag = "ApiTags::User")] async fn create_user(&self, user_form: Json, session: &Session) -> NewUserResponse { let has_referral = match &user_form.referral { @@ -271,6 +296,12 @@ impl Api { } } +struct Api { + db: Pool, + channel: Sender, + signup_open: bool, //TODO Should be in the db so it can change without restart +} + #[tokio::main] async fn main() -> anyhow::Result<()> { if std::env::var_os("RUST_LOG").is_none() { @@ -280,11 +311,14 @@ async fn main() -> anyhow::Result<()> { let session = ServerSession::new(CookieConfig::default(), MemoryStorage::new()); + let channel = tokio::sync::broadcast::channel::(128).0; + //let dbpool = SqlitePool::connect("sqlite:chat.db").await?; let dbpool = Pool::::connect("sqlite:chat.db").await?; let api = OpenApiService::new( Api { db: dbpool, + channel, signup_open: false, }, "Chat", @@ -293,12 +327,7 @@ async fn main() -> anyhow::Result<()> { let static_endpoint = StaticFilesEndpoint::new("./client/dist").index_file("index.html"); let app = Route::new() - //.at("/api/", get(index)) - //.at( - // "/api/ws/:name", - // get(ws.data(tokio::sync::broadcast::channel::(32).0)), - //) - //.data(dbpool) + .at("/api/test", get(index)) .nest("/", static_endpoint) .nest("/api", api) .with(session);