use futures_util::{SinkExt, StreamExt}; use poem::{ endpoint::StaticFilesEndpoint, error::ResponseError, get, handler, http::StatusCode, listener::TcpListener, session::{CookieConfig, MemoryStorage, ServerSession, Session}, web::{ websocket::{BoxWebSocketUpgraded, Message, WebSocket}, Html, }, EndpointExt, Result, Route, Server, }; use poem_openapi::{ payload::Json, registry::{MetaResponse, MetaResponses, Registry}, types::Password, ApiResponse, Object, OpenApi, OpenApiService, Tags, }; use serde::{Deserialize, Serialize}; use sqlx::{Pool, Sqlite}; use tokio::sync::broadcast::Sender; #[handler] fn index() -> Html<&'static str> { Html( r###"
Name:
"###, ) } #[derive(Debug, Object, Clone, Eq, PartialEq, Serialize, Deserialize)] struct User { #[oai(validator(max_length = 64))] username: String, } #[derive(ApiResponse)] enum UserResponse { #[oai(status = 200)] Ok(Json), #[oai(status = 403)] AuthError, } #[derive(Debug, Object, Clone, Eq, PartialEq)] struct NewUser { /// Username #[oai(validator(max_length = 64))] username: String, /// Password #[oai(validator(max_length = 50))] password: Password, /// Invite / Referral Code #[oai(validator(max_length = 8))] referral: Option, } #[derive(ApiResponse)] enum NewUserResponse { #[oai(status = 200)] Ok(Json), #[oai(status = 403)] UsernameTaken, #[oai(status = 403)] InvalidReferral, #[oai(status = 403)] SignupClosed, #[oai(status = 403)] InvalidPassword, #[oai(status = 500)] InternalServerError, } #[derive(Debug, Object, Clone, Eq, PartialEq)] struct UserLogin { /// Username #[oai(validator(max_length = 64))] username: String, /// Password #[oai(validator(max_length = 50))] password: Password, } #[derive(ApiResponse)] enum NumResponse { #[oai(status = 200)] Ok(Json), #[oai(status = 403)] AuthError, } #[derive(Tags)] enum ApiTags { /// Operations about user User, } async fn valid_code(code: &str) -> bool { "changeme" == code } #[derive(Debug, thiserror::Error)] #[error("API Error")] struct ApiError(); impl ResponseError for ApiError { fn status(&self) -> StatusCode { StatusCode::FORBIDDEN } } impl ApiResponse for ApiError { fn meta() -> MetaResponses { MetaResponses { responses: vec![MetaResponse { description: "An Error response", status: Some(403), content: vec![], headers: vec![], }], } } fn register(_registry: &mut Registry) {} } #[OpenApi] impl Api { #[oai(path = "/ws", method = "get")] async fn ws( &self, websock: WebSocket, session: &Session, ) -> Result { let name = session.get::("user").ok_or(ApiError {})?; let sender = self.channel.clone(); let mut receiver = sender.subscribe(); let x = websock .on_upgrade(move |socket| async move { let (mut sink, mut stream) = socket.split(); tokio::spawn(async move { while let Some(Ok(msg)) = stream.next().await { if let Message::Text(text) = msg { if sender.send(format!("{name}: {text}")).is_err() { break; } } } }); tokio::spawn(async move { while let Ok(msg) = receiver.recv().await { if sink.send(Message::Text(msg)).await.is_err() { break; } } }); }) .boxed(); Result::Ok(x) } #[oai(path = "/user", method = "post", tag = "ApiTags::User")] async fn create_user(&self, user_form: Json, session: &Session) -> NewUserResponse { let has_referral = match &user_form.referral { Some(code) if valid_code(code).await => true, Some(_) => return NewUserResponse::InvalidReferral, None => false, }; if !self.signup_open && !has_referral { return NewUserResponse::SignupClosed; } //TODO propper handling of query errors let username = user_form.username.clone(); let x = sqlx::query!("SELECT * FROM users WHERE username = ?", username) .fetch_optional(&self.db) .await; if let Ok(Some(_current_user)) = x { return NewUserResponse::UsernameTaken; } let hash = match bcrypt::hash(user_form.password.as_str(), bcrypt::DEFAULT_COST) { Ok(hash) => hash, _ => return NewUserResponse::InvalidPassword, }; let res = sqlx::query!( "INSERT INTO users(username, password, admin) VALUES(?,?,?)", username, hash, false ) .execute(&self.db) .await; match res { Err(_e) => NewUserResponse::InternalServerError, Ok(_) => { let user = User { username }; session.set("user", user.username.clone()); NewUserResponse::Ok(Json(user)) } } } #[oai(path = "/user", method = "get", tag = "ApiTags::User")] async fn get_user(&self, session: &Session) -> UserResponse { if let Some(username) = session.get::("user") { UserResponse::Ok(Json(User { username })) } else { UserResponse::AuthError } } #[oai(path = "/auth", method = "delete", tag = "ApiTags::User")] async fn deauth_user(&self, session: &Session) -> Result<()> { session.purge(); Result::Ok(()) } #[oai(path = "/auth", method = "post", tag = "ApiTags::User")] async fn auth_user(&self, user: Json, session: &Session) -> UserResponse { let password = user.password.as_str(); let username = user.username.to_string(); let result = sqlx::query!("SELECT * FROM users WHERE username = ?", username) .fetch_one(&self.db) .await; match result { Ok(user) if bcrypt::verify(password, &user.password).unwrap_or(false) => { let current_user = User { username: username.clone(), }; session.set("user", username); UserResponse::Ok(Json(current_user)) } _ => UserResponse::AuthError, } } #[oai(path = "/num", method = "get", tag = "ApiTags::User")] //async fn get_num(&self, session: &Session) -> Json { async fn get_num(&self, session: &Session) -> NumResponse { //if session.get("user") if let None = session.get::("user") { return NumResponse::AuthError; } match session.get::("num") { Some(i) => { let new: u32 = i + 1; session.set("num", new); NumResponse::Ok(Json(new)) } None => { session.set("num", 1_u32); NumResponse::Ok(Json(1)) } } } } struct Api { db: Pool, channel: Sender, signup_open: bool, //TODO Should be in the db so it can change without restart } #[tokio::main] async fn main() -> anyhow::Result<()> { if std::env::var_os("RUST_LOG").is_none() { std::env::set_var("RUST_LOG", "poem=debug"); } tracing_subscriber::fmt::init(); let session = ServerSession::new(CookieConfig::default(), MemoryStorage::new()); let channel = tokio::sync::broadcast::channel::(128).0; //let dbpool = SqlitePool::connect("sqlite:chat.db").await?; let dbpool = Pool::::connect("sqlite:chat.db").await?; let api = OpenApiService::new( Api { db: dbpool, channel, signup_open: false, }, "Chat", "0.0", ); let static_endpoint = StaticFilesEndpoint::new("./client/dist").index_file("index.html"); let app = Route::new() .at("/api/test", get(index)) .nest("/", static_endpoint) .nest("/api", api) .with(session); Server::new(TcpListener::bind("127.0.0.1:3000")) .run(app) .await?; Ok(()) }