Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GSLzma: Fix a file handle leak in GSDumpLzma #11118

Merged
merged 1 commit into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 29 additions & 44 deletions pcsx2/GS/GSLzma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include "common/AlignedMalloc.h"
#include "common/Console.h"
#include "common/FileSystem.h"
#include "common/ScopedGuard.h"
#include "common/StringUtil.h"
#include "common/BitUtils.h"
Expand Down Expand Up @@ -48,7 +47,7 @@ bool GSDumpFile::GetPreviewImageFromDump(const char* filename, u32* width, u32*
return false;
}

std::unique_ptr<u8[]> header_bits = std::make_unique<u8[]>(header_size);
std::unique_ptr<u8[]> header_bits = std::make_unique_for_overwrite<u8[]>(header_size);
if (!dump->Read(header_bits.get(), header_size))
return false;

Expand Down Expand Up @@ -251,7 +250,7 @@ namespace
~GSDumpLzma() override;

protected:
bool Open(std::FILE* fp, Error* error) override;
bool Open(FileSystem::ManagedCFilePtr fp, Error* error) override;
bool IsEof() override;
size_t Read(void* ptr, size_t size) override;

Expand All @@ -269,7 +268,7 @@ namespace

bool DecompressNextBlock();

std::FILE* m_fp = nullptr;

std::vector<Block> m_blocks;
size_t m_stream_size = 0;

Expand All @@ -286,14 +285,12 @@ namespace

GSDumpLzma::~GSDumpLzma()
{


XzUnpacker_Free(&m_unpacker);
}

bool GSDumpLzma::Open(std::FILE* fp, Error* error)
bool GSDumpLzma::Open(FileSystem::ManagedCFilePtr fp, Error* error)
{
m_fp = fp;
m_fp = std::move(fp);

GSInit7ZCRCTables();

Expand Down Expand Up @@ -324,7 +321,7 @@ namespace
*pos = new_pos;
return SZ_OK;
}},
m_fp};
m_fp.get()};

CLookToRead2 look_stream = {};
LookToRead2_Init(&look_stream);
Expand All @@ -349,7 +346,7 @@ namespace
Xzs_Free(&xzs, &g_Alloc);
});

const s64 file_size = FileSystem::FSize64(m_fp);
const s64 file_size = FileSystem::FSize64(m_fp.get());
Int64 start_pos = file_size;
SRes res = Xzs_ReadBackward(&xzs, &look_stream.vt, &start_pos, nullptr, &g_Alloc);
if (res != SZ_OK)
Expand Down Expand Up @@ -402,8 +399,8 @@ namespace
if (block.compressed_size > m_block_read_buffer.size())
m_block_read_buffer.resize(Common::AlignUpPow2(block.compressed_size, _128kb));

if (FileSystem::FSeek64(m_fp, static_cast<s64>(block.file_offset), SEEK_SET) != 0 ||
std::fread(m_block_read_buffer.data(), block.compressed_size, 1, m_fp) != 1)
if (FileSystem::FSeek64(m_fp.get(), static_cast<s64>(block.file_offset), SEEK_SET) != 0 ||
std::fread(m_block_read_buffer.data(), block.compressed_size, 1, m_fp.get()) != 1)
{
Console.ErrorFmt("Failed to read {} bytes from offset {}", block.file_offset, block.compressed_size);
return false;
Expand Down Expand Up @@ -473,8 +470,6 @@ namespace
static constexpr u32 INPUT_BUFFER_SIZE = 512 * _1kb;
static constexpr u32 OUTPUT_BUFFER_SIZE = 2 * _1mb;

std::FILE* m_fp = nullptr;

ZSTD_DStream* m_strm = nullptr;
ZSTD_inBuffer m_inbuf = {};

Expand All @@ -489,7 +484,7 @@ namespace
GSDumpDecompressZst();
~GSDumpDecompressZst() override;

bool Open(std::FILE* fp, Error* error) override;
bool Open(FileSystem::ManagedCFilePtr fp, Error* error) override;
bool IsEof() override;
size_t Read(void* ptr, size_t size) override;
};
Expand All @@ -502,21 +497,18 @@ namespace
ZSTD_freeDStream(m_strm);

if (m_inbuf.src)
_aligned_free((void*)m_inbuf.src);
_aligned_free(const_cast<void*>(m_inbuf.src));
if (m_area)
_aligned_free(m_area);

if (m_fp)
std::fclose(m_fp);
}

bool GSDumpDecompressZst::Open(std::FILE* fp, Error* error)
bool GSDumpDecompressZst::Open(FileSystem::ManagedCFilePtr fp, Error* error)
{
m_fp = fp;
m_fp = std::move(fp);
m_strm = ZSTD_createDStream();

m_area = (uint8_t*)_aligned_malloc(OUTPUT_BUFFER_SIZE, 32);
m_inbuf.src = (uint8_t*)_aligned_malloc(INPUT_BUFFER_SIZE, 32);
m_area = static_cast<uint8_t*>(_aligned_malloc(OUTPUT_BUFFER_SIZE, 32));
m_inbuf.src = static_cast<uint8_t*>(_aligned_malloc(INPUT_BUFFER_SIZE, 32));
m_inbuf.pos = 0;
m_inbuf.size = 0;
m_avail = 0;
Expand All @@ -530,12 +522,12 @@ namespace
while (outbuf.pos == 0)
{
// Nothing left in the input buffer. Read data from the file
if (m_inbuf.pos == m_inbuf.size && !std::feof(m_fp))
if (m_inbuf.pos == m_inbuf.size && !std::feof(m_fp.get()))
{
m_inbuf.size = fread((void*)m_inbuf.src, 1, INPUT_BUFFER_SIZE, m_fp);
m_inbuf.size = fread(const_cast<void*>(m_inbuf.src), 1, INPUT_BUFFER_SIZE, m_fp.get());
m_inbuf.pos = 0;

if (ferror(m_fp))
if (ferror(m_fp.get()))
{
Console.Error("Zst read error: %s", strerror(errno));
return false;
Expand All @@ -557,13 +549,13 @@ namespace

bool GSDumpDecompressZst::IsEof()
{
return feof(m_fp) && m_avail == 0 && m_inbuf.pos == m_inbuf.size;
return feof(m_fp.get()) && m_avail == 0 && m_inbuf.pos == m_inbuf.size;
}

size_t GSDumpDecompressZst::Read(void* ptr, size_t size)
{
size_t off = 0;
uint8_t* dst = (uint8_t*)ptr;
uint8_t* dst = static_cast<uint8_t*>(ptr);
while (size && !IsEof())
{
if (m_avail == 0)
Expand All @@ -587,43 +579,36 @@ namespace

class GSDumpRaw final : public GSDumpFile
{
std::FILE* m_fp;

public:
GSDumpRaw();
~GSDumpRaw() override;

bool Open(std::FILE* fp, Error* error) override;
bool Open(FileSystem::ManagedCFilePtr fp, Error* error) override;
bool IsEof() override;
size_t Read(void* ptr, size_t size) override;
};

GSDumpRaw::GSDumpRaw() = default;

GSDumpRaw::~GSDumpRaw()
{
if (m_fp)
std::fclose(m_fp);
}
GSDumpRaw::~GSDumpRaw() = default;

bool GSDumpRaw::Open(std::FILE* fp, Error* error)
bool GSDumpRaw::Open(FileSystem::ManagedCFilePtr fp, Error* error)
{
m_fp = fp;
m_fp = std::move(fp);
return true;
}

bool GSDumpRaw::IsEof()
{
return !!feof(m_fp);
return !!feof(m_fp.get());
}

size_t GSDumpRaw::Read(void* ptr, size_t size)
{
size_t ret = fread(ptr, 1, size, m_fp);
if (ret != size && ferror(m_fp))
size_t ret = fread(ptr, 1, size, m_fp.get());
if (ret != size && ferror(m_fp.get()))
{
fprintf(stderr, "GSDumpRaw:: Read error (%zu/%zu)\n", ret, size);
return ret;
}

return ret;
Expand All @@ -634,7 +619,7 @@ namespace

std::unique_ptr<GSDumpFile> GSDumpFile::OpenGSDump(const char* filename, Error* error)
{
std::FILE* fp = FileSystem::OpenCFile(filename, "rb", error);
FileSystem::ManagedCFilePtr fp = FileSystem::OpenManagedCFile(filename, "rb", error);
if (!fp)
return nullptr;

Expand All @@ -646,7 +631,7 @@ std::unique_ptr<GSDumpFile> GSDumpFile::OpenGSDump(const char* filename, Error*
else
file = std::make_unique<GSDumpRaw>();

if (!file->Open(fp, error))
if (!file->Open(std::move(fp), error))
file = {};

return file;
Expand Down
7 changes: 6 additions & 1 deletion pcsx2/GS/GSLzma.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#pragma once

#include "common/FileSystem.h"

#include <memory>
#include <string>
#include <vector>
Expand Down Expand Up @@ -305,10 +307,13 @@ class GSDumpFile
protected:
GSDumpFile();

virtual bool Open(std::FILE* fp, Error* error) = 0;
virtual bool Open(FileSystem::ManagedCFilePtr fp, Error* error) = 0;
virtual bool IsEof() = 0;
virtual size_t Read(void* ptr, size_t size) = 0;

protected:
FileSystem::ManagedCFilePtr m_fp;

private:
std::string m_serial;
u32 m_crc = 0;
Expand Down