109 lines
3.0 KiB
Rust

use futures_util::{SinkExt, StreamExt};
use poem::{
get, handler,
listener::TcpListener,
web::{
websocket::{Message, WebSocket},
Data, Html, Path,
},
EndpointExt, IntoResponse, Route, Server,
};
#[handler]
fn index() -> Html<&'static str> {
Html(
r###"
<body>
<form id="loginForm">
Name: <input id="nameInput" type="text" />
<button type="submit">Login</button>
</form>
<form id="sendForm" hidden>
Text: <input id="msgInput" type="text" />
<button type="submit">Send</button>
</form>
<textarea id="msgsArea" cols="50" rows="30" hidden></textarea>
</body>
<script>
let ws;
const loginForm = document.querySelector("#loginForm");
const sendForm = document.querySelector("#sendForm");
const nameInput = document.querySelector("#nameInput");
const msgInput = document.querySelector("#msgInput");
const msgsArea = document.querySelector("#msgsArea");
nameInput.focus();
loginForm.addEventListener("submit", function(event) {
event.preventDefault();
loginForm.hidden = true;
sendForm.hidden = false;
msgsArea.hidden = false;
msgInput.focus();
ws = new WebSocket("ws://127.0.0.1:3000/ws/" + nameInput.value);
ws.onmessage = function(event) {
msgsArea.value += event.data + "\r\n";
}
});
sendForm.addEventListener("submit", function(event) {
event.preventDefault();
ws.send(msgInput.value);
msgInput.value = "";
});
</script>
"###,
)
}
#[handler]
fn ws(
Path(name): Path<String>,
ws: WebSocket,
sender: Data<&tokio::sync::broadcast::Sender<String>>,
) -> impl IntoResponse {
let sender = sender.clone();
let mut receiver = sender.subscribe();
ws.on_upgrade(move |socket| async move {
let (mut sink, mut stream) = socket.split();
tokio::spawn(async move {
while let Some(Ok(msg)) = stream.next().await {
if let Message::Text(text) = msg {
if sender.send(format!("{name}: {text}")).is_err() {
break;
}
}
}
});
tokio::spawn(async move {
while let Ok(msg) = receiver.recv().await {
if sink.send(Message::Text(msg)).await.is_err() {
break;
}
}
});
})
}
#[tokio::main]
async fn main() -> Result<(), std::io::Error> {
if std::env::var_os("RUST_LOG").is_none() {
std::env::set_var("RUST_LOG", "poem=debug");
}
tracing_subscriber::fmt::init();
let app = Route::new().at("/api/", get(index)).at(
"/api/ws/:name",
get(ws.data(tokio::sync::broadcast::channel::<String>(32).0)),
);
Server::new(TcpListener::bind("127.0.0.1:3000"))
.run(app)
.await
}