312 lines
8.9 KiB
Rust
312 lines
8.9 KiB
Rust
//use anyhow::Result;
|
|
//use bcrypt::{hash, verify, DEFAULT_COST};
|
|
use futures_util::{SinkExt, StreamExt};
|
|
use poem::{
|
|
endpoint::StaticFilesEndpoint,
|
|
//get,
|
|
handler,
|
|
listener::TcpListener,
|
|
session::{CookieConfig, MemoryStorage, ServerSession, Session},
|
|
web::{
|
|
websocket::{Message, WebSocket},
|
|
Data, Html, Path,
|
|
},
|
|
EndpointExt,
|
|
IntoResponse,
|
|
Result,
|
|
Route,
|
|
Server,
|
|
};
|
|
use poem_openapi::{
|
|
payload::Json, types::Password, ApiResponse, Object, OpenApi, OpenApiService, Tags,
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use sqlx::{Pool, Sqlite};
|
|
|
|
#[handler]
|
|
fn index() -> Html<&'static str> {
|
|
Html(
|
|
r###"
|
|
<body>
|
|
<form id="loginForm">
|
|
Name: <input id="nameInput" type="text" />
|
|
<button type="submit">Login</button>
|
|
</form>
|
|
|
|
<form id="sendForm" hidden>
|
|
Text: <input id="msgInput" type="text" />
|
|
<button type="submit">Send</button>
|
|
</form>
|
|
|
|
<textarea id="msgsArea" cols="50" rows="30" hidden></textarea>
|
|
</body>
|
|
<script>
|
|
let ws;
|
|
const loginForm = document.querySelector("#loginForm");
|
|
const sendForm = document.querySelector("#sendForm");
|
|
const nameInput = document.querySelector("#nameInput");
|
|
const msgInput = document.querySelector("#msgInput");
|
|
const msgsArea = document.querySelector("#msgsArea");
|
|
|
|
nameInput.focus();
|
|
|
|
loginForm.addEventListener("submit", function(event) {
|
|
event.preventDefault();
|
|
loginForm.hidden = true;
|
|
sendForm.hidden = false;
|
|
msgsArea.hidden = false;
|
|
msgInput.focus();
|
|
ws = new WebSocket("ws://127.0.0.1:2000/api/ws/" + nameInput.value);
|
|
ws.onmessage = function(event) {
|
|
msgsArea.value += event.data + "\r\n";
|
|
}
|
|
});
|
|
|
|
sendForm.addEventListener("submit", function(event) {
|
|
event.preventDefault();
|
|
ws.send(msgInput.value);
|
|
msgInput.value = "";
|
|
});
|
|
|
|
</script>
|
|
"###,
|
|
)
|
|
}
|
|
|
|
#[handler]
|
|
fn ws(
|
|
Path(name): Path<String>,
|
|
ws: WebSocket,
|
|
sender: Data<&tokio::sync::broadcast::Sender<String>>,
|
|
) -> impl IntoResponse {
|
|
let sender = sender.clone();
|
|
let mut receiver = sender.subscribe();
|
|
ws.on_upgrade(move |socket| async move {
|
|
let (mut sink, mut stream) = socket.split();
|
|
|
|
tokio::spawn(async move {
|
|
while let Some(Ok(msg)) = stream.next().await {
|
|
if let Message::Text(text) = msg {
|
|
if sender.send(format!("{name}: {text}")).is_err() {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
tokio::spawn(async move {
|
|
while let Ok(msg) = receiver.recv().await {
|
|
if sink.send(Message::Text(msg)).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
})
|
|
}
|
|
|
|
#[derive(Debug, Object, Clone, Eq, PartialEq, Serialize, Deserialize)]
|
|
struct User {
|
|
#[oai(validator(max_length = 64))]
|
|
username: String,
|
|
}
|
|
#[derive(ApiResponse)]
|
|
enum UserResponse {
|
|
#[oai(status = 200)]
|
|
Ok(Json<User>),
|
|
#[oai(status = 403)]
|
|
AuthError,
|
|
}
|
|
#[derive(Debug, Object, Clone, Eq, PartialEq)]
|
|
struct NewUser {
|
|
/// Username
|
|
#[oai(validator(max_length = 64))]
|
|
username: String,
|
|
/// Password
|
|
#[oai(validator(max_length = 50))]
|
|
password: Password,
|
|
/// Invite / Referral Code
|
|
#[oai(validator(max_length = 8))]
|
|
referral: Option<String>,
|
|
}
|
|
#[derive(ApiResponse)]
|
|
enum NewUserResponse {
|
|
#[oai(status = 200)]
|
|
Ok(Json<User>),
|
|
#[oai(status = 403)]
|
|
UsernameTaken,
|
|
#[oai(status = 403)]
|
|
InvalidReferral,
|
|
#[oai(status = 403)]
|
|
SignupClosed,
|
|
#[oai(status = 403)]
|
|
InvalidPassword,
|
|
#[oai(status = 500)]
|
|
InternalServerError,
|
|
}
|
|
#[derive(Debug, Object, Clone, Eq, PartialEq)]
|
|
struct UserLogin {
|
|
/// Username
|
|
#[oai(validator(max_length = 64))]
|
|
username: String,
|
|
/// Password
|
|
#[oai(validator(max_length = 50))]
|
|
password: Password,
|
|
}
|
|
#[derive(ApiResponse)]
|
|
enum NumResponse {
|
|
#[oai(status = 200)]
|
|
Ok(Json<u32>),
|
|
#[oai(status = 403)]
|
|
AuthError,
|
|
}
|
|
#[derive(Tags)]
|
|
enum ApiTags {
|
|
/// Operations about user
|
|
User,
|
|
}
|
|
async fn valid_code(code: &str) -> bool {
|
|
"changeme" == code
|
|
}
|
|
struct Api {
|
|
db: Pool<Sqlite>,
|
|
signup_open: bool, //TODO Should be in the db so it can change without restart
|
|
}
|
|
#[OpenApi]
|
|
impl Api {
|
|
#[oai(path = "/user", method = "post", tag = "ApiTags::User")]
|
|
async fn create_user(&self, user_form: Json<NewUser>, session: &Session) -> NewUserResponse {
|
|
let has_referral = match &user_form.referral {
|
|
Some(code) if valid_code(code).await => true,
|
|
Some(_) => return NewUserResponse::InvalidReferral,
|
|
None => false,
|
|
};
|
|
if !self.signup_open && !has_referral {
|
|
return NewUserResponse::SignupClosed;
|
|
}
|
|
|
|
//TODO propper handling of query errors
|
|
let username = user_form.username.clone();
|
|
let x = sqlx::query!("SELECT * FROM users WHERE username = ?", username)
|
|
.fetch_optional(&self.db)
|
|
.await;
|
|
if let Ok(Some(_current_user)) = x {
|
|
return NewUserResponse::UsernameTaken;
|
|
}
|
|
|
|
let hash = match bcrypt::hash(user_form.password.as_str(), bcrypt::DEFAULT_COST) {
|
|
Ok(hash) => hash,
|
|
_ => return NewUserResponse::InvalidPassword,
|
|
};
|
|
|
|
let res = sqlx::query!(
|
|
"INSERT INTO users(username, password, admin) VALUES(?,?,?)",
|
|
username,
|
|
hash,
|
|
false
|
|
)
|
|
.execute(&self.db)
|
|
.await;
|
|
|
|
match res {
|
|
Err(_e) => NewUserResponse::InternalServerError,
|
|
Ok(_) => {
|
|
let user = User { username };
|
|
session.set("user", user.username.clone());
|
|
NewUserResponse::Ok(Json(user))
|
|
}
|
|
}
|
|
}
|
|
#[oai(path = "/user", method = "get", tag = "ApiTags::User")]
|
|
async fn get_user(&self, session: &Session) -> UserResponse {
|
|
if let Some(username) = session.get::<String>("user") {
|
|
UserResponse::Ok(Json(User { username }))
|
|
} else {
|
|
UserResponse::AuthError
|
|
}
|
|
}
|
|
#[oai(path = "/auth", method = "delete", tag = "ApiTags::User")]
|
|
async fn deauth_user(&self, session: &Session) -> Result<()> {
|
|
session.purge();
|
|
Result::Ok(())
|
|
}
|
|
|
|
#[oai(path = "/auth", method = "post", tag = "ApiTags::User")]
|
|
async fn auth_user(&self, user: Json<UserLogin>, session: &Session) -> UserResponse {
|
|
let password = user.password.as_str();
|
|
let username = user.username.to_string();
|
|
|
|
let result = sqlx::query!("SELECT * FROM users WHERE username = ?", username)
|
|
.fetch_one(&self.db)
|
|
.await;
|
|
match result {
|
|
Ok(user) if bcrypt::verify(password, &user.password).unwrap_or(false) => {
|
|
let current_user = User {
|
|
username: username.clone(),
|
|
};
|
|
session.set("user", username);
|
|
UserResponse::Ok(Json(current_user))
|
|
}
|
|
_ => UserResponse::AuthError,
|
|
}
|
|
}
|
|
|
|
#[oai(path = "/num", method = "get", tag = "ApiTags::User")]
|
|
//async fn get_num(&self, session: &Session) -> Json<u32> {
|
|
async fn get_num(&self, session: &Session) -> NumResponse {
|
|
//if session.get("user")
|
|
if let None = session.get::<String>("user") {
|
|
return NumResponse::AuthError;
|
|
}
|
|
match session.get::<u32>("num") {
|
|
Some(i) => {
|
|
let new: u32 = i + 1;
|
|
session.set("num", new);
|
|
NumResponse::Ok(Json(new))
|
|
}
|
|
None => {
|
|
session.set("num", 1_u32);
|
|
NumResponse::Ok(Json(1))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> anyhow::Result<()> {
|
|
if std::env::var_os("RUST_LOG").is_none() {
|
|
std::env::set_var("RUST_LOG", "poem=debug");
|
|
}
|
|
tracing_subscriber::fmt::init();
|
|
|
|
let session = ServerSession::new(CookieConfig::default(), MemoryStorage::new());
|
|
|
|
//let dbpool = SqlitePool::connect("sqlite:chat.db").await?;
|
|
let dbpool = Pool::<Sqlite>::connect("sqlite:chat.db").await?;
|
|
let api = OpenApiService::new(
|
|
Api {
|
|
db: dbpool,
|
|
signup_open: false,
|
|
},
|
|
"Chat",
|
|
"0.0",
|
|
);
|
|
let static_endpoint = StaticFilesEndpoint::new("./client/dist").index_file("index.html");
|
|
|
|
let app = Route::new()
|
|
//.at("/api/", get(index))
|
|
//.at(
|
|
// "/api/ws/:name",
|
|
// get(ws.data(tokio::sync::broadcast::channel::<String>(32).0)),
|
|
//)
|
|
//.data(dbpool)
|
|
.nest("/", static_endpoint)
|
|
.nest("/api", api)
|
|
.with(session);
|
|
|
|
Server::new(TcpListener::bind("127.0.0.1:3000"))
|
|
.run(app)
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|