diff --git a/src/common/indexed_resources.h b/src/common/indexed_resources.h new file mode 100644 index 0000000000..4f3fefb45d --- /dev/null +++ b/src/common/indexed_resources.h @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include + +#include +#include +#include +#include + +#include + +namespace Common { + +template ::max()> +class IndexedResources { +public: + IndexedResources() { + m_free_indices += boost::icl::interval::closed(0, MaxIndex); + } + + template + std::optional Create(Types&&... args) { + std::unique_lock lock{m_mutex}; + if (m_free_indices.empty()) { + return {}; + } + auto index = first(*m_free_indices.begin()); + m_free_indices -= index; + m_container.emplace(index, T(std::forward(args)...)); + return index; + } + + void Destroy(Index index) { + std::unique_lock lock{m_mutex}; + if (m_container.erase(index) > 0) { + m_free_indices += index; + } + } + + std::optional> Get(Index index) { + std::shared_lock lock{m_mutex}; + auto it = m_container.find(index); + if (it == m_container.end()) { + return {}; + } + return it->second; + } + +private: + std::shared_mutex m_mutex; + std::unordered_map m_container; + boost::icl::interval_set m_free_indices; +}; + +} // namespace Common diff --git a/src/core/libraries/ajm/ajm.cpp b/src/core/libraries/ajm/ajm.cpp index 2420832cf1..19bc6511f4 100644 --- a/src/core/libraries/ajm/ajm.cpp +++ b/src/core/libraries/ajm/ajm.cpp @@ -1,14 +1,20 @@ // SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project // SPDX-License-Identifier: GPL-2.0-or-later +#include +#include #include #include #include +#include +#include #include #include +#include #include "common/alignment.h" #include "common/assert.h" +#include "common/indexed_resources.h" #include "common/logging/log.h" #include "common/scope_exit.h" #include "core/libraries/ajm/ajm.h" @@ -31,20 +37,44 @@ namespace Libraries::Ajm { static constexpr u32 AJM_INSTANCE_STATISTICS = 0x80000; -static constexpr u32 SCE_AJM_WAIT_INFINITE = -1; +static constexpr u32 ORBIS_AJM_WAIT_INFINITE = -1; static constexpr u32 MaxInstances = 0x2fff; static constexpr u32 MaxBatches = 1000; +struct AjmJob { + u32 instance_id = 0; + AjmJobFlags flags = {.raw = 0}; + AjmJobInput input; + AjmJobOutput output; +}; + struct BatchInfo { - u16 instance{}; - u16 offset_in_qwords{}; // Needed for AjmBatchError? - bool waiting{}; + u32 id; + std::atomic_bool waiting{}; + std::atomic_bool canceled{}; std::binary_semaphore finished{0}; - int result{}; + boost::container::small_vector jobs; }; +template +ChunkType& AjmBufferExtract(u8*& p_cursor) { + auto* const result = reinterpret_cast(p_cursor); + p_cursor += sizeof(ChunkType); + return *result; +} + +template +void AjmBufferSkip(u8*& p_cursor) { + p_cursor += sizeof(ChunkType); +} + +template +ChunkType& AjmBufferPeek(u8* p_cursor) { + return *reinterpret_cast(p_cursor); +} + struct AjmDevice { u32 max_prio{}; u32 min_prio{}; @@ -53,9 +83,15 @@ struct AjmDevice { std::array is_registered{}; std::array free_instances{}; std::array, MaxInstances> instances; - std::vector> batches{}; + Common::IndexedResources, MaxBatches> batches; std::mutex batches_mutex; + std::jthread worker_thread{}; + std::queue> batch_queue{}; + std::mutex batch_queue_mutex{}; + std::mutex batch_queue_cv_mutex{}; + std::condition_variable_any batch_queue_cv{}; + [[nodiscard]] bool IsRegistered(AjmCodecType type) const { return is_registered[static_cast(type)]; } @@ -66,13 +102,116 @@ struct AjmDevice { AjmDevice() { std::iota(free_instances.begin(), free_instances.end(), 1); + worker_thread = std::jthread([this](std::stop_token stop) { this->WorkerThread(stop); }); + } + + void WorkerThread(std::stop_token stop) { + while (!stop.stop_requested()) { + { + std::unique_lock lock(batch_queue_cv_mutex); + if (!batch_queue_cv.wait(lock, stop, [this] { return !batch_queue.empty(); })) { + continue; + } + } + + std::shared_ptr batch; + { + std::lock_guard lock(batch_queue_mutex); + batch = batch_queue.front(); + batch_queue.pop(); + } + ProcessBatch(batch->id, batch->jobs); + batch->finished.release(); + } + } + + void ProcessBatch(u32 id, std::span jobs) { + // Perform operation requested by control flags. + for (auto& job : jobs) { + LOG_DEBUG(Lib_Ajm, "Processing job {} for instance {}. flags = {:#x}", id, + job.instance_id, job.flags.raw); + + AjmInstance* p_instance = instances[job.instance_id].get(); + + const auto control_flags = job.flags.control_flags; + if (True(control_flags & AjmJobControlFlags::Reset)) { + LOG_INFO(Lib_Ajm, "Resetting instance {}", job.instance_id); + p_instance->Reset(); + } + if (True(control_flags & AjmJobControlFlags::Initialize)) { + LOG_INFO(Lib_Ajm, "Initializing instance {}", job.instance_id); + ASSERT_MSG(job.input.init_params.has_value(), + "Initialize called without control buffer"); + auto& params = job.input.init_params.value(); + p_instance->Initialize(¶ms, sizeof(params)); + } + if (True(control_flags & AjmJobControlFlags::Resample)) { + LOG_ERROR(Lib_Ajm, "Unimplemented: resample params"); + ASSERT_MSG(job.input.resample_parameters.has_value(), + "Resample paramters are absent"); + p_instance->resample_parameters = job.input.resample_parameters.value(); + } + + const auto sideband_flags = job.flags.sideband_flags; + if (True(sideband_flags & AjmJobSidebandFlags::Format)) { + ASSERT_MSG(job.input.format.has_value(), "Format parameters are absent"); + p_instance->format = job.input.format.value(); + } + if (True(sideband_flags & AjmJobSidebandFlags::GaplessDecode)) { + ASSERT_MSG(job.input.gapless_decode.has_value(), + "Gapless decode parameters are absent"); + auto& params = job.input.gapless_decode.value(); + p_instance->gapless.total_samples = params.total_samples; + p_instance->gapless.skip_samples = params.skip_samples; + } + + ASSERT_MSG(job.input.buffers.size() <= job.output.buffers.size(), + "Unsupported combination of input/output buffers."); + + for (size_t i = 0; i < job.input.buffers.size(); ++i) { + // Decode as much of the input bitstream as possible. + const auto& in_buffer = job.input.buffers[i]; + auto& out_buffer = job.output.buffers[i]; + + const u8* in_address = in_buffer.data(); + u8* out_address = out_buffer.data(); + const auto [in_remain, out_remain] = p_instance->Decode( + in_address, in_buffer.size(), out_address, out_buffer.size(), &job.output); + + if (job.output.p_stream != nullptr) { + job.output.p_stream->input_consumed += in_buffer.size() - in_remain; + job.output.p_stream->output_written += out_buffer.size() - out_remain; + job.output.p_stream->total_decoded_samples += p_instance->decoded_samples; + } + } + + if (job.output.p_gapless_decode != nullptr) { + *job.output.p_gapless_decode = p_instance->gapless; + } + + if (job.output.p_codec_info != nullptr) { + p_instance->GetCodecInfo(job.output.p_codec_info); + } + } } }; static std::unique_ptr dev{}; -int PS4_SYSV_ABI sceAjmBatchCancel() { - LOG_ERROR(Lib_Ajm, "(STUBBED) called"); +int PS4_SYSV_ABI sceAjmBatchCancel(const u32 context_id, const u32 batch_id) { + std::shared_ptr batch{}; + { + std::lock_guard guard(dev->batches_mutex); + const auto opt_batch = dev->batches.Get(batch_id); + if (!opt_batch.has_value()) { + return ORBIS_AJM_ERROR_INVALID_BATCH; + } + + batch = opt_batch.value().get(); + } + + batch->canceled = true; + return ORBIS_OK; } @@ -81,23 +220,6 @@ int PS4_SYSV_ABI sceAjmBatchErrorDump() { return ORBIS_OK; } -template -ChunkType& AjmBufferExtract(CursorType& p_cursor) { - auto* const result = reinterpret_cast(p_cursor); - p_cursor += sizeof(ChunkType); - return *result; -} - -template -void AjmBufferSkip(CursorType& p_cursor) { - p_cursor += sizeof(ChunkType); -} - -template -ChunkType& AjmBufferPeek(CursorType p_cursor) { - return *reinterpret_cast(p_cursor); -} - void* PS4_SYSV_ABI sceAjmBatchJobControlBufferRa(void* p_buffer, u32 instance_id, u64 flags, void* p_sideband_input, size_t sideband_input_size, void* p_sideband_output, @@ -116,6 +238,7 @@ void* PS4_SYSV_ABI sceAjmBatchJobControlBufferRa(void* p_buffer, u32 instance_id if (p_return_address != nullptr) { auto& chunk_ra = AjmBufferExtract(p_current); chunk_ra.header.ident = AjmIdentReturnAddressBuf; + chunk_ra.header.payload = 0; chunk_ra.header.size = 0; chunk_ra.p_address = p_return_address; } @@ -123,6 +246,7 @@ void* PS4_SYSV_ABI sceAjmBatchJobControlBufferRa(void* p_buffer, u32 instance_id { auto& chunk_input = AjmBufferExtract(p_current); chunk_input.header.ident = AjmIdentInputControlBuf; + chunk_input.header.payload = 0; chunk_input.header.size = sideband_input_size; chunk_input.p_address = p_sideband_input; } @@ -149,6 +273,7 @@ void* PS4_SYSV_ABI sceAjmBatchJobControlBufferRa(void* p_buffer, u32 instance_id { auto& chunk_output = AjmBufferExtract(p_current); chunk_output.header.ident = AjmIdentOutputControlBuf; + chunk_output.header.payload = 0; chunk_output.header.size = sideband_output_size; chunk_output.p_address = p_sideband_output; } @@ -166,6 +291,7 @@ void* PS4_SYSV_ABI sceAjmBatchJobInlineBuffer(void* p_buffer, const void* p_data auto& header = AjmBufferExtract(p_current); header.ident = AjmIdentInlineBuf; + header.payload = 0; header.size = Common::AlignUp(data_input_size, 8); *pp_batch_address = p_current; @@ -191,6 +317,7 @@ void* PS4_SYSV_ABI sceAjmBatchJobRunBufferRa(void* p_buffer, u32 instance_id, u6 if (p_return_address != nullptr) { auto& chunk_ra = AjmBufferExtract(p_current); chunk_ra.header.ident = AjmIdentReturnAddressBuf; + chunk_ra.header.payload = 0; chunk_ra.header.size = 0; chunk_ra.p_address = p_return_address; } @@ -198,6 +325,7 @@ void* PS4_SYSV_ABI sceAjmBatchJobRunBufferRa(void* p_buffer, u32 instance_id, u6 { auto& chunk_input = AjmBufferExtract(p_current); chunk_input.header.ident = AjmIdentInputRunBuf; + chunk_input.header.payload = 0; chunk_input.header.size = data_input_size; chunk_input.p_address = p_data_input; } @@ -217,6 +345,7 @@ void* PS4_SYSV_ABI sceAjmBatchJobRunBufferRa(void* p_buffer, u32 instance_id, u6 { auto& chunk_output = AjmBufferExtract(p_current); chunk_output.header.ident = AjmIdentOutputRunBuf; + chunk_output.header.payload = 0; chunk_output.header.size = data_output_size; chunk_output.p_address = p_data_output; } @@ -224,6 +353,7 @@ void* PS4_SYSV_ABI sceAjmBatchJobRunBufferRa(void* p_buffer, u32 instance_id, u6 { auto& chunk_output = AjmBufferExtract(p_current); chunk_output.header.ident = AjmIdentOutputControlBuf; + chunk_output.header.payload = 0; chunk_output.header.size = sideband_output_size; chunk_output.p_address = p_sideband_output; } @@ -250,6 +380,7 @@ void* PS4_SYSV_ABI sceAjmBatchJobRunSplitBufferRa( if (p_return_address != nullptr) { auto& chunk_ra = AjmBufferExtract(p_current); chunk_ra.header.ident = AjmIdentReturnAddressBuf; + chunk_ra.header.payload = 0; chunk_ra.header.size = 0; chunk_ra.p_address = p_return_address; } @@ -257,6 +388,7 @@ void* PS4_SYSV_ABI sceAjmBatchJobRunSplitBufferRa( for (s32 i = 0; i < num_data_input_buffers; i++) { auto& chunk_input = AjmBufferExtract(p_current); chunk_input.header.ident = AjmIdentInputRunBuf; + chunk_input.header.payload = 0; chunk_input.header.size = p_data_input_buffers[i].size; chunk_input.p_address = p_data_input_buffers[i].p_address; } @@ -276,6 +408,7 @@ void* PS4_SYSV_ABI sceAjmBatchJobRunSplitBufferRa( for (s32 i = 0; i < num_data_output_buffers; i++) { auto& chunk_output = AjmBufferExtract(p_current); chunk_output.header.ident = AjmIdentOutputRunBuf; + chunk_output.header.payload = 0; chunk_output.header.size = p_data_output_buffers[i].size; chunk_output.p_address = p_data_output_buffers[i].p_address; } @@ -283,6 +416,7 @@ void* PS4_SYSV_ABI sceAjmBatchJobRunSplitBufferRa( { auto& chunk_output = AjmBufferExtract(p_current); chunk_output.header.ident = AjmIdentOutputControlBuf; + chunk_output.header.payload = 0; chunk_output.header.size = sideband_output_size; chunk_output.p_address = p_sideband_output; } @@ -291,58 +425,67 @@ void* PS4_SYSV_ABI sceAjmBatchJobRunSplitBufferRa( return p_current; } -int PS4_SYSV_ABI sceAjmBatchStartBuffer(u32 context, const u8* batch, u32 batch_size, +int PS4_SYSV_ABI sceAjmBatchStartBuffer(u32 context, u8* p_batch, u32 batch_size, const int priority, AjmBatchError* batch_error, u32* out_batch_id) { LOG_INFO(Lib_Ajm, "called context = {}, batch_size = {:#x}, priority = {}", context, batch_size, priority); if ((batch_size & 7) != 0) { + LOG_ERROR(Lib_Ajm, "ORBIS_AJM_ERROR_MALFORMED_BATCH"); return ORBIS_AJM_ERROR_MALFORMED_BATCH; } const auto batch_info = std::make_shared(); - if (dev->batches.size() >= MaxBatches) { + auto batch_id = dev->batches.Create(batch_info); + if (!batch_id.has_value()) { LOG_ERROR(Lib_Ajm, "Too many batches in job!"); return ORBIS_AJM_ERROR_OUT_OF_MEMORY; } + batch_info->id = batch_id.value(); + *out_batch_id = batch_id.value(); - *out_batch_id = static_cast(dev->batches.size()); - dev->batches.push_back(batch_info); - - const u8* p_current = batch; - const u8* p_batch_end = batch + batch_size; + u8* p_current = p_batch; + u8* const p_batch_end = p_current + batch_size; while (p_current < p_batch_end) { auto& header = AjmBufferExtract(p_current); ASSERT(header.ident == AjmIdentJob); + batch_info->jobs.push_back(AjmJob{}); + auto& job = batch_info->jobs.back(); + job.instance_id = header.payload; + std::optional job_flags = {}; std::optional input_control_buffer = {}; std::optional output_control_buffer = {}; + std::optional inline_buffer = {}; boost::container::small_vector input_run_buffers; boost::container::small_vector output_run_buffers; // Read parameters of a job auto* const p_job_end = p_current + header.size; while (p_current < p_job_end) { - auto& header = AjmBufferPeek(p_current); + auto& header = AjmBufferPeek(p_current); switch (header.ident) { case Identifier::AjmIdentInputRunBuf: { - input_run_buffers.emplace_back(AjmBufferExtract(p_current)); + auto& buffer = AjmBufferExtract(p_current); + u8* p_begin = reinterpret_cast(buffer.p_address); + job.input.buffers.emplace_back( + std::vector(p_begin, p_begin + buffer.header.size)); break; } case Identifier::AjmIdentInputControlBuf: { ASSERT_MSG(!input_control_buffer.has_value(), "Only one instance of input control buffer is allowed per job"); - input_control_buffer = AjmBufferExtract(p_current); + input_control_buffer = AjmBufferExtract(p_current); break; } case Identifier::AjmIdentControlFlags: case Identifier::AjmIdentRunFlags: { ASSERT_MSG(!job_flags.has_value(), "Only one instance of job flags is allowed per job"); - auto& flags_chunk = AjmBufferExtract(p_current); + auto& flags_chunk = AjmBufferExtract(p_current); job_flags = AjmJobFlags{ .raw = (u64(flags_chunk.payload) << 32) + flags_chunk.size, }; @@ -350,129 +493,113 @@ int PS4_SYSV_ABI sceAjmBatchStartBuffer(u32 context, const u8* batch, u32 batch_ } case Identifier::AjmIdentReturnAddressBuf: { // Ignore return address buffers. - AjmBufferSkip(p_current); + AjmBufferSkip(p_current); + break; + } + case Identifier::AjmIdentInlineBuf: { + ASSERT_MSG(!output_control_buffer.has_value(), + "Only one instance of inline buffer is allowed per job"); + inline_buffer = AjmBufferExtract(p_current); break; } case Identifier::AjmIdentOutputRunBuf: { - output_run_buffers.emplace_back(AjmBufferExtract(p_current)); + auto& buffer = AjmBufferExtract(p_current); + u8* p_begin = reinterpret_cast(buffer.p_address); + job.output.buffers.emplace_back( + std::span(p_begin, p_begin + buffer.header.size)); break; } case Identifier::AjmIdentOutputControlBuf: { ASSERT_MSG(!output_control_buffer.has_value(), "Only one instance of output control buffer is allowed per job"); - output_control_buffer = AjmBufferExtract(p_current); + output_control_buffer = AjmBufferExtract(p_current); break; } default: - LOG_ERROR(Lib_Ajm, "Unknown chunk: {}", header.ident); - p_current += header.size; - break; + UNREACHABLE_MSG("Unknown chunk: {}", header.ident); } } - const u32 instance = header.payload; - AjmInstance* p_instance = dev->instances[instance].get(); + job.flags = job_flags.value(); // Perform operation requested by control flags. const auto control_flags = job_flags.value().control_flags; - if (True(control_flags & AjmJobControlFlags::Reset)) { - LOG_INFO(Lib_Ajm, "Resetting instance {}", instance); - p_instance->Reset(); - } if (True(control_flags & AjmJobControlFlags::Initialize)) { - LOG_INFO(Lib_Ajm, "Initializing instance {}", instance); ASSERT_MSG(input_control_buffer.has_value(), "Initialize called without control buffer"); const auto& in_buffer = input_control_buffer.value(); - p_instance->Initialize(in_buffer.p_address, in_buffer.header.size); + job.input.init_params = AjmDecAt9InitializeParameters{}; + std::memcpy(&job.input.init_params.value(), in_buffer.p_address, in_buffer.header.size); } if (True(control_flags & AjmJobControlFlags::Resample)) { - LOG_ERROR(Lib_Ajm, "Unimplemented: Set resample params of instance {}", instance); + ASSERT_MSG(inline_buffer.has_value(), + "Resample paramters are stored in the inline buffer"); + auto* p_buffer = reinterpret_cast(inline_buffer.value().p_address); + job.input.resample_parameters = + AjmBufferExtract(p_buffer); } - AjmSidebandResult* p_result = nullptr; - AjmSidebandStream* p_stream = nullptr; - AjmSidebandFormat* p_format = nullptr; - AjmSidebandGaplessDecode* p_gapless_decode = nullptr; - AjmSidebandMFrame* p_mframe = nullptr; - u8* p_codec_info = nullptr; + // Initialize sideband input parameters + if (input_control_buffer.has_value()) { + auto* p_sideband = reinterpret_cast(input_control_buffer.value().p_address); + auto* const p_end = p_sideband + input_control_buffer.value().header.size; - // Initialize sideband structures. + const auto sideband_flags = job_flags.value().sideband_flags; + if (True(sideband_flags & AjmJobSidebandFlags::Format) && p_sideband < p_end) { + job.input.format = AjmBufferExtract(p_sideband); + } + if (True(sideband_flags & AjmJobSidebandFlags::GaplessDecode) && p_sideband < p_end) { + job.input.gapless_decode = AjmBufferExtract(p_sideband); + } + + ASSERT_MSG(p_sideband <= p_end, "Input sideband out of bounds"); + } + + // Initialize sideband output parameters if (output_control_buffer.has_value()) { auto* p_sideband = reinterpret_cast(output_control_buffer.value().p_address); - p_result = &AjmBufferExtract(p_sideband); - *p_result = AjmSidebandResult{}; + auto* const p_end = p_sideband + output_control_buffer.value().header.size; + job.output.p_result = &AjmBufferExtract(p_sideband); + *job.output.p_result = AjmSidebandResult{}; const auto sideband_flags = job_flags.value().sideband_flags; - if (True(sideband_flags & AjmJobSidebandFlags::Stream)) { - p_stream = &AjmBufferExtract(p_sideband); - *p_stream = AjmSidebandStream{}; + if (True(sideband_flags & AjmJobSidebandFlags::Stream) && p_sideband < p_end) { + job.output.p_stream = &AjmBufferExtract(p_sideband); + *job.output.p_stream = AjmSidebandStream{}; } - if (True(sideband_flags & AjmJobSidebandFlags::Format)) { + if (True(sideband_flags & AjmJobSidebandFlags::Format) && p_sideband < p_end) { LOG_ERROR(Lib_Ajm, "SIDEBAND_FORMAT is not implemented"); - p_format = &AjmBufferExtract(p_sideband); - *p_format = AjmSidebandFormat{}; + job.output.p_format = &AjmBufferExtract(p_sideband); + *job.output.p_format = AjmSidebandFormat{}; } - if (True(sideband_flags & AjmJobSidebandFlags::GaplessDecode)) { - LOG_ERROR(Lib_Ajm, "SIDEBAND_GAPLESS_DECODE is not implemented"); - p_gapless_decode = &AjmBufferExtract(p_sideband); - if (input_control_buffer) { - memcpy(&p_instance->gapless, input_control_buffer->p_address, - sizeof(AjmSidebandGaplessDecode)); - LOG_INFO(Lib_Ajm, - "Setting gapless params instance = {}, total_samples = {}, " - "skip_samples = {}", - instance, p_instance->gapless.total_samples, - p_instance->gapless.skip_samples); - } else { - LOG_ERROR(Lib_Ajm, "Requesting gapless structure!"); - } - *p_gapless_decode = AjmSidebandGaplessDecode{}; - } - const auto run_flags = job_flags.value().run_flags; - if (True(run_flags & AjmJobRunFlags::MultipleFrames)) { - p_mframe = &AjmBufferExtract(p_sideband); - *p_mframe = AjmSidebandMFrame{}; - } - if (True(run_flags & AjmJobRunFlags::GetCodecInfo)) { - p_codec_info = p_sideband; - p_sideband += p_instance->GetCodecInfoSize(); + if (True(sideband_flags & AjmJobSidebandFlags::GaplessDecode) && p_sideband < p_end) { + job.output.p_gapless_decode = + &AjmBufferExtract(p_sideband); + *job.output.p_gapless_decode = AjmSidebandGaplessDecode{}; } - } - - // Perform operation requested by run flags. - ASSERT_MSG(input_run_buffers.size() == output_run_buffers.size(), - "Run operation with uneven input/output buffers."); - - for (size_t i = 0; i < input_run_buffers.size(); ++i) { - // Decode as much of the input bitstream as possible. - const auto& in_buffer = input_run_buffers[i]; - const auto& out_buffer = output_run_buffers[i]; - const u8* in_address = reinterpret_cast(in_buffer.p_address); - u8* out_address = reinterpret_cast(out_buffer.p_address); - const auto [in_remain, out_remain, num_frames] = p_instance->Decode(in_address, in_buffer.header.size, - out_address, out_buffer.header.size); - - if (p_stream != nullptr) { - p_stream->input_consumed += in_buffer.header.size - in_remain; - p_stream->output_written += out_buffer.header.size - out_remain; - p_stream->total_decoded_samples += p_instance->decoded_samples; + const auto run_flags = job_flags.value().run_flags; + if (True(run_flags & AjmJobRunFlags::MultipleFrames) && p_sideband < p_end) { + job.output.p_mframe = &AjmBufferExtract(p_sideband); + *job.output.p_mframe = AjmSidebandMFrame{}; } - if (p_mframe != nullptr) { - p_mframe->num_frames += num_frames; + if (True(run_flags & AjmJobRunFlags::GetCodecInfo) && p_sideband < p_end) { + job.output.p_codec_info = p_sideband; } - } - if (p_codec_info) { - p_instance->GetCodecInfo(p_codec_info); + ASSERT_MSG(p_sideband <= p_end, "Output sideband out of bounds"); } + } - p_result->result = 0; - p_result->internal_result = 0; + { + std::lock_guard lock(dev->batch_queue_mutex); + dev->batch_queue.push(batch_info); } - batch_info->finished.release(); + { + std::unique_lock lock(dev->batch_queue_cv_mutex); + dev->batch_queue_cv.notify_all(); + } return ORBIS_OK; } @@ -482,27 +609,44 @@ int PS4_SYSV_ABI sceAjmBatchWait(const u32 context, const u32 batch_id, const u3 LOG_INFO(Lib_Ajm, "called context = {}, batch_id = {}, timeout = {}", context, batch_id, timeout); - if (batch_id > 0xFF || batch_id >= dev->batches.size()) { - return ORBIS_AJM_ERROR_INVALID_BATCH; - } + std::shared_ptr batch{}; + { + std::lock_guard guard(dev->batches_mutex); + const auto opt_batch = dev->batches.Get(batch_id); + if (!opt_batch.has_value()) { + LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_INVALID_BATCH"); + return ORBIS_AJM_ERROR_INVALID_BATCH; + } - const auto& batch = dev->batches[batch_id]; + batch = opt_batch.value().get(); + } - if (batch->waiting) { + bool expected = false; + if (!batch->waiting.compare_exchange_strong(expected, true)) { + LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_BUSY"); return ORBIS_AJM_ERROR_BUSY; } - batch->waiting = true; - SCOPE_EXIT { + if (timeout == ORBIS_AJM_WAIT_INFINITE) { + batch->finished.acquire(); + } else if (!batch->finished.try_acquire_for(std::chrono::milliseconds(timeout))) { batch->waiting = false; - }; - - if (!batch->finished.try_acquire_for(std::chrono::milliseconds(timeout))) { + LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_IN_PROGRESS"); return ORBIS_AJM_ERROR_IN_PROGRESS; } - dev->batches.erase(dev->batches.begin() + batch_id); - return 0; + { + std::lock_guard guard(dev->batches_mutex); + dev->batches.Destroy(batch_id); + } + + if (batch->canceled) { + LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_CANCELLED"); + return ORBIS_AJM_ERROR_CANCELLED; + } + + LOG_INFO(Lib_Ajm, "ORBIS_OK"); + return ORBIS_OK; } int PS4_SYSV_ABI sceAjmDecAt9ParseConfigData() { @@ -573,6 +717,7 @@ int PS4_SYSV_ABI sceAjmInstanceCreate(u32 context, AjmCodecType codec_type, AjmI instance->index = index; instance->codec_type = codec_type; instance->num_channels = flags.channels; + instance->flags = flags; dev->instances[index] = std::move(instance); *out_instance = index; LOG_INFO(Lib_Ajm, "called codec_type = {}, flags = {:#x}, instance = {}", diff --git a/src/core/libraries/ajm/ajm.h b/src/core/libraries/ajm/ajm.h index f4f737ba6b..183d955926 100644 --- a/src/core/libraries/ajm/ajm.h +++ b/src/core/libraries/ajm/ajm.h @@ -7,6 +7,8 @@ #include "common/enum.h" #include "common/types.h" +#include "core/libraries/ajm/ajm_instance.h" + namespace Core::Loader { class SymbolsResolver; } @@ -87,21 +89,10 @@ union AjmJobFlags { }; }; -union AjmInstanceFlags { - u64 raw; - struct { - u64 version : 3; - u64 channels : 4; - u64 format : 3; - u64 pad : 22; - u64 codec : 28; - }; -}; - struct AjmDecMp3ParseFrame; enum class AjmCodecType : u32; -int PS4_SYSV_ABI sceAjmBatchCancel(); +int PS4_SYSV_ABI sceAjmBatchCancel(const u32 context_id, const u32 batch_id); int PS4_SYSV_ABI sceAjmBatchErrorDump(); void* PS4_SYSV_ABI sceAjmBatchJobControlBufferRa(void* p_buffer, u32 instance_id, u64 flags, void* p_sideband_input, size_t sideband_input_size, @@ -121,9 +112,8 @@ void* PS4_SYSV_ABI sceAjmBatchJobRunSplitBufferRa( size_t num_data_input_buffers, const AjmBuffer* p_data_output_buffers, size_t num_data_output_buffers, void* p_sideband_output, size_t sideband_output_size, void* p_return_address); -int PS4_SYSV_ABI sceAjmBatchStartBuffer(u32 context, const u8* batch, u32 batch_size, - const int priority, AjmBatchError* batch_error, - u32* out_batch_id); +int PS4_SYSV_ABI sceAjmBatchStartBuffer(u32 context, u8* batch, u32 batch_size, const int priority, + AjmBatchError* batch_error, u32* out_batch_id); int PS4_SYSV_ABI sceAjmBatchWait(const u32 context, const u32 batch_id, const u32 timeout, AjmBatchError* const batch_error); int PS4_SYSV_ABI sceAjmDecAt9ParseConfigData(); diff --git a/src/core/libraries/ajm/ajm_at9.cpp b/src/core/libraries/ajm/ajm_at9.cpp index ddffdc36b0..7e13c31e71 100644 --- a/src/core/libraries/ajm/ajm_at9.cpp +++ b/src/core/libraries/ajm/ajm_at9.cpp @@ -43,7 +43,7 @@ void AjmAt9Decoder::Initialize(const void* buffer, u32 buffer_size) { ASSERT_MSG(buffer_size == sizeof(AjmDecAt9InitializeParameters), "Incorrect At9 initialization buffer size {}", buffer_size); const auto params = reinterpret_cast(buffer); - std::memcpy(config_data, params->config_data, SCE_AT9_CONFIG_DATA_SIZE); + std::memcpy(config_data, params->config_data, ORBIS_AT9_CONFIG_DATA_SIZE); AjmAt9Decoder::Reset(); } @@ -58,8 +58,8 @@ void AjmAt9Decoder::GetCodecInfo(void* out_info) { codec_info->uiSuperFrameSize = decoder_codec_info.superframeSize; } -std::tuple AjmAt9Decoder::Decode(const u8* in_buf, u32 in_size, u8* out_buf, - u32 out_size) { +std::tuple AjmAt9Decoder::Decode(const u8* in_buf, u32 in_size_in, u8* out_buf, + u32 out_size_in, AjmJobOutput* output) { const auto decoder_handle = static_cast(handle); Atrac9CodecInfo codec_info; Atrac9GetCodecInfo(handle, &codec_info); @@ -67,8 +67,11 @@ std::tuple AjmAt9Decoder::Decode(const u8* in_buf, u32 in_size, u int bytes_used = 0; int num_superframes = 0; + u32 in_size = in_size_in; + u32 out_size = out_size_in; + const auto ShouldDecode = [&] { - if (in_size <= 0 || out_size <= 0) { + if (in_size == 0 || out_size == 0) { return false; } if (gapless.total_samples != 0 && gapless.total_samples < decoded_samples) { @@ -84,16 +87,26 @@ std::tuple AjmAt9Decoder::Decode(const u8* in_buf, u32 in_size, u ASSERT_MSG(ret == At9Status::ERR_SUCCESS, "Atrac9Decode failed ret = {:#x}", ret); in_buf += bytes_used; in_size -= bytes_used; + if (output->p_mframe) { + ++output->p_mframe->num_frames; + } num_frames++; bytes_remain -= bytes_used; - if (gapless.skip_samples != 0) { - gapless.skip_samples -= decoder_handle->Config.FrameSamples; + if (gapless.skipped_samples < gapless.skip_samples) { + gapless.skipped_samples += decoder_handle->Config.FrameSamples; + if (gapless.skipped_samples > gapless.skip_samples) { + const auto size = gapless.skipped_samples - gapless.skip_samples; + const auto start = decoder_handle->Config.FrameSamples - size; + memcpy(out_buf, pcm_buffer.data() + start, size * sizeof(s16)); + out_buf += size * sizeof(s16); + out_size -= size * sizeof(s16); + } } else { memcpy(out_buf, pcm_buffer.data(), written_size); out_buf += written_size; out_size -= written_size; - decoded_samples += decoder_handle->Config.FrameSamples; } + decoded_samples += decoder_handle->Config.FrameSamples; if ((num_frames % codec_info.framesInSuperframe) == 0) { in_buf += bytes_remain; in_size -= bytes_remain; @@ -102,8 +115,15 @@ std::tuple AjmAt9Decoder::Decode(const u8* in_buf, u32 in_size, u } } - LOG_TRACE(Lib_Ajm, "Decoded {} samples, frame count: {}", decoded_samples, frame_index); - return std::tuple(in_size, out_size, num_superframes); + if (gapless.total_samples == decoded_samples) { + decoded_samples = 0; + if (flags.gapless_loop) { + gapless.skipped_samples = 0; + } + } + + LOG_TRACE(Lib_Ajm, "Decoded {} samples, frame count: {}", decoded_samples, num_frames); + return std::tuple(in_size, out_size); } } // namespace Libraries::Ajm diff --git a/src/core/libraries/ajm/ajm_at9.h b/src/core/libraries/ajm/ajm_at9.h index 905858ae37..d4e268eea0 100644 --- a/src/core/libraries/ajm/ajm_at9.h +++ b/src/core/libraries/ajm/ajm_at9.h @@ -15,13 +15,7 @@ extern "C" { namespace Libraries::Ajm { -constexpr u32 SCE_AT9_CONFIG_DATA_SIZE = 4; -constexpr s32 SCE_AJM_DEC_AT9_MAX_CHANNELS = 8; - -struct AjmDecAt9InitializeParameters { - u8 config_data[SCE_AT9_CONFIG_DATA_SIZE]; - u32 reserved; -}; +constexpr s32 ORBIS_AJM_DEC_AT9_MAX_CHANNELS = 8; struct AjmSidebandDecAt9CodecInfo { u32 uiSuperFrameSize; @@ -35,7 +29,7 @@ struct AjmAt9Decoder final : AjmInstance { bool decoder_initialized = false; std::fstream file; int length; - u8 config_data[SCE_AT9_CONFIG_DATA_SIZE]; + u8 config_data[ORBIS_AT9_CONFIG_DATA_SIZE]; explicit AjmAt9Decoder(); ~AjmAt9Decoder() override; @@ -49,8 +43,8 @@ struct AjmAt9Decoder final : AjmInstance { return sizeof(AjmSidebandDecAt9CodecInfo); } - std::tuple Decode(const u8* in_buf, u32 in_size, u8* out_buf, - u32 out_size) override; + std::tuple Decode(const u8* in_buf, u32 in_size, u8* out_buf, u32 out_size, + AjmJobOutput* output) override; }; } // namespace Libraries::Ajm diff --git a/src/core/libraries/ajm/ajm_instance.h b/src/core/libraries/ajm/ajm_instance.h index de501f0592..a14eb5ac7a 100644 --- a/src/core/libraries/ajm/ajm_instance.h +++ b/src/core/libraries/ajm/ajm_instance.h @@ -6,6 +6,12 @@ #include "common/enum.h" #include "common/types.h" +#include + +#include +#include +#include + extern "C" { struct AVCodec; struct AVCodecContext; @@ -14,6 +20,21 @@ struct AVCodecParserContext; namespace Libraries::Ajm { +constexpr int ORBIS_AJM_RESULT_NOT_INITIALIZED = 0x00000001; +constexpr int ORBIS_AJM_RESULT_INVALID_DATA = 0x00000002; +constexpr int ORBIS_AJM_RESULT_INVALID_PARAMETER = 0x00000004; +constexpr int ORBIS_AJM_RESULT_PARTIAL_INPUT = 0x00000008; +constexpr int ORBIS_AJM_RESULT_NOT_ENOUGH_ROOM = 0x00000010; +constexpr int ORBIS_AJM_RESULT_STREAM_CHANGE = 0x00000020; +constexpr int ORBIS_AJM_RESULT_TOO_MANY_CHANNELS = 0x00000040; +constexpr int ORBIS_AJM_RESULT_UNSUPPORTED_FLAG = 0x00000080; +constexpr int ORBIS_AJM_RESULT_SIDEBAND_TRUNCATED = 0x00000100; +constexpr int ORBIS_AJM_RESULT_PRIORITY_PASSED = 0x00000200; +constexpr int ORBIS_AJM_RESULT_CODEC_ERROR = 0x40000000; +constexpr int ORBIS_AJM_RESULT_FATAL = 0x80000000; + +constexpr u32 ORBIS_AT9_CONFIG_DATA_SIZE = 4; + enum class AjmCodecType : u32 { Mp3Dec = 0, At9Dec = 1, @@ -60,15 +81,63 @@ struct AjmSidebandGaplessDecode { u16 skipped_samples; }; +struct AjmSidebandResampleParameters { + float ratio; + uint32_t flags; +}; + +struct AjmDecAt9InitializeParameters { + u8 config_data[ORBIS_AT9_CONFIG_DATA_SIZE]; + u32 reserved; +}; + +union AjmSidebandInitParameters { + AjmDecAt9InitializeParameters at9; + u8 reserved[8]; +}; + +struct AjmJobInput { + std::optional init_params; + std::optional resample_parameters; + std::optional format; + std::optional gapless_decode; + boost::container::small_vector, 4> buffers; +}; + +struct AjmJobOutput { + boost::container::small_vector, 4> buffers; + AjmSidebandResult* p_result = nullptr; + AjmSidebandStream* p_stream = nullptr; + AjmSidebandFormat* p_format = nullptr; + AjmSidebandGaplessDecode* p_gapless_decode = nullptr; + AjmSidebandMFrame* p_mframe = nullptr; + u8* p_codec_info = nullptr; +}; + +union AjmInstanceFlags { + u64 raw; + struct { + u64 version : 3; + u64 channels : 4; + u64 format : 3; + u64 gapless_loop : 1; + u64 pad : 21; + u64 codec : 28; + }; +}; + struct AjmInstance { AjmCodecType codec_type; AjmFormatEncoding fmt{}; + AjmInstanceFlags flags{.raw = 0}; u32 num_channels{}; u32 index{}; u32 bytes_remain{}; u32 num_frames{}; u32 decoded_samples{}; + AjmSidebandFormat format{}; AjmSidebandGaplessDecode gapless{}; + AjmSidebandResampleParameters resample_parameters{}; explicit AjmInstance() = default; virtual ~AjmInstance() = default; @@ -80,8 +149,8 @@ struct AjmInstance { virtual void GetCodecInfo(void* out_info) = 0; virtual u32 GetCodecInfoSize() = 0; - virtual std::tuple Decode(const u8* in_buf, u32 in_size, u8* out_buf, - u32 out_size) = 0; + virtual std::tuple Decode(const u8* in_buf, u32 in_size, u8* out_buf, u32 out_size, + AjmJobOutput* output) = 0; }; } // namespace Libraries::Ajm diff --git a/src/core/libraries/ajm/ajm_mp3.cpp b/src/core/libraries/ajm/ajm_mp3.cpp index bed3569d4a..acdf3d3f07 100644 --- a/src/core/libraries/ajm/ajm_mp3.cpp +++ b/src/core/libraries/ajm/ajm_mp3.cpp @@ -77,8 +77,8 @@ void AjmMp3Decoder::Reset() { num_frames = 0; } -std::tuple AjmMp3Decoder::Decode(const u8* buf, u32 in_size, u8* out_buf, - u32 out_size) { +std::tuple AjmMp3Decoder::Decode(const u8* buf, u32 in_size, u8* out_buf, u32 out_size, + AjmJobOutput* output) { AVPacket* pkt = av_packet_alloc(); while (in_size > 0 && out_size > 0) { int ret = av_parser_parse2(parser, c, &pkt->data, &pkt->size, buf, in_size, AV_NOPTS_VALUE, @@ -119,7 +119,10 @@ std::tuple AjmMp3Decoder::Decode(const u8* buf, u32 in_size, u8* } } av_packet_free(&pkt); - return std::make_tuple(in_size, out_size, num_frames); + if (output->p_mframe) { + output->p_mframe->num_frames += num_frames; + } + return std::make_tuple(in_size, out_size); } int AjmMp3Decoder::ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl, diff --git a/src/core/libraries/ajm/ajm_mp3.h b/src/core/libraries/ajm/ajm_mp3.h index 2d654bdae4..69ba25e197 100644 --- a/src/core/libraries/ajm/ajm_mp3.h +++ b/src/core/libraries/ajm/ajm_mp3.h @@ -74,8 +74,8 @@ struct AjmMp3Decoder : public AjmInstance { return sizeof(AjmSidebandDecMp3CodecInfo); } - std::tuple Decode(const u8* in_buf, u32 in_size, u8* out_buf, - u32 out_size) override; + std::tuple Decode(const u8* in_buf, u32 in_size, u8* out_buf, u32 out_size, + AjmJobOutput* output) override; static int ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl, AjmDecMp3ParseFrame* frame);