diff --git a/src/game_api/script/usertypes/socket_lua.cpp b/src/game_api/script/usertypes/socket_lua.cpp index aa4896f12..6cbe79be3 100644 --- a/src/game_api/script/usertypes/socket_lua.cpp +++ b/src/game_api/script/usertypes/socket_lua.cpp @@ -1,184 +1,13 @@ #include "socket_lua.hpp" +#include "socket.hpp" -#include // for GetModuleHandleA, GetProcAddress -#include // for max -#include // for DetourAttach, DetourTransactionBegin -#include // for exception -#include // for operator new -#include // for inet_address -#include // for udp_socket -#include // for global_table, proxy_key_t, function -#include // for ssize_t -#include // for thread -#include // for get -#include // for move -#include // for max, min -#include // for InternetCloseHandle, InternetOpenA, InternetG... -#include // for sockaddr_in, SOCKET -#include // for inet_ntop +#include // for global_table, proxy_key_t, function +#include // for ssize_t #include "logger.h" // for DEBUG, ByteStr #include "script/lua_backend.hpp" // for LuaBackend #include "script/safe_cb.hpp" // for make_safe_cb -#pragma comment(lib, "wininet.lib") - -void udp_data(sockpp::udp_socket socket, UdpServer* server) -{ - ssize_t n; - char buf[500]; - sockpp::inet_address src; - while (server->kill_thr.test(std::memory_order_acquire) && (n = socket.recv_from(buf, sizeof(buf), &src)) > 0) - { - std::optional ret = server->cb(std::string(buf, n)); - if (ret) - { - socket.send_to(ret.value(), src); - } - } -} - -UdpServer::UdpServer(std::string host_, in_port_t port_, std::function cb_) - : host(host_), port(port_), cb(cb_) -{ - sock.bind(sockpp::inet_address(host, port)); - kill_thr.test_and_set(); - thr = std::thread(udp_data, std::move(sock), this); -} -void UdpServer::clear() // TODO: fix and expose: this and the destructor causes deadlock -{ - kill_thr.clear(std::memory_order_release); - thr.join(); -} -UdpServer::~UdpServer() -{ - if (thr.joinable()) - { - kill_thr.clear(std::memory_order_release); - thr.join(); - } -} - -using NetFun = int(SOCKET, char*, int, int, sockaddr_in*, int*); -NetFun* g_sendto_trampoline{nullptr}; -NetFun* g_recvfrom_trampoline{nullptr}; -int mySendto(SOCKET s, char* buf, int len, int flags, sockaddr_in* addr, int* tolen) -{ - auto ret = g_sendto_trampoline(s, buf, len, flags, addr, tolen); - char ip[16] = ""; - inet_ntop(addr->sin_family, &addr->sin_addr, ip, sizeof(ip)); - DEBUG("SEND: {}:{} | {}", ip, addr->sin_port, ByteStr{buf}); - return ret; -} - -int myRecvfrom(SOCKET s, char* buf, int len, int flags, sockaddr_in* addr, int* fromlen) -{ - auto ret = g_recvfrom_trampoline(s, buf, len, flags, addr, fromlen); - char ip[16] = ""; - inet_ntop(addr->sin_family, &addr->sin_addr, ip, sizeof(ip)); - DEBUG("RECV: {}:{} | {}", ip, addr->sin_port, ByteStr{buf}); - return ret; -} - -bool http_get(const char* sURL, std::string& out, std::string& err) -{ - const int BUFFER_SIZE = 32768; - DWORD iFlags; - const char* sAgent = "curl"; - const char* sHeader = NULL; - HINTERNET hInternet; - HINTERNET hConnect; - char acBuffer[BUFFER_SIZE]; - DWORD iReadBytes; - DWORD iBytesToRead = 0; - DWORD iReadBytesOfRq = 4; - - // Get connection state - InternetGetConnectedState(&iFlags, 0); - if (iFlags & INTERNET_CONNECTION_OFFLINE) - { - err = "Can't connect to the internet"; - return false; - } - - // Open internet session - if (!(iFlags & INTERNET_CONNECTION_PROXY)) - { - hInternet = InternetOpenA(sAgent, INTERNET_OPEN_TYPE_PRECONFIG_WITH_NO_AUTOPROXY, NULL, NULL, 0); - } - else - { - hInternet = InternetOpenA(sAgent, INTERNET_OPEN_TYPE_PRECONFIG, NULL, NULL, 0); - } - if (hInternet) - { - if (sHeader == NULL) - { - sHeader = "Accept: */*\r\n\r\n"; - } - - hConnect = InternetOpenUrlA(hInternet, sURL, sHeader, lstrlenA(sHeader), INTERNET_FLAG_DONT_CACHE | INTERNET_FLAG_PRAGMA_NOCACHE | INTERNET_FLAG_RELOAD, 0); - if (!hConnect) - { - InternetCloseHandle(hInternet); - err = "Can't connect to the url"; - return false; - } - - // Get content size - if (!HttpQueryInfo(hConnect, HTTP_QUERY_CONTENT_LENGTH | HTTP_QUERY_FLAG_NUMBER, (LPVOID)&iBytesToRead, &iReadBytesOfRq, NULL)) - { - iBytesToRead = 0; - } - - do - { - if (!InternetReadFile(hConnect, acBuffer, BUFFER_SIZE, &iReadBytes)) - { - InternetCloseHandle(hInternet); - err = "GET request failed"; - return false; - } - if (iReadBytes > 0) - { - out += std::string(acBuffer, iReadBytes); - } - if (iReadBytes <= 0) - { - break; - } - } while (TRUE); - InternetCloseHandle(hInternet); - } - else - { - err = "Can't connect to the internet"; - return false; - } - - return true; -} - -void http_get_async(HttpRequest* req) -{ - if (http_get(req->url.c_str(), req->response, req->error)) - { - req->cb(req->response, std::nullopt); - } - else - { - req->cb(std::nullopt, req->error); - } - delete req; -} - -HttpRequest::HttpRequest(std::string url_, std::function cb_) - : url(url_), cb(cb_) -{ - std::thread thr(http_get_async, this); - thr.detach(); -} - namespace NSocket { void register_usertypes(sol::state& lua) @@ -204,20 +33,7 @@ void register_usertypes(sol::state& lua) }; /// Hook the sendto and recvfrom functions and start dumping network data to terminal - lua["dump_network"] = []() - { - g_sendto_trampoline = (NetFun*)GetProcAddress(GetModuleHandleA("ws2_32.dll"), "sendto"); - g_recvfrom_trampoline = (NetFun*)GetProcAddress(GetModuleHandleA("ws2_32.dll"), "recvfrom"); - DetourTransactionBegin(); - DetourUpdateThread(GetCurrentThread()); - DetourAttach((void**)&g_sendto_trampoline, mySendto); - DetourAttach((void**)&g_recvfrom_trampoline, myRecvfrom); - const LONG error = DetourTransactionCommit(); - if (error != NO_ERROR) - { - DEBUG("Failed hooking network: {}\n", error); - } - }; + lua["dump_network"] = dump_network; /// Send a synchronous HTTP GET request and return response as a string or nil on an error lua["http_get"] = [&lua](std::string url) -> sol::optional diff --git a/src/game_api/script/usertypes/socket_lua.hpp b/src/game_api/script/usertypes/socket_lua.hpp index 972f8af99..711f09eea 100644 --- a/src/game_api/script/usertypes/socket_lua.hpp +++ b/src/game_api/script/usertypes/socket_lua.hpp @@ -1,42 +1,6 @@ #pragma once -#include // for atomic_flag -#include // for function -#include // for in_port_t -#include // for state, optional -#include // for string -#include // for thread - -#include "sockpp/udp_socket.h" // for udp_socket - -class UdpServer -{ - public: - using SocketCb = std::optional(std::string); - - UdpServer(std::string host, in_port_t port, std::function cb); - ~UdpServer(); - void clear(); - - std::string host; - in_port_t port; - std::function cb; - std::thread thr; - std::atomic_flag kill_thr; - sockpp::udp_socket sock; -}; - -class HttpRequest -{ - public: - using HttpCb = void(std::optional, std::optional); - - HttpRequest(std::string url, std::function cb); - std::string url; - std::function cb; - std::string response; - std::string error; -}; +#include // for state, optional namespace NSocket { diff --git a/src/game_api/socket.cpp b/src/game_api/socket.cpp new file mode 100644 index 000000000..5240b5b39 --- /dev/null +++ b/src/game_api/socket.cpp @@ -0,0 +1,189 @@ +#include "socket.hpp" + +#include // for GetModuleHandleA, GetProcAddress +#include // for max +#include // for DetourAttach, DetourTransactionBegin +#include // for exception +#include // for operator new +#include // for inet_address +#include // for udp_socket +#include // for thread +#include // for get +#include // for move +#include // for max, min +#include // for InternetCloseHandle, InternetOpenA, InternetG... +#include // for sockaddr_in, SOCKET +#include // for inet_ntop + +#pragma comment(lib, "wininet.lib") + +using NetFun = int(SOCKET, char*, int, int, sockaddr_in*, int*); +NetFun* g_sendto_trampoline{nullptr}; +NetFun* g_recvfrom_trampoline{nullptr}; +int mySendto(SOCKET s, char* buf, int len, int flags, sockaddr_in* addr, int* tolen) +{ + auto ret = g_sendto_trampoline(s, buf, len, flags, addr, tolen); + char ip[16] = ""; + inet_ntop(addr->sin_family, &addr->sin_addr, ip, sizeof(ip)); + DEBUG("SEND: {}:{} | {}", ip, addr->sin_port, ByteStr{buf}); + return ret; +} + +int myRecvfrom(SOCKET s, char* buf, int len, int flags, sockaddr_in* addr, int* fromlen) +{ + auto ret = g_recvfrom_trampoline(s, buf, len, flags, addr, fromlen); + char ip[16] = ""; + inet_ntop(addr->sin_family, &addr->sin_addr, ip, sizeof(ip)); + DEBUG("RECV: {}:{} | {}", ip, addr->sin_port, ByteStr{buf}); + return ret; +} + +void dump_network() +{ + g_sendto_trampoline = (NetFun*)GetProcAddress(GetModuleHandleA("ws2_32.dll"), "sendto"); + g_recvfrom_trampoline = (NetFun*)GetProcAddress(GetModuleHandleA("ws2_32.dll"), "recvfrom"); + DetourTransactionBegin(); + DetourUpdateThread(GetCurrentThread()); + DetourAttach((void**)&g_sendto_trampoline, mySendto); + DetourAttach((void**)&g_recvfrom_trampoline, myRecvfrom); + const LONG error = DetourTransactionCommit(); + if (error != NO_ERROR) + { + DEBUG("Failed hooking network: {}\n", error); + } +} + +void udp_data(sockpp::udp_socket socket, UdpServer* server) +{ + ssize_t n; + char buf[500]; + sockpp::inet_address src; + while (server->kill_thr.test(std::memory_order_acquire) && (n = socket.recv_from(buf, sizeof(buf), &src)) > 0) + { + std::optional ret = server->cb(std::string(buf, n)); + if (ret) + { + socket.send_to(ret.value(), src); + } + } +} + +UdpServer::UdpServer(std::string host_, in_port_t port_, std::function cb_) + : host(host_), port(port_), cb(cb_) +{ + sock.bind(sockpp::inet_address(host, port)); + kill_thr.test_and_set(); + thr = std::thread(udp_data, std::move(sock), this); +} +void UdpServer::clear() // TODO: fix and expose: this and the destructor causes deadlock +{ + kill_thr.clear(std::memory_order_release); + thr.join(); +} +UdpServer::~UdpServer() +{ + if (thr.joinable()) + { + kill_thr.clear(std::memory_order_release); + thr.join(); + } +} + +bool http_get(const char* sURL, std::string& out, std::string& err) +{ + const int BUFFER_SIZE = 32768; + DWORD iFlags; + const char* sAgent = "curl"; + const char* sHeader = NULL; + HINTERNET hInternet; + HINTERNET hConnect; + char acBuffer[BUFFER_SIZE]; + DWORD iReadBytes; + DWORD iBytesToRead = 0; + DWORD iReadBytesOfRq = 4; + + // Get connection state + InternetGetConnectedState(&iFlags, 0); + if (iFlags & INTERNET_CONNECTION_OFFLINE) + { + err = "Can't connect to the internet"; + return false; + } + + // Open internet session + if (!(iFlags & INTERNET_CONNECTION_PROXY)) + { + hInternet = InternetOpenA(sAgent, INTERNET_OPEN_TYPE_PRECONFIG_WITH_NO_AUTOPROXY, NULL, NULL, 0); + } + else + { + hInternet = InternetOpenA(sAgent, INTERNET_OPEN_TYPE_PRECONFIG, NULL, NULL, 0); + } + if (hInternet) + { + if (sHeader == NULL) + { + sHeader = "Accept: */*\r\n\r\n"; + } + + hConnect = InternetOpenUrlA(hInternet, sURL, sHeader, lstrlenA(sHeader), INTERNET_FLAG_DONT_CACHE | INTERNET_FLAG_PRAGMA_NOCACHE | INTERNET_FLAG_RELOAD, 0); + if (!hConnect) + { + InternetCloseHandle(hInternet); + err = "Can't connect to the url"; + return false; + } + + // Get content size + if (!HttpQueryInfo(hConnect, HTTP_QUERY_CONTENT_LENGTH | HTTP_QUERY_FLAG_NUMBER, (LPVOID)&iBytesToRead, &iReadBytesOfRq, NULL)) + { + iBytesToRead = 0; + } + + do + { + if (!InternetReadFile(hConnect, acBuffer, BUFFER_SIZE, &iReadBytes)) + { + InternetCloseHandle(hInternet); + err = "GET request failed"; + return false; + } + if (iReadBytes > 0) + { + out += std::string(acBuffer, iReadBytes); + } + if (iReadBytes <= 0) + { + break; + } + } while (TRUE); + InternetCloseHandle(hInternet); + } + else + { + err = "Can't connect to the internet"; + return false; + } + + return true; +} + +void http_get_async(HttpRequest* req) +{ + if (http_get(req->url.c_str(), req->response, req->error)) + { + req->cb(req->response, std::nullopt); + } + else + { + req->cb(std::nullopt, req->error); + } + delete req; +} + +HttpRequest::HttpRequest(std::string url_, std::function cb_) + : url(url_), cb(cb_) +{ + std::thread thr(http_get_async, this); + thr.detach(); +} diff --git a/src/game_api/socket.hpp b/src/game_api/socket.hpp new file mode 100644 index 000000000..c6904967a --- /dev/null +++ b/src/game_api/socket.hpp @@ -0,0 +1,41 @@ +#pragma once + +#include // for atomic_flag +#include // for function +#include // for in_port_t +#include // for string +#include // for thread + +#include "sockpp/udp_socket.h" // for udp_socket + +class UdpServer +{ + public: + using SocketCb = std::optional(std::string); + + UdpServer(std::string host, in_port_t port, std::function cb); + ~UdpServer(); + void clear(); + + std::string host; + in_port_t port; + std::function cb; + std::thread thr; + std::atomic_flag kill_thr; + sockpp::udp_socket sock; +}; + +class HttpRequest +{ + public: + using HttpCb = void(std::optional, std::optional); + + HttpRequest(std::string url, std::function cb); + std::string url; + std::function cb; + std::string response; + std::string error; +}; + +void dump_network(); +bool http_get(const char* sURL, std::string& out, std::string& err); diff --git a/src/injected/ui.cpp b/src/injected/ui.cpp index 22d479be3..1d1ff746e 100644 --- a/src/injected/ui.cpp +++ b/src/injected/ui.cpp @@ -50,6 +50,7 @@ #include "screen.hpp" #include "script.hpp" #include "settings_api.hpp" +#include "socket.hpp" #include "sound_manager.hpp" // TODO: remove from here? #include "state.hpp" #include "steam_api.hpp" @@ -298,6 +299,7 @@ std::string g_change_key = ""; const char* inifile = "imgui.ini"; const std::string cfgfile = "overlunky.ini"; std::string scriptpath = "Overlunky/Scripts"; +const std::string version_check_url = "https://api.github.com/repos/spelunky-fyi/overlunky/git/ref/tags/whip"; std::string fontfile = ""; std::vector fontsize = {14.0f, 32.0f, 72.0f}; @@ -365,13 +367,146 @@ std::map options = { {"vsync", true}, {"uncap_unfocused_fps", true}, {"pause_loading", false}, - {"pause_update_camera", false}}; + {"pause_update_camera", false}, + {"update_check", true}, +}; double g_engine_fps = 60.0, g_unfocused_fps = 33.0; double fps_min = 0, fps_max = 600.0; float g_speedhack_ui_multiplier = 1.0; float g_speedhack_old_multiplier = 1.0; +enum class VERSION_CHECK +{ + HIDDEN, + DISABLED, + CHECKING, + FAILED, + OLD, + LATEST, + STABLE, +}; +struct VersionCheck +{ + std::string message; + float color[4]; + float fade; +}; +struct VersionCheckStatus +{ + VERSION_CHECK state; + float opacity; + int64_t start; + float color[4]; +}; +VersionCheckStatus version_check_status{VERSION_CHECK::HIDDEN, 1.0f, 0}; +std::array version_check_messages{ + VersionCheck{"", {0.0f, 0.0f, 0.0f, 0.0f}, 0}, + VersionCheck{"Automatic update check is disabled...", {0.5f, 0.5f, 0.5f, 0.8f}, 600.0f}, + VersionCheck{"Checking for updates on GitHub...", {0.3f, 0.6f, 1.0f, 0.9f}, 0}, + VersionCheck{"Automatic update check failed. Please retry, check GitHub or use Modlunky to update!", {0.8f, 0.0f, 0.0f, 1.0f}, 900.0f}, + VersionCheck{"This is not the latest build of Overlunky WHIP! Please run the Overlunky launcher or use Modlunky to get the latest build!", {0.8f, 0.0f, 0.0f, 1.0f}, 3600.0f}, + VersionCheck{"This is the latest build of Overlunky WHIP! Yippee!", {0.0f, 0.8f, 0.2f, 0.8f}, 900.0f}, + VersionCheck{"This is a stable build of Overlunky. Get the WHIP build for automatic updates!", {0.9f, 0.6f, 0.0f, 0.8f}, 600.0f}, +}; +void render_version_warning() +{ + if (version_check_status.state == VERSION_CHECK::HIDDEN) + return; + + version_check_status.color[0] = version_check_messages[(int)version_check_status.state].color[0]; + version_check_status.color[1] = version_check_messages[(int)version_check_status.state].color[1]; + version_check_status.color[2] = version_check_messages[(int)version_check_status.state].color[2]; + version_check_status.color[3] = version_check_messages[(int)version_check_status.state].color[3]; + + if (version_check_messages[(int)version_check_status.state].fade > 0) + { + if (version_check_status.start == 0) + version_check_status.start = get_global_frame_count(); + + auto duration = get_global_frame_count() - version_check_status.start; + version_check_status.opacity = 1.0f - duration / version_check_messages[(int)version_check_status.state].fade; + if (version_check_status.opacity <= 0.0f) + { + version_check_status.state = VERSION_CHECK::HIDDEN; + return; + } + version_check_status.color[3] = version_check_status.opacity; + } + + auto& render_api = RenderAPI::get(); + const float scale{0.0004f}; + static TextRenderingInfo tri{}; + tri.set_text(version_check_messages[(int)version_check_status.state].message, 0, 0, scale, scale, 1, 0); + const auto [w, h] = tri.text_size(); + tri.y = -1.0f + std::abs(h) / 2.0f + std::abs(h) + 0.005f; + render_api.draw_text(&tri, version_check_status.color); +} + +void get_version_info(std::optional res, std::optional err) +{ + // http error + if (!res.has_value()) + { + version_check_status.state = VERSION_CHECK::FAILED; + if (err.has_value()) + DEBUG("UpdateCheck: Error: {}", err.value()); + return; + } + std::string data = res.value(); + + // some github error + if (data.find("overlunky") == std::string::npos) + { + version_check_status.state = VERSION_CHECK::FAILED; + DEBUG("UpdateCheck: {}", version_check_messages[(int)version_check_status.state].message); + return; + } + + std::string version = fmt::format("\"sha\": \"{}", get_version()); + + // old version + if (data.find(version) == std::string::npos) + { + version_check_status.state = VERSION_CHECK::OLD; + DEBUG("UpdateCheck: {}", version_check_messages[(int)version_check_status.state].message); + return; + } + // latest version + else + { + version_check_status.state = VERSION_CHECK::LATEST; + DEBUG("UpdateCheck: {}", version_check_messages[(int)version_check_status.state].message); + return; + } +} + +void version_check(bool force = false) +{ + version_check_status = VersionCheckStatus{VERSION_CHECK::HIDDEN, 1.0f, 0}; + + if (!options["update_check"] && !force) + { + version_check_status.state = VERSION_CHECK::DISABLED; + DEBUG("UpdateCheck: {}", version_check_messages[(int)version_check_status.state].message); + return; + } + + auto version = std::string(get_version()); + + // not a whip build + if (version.find(".") != std::string::npos) + { + version_check_status.state = VERSION_CHECK::STABLE; + DEBUG("UpdateCheck: {}", version_check_messages[(int)version_check_status.state].message); + return; + } + + version_check_status.state = VERSION_CHECK::CHECKING; + DEBUG("UpdateCheck: {}", version_check_messages[(int)version_check_status.state].message); + new HttpRequest(std::move(version_check_url), get_version_info); +} + void hook_savegame() { static bool savegame_hooked = false; @@ -5831,6 +5966,12 @@ void render_options() ImGui::Checkbox("Show tooltips", &options["show_tooltips"]); tooltip("Am I annoying you already :("); + if (ImGui::Checkbox("Automatically check for updates", &options["update_check"])) + { + if (options["update_check"]) + version_check(true); + } + tooltip("Check if you're running the latest build of Overlunky WHIP on start."); endmenu(); } @@ -5906,6 +6047,10 @@ void render_options() { if (ImGui::BeginMenu("Help")) { + if (ImGui::MenuItem("Check for updates")) + version_check(true); + if (ImGui::MenuItem("Get latest version here")) + ShellExecuteA(NULL, "open", "https://github.com/spelunky-fyi/overlunky/releases/tag/whip", NULL, NULL, SW_SHOWNORMAL); if (ImGui::MenuItem("README")) ShellExecuteA(NULL, "open", "https://github.com/spelunky-fyi/overlunky#overlunky", NULL, NULL, SW_SHOWNORMAL); if (ImGui::MenuItem("API Documentation")) @@ -8995,6 +9140,7 @@ void imgui_init(ImGuiContext*) refresh_script_files(); autorun_scripts(); set_colors(); + version_check(); windows["tool_entity"] = new Window({"Spawner", is_tab_detached("tool_entity"), is_tab_open("tool_entity")}); windows["tool_door"] = new Window({"Warp", is_tab_detached("tool_door"), is_tab_open("tool_door")}); windows["tool_camera"] = new Window({"Camera", is_tab_detached("tool_camera"), is_tab_open("tool_camera")}); @@ -9325,6 +9471,10 @@ void imgui_draw() } if (ImGui::BeginMenu("Help")) { + if (ImGui::MenuItem("Check for updates")) + version_check(true); + if (ImGui::MenuItem("Get latest version here")) + ShellExecuteA(NULL, "open", "https://github.com/spelunky-fyi/overlunky/releases/tag/whip", NULL, NULL, SW_SHOWNORMAL); if (ImGui::MenuItem("README")) ShellExecuteA(NULL, "open", "https://github.com/spelunky-fyi/overlunky#overlunky", NULL, NULL, SW_SHOWNORMAL); if (ImGui::MenuItem("API Documentation")) @@ -9604,7 +9754,7 @@ void init_ui() render_api.set_advanced_hud(); const std::string version_string = fmt::format("Overlunky {}", get_version()); - const float scale{0.00035f}; + const float scale{0.0004f}; static TextRenderingInfo tri{}; tri.set_text(version_string, 0, 0, scale, scale, 1, 0); @@ -9618,6 +9768,7 @@ void init_ui() auto& render_api_l = RenderAPI::get(); static const float color[4]{1.0f, 1.0f, 1.0f, 0.3f}; render_vanilla_stuff(); + render_version_warning(); render_api_l.draw_text(&tri, color); }); }