From ddf88c64bdaa8b0c7a00d6773f95411d28567a90 Mon Sep 17 00:00:00 2001 From: Sasha Szpakowski Date: Wed, 30 Aug 2023 21:47:26 -0300 Subject: [PATCH] Update lua-https to latest commit (love2d/lua-https@6dbce69) --- CMakeLists.txt | 11 +- .../luahttps/src/common/HTTPRequest.cpp | 3 - .../luahttps/src/common/HTTPRequest.h | 18 +- src/libraries/luahttps/src/common/HTTPS.cpp | 10 + .../luahttps/src/common/HTTPSClient.cpp | 4 +- src/libraries/luahttps/src/common/config.h | 5 + .../luahttps/src/generic/CurlClient.cpp | 148 +++++++++--- .../luahttps/src/generic/CurlClient.h | 18 +- src/libraries/luahttps/src/lua/main.cpp | 1 + .../luahttps/src/windows/WinINetClient.cpp | 224 ++++++++++++++++++ .../luahttps/src/windows/WinINetClient.h | 16 ++ 11 files changed, 405 insertions(+), 53 deletions(-) create mode 100644 src/libraries/luahttps/src/windows/WinINetClient.cpp create mode 100644 src/libraries/luahttps/src/windows/WinINetClient.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 96ff2be8f..7e5bc4fcd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1619,6 +1619,8 @@ set(LOVE_SRC_3P_LUAHTTPS_LUA set(LOVE_SRC_3P_LUAHTTPS_WINDOWS src/libraries/luahttps/src/windows/SChannelConnection.cpp src/libraries/luahttps/src/windows/SChannelConnection.h + src/libraries/luahttps/src/windows/WinINetClient.cpp + src/libraries/luahttps/src/windows/WinINetClient.h ) # These are platform-dependent but have ifdef guards to make sure they only @@ -1638,10 +1640,17 @@ endif() set(LOVE_LINK_L3P_LUAHTTPS) if(MSVC) set(LOVE_LINK_L3P_LUAHTTPS - ${LOVE_LINK_L3P_LUASOCKET_LIBLUASOCKET} + ${LOVE_LINK_L3P_LUAHTTPS} ws2_32 secur32 ) + + if(NOT CMAKE_SYSTEM_NAME STREQUAL "WindowsStore") + set(LOVE_LINK_L3P_LUAHTTPS + ${LOVE_LINK_L3P_LUAHTTPS} + wininet + ) + endif() endif() add_library(love_3p_luahttps ${LOVE_SRC_3P_LUAHTTPS}) diff --git a/src/libraries/luahttps/src/common/HTTPRequest.cpp b/src/libraries/luahttps/src/common/HTTPRequest.cpp index a8c3fbab1..5d86fd889 100644 --- a/src/libraries/luahttps/src/common/HTTPRequest.cpp +++ b/src/libraries/luahttps/src/common/HTTPRequest.cpp @@ -50,9 +50,6 @@ HTTPSClient::Reply HTTPRequest::request(const HTTPSClient::Request &req) request << "Host: " << info.hostname << "\r\n"; - if (hasData && req.headers.count("Content-Type") == 0) - request << "Content-Type: application/x-www-form-urlencoded\r\n"; - if (hasData) request << "Content-Length: " << req.postdata.size() << "\r\n"; diff --git a/src/libraries/luahttps/src/common/HTTPRequest.h b/src/libraries/luahttps/src/common/HTTPRequest.h index 14722d654..fe21c12ce 100644 --- a/src/libraries/luahttps/src/common/HTTPRequest.h +++ b/src/libraries/luahttps/src/common/HTTPRequest.h @@ -8,14 +8,6 @@ class HTTPRequest { public: - typedef std::function ConnectionFactory; - HTTPRequest(ConnectionFactory factory); - - HTTPSClient::Reply request(const HTTPSClient::Request &req); - -private: - ConnectionFactory factory; - struct DissectedURL { bool valid; @@ -25,6 +17,14 @@ class HTTPRequest std::string query; // TODO: Auth? }; + typedef std::function ConnectionFactory; + + HTTPRequest(ConnectionFactory factory); + + HTTPSClient::Reply request(const HTTPSClient::Request &req); + + static DissectedURL parseUrl(const std::string &url); - DissectedURL parseUrl(const std::string &url); +private: + ConnectionFactory factory; }; diff --git a/src/libraries/luahttps/src/common/HTTPS.cpp b/src/libraries/luahttps/src/common/HTTPS.cpp index d4ca45452..2667f2a4d 100644 --- a/src/libraries/luahttps/src/common/HTTPS.cpp +++ b/src/libraries/luahttps/src/common/HTTPS.cpp @@ -19,6 +19,9 @@ #ifdef HTTPS_BACKEND_ANDROID # include "../android/AndroidClient.h" #endif +#ifdef HTTPS_BACKEND_WININET +# include "../windows/WinINetClient.h" +#endif #ifdef HTTPS_BACKEND_CURL static CurlClient curlclient; @@ -35,6 +38,9 @@ #ifdef HTTPS_BACKEND_ANDROID static AndroidClient androidclient; #endif +#ifdef HTTPS_BACKEND_WININET + static WinINetClient wininetclient; +#endif static HTTPSClient *clients[] = { #ifdef HTTPS_BACKEND_CURL @@ -42,6 +48,10 @@ static HTTPSClient *clients[] = { #endif #ifdef HTTPS_BACKEND_OPENSSL &opensslclient, +#endif + // WinINet must be above SChannel +#ifdef HTTPS_BACKEND_WININET + &wininetclient, #endif #ifdef HTTPS_BACKEND_SCHANNEL &schannelclient, diff --git a/src/libraries/luahttps/src/common/HTTPSClient.cpp b/src/libraries/luahttps/src/common/HTTPSClient.cpp index 6e32ea533..25876ef28 100644 --- a/src/libraries/luahttps/src/common/HTTPSClient.cpp +++ b/src/libraries/luahttps/src/common/HTTPSClient.cpp @@ -30,8 +30,8 @@ bool HTTPSClient::ci_string_less::operator()(const std::string &lhs, const std:: } HTTPSClient::Request::Request(const std::string &url) - : url(url) - , method("") +: url(url) +, method("GET") { } diff --git a/src/libraries/luahttps/src/common/config.h b/src/libraries/luahttps/src/common/config.h index 2889ba257..e55c3fe3f 100644 --- a/src/libraries/luahttps/src/common/config.h +++ b/src/libraries/luahttps/src/common/config.h @@ -5,6 +5,11 @@ #elif defined(WIN32) || defined(_WIN32) #define HTTPS_BACKEND_SCHANNEL #define HTTPS_USE_WINSOCK + #include + #if !defined(WINAPI_FAMILY) || (WINAPI_FAMILY == WINAPI_FAMILY_DESKTOP_APP) + // WinINet is only supported on desktop. + #define HTTPS_BACKEND_WININET + #endif #elif defined(__ANDROID__) #define HTTPS_BACKEND_ANDROID #elif defined(__APPLE__) diff --git a/src/libraries/luahttps/src/generic/CurlClient.cpp b/src/libraries/luahttps/src/generic/CurlClient.cpp index 451db4b58..ac8468912 100644 --- a/src/libraries/luahttps/src/generic/CurlClient.cpp +++ b/src/libraries/luahttps/src/generic/CurlClient.cpp @@ -1,49 +1,129 @@ +#ifdef _WIN32 +#define NOMINMAX +#define WIN32_LEAN_AND_MEAN +#endif + #include "CurlClient.h" #ifdef HTTPS_BACKEND_CURL -#include +#include #include #include #include +// Dynamic library loader +#ifdef _WIN32 +#include +#else +#include +#endif + +typedef struct StringReader +{ + const std::string *str; + size_t pos; +} StringReader; + +template +static inline bool loadSymbol(T &var, void *handle, const char *name) +{ +#ifdef _WIN32 + var = (T) GetProcAddress((HMODULE) handle, name); +#else + var = (T) dlsym(handle, name); +#endif + return var != nullptr; +} + CurlClient::Curl::Curl() +: handle(nullptr) +, loaded(false) +, global_cleanup(nullptr) +, easy_init(nullptr) +, easy_cleanup(nullptr) +, easy_setopt(nullptr) +, easy_perform(nullptr) +, easy_getinfo(nullptr) +, slist_append(nullptr) +, slist_free_all(nullptr) { - void *handle = dlopen("libcurl.so", RTLD_LAZY); +#ifdef _WIN32 + handle = (void *) LoadLibraryA("libcurl.dll"); +#else + handle = dlopen("libcurl.so.4", RTLD_LAZY); +#endif if (!handle) - { - loaded = false; return; - } - void (*global_init)() = (void(*)()) dlsym(handle, "curl_global_init"); - easy_init = (CURL*(*)()) dlsym(handle, "curl_easy_init"); - easy_cleanup = (void(*)(CURL*)) dlsym(handle, "curl_easy_cleanup"); - easy_setopt = (CURLcode(*)(CURL*,CURLoption,...)) dlsym(handle, "curl_easy_setopt"); - easy_perform = (CURLcode(*)(CURL*)) dlsym(handle, "curl_easy_perform"); - easy_getinfo = (CURLcode(*)(CURL*,CURLINFO,...)) dlsym(handle, "curl_easy_getinfo"); - slist_append = (curl_slist*(*)(curl_slist*,const char*)) dlsym(handle, "curl_slist_append"); - slist_free_all = (void(*)(curl_slist*)) dlsym(handle, "curl_slist_free_all"); + // Load symbols + decltype(&curl_global_init) global_init = nullptr; + if (!loadSymbol(global_init, handle, "curl_global_init")) + return; + if (!loadSymbol(global_cleanup, handle, "curl_global_cleanup")) + return; + if (!loadSymbol(easy_init, handle, "curl_easy_init")) + return; + if (!loadSymbol(easy_cleanup, handle, "curl_easy_cleanup")) + return; + if (!loadSymbol(easy_setopt, handle, "curl_easy_setopt")) + return; + if (!loadSymbol(easy_perform, handle, "curl_easy_perform")) + return; + if (!loadSymbol(easy_getinfo, handle, "curl_easy_getinfo")) + return; + if (!loadSymbol(slist_append, handle, "curl_slist_append")) + return; + if (!loadSymbol(slist_free_all, handle, "curl_slist_free_all")) + return; + + global_init(CURL_GLOBAL_DEFAULT); + loaded = true; +} + +CurlClient::Curl::~Curl() +{ + if (loaded) + global_cleanup(); + + if (handle) +#ifdef _WIN32 + FreeLibrary((HMODULE) handle); +#else + dlclose(handle); +#endif +} + +static char toUppercase(char c) +{ + int ch = (unsigned char) c; + return toupper(ch); +} - loaded = (global_init && easy_init && easy_cleanup && easy_setopt && easy_perform && easy_getinfo && slist_append && slist_free_all); +static size_t stringReader(char *ptr, size_t size, size_t nmemb, StringReader *reader) +{ + const char *data = reader->str->data(); + size_t len = reader->str->length(); + size_t maxCount = (len - reader->pos) / size; + size_t desiredCount = std::min(maxCount, nmemb); + size_t desiredBytes = desiredCount * size; - if (!loaded) - return; + std::copy(data + reader->pos, data + desiredBytes, ptr); + reader->pos += desiredBytes; - global_init(); + return desiredCount; } -static size_t stringstreamWriter(char *ptr, size_t size, size_t nmemb, void *userdata) +static size_t stringstreamWriter(char *ptr, size_t size, size_t nmemb, std::stringstream *ss) { - std::stringstream *ss = (std::stringstream*) userdata; size_t count = size*nmemb; ss->write(ptr, count); return count; } -static size_t headerWriter(char *ptr, size_t size, size_t nmemb, void *userdata) +static size_t headerWriter(char *ptr, size_t size, size_t nmemb, std::map *userdata) { - std::map &headers = *((std::map*) userdata); + std::map &headers = *userdata; size_t count = size*nmemb; std::string line(ptr, count); size_t split = line.find(':'); @@ -64,7 +144,10 @@ bool CurlClient::valid() const HTTPSClient::Reply CurlClient::request(const HTTPSClient::Request &req) { Reply reply; - reply.responseCode = 400; + reply.responseCode = 0; + + // Use sensible default header for later + HTTPSClient::header_map newHeaders = req.headers; CURL *handle = curl.easy_init(); if (!handle) @@ -72,23 +155,26 @@ HTTPSClient::Reply CurlClient::request(const HTTPSClient::Request &req) curl.easy_setopt(handle, CURLOPT_URL, req.url.c_str()); curl.easy_setopt(handle, CURLOPT_FOLLOWLOCATION, 1L); + curl.easy_setopt(handle, CURLOPT_CUSTOMREQUEST, req.method.c_str()); - if (req.method == "PUT") - curl.easy_setopt(handle, CURLOPT_PUT, 1L); - else if (req.method == "POST") - curl.easy_setopt(handle, CURLOPT_POST, 1L); - else - curl.easy_setopt(handle, CURLOPT_CUSTOMREQUEST, req.method.c_str()); + StringReader reader {}; if (req.postdata.size() > 0 && (req.method != "GET" && req.method != "HEAD")) { - curl.easy_setopt(handle, CURLOPT_POSTFIELDS, req.postdata.c_str()); - curl.easy_setopt(handle, CURLOPT_POSTFIELDSIZE, req.postdata.size()); + reader.str = &req.postdata; + reader.pos = 0; + curl.easy_setopt(handle, CURLOPT_UPLOAD, 1L); + curl.easy_setopt(handle, CURLOPT_READFUNCTION, stringReader); + curl.easy_setopt(handle, CURLOPT_READDATA, &reader); + curl.easy_setopt(handle, CURLOPT_INFILESIZE_LARGE, (curl_off_t) req.postdata.length()); } + if (req.method == "HEAD") + curl.easy_setopt(handle, CURLOPT_NOBODY, 1L); + // Curl doesn't copy memory, keep the strings around std::vector lines; - for (auto &header : req.headers) + for (auto &header : newHeaders) { std::stringstream line; line << header.first << ": " << header.second; diff --git a/src/libraries/luahttps/src/generic/CurlClient.h b/src/libraries/luahttps/src/generic/CurlClient.h index 9274fdfc5..30d4a9bcb 100644 --- a/src/libraries/luahttps/src/generic/CurlClient.h +++ b/src/libraries/luahttps/src/generic/CurlClient.h @@ -18,16 +18,20 @@ class CurlClient : public HTTPSClient static struct Curl { Curl(); + ~Curl(); + void *handle; bool loaded; - CURL *(*easy_init)(); - void (*easy_cleanup)(CURL *handle); - CURLcode (*easy_setopt)(CURL *handle, CURLoption option, ...); - CURLcode (*easy_perform)(CURL *easy_handle); - CURLcode (*easy_getinfo)(CURL *curl, CURLINFO info, ...); + decltype(&curl_global_cleanup) global_cleanup; - curl_slist *(*slist_append)(curl_slist *list, const char *string); - void (*slist_free_all)(curl_slist *list); + decltype(&curl_easy_init) easy_init; + decltype(&curl_easy_cleanup) easy_cleanup; + decltype(&curl_easy_setopt) easy_setopt; + decltype(&curl_easy_perform) easy_perform; + decltype(&curl_easy_getinfo) easy_getinfo; + + decltype(&curl_slist_append) slist_append; + decltype(&curl_slist_free_all) slist_free_all; } curl; }; diff --git a/src/libraries/luahttps/src/lua/main.cpp b/src/libraries/luahttps/src/lua/main.cpp index a2d843dbd..126772676 100644 --- a/src/libraries/luahttps/src/lua/main.cpp +++ b/src/libraries/luahttps/src/lua/main.cpp @@ -78,6 +78,7 @@ static int w_request(lua_State *L) if (!lua_isnoneornil(L, -1)) { req.postdata = w_checkstring(L, -1); + req.headers["Content-Type"] = "application/x-www-form-urlencoded"; defaultMethod = "POST"; } lua_pop(L, 1); diff --git a/src/libraries/luahttps/src/windows/WinINetClient.cpp b/src/libraries/luahttps/src/windows/WinINetClient.cpp new file mode 100644 index 000000000..5cf787641 --- /dev/null +++ b/src/libraries/luahttps/src/windows/WinINetClient.cpp @@ -0,0 +1,224 @@ +#include "WinINetClient.h" + +#ifdef HTTPS_BACKEND_WININET + +#include +#include +#include +#include + +#include +#include + +#include "../common/HTTPRequest.h" + +class LazyHInternetLoader final +{ +public: + LazyHInternetLoader(): hInternet(nullptr) { } + ~LazyHInternetLoader() + { + if (hInternet) + InternetCloseHandle(hInternet); + } + + HINTERNET getInstance() + { + if (!init) + { + hInternet = InternetOpenA("", INTERNET_OPEN_TYPE_PRECONFIG, nullptr, nullptr, 0); + if (hInternet) + { + // Try to enable HTTP2 + DWORD httpProtocol = HTTP_PROTOCOL_FLAG_HTTP2; + InternetSetOptionA(hInternet, INTERNET_OPTION_ENABLE_HTTP_PROTOCOL, &httpProtocol, sizeof(DWORD)); + SetLastError(0); // If it errors, ignore. + } + } + + return hInternet; + } + +private: + bool init; + HINTERNET hInternet; +}; + +static thread_local LazyHInternetLoader hInternetCache; + +bool WinINetClient::valid() const +{ + // Allow disablement of WinINet backend. + const char *disabler = getenv("LUAHTTPS_DISABLE_WININET"); + if (disabler && strcmp(disabler, "1") == 0) + return false; + + return hInternetCache.getInstance() != nullptr; +} + +HTTPSClient::Reply WinINetClient::request(const HTTPSClient::Request &req) +{ + Reply reply; + reply.responseCode = 0; + + // Parse URL + auto parsedUrl = HTTPRequest::parseUrl(req.url); + + // Default flags + DWORD inetFlags = + INTERNET_FLAG_NO_AUTH | + INTERNET_FLAG_NO_CACHE_WRITE | + INTERNET_FLAG_NO_COOKIES | + INTERNET_FLAG_NO_UI; + + if (parsedUrl.schema == "https") + inetFlags |= INTERNET_FLAG_SECURE; + else if (parsedUrl.schema != "http") + return reply; + + // Keep-Alive + auto connectHeader = req.headers.find("Connection"); + auto headerEnd = req.headers.end(); + if ((connectHeader != headerEnd && connectHeader->second != "close") || connectHeader == headerEnd) + inetFlags |= INTERNET_FLAG_KEEP_CONNECTION; + + // Open internet + HINTERNET hInternet = hInternetCache.getInstance(); + if (hInternet == nullptr) + return reply; + + // Connect + HINTERNET hConnect = InternetConnectA( + hInternet, + parsedUrl.hostname.c_str(), + parsedUrl.port, + nullptr, nullptr, + INTERNET_SERVICE_HTTP, + INTERNET_FLAG_EXISTING_CONNECT, + (DWORD_PTR) this + ); + if (!hConnect) + return reply; + + std::string httpMethod = req.method; + std::transform( + httpMethod.begin(), + httpMethod.end(), + httpMethod.begin(), + [](char c) {return (char)toupper((unsigned char) c); } + ); + + // Open HTTP request + HINTERNET hHTTP = HttpOpenRequestA( + hConnect, + httpMethod.c_str(), + parsedUrl.query.c_str(), + nullptr, + nullptr, + nullptr, + inetFlags, + (DWORD_PTR) this + ); + if (!hHTTP) + { + InternetCloseHandle(hConnect); + return reply; + } + + // Send additional headers + HttpAddRequestHeadersA(hHTTP, "User-Agent:", 0, HTTP_ADDREQ_FLAG_REPLACE); + for (const auto &header: req.headers) + { + std::string headerString = header.first + ": " + header.second + "\r\n"; + HttpAddRequestHeadersA(hHTTP, headerString.c_str(), headerString.length(), HTTP_ADDREQ_FLAG_ADD | HTTP_ADDREQ_FLAG_REPLACE); + } + + // POST data + const char *postData = nullptr; + if (req.postdata.length() > 0 && (httpMethod != "GET" && httpMethod != "HEAD")) + { + char temp[48]; + int len = sprintf(temp, "Content-Length: %u\r\n", (unsigned int) req.postdata.length()); + postData = req.postdata.c_str(); + + HttpAddRequestHeadersA(hHTTP, temp, len, HTTP_ADDREQ_FLAG_ADD | HTTP_ADDREQ_FLAG_REPLACE); + } + + // Send away! + BOOL result = HttpSendRequestA(hHTTP, nullptr, 0, (void *) postData, (DWORD) req.postdata.length()); + if (!result) + { + InternetCloseHandle(hHTTP); + InternetCloseHandle(hConnect); + return reply; + } + + DWORD bufferLength = sizeof(DWORD); + DWORD headerCounter = 0; + + // Status code + DWORD statusCode = 0; + if (!HttpQueryInfoA(hHTTP, HTTP_QUERY_STATUS_CODE | HTTP_QUERY_FLAG_NUMBER, &statusCode, &bufferLength, &headerCounter)) + { + InternetCloseHandle(hHTTP); + InternetCloseHandle(hConnect); + return reply; + } + + // Query headers + std::vector responseHeaders; + bufferLength = 0; + HttpQueryInfoA(hHTTP, HTTP_QUERY_RAW_HEADERS, responseHeaders.data(), &bufferLength, &headerCounter); + if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) + { + InternetCloseHandle(hHTTP); + InternetCloseHandle(hConnect); + return reply; + } + + responseHeaders.resize(bufferLength); + if (!HttpQueryInfoA(hHTTP, HTTP_QUERY_RAW_HEADERS, responseHeaders.data(), &bufferLength, &headerCounter)) + { + InternetCloseHandle(hHTTP); + InternetCloseHandle(hConnect); + return reply; + } + + for (const char *headerData = responseHeaders.data(); *headerData; headerData += strlen(headerData) + 1) + { + const char *value = strchr(headerData, ':'); + if (value) + { + ptrdiff_t keyLen = (ptrdiff_t) (value - headerData); + reply.headers[std::string(headerData, keyLen)] = value + 2; // +2, colon and 1 space character. + } + } + responseHeaders.resize(1); + + // Read response + std::stringstream responseData; + for (;;) + { + constexpr DWORD BUFFER_SIZE = 4096; + char buffer[BUFFER_SIZE]; + DWORD readed = 0; + + BOOL ret = InternetQueryDataAvailable(hHTTP, &readed, 0, 0); + if (!ret || readed == 0) + break; + + if (!InternetReadFile(hHTTP, buffer, BUFFER_SIZE, &readed)) + break; + + responseData.write(buffer, readed); + } + + reply.body = responseData.str(); + reply.responseCode = statusCode; + + InternetCloseHandle(hHTTP); + InternetCloseHandle(hConnect); + return reply; +} + +#endif // HTTPS_BACKEND_WININET diff --git a/src/libraries/luahttps/src/windows/WinINetClient.h b/src/libraries/luahttps/src/windows/WinINetClient.h new file mode 100644 index 000000000..61d28279b --- /dev/null +++ b/src/libraries/luahttps/src/windows/WinINetClient.h @@ -0,0 +1,16 @@ +#pragma once + +#include "../common/config.h" + +#ifdef HTTPS_BACKEND_WININET + +#include "../common/HTTPSClient.h" + +class WinINetClient: public HTTPSClient +{ +public: + bool valid() const override; + HTTPSClient::Reply request(const HTTPSClient::Request &req) override; +}; + +#endif // HTTPS_BACKEND_WININET