Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite scheduler and make it smol #165

Merged
merged 17 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
794 changes: 592 additions & 202 deletions Cargo.lock

Large diffs are not rendered by default.

25 changes: 19 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,19 @@ urlencoding = "2.1"

### RUNTIME

blocking = "1.5"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
mlua = { version = "0.9.1", features = ["luau", "luau-jit", "serialize"] }
tokio = { version = "1.24", features = ["full", "tracing"] }
os_str_bytes = { version = "6.4", features = ["conversions"] }
os_str_bytes = { version = "7.0", features = ["conversions"] }

mlua-luau-scheduler = { version = "0.0.2" }
mlua = { version = "0.9.6", features = [
"luau",
"luau-jit",
"async",
"serialize",
] }

### SERDE

Expand All @@ -101,12 +109,17 @@ toml = { version = "0.8", features = ["preserve_order"] }

### NET

hyper = { version = "0.14", features = ["full"] }
hyper-tungstenite = { version = "0.11" }
hyper = { version = "1.1", features = ["full"] }
hyper-util = { version = "0.1", features = ["full"] }
http = "1.0"
http-body-util = { version = "0.1" }
hyper-tungstenite = { version = "0.13" }

reqwest = { version = "0.11", default-features = false, features = [
"rustls-tls",
] }
tokio-tungstenite = { version = "0.20", features = ["rustls-tls-webpki-roots"] }

tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] }

### DATETIME
chrono = "0.4"
Expand All @@ -115,7 +128,7 @@ chrono_lc = "0.1"
### CLI

anyhow = { optional = true, version = "1.0" }
env_logger = { optional = true, version = "0.10" }
env_logger = { optional = true, version = "0.11" }
itertools = { optional = true, version = "0.12" }
clap = { optional = true, version = "4.1", features = ["derive"] }
include_dir = { optional = true, version = "0.7", features = ["glob"] }
Expand Down
2 changes: 1 addition & 1 deletion src/lune/builtins/fs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use copy::copy;
use metadata::FsMetadata;
use options::FsWriteOptions;

pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
TableBuilder::new(lua)?
.with_async_function("readFile", fs_read_file)?
.with_async_function("readDir", fs_read_dir)?
Expand Down
7 changes: 2 additions & 5 deletions src/lune/builtins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ pub enum LuneBuiltin {
Roblox,
}

impl<'lua> LuneBuiltin
where
'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it
{
impl LuneBuiltin {
pub fn name(&self) -> &'static str {
match self {
Self::DateTime => "datetime",
Expand All @@ -47,7 +44,7 @@ where
}
}

pub fn create(&self, lua: &'lua Lua) -> LuaResult<LuaMultiValue<'lua>> {
pub fn create<'lua>(&self, lua: &'lua Lua) -> LuaResult<LuaMultiValue<'lua>> {
let res = match self {
Self::DateTime => datetime::create(lua),
Self::Fs => fs::create(lua),
Expand Down
111 changes: 98 additions & 13 deletions src/lune/builtins/net/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@ use std::str::FromStr;

use mlua::prelude::*;

use hyper::{header::HeaderName, http::HeaderValue, HeaderMap};
use reqwest::{IntoUrl, Method, RequestBuilder};
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_ENCODING};

use crate::lune::{
builtins::serde::compress_decompress::{decompress, CompressDecompressFormat},
util::TableBuilder,
};

use super::{config::RequestConfig, util::header_map_to_table};

const REGISTRY_KEY: &str = "NetClient";

Expand Down Expand Up @@ -35,33 +41,88 @@ impl NetClientBuilder {

pub fn build(self) -> LuaResult<NetClient> {
let client = self.builder.build().into_lua_err()?;
Ok(NetClient(client))
Ok(NetClient { inner: client })
}
}

#[derive(Debug, Clone)]
pub struct NetClient(reqwest::Client);
pub struct NetClient {
inner: reqwest::Client,
}

impl NetClient {
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
self.0.request(method, url)
pub fn from_registry(lua: &Lua) -> Self {
lua.named_registry_value(REGISTRY_KEY)
.expect("Failed to get NetClient from lua registry")
}

pub fn into_registry(self, lua: &Lua) {
lua.set_named_registry_value(REGISTRY_KEY, self)
.expect("Failed to store NetClient in lua registry");
}

pub fn from_registry(lua: &Lua) -> Self {
lua.named_registry_value(REGISTRY_KEY)
.expect("Failed to get NetClient from lua registry")
pub async fn request(&self, config: RequestConfig) -> LuaResult<NetClientResponse> {
// Create and send the request
let mut request = self.inner.request(config.method, config.url);
for (query, values) in config.query {
request = request.query(
&values
.iter()
.map(|v| (query.as_str(), v))
.collect::<Vec<_>>(),
);
}
for (header, values) in config.headers {
for value in values {
request = request.header(header.as_str(), value);
}
}
let res = request
.body(config.body.unwrap_or_default())
.send()
.await
.into_lua_err()?;

// Extract status, headers
let res_status = res.status().as_u16();
let res_status_text = res.status().canonical_reason();
let res_headers = res.headers().clone();

// Read response bytes
let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec();
let mut res_decompressed = false;

// Check for extra options, decompression
if config.options.decompress {
let decompress_format = res_headers
.iter()
.find(|(name, _)| {
name.as_str()
.eq_ignore_ascii_case(CONTENT_ENCODING.as_str())
})
.and_then(|(_, value)| value.to_str().ok())
.and_then(CompressDecompressFormat::detect_from_header_str);
if let Some(format) = decompress_format {
res_bytes = decompress(format, res_bytes).await?;
res_decompressed = true;
}
}

Ok(NetClientResponse {
ok: (200..300).contains(&res_status),
status_code: res_status,
status_message: res_status_text.unwrap_or_default().to_string(),
headers: res_headers,
body: res_bytes,
body_decompressed: res_decompressed,
})
}
}

impl LuaUserData for NetClient {}

impl<'lua> FromLua<'lua> for NetClient {
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
impl FromLua<'_> for NetClient {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
if let LuaValue::UserData(ud) = value {
if let Ok(ctx) = ud.borrow::<NetClient>() {
return Ok(ctx.clone());
Expand All @@ -71,10 +132,34 @@ impl<'lua> FromLua<'lua> for NetClient {
}
}

impl<'lua> From<&'lua Lua> for NetClient {
fn from(value: &'lua Lua) -> Self {
impl From<&Lua> for NetClient {
fn from(value: &Lua) -> Self {
value
.named_registry_value(REGISTRY_KEY)
.expect("Missing require context in lua registry")
}
}

pub struct NetClientResponse {
ok: bool,
status_code: u16,
status_message: String,
headers: HeaderMap,
body: Vec<u8>,
body_decompressed: bool,
}

impl NetClientResponse {
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
TableBuilder::new(lua)?
.with_value("ok", self.ok)?
.with_value("statusCode", self.status_code)?
.with_value("statusMessage", self.status_message)?
.with_value(
"headers",
header_map_to_table(lua, self.headers, self.body_decompressed)?,
)?
.with_value("body", lua.create_string(&self.body)?)?
.build_readonly()
}
}
Loading
Loading