342 lines
9.9 KiB
Rust

use futures_util::{SinkExt, StreamExt};
use poem::{
endpoint::StaticFilesEndpoint,
error::ResponseError,
get, handler,
http::StatusCode,
listener::TcpListener,
session::{CookieConfig, MemoryStorage, ServerSession, Session},
web::{
websocket::{BoxWebSocketUpgraded, Message, WebSocket},
Html,
},
EndpointExt, Result, Route, Server,
};
use poem_openapi::{
payload::Json,
registry::{MetaResponse, MetaResponses, Registry},
types::Password,
ApiResponse, Object, OpenApi, OpenApiService, Tags,
};
use serde::{Deserialize, Serialize};
use sqlx::{Pool, Sqlite};
use tokio::sync::broadcast::Sender;
#[handler]
fn index() -> Html<&'static str> {
Html(
r###"
<body>
<form id="loginForm">
Name: <input id="nameInput" type="text" />
<button type="submit">Login</button>
</form>
<form id="sendForm" hidden>
Text: <input id="msgInput" type="text" />
<button type="submit">Send</button>
</form>
<textarea id="msgsArea" cols="50" rows="30" hidden></textarea>
</body>
<script>
let ws;
const loginForm = document.querySelector("#loginForm");
const sendForm = document.querySelector("#sendForm");
const nameInput = document.querySelector("#nameInput");
const msgInput = document.querySelector("#msgInput");
const msgsArea = document.querySelector("#msgsArea");
nameInput.focus();
loginForm.addEventListener("submit", function(event) {
event.preventDefault();
loginForm.hidden = true;
sendForm.hidden = false;
msgsArea.hidden = false;
msgInput.focus();
let protocol = window.location.protocol == "https:" ? "wss://" : "ws://";
ws = new WebSocket(protocol + window.location.host + "/api/ws");
ws.onmessage = function(event) {
msgsArea.value += event.data + "\r\n";
}
});
sendForm.addEventListener("submit", function(event) {
event.preventDefault();
ws.send(msgInput.value);
msgInput.value = "";
});
</script>
"###,
)
}
#[derive(Debug, Object, Clone, Eq, PartialEq, Serialize, Deserialize)]
struct User {
#[oai(validator(max_length = 64))]
username: String,
}
#[derive(ApiResponse)]
enum UserResponse {
#[oai(status = 200)]
Ok(Json<User>),
#[oai(status = 403)]
AuthError,
}
#[derive(Debug, Object, Clone, Eq, PartialEq)]
struct NewUser {
/// Username
#[oai(validator(max_length = 64))]
username: String,
/// Password
#[oai(validator(max_length = 50))]
password: Password,
/// Invite / Referral Code
#[oai(validator(max_length = 8))]
referral: Option<String>,
}
#[derive(ApiResponse)]
enum NewUserResponse {
#[oai(status = 200)]
Ok(Json<User>),
#[oai(status = 403)]
UsernameTaken,
#[oai(status = 403)]
InvalidReferral,
#[oai(status = 403)]
SignupClosed,
#[oai(status = 403)]
InvalidPassword,
#[oai(status = 500)]
InternalServerError,
}
#[derive(Debug, Object, Clone, Eq, PartialEq)]
struct UserLogin {
/// Username
#[oai(validator(max_length = 64))]
username: String,
/// Password
#[oai(validator(max_length = 50))]
password: Password,
}
#[derive(ApiResponse)]
enum NumResponse {
#[oai(status = 200)]
Ok(Json<u32>),
#[oai(status = 403)]
AuthError,
}
#[derive(Tags)]
enum ApiTags {
/// Operations about user
User,
}
async fn valid_code(code: &str) -> bool {
"changeme" == code
}
#[derive(Debug, thiserror::Error)]
#[error("API Error")]
struct ApiError();
impl ResponseError for ApiError {
fn status(&self) -> StatusCode {
StatusCode::FORBIDDEN
}
}
impl ApiResponse for ApiError {
fn meta() -> MetaResponses {
MetaResponses {
responses: vec![MetaResponse {
description: "An Error response",
status: Some(403),
content: vec![],
headers: vec![],
}],
}
}
fn register(_registry: &mut Registry) {}
}
#[OpenApi]
impl Api {
#[oai(path = "/ws", method = "get")]
async fn ws(
&self,
websock: WebSocket,
session: &Session,
) -> Result<BoxWebSocketUpgraded, ApiError> {
let name = session.get::<String>("user").ok_or(ApiError {})?;
let sender = self.channel.clone();
let mut receiver = sender.subscribe();
let x = websock
.on_upgrade(move |socket| async move {
let (mut sink, mut stream) = socket.split();
tokio::spawn(async move {
while let Some(Ok(msg)) = stream.next().await {
if let Message::Text(text) = msg {
if sender.send(format!("{name}: {text}")).is_err() {
break;
}
}
}
});
tokio::spawn(async move {
while let Ok(msg) = receiver.recv().await {
if sink.send(Message::Text(msg)).await.is_err() {
break;
}
}
});
})
.boxed();
Result::Ok(x)
}
#[oai(path = "/user", method = "post", tag = "ApiTags::User")]
async fn create_user(&self, user_form: Json<NewUser>, session: &Session) -> NewUserResponse {
let has_referral = match &user_form.referral {
Some(code) if valid_code(code).await => true,
Some(_) => return NewUserResponse::InvalidReferral,
None => false,
};
if !self.signup_open && !has_referral {
return NewUserResponse::SignupClosed;
}
//TODO propper handling of query errors
let username = user_form.username.clone();
let x = sqlx::query!("SELECT * FROM users WHERE username = ?", username)
.fetch_optional(&self.db)
.await;
if let Ok(Some(_current_user)) = x {
return NewUserResponse::UsernameTaken;
}
let hash = match bcrypt::hash(user_form.password.as_str(), bcrypt::DEFAULT_COST) {
Ok(hash) => hash,
_ => return NewUserResponse::InvalidPassword,
};
let res = sqlx::query!(
"INSERT INTO users(username, password, admin) VALUES(?,?,?)",
username,
hash,
false
)
.execute(&self.db)
.await;
match res {
Err(_e) => NewUserResponse::InternalServerError,
Ok(_) => {
let user = User { username };
session.set("user", user.username.clone());
NewUserResponse::Ok(Json(user))
}
}
}
#[oai(path = "/user", method = "get", tag = "ApiTags::User")]
async fn get_user(&self, session: &Session) -> UserResponse {
if let Some(username) = session.get::<String>("user") {
UserResponse::Ok(Json(User { username }))
} else {
UserResponse::AuthError
}
}
#[oai(path = "/auth", method = "delete", tag = "ApiTags::User")]
async fn deauth_user(&self, session: &Session) -> Result<()> {
session.purge();
Result::Ok(())
}
#[oai(path = "/auth", method = "post", tag = "ApiTags::User")]
async fn auth_user(&self, user: Json<UserLogin>, session: &Session) -> UserResponse {
let password = user.password.as_str();
let username = user.username.to_string();
let result = sqlx::query!("SELECT * FROM users WHERE username = ?", username)
.fetch_one(&self.db)
.await;
match result {
Ok(user) if bcrypt::verify(password, &user.password).unwrap_or(false) => {
let current_user = User {
username: username.clone(),
};
session.set("user", username);
UserResponse::Ok(Json(current_user))
}
_ => UserResponse::AuthError,
}
}
#[oai(path = "/num", method = "get", tag = "ApiTags::User")]
//async fn get_num(&self, session: &Session) -> Json<u32> {
async fn get_num(&self, session: &Session) -> NumResponse {
//if session.get("user")
if let None = session.get::<String>("user") {
return NumResponse::AuthError;
}
match session.get::<u32>("num") {
Some(i) => {
let new: u32 = i + 1;
session.set("num", new);
NumResponse::Ok(Json(new))
}
None => {
session.set("num", 1_u32);
NumResponse::Ok(Json(1))
}
}
}
}
struct Api {
db: Pool<Sqlite>,
channel: Sender<String>,
signup_open: bool, //TODO Should be in the db so it can change without restart
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
if std::env::var_os("RUST_LOG").is_none() {
std::env::set_var("RUST_LOG", "poem=debug");
}
tracing_subscriber::fmt::init();
let session = ServerSession::new(CookieConfig::default(), MemoryStorage::new());
let channel = tokio::sync::broadcast::channel::<String>(128).0;
//let dbpool = SqlitePool::connect("sqlite:chat.db").await?;
let dbpool = Pool::<Sqlite>::connect("sqlite:chat.db").await?;
let api = OpenApiService::new(
Api {
db: dbpool,
channel,
signup_open: false,
},
"Chat",
"0.0",
);
let static_endpoint = StaticFilesEndpoint::new("./client/dist").index_file("index.html");
let app = Route::new()
.at("/api/test", get(index))
.nest("/", static_endpoint)
.nest("/api", api)
.with(session);
Server::new(TcpListener::bind("127.0.0.1:3000"))
.run(app)
.await?;
Ok(())
}