Add chat websocket
This commit is contained in:
parent
62621e7fe6
commit
748b7968cd
9
Cargo.lock
generated
9
Cargo.lock
generated
@ -509,6 +509,7 @@ dependencies = [
|
|||||||
"poem-openapi",
|
"poem-openapi",
|
||||||
"serde",
|
"serde",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
]
|
]
|
||||||
@ -2019,18 +2020,18 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "1.0.49"
|
version = "1.0.50"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4"
|
checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"thiserror-impl",
|
"thiserror-impl",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror-impl"
|
name = "thiserror-impl"
|
||||||
version = "1.0.49"
|
version = "1.0.50"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc"
|
checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
|||||||
@ -10,10 +10,8 @@ futures-util = "0.3.17"
|
|||||||
tracing-subscriber = { version ="0.3.9", features = ["env-filter"] }
|
tracing-subscriber = { version ="0.3.9", features = ["env-filter"] }
|
||||||
sqlx = { version = "0.7.2", features = ["runtime-tokio", "sqlite"] }
|
sqlx = { version = "0.7.2", features = ["runtime-tokio", "sqlite"] }
|
||||||
bcrypt = "0.15.0"
|
bcrypt = "0.15.0"
|
||||||
anyhow = "1.0.75"
|
|
||||||
poem-openapi = { version = "3.0.5", features = ["websocket"] }
|
poem-openapi = { version = "3.0.5", features = ["websocket"] }
|
||||||
serde = "1.0.190"
|
serde = "1.0.190"
|
||||||
|
thiserror = "1.0.50"
|
||||||
[profile.dev.package.sqlx-macros]
|
anyhow = "1.0.75"
|
||||||
opt-level = 3
|
|
||||||
|
|
||||||
|
|||||||
135
src/main.rs
135
src/main.rs
@ -1,27 +1,26 @@
|
|||||||
//use anyhow::Result;
|
|
||||||
//use bcrypt::{hash, verify, DEFAULT_COST};
|
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use poem::{
|
use poem::{
|
||||||
endpoint::StaticFilesEndpoint,
|
endpoint::StaticFilesEndpoint,
|
||||||
//get,
|
error::ResponseError,
|
||||||
handler,
|
get, handler,
|
||||||
|
http::StatusCode,
|
||||||
listener::TcpListener,
|
listener::TcpListener,
|
||||||
session::{CookieConfig, MemoryStorage, ServerSession, Session},
|
session::{CookieConfig, MemoryStorage, ServerSession, Session},
|
||||||
web::{
|
web::{
|
||||||
websocket::{Message, WebSocket},
|
websocket::{BoxWebSocketUpgraded, Message, WebSocket},
|
||||||
Data, Html, Path,
|
Html,
|
||||||
},
|
},
|
||||||
EndpointExt,
|
EndpointExt, Result, Route, Server,
|
||||||
IntoResponse,
|
|
||||||
Result,
|
|
||||||
Route,
|
|
||||||
Server,
|
|
||||||
};
|
};
|
||||||
use poem_openapi::{
|
use poem_openapi::{
|
||||||
payload::Json, types::Password, ApiResponse, Object, OpenApi, OpenApiService, Tags,
|
payload::Json,
|
||||||
|
registry::{MetaResponse, MetaResponses, Registry},
|
||||||
|
types::Password,
|
||||||
|
ApiResponse, Object, OpenApi, OpenApiService, Tags,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sqlx::{Pool, Sqlite};
|
use sqlx::{Pool, Sqlite};
|
||||||
|
use tokio::sync::broadcast::Sender;
|
||||||
|
|
||||||
#[handler]
|
#[handler]
|
||||||
fn index() -> Html<&'static str> {
|
fn index() -> Html<&'static str> {
|
||||||
@ -56,7 +55,7 @@ fn index() -> Html<&'static str> {
|
|||||||
sendForm.hidden = false;
|
sendForm.hidden = false;
|
||||||
msgsArea.hidden = false;
|
msgsArea.hidden = false;
|
||||||
msgInput.focus();
|
msgInput.focus();
|
||||||
ws = new WebSocket("ws://127.0.0.1:2000/api/ws/" + nameInput.value);
|
ws = new WebSocket("ws://localhost:2000/api/ws");
|
||||||
ws.onmessage = function(event) {
|
ws.onmessage = function(event) {
|
||||||
msgsArea.value += event.data + "\r\n";
|
msgsArea.value += event.data + "\r\n";
|
||||||
}
|
}
|
||||||
@ -73,37 +72,6 @@ fn index() -> Html<&'static str> {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[handler]
|
|
||||||
fn ws(
|
|
||||||
Path(name): Path<String>,
|
|
||||||
ws: WebSocket,
|
|
||||||
sender: Data<&tokio::sync::broadcast::Sender<String>>,
|
|
||||||
) -> impl IntoResponse {
|
|
||||||
let sender = sender.clone();
|
|
||||||
let mut receiver = sender.subscribe();
|
|
||||||
ws.on_upgrade(move |socket| async move {
|
|
||||||
let (mut sink, mut stream) = socket.split();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
while let Some(Ok(msg)) = stream.next().await {
|
|
||||||
if let Message::Text(text) = msg {
|
|
||||||
if sender.send(format!("{name}: {text}")).is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
while let Ok(msg) = receiver.recv().await {
|
|
||||||
if sink.send(Message::Text(msg)).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Object, Clone, Eq, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Object, Clone, Eq, PartialEq, Serialize, Deserialize)]
|
||||||
struct User {
|
struct User {
|
||||||
#[oai(validator(max_length = 64))]
|
#[oai(validator(max_length = 64))]
|
||||||
@ -159,6 +127,7 @@ enum NumResponse {
|
|||||||
#[oai(status = 403)]
|
#[oai(status = 403)]
|
||||||
AuthError,
|
AuthError,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Tags)]
|
#[derive(Tags)]
|
||||||
enum ApiTags {
|
enum ApiTags {
|
||||||
/// Operations about user
|
/// Operations about user
|
||||||
@ -167,12 +136,68 @@ enum ApiTags {
|
|||||||
async fn valid_code(code: &str) -> bool {
|
async fn valid_code(code: &str) -> bool {
|
||||||
"changeme" == code
|
"changeme" == code
|
||||||
}
|
}
|
||||||
struct Api {
|
#[derive(Debug, thiserror::Error)]
|
||||||
db: Pool<Sqlite>,
|
#[error("API Error")]
|
||||||
signup_open: bool, //TODO Should be in the db so it can change without restart
|
struct ApiError();
|
||||||
|
|
||||||
|
impl ResponseError for ApiError {
|
||||||
|
fn status(&self) -> StatusCode {
|
||||||
|
StatusCode::FORBIDDEN
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
impl ApiResponse for ApiError {
|
||||||
|
fn meta() -> MetaResponses {
|
||||||
|
MetaResponses {
|
||||||
|
responses: vec![MetaResponse {
|
||||||
|
description: "An Error response",
|
||||||
|
status: Some(403),
|
||||||
|
content: vec![],
|
||||||
|
headers: vec![],
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn register(_registry: &mut Registry) {}
|
||||||
|
}
|
||||||
|
|
||||||
#[OpenApi]
|
#[OpenApi]
|
||||||
impl Api {
|
impl Api {
|
||||||
|
#[oai(path = "/ws", method = "get")]
|
||||||
|
async fn ws(
|
||||||
|
&self,
|
||||||
|
websock: WebSocket,
|
||||||
|
session: &Session,
|
||||||
|
) -> Result<BoxWebSocketUpgraded, ApiError> {
|
||||||
|
let name = session.get::<String>("user").ok_or(ApiError {})?;
|
||||||
|
|
||||||
|
let sender = self.channel.clone();
|
||||||
|
let mut receiver = sender.subscribe();
|
||||||
|
let x = websock
|
||||||
|
.on_upgrade(move |socket| async move {
|
||||||
|
let (mut sink, mut stream) = socket.split();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
while let Some(Ok(msg)) = stream.next().await {
|
||||||
|
if let Message::Text(text) = msg {
|
||||||
|
if sender.send(format!("{name}: {text}")).is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
while let Ok(msg) = receiver.recv().await {
|
||||||
|
if sink.send(Message::Text(msg)).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.boxed();
|
||||||
|
|
||||||
|
Result::Ok(x)
|
||||||
|
}
|
||||||
|
|
||||||
#[oai(path = "/user", method = "post", tag = "ApiTags::User")]
|
#[oai(path = "/user", method = "post", tag = "ApiTags::User")]
|
||||||
async fn create_user(&self, user_form: Json<NewUser>, session: &Session) -> NewUserResponse {
|
async fn create_user(&self, user_form: Json<NewUser>, session: &Session) -> NewUserResponse {
|
||||||
let has_referral = match &user_form.referral {
|
let has_referral = match &user_form.referral {
|
||||||
@ -271,6 +296,12 @@ impl Api {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Api {
|
||||||
|
db: Pool<Sqlite>,
|
||||||
|
channel: Sender<String>,
|
||||||
|
signup_open: bool, //TODO Should be in the db so it can change without restart
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
if std::env::var_os("RUST_LOG").is_none() {
|
if std::env::var_os("RUST_LOG").is_none() {
|
||||||
@ -280,11 +311,14 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let session = ServerSession::new(CookieConfig::default(), MemoryStorage::new());
|
let session = ServerSession::new(CookieConfig::default(), MemoryStorage::new());
|
||||||
|
|
||||||
|
let channel = tokio::sync::broadcast::channel::<String>(128).0;
|
||||||
|
|
||||||
//let dbpool = SqlitePool::connect("sqlite:chat.db").await?;
|
//let dbpool = SqlitePool::connect("sqlite:chat.db").await?;
|
||||||
let dbpool = Pool::<Sqlite>::connect("sqlite:chat.db").await?;
|
let dbpool = Pool::<Sqlite>::connect("sqlite:chat.db").await?;
|
||||||
let api = OpenApiService::new(
|
let api = OpenApiService::new(
|
||||||
Api {
|
Api {
|
||||||
db: dbpool,
|
db: dbpool,
|
||||||
|
channel,
|
||||||
signup_open: false,
|
signup_open: false,
|
||||||
},
|
},
|
||||||
"Chat",
|
"Chat",
|
||||||
@ -293,12 +327,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let static_endpoint = StaticFilesEndpoint::new("./client/dist").index_file("index.html");
|
let static_endpoint = StaticFilesEndpoint::new("./client/dist").index_file("index.html");
|
||||||
|
|
||||||
let app = Route::new()
|
let app = Route::new()
|
||||||
//.at("/api/", get(index))
|
.at("/api/test", get(index))
|
||||||
//.at(
|
|
||||||
// "/api/ws/:name",
|
|
||||||
// get(ws.data(tokio::sync::broadcast::channel::<String>(32).0)),
|
|
||||||
//)
|
|
||||||
//.data(dbpool)
|
|
||||||
.nest("/", static_endpoint)
|
.nest("/", static_endpoint)
|
||||||
.nest("/api", api)
|
.nest("/api", api)
|
||||||
.with(session);
|
.with(session);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user