forked from mlua-rs/mlua
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathasync_tcp_server.rs
110 lines (92 loc) · 3.31 KB
/
async_tcp_server.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
use std::io;
use std::net::SocketAddr;
use std::rc::Rc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::task;
use mlua::{chunk, Function, Lua, RegistryKey, String as LuaString, UserData, UserDataMethods};
struct LuaTcpStream(TcpStream);
impl UserData for LuaTcpStream {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_method("peer_addr", |_, this, ()| {
Ok(this.0.peer_addr()?.to_string())
});
methods.add_async_method_mut("read", |lua, this, size| async move {
let mut buf = vec![0; size];
let n = this.0.read(&mut buf).await?;
buf.truncate(n);
lua.create_string(&buf)
});
methods.add_async_method_mut("write", |_, this, data: LuaString| async move {
let n = this.0.write(&data.as_bytes()).await?;
Ok(n)
});
methods.add_async_method_mut("close", |_, this, ()| async move {
this.0.shutdown().await?;
Ok(())
});
}
}
async fn run_server(lua: Lua, handler: RegistryKey) -> io::Result<()> {
let addr: SocketAddr = ([127, 0, 0, 1], 3000).into();
let listener = TcpListener::bind(addr).await.expect("cannot bind addr");
println!("Listening on {}", addr);
let lua = Rc::new(lua);
let handler = Rc::new(handler);
loop {
let (stream, _) = match listener.accept().await {
Ok(res) => res,
Err(err) if is_transient_error(&err) => continue,
Err(err) => return Err(err),
};
let lua = lua.clone();
let handler = handler.clone();
task::spawn_local(async move {
let handler: Function = lua
.registry_value(&handler)
.expect("cannot get Lua handler");
let stream = LuaTcpStream(stream);
if let Err(err) = handler.call_async::<_, ()>(stream).await {
eprintln!("{}", err);
}
});
}
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
let lua = Lua::new();
// Create Lua handler function
let handler_fn = lua
.load(chunk! {
function(stream)
local peer_addr = stream:peer_addr()
print("connected from "..peer_addr)
while true do
local data = stream:read(100)
data = data:match("^%s*(.-)%s*$") // trim
print("["..peer_addr.."] "..data)
if data == "bye" then
stream:write("bye bye\n")
stream:close()
return
end
stream:write("echo: "..data.."\n")
end
end
})
.eval::<Function>()
.expect("cannot create Lua handler");
// Store it in the Registry
let handler = lua
.create_registry_value(handler_fn)
.expect("cannot store Lua handler");
task::LocalSet::new()
.run_until(run_server(lua, handler))
.await
.expect("cannot run server")
}
fn is_transient_error(e: &io::Error) -> bool {
e.kind() == io::ErrorKind::ConnectionRefused
|| e.kind() == io::ErrorKind::ConnectionAborted
|| e.kind() == io::ErrorKind::ConnectionReset
}