diff --git a/.gitmodules b/.gitmodules index ced5dcf94..ec484eb61 100644 --- a/.gitmodules +++ b/.gitmodules @@ -17,3 +17,7 @@ [submodule "third_party/json"] path = third_party/json url = https://github.com/nlohmann/json + +[submodule "third_party/dlpack"] + path = third_party/dlpack + url = https://github.com/dmlc/dlpack diff --git a/.vscode/c_cpp_properties.json b/.vscode/c_cpp_properties.json index b4d0e7494..ac86a796e 100644 --- a/.vscode/c_cpp_properties.json +++ b/.vscode/c_cpp_properties.json @@ -4,6 +4,7 @@ "name": "Linux", "includePath": [ "${workspaceFolder}/**", + "${workspaceFolder}/third_party/mscclpp/include", "/usr/local/cuda/include", "/opt/rocm/include" ], diff --git a/ark/CMakeLists.txt b/ark/CMakeLists.txt index 208d9f9cb..9616ea875 100644 --- a/ark/CMakeLists.txt +++ b/ark/CMakeLists.txt @@ -17,6 +17,7 @@ set(COMMON_LIBS ARK::numa ARK::ibverbs pthread rt) target_include_directories(ark_obj PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) target_include_directories(ark_obj PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_include_directories(ark_obj SYSTEM PRIVATE + ${DLPACK_INCLUDE_DIRS} ${JSON_INCLUDE_DIRS} ${MSCCLPP_INCLUDE_DIRS} ${IBVERBS_INCLUDE_DIRS} diff --git a/ark/api/context.cpp b/ark/api/context.cpp index 76baedc87..702247ddf 100644 --- a/ark/api/context.cpp +++ b/ark/api/context.cpp @@ -29,4 +29,8 @@ void Context::set(const std::string& key, const std::string& value, this->impl_->set(key, value_json, type); } +std::string Context::dump() const { + return this->impl_->dump().dump(); +} + } // namespace ark diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index c8e2e7df6..af1789dc1 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -4,13 +4,17 @@ #include "ark/executor.hpp" #include +#include #include #include #include #include +#include +#include "ark/data_type.hpp" #include "ark/model.hpp" #include "ark/planner.hpp" +#include "buffer_registry.hpp" #include "codegen.hpp" #include "env.h" #include "file_io.h" @@ -138,172 +142,339 @@ static size_t tensor_stride_bytes(const Json &tensor) { return nelems * DataType::from_name(tensor["DataType"]).bytes(); } -class Executor::Impl { +class CommResource { public: - Impl(int device_id, Stream stream, const std::string &name, bool loop_mode); - ~Impl(); + CommResource(int device_id, int rank, int world_size); - void init(const PlanJson& plan); + int rank() const { return rank_; } - int device_id() const { return device_id_; } + int world_size() const { return world_size_; } - Stream stream() const { return reinterpret_cast(stream_raw_); } + std::shared_ptr bootstrap() { + return comm_->bootstrap(); + } - std::string plan() const { return plan_json_.dump_pretty(); } + std::shared_ptr comm() { return comm_; } - void compile(); - void launch(); - void run(int iter); - void wait(int64_t max_spin_count); - float stop(int64_t max_spin_count); - void barrier(); + std::shared_ptr proxy_service() { + return proxy_service_; + } - uintptr_t tensor_address(const Tensor &tensor) const; + struct ConnectionResource { + std::shared_ptr connection; + std::vector> + proxy_channels; + std::vector> sm_channels; + }; - void tensor_read(const Tensor &tensor, void *data, size_t bytes, - Stream stream, bool is_d2d) const; - void tensor_write(const Tensor &tensor, const void *data, size_t bytes, - Stream stream, bool is_d2d) const; + struct RankResource { + int remote_rank; + std::shared_ptr ipc; + std::shared_ptr eth; + std::shared_ptr ib; + }; + + const std::shared_ptr resource(int rank) const { + auto it = rank_to_resource_.find(rank); + if (it == rank_to_resource_.end()) { + return nullptr; + } + return it->second; + } + + void connect(const PlanJson &plan_json, std::shared_ptr buffer); private: - void init_communicator(); - std::map init_buffers(const Json &plan_json); - std::set init_remote_ranks(const Json &plan_json) const; - void init_channels(const std::set &remote_ranks); + int device_id_; + int rank_; + int world_size_; + std::shared_ptr comm_; + std::shared_ptr proxy_service_; + std::map> rank_to_resource_; +}; - protected: +CommResource::CommResource(int device_id, int rank, int world_size) + : device_id_(device_id), rank_(rank), world_size_(world_size) { + auto bootstrap = std::make_shared(rank, world_size); + std::stringstream ip_port; + ip_port << get_host(0) << ":" << get_env().mscclpp_port; + bootstrap->initialize(ip_port.str()); + comm_ = std::make_shared(bootstrap); + proxy_service_ = std::make_shared(); +} + +void CommResource::connect(const PlanJson &plan_json, + std::shared_ptr buffer) { + int rank = plan_json["Rank"]; + std::set remote_ranks; + for (auto &task_info : plan_json["TaskInfos"]) { + for (auto &op : task_info["Ops"]) { + for (auto &tns : op["ReadTensors"]) { + auto buffer = ModelBuffer::deserialize(tns["Buffer"]); + if (buffer->rank() != rank && buffer->rank() != -1) { + remote_ranks.insert(buffer->rank()); + } + } + for (auto &tns : op["WriteTensors"]) { + auto buffer = ModelBuffer::deserialize(tns["Buffer"]); + if (buffer->rank() != rank && buffer->rank() != -1) { + remote_ranks.insert(buffer->rank()); + } + } + for (auto &tns : op["ResultTensors"]) { + auto buffer = ModelBuffer::deserialize(tns["Buffer"]); + if (buffer->rank() != rank && buffer->rank() != -1) { + remote_ranks.insert(buffer->rank()); + } + } + } + } + if (remote_ranks.empty()) return; + + int num_ranks_per_node = get_env().num_ranks_per_host; + auto rank_to_node = [&](int r) { return r / num_ranks_per_node; }; + int this_node = rank_to_node(rank); + + const mscclpp::Transport IBs[] = { + mscclpp::Transport::IB0, mscclpp::Transport::IB1, + mscclpp::Transport::IB2, mscclpp::Transport::IB3, + mscclpp::Transport::IB4, mscclpp::Transport::IB5, + mscclpp::Transport::IB6, mscclpp::Transport::IB7}; + + mscclpp::TransportFlags all_transports = + mscclpp::Transport::CudaIpc | mscclpp::Transport::Ethernet; + if (!get_env().disable_ib) { + all_transports |= IBs[device_id_]; + } + mscclpp::RegisteredMemory regmem = + comm_->registerMemory(buffer->ref(), buffer->bytes(), all_transports); + + using ConnectionFuture = + mscclpp::NonblockingFuture>; + std::map rank_to_ipc_connection_future; + std::map rank_to_eth_connection_future; + std::map rank_to_ib_connection_future; + std::map> + rank_to_remote_regmem_future; + + for (auto remote_rank : remote_ranks) { + auto it = rank_to_resource_.find(remote_rank); + if (it != rank_to_resource_.end()) { + // connection already set + continue; + } + auto resource = std::make_shared(); + rank_to_resource_[remote_rank] = resource; + int remote_node = rank_to_node(remote_rank); + if (remote_node == this_node) { + rank_to_ipc_connection_future[remote_rank] = comm_->connectOnSetup( + remote_rank, 0, mscclpp::Transport::CudaIpc); + resource->ipc = std::make_shared(); + } + if ((remote_node != this_node) && get_env().disable_ib) { + rank_to_eth_connection_future[remote_rank] = comm_->connectOnSetup( + remote_rank, 0, mscclpp::Transport::Ethernet); + resource->eth = std::make_shared(); + } + if (!get_env().disable_ib) { + rank_to_ib_connection_future[remote_rank] = + comm_->connectOnSetup(remote_rank, 0, IBs[device_id_]); + resource->ib = std::make_shared(); + } + comm_->sendMemoryOnSetup(regmem, remote_rank, 0); + rank_to_remote_regmem_future[remote_rank] = + comm_->recvMemoryOnSetup(remote_rank, 0); + } + comm_->setup(); + + for (auto &[remote_rank, future] : rank_to_ipc_connection_future) { + rank_to_resource_[remote_rank]->ipc->connection = future.get(); + } + for (auto &[remote_rank, future] : rank_to_eth_connection_future) { + rank_to_resource_[remote_rank]->eth->connection = future.get(); + } + for (auto &[remote_rank, future] : rank_to_ib_connection_future) { + rank_to_resource_[remote_rank]->ib->connection = future.get(); + } + + mscclpp::MemoryId regmem_id = proxy_service_->addMemory(regmem); + std::map rank_to_remote_regmem; + std::map rank_to_remote_regmem_id; + for (auto &[remote_rank, future] : rank_to_remote_regmem_future) { + rank_to_remote_regmem[remote_rank] = future.get(); + rank_to_remote_regmem_id[remote_rank] = + proxy_service_->addMemory(rank_to_remote_regmem[remote_rank]); + } + + for (auto &[remote_rank, resource] : rank_to_resource_) { + auto remote_regmem_id = rank_to_remote_regmem_id[remote_rank]; + auto add_proxy_channel = + [&](std::shared_ptr conn_resource) { + if (!conn_resource) return; + conn_resource->proxy_channels.push_back( + std::make_shared( + proxy_service_->proxyChannel( + proxy_service_->buildAndAddSemaphore( + *comm_, conn_resource->connection)), + remote_regmem_id, regmem_id)); + }; + // NOTE: We can create multiple proxy channels here if we need in the + // future + add_proxy_channel(resource->ipc); + add_proxy_channel(resource->eth); + add_proxy_channel(resource->ib); + } + comm_->setup(); + + std::map>> + sm_semaphores; + for (auto &[remote_rank, resource] : rank_to_resource_) { + // NOTE: We can create multiple semaphores here if we need in the future + sm_semaphores[remote_rank].push_back( + std::make_shared( + *comm_, resource->ipc->connection)); + } + comm_->setup(); + + for (auto &[remote_rank, resource] : rank_to_resource_) { + // NOTE: We can create multiple sm channels here if we need in the + // future + resource->ipc->sm_channels.push_back( + std::make_shared( + sm_semaphores[remote_rank][0], + rank_to_remote_regmem[remote_rank], regmem.data(), nullptr)); + } +} + +class PlanResourceKey { + public: + PlanResourceKey(const std::string &plan, int device_id, + const std::string &name) + : plan_(plan), device_id_(device_id), name_(name) {} + + bool operator<(const PlanResourceKey &other) const { + return std::tie(plan_, device_id_, name_) < + std::tie(other.plan_, other.device_id_, other.name_); + } + + private: + std::string plan_; int device_id_; std::string name_; - bool loop_mode_; +}; - gpuStream stream_raw_; +class PlanResource { + public: + PlanResource(const PlanJson &plan_json, int device_id, + const std::string &name, + std::shared_ptr &comm_resource); - int rank_; - int world_size_; + const PlanJson &plan_json() const { return plan_json_; } - bool is_launched_ = false; - bool is_recording_ = false; - float elapsed_msec_ = -1; + int device_id() const { return device_id_; } + + const std::string &name() const { return name_; } + + std::shared_ptr buffer() const { return buffer_; } + + void launch_kernel(const std::string &name, const std::vector &args, + gpuStream stream); + + private: + void verify_plan(); + void init_comm_resource(); + void init_internal_buffers(); + void init_comm_connections(); + void init_kernel(); PlanJson plan_json_; - std::map buffer_id_to_offset_; - size_t total_bytes_; - std::shared_ptr codegen_; - std::shared_ptr timer_begin_; - std::shared_ptr timer_end_; + int device_id_; + std::string name_; + std::shared_ptr &comm_resource_; + + int rank_; + int world_size_; std::shared_ptr buffer_; - std::shared_ptr flag_; - std::shared_ptr stream_; + std::map internal_buffer_id_to_offset_; + // extra buffers: external buffers or buffers that are allocated by other + // plans + std::set extra_buffer_ids_; std::shared_ptr kernel_; - - // For communication - std::shared_ptr comm_; - std::shared_ptr proxy_service_; - std::map>> - rank_to_proxy_channels_; - std::map>> - rank_to_sm_channels_; }; -Executor::Impl::Impl(int device_id, Stream stream, const std::string &name, - bool loop_mode) - : device_id_(device_id), name_(name), loop_mode_(loop_mode) { +PlanResource::PlanResource(const PlanJson &plan_json, int device_id, + const std::string &name, + std::shared_ptr &comm_resource) + : plan_json_(plan_json), + device_id_(device_id), + name_(name), + comm_resource_(comm_resource) { if (device_id < 0) { ERR(InvalidUsageError, "Invalid device ID ", device_id); } - if (stream) { - stream_raw_ = reinterpret_cast(stream); - } else { - stream_ = GpuManager::get_instance(device_id_)->create_stream(); - stream_raw_ = stream_->get(); - } -} -Executor::Impl::~Impl() { - if (is_launched_) stop(-1); -} + // Verify if `plan_json` is describes a valid plan + verify_plan(); -void Executor::Impl::init(const PlanJson &plan_json) { - plan_json_ = plan_json; - rank_ = plan_json_["Rank"].get(); - world_size_ = plan_json_["WorldSize"].get(); + // Construct `comm_resource_` if needed + init_comm_resource(); + // Allocate memory for internal buffers and construct + // `internal_buffer_id_to_offset_` and `extra_buffer_ids_`. + init_internal_buffers(); + + // Create connections and channels to remote ranks + init_comm_connections(); + + // Construct `kernel_`. + init_kernel(); +} + +void PlanResource::verify_plan() { + rank_ = plan_json_["Rank"]; + world_size_ = plan_json_["WorldSize"]; if (rank_ < 0 || rank_ >= world_size_) { ERR(InvalidUsageError, "Invalid rank ", rank_, " with world size ", world_size_); } - if (world_size_ > 1) { - init_communicator(); - } - auto gpu_manager = GpuManager::get_instance(device_id_); if (!gpu_manager->info().arch->belongs_to( - Arch::from_name(plan_json.at("Architecture")))) { + Arch::from_name(plan_json_.at("Architecture")))) { LOG(WARN, "Architecture name of the plan `", - plan_json.at("Architecture").get(), + plan_json_.at("Architecture").get(), "` is not compatible with the GPU architecture `", gpu_manager->info().arch->name(), "`."); } - - buffer_id_to_offset_ = init_buffers(plan_json_); - - std::string buffer_id_to_offset_str; - for (const auto &kv : buffer_id_to_offset_) { - buffer_id_to_offset_str += - std::to_string(kv.first) + ": " + std::to_string(kv.second) + ", "; - } - - codegen_ = std::make_shared(plan_json_, buffer_id_to_offset_, - name_); - - timer_begin_ = gpu_manager->create_event(); - timer_end_ = gpu_manager->create_event(); - buffer_ = gpu_manager->malloc(total_bytes_, 65536); - flag_ = gpu_manager->malloc_host( - sizeof(int), gpuHostAllocMapped | gpuHostAllocWriteCombined); - - int threads_per_block = static_cast( - codegen_->num_warps_per_proc() * gpu_manager->info().threads_per_warp); - int num_sm = static_cast(codegen_->num_procs()); - size_t smem_block_total = - static_cast(gpu_manager->info().smem_block_total); - - if (world_size_ > 1) { - auto remote_ranks = init_remote_ranks(plan_json_); - init_channels(remote_ranks); - } - - std::string kernel_name; - if (loop_mode_) { - kernel_name = "ark_loop_kernel"; - } else { - kernel_name = "ark_kernel"; - } - if (!name_.empty()) { - kernel_name += "_" + name_; - } - - kernel_ = std::shared_ptr(new GpuKernel( - device_id_, codegen_->code(), {threads_per_block, 1, 1}, {num_sm, 1, 1}, - std::max(smem_block_total, size_t(4)), kernel_name)); } -void Executor::Impl::init_communicator() { - auto bootstrap = - std::make_shared(rank_, world_size_); - std::stringstream ip_port; - ip_port << get_host(0) << ":" << get_env().mscclpp_port; - bootstrap->initialize(ip_port.str()); - comm_ = std::make_shared(bootstrap); +void PlanResource::init_comm_resource() { + if (comm_resource_) { + if (comm_resource_->rank() != rank_) { + ERR(InvalidUsageError, + "Rank should be consistent across all plans. " + "Expected ", + rank_, " but got ", comm_resource_->rank()); + } + if (comm_resource_->world_size() != world_size_) { + ERR(InvalidUsageError, + "World size should be consistent across all " + "plans. Expected ", + world_size_, " but got ", comm_resource_->world_size()); + } + } else if (world_size_ > 1) { + comm_resource_ = + std::make_shared(device_id_, rank_, world_size_); + } } -std::map Executor::Impl::init_buffers(const Json &plan_json) { +void PlanResource::init_internal_buffers() { class BufferInfo { public: BufferInfo(const std::shared_ptr buffer) : buffer(buffer), bytes(0), is_input(true), is_output(true) {} - // ID of this buffer + // Underlying ModelBuffer const std::shared_ptr buffer; // Total bytes of this buffer @@ -324,17 +495,17 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { std::set task_ids; }; - std::map buffer_id_to_offset; std::map> buffer_id_to_info; auto get_or_create_buffer_info = [&](const Json &buffer_json) { auto buffer = ModelBuffer::deserialize(buffer_json); - if (buffer_id_to_info.find(buffer->id()) == buffer_id_to_info.end()) { + auto it = buffer_id_to_info.find(buffer->id()); + if (it == buffer_id_to_info.end()) { auto buf_info = std::make_shared(buffer); buffer_id_to_info[buffer->id()] = buf_info; return buf_info; } - return buffer_id_to_info[buffer->id()]; + return it->second; }; auto retrieve_buffer_info = [&](const Json &tensor, size_t task_id, @@ -349,7 +520,7 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { buf_info->task_ids.insert(task_id); }; - for (auto &task_info : plan_json["TaskInfos"]) { + for (auto &task_info : plan_json_["TaskInfos"]) { for (auto &op : task_info["Ops"]) { size_t task_id = task_info["Id"].get(); for (auto &tns : op["ReadTensors"]) { @@ -371,41 +542,69 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { std::map> remote_rank_to_send_tag_to_buffer_id; std::map> remote_rank_to_recv_tag_to_buffer_id; + auto is_remote = [&](const std::shared_ptr &buffer) { + return buffer->rank() != rank_ && buffer->rank() != -1; + }; + // TODO: improve memory planning size_t offset = 0; - for (auto &kv : buffer_id_to_info) { - auto &buf_info = kv.second; - int r = buf_info->buffer->rank(); - if (r != rank_ && r != -1) { + for (auto &[buf_id, buf_info] : buffer_id_to_info) { + auto &buffer = buf_info->buffer; + if (is_remote(buffer)) { // this is a remote buffer - for (const auto &tag_info : buf_info->buffer->send_tags()) { - remote_rank_to_send_tag_to_buffer_id[buf_info->buffer->rank()] - [tag_info.second] = - buf_info->buffer->id(); + if (buffer->is_external()) { + ERR(InvalidUsageError, + "Communication with external buffers is not supported"); + } + int r = buffer->rank(); + for (const auto &tag_info : buffer->send_tags()) { + // This remote buffer will send data to local buffers + remote_rank_to_send_tag_to_buffer_id[r][tag_info.second] = + buf_id; } - for (const auto &tag_info : buf_info->buffer->recv_tags()) { - remote_rank_to_recv_tag_to_buffer_id[buf_info->buffer->rank()] - [tag_info.second] = - buf_info->buffer->id(); + for (const auto &tag_info : buffer->recv_tags()) { + // This remote buffer will receive data from local buffers + remote_rank_to_recv_tag_to_buffer_id[r][tag_info.second] = + buf_id; } continue; } - buffer_id_to_offset[buf_info->buffer->id()] = offset; - for (const auto &tag_info : buf_info->buffer->send_tags()) { - remote_rank_to_send_tags_and_offsets[tag_info.first] - .first.push_back(tag_info.second); - remote_rank_to_send_tags_and_offsets[tag_info.first] - .second.push_back(offset); + auto info = BufferRegistry::get_instance().get(buf_id); + if (info || buffer->is_external()) { + // This buffer is external or has been already allocated by a + // previous plan. + extra_buffer_ids_.insert(buf_id); + } else { + // Assign an offset to this internal local buffer + internal_buffer_id_to_offset_[buf_id] = offset; + for (const auto &tag_info : buffer->send_tags()) { + // This local buffer will send data to remote ranks + remote_rank_to_send_tags_and_offsets[tag_info.first] + .first.push_back(tag_info.second); + remote_rank_to_send_tags_and_offsets[tag_info.first] + .second.push_back(offset); + } + for (const auto &tag_info : buffer->recv_tags()) { + // This local buffer will receive data from remote ranks + remote_rank_to_recv_tags_and_offsets[tag_info.first] + .first.push_back(tag_info.second); + remote_rank_to_recv_tags_and_offsets[tag_info.first] + .second.push_back(offset); + } + offset += buf_info->bytes; } - for (const auto &tag_info : buf_info->buffer->recv_tags()) { - remote_rank_to_recv_tags_and_offsets[tag_info.first] - .first.push_back(tag_info.second); - remote_rank_to_recv_tags_and_offsets[tag_info.first] - .second.push_back(offset); + } + size_t total_bytes = offset; + + // Allocate memory for internal local buffers + if (total_bytes > 0) { + buffer_ = + GpuManager::get_instance(device_id_)->malloc(total_bytes, 65536); + for (auto &[buf_id, buf_offset] : internal_buffer_id_to_offset_) { + BufferRegistry::get_instance().set(buf_id, buffer_->ref(buf_offset), + device_id_, false); } - offset += buf_info->bytes; } - total_bytes_ = offset; // // Send each tag (SendTag or RecvTag) and the corresponding offset to @@ -447,7 +646,7 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { auto &tags = tags_and_offsets.first; auto &offsets = tags_and_offsets.second; int len = tags.size(); - auto bootstrap = comm_->bootstrap(); + auto bootstrap = comm_resource_->bootstrap(); bootstrap->send(&len, sizeof(int), remote_rank, 0); bootstrap->send(tags.data(), tags.size() * sizeof(int), remote_rank, 1); bootstrap->send(offsets.data(), offsets.size() * sizeof(size_t), @@ -460,7 +659,7 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { auto &tags = tags_and_offsets.first; auto &offsets = tags_and_offsets.second; int len = tags.size(); - auto bootstrap = comm_->bootstrap(); + auto bootstrap = comm_resource_->bootstrap(); bootstrap->send(&len, sizeof(int), remote_rank, 3); bootstrap->send(tags.data(), tags.size() * sizeof(int), remote_rank, 4); bootstrap->send(offsets.data(), offsets.size() * sizeof(size_t), @@ -472,14 +671,21 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { std::vector tags; std::vector offsets; int len; - auto bootstrap = comm_->bootstrap(); + auto bootstrap = comm_resource_->bootstrap(); bootstrap->recv(&len, sizeof(int), remote_rank, 0); tags.resize(len); offsets.resize(len); bootstrap->recv(tags.data(), len * sizeof(int), remote_rank, 1); bootstrap->recv(offsets.data(), len * sizeof(size_t), remote_rank, 2); for (int i = 0; i < len; ++i) { - buffer_id_to_offset[send_tag_to_buffer_id[tags[i]]] = offsets[i]; + auto it = send_tag_to_buffer_id.find(tags[i]); + if (it == send_tag_to_buffer_id.end()) { + LOG(WARN, "Send tag ", tags[i], " from remote rank ", + remote_rank, " is unexpected"); + continue; + } + size_t buf_id = it->second; + internal_buffer_id_to_offset_[buf_id] = offsets[i]; } } for (auto &kv : remote_rank_to_recv_tag_to_buffer_id) { @@ -488,225 +694,287 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { std::vector tags; std::vector offsets; int len; - auto bootstrap = comm_->bootstrap(); + auto bootstrap = comm_resource_->bootstrap(); bootstrap->recv(&len, sizeof(int), remote_rank, 3); tags.resize(len); offsets.resize(len); bootstrap->recv(tags.data(), len * sizeof(int), remote_rank, 4); bootstrap->recv(offsets.data(), len * sizeof(size_t), remote_rank, 5); for (int i = 0; i < len; ++i) { - buffer_id_to_offset[recv_tag_to_buffer_id[tags[i]]] = offsets[i]; + auto it = recv_tag_to_buffer_id.find(tags[i]); + if (it == recv_tag_to_buffer_id.end()) { + LOG(WARN, "Recv tag ", tags[i], " from remote rank ", + remote_rank, " is unexpected"); + continue; + } + size_t buf_id = it->second; + internal_buffer_id_to_offset_[buf_id] = offsets[i]; } } - - return buffer_id_to_offset; } -std::set Executor::Impl::init_remote_ranks(const Json &plan_json) const { - std::set remote_ranks; - for (auto &task_info : plan_json["TaskInfos"]) { - for (auto &op : task_info["Ops"]) { - for (auto &tns : op["ReadTensors"]) { - auto buffer = ModelBuffer::deserialize(tns["Buffer"]); - if (buffer->rank() != rank_ && buffer->rank() != -1) { - remote_ranks.insert(buffer->rank()); - } - } - for (auto &tns : op["WriteTensors"]) { - auto buffer = ModelBuffer::deserialize(tns["Buffer"]); - if (buffer->rank() != rank_ && buffer->rank() != -1) { - remote_ranks.insert(buffer->rank()); - } - } - for (auto &tns : op["ResultTensors"]) { - auto buffer = ModelBuffer::deserialize(tns["Buffer"]); - if (buffer->rank() != rank_ && buffer->rank() != -1) { - remote_ranks.insert(buffer->rank()); - } - } - } +void PlanResource::init_comm_connections() { + if (comm_resource_ && buffer_) { + comm_resource_->connect(plan_json_, buffer_); } - return remote_ranks; } -void Executor::Impl::init_channels(const std::set &remote_ranks) { - proxy_service_ = std::make_shared(); +void PlanResource::init_kernel() { + auto gpu_manager = GpuManager::get_instance(device_id_); + auto codegen = std::make_shared( + plan_json_, internal_buffer_id_to_offset_, extra_buffer_ids_); + int num_sm = static_cast(codegen->num_procs()); + int threads_per_block = static_cast( + codegen->num_warps_per_proc() * gpu_manager->info().threads_per_warp); + size_t smem_block_total = + static_cast(gpu_manager->info().smem_block_total); - int num_ranks_per_node = get_env().num_ranks_per_host; - auto rank_to_node = [&](int rank) { return rank / num_ranks_per_node; }; - int this_node = rank_to_node(rank_); + kernel_ = std::shared_ptr( + new GpuKernel(device_id_, codegen->code(), {threads_per_block, 1, 1}, + {num_sm, 1, 1}, std::max(smem_block_total, size_t(4)))); + kernel_->compile(); - const mscclpp::Transport IBs[] = { - mscclpp::Transport::IB0, mscclpp::Transport::IB1, - mscclpp::Transport::IB2, mscclpp::Transport::IB3, - mscclpp::Transport::IB4, mscclpp::Transport::IB5, - mscclpp::Transport::IB6, mscclpp::Transport::IB7}; + if (world_size_ <= 1) return; - mscclpp::TransportFlags all_transports = - mscclpp::Transport::CudaIpc | mscclpp::Transport::Ethernet; - if (!get_env().disable_ib) { - all_transports |= IBs[device_id_]; + auto get_global_rt = [&](const std::string &symbol) { + return reinterpret_cast(kernel_->get_global(symbol)); + }; + void *proxy_chan_addr = get_global_rt("ARK_PROXY_CHANS"); + void *proxy_secondary_chan_addr = + get_global_rt("ARK_PROXY_SECONDARY_CHANS"); + void *sm_chan_addr = get_global_rt("ARK_SM_CHANS"); + std::vector proxy_handles( + world_size_); + std::vector + proxy_secondary_handles(world_size_); + std::vector sm_handles(world_size_); + for (int i = 0; i < world_size_; i++) { + if (i == rank_) continue; + auto resource = comm_resource_->resource(i); + if (!resource) continue; + std::vector p_hdls; + if (resource->ipc) { + sm_handles[i] = resource->ipc->sm_channels[0]->deviceHandle(); + p_hdls.push_back(resource->ipc->proxy_channels[0]->deviceHandle()); + } + if (resource->ib) { + p_hdls.push_back(resource->ib->proxy_channels[0]->deviceHandle()); + } + if (resource->eth) { + p_hdls.push_back(resource->eth->proxy_channels[0]->deviceHandle()); + } + if (p_hdls.size() > 0) { + proxy_handles[i] = p_hdls[0]; + } + if (p_hdls.size() > 1) { + proxy_secondary_handles[i] = p_hdls[1]; + } } - mscclpp::RegisteredMemory regmem = - comm_->registerMemory(buffer_->ref(), buffer_->bytes(), all_transports); - - std::map>>> - rank_to_connections_future; - std::map> - rank_to_remote_regmem_future; + auto tmp_stream = gpu_manager->create_stream(); + GLOG(gpuSetDevice(device_id_)); + GLOG(gpuMemcpyAsync(proxy_chan_addr, proxy_handles.data(), + proxy_handles.size() * + sizeof(mscclpp::SimpleProxyChannel::DeviceHandle), + gpuMemcpyHostToDevice, tmp_stream->get())); + GLOG(gpuMemcpyAsync(proxy_secondary_chan_addr, + proxy_secondary_handles.data(), + proxy_secondary_handles.size() * + sizeof(mscclpp::SimpleProxyChannel::DeviceHandle), + gpuMemcpyHostToDevice, tmp_stream->get())); + GLOG(gpuMemcpyAsync( + sm_chan_addr, sm_handles.data(), + sm_handles.size() * sizeof(mscclpp::SmChannel::DeviceHandle), + gpuMemcpyHostToDevice, tmp_stream->get())); + GLOG(gpuStreamSynchronize(tmp_stream->get())); +} - for (auto remote_rank : remote_ranks) { - int remote_node = rank_to_node(remote_rank); - auto add_connection = [&](int remote_rank, - mscclpp::Transport transport) { - rank_to_connections_future[remote_rank].push_back( - comm_->connectOnSetup(remote_rank, 0, transport)); - }; - if (remote_node == this_node) { - add_connection(remote_rank, mscclpp::Transport::CudaIpc); - if (!get_env().disable_ib) { - add_connection(remote_rank, IBs[device_id_]); - } - } else { - add_connection(remote_rank, get_env().disable_ib - ? mscclpp::Transport::Ethernet - : IBs[device_id_]); +void PlanResource::launch_kernel(const std::string &name, + const std::vector &args, + gpuStream stream) { + std::vector kernel_args = args; + for (size_t id : extra_buffer_ids_) { + auto info = BufferRegistry::get_instance().get(id); + if (!info) { + ERR(InternalError, "External buffer not found."); + } else if (info->data == nullptr) { + ERR(InvalidUsageError, "External buffer data is nullptr."); } - comm_->sendMemoryOnSetup(regmem, remote_rank, 0); - rank_to_remote_regmem_future[remote_rank] = - comm_->recvMemoryOnSetup(remote_rank, 0); + kernel_args.push_back(&(info->data)); } - comm_->setup(); + kernel_->launch(name, stream, kernel_args); +} - std::map>> - rank_to_connections; - for (auto &kv : rank_to_connections_future) { - for (auto &future : kv.second) { - rank_to_connections[kv.first].push_back(future.get()); - } +class Executor::Impl { + public: + Impl(){}; + ~Impl(); + + int device_id() const { + return foreground_plan_resource_ + ? foreground_plan_resource_->device_id() + : -1; } - for (auto &kv : rank_to_connections) { - for (auto &conn : kv.second) { - rank_to_proxy_channels_[kv.first].push_back( - std::make_shared( - proxy_service_->proxyChannel( - proxy_service_->buildAndAddSemaphore(*comm_, conn)), - proxy_service_->addMemory( - rank_to_remote_regmem_future[kv.first].get()), - proxy_service_->addMemory(regmem))); - } + + Stream stream() const { return reinterpret_cast(stream_raw_); } + + std::shared_ptr buffer() const { + return foreground_plan_resource_ ? foreground_plan_resource_->buffer() + : nullptr; } - comm_->setup(); - std::map>> - sm_semaphores; - for (auto &kv : rank_to_connections) { - for (auto &conn : kv.second) { - if (conn->transport() != mscclpp::Transport::CudaIpc) continue; - sm_semaphores[kv.first].push_back( - std::make_shared(*comm_, - conn)); - } + std::string plan() const { + return foreground_plan_resource_ + ? foreground_plan_resource_->plan_json().dump_pretty() + : ""; } - comm_->setup(); - for (auto &kv : sm_semaphores) { - for (auto &sem : kv.second) { - rank_to_sm_channels_[kv.first].push_back( - std::make_shared( - sem, rank_to_remote_regmem_future[kv.first].get(), - regmem.data(), nullptr)); - } + std::string name() const { + return foreground_plan_resource_ ? foreground_plan_resource_->name() + : ""; } + + void compile(const std::string &plan, int device_id, + const std::string &name); + void launch(const std::unordered_map &placeholder_data, + Stream stream, bool loop_mode, bool record); + void run(int iter, + const std::unordered_map &placeholder_data); + void wait(int64_t max_spin_count); + float stop(int64_t max_spin_count); + void barrier(); + + void *tensor_address(const Tensor &tensor) const; + + void tensor_read(const Tensor &tensor, void *data, size_t bytes, + Stream stream, bool is_d2d) const; + void tensor_write(const Tensor &tensor, const void *data, size_t bytes, + Stream stream, bool is_d2d) const; + + protected: + friend class DefaultExecutor; + + gpuStream stream_raw_; + bool loop_mode_; + + private: + std::shared_ptr get_buffer_info( + const Tensor &tensor, bool fail_on_null) const; + + std::map> plan_resources_; + std::shared_ptr foreground_plan_resource_; + std::shared_ptr comm_resource_; + + bool is_launched_ = false; + bool is_recording_ = false; + float elapsed_msec_ = -1; + + std::shared_ptr timer_begin_; + std::shared_ptr timer_end_; + std::shared_ptr flag_; + std::shared_ptr stream_; +}; + +Executor::Impl::~Impl() { + if (is_launched_) stop(-1); } -void Executor::Impl::compile() { kernel_->compile(); } +void Executor::Impl::compile(const std::string &plan, int device_id, + const std::string &name) { + if (is_launched_) { + ERR(InvalidUsageError, "Need to stop before re-compiling."); + } + int prev_device_id = -1; + if (foreground_plan_resource_) { + prev_device_id = foreground_plan_resource_->device_id(); + } + if (prev_device_id != device_id) { + auto gpu_manager = GpuManager::get_instance(device_id); + timer_begin_ = gpu_manager->create_event(); + timer_end_ = gpu_manager->create_event(); + flag_ = gpu_manager->malloc_host( + sizeof(int), gpuHostAllocMapped | gpuHostAllocWriteCombined); + stream_ = gpu_manager->create_stream(); + } + PlanResourceKey key(plan, device_id, name); + auto it = plan_resources_.find(key); + if (it == plan_resources_.end()) { + try { + auto plan_json = Json::parse(plan); + auto resource = std::make_shared( + plan_json, device_id, name, comm_resource_); + plan_resources_[key] = resource; + foreground_plan_resource_ = resource; + } catch (const ::nlohmann::json::parse_error &e) { + ERR(InvalidUsageError, "Failed to parse the plan JSON: ", e.what()); + } + } else { + foreground_plan_resource_ = it->second; + } +} -void Executor::Impl::launch() { - if (!kernel_->is_compiled()) { - ERR(InvalidUsageError, "Need to compile first before initialization."); +void Executor::Impl::launch( + const std::unordered_map &placeholder_data, Stream stream, + bool loop_mode, bool record) { + if (!foreground_plan_resource_) { + ERR(InvalidUsageError, "Need to compile first before launch."); } if (is_launched_) { LOG(WARN, "Ignore launching twice."); return; } - auto get_global_rt = [&](const std::string &symbol) { - return reinterpret_cast(kernel_->get_global(symbol)); - }; - if (world_size_ > 1) { - void *proxy_chan_addr = get_global_rt("ARK_PROXY_CHANS"); - void *proxy_secondary_chan_addr = - get_global_rt("ARK_PROXY_SECONDARY_CHANS"); - void *sm_chan_addr = get_global_rt("ARK_SM_CHANS"); - std::vector proxy_handles( - world_size_); - std::vector - proxy_secondary_handles(world_size_); - std::vector sm_handles(world_size_); - for (int i = 0; i < world_size_; i++) { - auto it = rank_to_proxy_channels_.find(i); - if (it != rank_to_proxy_channels_.end() && it->second.size() > 0) { - proxy_handles[i] = it->second[0]->deviceHandle(); - if (it->second.size() > 1) { - proxy_secondary_handles[i] = it->second[1]->deviceHandle(); - } - } - auto it2 = rank_to_sm_channels_.find(i); - if (it2 != rank_to_sm_channels_.end() && it2->second.size() > 0) { - sm_handles[i] = it2->second[0]->deviceHandle(); - } + for (const auto &[tensor, ptr] : placeholder_data) { + if (tensor.ref()->data(ptr) != ptr) { + ERR(InvalidUsageError, + "Placeholder data must be external tensors."); } - GLOG(gpuSetDevice(device_id_)); - GLOG(gpuMemcpyAsync( - proxy_chan_addr, proxy_handles.data(), - proxy_handles.size() * - sizeof(mscclpp::SimpleProxyChannel::DeviceHandle), - gpuMemcpyHostToDevice, stream_raw_)); - GLOG(gpuMemcpyAsync( - proxy_secondary_chan_addr, proxy_secondary_handles.data(), - proxy_secondary_handles.size() * - sizeof(mscclpp::SimpleProxyChannel::DeviceHandle), - gpuMemcpyHostToDevice, stream_raw_)); - GLOG(gpuMemcpyAsync( - sm_chan_addr, sm_handles.data(), - sm_handles.size() * sizeof(mscclpp::SmChannel::DeviceHandle), - gpuMemcpyHostToDevice, stream_raw_)); - GLOG(gpuStreamSynchronize(stream_raw_)); } + stream_raw_ = stream ? reinterpret_cast(stream) : stream_->get(); + loop_mode_ = loop_mode; elapsed_msec_ = -1; - timer_begin_->record(stream_raw_); - if (world_size_ > 1) { - proxy_service_->startProxy(); + if (record) { + timer_begin_->record(stream_raw_); + is_recording_ = true; + } + if (comm_resource_) { + comm_resource_->proxy_service()->startProxy(); } if (loop_mode_) { // Initialize loop flags. atomicStoreRelaxed(flag_->ref(), 0); - void *buf_ptr = buffer_->ref(); + auto buffer = foreground_plan_resource_->buffer(); + void *buf_ptr = buffer ? buffer->ref() : nullptr; void *flag_ptr = flag_->ref(); std::vector args = {&buf_ptr, &flag_ptr}; - kernel_->launch(stream_raw_, args); + foreground_plan_resource_->launch_kernel("ark_loop_kernel", args, + stream_raw_); } - is_recording_ = true; is_launched_ = true; } -void Executor::Impl::run(int iter) { +void Executor::Impl::run( + int iter, const std::unordered_map &placeholder_data) { + for (const auto &[tensor, ptr] : placeholder_data) { + if (tensor.ref()->data(ptr) != ptr) { + ERR(InvalidUsageError, + "Placeholder data must be external tensors."); + } + } if (iter <= 0) return; if (loop_mode_) { while (atomicLoadRelaxed(flag_->ref()) > 0) { } atomicStoreRelaxed(flag_->ref(), iter); } else { - void *buf_ptr = buffer_->ref(); + auto buffer = foreground_plan_resource_->buffer(); + void *buf_ptr = buffer ? buffer->ref() : nullptr; int i = 0; std::vector args = {&buf_ptr, reinterpret_cast(&i)}; for (; i < iter; i++) { - kernel_->launch(stream_raw_, args); + foreground_plan_resource_->launch_kernel("ark_kernel", args, + stream_raw_); } } } @@ -722,9 +990,8 @@ void Executor::Impl::wait(int64_t max_spin_count) { gpuError res = gpuStreamQuery(stream_raw_); if (res == gpuSuccess) { if (atomicLoadRelaxed(flag_->ref()) > 0) { - LOG(WARN, + ERR(InternalError, "Stream is finished but the loop flag is still set."); - break; } else { LOG(WARN, "wait() is delayed by a stream query. Regarding " @@ -759,30 +1026,46 @@ float Executor::Impl::stop(int64_t max_spin_count) { is_recording_ = false; } is_launched_ = false; - if (world_size_ > 1) { - proxy_service_->stopProxy(); + if (comm_resource_) { + comm_resource_->proxy_service()->stopProxy(); } return elapsed_msec_; } void Executor::Impl::barrier() { - if (world_size_ > 1) { - comm_->bootstrap()->barrier(); + if (comm_resource_) { + comm_resource_->bootstrap()->barrier(); } } -uintptr_t Executor::Impl::tensor_address(const Tensor &tensor) const { +std::shared_ptr Executor::Impl::get_buffer_info( + const Tensor &tensor, bool fail_on_null) const { size_t buffer_id = tensor.ref()->buffer()->id(); - if (buffer_id_to_offset_.find(buffer_id) == buffer_id_to_offset_.end()) { - ERR(InternalError, "Invalid buffer ID: ", buffer_id); + auto &buf_reg = BufferRegistry::get_instance(); + auto info = buf_reg.get(buffer_id); + if (fail_on_null && (!info || !(info->data))) { + ERR(InvalidUsageError, + "Tensor has no allocated memory. " + "This is likely caused by accessing a tensor that is optimized " + "out by the compiler or not used in any plan passed to the " + "executor."); } - size_t offset = buffer_id_to_offset_.at(buffer_id); - return reinterpret_cast(buffer_->ref(offset)); + return info; +} + +void *Executor::Impl::tensor_address(const Tensor &tensor) const { + auto info = get_buffer_info(tensor, false); + if (!info || !(info->data)) { + return nullptr; + } + return info->data; } void Executor::Impl::tensor_read(const Tensor &tensor, void *data, size_t bytes, Stream stream, bool is_d2d) const { - GLOG(gpuSetDevice(device_id_)); + auto info = get_buffer_info(tensor, true); + size_t device_id = info->device_id; + GLOG(gpuSetDevice(device_id)); std::shared_ptr copy_stream; gpuStream copy_stream_raw; if (stream) { @@ -793,7 +1076,7 @@ void Executor::Impl::tensor_read(const Tensor &tensor, void *data, size_t bytes, "may cause a deadlock."); } } else { - copy_stream = GpuManager::get_instance(device_id_)->create_stream(); + copy_stream = GpuManager::get_instance(device_id)->create_stream(); copy_stream_raw = copy_stream->get(); } size_t tensor_data_bytes = @@ -803,7 +1086,7 @@ void Executor::Impl::tensor_read(const Tensor &tensor, void *data, size_t bytes, ") mismatches the tensor data bytes (", tensor_data_bytes, ")."); } auto kind = (is_d2d) ? gpuMemcpyDeviceToDevice : gpuMemcpyDeviceToHost; - void *src = reinterpret_cast(tensor_address(tensor)); + void *src = info->data; if (tensor.strides() == tensor.shape()) { GLOG(gpuMemcpyAsync(data, src, bytes, kind, copy_stream_raw)); } else { @@ -833,7 +1116,9 @@ void Executor::Impl::tensor_read(const Tensor &tensor, void *data, size_t bytes, void Executor::Impl::tensor_write(const Tensor &tensor, const void *data, size_t bytes, Stream stream, bool is_d2d) const { - GLOG(gpuSetDevice(device_id_)); + auto info = get_buffer_info(tensor, true); + size_t device_id = info->device_id; + GLOG(gpuSetDevice(device_id)); std::shared_ptr copy_stream; gpuStream copy_stream_raw; if (stream) { @@ -844,7 +1129,7 @@ void Executor::Impl::tensor_write(const Tensor &tensor, const void *data, "may cause a deadlock."); } } else { - copy_stream = GpuManager::get_instance(device_id_)->create_stream(); + copy_stream = GpuManager::get_instance(device_id)->create_stream(); copy_stream_raw = copy_stream->get(); } size_t tensor_data_bytes = @@ -856,7 +1141,7 @@ void Executor::Impl::tensor_write(const Tensor &tensor, const void *data, size_t tensor_bytes = tensor.strides().nelems() * tensor.data_type().bytes(); auto kind = (is_d2d) ? gpuMemcpyDeviceToDevice : gpuMemcpyHostToDevice; - void *dst = reinterpret_cast(tensor_address(tensor)); + void *dst = info->data; if (tensor.strides() == tensor.shape()) { GLOG(gpuMemcpyAsync(dst, data, tensor_bytes, kind, copy_stream_raw)); } else { @@ -885,18 +1170,7 @@ void Executor::Impl::tensor_write(const Tensor &tensor, const void *data, GLOG(gpuStreamSynchronize(copy_stream_raw)); } -Executor::Executor(int device_id, Stream stream, const std::string &name, - const std::string &plan, bool loop_mode) - : impl_(std::make_unique(device_id, stream, name, - loop_mode)) { - auto &plan_path = get_env().enforce_plan_path; - if (!plan_path.empty()) { - LOG(INFO, "Enforce executor plan path: ", plan_path); - impl_->init(Json::parse(read_file(plan_path))); - } else if (!plan.empty()) { - impl_->init(Json::parse(plan)); - } -} +Executor::Executor() : impl_(std::make_unique()) {} Executor::~Executor() = default; @@ -904,13 +1178,27 @@ int Executor::device_id() const { return impl_->device_id(); } Stream Executor::stream() const { return impl_->stream(); } +std::shared_ptr Executor::buffer() const { return impl_->buffer(); } + std::string Executor::plan() const { return impl_->plan(); } -void Executor::compile() { impl_->compile(); } +std::string Executor::name() const { return impl_->name(); } + +void Executor::compile(const std::string &plan, int device_id, + const std::string &name) { + impl_->compile(plan, device_id, name); +} -void Executor::launch() { impl_->launch(); } +void Executor::launch( + const std::unordered_map &placeholder_data, Stream stream, + bool loop_mode, bool record) { + impl_->launch(placeholder_data, stream, loop_mode, record); +} -void Executor::run(int iter) { impl_->run(iter); } +void Executor::run(int iter, + const std::unordered_map &placeholder_data) { + impl_->run(iter, placeholder_data); +} void Executor::wait(int64_t max_spin_count) { impl_->wait(max_spin_count); } @@ -924,7 +1212,7 @@ void Executor::destroy() { impl_.reset(nullptr); } bool Executor::destroyed() const { return impl_.get() == nullptr; } -uintptr_t Executor::tensor_address(const Tensor &tensor) const { +void *Executor::tensor_address(const Tensor &tensor) const { return impl_->tensor_address(tensor); } @@ -941,15 +1229,24 @@ void Executor::tensor_write(const Tensor &tensor, const void *data, DefaultExecutor::DefaultExecutor( const Model &model, int device_id, Stream stream, const std::vector &config_rules, - const std::string &name, bool loop_mode) - : Executor((device_id < 0) ? (model.rank() % get_env().num_ranks_per_host) - : device_id, - stream, name, "", loop_mode) { - Planner planner(model, impl_->device_id()); + const std::string &name, bool loop_mode, bool record) + : Executor(), record_(record) { + device_id = (device_id < 0) ? (model.rank() % get_env().num_ranks_per_host) + : device_id; + Planner planner(model, device_id); for (const auto &rule : config_rules) { planner.install_config_rule(rule); } - impl_->init(Json::parse(planner.plan())); + compile(planner.plan(), device_id, name); + impl_->stream_raw_ = reinterpret_cast(stream); + impl_->loop_mode_ = loop_mode; +} + +void DefaultExecutor::launch( + const std::unordered_map &placeholder_data) { + Executor::launch(placeholder_data, + reinterpret_cast(impl_->stream_raw_), + impl_->loop_mode_, record_); } } // namespace ark diff --git a/ark/api/executor_test.cpp b/ark/api/executor_test.cpp index dad0e9d83..22c7d7c47 100644 --- a/ark/api/executor_test.cpp +++ b/ark/api/executor_test.cpp @@ -3,6 +3,7 @@ #include "ark/executor.hpp" +#include "ark/planner.hpp" #include "gpu/gpu.hpp" #include "model/model_json.hpp" #include "unittest/unittest_utils.h" @@ -20,7 +21,6 @@ ark::unittest::State test_executor() { UNITTEST_EQ(executor.device_id(), 0); UNITTEST_EQ(executor.stream(), stream); - executor.compile(); executor.launch(); executor.run(1); executor.wait(); @@ -31,7 +31,6 @@ ark::unittest::State test_executor() { } { ark::DefaultExecutor executor(empty, 0, stream, {}, "test", LoopMode); - executor.compile(); executor.launch(); executor.run(1); executor.wait(); @@ -46,9 +45,7 @@ ark::unittest::State test_executor() { } { ark::DefaultExecutor executor(empty, 0, stream, {}, "test", LoopMode); - UNITTEST_THROW(executor.launch(), ark::InvalidUsageError); - executor.compile(); executor.launch(); executor.launch(); // Will be ignored with a warning. executor.run(1); @@ -58,6 +55,34 @@ ark::unittest::State test_executor() { // Stop & destroy automatically. } + // Raw executor test + ark::Model m; + auto tensor = m.tensor({1024}, ark::FP32); + m.noop(tensor); + + ark::Planner planner(m, 0); + auto plan = planner.plan(); + { + std::vector array(1024); + + ark::Executor exe; + UNITTEST_EQ(exe.tensor_address(tensor), nullptr); + UNITTEST_THROW( + exe.tensor_read(tensor, array.data(), array.size() * sizeof(float)), + ark::InvalidUsageError); + UNITTEST_THROW(exe.tensor_write(tensor, array.data(), + array.size() * sizeof(float)), + ark::InvalidUsageError); + UNITTEST_THROW(exe.launch(), ark::InvalidUsageError); + + exe.compile(plan, 0); + UNITTEST_NE(exe.tensor_address(tensor), nullptr); + + exe.launch(); + exe.run(1); + exe.wait(); + } + UNITTEST_EQ(ark::gpuStreamDestroy(stream), ark::gpuSuccess); return ark::unittest::SUCCESS; } @@ -86,9 +111,8 @@ ark::unittest::State test_executor_tensor_read_write(ark::Dims shape, m.noop(tensor); ark::DefaultExecutor executor(m, 0); - executor.compile(); - executor.launch(); - UNITTEST_GT(executor.tensor_address(tensor), 0); + + UNITTEST_NE(executor.tensor_address(tensor), nullptr); // Copy data from CPU array to ARK tensor executor.tensor_write(tensor, host_data.data(), @@ -107,20 +131,28 @@ ark::unittest::State test_executor_tensor_read_write(ark::Dims shape, dev_data[i] = -1; } + ark::gpuStream stream; UNITTEST_EQ( - ark::gpuMemcpy(dev_data.data(), dev_ptr, shape.nelems() * sizeof(float), - ark::gpuMemcpyDeviceToHost), + ark::gpuStreamCreateWithFlags(&stream, ark::gpuStreamNonBlocking), ark::gpuSuccess); + + UNITTEST_EQ(ark::gpuMemcpyAsync(dev_data.data(), dev_ptr, + shape.nelems() * sizeof(float), + ark::gpuMemcpyDeviceToHost, stream), + ark::gpuSuccess); + UNITTEST_EQ(ark::gpuStreamSynchronize(stream), ark::gpuSuccess); + for (size_t i = 0; i < dev_data.size(); ++i) { UNITTEST_EQ(dev_data[i], static_cast(i)); dev_data[i] = -1; } // Copy -1s back to GPU array - UNITTEST_EQ( - ark::gpuMemcpy(dev_ptr, dev_data.data(), shape.nelems() * sizeof(float), - ark::gpuMemcpyHostToDevice), - ark::gpuSuccess); + UNITTEST_EQ(ark::gpuMemcpyAsync(dev_ptr, dev_data.data(), + shape.nelems() * sizeof(float), + ark::gpuMemcpyHostToDevice, stream), + ark::gpuSuccess); + UNITTEST_EQ(ark::gpuStreamSynchronize(stream), ark::gpuSuccess); // Copy data from GPU array to ARK tensor executor.tensor_write(tensor, dev_ptr, shape.nelems() * sizeof(float), @@ -136,10 +168,6 @@ ark::unittest::State test_executor_tensor_read_write(ark::Dims shape, } // Provide a stream - ark::gpuStream stream; - UNITTEST_EQ( - ark::gpuStreamCreateWithFlags(&stream, ark::gpuStreamNonBlocking), - ark::gpuSuccess); executor.tensor_read(tensor, host_data.data(), shape.nelems() * sizeof(float), stream); executor.tensor_write(tensor, host_data.data(), @@ -169,15 +197,19 @@ ark::unittest::State test_executor_tensor_read_write_stride_offset() { } ark::unittest::State test_executor_invalid() { + ark::Executor exe; + + // Invalid plan. + UNITTEST_THROW(exe.compile("not a json", 0), ark::InvalidUsageError); + // Invalid device ID. - UNITTEST_THROW(ark::Executor(-1, nullptr, "test", ""), + UNITTEST_THROW(exe.compile(ark::PlanJson().dump(), -1), ark::InvalidUsageError); // Invalid rank. ark::PlanJson plan; plan["Rank"] = 1; - UNITTEST_THROW(ark::Executor(0, nullptr, "test", plan.dump(), true), - ark::InvalidUsageError); + UNITTEST_THROW(exe.compile(plan.dump(), 0), ark::InvalidUsageError); return ark::unittest::SUCCESS; } diff --git a/ark/api/planner.cpp b/ark/api/planner.cpp index d36f33cbe..c2c98b216 100644 --- a/ark/api/planner.cpp +++ b/ark/api/planner.cpp @@ -11,12 +11,17 @@ #include "model/model_json.hpp" #include "model/model_node.hpp" #include "model/model_op.hpp" +#include "model/model_tensor.hpp" #include "range.hpp" namespace ark { PlannerContext::PlannerContext(Model &model) : Context(model) { - this->impl_->set("Id", this->id(), ContextType::Immutable); + this->impl_->set("Id", id()); + Json val; + val.push_back(id()); + val.push_back(true); + this->impl_->set("Sync", val); } void PlannerContext::check_range(const std::string &key, @@ -26,7 +31,7 @@ void PlannerContext::check_range(const std::string &key, // ok return; } - auto prev_vec = prev.get>(); + auto prev_vec = prev[1].get>(); if (prev_vec.size() < 2 || prev_vec.size() > 3) { ERR(InternalError, "unexpected"); } @@ -40,50 +45,56 @@ void PlannerContext::check_range(const std::string &key, void PlannerContext::processor_range(int start, int end, int step) { check_range("ProcessorRange", {start, end, step}); + Json val; + val.push_back(id()); if (step == 1) { - this->impl_->set("ProcessorRange", {start, end}, - ContextType::Overwrite); + val.push_back({start, end}); + this->impl_->set("ProcessorRange", {id(), {start, end}}); } else { - this->impl_->set("ProcessorRange", {start, end, step}, - ContextType::Overwrite); + val.push_back({start, end, step}); + this->impl_->set("ProcessorRange", {id(), {start, end, step}}); } } void PlannerContext::warp_range(int start, int end, int step) { check_range("WarpRange", {start, end, step}); + Json val; + val.push_back(id()); if (step == 1) { - this->impl_->set("WarpRange", {start, end}, ContextType::Overwrite); + val.push_back({start, end}); + this->impl_->set("WarpRange", {id(), {start, end}}); } else { - this->impl_->set("WarpRange", {start, end, step}, - ContextType::Overwrite); + val.push_back({start, end, step}); + this->impl_->set("WarpRange", {id(), {start, end, step}}); } } void PlannerContext::sram_range(int start, int end, int step) { check_range("SramRange", {start, end, step}); + Json val; + val.push_back(id()); if (step == 1) { - this->impl_->set("SramRange", {start, end}, ContextType::Overwrite); + val.push_back({start, end}); + this->impl_->set("SramRange", {id(), {start, end}}); } else { - this->impl_->set("SramRange", {start, end, step}, - ContextType::Overwrite); + val.push_back({start, end, step}); + this->impl_->set("SramRange", {id(), {start, end, step}}); } } void PlannerContext::sync(bool sync) { - if (sync) { - // `true` should not overwrite `false`. - if (this->impl_->get("Sync") == Json(false)) { - LOG(WARN, "Ignoring sync(true) while sync(false) is already set"); - return; - } - this->impl_->set("Sync", true, ContextType::Immutable); - } else { - this->impl_->set("Sync", false, ContextType::Overwrite); - } + // Sync should be always pushed with Id together. + Json val; + val.push_back(id()); + val.push_back(sync); + this->impl_->set("Sync", val); } void PlannerContext::config(const std::string &config) { - this->impl_->set("Config", Json::parse(config), ContextType::Extend); + Json val; + val.push_back(id()); + val.push_back(Json::parse(config)); + this->impl_->set("Config", val); } class Planner::Impl { @@ -132,27 +143,38 @@ std::string Planner::Impl::plan(bool pretty) const { size_t max_processor_id = 1; size_t max_warp_id = 1; size_t next_task_id = 0; - int prev_ctx_id = -1; + int merge_root = -1; + int processor_group_root = -1; bool first_op = true; auto get_context = [&](const ModelNodeRef &node, const std::string &key) -> Json { - if (node->context.find(key) != node->context.end()) { + try { return node->context.at(key); + } catch (const Json::out_of_range &e) { } - return Json(); + return Json::array(); + }; + + auto get_latest_context = [&](const ModelNodeRef &node, + const std::string &key) -> Json { + auto ctx = get_context(node, key); + if (ctx.empty()) return Json(); + return ctx.back(); }; for (const auto &node : model_.nodes()) { const auto &op = node->op; if (op->is_virtual()) continue; - auto ctx_config = get_context(node, "Config"); - - Json config; - if (!ctx_config.empty()) { - config = ctx_config; - } else if (!config_rules_.empty()) { + Json config = Json::object(); + for (auto &obj : get_context(node, "Config")) { + auto &items = obj[1]; + for (auto &item : items.items()) { + config[item.key()] = item.value(); + } + } + if (config.empty() && !config_rules_.empty()) { const std::string op_str = op->serialize().dump(); for (auto &rule : config_rules_) { auto config_str = rule(op_str, gpu_info.arch->name()); @@ -166,18 +188,100 @@ std::string Planner::Impl::plan(bool pretty) const { config = op->default_config(gpu_info.arch); } check_config_field(op, config, "NumWarps"); - check_config_field(op, config, "NumTasks"); check_config_field(op, config, "SramBytes"); size_t num_warps = config["NumWarps"]; - size_t num_tasks = config["NumTasks"]; size_t sram_bytes = config["SramBytes"]; + size_t max_num_tasks = 0; + size_t num_tasks; + + auto &result_tensors = op->result_tensors(); + if (!result_tensors.empty() && config.contains("Tile")) { + const std::vector tile_vec = config["Tile"]; + std::vector trim_leading_ones; + for (size_t i = 0; i < tile_vec.size(); i++) { + if (tile_vec[i] != 1) { + trim_leading_ones = std::vector( + tile_vec.begin() + i, tile_vec.end()); + break; + } + } + if (trim_leading_ones.empty()) { + trim_leading_ones.push_back(1); + } + Dims tile(trim_leading_ones); + + std::stringstream ss; + ss << "Result shape is not divided by tile " + << tile << ". Op: " << op->serialize().dump(); + auto not_divided_error = ss.str(); + + auto &result_shape = result_tensors[0]->padded_shape(); + if (result_shape.ndims() < tile.ndims()) { + ERR(PlanError, not_divided_error); + } + auto tile4 = tile.dims4(); + auto result_shape4 = result_shape.dims4(); + max_num_tasks = 1; + for (int i = 0; i < tile4.ndims(); i++) { + if (tile4[i] == 0 || result_shape4[i] % tile4[i] != 0) { + ERR(PlanError, not_divided_error); + } + max_num_tasks *= result_shape4[i] / tile4[i]; + } + if (max_num_tasks == 0) ERR(InternalError, "max_num_tasks == 0"); + } + if (config.contains("NumTasks")) { + num_tasks = config["NumTasks"]; + if (max_num_tasks > 0 && num_tasks > max_num_tasks) { + ERR(PlanError, "NumTasks (", num_tasks, + ") exceeds the maximum number of tasks calculated from the " + "tile (", + max_num_tasks, "). Op: ", op->serialize().dump()); + } else if (num_tasks < max_num_tasks) { + LOG(WARN, "NumTasks (", num_tasks, + ") is less than the maximum number of tasks calculated " + "from the tile (", + max_num_tasks, "). Op: ", op->serialize().dump()); + } + } else { + num_tasks = max_num_tasks; + } + if (num_tasks == 0 && op->type() != ModelOpT::from_name("Noop")) { + LOG(WARN, + "Detected a non-virtual op that does not perform any " + "computation. If this is unexpected, please check if " + "the config includes either `NumTasks` or `Tile` " + "field. Op: ", + op->serialize().dump()); + } size_t granularity = config.value("Granularity", 1); - auto ctx_id = get_context(node, "Id"); - auto ctx_sync = get_context(node, "Sync"); - int id = ctx_id.empty() ? -1 : ctx_id.get(); - bool sync = ctx_sync.empty() ? true : ctx_sync.get(); - if (id == prev_ctx_id && !sync) { + auto ctx_id_list = get_context(node, "Id"); + auto ctx_sync_list = get_context(node, "Sync"); + if (merge_root != -1) { + bool not_found = true; + for (auto ctx_id : ctx_id_list) { + if (ctx_id == merge_root) { + not_found = false; + break; + } + } + if (not_found) { + merge_root = -1; + } + } + bool merge_this_node = (merge_root != -1); + if (merge_root == -1) { + for (auto &item : ctx_sync_list) { + auto &ctx_id = item[0]; + auto &sync = item[1]; + if (!sync) { + merge_root = ctx_id; + break; + } + } + } + if (merge_this_node) { auto &task_info = task_infos.back(); task_info["NumWarps"] = std::max(task_info["NumWarps"].get(), num_warps); @@ -195,34 +299,53 @@ std::string Planner::Impl::plan(bool pretty) const { task_info["Ops"][0]["Config"] = config; task_infos.push_back(task_info); - auto ctx_processor_range = get_context(node, "ProcessorRange"); - auto ctx_warp_range = get_context(node, "WarpRange"); - auto ctx_sram_range = get_context(node, "SramRange"); + auto ctx_processor_range_list = get_context(node, "ProcessorRange"); + auto ctx_warp_range = get_latest_context(node, "WarpRange"); + auto ctx_sram_range = get_latest_context(node, "SramRange"); Json processor_group; - if (!ctx_processor_range.empty()) { - processor_group["ProcessorRange"] = ctx_processor_range; - max_processor_id = std::max( - max_processor_id, ctx_processor_range[1].get()); - } else { + Json resource_group; + bool new_processor_group = true; + bool id_found = false; + for (auto &item : ctx_processor_range_list) { + if (item[0] == processor_group_root) { + id_found = true; + break; + } + } + if (!id_found) { + processor_group_root = -1; + } + if (ctx_processor_range_list.size() > 2) { + ERR(UnsupportedError, "ProcessorRange list size > 2"); + } + if (ctx_processor_range_list.empty()) { size_t num_processors = std::min(num_sm, num_tasks); processor_group["ProcessorRange"] = {0, num_processors}; + resource_group["ProcessorRange"] = {0, num_processors}; max_processor_id = std::max(max_processor_id, num_processors); + } else if (processor_group_root == -1) { + processor_group_root = ctx_processor_range_list.front()[0]; + processor_group["ProcessorRange"] = ctx_processor_range_list.front()[1]; + resource_group["ProcessorRange"] = ctx_processor_range_list.back()[1]; + max_processor_id = std::max( + max_processor_id, ctx_processor_range_list.front()[1][1].get()); + } else { + new_processor_group = false; + resource_group["ProcessorRange"] = + ctx_processor_range_list.back()[1]; } - Json resource_group; - resource_group["ProcessorRange"] = - processor_group["ProcessorRange"]; if (!ctx_warp_range.empty()) { - resource_group["WarpRange"] = ctx_warp_range; + resource_group["WarpRange"] = ctx_warp_range[1]; max_warp_id = - std::max(max_warp_id, ctx_warp_range[1].get()); + std::max(max_warp_id, ctx_warp_range[1][1].get()); } else { resource_group["WarpRange"] = {0, num_warps}; max_warp_id = std::max(max_warp_id, num_warps); } if (!ctx_sram_range.empty()) { - resource_group["SramRange"] = ctx_sram_range; + resource_group["SramRange"] = ctx_sram_range[1]; } else { resource_group["SramRange"] = {0, sram_bytes}; } @@ -230,11 +353,15 @@ std::string Planner::Impl::plan(bool pretty) const { {"TaskRange", {0, num_tasks}}, {"Granularity", granularity}}}; - processor_group["ResourceGroups"] = Json::array(); - processor_group["ResourceGroups"].push_back(resource_group); - processor_groups.push_back(processor_group); + if (new_processor_group) { + processor_group["ResourceGroups"] = Json::array(); + processor_group["ResourceGroups"].push_back(resource_group); + processor_groups.push_back(processor_group); + } else { + processor_groups.back()["ResourceGroups"].push_back( + resource_group); + } } - prev_ctx_id = id; first_op = false; } diff --git a/ark/api/planner_test.cpp b/ark/api/planner_test.cpp index 011b25d8d..919ba2f1d 100644 --- a/ark/api/planner_test.cpp +++ b/ark/api/planner_test.cpp @@ -7,57 +7,93 @@ #include "unittest/unittest_utils.h" ark::unittest::State test_planner_context_processor_range() { - ark::Model model; - ark::Tensor t0 = model.tensor({1}, ark::FP32); - ark::Tensor t1 = model.tensor({1}, ark::FP32); - - // node 0 - ark::Tensor t2 = model.add(t0, t1); - - ark::Tensor t3; - ark::Tensor t4; - ark::Tensor t5; { - // node 1 - ark::PlannerContext ctx(model); - ctx.processor_range(0, 4); - t3 = model.relu(t2); - - UNITTEST_EQ(ctx.get("ProcessorRange"), ark::Json({0, 4}).dump()); - - // node 2 - ctx.processor_range(2, 4); - t4 = model.sqrt(t3); - - UNITTEST_EQ(ctx.get("ProcessorRange"), ark::Json({2, 4}).dump()); - - // Invalid usage: range (0, 4) is out of previous range (2, 4) - UNITTEST_THROW(ctx.processor_range(0, 4), ark::PlanError); + ark::Model model; + ark::Tensor t0 = model.tensor({1}, ark::FP32); + ark::Tensor t1 = model.tensor({1}, ark::FP32); + + // node 0 + ark::Tensor t2 = model.add(t0, t1); + + ark::Tensor t3; + ark::Tensor t4; + ark::Tensor t5; + { + // node 1 + ark::PlannerContext ctx(model); + ctx.processor_range(0, 4); + t3 = model.relu(t2); + + UNITTEST_EQ(ctx.get("ProcessorRange"), ark::Json({0, 4}).dump()); + + // node 2 + ctx.processor_range(2, 4); + t4 = model.sqrt(t3); + + UNITTEST_EQ(ctx.get("ProcessorRange"), ark::Json({2, 4}).dump()); + + // Invalid usage: range (0, 4) is out of previous range (2, 4) + UNITTEST_THROW(ctx.processor_range(0, 4), ark::PlanError); + } + { + // node 3 + ark::PlannerContext ctx(model); + ctx.processor_range(2, 6, 2); + t5 = model.exp(t2); + + UNITTEST_EQ(ctx.get("ProcessorRange"), ark::Json({2, 6, 2}).dump()); + } + + UNITTEST_TRUE(model.verify()); + + auto compressed = model.compress(); + UNITTEST_TRUE(compressed.verify()); + + auto nodes = compressed.nodes(); + UNITTEST_EQ(nodes.size(), 4); + + UNITTEST_EQ(nodes[0]->context.size(), 0); + UNITTEST_GE(nodes[1]->context.size(), 1); + UNITTEST_EQ(nodes[1]->context.at("ProcessorRange"), ark::Json({0, 4})); + UNITTEST_GE(nodes[2]->context.size(), 1); + UNITTEST_EQ(nodes[2]->context.at("ProcessorRange"), ark::Json({2, 4})); + UNITTEST_GE(nodes[3]->context.size(), 1); + UNITTEST_EQ(nodes[3]->context.at("ProcessorRange"), + ark::Json({2, 6, 2})); } { - // node 3 + ark::Model model; + ark::Tensor t0 = model.tensor({1}, ark::FP32); + ark::Tensor t1 = model.tensor({1}, ark::FP32); + ark::PlannerContext ctx(model); - ctx.processor_range(2, 6, 2); - t5 = model.exp(t2); + ctx.processor_range(0, 10); - UNITTEST_EQ(ctx.get("ProcessorRange"), ark::Json({2, 6, 2}).dump()); - } + std::vector tensors; + for (size_t i = 0; i < 5; ++i) { + ark::PlannerContext subctx(model); + subctx.processor_range(0 * i, 2 * i); + auto t = model.add(t0, t1); + tensors.push_back(t); - UNITTEST_TRUE(model.verify()); + UNITTEST_EQ(ctx.get("ProcessorRange"), + ark::Json({0 * i, 2 * i}).dump()); + } - auto compressed = model.compress(); - UNITTEST_TRUE(compressed.verify()); + UNITTEST_TRUE(model.verify()); - auto nodes = compressed.nodes(); - UNITTEST_EQ(nodes.size(), 4); + auto compressed = model.compress(); + UNITTEST_TRUE(compressed.verify()); - UNITTEST_EQ(nodes[0]->context.size(), 0); - UNITTEST_GE(nodes[1]->context.size(), 1); - UNITTEST_EQ(nodes[1]->context.at("ProcessorRange"), ark::Json({0, 4})); - UNITTEST_GE(nodes[2]->context.size(), 1); - UNITTEST_EQ(nodes[2]->context.at("ProcessorRange"), ark::Json({2, 4})); - UNITTEST_GE(nodes[3]->context.size(), 1); - UNITTEST_EQ(nodes[3]->context.at("ProcessorRange"), ark::Json({2, 6, 2})); + auto nodes = compressed.nodes(); + UNITTEST_EQ(nodes.size(), 5); + + for (size_t i = 0; i < 5; ++i) { + UNITTEST_GE(nodes[i]->context.size(), 1); + UNITTEST_EQ(nodes[i]->context.at("ProcessorRange"), + ark::Json({0 * i, 2 * i})); + } + } return ark::unittest::SUCCESS; } diff --git a/ark/api/tensor.cpp b/ark/api/tensor.cpp index 4b03c3ac8..103fb8896 100644 --- a/ark/api/tensor.cpp +++ b/ark/api/tensor.cpp @@ -3,6 +3,7 @@ #include "ark/tensor.hpp" +#include "model/model_buffer.hpp" #include "model/model_data_type.hpp" #include "model/model_tensor.hpp" @@ -50,6 +51,44 @@ const DataType &Tensor::data_type() const { return NONE; } +Dims Tensor::torch_strides() const { + if (ref_) { + Dims st = ref_->strides(); + int ndims = st.ndims(); + std::vector tmp; + for (int i = 1; i < ndims; ++i) { + tmp.push_back(st[i]); + } + tmp.push_back(1); + for (int i = ndims - 2; i >= 0; --i) { + tmp[i] *= tmp[i + 1]; + } + return Dims(tmp); + } + return Dims(); +} + +void *Tensor::data() const { + if (ref_) { + return ref_->data(); + } + return nullptr; +} + +void *Tensor::data(void *data) { + if (ref_) { + return ref_->data(data); + } + return nullptr; +} + +bool Tensor::is_external() const { + if (ref_) { + return ref_->is_external(); + } + return false; +} + std::ostream &operator<<(std::ostream &os, const Tensor &tensor) { if (tensor.is_null()) { os << "null"; diff --git a/ark/buffer_registry.cpp b/ark/buffer_registry.cpp new file mode 100644 index 000000000..00c5ea28e --- /dev/null +++ b/ark/buffer_registry.cpp @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "buffer_registry.hpp" + +#include "gpu/gpu_logging.hpp" + +namespace ark { + +BufferRegistry &BufferRegistry::get_instance() { + static BufferRegistry instance; + return instance; +} + +void BufferRegistry::set(size_t id, void *data, int device_id, + bool is_external) { + if (data != nullptr && device_id < 0) { + gpuPointerAttributes attr; + GLOG(gpuPointerGetAttributes(&attr, data)); + device_id = attr.device; + } + buffers_[id] = + std::make_shared(data, device_id, is_external); +} + +std::shared_ptr BufferRegistry::get(size_t id) const { + auto it = buffers_.find(id); + if (it != buffers_.end()) { + return it->second; + } + return nullptr; +} + +} // namespace ark diff --git a/ark/buffer_registry.hpp b/ark/buffer_registry.hpp new file mode 100644 index 000000000..81a26e722 --- /dev/null +++ b/ark/buffer_registry.hpp @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_BUFFER_REGISTRY_HPP_ +#define ARK_BUFFER_REGISTRY_HPP_ + +#include +#include + +namespace ark { + +/// Manages addresses of all allocated buffers including externally managed +/// buffers. +class BufferRegistry { + public: + struct Info { + Info(void *data, int device_id, bool is_external) + : data(data), device_id(device_id), is_external(is_external) {} + void *data; + int device_id; + bool is_external; + }; + + ~BufferRegistry() {} + + static BufferRegistry &get_instance(); + + void set(size_t id, void *data, int device_id, bool is_external); + + std::shared_ptr get(size_t id) const; + + private: + std::unordered_map> buffers_; + BufferRegistry() {} + BufferRegistry(const BufferRegistry &) = delete; + BufferRegistry &operator=(const BufferRegistry &) = delete; +}; + +} // namespace ark + +#endif // ARK_BUFFER_REGISTRY_HPP_ diff --git a/ark/codegen.cpp b/ark/codegen.cpp index 54214277d..7ab2f5635 100644 --- a/ark/codegen.cpp +++ b/ark/codegen.cpp @@ -4,8 +4,10 @@ #include "codegen.hpp" #include +#include #include "ark/data_type.hpp" +#include "buffer_registry.hpp" #include "env.h" #include "file_io.h" #include "logging.hpp" @@ -24,7 +26,18 @@ static std::string replace( size_t pos = 0; while ((pos = result.find(kv.first, pos)) != std::string::npos) { result.replace(pos, kv.first.length(), kv.second); - pos += kv.second.length(); + if ((kv.first == "@GLOBAL_ARGS@" || kv.first == "@FUNCTION_ARGS@" || + kv.first == "@ARG_TYPES@") && + kv.second.empty()) { + size_t comma_pos = pos; + if (comma_pos >= 2 && result.substr(comma_pos - 2, 2) == ", ") { + result.erase(comma_pos - 2, 2); + pos -= 2; + } + + } else { + pos += kv.second.length(); + } } } return result; @@ -43,11 +56,11 @@ class CodeGenerator::Impl { public: Impl(const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name); + const std::set &extra_buffer_ids, const std::string &name); ~Impl() = default; private: - std::string def_op(const Json &op_json, size_t task_id, size_t op_idx); + std::pair def_op(const Json &op_json); std::string def_task(const Json &task_json); @@ -67,7 +80,10 @@ class CodeGenerator::Impl { protected: friend class CodeGenerator; + std::set op_hashes_; + std::set task_hashes_; std::map buffer_id_to_offset_; + std::set extra_buffer_ids_; std::string name_; int rank_; int world_size_; @@ -78,8 +94,11 @@ class CodeGenerator::Impl { CodeGenerator::Impl::Impl(const PlanJson &plan, const std::map &buffer_id_to_offset, + const std::set &extra_buffer_ids, const std::string &name) - : buffer_id_to_offset_(buffer_id_to_offset), name_(name) { + : buffer_id_to_offset_(buffer_id_to_offset), + extra_buffer_ids_(extra_buffer_ids), + name_(name) { rank_ = plan.at("Rank"); world_size_ = plan.at("WorldSize"); num_procs_ = plan.at("NumProcessors"); @@ -166,83 +185,170 @@ CodeGenerator::Impl::Impl(const PlanJson &plan, const std::string &template_path = ark_root + "/include/kernels/kernel_template.in"; if (!is_file(template_path)) { - ERR(InternalError, "kernel template file not found: ", template_path); + ERR(InvalidUsageError, + "kernel template file not found: ", template_path, + ". Please make sure the ARK_ROOT environment variable is set " + "correctly."); + } + + // Generate the global arguments + std::stringstream global_args_ss, function_args_ss, arg_types_ss; + for (auto buf_id : extra_buffer_ids_) { + std::string arg_name = "_ext_buf_" + std::to_string(buf_id); + global_args_ss << "void *" << arg_name << ", "; + function_args_ss << arg_name << ", "; + arg_types_ss << "void *, "; + } + std::string global_args = global_args_ss.str(); + std::string function_args = function_args_ss.str(); + std::string arg_types = arg_types_ss.str(); + if (!global_args.empty()) { + global_args.pop_back(); + global_args.pop_back(); } + if (!function_args.empty()) { + function_args.pop_back(); + function_args.pop_back(); + } + if (!arg_types.empty()) { + arg_types.pop_back(); + arg_types.pop_back(); + } + std::string template_code = read_file(template_path); std::map replacements = { {"@NUM_BLOCKS@", std::to_string(num_procs_)}, {"@NUM_WARPS_PER_BLOCK@", std::to_string(num_warps_per_proc_)}, {"@DEFINITIONS@", definitions_ss.str()}, {"@BODY@", body_ss.str()}, - {"@NAME@", (name_.empty() ? "" : "_" + name_)}, + {"@NAME@", (!name_.empty() ? "" : name_)}, + {"@GLOBAL_ARGS@", global_args}, + {"@FUNCTION_ARGS@", function_args}, + {"@ARG_TYPES@", arg_types}, }; code_ = replace(template_code, replacements); } -std::string CodeGenerator::Impl::def_op(const Json &op_json, size_t task_id, - size_t op_idx) { +std::pair CodeGenerator::Impl::def_op( + const Json &op_json) { auto op = ModelOp::deserialize(op_json); auto impl_name = op->impl_name(op_json["Config"]); auto impl_args = op->impl_args(op_json["Config"]); - std::stringstream ss; - ss << "__forceinline__ __device__ void t" << task_id << "_o" << op_idx - << "("; + std::stringstream ss_desc; size_t arg_idx = 0; for (auto &arg : impl_args) { if (arg.type_name() == "TENSOR") { auto tns = arg.value(); - ss << tns->data_type()->type_str() << "*"; + ss_desc << tns->data_type()->type_str() << "*"; } else if (arg.type_name() == "OFFSET") { - ss << "uint64_t"; + ss_desc << "uint64_t"; } else { - ss << arg.type_str(); + ss_desc << arg.type_str(); } - ss << " _" << arg_idx++ << ", "; + ss_desc << " _" << arg_idx++ << ", "; } - ss << "int _idx, int _spw) {\n " << impl_name << "("; + ss_desc << "int _idx, int _spw) {\n " << impl_name << "("; for (size_t i = 0; i < impl_args.size(); ++i) { - ss << "_" << i << ", "; + ss_desc << "_" << i << ", "; } - ss << "_idx, _spw);\n}\n"; - return ss.str(); + ss_desc << "_idx, _spw);\n}\n"; + auto desc_str = ss_desc.str(); + size_t op_hash = std::hash{}(desc_str); + std::stringstream ss; + ss << "__forceinline__ __device__ void __op_" << std::hex << op_hash + << std::dec << "("; + ss << desc_str; + return {ss.str(), op_hash}; } std::string CodeGenerator::Impl::def_task(const Json &task_json) { std::stringstream ss; - size_t op_idx = 0; + std::stringstream ss_hash_concat; + std::vector op_hash_list; for (auto &op_json : task_json["Ops"]) { - ss << this->def_op(op_json, task_json["Id"], op_idx++); + auto [def_str, hash] = this->def_op(op_json); + if (op_hashes_.find(hash) == op_hashes_.end()) { + ss << def_str; + op_hashes_.insert(hash); + } + ss_hash_concat << std::hex << hash; + op_hash_list.push_back(hash); } - ss << "__device__ void t" << task_json["Id"] - << "(char* _buf, int _idx, int _spw) {\n"; - op_idx = 0; + size_t task_hash = std::hash{}(ss_hash_concat.str()); + std::stringstream ss_desc; + auto &buf_reg = BufferRegistry::get_instance(); + size_t op_idx = 0; + std::map ptr_str_to_index; + std::vector ptr_str_list; for (auto &op_json : task_json["Ops"]) { auto op = ModelOp::deserialize(op_json); auto impl_args = op->impl_args(op_json["Config"]); - ss << " t" << task_json["Id"] << "_o" << op_idx++ << "("; - for (size_t i = 0; i < impl_args.size(); ++i) { - auto &arg = impl_args[i]; + ss_desc << " __op_" << std::hex << op_hash_list[op_idx++] << std::dec + << "("; + for (auto &arg : impl_args) { if (arg.type_name() == "TENSOR") { auto tns = arg.value(); - size_t buffer_offset = - buffer_id_to_offset_.at(tns->buffer()->id()); - size_t offset = buffer_offset + ModelOffset(tns).value(); - ss << "(" << tns->data_type()->type_str() << "*)&_buf[" - << offset << "]"; + size_t buffer_id = tns->buffer()->id(); + auto it = buffer_id_to_offset_.find(buffer_id); + auto buf_info = buf_reg.get(buffer_id); + std::string ptr_str; + if ((buf_info && buf_info->is_external) || + (it == buffer_id_to_offset_.end())) { + ptr_str = "_ext_buf_" + std::to_string(buffer_id); + } else { + size_t buffer_offset; + buffer_offset = it->second; + size_t offset = buffer_offset + ModelOffset(tns).value(); + ptr_str = "&_buf[" + std::to_string(offset) + "]"; + } + size_t ptr_idx; + if (ptr_str_to_index.find(ptr_str) == ptr_str_to_index.end()) { + ptr_idx = ptr_str_to_index.size(); + ptr_str_to_index[ptr_str] = ptr_idx; + ptr_str_list.push_back(ptr_str); + } else { + ptr_idx = ptr_str_to_index[ptr_str]; + } + ss_desc << "(" << tns->data_type()->type_str() << "*)_" + << ptr_idx; } else if (arg.type_name() == "OFFSET") { auto moff = arg.value(); - size_t buffer_offset = - buffer_id_to_offset_.at(moff.buffer_id()); + size_t buffer_id = moff.buffer_id(); + auto buf_info = buf_reg.get(buffer_id); + if (buf_info && buf_info->is_external) { + ERR(InternalError, "cannot offset external buffer"); + } + size_t buffer_offset; + auto it = buffer_id_to_offset_.find(buffer_id); + if (it == buffer_id_to_offset_.end()) { + ERR(InternalError, "buffer ID not found: ", buffer_id); + } + buffer_offset = it->second; size_t offset = buffer_offset + moff.value(); - ss << offset; + ss_desc << offset; } else { - ss << arg.serialize().begin().value(); + ss_desc << arg.serialize().begin().value(); } - ss << ", "; + ss_desc << ", "; + } + ss_desc << "_idx, _spw);\n"; + } + if (task_hashes_.find(task_hash) == task_hashes_.end()) { + ss << "__device__ void __task_" << std::hex << task_hash << std::dec + << "("; + for (size_t i = 0; i < ptr_str_list.size(); ++i) { + ss << "void *_" << i << ", "; } - ss << "_idx, _spw);\n"; + ss << "int _idx, int _spw) {\n" << ss_desc.str() << "}\n"; + task_hashes_.insert(task_hash); } - ss << "}\n"; + ss << "__forceinline__ __device__ void __t" << task_json["Id"] + << "(char *_buf, int _idx, int _spw, @GLOBAL_ARGS@) {\n"; + ss << " __task_" << std::hex << task_hash << std::dec << "("; + for (auto &ptr_str : ptr_str_list) { + ss << ptr_str << ", "; + } + ss << "_idx, _spw);\n}\n"; return ss.str(); } @@ -265,7 +371,8 @@ std::string CodeGenerator::Impl::task_seq( ss << "task_seq<" << proc_b << ", " << proc_e << ", " << proc_s << ", " << proc_cur << ", " << task_b << ", " << task_e << ", " << task_s << ", " << task_gran << ", " << num_slots << ", " << slot_num_warps << ", " - << slot_sram_bytes << ", t" << task_id << ">(_buf);\n"; + << slot_sram_bytes << ", __t" << task_id + << ">(_buf, @FUNCTION_ARGS@);\n"; return ss.str(); } @@ -288,10 +395,14 @@ std::string CodeGenerator::Impl::resource_group( size_t proc_b = *rg_proc_range.begin(); size_t proc_e = *rg_proc_range.end(); size_t proc_s = rg_proc_range.step(); + std::map task_infos_map; + for (auto &task_info : task_infos) { + task_infos_map[task_info.at("Id").get()] = task_info; + } std::stringstream ss; for (auto &tg : rg_json["TaskGroups"]) { size_t task_id = tg["TaskId"]; - auto &task_info = task_infos[task_id]; + auto &task_info = task_infos_map.at(task_id); Range task_range(tg["TaskRange"][0], tg["TaskRange"][1]); size_t task_gran = tg["Granularity"]; size_t num_warps_per_task = task_info["NumWarps"]; @@ -305,7 +416,7 @@ std::string CodeGenerator::Impl::resource_group( n_slots = total_warps / num_warps_per_task; } if (n_slots == 0) { - ERR(PlanError, "not enough resources for task group"); + ERR(PlanError, "not enough resources for task group: ", tg.dump()); } size_t task_b = *task_range.begin(); @@ -430,8 +541,9 @@ std::string CodeGenerator::Impl::sync_process_range(const Range &range, CodeGenerator::CodeGenerator( const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name) - : impl_(std::make_shared(plan, buffer_id_to_offset, name)) {} + const std::set &extra_buffer_ids, const std::string &name) + : impl_(std::make_shared(plan, buffer_id_to_offset, extra_buffer_ids, + name)) {} std::string CodeGenerator::code() const { return impl_->code_; } diff --git a/ark/codegen.hpp b/ark/codegen.hpp index 4f8307e7e..9f5947deb 100644 --- a/ark/codegen.hpp +++ b/ark/codegen.hpp @@ -6,6 +6,7 @@ #include #include +#include #include #include "model/model_json.hpp" @@ -16,6 +17,7 @@ class CodeGenerator { public: CodeGenerator(const PlanJson &plan, const std::map &buffer_id_to_offset, + const std::set &extra_buffer_ids, const std::string &name = "ark_kernel"); ~CodeGenerator() = default; diff --git a/ark/context_impl.cpp b/ark/context_impl.cpp index 9a2692ea8..c4f95f2c3 100644 --- a/ark/context_impl.cpp +++ b/ark/context_impl.cpp @@ -52,4 +52,8 @@ bool Context::Impl::has(const std::string& key) const { return context_manager_->has(key); } +Json Context::Impl::dump() const { + return context_manager_->dump(); +} + } // namespace ark diff --git a/ark/context_impl.hpp b/ark/context_impl.hpp index 1a77891b9..b79353296 100644 --- a/ark/context_impl.hpp +++ b/ark/context_impl.hpp @@ -17,10 +17,12 @@ class Context::Impl { Json get(const std::string& key) const; - void set(const std::string& key, const Json& value_json, ContextType type); + void set(const std::string& key, const Json& value_json, ContextType type = ContextType::Overwrite); bool has(const std::string& key) const; + Json dump() const; + protected: friend class Context; diff --git a/ark/cpu_timer.cpp b/ark/cpu_timer.cpp index c740de5f3..129ba7bd2 100644 --- a/ark/cpu_timer.cpp +++ b/ark/cpu_timer.cpp @@ -16,20 +16,4 @@ double cpu_timer(void) { return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec; } -// Sleep in second. -int cpu_timer_sleep(double sec) { - struct timespec tspec; - tspec.tv_sec = (time_t)sec; - tspec.tv_nsec = (long)((sec - tspec.tv_sec) * 1.0e9); - return nanosleep(&tspec, 0); -} - -// Sleep in nanosecond. -int cpu_ntimer_sleep(long nsec) { - struct timespec tspec; - tspec.tv_sec = 0; - tspec.tv_nsec = nsec; - return nanosleep(&tspec, 0); -} - } // namespace ark diff --git a/ark/cpu_timer.h b/ark/cpu_timer.h index 52bf63d92..eaac94061 100644 --- a/ark/cpu_timer.h +++ b/ark/cpu_timer.h @@ -8,10 +8,6 @@ namespace ark { // Measure current time in second. double cpu_timer(void); -// Sleep in second. -int cpu_timer_sleep(double sec); -// Sleep in nanosecond. -int cpu_ntimer_sleep(long nsec); } // namespace ark diff --git a/ark/env.cpp b/ark/env.cpp index d8322378f..f9e7355ff 100644 --- a/ark/env.cpp +++ b/ark/env.cpp @@ -10,11 +10,11 @@ #define DEFAULT_ARK_LOG_LEVEL "INFO" #define DEFAULT_ARK_ROOT "/usr/local/ark" #define DEFAULT_ARK_TMP "/tmp/ark" -#define DEFAULT_ARK_KEEP_TMP true +#define DEFAULT_ARK_KEEP_TMP false #define DEFAULT_ARK_HOSTFILE_NAME "hostfile" #define DEFAULT_ARK_NUM_RANKS_PER_HOST 8 #define DEFAULT_ARK_DISABLE_IB false -#define DEFAULT_ARK_IGNORE_BINARY_CACHE true +#define DEFAULT_ARK_IGNORE_BINARY_CACHE false #define DEFAULT_ARK_ENFORCE_PLAN_PATH "" #define DEFAULT_ARK_MSCCLPP_PORT 50051 diff --git a/ark/gpu/gpu.hpp b/ark/gpu/gpu.hpp index 531d6c7ee..dbcd50f3e 100644 --- a/ark/gpu/gpu.hpp +++ b/ark/gpu/gpu.hpp @@ -53,6 +53,8 @@ ARK_GPU_DEFINE_TYPE_ALIAS(gpuModule, CUmodule, hipModule_t); ARK_GPU_DEFINE_TYPE_ALIAS(gpuFunction, CUfunction, hipFunction_t); ARK_GPU_DEFINE_TYPE_ALIAS(gpuFunctionAttribute, CUfunction_attribute, hipFunction_attribute); +ARK_GPU_DEFINE_TYPE_ALIAS(gpuPointerAttributes, cudaPointerAttributes, + hipPointerAttribute_t); // runtime API ARK_GPU_DEFINE_CONSTANT_ALIAS(gpuSuccess, cudaSuccess, hipSuccess); @@ -126,6 +128,8 @@ ARK_GPU_DEFINE_CONSTANT_ALIAS(gpuPointerAttributeSyncMemops, ARK_GPU_DEFINE_FUNC_ALIAS(gpuGetErrorString, cudaGetErrorString, hipGetErrorString); ARK_GPU_DEFINE_FUNC_ALIAS(gpuGetLastError, cudaGetLastError, hipGetLastError); +ARK_GPU_DEFINE_FUNC_ALIAS(gpuPointerGetAttributes, cudaPointerGetAttributes, + hipPointerGetAttributes); ARK_GPU_DEFINE_FUNC_ALIAS(gpuDeviceGetAttribute, cudaDeviceGetAttribute, hipDeviceGetAttribute); ARK_GPU_DEFINE_FUNC_ALIAS(gpuDeviceSynchronize, cudaDeviceSynchronize, diff --git a/ark/gpu/gpu_event.cpp b/ark/gpu/gpu_event.cpp index 06779b91a..9f91e384d 100644 --- a/ark/gpu/gpu_event.cpp +++ b/ark/gpu/gpu_event.cpp @@ -7,21 +7,25 @@ #include "gpu/gpu_manager.hpp" namespace ark { + class GpuEvent::Impl { public: - Impl(bool disable_timing); + Impl(int device_id, bool disable_timing); ~Impl(); Impl(const Impl&) = delete; Impl& operator=(const Impl&) = delete; + int device_id() const { return device_id_; } void record(gpuStream stream); float elapsed_msec(const GpuEvent& other) const; private: + int device_id_; gpuEvent event_; }; -GpuEvent::Impl::Impl(bool disable_timing) { +GpuEvent::Impl::Impl(int device_id, bool disable_timing) + : device_id_(device_id) { unsigned int flags = 0; if (disable_timing) { flags |= gpuEventDisableTiming; @@ -41,8 +45,10 @@ float GpuEvent::Impl::elapsed_msec(const GpuEvent& other) const { return elapsed; } -GpuEvent::GpuEvent(bool disable_timing) - : pimpl_(std::make_shared(disable_timing)) {} +GpuEvent::GpuEvent(int device_id, bool disable_timing) + : pimpl_(std::make_shared(device_id, disable_timing)) {} + +int GpuEvent::device_id() const { return pimpl_->device_id(); } void GpuEvent::record(gpuStream stream) { pimpl_->record(stream); } diff --git a/ark/gpu/gpu_event.hpp b/ark/gpu/gpu_event.hpp index bd2a7c952..2180f1320 100644 --- a/ark/gpu/gpu_event.hpp +++ b/ark/gpu/gpu_event.hpp @@ -19,13 +19,14 @@ class GpuEvent { GpuEvent(const GpuEvent &) = delete; GpuEvent &operator=(const GpuEvent &) = delete; + int device_id() const; void record(gpuStream stream); float elapsed_msec(const GpuEvent &other) const; protected: friend class GpuManager; - GpuEvent(bool disable_timing = false); + GpuEvent(int device_id, bool disable_timing = false); private: class Impl; diff --git a/ark/gpu/gpu_kernel.cpp b/ark/gpu/gpu_kernel.cpp index d4412f80e..a474b32a7 100644 --- a/ark/gpu/gpu_kernel.cpp +++ b/ark/gpu/gpu_kernel.cpp @@ -15,24 +15,18 @@ namespace ark { GpuKernel::GpuKernel(int gpu_id, const std::string& code, const std::array& block_dim, - const std::array& grid_dim, size_t smem_bytes, - const std::string& kernel_name) { - this->init(gpu_id, code, block_dim, grid_dim, smem_bytes, kernel_name); + const std::array& grid_dim, size_t smem_bytes) { + this->init(gpu_id, code, block_dim, grid_dim, smem_bytes); } void GpuKernel::init(int gpu_id, const std::string& code, const std::array& block_dim, - const std::array& grid_dim, size_t smem_bytes, - const std::string& kernel_name) { + const std::array& grid_dim, size_t smem_bytes) { gpu_manager_ = GpuManager::get_instance(gpu_id); code_ = code; block_dim_ = block_dim; grid_dim_ = grid_dim; smem_bytes_ = smem_bytes; - kernel_name_ = kernel_name; - if (kernel_name_.size() == 0) { - ERR(InvalidUsageError, "Invalid kernel name: ", kernel_name_); - } } void GpuKernel::compile() { @@ -45,21 +39,30 @@ void GpuKernel::compile() { } bin_ = gpu_compile({code_}, gpu_manager_->info().arch, max_reg_cnt); GLOG_DRV(gpuModuleLoadData(&module_, bin_.c_str())); - GLOG_DRV(gpuModuleGetFunction(&function_, module_, kernel_name_.c_str())); - - int static_smem_size_bytes; - GLOG_DRV(gpuFuncGetAttribute(&static_smem_size_bytes, - gpuFuncAttributeSharedSizeBytes, function_)); - int dynamic_smem_size_bytes = smem_bytes_ - static_smem_size_bytes; - GLOG_DRV(gpuFuncSetAttribute(function_, - gpuFuncAttributeMaxDynamicSharedSizeBytes, - dynamic_smem_size_bytes)); } -void GpuKernel::launch(gpuStream stream, std::vector& args) { +void GpuKernel::launch(const std::string& kernel_name, gpuStream stream, + std::vector& args) { if (!this->is_compiled()) { ERR(InvalidUsageError, "Kernel is not compiled yet."); } + if (kernel_name.size() == 0) { + ERR(InvalidUsageError, "Invalid kernel name: ", kernel_name); + } + if (kernel_name_ != kernel_name) { + GLOG_DRV( + gpuModuleGetFunction(&function_, module_, kernel_name.c_str())); + + int static_smem_size_bytes; + GLOG_DRV(gpuFuncGetAttribute(&static_smem_size_bytes, + gpuFuncAttributeSharedSizeBytes, + function_)); + int dynamic_smem_size_bytes = smem_bytes_ - static_smem_size_bytes; + GLOG_DRV(gpuFuncSetAttribute(function_, + gpuFuncAttributeMaxDynamicSharedSizeBytes, + dynamic_smem_size_bytes)); + kernel_name_ = kernel_name; + } gpu_manager_->launch(function_, grid_dim_, block_dim_, smem_bytes_, stream, args.data(), nullptr); GLOG(gpuGetLastError()); diff --git a/ark/gpu/gpu_kernel.hpp b/ark/gpu/gpu_kernel.hpp index 5308cfead..1e02cc7a1 100644 --- a/ark/gpu/gpu_kernel.hpp +++ b/ark/gpu/gpu_kernel.hpp @@ -18,19 +18,18 @@ class GpuKernel { public: GpuKernel(int gpu_id, const std::string& codes, const std::array& block_dim, - const std::array& grid_dim, size_t smem_bytes, - const std::string& kernel_name); + const std::array& grid_dim, size_t smem_bytes); void init(int gpu_id, const std::string& codes, const std::array& block_dim, - const std::array& grid_dim, size_t smem_bytes, - const std::string& kernel_name); + const std::array& grid_dim, size_t smem_bytes); void compile(); - void launch(gpuStream stream, std::vector& args); + void launch(const std::string& kernel_name, gpuStream stream, + std::vector& args); gpuDeviceptr get_global(const std::string& name, bool ignore_not_found = false) const; - bool is_compiled() const { return function_ != nullptr; } + bool is_compiled() const { return !bin_.empty(); } protected: std::shared_ptr gpu_manager_; diff --git a/ark/gpu/gpu_kernel_test.cpp b/ark/gpu/gpu_kernel_test.cpp index 342ef9656..7b9f7f176 100644 --- a/ark/gpu/gpu_kernel_test.cpp +++ b/ark/gpu/gpu_kernel_test.cpp @@ -8,13 +8,14 @@ const std::string void_kernel = "extern \"C\" __global__ void kernel() {}"; ark::unittest::State test_gpu_kernel() { - ark::GpuKernel kernel(0, void_kernel, {1, 1, 1}, {1, 1, 1}, 0, "kernel"); + ark::GpuKernel kernel(0, void_kernel, {1, 1, 1}, {1, 1, 1}, 0); UNITTEST_TRUE(!kernel.is_compiled()); kernel.compile(); UNITTEST_TRUE(kernel.is_compiled()); std::vector args; + UNITTEST_THROW(kernel.launch("", nullptr, args), ark::InvalidUsageError); for (int i = 0; i < 10; i++) { - kernel.launch(nullptr, args); + kernel.launch("kernel", nullptr, args); } return ark::unittest::SUCCESS; } diff --git a/ark/gpu/gpu_manager.cpp b/ark/gpu/gpu_manager.cpp index 2b5be490b..9c49cfbc6 100644 --- a/ark/gpu/gpu_manager.cpp +++ b/ark/gpu/gpu_manager.cpp @@ -118,7 +118,8 @@ std::shared_ptr GpuManager::malloc_host(size_t bytes, } std::shared_ptr GpuManager::create_event(bool disable_timing) const { - return std::shared_ptr(new GpuEvent(disable_timing)); + return std::shared_ptr( + new GpuEvent(pimpl_->gpu_id_, disable_timing)); } std::shared_ptr GpuManager::create_stream() const { diff --git a/ark/gpu/gpu_manager.hpp b/ark/gpu/gpu_manager.hpp index eeeda4d94..71f47e670 100644 --- a/ark/gpu/gpu_manager.hpp +++ b/ark/gpu/gpu_manager.hpp @@ -16,7 +16,7 @@ namespace ark { class GpuManager { public: - static std::shared_ptr get_instance(int gpu_id); + static std::shared_ptr get_instance(int device_id); GpuManager(const GpuManager &) = delete; ~GpuManager() = default; @@ -54,7 +54,7 @@ class GpuManager { }; private: - GpuManager(int gpu_id); + GpuManager(int device_id); class Impl; std::shared_ptr pimpl_; diff --git a/ark/include/ark/context.hpp b/ark/include/ark/context.hpp index f3eef2836..aaa22bd3a 100644 --- a/ark/include/ark/context.hpp +++ b/ark/include/ark/context.hpp @@ -17,9 +17,9 @@ enum class ContextType { class Context { public: /// - /// Construct an empty context for the given model. + /// Context handler of the given model. /// - /// @param model The model to create the context for. + /// @param model The model to manipulate the context for. /// Context(Model& model); @@ -78,6 +78,9 @@ class Context { void set(const std::string& key, const std::string& value, ContextType type = ContextType::Overwrite); + /// Return the entire context stacks as a JSON format string. + std::string dump() const; + protected: friend class PlannerContext; diff --git a/ark/include/ark/executor.hpp b/ark/include/ark/executor.hpp index 14ca87618..2e97ffe78 100644 --- a/ark/include/ark/executor.hpp +++ b/ark/include/ark/executor.hpp @@ -9,18 +9,20 @@ #include #include #include +#include #include namespace ark { using Stream = void *; +class GpuMemory; + /// Convenience class for executing a model. class Executor { public: /// Constructor. - Executor(int device_id, Stream stream, const std::string &name, - const std::string &plan, bool loop_mode = true); + Executor(); /// Destructor. ~Executor(); @@ -31,23 +33,33 @@ class Executor { /// Return the stream of the executor. Stream stream() const; + /// Return the buffer of the executor. + std::shared_ptr buffer() const; + /// Return the plan string. std::string plan() const; + /// Return the name of the executor. + std::string name() const; + /// Compile the model. This must be called before `launch()`. - void compile(); + void compile(const std::string &plan, int device_id, + const std::string &name = "executor"); - /// Launch the model (not running yet). This must be called after - /// `compile()`. - void launch(); + /// Launch the executor. This must be called after `compile()`. + void launch(const std::unordered_map &placeholder_data = {}, + Stream stream = nullptr, bool loop_mode = true, + bool record = false); - /// Run the model for `iter` iterations. - void run(int iter); + /// Run the executor for `iter` iterations. + void run( + int iter, + const std::unordered_map &placeholder_data = {}); /// Wait for the previous run to finish. void wait(int64_t max_spin_count = -1); - /// Stop the model and return the elapsed time in milliseconds. + /// Stop the executor and return the elapsed time in milliseconds. /// Once this is called, we need to call `launch()` again to run the model /// again. float stop(int64_t max_spin_count = -1); @@ -62,7 +74,7 @@ class Executor { bool destroyed() const; /// Return the raw virtual address of the tensor. - uintptr_t tensor_address(const Tensor &tensor) const; + void *tensor_address(const Tensor &tensor) const; template void tensor_read(const Tensor &tensor, std::vector &data, @@ -93,10 +105,18 @@ class Model; class DefaultExecutor : public Executor { public: - DefaultExecutor( - const Model &model, int device_id = -1, Stream stream = nullptr, - const std::vector &config_rules = {}, - const std::string &name = "DefaultExecutor", bool loop_mode = true); + DefaultExecutor(const Model &model, int device_id = -1, + Stream stream = nullptr, + const std::vector &config_rules = {}, + const std::string &name = "DefaultExecutor", + bool loop_mode = true, bool record = false); + + /// Launch the default executor. + void launch( + const std::unordered_map &placeholder_data = {}); + + private: + bool record_; }; } // namespace ark diff --git a/ark/include/ark/model.hpp b/ark/include/ark/model.hpp index 3c4f22e22..e1b1f462b 100644 --- a/ark/include/ark/model.hpp +++ b/ark/include/ark/model.hpp @@ -76,6 +76,37 @@ class Model : public ModelGraph { const Dims &padded_shape = {}, int rank = -1, const std::string &name = ""); + /// + /// Returns a tensor object associated with an external buffer. + /// + /// @param shape Shape of the tensor, where the data of interest is. + /// @param dtype Type of the tensor data. + /// @param strides Strides of each dimension of the tensor, which may be + /// different from the shape. @p strides can be considered as the actual + /// shape of the underlying data buffer. + /// @param offsets Offsets of the tensor. The data of interest starts at + /// @p offsets and ends at @p offsets + @p padded_shape. + /// @param padded_shape Padded shape of the tensor. Padding is used to + /// reserve extra space for the tensor when computation requires it. + /// Data on the padded region is allowed to be accessed by computation, + /// but it is not considered as the data of interest. The padded region is + /// initialized to zero only once when the Executor is launched. The padded + /// shape should be greater than or equal to the @p shape, and the + /// @p strides should be greater than or equal to the padded shape. If the + /// @p strides are not provided, they are set to the padded shape. If the + /// padded shape is not provided, it is set to the @p shape. + /// @param rank Rank of the tensor. -1 means the rank of this model. + /// @param name Name of the tensor. + /// @param data Address of data to pass through placeholder. If provided, + /// this buffer is registered with the ExternalBufferRegistry and associated + /// with the tensor. + /// @return Pointer to a tensor object that references the external buffer. + /// + Tensor placeholder(const Dims &shape, const DataType &data_type, + const Dims &strides = {}, const Dims &offsets = {}, + const Dims &padded_shape = {}, int rank = -1, + void *data = nullptr, const std::string &name = ""); + Tensor refer(Tensor input, const Dims &shape = {}, const Dims &strides = {}, const Dims &offsets = {}, const Dims &padded_shape = {}, const std::string &name = ""); @@ -254,7 +285,6 @@ class Model : public ModelGraph { Tensor local_all_reduce(Tensor input, int gpu_id, int gpu_num, const std::string &name = ""); - }; } // namespace ark diff --git a/ark/include/ark/tensor.hpp b/ark/include/ark/tensor.hpp index 747ce5fea..aa8dcaa68 100644 --- a/ark/include/ark/tensor.hpp +++ b/ark/include/ark/tensor.hpp @@ -50,6 +50,14 @@ class Tensor { Dims padded_shape() const; const DataType &data_type() const; + + Dims torch_strides() const; + + void *data() const; + + void *data(void *data); + + bool is_external() const; }; const Tensor NullTensor; @@ -58,4 +66,13 @@ std::ostream &operator<<(std::ostream &os, const Tensor &tensor); } // namespace ark +namespace std { +template <> +struct hash { + size_t operator()(const ark::Tensor &t) const noexcept { + return t.id(); + } +}; +} // namespace std + #endif // ARK_TENSOR_HPP diff --git a/ark/include/kernels/common/broadcast.h b/ark/include/kernels/common/broadcast.h index 858938613..86e84e5d0 100644 --- a/ark/include/kernels/common/broadcast.h +++ b/ark/include/kernels/common/broadcast.h @@ -400,6 +400,8 @@ struct Broadcast1 { static constexpr size_t StepSize = NelemPerThread * UnitOp::NumThreads; + UnitOp::sync_threads(); + for (size_t tid = NelemPerThread * UnitOp::thread_id();; tid += StepSize) { size_t tid_n = tid / UnitOutDims::CHW; @@ -435,8 +437,6 @@ struct Broadcast1 { } Intrinsic::compute(&out[idx_out], &in[idx_in]); } - - UnitOp::sync_threads(); } }; @@ -469,6 +469,8 @@ struct Broadcast2 { static constexpr size_t StepSize = NelemPerThread * UnitOp::NumThreads; + UnitOp::sync_threads(); + for (size_t tid = NelemPerThread * UnitOp::thread_id();; tid += StepSize) { size_t tid_n = tid / UnitOutDims::CHW; @@ -518,8 +520,6 @@ struct Broadcast2 { } Intrinsic::compute(&out[idx_out], &in0[idx_in0], &in1[idx_in1]); } - - UnitOp::sync_threads(); } }; diff --git a/ark/include/kernels/common/ewise.h b/ark/include/kernels/common/ewise.h index de52f4584..c77bb7abf 100644 --- a/ark/include/kernels/common/ewise.h +++ b/ark/include/kernels/common/ewise.h @@ -31,6 +31,8 @@ struct Ewise1 { int uh = UnitOp::uop_idx_h(uop_idx); int uw = UnitOp::uop_idx_w(uop_idx); + UnitOp::sync_threads(); + for (int tid = UnitOp::thread_id();; tid += UnitOp::NumThreads) { int tid_w = (tid * NelemPerThread) % UnitOutDims::W; int tid_h = @@ -50,8 +52,6 @@ struct Ewise1 { CompType::compute(out, in, idx_n, idx_c, idx_h, idx_w); } - - UnitOp::sync_threads(); } }; diff --git a/ark/include/kernels/gemm_ck.h b/ark/include/kernels/gemm_ck.h index 4054f2d37..478419691 100644 --- a/ark/include/kernels/gemm_ck.h +++ b/ark/include/kernels/gemm_ck.h @@ -376,7 +376,7 @@ template + typename UnitOp> struct CkGemm { static_assert(LeadingDimA >= 0, ""); static_assert(LeadingDimB >= 0, ""); @@ -386,7 +386,6 @@ struct CkGemm { static_assert(ProblemSizeK >= 0, ""); static_assert(TileSizeM >= 0, ""); static_assert(TileSizeN >= 0, ""); - static_assert(TileSizeK >= 0, ""); using AccumulateType = fp32; @@ -514,13 +513,13 @@ template + typename UnitOp> DEVICE void gemm_ck(DataTypeC *C, DataTypeA *A, DataTypeB *B, int uop_idx, int smem_per_warp) { using CkGemm = CkGemm; + ProblemSizeK, TileSizeM, TileSizeN, UnitOp>; CkGemm gemm; gemm.Run(C, A, B, uop_idx, smem_per_warp); } diff --git a/ark/include/kernels/gemm_cutlass.h b/ark/include/kernels/gemm_cutlass.h index 80d377290..e87c7ddd2 100644 --- a/ark/include/kernels/gemm_cutlass.h +++ b/ark/include/kernels/gemm_cutlass.h @@ -45,754 +45,53 @@ struct GemmThreadblockSwizzle { } }; -template -struct GemmConfiguration; - -//////////////////////////////////////////////////////////////////////////////// -/// SM70 FP16 -//////////////////////////////////////////////////////////////////////////////// - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 256, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<8, 8, 4>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 2>; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// SM80 FP16 -//////////////////////////////////////////////////////////////////////////////// - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 256, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<256, 64, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<64, 256, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<64, 128, 64>, - cutlass::gemm::GemmShape<32, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 4>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 64, 64>, - cutlass::gemm::GemmShape<64, 32, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 4>; -}; +template +struct InstructionShape; -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<32, 32, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 6>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 256, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<256, 128, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 128, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 4>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<256, 64, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 4>; +template +struct InstructionShape { + using value = cutlass::gemm::GemmShape<8, 8, 4>; }; -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<64, 256, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 4>; +template +struct InstructionShape { + static constexpr int K = std::is_same_v ? 8 : 16; + using value = cutlass::gemm::GemmShape<16, 8, K>; }; -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<64, 128, 32>, - cutlass::gemm::GemmShape<32, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 6>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 64, 32>, - cutlass::gemm::GemmShape<64, 32, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 6>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<32, 32, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 10>; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// SM80 BF16 -//////////////////////////////////////////////////////////////////////////////// - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, - ElementOutput, LayoutC, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 256, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, - ElementOutput, LayoutC, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, - ElementOutput, LayoutC, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, - ElementOutput, LayoutC, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 64, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, - ElementOutput, LayoutC, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 256, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, - ElementOutput, LayoutC, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 128, 64>, - cutlass::gemm::GemmShape<32, 64, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 4>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, - ElementOutput, LayoutC, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 64, 64>, - cutlass::gemm::GemmShape<64, 32, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 4>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, - ElementOutput, LayoutC, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<32, 32, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 6>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, - ElementOutput, LayoutC, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 256, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, - ElementOutput, LayoutC, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 128, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, - ElementOutput, LayoutC, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 4>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, - ElementOutput, LayoutC, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 64, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 4>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, - ElementOutput, LayoutC, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 256, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 4>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, - ElementOutput, LayoutC, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 128, 32>, - cutlass::gemm::GemmShape<32, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 6>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, - ElementOutput, LayoutC, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 64, 32>, - cutlass::gemm::GemmShape<64, 32, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 6>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, - LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<32, 32, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 10>; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// SM80 FP32 -//////////////////////////////////////////////////////////////////////////////// - -template -struct GemmConfiguration< - UnitOp, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, float, LayoutA, - float, LayoutB, float, LayoutC, cutlass::gemm::GemmShape<128, 256, 32>> { - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - float, LayoutA, float, LayoutB, ElementOutput, LayoutC, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 256, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration< - UnitOp, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, float, LayoutA, - float, LayoutB, float, LayoutC, cutlass::gemm::GemmShape<128, 128, 32>> { - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - float, LayoutA, float, LayoutB, ElementOutput, LayoutC, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator>, - ark::GemmThreadblockSwizzle, 3>; -}; - -template -struct GemmConfiguration> { - using ElementOutput = float; - using ElementAccumulator = float; - - using Gemm = cutlass::gemm::device::Gemm< - float, LayoutA, float, LayoutB, ElementOutput, LayoutC, - ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<32, 32, 32>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, +template +struct GemmConfiguration { + // Supports float, half, and bfloat16. + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v, + "ElementA must be float, half, or bfloat16"); + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v, + "ElementB must be float, half, or bfloat16"); + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v, + "ElementC must be float, half, or bfloat16"); + using ElementAccumulator = typename std::conditional_t< + std::is_same_v, float, ElementC>; + static constexpr int NumWarps = UnitOp::NumWarps; + static constexpr int NumWarpsN = + 1 << math::div_up::value, 2>::value; + static constexpr int NumWarpsM = NumWarps / NumWarpsN; + using WarpShape = + cutlass::gemm::GemmShape; + using InstShape = typename InstructionShape::value; + using Gemm = cutlass::gemm::device::Gemm< + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + ElementAccumulator, OperatorClass, ArchTag, Shape, WarpShape, InstShape, + cutlass::epilogue::thread::LinearCombination< + ElementC, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator>, ark::GemmThreadblockSwizzle, 3>; }; @@ -901,7 +200,7 @@ template + typename UnitOp> DEVICE void gemm_cuda(DataTypeC *C, DataTypeA *A, DataTypeB *B, int uop_idx, int smem_per_warp) { #if (ARK_TARGET_CUDA_ARCH == 60) @@ -924,6 +223,7 @@ DEVICE void gemm_cuda(DataTypeC *C, DataTypeA *A, DataTypeB *B, int uop_idx, cutlass::layout::RowMajor>::type; using LayoutC = cutlass::layout::RowMajor; + static constexpr int TileSizeK = std::is_same_v ? 32 : 64; using GemmKernel = typename ark::GemmConfiguration< UnitOp, cutlass::arch::OpClassTensorOp, ArchTag, DataTypeA, LayoutA, DataTypeB, LayoutB, DataTypeC, LayoutC, @@ -960,6 +260,8 @@ DEVICE void gemm_cuda(DataTypeC *C, DataTypeA *A, DataTypeB *B, int uop_idx, UnitOp::template shared_memory( smem_per_warp); + UnitOp::sync_threads(); + GemmKernel gemm_kernel{}; gemm_kernel(params, *ps); } @@ -1105,7 +407,7 @@ template + typename UnitOp> DEVICE void gemm_cutlass(DataTypeC *C, DataTypeA *A, DataTypeB *B, int uop_idx, int smem_per_warp) { using CutDataTypeA = typename cutlass::platform::conditional< @@ -1134,13 +436,13 @@ DEVICE void gemm_cutlass(DataTypeC *C, DataTypeA *A, DataTypeB *B, int uop_idx, ARK_TARGET_CUDA_ARCH == 80) gemm_cuda( - pC, pA, pB, uop_idx, smem_per_warp); + ProblemSizeK, TileSizeM, TileSizeN, UnitOp>(pC, pA, pB, uop_idx, + smem_per_warp); #elif (ARK_TARGET_CUDA_ARCH == 90) gemm_cuda_90(pC, pA, pB, uop_idx, smem_per_warp); + UnitOp>(pC, pA, pB, uop_idx, smem_per_warp); #else static_assert(false, "Unsupported CUDA arch."); #endif diff --git a/ark/include/kernels/kernel_template.in b/ark/include/kernels/kernel_template.in index a8a56f141..a05e143d3 100644 --- a/ark/include/kernels/kernel_template.in +++ b/ark/include/kernels/kernel_template.in @@ -6,8 +6,8 @@ using namespace ark; template -__forceinline__ __device__ void task_seq(char *_buf) { + void (*task)(char*, int, int, @ARG_TYPES@)> +__forceinline__ __device__ void task_seq(char *_buf, @GLOBAL_ARGS@) { if (math::geq(blockIdx.x) && math::le(blockIdx.x) && ((blockIdx.x - ProcBegin) % ProcStep == 0)) { constexpr size_t SlotNumThreads = SlotNumWarps * Arch::ThreadsPerWarp; @@ -23,7 +23,7 @@ __forceinline__ __device__ void task_seq(char *_buf) { size_t task_id = task_id_base + TaskStep * (t % TaskGranularity + t / TaskGranularity * TaskGranularity * NumProcs); if (task_id >= TaskEnd) break; - task(_buf, task_id, SramBytesPerWarp); + task(_buf, task_id, SramBytesPerWarp, @FUNCTION_ARGS@); } } } @@ -33,12 +33,12 @@ __device__ sync::State ARK_LOOP_SYNC_STATE; @DEFINITIONS@ -__device__ void ark_body(char *_buf, int _iter) { +__device__ void ark_body(char *_buf, int _iter, @GLOBAL_ARGS@) { @BODY@ } extern "C" __global__ __launch_bounds__(ARK_WARPS_PER_BLOCK * Arch::ThreadsPerWarp, 1) -void ark_loop_kernel@NAME@(char *_buf, int *_iter) { +void ark_loop_kernel@NAME@(char *_buf, int *_iter, @GLOBAL_ARGS@) { int *shared_mem = (int *)_ARK_SMEM; for (int i = threadIdx.x; i < ARK_SMEM_RESERVED_BYTES / sizeof(int); i += blockDim.x) { shared_mem[i] = 0; @@ -52,10 +52,10 @@ void ark_loop_kernel@NAME@(char *_buf, int *_iter) { sync_gpu<@NUM_BLOCKS@>(ARK_LOOP_SYNC_STATE); if (ARK_ITER < 0) return; - ark_body(_buf, 0); + ark_body(_buf, 0, @FUNCTION_ARGS@); for (int _i = 1; _i < ARK_ITER; ++_i) { sync_gpu<@NUM_BLOCKS@>(ARK_LOOP_SYNC_STATE); - ark_body(_buf, _i); + ark_body(_buf, _i, @FUNCTION_ARGS@); } if (threadIdx.x == 0) { __threadfence_system(); @@ -69,10 +69,10 @@ void ark_loop_kernel@NAME@(char *_buf, int *_iter) { } extern "C" __global__ __launch_bounds__(ARK_WARPS_PER_BLOCK * Arch::ThreadsPerWarp, 1) -void ark_kernel@NAME@(char *_buf, int _iter) { +void ark_kernel@NAME@(char *_buf, int _iter, @GLOBAL_ARGS@) { int *shared_mem = (int *)_ARK_SMEM; for (int i = threadIdx.x; i < ARK_SMEM_RESERVED_BYTES / sizeof(int); i += blockDim.x) { shared_mem[i] = 0; } - ark_body(_buf, _iter); + ark_body(_buf, _iter, @FUNCTION_ARGS@); } diff --git a/ark/include/kernels/layernorm.h b/ark/include/kernels/layernorm.h index b0f101e76..5bc17235d 100644 --- a/ark/include/kernels/layernorm.h +++ b/ark/include/kernels/layernorm.h @@ -63,6 +63,8 @@ struct LayerNorm { (tid_c + uc * UnitOutDims::C) * InDims::HW + (tid_n + un * UnitOutDims::N) * InDims::CHW; + UnitOp::sync_threads(); + DataType mean; DataType cmp; ReduceTypeMean::template identity<1>(&mean); @@ -108,7 +110,6 @@ struct LayerNorm { out[idx_out] = type::Mul::compute( type::Sub::compute(in[idx_in], mean), variance); } - UnitOp::sync_threads(); } }; diff --git a/ark/include/kernels/matmul.h b/ark/include/kernels/matmul.h index 3b97a3907..b14f10bf6 100644 --- a/ark/include/kernels/matmul.h +++ b/ark/include/kernels/matmul.h @@ -21,22 +21,27 @@ namespace ark { /// @tparam OutDims (ark::Vec) Output tensor leading dimensions. /// @tparam NCA (ark::Vec) A 2D vector with N and C dimensions of matrix A. /// @tparam NCB (ark::Vec) A 2D vector with N and C dimensions of matrix B. -/// @tparam TileShape (ark::Vec) The tile shape of matmul computation (m, n, k). +/// @tparam TileShape (ark::Vec) The output tile shape. /// @tparam ProblemSize (ark::Vec) The problem size of matmul computation /// (m, n, k). /// @tparam LeadingDims (ark::Vec) The leading dimensions of matrix inputs /// and outputs. (lda, ldc, ldc, ldb). -/// @tparam InnerLdimA (int) The leading dimension of the inner dimension of A. -/// @tparam InnerLdimB (int) The leading dimension of the inner dimension of B. +/// @tparam BatchStrideNA (int) +/// @tparam BatchStrideCA (int) +/// @tparam BatchStrideNB (int) +/// @tparam BatchStrideCB (int) +/// @tparam BatchStrideNC (int) +/// @tparam BatchStrideCC (int) /// @tparam IsColumnA (bool) Whether matrix A is column-major. /// @tparam IsColumnB (bool) Whether matrix B is column-major. /// @tparam NumWarps (int) The number of warps per uop. /// @tparam SmemBytes (int) The size of shared memory per uop. /// template DEVICE void matmul(DataTypeC *C, DataTypeA *A, DataTypeB *B, int uop_idx, int smem_per_warp) { @@ -44,7 +49,8 @@ DEVICE void matmul(DataTypeC *C, DataTypeA *A, DataTypeB *B, int uop_idx, "NCA should be two dimensional."); static_assert(NCB::D2 == 1 && NCB::D3 == 1, "NCB should be two dimensional."); - static_assert(TileShape::D3 == 1, "TileShape should be three dimensional."); + static_assert(TileShape::D2 == 1 && TileShape::D3 == 1, + "TileShape should be two dimensional."); static_assert(ProblemSize::D3 == 1, "ProblemSize should be three dimensional."); @@ -65,53 +71,26 @@ DEVICE void matmul(DataTypeC *C, DataTypeA *A, DataTypeB *B, int uop_idx, constexpr int ProblemSizeK = ProblemSize::D2; constexpr int TileSizeM = TileShape::D0; constexpr int TileSizeN = TileShape::D1; - constexpr int TileSizeK = TileShape::D2; - - constexpr DimType SizeA = math::mul::value; - constexpr DimType SizeB = math::mul::value; - constexpr DimType SizeC = math::mul::value; - static_assert(SizeA >= 0, ""); - static_assert(SizeB >= 0, ""); - static_assert(SizeC >= 0, ""); int un = UnitOp::uop_idx_n(uop_idx); int uc = UnitOp::uop_idx_c(uop_idx); // Broadcasting - DataTypeA *pA; - DataTypeB *pB; - DataTypeC *pC = &C[un * math::mul::value + uc * SizeC]; - if constexpr (NCA::D0 == 1 && NCA::D1 == 1) { - pA = A; - } else if constexpr (NCA::D0 == 1) { - pA = &A[uc * SizeA]; - } else if constexpr (NCA::D1 == 1) { - pA = &A[un * SizeA]; - } else { - pA = &A[un * math::mul::value + uc * SizeA]; - } - if constexpr (NCB::D0 == 1 && NCB::D1 == 1) { - pB = B; - } else if constexpr (NCB::D0 == 1) { - pB = &B[uc * SizeB]; - } else if constexpr (NCB::D1 == 1) { - pB = &B[un * SizeB]; - } else { - pB = &B[un * math::mul::value + uc * SizeB]; - } + DataTypeA *pA = &A[un * BatchStrideNA + uc * BatchStrideCA]; + DataTypeB *pB = &B[un * BatchStrideNB + uc * BatchStrideCB]; + DataTypeC *pC = &C[un * BatchStrideNC + uc * BatchStrideCC]; #if defined(ARK_TARGET_CUDA_ARCH) gemm_cutlass( + ProblemSizeK, TileSizeM, TileSizeN, UnitOp>( pC, pA, pB, uop_idx, smem_per_warp); #elif defined(ARK_TARGET_ROCM_ARCH) gemm_ck( - pC, pA, pB, uop_idx, smem_per_warp); + ProblemSizeK, TileSizeM, TileSizeN, UnitOp>(pC, pA, pB, uop_idx, + smem_per_warp); #endif - UnitOp::sync_threads(); } } // namespace ark diff --git a/ark/include/kernels/reduce.h b/ark/include/kernels/reduce.h index 9ebe6555c..62af5840b 100644 --- a/ark/include/kernels/reduce.h +++ b/ark/include/kernels/reduce.h @@ -397,6 +397,8 @@ struct WwiseReduce { DataType reduced[NelemPerThread]; + UnitOp::sync_threads(); + ReduceType::template identity(reduced); for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) { int idx_in = idx_in_base + idx_w; @@ -438,8 +440,6 @@ struct WwiseReduce { ReduceType::template postReduce<1>(&out[idx_out], &reduced[0], InShape::W); } - - UnitOp::sync_threads(); } }; diff --git a/ark/model/model_buffer.cpp b/ark/model/model_buffer.cpp index e637307fd..a54b6e81f 100644 --- a/ark/model/model_buffer.cpp +++ b/ark/model/model_buffer.cpp @@ -3,19 +3,25 @@ #include "model_buffer.hpp" +#include "buffer_registry.hpp" #include "logging.hpp" namespace ark { -ModelBuffer::ModelBuffer(int rank) : rank_(rank) { - static size_t id = 0; - id_ = id++; +size_t ModelBuffer::curr_id = 0; + +ModelBuffer::ModelBuffer(int rank, bool is_external) + : rank_(rank), is_external_(is_external) { + id_ = curr_id++; } -ModelBuffer::ModelBuffer(size_t id, int rank, +ModelBuffer::ModelBuffer(size_t id, int rank, bool is_external, const std::vector &send_tags, const std::vector &recv_tags) - : id_(id), rank_(rank) { + : id_(id), rank_(rank), is_external_(is_external) { + if (is_external && (!send_tags.empty() || !recv_tags.empty())) { + ERR(ModelError, "External buffer cannot have send or receive tags"); + } for (const auto &info : send_tags) { send_tags_.insert(info); } @@ -32,6 +38,22 @@ void ModelBuffer::tag_recv(int remote_rank, int tag) { recv_tags_.insert(TagInfo{remote_rank, tag}); } +void *ModelBuffer::data() const { + auto info = BufferRegistry::get_instance().get(id_); + if (info) { + return info->data; + } + return nullptr; +} + +void *ModelBuffer::data(void *data) { + if (is_external_) { + BufferRegistry::get_instance().set(id_, data, -1, true); + return data; + } + return nullptr; +} + Json ModelBuffer::serialize() const { Json j; j["Id"] = id_; @@ -44,6 +66,7 @@ Json ModelBuffer::serialize() const { for (const auto &info : recv_tags_) { recv_tags.push_back({info.first, info.second}); } + j["IsExternal"] = is_external_; j["SendTags"] = send_tags; j["RecvTags"] = recv_tags; return j; @@ -57,11 +80,15 @@ std::shared_ptr ModelBuffer::deserialize(const Json &serialized) { } else if (!serialized.contains("SendTags")) { ERR(ModelError, "ModelBuffer deserialization failed: missing SendTags"); } else if (!serialized.contains("RecvTags")) { - ERR(ModelError, "ModelBuffer deserialization failed: missing RecvTags"); + ERR(ModelError, + "ModelBuffer deserialization failed: missing RecvTags"); + } else if (!serialized.contains("IsExternal")) { + ERR(ModelError, + "ModelBuffer deserialization failed: missing IsExternal"); } - return std::make_shared(serialized["Id"], serialized["Rank"], - serialized["SendTags"], - serialized["RecvTags"]); + return std::make_shared( + serialized["Id"], serialized["Rank"], serialized["IsExternal"], + serialized["SendTags"], serialized["RecvTags"]); } } // namespace ark diff --git a/ark/model/model_buffer.hpp b/ark/model/model_buffer.hpp index 7ad3db206..d52f2bf26 100644 --- a/ark/model/model_buffer.hpp +++ b/ark/model/model_buffer.hpp @@ -17,15 +17,18 @@ class ModelBuffer { // (remote_rank, tag) using TagInfo = std::pair; - ModelBuffer(int rank = -1); + ModelBuffer(int rank = -1, bool is_external = false); - ModelBuffer(size_t id, int rank, const std::vector &send_tags, + ModelBuffer(size_t id, int rank, bool is_external, + const std::vector &send_tags, const std::vector &recv_tags); size_t id() const { return id_; } int rank() const { return rank_; } + bool is_external() const { return is_external_; } + const std::set &send_tags() const { return send_tags_; } const std::set &recv_tags() const { return recv_tags_; } @@ -40,13 +43,23 @@ class ModelBuffer { // but the same tag can only be used for one receiving buffer. void tag_recv(int remote_rank, int tag); + // Return the underlying data pointer if this buffer is allocated. + // Otherwise, return nullptr. + void *data() const; + + // Set the underlying data pointer if this buffer is externally managed. + // Return the input data pointer. Otherwise, return nullptr. + void *data(void *data); + Json serialize() const; static std::shared_ptr deserialize(const Json &serialized); private: + static size_t curr_id; size_t id_; int rank_; + bool is_external_; std::set send_tags_; std::set recv_tags_; }; diff --git a/ark/model/model_context_manager.cpp b/ark/model/model_context_manager.cpp index f1bb62e9d..799cce785 100644 --- a/ark/model/model_context_manager.cpp +++ b/ark/model/model_context_manager.cpp @@ -27,4 +27,8 @@ Json ModelContextManager::get(const std::string& key) const { return context_stack_->get(key); } +Json ModelContextManager::dump() const { + return context_stack_->dump(); +} + } // namespace ark diff --git a/ark/model/model_context_manager.hpp b/ark/model/model_context_manager.hpp index 6aa91692e..4dc246fe8 100644 --- a/ark/model/model_context_manager.hpp +++ b/ark/model/model_context_manager.hpp @@ -24,6 +24,8 @@ class ModelContextManager { Json get(const std::string& key) const; + Json dump() const; + private: std::shared_ptr context_stack_; std::vector keys_; diff --git a/ark/model/model_graph_impl.cpp b/ark/model/model_graph_impl.cpp index 7c1ea3fb5..7c72a7dd2 100644 --- a/ark/model/model_graph_impl.cpp +++ b/ark/model/model_graph_impl.cpp @@ -52,14 +52,15 @@ Json ModelGraphContextStack::get(const std::string &key) const { return Json(); } -std::map ModelGraphContextStack::get_all() const { - std::map cur; +Json ModelGraphContextStack::dump() const { + Json j = Json::object(); for (const auto &pair : this->storage_) { - if (!pair.second.empty()) { - cur[pair.first] = *pair.second.back(); + j[pair.first] = Json::array(); + for (const auto &value : pair.second) { + j[pair.first].emplace_back(*value); } } - return cur; + return j; } ModelGraph::Impl::Impl(const ModelGraph::Impl &other) { *this = other; } @@ -216,7 +217,7 @@ ModelNodeRef ModelGraph::Impl::add_op(ModelOpRef op) { producer->consumers.push_back(node); } - node->context = context_stack_->get_all(); + node->context = context_stack_->dump(); nodes_.push_back(node); return node; diff --git a/ark/model/model_graph_impl.hpp b/ark/model/model_graph_impl.hpp index 62944f999..b9646d057 100644 --- a/ark/model/model_graph_impl.hpp +++ b/ark/model/model_graph_impl.hpp @@ -38,7 +38,7 @@ class ModelGraphContextStack { Json get(const std::string &key) const; - std::map get_all() const; + Json dump() const; }; class ModelGraph::Impl { diff --git a/ark/model/model_json.cpp b/ark/model/model_json.cpp index c2099e2c9..31fb24d51 100644 --- a/ark/model/model_json.cpp +++ b/ark/model/model_json.cpp @@ -5,6 +5,7 @@ #include +#include "ark/dims.hpp" #include "logging.hpp" static std::stringstream &idnt(std::stringstream &ss, int indent) { @@ -26,14 +27,46 @@ static void verify_format_json(const std::string &name, const Json &json, const std::vector &array_fields) { for (const auto &field : required_fields) { if (!json.contains(field)) { - ERR(ErrorType, - name + ": " + field + " not found. Given: " + json.dump()); + ERR(ErrorType, name, ": ", field, + " not found. Given: ", json.dump()); } } for (const auto &field : array_fields) { if (!json.at(field).is_array()) { - ERR(ErrorType, name + ": " + field + - " is not an array. Given: " + json.dump()); + ERR(ErrorType, name, ": ", field, + " is not an array. Given: ", json.dump()); + } + } +} + +template +static void verify_format_dims(const std::string &name, const Json &json, + const std::vector &dims_fields) { + for (const auto &field : dims_fields) { + if (!json.at(field).is_array()) { + ERR(ErrorType, name, ": ", field, + " is not an array. Given: ", json.dump()); + } + std::vector dims; + try { + dims = json.at(field).get>(); + } catch (const std::exception &e) { + ERR(ErrorType, name, ": ", field, + " is not an array of integers. Given: ", json.dump()); + } + for (const auto &dim : dims) { + if (dim < 0) { + ERR(ErrorType, name, ": ", field, + " contains negative value. Given: ", json.dump()); + } + } + if (ZeroNotAllowed) { + for (const auto &dim : dims) { + if (dim == 0) { + ERR(ErrorType, name, ": ", field, + " contains zero value. Given: ", json.dump()); + } + } } } } @@ -52,10 +85,15 @@ static void verify_format_tensor(const Json &json) { const std::vector required_fields = { "Id", "DataType", "Shape", "Strides", "Offsets", "PaddedShape", "Buffer"}; - const std::vector array_fields = {"Shape", "Strides", - "Offsets", "PaddedShape"}; - verify_format_json("TensorJson", json, required_fields, - array_fields); + const std::vector dims_fields = {"Shape", "Strides", "Offsets", + "PaddedShape"}; + verify_format_json("TensorJson", json, required_fields, {}); + verify_format_dims("TensorJson", json, + { + "Offsets", + }); + verify_format_dims("TensorJson", json, + {"Shape", "Strides", "PaddedShape"}); verify_format_buffer(json.at("Buffer")); } @@ -264,9 +302,16 @@ static void verify_format_plan(const Json &json) { "NumWarpsPerProcessor", "TaskInfos", "ProcessorGroups"}; + if (!json.is_object()) { + std::string dumped = json.dump(); + if (dumped.size() > 100) { + dumped = dumped.substr(0, 100) + "..."; + } + ERR(PlanError, "Plan should be a JSON object. Given: ", dumped); + } for (const auto &field : required_fields) { if (!json.contains(field)) { - ERR(PlanError, field + " not found"); + ERR(PlanError, field, " not found"); } } if (!json.at("TaskInfos").is_array()) { diff --git a/ark/model/model_node.hpp b/ark/model/model_node.hpp index ca97f4540..437875676 100644 --- a/ark/model/model_node.hpp +++ b/ark/model/model_node.hpp @@ -28,7 +28,7 @@ class ModelNode { UniqueList producers; /// Graph context of this node. - std::map context; + Json context; }; } // namespace ark diff --git a/ark/model/model_op.cpp b/ark/model/model_op.cpp index 5db8576e8..8f222b75d 100644 --- a/ark/model/model_op.cpp +++ b/ark/model/model_op.cpp @@ -16,6 +16,7 @@ #include "ops/ops_math.hpp" #include "ops/ops_matmul.hpp" #include "ops/ops_noop.hpp" +#include "ops/ops_placeholder.hpp" #include "ops/ops_reduce.hpp" #include "ops/ops_refer.hpp" #include "ops/ops_reshape.hpp" @@ -78,6 +79,7 @@ const ModelOpType ModelOpT::from_name(const std::string &type_name) { MODEL_OP_TYPE_REGISTER(Sqrt); MODEL_OP_TYPE_REGISTER(Sub); MODEL_OP_TYPE_REGISTER(Tensor); + MODEL_OP_TYPE_REGISTER(Placeholder); MODEL_OP_TYPE_REGISTER(Transpose); MODEL_OP_TYPE_REGISTER(SendPacket); MODEL_OP_TYPE_REGISTER(RecvPacket); diff --git a/ark/model/model_op.hpp b/ark/model/model_op.hpp index f7323d6c0..ab261eb20 100644 --- a/ark/model/model_op.hpp +++ b/ark/model/model_op.hpp @@ -57,7 +57,7 @@ class ModelOp { virtual Json default_config( [[maybe_unused]] const ArchRef arch = ARCH_ANY) const { - return {{"NumTasks", 0}, {"NumWarps", 0}, {"SramBytes", 0}}; + return {{"NumWarps", 0}, {"SramBytes", 0}}; } void set_name(const std::string &name) { name_ = name; } diff --git a/ark/model/model_tensor.cpp b/ark/model/model_tensor.cpp index 713fbf62c..068783045 100644 --- a/ark/model/model_tensor.cpp +++ b/ark/model/model_tensor.cpp @@ -92,6 +92,16 @@ size_t ModelTensor::shape_bytes() const { return shape_.nelems() * data_type_->bytes(); } +void *ModelTensor::data() const { + return buffer_->data(); +} + +void *ModelTensor::data(void *data) { + return buffer_->data(data); +} + +bool ModelTensor::is_external() const { return buffer_->is_external(); } + Json ModelTensor::serialize() const { Json j; j["Id"] = id_; diff --git a/ark/model/model_tensor.hpp b/ark/model/model_tensor.hpp index 7c7afac2c..8c892f2b4 100644 --- a/ark/model/model_tensor.hpp +++ b/ark/model/model_tensor.hpp @@ -37,6 +37,12 @@ class ModelTensor { size_t shape_bytes() const; + void *data() const; + + void *data(void *data); + + bool is_external() const; + Json serialize() const; static std::shared_ptr deserialize(const Json &serialized); diff --git a/ark/ops/ops_arithmetic_test.cpp b/ark/ops/ops_arithmetic_test.cpp index 772da3276..6a878c667 100644 --- a/ark/ops/ops_arithmetic_test.cpp +++ b/ark/ops/ops_arithmetic_test.cpp @@ -216,6 +216,21 @@ ark::unittest::State test_add_broadcast() { return ark::unittest::SUCCESS; } +ark::unittest::State test_add_offset() { + { + ark::Model m; + ark::Tensor t0 = m.tensor({2, 64}, ark::FP16, {4, 128}, {2, 64}); + ark::Tensor t1 = m.tensor({2, 64}, ark::FP16); + ark::Tensor out = m.add(t0, t1); + + auto result = ark::op_test("add_offset", m, {t0, t1}, {out}, + baseline_add); + UNITTEST_LOG(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + } + return ark::unittest::SUCCESS; +} + ark::unittest::State test_add_invalid() { { ark::Model m; @@ -421,6 +436,7 @@ int main() { UNITTEST(test_add_bf16); UNITTEST(test_add_overwrite); UNITTEST(test_add_broadcast); + UNITTEST(test_add_offset); UNITTEST(test_add_invalid); UNITTEST(test_sub_fp32); UNITTEST(test_sub_invalid); diff --git a/ark/ops/ops_broadcast.cpp b/ark/ops/ops_broadcast.cpp index e5559fc32..2fd02b801 100644 --- a/ark/ops/ops_broadcast.cpp +++ b/ark/ops/ops_broadcast.cpp @@ -94,7 +94,7 @@ ModelOpBroadcast2::ModelOpBroadcast2(const std::string &type_name, std::string ModelOpBroadcast2::impl_name(const Json &config) const { check_fields_config(config, {"NumWarps", "Tile"}); int num_warps = config["NumWarps"]; - auto &tile_shape = config["Tile"]; + Dims unit_out_dims(config.at("Tile").get>()); return function_name_string( pascal_to_snake(type()->type_name()), @@ -104,8 +104,8 @@ std::string ModelOpBroadcast2::impl_name(const Json &config) const { vec_string(read_tensors_[1]->shape().dims4()), vec_string(write_tensors_[0]->strides().dims4()), vec_string(write_tensors_[0]->shape().dims4()), - vec_string({1, 1, tile_shape[0], tile_shape[1]}), - std::to_string(num_warps), std::to_string(0)}); + vec_string(unit_out_dims.dims4()), std::to_string(num_warps), + std::to_string(0)}); } std::vector ModelOpBroadcast2::impl_args([ diff --git a/ark/ops/ops_communication_test.cpp b/ark/ops/ops_communication_test.cpp index 8cdad41b2..de7c42833 100644 --- a/ark/ops/ops_communication_test.cpp +++ b/ark/ops/ops_communication_test.cpp @@ -25,7 +25,6 @@ ark::unittest::State test_communication_send_recv_unidir() { } ark::DefaultExecutor exe(model, gpu_id); - exe.compile(); if (gpu_id == 0) { std::vector data(1024); @@ -60,15 +59,14 @@ ark::unittest::State test_communication_send_recv_unidir() { ark::Model model(gpu_id, 2); ark::Tensor tns = model.tensor({1024}, ark::FP16); if (gpu_id == 1) { - tns = model.send(tns, 0, 0); - model.send_done(tns); + auto out_tns = model.send(tns, 0, 0); + model.send_done(out_tns); } if (gpu_id == 0) { tns = model.recv(tns, 1, 0); } ark::DefaultExecutor exe(model, gpu_id); - exe.compile(); if (gpu_id == 1) { std::vector data(1024); @@ -117,7 +115,6 @@ ark::unittest::State test_communication_send_recv_bidir() { tns2 = model.recv(tns2_data, remote_gpu_id, tag); ark::DefaultExecutor exe(model, gpu_id); - exe.compile(); std::vector data(1024); std::iota(data.begin(), data.end(), ark::half_t(gpu_id + 1)); @@ -161,7 +158,6 @@ ark::unittest::State test_communication_send_recv_bidir() { ark::Tensor sum = model.add(tns2, tns_data); ark::DefaultExecutor exe(model, gpu_id); - exe.compile(); std::vector data(1024); std::iota(data.begin(), data.end(), ark::half_t(gpu_id + 1)); @@ -232,7 +228,6 @@ ark::unittest::State test_communication_send_recv_bidir_sm() { tns2 = model.recv(tns2_data, remote_gpu_id, tag); ark::DefaultExecutor exe(model, gpu_id, nullptr, {config_rule}); - exe.compile(); std::vector data(1024); std::iota(data.begin(), data.end(), ark::half_t(gpu_id + 1)); @@ -276,7 +271,6 @@ ark::unittest::State test_communication_send_recv_bidir_sm() { ark::Tensor sum = model.add(tns2, tns_data); ark::DefaultExecutor exe(model, gpu_id, nullptr, {config_rule}); - exe.compile(); std::vector data(1024); std::iota(data.begin(), data.end(), ark::half_t(gpu_id + 1)); @@ -319,7 +313,6 @@ ark::unittest::State test_communication_send_packet() { } ark::DefaultExecutor exe(model, gpu_id); - exe.compile(); if (gpu_id == 0) { std::vector data(1024); @@ -362,7 +355,6 @@ ark::unittest::State test_communication_send_recv_reduce_packet() { model.recv_packet(shard_tensors[peer_gpu_id], peer_gpu_id, 1, 1); ark::DefaultExecutor exe(model, gpu_id); - exe.compile(); std::vector data(1024); std::iota(data.begin(), data.end(), 1.0f); @@ -433,8 +425,8 @@ ark::unittest::State test_communication_send_recv_reduce() { ark::Planner planner(model, gpu_id); planner.install_config_rule(config_rule); - ark::Executor exe(gpu_id, nullptr, "Executor", planner.plan()); - exe.compile(); + ark::Executor exe; + exe.compile(planner.plan(), gpu_id); std::vector data(1024); std::iota(data.begin(), data.end(), 1.0f); diff --git a/ark/ops/ops_embedding.cpp b/ark/ops/ops_embedding.cpp index 2e2626d4c..2d6b63720 100644 --- a/ark/ops/ops_embedding.cpp +++ b/ark/ops/ops_embedding.cpp @@ -21,9 +21,9 @@ ModelOpEmbedding::ModelOpEmbedding(ModelTensorRef input, ModelTensorRef weight, if (output) { check_match_data_type(weight, output); } else { - Dims input_shape = input->shape().dims4(); - Dims output_shape(input_shape[1], input_shape[2], input_shape[3], - weight->shape()[-1]); + auto shape_vec = input->shape().vector(); + shape_vec.push_back(weight->shape()[-1]); + Dims output_shape(shape_vec); output = std::make_shared( weight->data_type(), std::make_shared(), output_shape); } diff --git a/ark/ops/ops_identity_test.cpp b/ark/ops/ops_identity_test.cpp index a6e49c9c0..eb8d3f4d4 100644 --- a/ark/ops/ops_identity_test.cpp +++ b/ark/ops/ops_identity_test.cpp @@ -58,7 +58,6 @@ ark::unittest::State test_ops_identity() { // Create an executor ark::DefaultExecutor exe(model); - exe.compile(); int num_elem = 2 * 3 * 4 * 5; diff --git a/ark/ops/ops_matmul.cpp b/ark/ops/ops_matmul.cpp index dca349f44..823bf2656 100644 --- a/ark/ops/ops_matmul.cpp +++ b/ark/ops/ops_matmul.cpp @@ -98,7 +98,7 @@ ModelOpMatmul::ModelOpMatmul(ModelTensorRef input, ModelTensorRef other, } std::string ModelOpMatmul::impl_name(const Json &config) const { - check_fields_config(config, {"NumWarps", "SramBytes", "TileShapeMNK"}); + check_fields_config(config, {"NumWarps", "SramBytes", "Tile"}); check_fields_args(args_, {"TransposeInput", "TransposeOther"}); bool trans_input = args_.at("TransposeInput").value(); @@ -125,6 +125,7 @@ std::string ModelOpMatmul::impl_name(const Json &config) const { Dims other_shape_dims4 = other->shape().dims4(); Dims input_dim_nc{input_shape_dims4[0], input_shape_dims4[1]}; Dims other_dim_nc{other_shape_dims4[0], other_shape_dims4[1]}; + Dims output_dim_nc = broadcast_shape(input_dim_nc, other_dim_nc); Dims strides_acdb{ input->strides().dims4()[-1], output->strides().dims4()[-1], @@ -132,14 +133,14 @@ std::string ModelOpMatmul::impl_name(const Json &config) const { int num_warps = config["NumWarps"]; int smem_bytes = config["SramBytes"]; - Dims tile_shape_mnk = config["TileShapeMNK"].get>(); - if (tile_shape_mnk.ndims() != 3) { - ERR(PlanError, "TileShapeMNK should have 3 elements"); - } - for (int i = 0; i < 3; ++i) { - if (padded_problem_size[i] % tile_shape_mnk[i] != 0) { - ERR(PlanError, "output padded shape MNK ", padded_problem_size, - " should be divisible by tile shape MNK ", tile_shape_mnk); + Dims tile_shape = config["Tile"].get>(); + if (tile_shape.ndims() != 2) { + ERR(PlanError, "Tile should have 2 elements"); + } + for (int i = 0; i < 2; ++i) { + if (padded_output_shape[i - 2] % tile_shape[i] != 0) { + ERR(PlanError, "output padded shape ", padded_output_shape, + " should be divisible by tile shape ", tile_shape); } } @@ -156,16 +157,51 @@ std::string ModelOpMatmul::impl_name(const Json &config) const { inner_stride_b = other->strides().dims4()[-2]; } + DimType size_a = inner_stride_a * output->strides()[-2]; + DimType size_b = inner_stride_b * output->strides()[-1]; + DimType size_c = output->strides()[-2] * output->strides()[-1]; + DimType batch_stride_c_a = input_dim_nc[1] == 1 ? 0 : size_a; + DimType batch_stride_n_a = + input_dim_nc[0] == 1 ? 0 : size_a * input_dim_nc[1]; + DimType batch_stride_c_b = other_dim_nc[1] == 1 ? 0 : size_b; + DimType batch_stride_n_b = + other_dim_nc[0] == 1 ? 0 : size_b * other_dim_nc[1]; + DimType batch_stride_c_c = output_dim_nc[1] == 1 ? 0 : size_c; + DimType batch_stride_n_c = + output_dim_nc[0] == 1 ? 0 : size_c * output_dim_nc[1]; + if (config.contains("BatchStrideNA")) { + batch_stride_n_a = config["BatchStrideNA"].get(); + } + if (config.contains("BatchStrideNB")) { + batch_stride_n_b = config["BatchStrideNB"].get(); + } + if (config.contains("BatchStrideNC")) { + batch_stride_n_c = config["BatchStrideNC"].get(); + } + if (config.contains("BatchStrideCA")) { + batch_stride_c_a = config["BatchStrideCA"].get(); + } + if (config.contains("BatchStrideCB")) { + batch_stride_c_b = config["BatchStrideCB"].get(); + } + if (config.contains("BatchStrideCC")) { + batch_stride_c_c = config["BatchStrideCC"].get(); + } + return function_name_string("matmul", { vec_string(output->strides().dims4()), vec_string(input_dim_nc), vec_string(other_dim_nc), - vec_string(tile_shape_mnk), + vec_string(tile_shape), vec_string(padded_problem_size), vec_string(strides_acdb), - std::to_string(inner_stride_a), - std::to_string(inner_stride_b), + std::to_string(batch_stride_n_a), + std::to_string(batch_stride_c_a), + std::to_string(batch_stride_n_b), + std::to_string(batch_stride_c_b), + std::to_string(batch_stride_n_c), + std::to_string(batch_stride_c_c), std::to_string(trans_input), std::to_string(trans_other), std::to_string(num_warps), @@ -173,8 +209,8 @@ std::string ModelOpMatmul::impl_name(const Json &config) const { }); } -std::vector ModelOpMatmul::impl_args([ - [maybe_unused]] const Json &config) const { +std::vector ModelOpMatmul::impl_args( + [[maybe_unused]] const Json &config) const { return {result_tensors_[0], read_tensors_[0], read_tensors_[1]}; } @@ -191,29 +227,17 @@ static const Json get_default_config(const ArchRef arch, DimType tm = (mnk[0] > mnk[1]) ? 256 : 128; DimType tn = (mnk[0] > mnk[1]) ? 128 : 256; if (arch->belongs_to(ARCH_CUDA_80) && data_type == FP32.ref()) { - return {{"NumWarps", 8}, - {"SramBytes", 147456}, - {"TileShapeMNK", {tm, tn, 32}}}; + return {{"NumWarps", 8}, {"SramBytes", 147456}, {"Tile", {tm, tn}}}; } else if (arch->belongs_to(ARCH_CUDA_80) && data_type == FP16.ref()) { - return {{"NumWarps", 8}, - {"SramBytes", 147456}, - {"TileShapeMNK", {tm, tn, 64}}}; + return {{"NumWarps", 8}, {"SramBytes", 147456}, {"Tile", {tm, tn}}}; } else if (arch->belongs_to(ARCH_CUDA_80) && data_type == BF16.ref()) { - return {{"NumWarps", 8}, - {"SramBytes", 147456}, - {"TileShapeMNK", {tm, tn, 64}}}; + return {{"NumWarps", 8}, {"SramBytes", 147456}, {"Tile", {tm, tn}}}; } else if (arch->belongs_to(ARCH_ROCM_942) && data_type == FP32.ref()) { - return {{"NumWarps", 4}, - {"SramBytes", 24672}, - {"TileShapeMNK", {tm, tn, 16}}}; + return {{"NumWarps", 4}, {"SramBytes", 24672}, {"Tile", {tm, tn}}}; } else if (arch->belongs_to(ARCH_ROCM_942) && data_type == FP16.ref()) { - return {{"NumWarps", 4}, - {"SramBytes", 24672}, - {"TileShapeMNK", {tm, tn, 32}}}; + return {{"NumWarps", 4}, {"SramBytes", 24672}, {"Tile", {tm, tn}}}; } else if (arch->belongs_to(ARCH_ROCM_942) && data_type == BF16.ref()) { - return {{"NumWarps", 4}, - {"SramBytes", 24624}, - {"TileShapeMNK", {tm, tn, 32}}}; + return {{"NumWarps", 4}, {"SramBytes", 24624}, {"Tile", {tm, tn}}}; } ERR(InternalError, "Unexpected error"); return {}; @@ -227,18 +251,12 @@ Json ModelOpMatmul::default_config(const ArchRef arch) const { args_.at("TransposeInput").value(), args_.at("TransposeOther").value()); Json config = get_default_config(arch, result->data_type(), mnk); - size_t tile_x = config.at("TileShapeMNK")[0]; - size_t tile_y = config.at("TileShapeMNK")[1]; + size_t tile_x = config.at("Tile")[0]; + size_t tile_y = config.at("Tile")[1]; if (mnk[0] % tile_x != 0 || mnk[1] % tile_y != 0) { - ERR(PlanError, "output padded shape MNK ", mnk, - " should be divisible by tile shape MNK ", - config.at("TileShapeMNK")); - } - Dims result_shape = result->shape().dims4(); - size_t num_tasks = result_shape[0] * result_shape[1]; - num_tasks *= mnk[0] / tile_x; - num_tasks *= mnk[1] / tile_y; - config["NumTasks"] = num_tasks; + ERR(PlanError, "output padded shape ", Dims{mnk[0], mnk[1]}, + " should be divisible by tile shape ", config.at("Tile")); + } return config; } diff --git a/ark/ops/ops_placeholder.cpp b/ark/ops/ops_placeholder.cpp new file mode 100644 index 000000000..b654aac39 --- /dev/null +++ b/ark/ops/ops_placeholder.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ops_placeholder.hpp" + +#include "buffer_registry.hpp" +#include "logging.hpp" +#include "ops_common.hpp" + +namespace ark { + +ModelOpPlaceholder::ModelOpPlaceholder(ModelBufferRef buffer, const Dims &shape, + ModelDataType data_type, + const Dims &strides, const Dims &offsets, + const Dims &padded_shape, void *data) + : ModelOp("Placeholder", true) { + if (!buffer) { + buffer = std::make_shared(-1, true); + } + + BufferRegistry::get_instance().set(buffer->id(), data, -1, true); + + ModelTensorRef tensor = std::make_shared( + data_type, buffer, shape, strides, offsets, padded_shape); + + result_tensors_.emplace_back(tensor); + + verify(); +} + +Tensor Model::placeholder(const Dims &shape, const DataType &data_type, + const Dims &strides, const Dims &offsets, + const Dims &padded_shape, int rank, void *data, + const std::string &name) { + if (rank != -1) { + if (rank == this->rank()) { + rank = -1; + } else if (rank < 0 || rank >= this->world_size()) { + ERR(ModelError, "Invalid rank %d", rank); + } + } + return impl_ + ->create_op( + name, std::make_shared(rank, true), shape, + data_type.ref(), strides, offsets, padded_shape, data) + ->result_tensors()[0]; +} + +} // namespace ark diff --git a/ark/ops/ops_placeholder.hpp b/ark/ops/ops_placeholder.hpp new file mode 100644 index 000000000..14ae53144 --- /dev/null +++ b/ark/ops/ops_placeholder.hpp @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_OPS_PLACEHOLDER_HPP_ +#define ARK_OPS_PLACEHOLDER_HPP_ + +#include "ark/model.hpp" +#include "model/model_op.hpp" + +namespace ark { + +class ModelOpPlaceholder : public ModelOp { + public: + ModelOpPlaceholder() = default; + ModelOpPlaceholder(ModelBufferRef buffer, const Dims &shape, + ModelDataType data_type, const Dims &strides, + const Dims &offsets, const Dims &padded_shape, + void *data = nullptr); +}; + +} // namespace ark + +#endif // ARK_OPS_PLACEHOLDER_HPP_ diff --git a/ark/ops/ops_placeholder_test.cpp b/ark/ops/ops_placeholder_test.cpp new file mode 100644 index 000000000..e91629fc8 --- /dev/null +++ b/ark/ops/ops_placeholder_test.cpp @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ark/executor.hpp" +#include "gpu/gpu.hpp" +#include "logging.hpp" +#include "model/model_op.hpp" +#include "ops_test_common.hpp" + +ark::unittest::State test_ops_placeholder() { + ark::Model model; + ark::Dims shape{10, 1}; + + // Allocate GPU memory for the external buffer + float *d_ext_buffer = nullptr; + UNITTEST_EQ(ark::gpuMalloc(&d_ext_buffer, shape.nelems() * sizeof(float)), + ark::gpuSuccess); + + // Initialize GPU Memory + std::vector h_ext_buffer(shape.nelems()); + std::iota(h_ext_buffer.begin(), h_ext_buffer.end(), 1.0f); + UNITTEST_EQ(ark::gpuMemcpy(d_ext_buffer, h_ext_buffer.data(), + shape.nelems() * sizeof(float), + ark::gpuMemcpyHostToDevice), + ark::gpuSuccess); + + // Associate the initialized device buffer with a tensor produced from a + // placeholder operation + ark::Tensor tns = + model.placeholder(shape, ark::FP32, {}, {}, {}, -1, d_ext_buffer); + + ark::Tensor res = model.add(tns, 1.0); + + ark::DefaultExecutor exe(model); + + exe.launch(); + exe.run(1); + exe.stop(); + + UNITTEST_EQ(exe.tensor_address(tns), d_ext_buffer); + + // Copy tensor data from GPU to CPU + std::vector h_res(shape.nelems(), 0.0f); + exe.tensor_read(res, h_res); + + for (auto i = 0; i < shape.nelems(); ++i) { + UNITTEST_EQ(h_res[i], i + 2); + } + + UNITTEST_EQ(ark::gpuFree(d_ext_buffer), ark::gpuSuccess); + + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_placeholder_delayed_binding() { + ark::Model model; + ark::Dims shape{10, 1}; + + float *d_ext_buffer = nullptr; + UNITTEST_EQ(ark::gpuMalloc(&d_ext_buffer, shape.nelems() * sizeof(float)), + ark::gpuSuccess); + + std::vector h_ext_buffer(shape.nelems()); + std::iota(h_ext_buffer.begin(), h_ext_buffer.end(), 1.0f); + UNITTEST_EQ(ark::gpuMemcpy(d_ext_buffer, h_ext_buffer.data(), + shape.nelems() * sizeof(float), + ark::gpuMemcpyHostToDevice), + ark::gpuSuccess); + + // Create a placeholder tensor without binding the buffer yet + ark::Tensor tns = + model.placeholder(shape, ark::FP32, {}, {}, {}, -1, nullptr); + + ark::Tensor res = model.add(tns, 1.0); + + ark::DefaultExecutor exe(model); + + // Delay the binding by providing the tensor-to-address mapping at launch + std::unordered_map tensor_bindings; + tensor_bindings[tns] = reinterpret_cast(d_ext_buffer); + + exe.launch(tensor_bindings); + exe.run(1); + exe.stop(); + + // Copy tensor data from GPU to CPU + std::vector h_res(shape.nelems(), 0.0f); + exe.tensor_read(res, h_res); + + for (auto i = 0; i < shape.nelems(); ++i) { + UNITTEST_EQ(h_res[i], i + 2); + } + UNITTEST_EQ(ark::gpuFree(d_ext_buffer), ark::gpuSuccess); + + return ark::unittest::SUCCESS; +} + +int main() { + ark::init(); + UNITTEST(test_ops_placeholder); + UNITTEST(test_placeholder_delayed_binding); + return ark::unittest::SUCCESS; +} diff --git a/ark/ops/ops_reshape.cpp b/ark/ops/ops_reshape.cpp index aac22b71a..f8f5e942c 100644 --- a/ark/ops/ops_reshape.cpp +++ b/ark/ops/ops_reshape.cpp @@ -11,22 +11,28 @@ namespace ark { -// Reshape `input` to `shape`. This interface does not support -1 as a dimension -// of `shape`, because Dims does not allow -1 as a valid dimension. +// Reshape `input` to `inferred_shape`. This interface does not support -1 as a +// dimension of `inferred_shape`. static void reshape_helper(ModelTensorRef input, const Dims &inferred_shape, bool allowzero, Dims &new_shape, Dims &new_strides, Dims &new_offs) { const auto &orig_shape = input->shape(); const auto &orig_strides = input->strides(); const auto &orig_offsets = input->offsets(); + + std::stringstream ss; + ss << "reshape failed as the number of elements mismatch: reshape from " + << orig_shape << " to " << inferred_shape + << " (allowzero = " << allowzero << ")"; + auto nelems_mismatch_error = ss.str(); + // Calculate the new shape std::vector new_shape_vec; if (inferred_shape.ndims() == 0) { // Convert to a scalar new_shape_vec.emplace_back(1); if (orig_shape.nelems() != 1) { - ERR(ModelError, "number of elements mismatch: reshape from ", - orig_shape, " to ", inferred_shape); + ERR(ModelError, nelems_mismatch_error); } } else { DimType total_size = 1; @@ -46,13 +52,12 @@ static void reshape_helper(ModelTensorRef input, const Dims &inferred_shape, } } if (orig_shape.nelems() != total_size) { - ERR(ModelError, "number of elements mismatch: reshape from ", - orig_shape, " to ", inferred_shape); + ERR(ModelError, nelems_mismatch_error); } } new_shape = new_shape_vec; - std::stringstream ss; + ss = std::stringstream(); ss << "reshape failed as the strides of the input tensor is incompatible " "with the new shape. A workaround is copying the input tensor to a " "new tensor, so that the data becomes sequential in memory. "; @@ -104,12 +109,26 @@ static void reshape_helper(ModelTensorRef input, const Dims &inferred_shape, } else { if (orig_strides[orig_idx] != orig_shape[orig_idx] || orig_offsets[orig_idx] != 0) { - ERR(ModelError, incompatible_strides_error); - } - orig_idx--; - if (orig_idx >= 0) { - orig_shape_stack *= orig_shape[orig_idx]; - orig_strides_stack *= orig_strides[orig_idx]; + if (orig_shape[orig_idx] != 1 || reverse_strides.empty()) { + ERR(ModelError, incompatible_strides_error); + } + *reverse_strides.rbegin() *= orig_strides[orig_idx]; + DimType new_off = orig_offsets[orig_idx]; + for (auto i = orig_idx + 1; i < orig_strides.ndims(); i++) { + new_off *= orig_strides[i]; + } + *reverse_offsets.rbegin() = new_off; + orig_idx--; + if (orig_idx >= 0) { + orig_shape_stack = orig_shape[orig_idx]; + orig_strides_stack = orig_strides[orig_idx]; + } + } else { + orig_idx--; + if (orig_idx >= 0) { + orig_shape_stack *= orig_shape[orig_idx]; + orig_strides_stack *= orig_strides[orig_idx]; + } } } } diff --git a/ark/ops/ops_reshape_test.cpp b/ark/ops/ops_reshape_test.cpp index 1128c955a..550476199 100644 --- a/ark/ops/ops_reshape_test.cpp +++ b/ark/ops/ops_reshape_test.cpp @@ -9,7 +9,6 @@ void test_reshape_checker(ark::Model &m, ark::Tensor t0, ark::Tensor t1, const std::string &) { ark::DefaultExecutor exe(m); - exe.compile(); std::vector data_vec(t0.shape().nelems()); std::iota(data_vec.begin(), data_vec.end(), 1.0f); @@ -208,6 +207,38 @@ ark::unittest::State test_reshape_padded() { test_reshape_checker(model, tns0, tns1, "test_reshape_padded"); } + { + ark::Model model; + ark::Tensor tns0 = + model.tensor({1024, 1, 128}, ark::FP32, {1024, 64, 128}, {0, 8, 0}); + ark::Tensor tns1 = model.reshape(tns0, {1024, 128}); + + UNITTEST_EQ(tns1.shape(), ark::Dims(1024, 128)); + UNITTEST_EQ(tns1.strides(), ark::Dims(1024, 8192)); + UNITTEST_EQ(tns1.offsets(), ark::Dims(0, 1024)); + + // For preventing optimize-out + model.noop(tns0); + model.noop(tns1); + + test_reshape_checker(model, tns0, tns1, "test_reshape_padded"); + } + { + ark::Model model; + ark::Tensor tns0 = + model.tensor({1024, 2, 128}, ark::FP32, {1024, 64, 128}, {0, 8, 0}); + ark::Tensor tns1 = model.reshape(tns0, {1024, 256}); + + UNITTEST_EQ(tns1.shape(), ark::Dims(1024, 256)); + UNITTEST_EQ(tns1.strides(), ark::Dims(1024, 8192)); + UNITTEST_EQ(tns1.offsets(), ark::Dims(0, 1024)); + + // For preventing optimize-out + model.noop(tns0); + model.noop(tns1); + + test_reshape_checker(model, tns0, tns1, "test_reshape_padded"); + } return ark::unittest::SUCCESS; } @@ -269,6 +300,12 @@ ark::unittest::State test_reshape_invalid() { ark::Tensor tns = model.tensor({64, 256}, ark::FP32, {64, 512}); UNITTEST_THROW(model.reshape(tns, {16384}), ark::ModelError); } + { + ark::Model model; + ark::Tensor tns = + model.tensor({1024, 1}, ark::FP32, {1024, 64}, {0, 8}); + UNITTEST_THROW(model.reshape(tns, {1024}), ark::ModelError); + } return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_scalar_test.cpp b/ark/ops/ops_scalar_test.cpp index 6afc9e1ad..47a5b40bd 100644 --- a/ark/ops/ops_scalar_test.cpp +++ b/ark/ops/ops_scalar_test.cpp @@ -66,7 +66,6 @@ ark::unittest::State test_scalar_assign_fp16() { ark::Tensor t = m.constant(7, ark::Dims(4, 2, 50), ark::FP16); ark::DefaultExecutor exe(m); - exe.compile(); exe.launch(); exe.run(1); @@ -84,7 +83,6 @@ ark::unittest::State test_scalar_assign_fp16() { ark::Tensor out = m.copy(7, t); ark::DefaultExecutor exe(m); - exe.compile(); std::vector data(4 * 2 * 50, 3); exe.tensor_write(t, data); @@ -109,7 +107,6 @@ ark::unittest::State test_scalar_assign_fp32() { ark::Tensor out = m.copy(7); ark::DefaultExecutor exe(m); - exe.compile(); exe.launch(); exe.run(1); diff --git a/ark/ops/ops_tensor_test.cpp b/ark/ops/ops_tensor_test.cpp index be6488ef1..a2c36fd8c 100644 --- a/ark/ops/ops_tensor_test.cpp +++ b/ark/ops/ops_tensor_test.cpp @@ -20,7 +20,6 @@ ark::unittest::State test_tensor_strides() { // Create an executor ark::DefaultExecutor exe(model); - exe.compile(); // Fill buffer data: {1.0, 2.0, 3.0, 4.0} std::vector data(shape.nelems()); @@ -53,7 +52,6 @@ ark::unittest::State test_tensor_memcpy() { // Create an executor ark::DefaultExecutor exe(model); - exe.compile(); // Fill buffer data: {1.0, 2.0, 3.0, ..., 3024.0} std::vector data(strides.nelems()); @@ -138,7 +136,6 @@ ark::unittest::State test_tensor_layout() { // Create an executor ark::DefaultExecutor exe(model); - exe.compile(); // Fill tensor data: {1.0, 2.0, 3.0, ..., 120.0} std::vector data(2 * 3 * 4 * 5); diff --git a/ark/ops/ops_test_common.cpp b/ark/ops/ops_test_common.cpp index 4e94d06a7..bfbe79a70 100644 --- a/ark/ops/ops_test_common.cpp +++ b/ark/ops/ops_test_common.cpp @@ -9,6 +9,7 @@ #include "ark/model.hpp" #include "ark/planner.hpp" #include "ark/random.hpp" +#include "cpu_timer.h" #include "env.h" #include "gpu/gpu_logging.hpp" #include "logging.hpp" @@ -38,7 +39,6 @@ OpsTestResult op_test( const std::vector &config_rules, bool print_on_error) { DefaultExecutor exe(model, -1, nullptr, config_rules); - exe.compile(); std::vector>> inputs_data_storages; std::vector inputs_data_refs; @@ -195,17 +195,21 @@ OpsTestResult op_test( // use a magic number here. int iter = 1000; exe.launch(); + double start = cpu_timer(); exe.run(iter); - float msec = exe.stop(); + exe.stop(); + double msec = (cpu_timer() - start) * 1000; result.iter = iter; result.msec_per_iter = msec / iter; } else { // Rough measure. int warmup_iter = 3; - float target_msec = 5000; + double target_msec = 5000; exe.launch(); + double start = cpu_timer(); exe.run(warmup_iter); - float warmup_msec = exe.stop(); + exe.stop(); + double warmup_msec = (cpu_timer() - start) * 1000; if (warmup_msec > target_msec) { // Warm-up was long enough. @@ -214,8 +218,10 @@ OpsTestResult op_test( } else { int iter = int(target_msec / warmup_msec) * warmup_iter; exe.launch(); + start = cpu_timer(); exe.run(iter); - float msec = exe.stop(); + exe.stop(); + double msec = (cpu_timer() - start) * 1000; result.iter = iter; result.msec_per_iter = msec / iter; } diff --git a/ark/ops/ops_transpose.cpp b/ark/ops/ops_transpose.cpp index d0f7581cc..b7a67c8c0 100644 --- a/ark/ops/ops_transpose.cpp +++ b/ark/ops/ops_transpose.cpp @@ -85,10 +85,20 @@ std::string ModelOpTranspose::impl_name(const Json &config) const { auto permutation = args_.at("Permutation").value(); auto perm_str = permutation_str(permutation); int num_warps = config["NumWarps"]; - auto &tile_shape = config["Tile"]; - Dims unit_out_dims{tile_shape[0], tile_shape[1]}; - if (tile_shape[0] < 0) unit_out_dims[0] = write_tensors_[0]->strides()[-2]; - if (tile_shape[1] < 0) unit_out_dims[1] = write_tensors_[0]->strides()[-1]; + Dims unit_out_dims{config["Tile"].get>()}; + auto result_tensor_shape = result_tensors_[0]->shape(); + if (unit_out_dims.ndims() > result_tensor_shape.ndims()) { + ERR(ModelError, + "The number of dimensions of Tile should be less than or equal to " + "the number of dimensions of the result tensor. Given Tile: ", + unit_out_dims, ", output tensor shape: ", result_tensor_shape); + } + int ndims = unit_out_dims.ndims(); + for (int i = 0; i < ndims; ++i) { + if (unit_out_dims[i] < 0) { + unit_out_dims[i] = result_tensor_shape[i - ndims]; + } + } return function_name_string( "transpose" + perm_str, diff --git a/ark/ops/ops_transpose_test.cpp b/ark/ops/ops_transpose_test.cpp index 999d2c6e9..139e1ee66 100644 --- a/ark/ops/ops_transpose_test.cpp +++ b/ark/ops/ops_transpose_test.cpp @@ -2,9 +2,13 @@ // Licensed under the MIT license. #include "ark/model.hpp" +#include "ark/planner.hpp" +#include "model/model_json.hpp" #include "ops_test_common.hpp" #include "unittest/unittest_utils.h" +#define SYNC_TEST 0 + template void baseline_transpose_0132(std::vector &outputs, const std::vector &output_shapes, @@ -53,6 +57,41 @@ void baseline_transpose_0231(std::vector &outputs, } }; +template +void baseline_transpose_0213(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes, int) { + T *out = static_cast(outputs[0]); + T *in = static_cast(inputs[0]); + ark::Dims osh = output_shapes[0].dims4(); + ark::Dims ish = input_shapes[0].dims4(); + for (ark::DimType n = 0; n < ish[0]; ++n) { + for (ark::DimType c = 0; c < ish[1]; ++c) { + for (ark::DimType h = 0; h < ish[2]; ++h) { + for (ark::DimType w = 0; w < ish[3]; ++w) { + // out[n][h][c][w] = in[n][c][h][w] + out[w + c * osh[3] + h * osh[2] * osh[3] + + n * osh[1] * osh[2] * osh[3]] = + in[w + h * ish[3] + c * ish[3] * ish[2] + + n * ish[3] * ish[2] * ish[1]]; + } + } + } + } +}; + +template +void baseline_transpose_sync_test(std::vector &outputs, + const std::vector &, + const std::vector &inputs, + const std::vector &input_shapes, + int) { + T *out = static_cast(outputs[0]); + T *in = static_cast(inputs[0]); + ::memcpy(out, in, sizeof(T) * input_shapes[0].nelems()); +}; + ark::unittest::State test_transpose_0132_fp32() { ark::Model m; ark::Tensor t = m.tensor({5, 3, 32, 128}, ark::FP32); @@ -125,6 +164,80 @@ ark::unittest::State test_transpose_0231_bf16() { return ark::unittest::SUCCESS; } +ark::unittest::State test_transpose_0213_fp32() { + ark::Model m; + ark::Tensor t = m.tensor({5, 3, 32, 128}, ark::FP32); + ark::Tensor out = m.transpose(t, {0, 2, 1, 3}); + + auto result = ark::op_test("transpose_0213_fp32", m, {t}, {out}, + baseline_transpose_0213); + UNITTEST_LOG(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_transpose_0213_fp16() { + ark::Model m; + ark::PlannerContext ctx(m); + ctx.warp_range(0, 4); + ctx.sram_range(0, 0); + ctx.sync(false); + ctx.config(ark::Json({{"NumWarps", 4}, {"SramBytes", 0}, {"Tile", {8, 64}}}) + .dump()); + + ark::Tensor t = m.tensor({5, 256, 32, 128}, ark::FP16); + ark::Tensor out = m.transpose(t, {0, 2, 1, 3}); + + auto result = ark::op_test("transpose_0213_fp16", m, {t}, {out}, + baseline_transpose_0213); + UNITTEST_LOG(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_transpose_0213_bf16() { + ark::Model m; + ark::Tensor t = m.tensor({5, 3, 32, 128}, ark::BF16); + ark::Tensor out = m.transpose(t, {0, 2, 1, 3}); + + auto result = ark::op_test("transpose_0213_bf16", m, {t}, {out}, + baseline_transpose_0213); + UNITTEST_LOG(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_transpose_sync_test() { + ark::Model m; + ark::PlannerContext shared_ctx(m); + shared_ctx.warp_range(0, 4); + shared_ctx.sram_range(0, 0); + shared_ctx.sync(false); + + ark::Tensor in, t, out; + in = m.tensor({1, 16, 2, 64}, ark::FP16); + { + ark::PlannerContext ctx(m); + ctx.config( + ark::Json({{"NumWarps", 4}, {"SramBytes", 0}, {"Tile", {8, 64}}}) + .dump()); + t = m.transpose(in, {0, 2, 1, 3}); + } + { + ark::PlannerContext ctx(m); + ctx.config( + ark::Json({{"NumWarps", 4}, {"SramBytes", 0}, {"Tile", {8, 1, 64}}}) + .dump()); + out = m.transpose(t, {0, 2, 1, 3}); + } + + auto result = ark::op_test("transpose_sync_test", m, {in}, {out}, + baseline_transpose_sync_test); + UNITTEST_LOG(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + ark::unittest::State test_transpose_invalid() { { ark::Model m; @@ -157,6 +270,12 @@ int main() { UNITTEST(test_transpose_0231_fp32); UNITTEST(test_transpose_0231_fp16); UNITTEST(test_transpose_0231_bf16); + UNITTEST(test_transpose_0213_fp32); + UNITTEST(test_transpose_0213_fp16); + UNITTEST(test_transpose_0213_bf16); +#if (SYNC_TEST) + UNITTEST(test_transpose_sync_test); +#endif UNITTEST(test_transpose_invalid); return ark::unittest::SUCCESS; } diff --git a/ark/utils/utils_net_test.cpp b/ark/utils/utils_net_test.cpp index 4c3b6f162..95dda890c 100644 --- a/ark/utils/utils_net_test.cpp +++ b/ark/utils/utils_net_test.cpp @@ -12,6 +12,7 @@ ark::unittest::State test_ipc_hosts() { auto tmp_hostfile = tmp_dir + "/.test_ipc_hostfile"; ark::write_file(tmp_hostfile, "127.0.0.1\n127.0.0.1\n127.0.0.1\n"); ::setenv("ARK_HOSTFILE", tmp_hostfile.c_str(), 1); + ::setenv("ARK_KEEP_TMP", "1", 1); ark::init(); UNITTEST_EQ(ark::get_host(0, true), "127.0.0.1"); @@ -31,6 +32,7 @@ ark::unittest::State test_ipc_hosts_unknown_host() { auto tmp_hostfile = tmp_dir + "/.test_ipc_hostfile"; ark::write_file(tmp_hostfile, "unknown\nunknown\nunknown\n"); ::setenv("ARK_HOSTFILE", tmp_hostfile.c_str(), 1); + ::setenv("ARK_KEEP_TMP", "1", 1); ark::init(); UNITTEST_THROW(ark::get_host(0, true), ark::InvalidUsageError); diff --git a/docs/env.md b/docs/env.md index 2d5839c3b..95330a032 100644 --- a/docs/env.md +++ b/docs/env.md @@ -27,3 +27,7 @@ - `ARK_DISABLE_IB` (Default: `0`; Options: `0`, `1`) If set to `1`, disable ibverbs networking (i.e., disable multi-node execution). + +- `ARK_IGNORE_BINARY_CACHE` (Default: `1`; Options: `0`, `1`) + + If set to `1`, ignore the binary cache and force ARK to recompile binaries on each run. diff --git a/docs/plan_file.md b/docs/plan_file.md index 90a4537a2..2f93b51a0 100644 --- a/docs/plan_file.md +++ b/docs/plan_file.md @@ -6,6 +6,7 @@ See an example plan file: [Example 1](../examples/tutorial/default_plan.json) - Rank (Int) - WorldSize (Int) + - Architecture (String) - NumProcessors (Int) - NumWarpsPerProcessor (Int) - TaskInfos (Array of TaskInfo) @@ -42,6 +43,23 @@ See an example plan file: [Example 1](../examples/tutorial/default_plan.json) `ProcessorRange`, `WarpRange`, `SramRange`, and `TaskRange` are in the "range" format, i.e., `[Begin, End, Step]` that indicates an arithmetic integer sequence with a common difference of `Step`, starting from `Begin` and ends before `End` (does not include `End`). They alternatively can be in the format `[Begin, End]` that assumes `Step` is 1. +## Architecture + +A name that refers to the hardware architecture where the plan is supposed to run over. The following names are currently supported. + +- `ANY`: compatible with all architectures. + +- NVIDIA Family + - `CUDA`: compatible with all supported NVIDIA architectures. + - `CUDA_70`: compatible with NVIDIA Volta architecture. + - `CUDA_80`: compatible with NVIDIA Ampere architecture. + - `CUDA_90`: compatible with NVIDIA Hopper architecture. + +- AMD Family + - `ROCM`: compatible with all supported AMD architectures. + - `ROCM_90A`: compatible with AMD CDNA 2 (GFX90A) architecture. + - `ROCM_942`: compatible with AMD CDNA 3 (GFX942) architecture. + ## TaskInfo A `TaskInfo` object describes a sequential set of operators. The followings describe each field of `TaskInfo`. @@ -57,47 +75,36 @@ Structure of an `Op` object in a plan file is the same as [the one in the model ### Config Details -The followings explain a few fields that many configs commonly consist of. +The followings explain a few fields that many configs consist of. -- `NumWarps`: number of concurrent warps needed to calculate a single output tile. -- `SramBytes`: bytes of SRAM needed to calculate a single output tile. -- `NumTasks`: total number of output tiles need to compute. +- `Tile` (Optional): up-to-4-dimensional shape of a single tile. A tile refers to elements that each task calculates for the first result tensor. The shape of the first result tensor should be divisible by the tile shape. `Tile` may not be needed depending on the operator type. +- `NumWarps`: number of concurrent warps needed to calculate a single tile. +- `SramBytes`: bytes of SRAM needed to calculate a single tile. +- `NumTasks` (Optional): total number of tiles need to compute. If `NumTasks` is not provided, it will be calculated as the number of elements in the first result tensor divided by the number of elements in a single `Tile`. If both `NumTasks` and `Tile` are not provided, no computation will be conducted (regarded as `NumTask == 0`). The followings describe `Config` structure of different types of operators. -- `Matmul` - - `NumWarps` - - `SramBytes` - - `NumTasks` - - `TileShapeMNK`: tile shape of matrix multiplication in the [M,N,K] format. - - `TilePadMNK`: this field is not well defined and will be updated in the future. Currently, it should be the same as `TileShapeMNK`. - - `ReduceSum`, `ReduceMax`, `ReduceMean` + - `Tile` (Optional) - `NumWarps` - `SramBytes` - - `NumTasks` + - `NumTasks` (Optional) - `ImplType`: type of reduction implementation, either `WarpWise` or `ElementWise`. - `Send`, `SendDone`, `Recv` - - `NumWarps`: should be always 1. + - `NumWarps` - `SramBytes`: should be always 0. - `NumTasks`: should be always 1. -- `Embedding` - - `NumWarps` - - `SramBytes` - - `NumTasks` - - `Noop` - `NumWarps`: should be always 1. - `SramBytes`: should be always 0. - - `NumTasks`: should be always 0. - `Default`: all other operators that are not listed above follow this structure. + - `Tile` (Optional) - `NumWarps` - `SramBytes` - - `NumTasks` - - `Tile`: 2-dimensional shape of a single output tile. + - `NumTasks` (Optional) ## ProcessorGroup @@ -116,6 +123,6 @@ A `ResourceGroup` object describes computing tasks that use the entire or a subs ## TaskGroup -A `TaskGroup` object describes computing tasks. Each task can be typically considered as computing a single output tile of an operator. The `TaskId` field declares the type of task, of which details are found from `TaskInfos`. The `TaskRange` field declares tasks to run, which should be within the range `[0, NumTasks)` where `NumTasks` is found from `Config` of operators in the `TaskInfo`. If there are multiple operators in a `TaskInfo`, all operators should have the same `NumTasks`. +A `TaskGroup` object describes computing tasks. Each task can be typically considered as computing a single result tile of an operator. The `TaskId` field declares the type of task, of which details are found from `TaskInfos`. The `TaskRange` field declares tasks to run, which should be within the range `[0, NumTasks)` where `NumTasks` is found from `Config` of operators in the `TaskInfo`. If there are multiple operators in a `TaskInfo`, all operators should have the same `NumTasks`. Tasks in the `TaskRange` are distributed across processors in the resource group. If `Granularity` is 1, the distribution is round-robin. Otherwise, the distribution assigns `Granularity` consequent tasks to each processor (as long as there are enough tasks), and then assign the following task to the next processor. `Granularity` should be always a positive integer. diff --git a/examples/ffn/Makefile b/examples/ffn/Makefile deleted file mode 100644 index 996f8a187..000000000 --- a/examples/ffn/Makefile +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -ARK_ROOT ?= /usr/local/ark -CUDIR ?= /usr/local/cuda - -CXX := g++ -CXXFLAGS := -std=c++17 -Wall -Wextra -INCLUDE := -I$(ARK_ROOT)/include -I $(CUDIR)/include -I$(ARK_ROOT)/include/kernels -LDFLAGS := -L$(CUDIR)/lib64/stubs -Wl,-rpath,$(CUDIR)/lib64 -LDLIBS := -lcuda -lnvidia-ml -lnvrtc -lpthread -lrt -libverbs -lnuma - -all: build/ffn - -build/ffn: build/ffn.o - $(CXX) -o $@ $< -L$(ARK_ROOT)/lib -lark $(LDFLAGS) $(LDLIBS) - -build/ffn.o: ffn.cc - mkdir -p $(@D) - $(CXX) -o $@ $(CXXFLAGS) $(INCLUDE) -c $< - -clean: - rm -r build/ diff --git a/examples/ffn/ffn.cc b/examples/ffn/ffn.cc deleted file mode 100644 index 6eee77a7d..000000000 --- a/examples/ffn/ffn.cc +++ /dev/null @@ -1,450 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "ark.h" -#include "ark_utils.h" - -using namespace std; -using namespace ark; - -void print_tensor(Tensor *tensor, Executor *exe) { - if (tensor == nullptr) { - return; - } - cout << "tensor: " << tensor->name << endl; - size_t tensor_size = tensor->shape_bytes(); - half_t *data = (half_t *)malloc(tensor_size); - exe->tensor_memcpy(data, tensor, tensor_size); - for (int i = 0; i < tensor->size(); ++i) { - cout << data[i] << " "; - } - cout << endl; - delete[] data; -} - -class FullyConnectedLayer { - public: - FullyConnectedLayer(int dim_input, int dim_output, TensorType dtype, - Model &model) - : model{model} { - Tensor *weight = model.tensor({dim_input, dim_output}, dtype); - Tensor *bias = model.tensor({1, dim_output}, dtype); - params = {weight, bias}; - } - - Tensor *forward(Tensor *input) { - this->input = input; - Tensor *weight = params[0]; - Tensor *output1 = model.matmul(input, weight); - Tensor *bias = params[1]; - Tensor *output2 = model.add(output1, bias); - return output2; - } - - Tensor *backward(Tensor *grad) { - Tensor *weight = params[0]; - Tensor *bias = params[1]; - Tensor *grad_output2 = grad; - Tensor *grad_bias = model.tensor(bias->shape, bias->type); - grad_bias = model.scale(grad_output2, 1, grad_bias); - Tensor *grad_output1 = grad_output2; - Tensor *grad_input = model.tensor(input->shape, input->type); - Tensor *grad_weight = model.tensor(weight->shape, weight->type); - grad_input = - model.matmul(grad_output1, weight, nullptr, 1, false, true); - grad_weight = - model.matmul(input, grad_output1, nullptr, 1, true, false); - grads[weight] = grad_weight; - grads[bias] = grad_bias; - return grad_input; - } - - void apply_grads() { - for (auto ¶m : params) { - Tensor *grad = grads[param]; - // the learning rate - Tensor *grad_scale = model.scale(grad, -0.0001); - Tensor *param_identity = model.identity(param); - model.add(param, grad_scale, param_identity); - } - } - - void print_tensors(Executor *exe) { - print_tensor(input, exe); - // print the parameters. - for (size_t i = 0; i < params.size(); ++i) { - print_tensor(params[i], exe); - } - } - - Tensor *input; - vector params; - map grads; - Model &model; -}; - -class FFN_Model { - public: - // - FFN_Model(int dim_model, TensorType dtype, Model &model, int layer_num, - int num_gpus, int gpu_id) - : model{model}, num_gpus{num_gpus}, gpu_id{gpu_id} { - for (int i = 0; i < layer_num; ++i) { - FullyConnectedLayer layer{dim_model, dim_model, dtype, model}; - layers.push_back(layer); - } - } - - Model &get_model() { return model; } - - // - Tensor *forward(Tensor *input = nullptr) { - for (size_t i = 0; i < layers.size(); ++i) { - printf("forward layer: %d\n", i); - input = layers[i].forward(input); - } - return input; - } - - // - void backward(Tensor *grad) { - for (int i = layers.size() - 1; i >= 0; --i) { - printf("backward layer: %d\n", i); - grad = layers[i].backward(grad); - } - DimType grads_size = 0; - vector grads; - - for (auto &layer : layers) { - for (auto ¶m : layer.params) { - grads.push_back(layer.grads[param]); - grads_size += layer.grads[param]->size(); - } - } - - // All-reduce gradients - if (num_gpus > 1) { - Tensor *gradients = model.tensor({1, grads_size, 1, 1}, FP16); - Tensor *idn = model.identity(gradients, {grads}); - - model.all_reduce(idn, gpu_id, num_gpus); - } - } - - void print_tensors(Executor *exe) { - for (size_t i = 0; i < layers.size(); ++i) { - printf("layer: %d\n", i); - layers[i].print_tensors(exe); - } - } - - Model &model; - // model parameters. - vector layers; - Tensor *model_input; - int num_gpus; - int gpu_id; -}; - -class LossFn { - public: - LossFn(Model &model) : model{model} {} - - Tensor *forward(Tensor *output, Tensor *ground_truth) { - this->output = output; - printf("loss forward"); - neg_ground_truth = - model.tensor(ground_truth->shape, ground_truth->type); - neg_ground_truth = model.scale(ground_truth, -1, neg_ground_truth); - diff = model.tensor(output->shape, output->type); - model.add(output, neg_ground_truth, diff); - diff1 = model.tensor(diff->shape, diff->type); - model.scale(diff, 1, diff1); - loss_tensor = model.tensor(diff->shape, diff->type); - model.mul(diff, diff1, loss_tensor); - return loss_tensor; - } - - Tensor *backward(Tensor *loss_tensor) { - printf("loss backward"); - grad_diff = model.tensor(diff->shape, diff->type); - model.mul(loss_tensor, diff, grad_diff); - return grad_diff; - } - - void print_tensors(Executor *exe) { - printf("loss_fn.output: "); - print_tensor(this->output, exe); - printf("loss_fn.neg_ground_truth: "); - print_tensor(this->neg_ground_truth, exe); - printf("loss_fn.diff: "); - print_tensor(this->diff, exe); - printf("loss_fn.diff1: "); - print_tensor(this->diff1, exe); - printf("loss_fn.neg_ground_truth: "); - print_tensor(this->neg_ground_truth, exe); - printf("loss_fn.loss_tensor: "); - print_tensor(this->loss_tensor, exe); - printf("loss_fn.grad_diff: "); - print_tensor(this->grad_diff, exe); - } - Tensor *output; - Tensor *loss_tensor; - Tensor *neg_ground_truth; - Tensor *diff; - Tensor *diff1; - Tensor *grad_diff; - Model &model; -}; - -class Trainer { - public: - Trainer(Model &model, int dim_input, int batch_size, int gpu_id, - int num_gpus) - : model{model}, - ffn_model{dim_input, FP16, model, 2, num_gpus, gpu_id}, - loss_fn{model}, - batch_size{batch_size}, - num_gpus{num_gpus}, - gpu_id{gpu_id} { - input = model.tensor({batch_size, dim_input}, FP16); - ground_truth = model.tensor({batch_size, dim_input}, FP16); - output = ffn_model.forward(input); - loss_tensor = loss_fn.forward(output, ground_truth); - grad_loss = model.tensor(loss_tensor->shape, loss_tensor->type); - grad_output = loss_fn.backward(grad_loss); - ffn_model.backward(grad_output); - apply_grad(); - - exe = new Executor(gpu_id, gpu_id, (int)num_gpus, model, - "sampleFFN_Model"); - exe->compile(); - } - - void init_data() { - // init the input and ground_truth. - auto data_input = - ark::utils::range_halfs(this->input->shape_bytes(), 1, 0); - exe->tensor_memcpy(this->input, data_input.get(), - this->input->shape_bytes()); - auto data_ground_truth = - ark::utils::range_halfs(this->ground_truth->shape_bytes(), 2, 0); - exe->tensor_memcpy(this->ground_truth, data_ground_truth.get(), - this->ground_truth->shape_bytes()); - // init the grad_loss with 1. - auto data_grad_loss = - ark::utils::range_halfs(this->grad_loss->shape_bytes(), 1, 0); - exe->tensor_memcpy(this->grad_loss, data_grad_loss.get(), - this->grad_loss->shape_bytes()); - // init all the parameters of the model with random values. - for (auto &layer : ffn_model.layers) { - for (auto ¶m : layer.params) { - auto data = ark::utils::rand_halfs(param->shape_bytes(), 1); - exe->tensor_memcpy(param, data.get(), param->shape_bytes()); - } - } - } - - void train(int iter, int print_interval = 1) { - exe->launch(); - if (print_interval == 0) { - // don't print the loss for debug. - exe->run(iter); - } else { - // we only print the loss every print_interval iterations for debug. - for (int i = 0; i < iter; ++i) { - exe->run(1); - exe->wait(); - if (i % print_interval == 0) { - float loss = get_loss(); - cout << "iter: " << i << ", loss: " << loss << endl; - } - } - } - float elapsed_msec = exe->stop(); - cout << "Elapsed: " << elapsed_msec / iter << " ms/iter\n"; - } - - float get_loss() { - size_t tensor_size = this->loss_tensor->shape_bytes(); - half_t *loss = (half_t *)malloc(tensor_size); - exe->tensor_memcpy(loss, this->loss_tensor, tensor_size); - float loss_sum = 0; - for (int i = 0; i < this->loss_tensor->size(); ++i) { - loss_sum += (float)loss[i]; - } - delete[] loss; - return loss_sum; - } - - void apply_grad() { - for (auto &layer : ffn_model.layers) { - layer.apply_grads(); - } - } - - void print_tensors(Executor *exe) { - printf("loss_tensor: "); - print_tensor(this->loss_tensor, exe); - printf("input: "); - print_tensor(this->input, exe); - printf("output: "); - print_tensor(this->output, exe); - printf("ground_truth: "); - print_tensor(this->ground_truth, exe); - printf("ffn_model: "); - this->ffn_model.print_tensors(exe); - printf("loss_fn: "); - this->loss_fn.print_tensors(exe); - } - - Model &model; - Tensor *loss_tensor, *input, *ground_truth, *output; - Tensor *grad_output; - Tensor *grad_loss; - FFN_Model ffn_model; - LossFn loss_fn; - Executor *exe; - int batch_size; - int num_gpus; - int gpu_id; -}; - -struct Args { - int batch_size; - int dims; - int num_gpus; - int iterations; - int print_interval; - int seed; - bool verbose; -}; - -Args parse_args(int argc, const char **argv) { - string prog = argv[0]; - vector args(argv + 1, argv + argc); - - auto print_help = [&prog]() { - cerr << "Usage: " << prog << " [options]\n" - << "Options:\n" - << " -h, --help\t\t\tPrint this help message\n" - << " -b, --batch-size \t\tBatch size\n" - << " -d, --dims \t\tDimensions\n" - << " -g, --num-gpus \t\tNumber of GPUs\n" - << " -i, --iter \t\tNumber of iterations\n" - << " -p, --print-interval \tPrint interval\n" - << " -s, --seed \t\tRandom seed\n" - << " -v, --verbose\t\t\tVerbose output\n"; - exit(0); - }; - - Args ret; - - // Default arguments - ret.batch_size = 1; - ret.dims = 64; - ret.num_gpus = 1; - ret.iterations = 10; - ret.print_interval = 1; - ret.seed = -1; - ret.verbose = false; - - for (auto it = args.begin(); it != args.end(); ++it) { - if (*it == "-h" || *it == "--help") { - print_help(); - } else if (*it == "-b" || *it == "--batch-size") { - if (++it == args.end()) { - cerr << "Error: missing argument for " << *(it - 1) << endl; - exit(1); - } - ret.batch_size = stoi(*it); - } else if (*it == "-d" || *it == "--dims") { - if (++it == args.end()) { - cerr << "Error: missing argument for " << *(it - 1) << endl; - exit(1); - } - ret.dims = stoi(*it); - } else if (*it == "-g" || *it == "--num-gpus") { - if (++it == args.end()) { - cerr << "Error: missing argument for " << *(it - 1) << endl; - exit(1); - } - ret.num_gpus = stoi(*it); - } else if (*it == "-i" || *it == "--iter") { - if (++it == args.end()) { - cerr << "Error: missing argument for " << *(it - 1) << endl; - exit(1); - } - ret.iterations = stoi(*it); - } else if (*it == "-p" || *it == "--print-interval") { - if (++it == args.end()) { - cerr << "Error: missing argument for " << *(it - 1) << endl; - exit(1); - } - ret.print_interval = stoi(*it); - } else if (*it == "-s" || *it == "--seed") { - if (++it == args.end()) { - cerr << "Error: missing argument for " << *(it - 1) << endl; - exit(1); - } - ret.seed = stoi(*it); - } else if (*it == "-v" || *it == "--verbose") { - ret.verbose = true; - } else { - cerr << "Error: unknown option " << *it << endl; - print_help(); - } - } - - return ret; -} - -int main(int argc, const char **argv) { - Args args = parse_args(argc, argv); - - cout << "--" << endl - << "batch_size=" << args.batch_size << endl - << "dims=" << args.dims << endl - << "num_gpus=" << args.num_gpus << endl - << "iterations=" << args.iterations << endl - << "print_interval=" << args.print_interval << endl - << "seed=" << args.seed << endl - << "verbose=" << args.verbose << endl - << "--" << endl; - - vector pids; - for (int gpu_id = 0; gpu_id < args.num_gpus; ++gpu_id) { - pids.emplace_back(ark::utils::proc_spawn([&] { - ark::srand(args.seed); - - Model model{gpu_id}; - Trainer trainer{model, args.dims, args.batch_size, gpu_id, - args.num_gpus}; - trainer.init_data(); - // train the model. - trainer.train(args.iterations, args.print_interval); - // trainer.print_tensors(trainer.exe); - return 0; - })); - } - int state = 0; - for (auto pid : pids) { - int ret = ark::utils::proc_wait(pid); - if (ret != 0) { - cerr << "E: Process " << pid << " returned " << ret << endl; - state = 1; - } - } - return state; -} diff --git a/examples/llama/README.md b/examples/llama/README.md index 090dd1de3..1fe040ae0 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -29,10 +29,10 @@ Llama2 examples over ARK. 4. Download Llama2 model weights and tokenizer weights. * The model and tokenizer should be compatible with the [official PyTorch implementation](https://github.com/facebookresearch/llama/blob/main/llama). -5. Run the model accuracy test. `--pth_path` is the path to the model weights file (`consolidated.00.pth`). +5. Run the model accuracy test. `--ckpt_dir` is the directory where the model weight files are at (e.g., `consolidated.00.pth`). ```bash - python3 model_test.py --pth_path=/path/to/model/weights.pth + python3 model_test.py --ckpt_dir=/directory/of/model/weights ``` 6. Test text generation. `--pth_path` is the path to the model weights file (`consolidated.00.pth`), `--tok_path` is the path to the tokenizer weights file (`tokenizer.model`), and `--params_path` is the path to the model parameters (`params.json`). diff --git a/examples/llama/model.py b/examples/llama/model.py index 925615bf3..cd1bede29 100644 --- a/examples/llama/model.py +++ b/examples/llama/model.py @@ -9,7 +9,13 @@ import math from dataclasses import dataclass from typing import Optional -import os +from ark import PlannerContext as Context + +NUM_SM = 304 +NUM_WARPS_PER_SM = 8 +NUM_WARPS = NUM_SM * NUM_WARPS_PER_SM +WARP_SIZE = 64 +SRAM_PER_SM = 65536 @dataclass @@ -88,15 +94,28 @@ def __init__( self.eps = eps self.dtype = dtype self.weight = ark.parameter([1, 1, dim], ark.fp32) + self.dim = dim def forward(self, x): - x = ark.cast(x, ark.fp32) - x2 = ark.mul(x, x) - mean = ark.reduce_mean(x2, axis=-1) - rrms = ark.rsqrt(mean) - x = ark.mul(x, rrms) - x = ark.mul(x, self.weight, x) - return ark.cast(x, self.dtype) + with Context( + sync=False, + config={ + "NumWarps": 1, + "SramBytes": 0, + "Granularity": 7, + }, + ): + with Context(config={"Tile": [self.dim]}): + x = ark.cast(x, ark.fp32) + x2 = ark.mul(x, x) + with Context(config={"Tile": [1], "ImplType": "WarpWise"}): + mean = ark.reduce_mean(x2, axis=-1) + mean = ark.add(mean, self.eps) + rrms = ark.rsqrt(mean) + with Context(config={"Tile": [self.dim]}): + x = ark.mul(x, rrms) + x = ark.mul(x, self.weight, x) + return ark.cast(x, self.dtype) class ColumnParallelLinear(ark.Module): @@ -210,7 +229,7 @@ def forward(self, x): local_result = ark.matmul( input_parallel, self.weight, transpose_other=True ) - reduced_result = ark.local_all_reduce( + reduced_result = ark.all_reduce( local_result, self.local_rank, self.world_size ) return reduced_result @@ -236,9 +255,75 @@ def __init__( self.world_size = world_size self.local_rank = local_rank - def forward(self, x): + def forward(self, x: ark.Tensor): if self.world_size == 1: - return ark.embedding(x, self.weight) + config = {"SramBytes": 0} + num_vecs = x.nelems() + if num_vecs >= NUM_WARPS: + config.update({"NumWarps": 1, "Tile": [self.dim]}) + num_parts = 1 + else: + min_elem_per_warp = WARP_SIZE * 2 + max_warps_per_vec = ( + self.dim + min_elem_per_warp - 1 + ) // min_elem_per_warp + warps_per_vec = min(max_warps_per_vec, NUM_WARPS // num_vecs) + if warps_per_vec <= NUM_WARPS_PER_SM: + config.update( + {"NumWarps": warps_per_vec, "Tile": [self.dim]} + ) + num_parts = 1 + else: + num_parts = warps_per_vec // NUM_WARPS_PER_SM + max_num_parts = 4 + assert NUM_SM % max_num_parts == 0 + assert ( + 2 ** (max_num_parts.bit_length() - 1) == max_num_parts + ) + if num_parts > max_num_parts: + num_parts = max_num_parts + # make it max power of 2 smaller than num_parts + num_parts = 2 ** (num_parts.bit_length() - 1) + config.update( + { + "NumWarps": NUM_WARPS_PER_SM, + "Tile": [self.dim // num_parts], + } + ) + with Context(processor_range=[0, NUM_SM], config=config): + if num_parts == 1: + return ark.embedding(x, self.weight) + emb_output = ark.tensor( + [x.shape()[0], x.shape()[1], self.dim], self.dtype + ) + emb_parts = [] + dim_per_part = self.dim // num_parts + for i in range(num_parts): + with Context( + processor_range=[ + i * NUM_SM // num_parts, + (i + 1) * NUM_SM // num_parts, + ] + ): + emb_parts.append( + ark.embedding( + x, + self.weight[ + :, + (i * dim_per_part) : ( + (i + 1) * dim_per_part + ), + ], + emb_output[ + :, + :, + (i * dim_per_part) : ( + (i + 1) * dim_per_part + ), + ], + ) + ) + return ark.identity(emb_output, deps=emb_parts) output_tensor = ark.tensor( [x.shape()[0], x.shape()[1], self.out_dim], self.dtype @@ -262,22 +347,6 @@ def forward(self, x): ) -class Linear(ark.Module): - """ - Linear layer module with weights and no bias. - """ - - def __init__( - self, in_dim: int, out_dim: int, dtype: ark.DataType = ark.fp16 - ): - super().__init__() - self.dtype = dtype - self.weight = ark.parameter([out_dim, in_dim], dtype) - - def forward(self, x): - return ark.matmul(x, self.weight, transpose_other=True) - - class Silu(ark.Module): """ Silu activation function, silu(x) = x * sigmoid(x) @@ -312,6 +381,7 @@ def __init__( hidden_dim = multiple_of * ( (hidden_dim + multiple_of - 1) // multiple_of ) + self.hidden_dim = hidden_dim self.w1 = ColumnParallelLinear( dim, hidden_dim, dtype, False, local_rank, world_size @@ -323,14 +393,99 @@ def __init__( dim, hidden_dim, dtype, False, local_rank, world_size ) - def forward(self, x): - # self.w2(F.silu(self.w1(x)) * self.w3(x)) - x1 = self.w1(x) - x1 = Silu()(x1) - x2 = self.w3(x) - x3 = ark.mul(x1, x2) - x4 = self.w2(x3) - return x4 + def forward(self, x, ffn_norm): + h = ffn_norm(x) + + seqlen = h.shape()[1] + schedule = None + if seqlen == 2048: + schedule = [ + [1792, [256, 128], 24672], + [256, [128, 128], 16480], + ] + elif seqlen == 128: + schedule = [ + [128, [128, 64], 16480], + ] + else: + raise ValueError(f"Unsupported seqlen {seqlen}") + + out_shape = h.shape() + out_shape[-1] = self.w1.out_dim + out = ark.tensor(out_shape, h.dtype()) + pos = 0 + + dim, tile, sram = schedule[0] + + with Context(sync=False, config={"Tile": tile, "NumWarps": 4}): + h_shard = h[:, pos : pos + dim, :] + out_shard = out[:, pos : pos + dim, :] + with Context(config={"SramBytes": sram}): + x1 = ark.matmul(h_shard, self.w1.weight, transpose_other=True) + with Context(config={"SramBytes": 0}): + x1 = Silu()(x1) + + # We don't need a barrier here but somehow the performance is better with it + with Context(sync=False, config={"Tile": tile, "NumWarps": 4}): + with Context(config={"SramBytes": sram}): + x2 = ark.matmul(h_shard, self.w3.weight, transpose_other=True) + with Context(config={"SramBytes": 0}): + x3 = ark.mul(x1, x2, out_shard) + out = ark.identity(out, deps=[x3]) + pos += dim + + if len(schedule) > 1: + dim, tile, sram = schedule[1] + with Context( + processor_range=[0, NUM_SM // 2], + sync=False, + config={"Tile": tile, "NumWarps": 4}, + ): + h_shard = h[:, pos : pos + dim, :] + out_shard = out[:, pos : pos + dim, :] + with Context(config={"SramBytes": sram}): + x1 = ark.matmul( + h_shard, self.w1.weight, transpose_other=True + ) + with Context(config={"SramBytes": 0}): + x1 = Silu()(x1) + + with Context( + processor_range=[NUM_SM // 2, NUM_SM], + sync=False, + config={"Tile": tile, "NumWarps": 4, "SramBytes": sram}, + ): + x2 = ark.matmul(h_shard, self.w3.weight, transpose_other=True) + with Context( + processor_range=[0, NUM_SM], + sync=False, + config={"Tile": tile, "NumWarps": 4}, + ): + with Context(config={"SramBytes": 0}): + x3 = ark.mul(x1, x2, out_shard) + out = ark.identity(out, deps=[x3]) + pos += dim + + if seqlen == 2048: + tile = [256, 128] + sram = 24672 + elif seqlen == 128: + tile = [128, 64] + sram = 16480 + else: + raise ValueError(f"Unsupported seqlen {seqlen}") + + with Context( + warp_range=[0, 4], + config={ + "NumWarps": 4, + "Tile": tile, + "SramBytes": sram, + }, + sync=False, + ): + ff = self.w2(out) + return ark.add(x, ff) def apply_rotary_emb(xq, xk, freqs_cis): @@ -356,6 +511,7 @@ def __init__( ) model_parallel_size = world_size self.dtype = dtype + self.args = args self.n_local_heads = args.n_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads @@ -399,49 +555,215 @@ def forward( start_pos: int, freqs_cis: ark.Tensor, mask: Optional[ark.Tensor], + attention_norm, ): bsz, seqlen, _ = x.shape() - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - # xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) - # xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - # xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - xq = ark.reshape(xq, [bsz, seqlen, self.n_local_heads, self.head_dim]) - xk = ark.reshape( - xk, [bsz, seqlen, self.n_local_kv_heads, self.head_dim] - ) - xv = ark.reshape( - xv, [bsz, seqlen, self.n_local_kv_heads, self.head_dim] - ) - if freqs_cis is not None: - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - # TODO: enable kv cache later - keys = xk - values = xv - # (bs, n_local_heads, seqlen, head_dim) - xq = ark.transpose(xq, [0, 2, 1, 3]) - values = ark.transpose(values, [0, 2, 1, 3]) - - # (bs, n_local_heads, head_dim, seqlen) - keys = ark.transpose(keys, [0, 2, 3, 1]) - scores = ark.matmul(xq, keys) - scores = ark.mul(scores, 1.0 / math.sqrt(self.head_dim)) - - if mask is not None: - scores = ark.add(scores, mask) - # if self.dtype == ark.fp16: - # scores = ark.cast(scores, ark.fp32) - scores = ark.softmax(scores, output=scores) - # if self.dtype == ark.fp16: - # scores = ark.cast(scores, ark.fp16) - - output = ark.matmul( - scores, values - ) # (bs, n_local_heads, seqlen, head_dim) - output = ark.transpose(output, [0, 2, 1, 3]) + + x_norm = attention_norm(x) + + xq_scratch = ark.tensor( + [ + bsz, + seqlen * self.n_local_heads, + self.n_local_heads, + self.head_dim, + ], + self.dtype, + ) + xk_scratch = ark.tensor( + [ + bsz, + seqlen * self.n_local_kv_heads, + self.n_local_kv_heads, + self.head_dim, + ], + self.dtype, + ) + + def calc_scores(xq_scratch, xk_scratch, mask): + xq = xq_scratch[:, :, 0, :] + xk = xk_scratch[:, :, 0, :] + xq = ark.reshape( + xq, [bsz, self.n_local_heads, seqlen, self.head_dim] + ) + xk = ark.reshape( + xk, [bsz, self.n_local_kv_heads, seqlen, self.head_dim] + ) + if seqlen == 2048: + tile = [256, 128] + sram = 24672 + elif seqlen == 128: + tile = [128, 128] + sram = 16480 + else: + raise ValueError(f"Unsupported seqlen {seqlen}") + + with Context( + sync=False, + config={ + "Tile": tile, + "SramBytes": sram, + "NumWarps": 4, + "BatchStrideCA": self.head_dim, + "BatchStrideNA": ( + self.n_local_heads * seqlen * self.head_dim + ), + "BatchStrideCB": self.head_dim, + "BatchStrideNB": ( + self.n_local_kv_heads * seqlen * self.head_dim + ), + }, + ): + scores = ark.matmul(xq, xk, transpose_other=True) + scores = ark.mul(scores, 1.0 / math.sqrt(self.head_dim), scores) + if mask is not None: + scores = ark.add(scores, mask, scores) + return scores + + def softmax(scores): + with Context( + sram_range=[0, 0], + sync=False, + config={ + "NumWarps": 1, + "SramBytes": 0, + }, + ): + with Context(config={"ImplType": "WarpWise", "Tile": [1]}): + max = ark.reduce_max(scores, axis=-1) + with Context(config={"Tile": [seqlen]}): + tmp = ark.sub(scores, max) + tmp = ark.exp(tmp) + with Context(config={"ImplType": "WarpWise", "Tile": [1]}): + sum = ark.reduce_sum(tmp, axis=-1) + with Context(config={"Tile": [seqlen]}): + output = ark.div(tmp, sum) + return output + + if seqlen == 2048: + tile = [256, 128] + sram = 24672 + elif seqlen == 128: + tile = [128, 64] + sram = 16480 + else: + raise ValueError(f"Unsupported seqlen {seqlen}") + + with Context( + processor_range=[0, 128], + config={"NumWarps": 4}, + sync=False, + ): + with Context(config={"SramBytes": sram, "Tile": tile}): + xq = ark.matmul(x_norm, self.wq.weight, transpose_other=True) + xq = ark.reshape( + xq, [bsz, seqlen, self.n_local_heads, self.head_dim] + ) + with Context( + config={"SramBytes": 0, "Tile": [tile[0], 1, tile[1]]} + ): + if freqs_cis is not None: + xq = ark.rope(xq, freqs_cis, xq_scratch[:, :seqlen, :, :]) + + xq_scratch = ark.identity(xq_scratch, deps=[xq]) + + with Context( + processor_range=[128, 256], + config={"NumWarps": 4}, + sync=False, + ): + with Context(config={"SramBytes": sram, "Tile": tile}): + xk = ark.matmul(x_norm, self.wk.weight, transpose_other=True) + xk = ark.reshape( + xk, [bsz, seqlen, self.n_local_kv_heads, self.head_dim] + ) + with Context( + config={"SramBytes": 0, "Tile": [tile[0], 1, tile[1]]} + ): + if freqs_cis is not None: + xk = ark.rope(xk, freqs_cis, xk_scratch[:, :seqlen, :, :]) + + xk_scratch = ark.identity(xk_scratch, deps=[xk]) + + with Context( + processor_range=[256, NUM_SM], + config={"NumWarps": 4}, + sync=False, + ): + with Context(config={"SramBytes": sram, "Tile": tile}): + xv = ark.matmul(x_norm, self.wv.weight, transpose_other=True) + xv = ark.reshape( + xv, [bsz, seqlen, self.n_local_kv_heads, self.head_dim] + ) + + with Context( + processor_range=[0, 256], + ): + scores = calc_scores(xq_scratch, xk_scratch, mask) + scores = softmax(scores) + + output_scratch = ark.tensor( + [ + bsz, + seqlen * self.n_local_heads, + self.n_local_heads, + self.head_dim, + ], + dtype=self.dtype, + ) + if seqlen == 2048: + tile = [256, 128] + sram = 24672 + elif seqlen == 128: + tile = [128, 128] + sram = 16480 + else: + raise ValueError(f"Unsupported seqlen {seqlen}") + + with Context( + sync=False, + config={ + "Tile": tile, + "SramBytes": sram, + "NumWarps": 4, + "BatchStrideCB": self.head_dim, + "BatchStrideNB": self.n_local_kv_heads * seqlen * self.head_dim, + "BatchStrideCC": self.head_dim, + "BatchStrideNC": self.n_local_kv_heads * seqlen * self.head_dim, + }, + ): + xv = ark.reshape(xv[:, :, 0, :], [bsz, 1, seqlen, self.head_dim]) + output = ark.reshape( + output_scratch[:, :, 0, :], + [bsz, self.n_local_heads, seqlen, self.head_dim], + ) + output = ark.matmul(scores, xv, output) + output = ark.identity( + output_scratch[:, :seqlen, :, :], deps=[output] + ) + output = ark.reshape( output, [bsz, seqlen, self.head_dim * self.n_local_heads] ) - return self.wo(output) + if seqlen == 2048: + tile = [256, 128] + sram = 24672 + elif seqlen == 128: + tile = [128, 128] + sram = 16480 + else: + raise ValueError(f"Unsupported seqlen {seqlen}") + + with Context( + config={ + "NumWarps": 4, + "Tile": tile, + "SramBytes": sram, + }, + sync=False, + ): + output = self.wo(output) + return ark.add(x, output) class TransformerBlock(ark.Module): @@ -478,11 +800,10 @@ def forward( freqs_cis: ark.Tensor, mask: Optional[ark.Tensor], ): - attention_norm_x = self.attention_norm(x) - h = self.attention.forward(attention_norm_x, start_pos, freqs_cis, mask) - h = ark.add(x, h) - out = ark.add(h, self.feed_forward(self.ffn_norm(h))) - return out + h = self.attention.forward( + x, start_pos, freqs_cis, mask, self.attention_norm + ) + return self.feed_forward(h, self.ffn_norm) class Transformer(ark.Module): @@ -522,10 +843,25 @@ def forward( freqs_cis: ark.Tensor, mask: Optional[ark.Tensor], ): - h = self.tok_embeddings(tokens) - - for layer in self.layers: - h = layer(h, start_pos, freqs_cis, mask) - h = self.norm(h) - output = self.output(h) - return output + with Context(warp_range=[0, NUM_WARPS_PER_SM], sram_range=[0, 49344]): + h = self.tok_embeddings(tokens) + + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + + seqlen = h.shape()[1] + if seqlen == 2048: + tile = [256, 128] + sram = 24672 + elif seqlen == 128: + tile = [128, 128] + sram = 16480 + else: + raise ValueError(f"Unsupported seqlen {seqlen}") + + with Context( + config={"Tile": tile, "SramBytes": sram, "NumWarps": 4} + ): + output = self.output(h) + return output diff --git a/examples/llama/model_test.py b/examples/llama/model_test.py index 737d3ec8b..2ed2d0e63 100644 --- a/examples/llama/model_test.py +++ b/examples/llama/model_test.py @@ -59,8 +59,7 @@ def run_ark( output = module(*module_inputs) runtime = ark.Runtime() - # Prefer num_warps_per_sm = 16 for nvidia and 8 for amd - runtime.launch(num_warps_per_sm=8) + runtime.launch() # Load model parameters if state_dict: @@ -70,7 +69,8 @@ def run_ark( tensors = [i for i in module_inputs if isinstance(i, ark.Tensor)] tensor_data = [i for i in inputs if isinstance(i, np.ndarray)] for tensor, ndarray in zip(tensors, tensor_data): - tensor.from_numpy(ndarray) + if tensor.data_ptr() != 0: + tensor.from_numpy(ndarray) start_time = time.time() @@ -406,69 +406,6 @@ def test_attention( ) -def test_transformer_block( - args: ModelArgs, - batch_size: int, - seq_len: int, - dtype: np.dtype, - rank: int = 0, - world_size: int = 1, -): - # - freqs_cis = precompute_freqs_cis( - args.dim // args.n_heads, args.max_seq_len * 2 - )[0:seq_len] - - freqs_cis_ark = freqs_cis.astype(np.complex64) - freqs_cis_ark = ( - np.stack([freqs_cis_ark.real, freqs_cis_ark.imag], axis=-1) - .astype(dtype) - .reshape(1, seq_len, 1, args.dim // args.n_heads) - ) - - feature = np.random.uniform( - low=-1, high=1, size=(batch_size, seq_len, args.dim) - ).astype(dtype) - - module = model_ark.Attention( - args, ark.DataType.from_numpy(dtype), rank, world_size - ) - # module_inputs = [ - # ark.tensor(list(i.shape), ark.DataType.from_numpy(i.dtype)) - # if isinstance(i, np.ndarray) - # else i - # for i in inputs - # ] - feature_tensor = ark.tensor( - list(feature.shape), ark.DataType.from_numpy(feature.dtype) - ) - freqs_cis_ark_tensor = ark.tensor( - list(freqs_cis_ark.shape), ark.DataType.from_numpy(freqs_cis_ark.dtype) - ) - output = module(feature_tensor, 0, freqs_cis_ark_tensor, None) - - ark.Model.get_model().create_nodes() - print(ark.Model.get_model().serialize()) - - # test_module( - # module_class_ark=model_ark.TransformerBlock, - # module_args_ark=[ - # 0, - # args, - # ark.DataType.from_numpy(dtype), - # rank, - # world_size, - # ], - # inputs_ark=[feature, 0, freqs_cis_ark, None], - # module_class_pt=model_pt.TransformerBlock, - # module_args_pt=[0, args], - # inputs_pt=[feature.astype(dtype), 0, freqs_cis, None], - # module_name_prefix="layers.0", - # rank=rank, - # world_size=world_size, - # ) - - def test_transformer( args: ModelArgs, batch_size: int, @@ -536,8 +473,7 @@ def test(args, batch_size, seq_len, dtype, rank, world_size): # test_row_parallel_linear(args, batch_size, seq_len, dtype, rank, world_size) # test_column_parallel_linear(args, batch_size, seq_len, dtype, rank, world_size) # test_attention(args, batch_size, seq_len, dtype, rank, world_size) - test_transformer_block(args, batch_size, seq_len, dtype, rank, world_size) - # test_transformer(args, batch_size, seq_len, dtype, rank, world_size) + test_transformer(args, batch_size, seq_len, dtype, rank, world_size) def worker( @@ -561,16 +497,17 @@ def worker( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--ckpt_dir", type=str, required=True) - parser.add_argument("--ngpus", type=int, default=1) + parser.add_argument("--ngpus", type=int, default=1, help="Number of GPUs") + parser.add_argument("--ckpt_dir", type=str) ckpt_dir = parser.parse_args().ckpt_dir ngpus = parser.parse_args().ngpus # Configurations args = ModelArgs7B() + args.n_layers = 1 batch_size = 1 - seq_len = 512 + seq_len = 2048 dtype = np.float16 world_size = ngpus @@ -578,7 +515,7 @@ def worker( args.vocab_size = 32000 # Reduce max_seq_len due to OOM from the PyTorch model - args.max_seq_len = 512 + args.max_seq_len = 2048 # Verify the configurations assert batch_size <= args.max_batch_size diff --git a/examples/tutorial/default_plan.json b/examples/tutorial/default_plan.json deleted file mode 100644 index c6b4be243..000000000 --- a/examples/tutorial/default_plan.json +++ /dev/null @@ -1,270 +0,0 @@ -{ - "Rank": 0, - "WorldSize": 1, - "NumProcessors": 108, - "NumWarpsPerProcessor": 8, - "TaskInfos": [ - { - "Id": 0, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":1,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":4,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 1, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "Sigmoid", - "Name": "sigmoid", - "IsVirtual": false, - "ReadTensors": [ - {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":6,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": {}, - "Config": { - "NumWarps": 1, - "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 88064 - } - } - ] - }, - { - "Id": 2, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":8,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": {}, - "Config": { - "NumWarps": 1, - "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 88064 - } - } - ] - }, - { - "Id": 3, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":3,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":10,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 4, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":12,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": {}, - "Config": { - "NumWarps": 1, - "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 88064 - } - } - ] - }, - { - "Id": 5, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_2", - "IsVirtual": false, - "ReadTensors": [ - {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":2,"DataType":"FP16","Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008],"Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":14,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":15,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 64 - } - } - ] - } - ], - "ProcessorGroups": [ - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":0,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":1,"TaskRange":[0,88064],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":2,"TaskRange":[0,88064],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":3,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":4,"TaskRange":[0,88064],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,64], - "ResourceGroups": [ - { - "ProcessorRange": [0,64], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":5,"TaskRange":[0,64],"Granularity":1} - ] - } - ] - } - ] -} diff --git a/examples/tutorial/model.json b/examples/tutorial/model.json deleted file mode 100644 index c2b88bbd0..000000000 --- a/examples/tutorial/model.json +++ /dev/null @@ -1,140 +0,0 @@ -{ - "Rank": 0, - "WorldSize": 1, - "Nodes": [ - { - "Id": 0, - "ProducerNodeIds": [], - "ConsumerNodeIds": [1,2], - "Op": { - "Type": "Matmul", - "Name": "matmul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":1,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":4,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - } - } - }, - { - "Id": 1, - "ProducerNodeIds": [0], - "ConsumerNodeIds": [2], - "Op": { - "Type": "Sigmoid", - "Name": "sigmoid", - "IsVirtual": false, - "ReadTensors": [ - {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":6,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": {} - } - }, - { - "Id": 2, - "ProducerNodeIds": [0,1], - "ConsumerNodeIds": [4], - "Op": { - "Type": "Mul", - "Name": "mul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":8,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": {} - } - }, - { - "Id": 3, - "ProducerNodeIds": [], - "ConsumerNodeIds": [4], - "Op": { - "Type": "Matmul", - "Name": "matmul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":3,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":10,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - } - } - }, - { - "Id": 4, - "ProducerNodeIds": [2,3], - "ConsumerNodeIds": [5], - "Op": { - "Type": "Mul", - "Name": "mul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":12,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": {} - } - }, - { - "Id": 5, - "ProducerNodeIds": [4], - "ConsumerNodeIds": [], - "Op": { - "Type": "Matmul", - "Name": "matmul_2", - "IsVirtual": false, - "ReadTensors": [ - {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":2,"DataType":"FP16","Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008],"Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "WriteTensors": [ - {"Id":14,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "ResultTensors": [ - {"Id":15,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - } - } - } - ] -} diff --git a/examples/tutorial/plan.json b/examples/tutorial/plan.json deleted file mode 100644 index c0854e505..000000000 --- a/examples/tutorial/plan.json +++ /dev/null @@ -1,357 +0,0 @@ -{ - "Rank": 0, - "WorldSize": 1, - "NumProcessors": 108, - "NumWarpsPerProcessor": 8, - "TaskInfos": [ - { - "Id": 0, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":1,"DataType":"FP16","Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} - ], - "WriteTensors": [ - {"Id":4,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 1, - "NumWarps": 8, - "SramBytes": 0, - "Ops": [ - { - "Type": "Sigmoid", - "Name": "sigmoid", - "IsVirtual": false, - "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "WriteTensors": [ - {"Id":6,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":7,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": {}, - "Config": { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 2, - "NumWarps": 8, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":7,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "WriteTensors": [ - {"Id":8,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": {}, - "Config": { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 3, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":3,"DataType":"FP16","Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} - ], - "WriteTensors": [ - {"Id":10,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 4, - "NumWarps": 8, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "WriteTensors": [ - {"Id":12,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":13,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": {}, - "Config": { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 5, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":16,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,8320],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,8320]}, - {"Id":17,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[4096,8320],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,8320]} - ], - "WriteTensors": [ - {"Id":14,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "ResultTensors": [ - {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 64 - } - } - ] - }, - { - "Id": 6, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":18,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,2688],"Strides":[1,512,11008],"Offsets":[0,0,8320],"PaddedShape":[1,512,2688]}, - {"Id":19,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[4096,2688],"Strides":[4096,11008],"Offsets":[0,8320],"PaddedShape":[4096,2688]} - ], - "WriteTensors": [ - {"Id":20,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "ResultTensors": [ - {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 64 - } - } - ] - }, - { - "Id": 7, - "NumWarps": 8, - "SramBytes": 0, - "Ops": [ - { - "Type": "Add", - "Name": "add_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "WriteTensors": [ - {"Id":23,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "ResultTensors": [ - {"Id":15,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "Args": {}, - "Config": { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 64 - } - } - ] - } - ], - "ProcessorGroups": [ - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":0,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":1,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":2,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":3,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":4,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,64], - "ResourceGroups": [ - { - "ProcessorRange": [0,64], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":5,"TaskRange":[0,64],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [64,108], - "ResourceGroups": [ - { - "ProcessorRange": [64,108], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":6,"TaskRange":[0,64],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,64], - "WarpRange": [0,8], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":7,"TaskRange":[0,64],"Granularity":1} - ] - } - ] - } - ] -} diff --git a/examples/tutorial/plan_1_larger_tile.json b/examples/tutorial/plan_1_larger_tile.json deleted file mode 100644 index 3a3f66530..000000000 --- a/examples/tutorial/plan_1_larger_tile.json +++ /dev/null @@ -1,270 +0,0 @@ -{ - "Rank": 0, - "WorldSize": 1, - "NumProcessors": 108, - "NumWarpsPerProcessor": 8, - "TaskInfos": [ - { - "Id": 0, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":1,"DataType":"FP16","Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} - ], - "WriteTensors": [ - {"Id":4,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 1, - "NumWarps": 8, - "SramBytes": 0, - "Ops": [ - { - "Type": "Sigmoid", - "Name": "sigmoid", - "IsVirtual": false, - "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "WriteTensors": [ - {"Id":6,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":7,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": {}, - "Config": { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 2, - "NumWarps": 8, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":7,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "WriteTensors": [ - {"Id":8,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": {}, - "Config": { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 3, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":3,"DataType":"FP16","Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} - ], - "WriteTensors": [ - {"Id":10,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 4, - "NumWarps": 8, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "WriteTensors": [ - {"Id":12,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":13,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": {}, - "Config": { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 5, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":13,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":2,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008]} - ], - "WriteTensors": [ - {"Id":14,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "ResultTensors": [ - {"Id":15,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 64 - } - } - ] - } - ], - "ProcessorGroups": [ - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":0,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":1,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":2,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":3,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":4,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,64], - "ResourceGroups": [ - { - "ProcessorRange": [0,64], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":5,"TaskRange":[0,64],"Granularity":1} - ] - } - ] - } - ] -} diff --git a/examples/tutorial/plan_2_split_k.json b/examples/tutorial/plan_2_split_k.json deleted file mode 100644 index 493515d8c..000000000 --- a/examples/tutorial/plan_2_split_k.json +++ /dev/null @@ -1,357 +0,0 @@ -{ - "Rank": 0, - "WorldSize": 1, - "NumProcessors": 108, - "NumWarpsPerProcessor": 8, - "TaskInfos": [ - { - "Id": 0, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":1,"DataType":"FP16","Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} - ], - "WriteTensors": [ - {"Id":4,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 1, - "NumWarps": 8, - "SramBytes": 0, - "Ops": [ - { - "Type": "Sigmoid", - "Name": "sigmoid", - "IsVirtual": false, - "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "WriteTensors": [ - {"Id":6,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":7,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": {}, - "Config": { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 2, - "NumWarps": 8, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":7,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "WriteTensors": [ - {"Id":8,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": {}, - "Config": { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 3, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":3,"DataType":"FP16","Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} - ], - "WriteTensors": [ - {"Id":10,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 4, - "NumWarps": 8, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "WriteTensors": [ - {"Id":12,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":13,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": {}, - "Config": { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 5, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":16,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,8320],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,8320]}, - {"Id":17,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[4096,8320],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,8320]} - ], - "WriteTensors": [ - {"Id":14,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "ResultTensors": [ - {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 64 - } - } - ] - }, - { - "Id": 6, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":18,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,2688],"Strides":[1,512,11008],"Offsets":[0,0,8320],"PaddedShape":[1,512,2688]}, - {"Id":19,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[4096,2688],"Strides":[4096,11008],"Offsets":[0,8320],"PaddedShape":[4096,2688]} - ], - "WriteTensors": [ - {"Id":20,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "ResultTensors": [ - {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 64 - } - } - ] - }, - { - "Id": 7, - "NumWarps": 8, - "SramBytes": 0, - "Ops": [ - { - "Type": "Add", - "Name": "add_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "WriteTensors": [ - {"Id":23,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "ResultTensors": [ - {"Id":15,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "Args": {}, - "Config": { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 64 - } - } - ] - } - ], - "ProcessorGroups": [ - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":0,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":1,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":2,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":3,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":4,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,64], - "ResourceGroups": [ - { - "ProcessorRange": [0,64], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":5,"TaskRange":[0,64],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [64,108], - "ResourceGroups": [ - { - "ProcessorRange": [64,108], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":6,"TaskRange":[0,64],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,64], - "WarpRange": [0,8], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":7,"TaskRange":[0,64],"Granularity":1} - ] - } - ] - } - ] -} diff --git a/examples/tutorial/plan_3_overwrite.json b/examples/tutorial/plan_3_overwrite.json deleted file mode 100644 index c0854e505..000000000 --- a/examples/tutorial/plan_3_overwrite.json +++ /dev/null @@ -1,357 +0,0 @@ -{ - "Rank": 0, - "WorldSize": 1, - "NumProcessors": 108, - "NumWarpsPerProcessor": 8, - "TaskInfos": [ - { - "Id": 0, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":1,"DataType":"FP16","Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} - ], - "WriteTensors": [ - {"Id":4,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 1, - "NumWarps": 8, - "SramBytes": 0, - "Ops": [ - { - "Type": "Sigmoid", - "Name": "sigmoid", - "IsVirtual": false, - "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "WriteTensors": [ - {"Id":6,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":7,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": {}, - "Config": { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 2, - "NumWarps": 8, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":7,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "WriteTensors": [ - {"Id":8,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": {}, - "Config": { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 3, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":3,"DataType":"FP16","Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} - ], - "WriteTensors": [ - {"Id":10,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 4, - "NumWarps": 8, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "WriteTensors": [ - {"Id":12,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "ResultTensors": [ - {"Id":13,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} - ], - "Args": {}, - "Config": { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 172 - } - } - ] - }, - { - "Id": 5, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":16,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,8320],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,8320]}, - {"Id":17,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[4096,8320],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,8320]} - ], - "WriteTensors": [ - {"Id":14,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "ResultTensors": [ - {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 64 - } - } - ] - }, - { - "Id": 6, - "NumWarps": 8, - "SramBytes": 147456, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":18,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,2688],"Strides":[1,512,11008],"Offsets":[0,0,8320],"PaddedShape":[1,512,2688]}, - {"Id":19,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[4096,2688],"Strides":[4096,11008],"Offsets":[0,8320],"PaddedShape":[4096,2688]} - ], - "WriteTensors": [ - {"Id":20,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "ResultTensors": [ - {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], - "NumTasks": 64 - } - } - ] - }, - { - "Id": 7, - "NumWarps": 8, - "SramBytes": 0, - "Ops": [ - { - "Type": "Add", - "Name": "add_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "WriteTensors": [ - {"Id":23,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "ResultTensors": [ - {"Id":15,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} - ], - "Args": {}, - "Config": { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 64 - } - } - ] - } - ], - "ProcessorGroups": [ - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":0,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":1,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":2,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":3,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":4,"TaskRange":[0,172],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,64], - "ResourceGroups": [ - { - "ProcessorRange": [0,64], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":5,"TaskRange":[0,64],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [64,108], - "ResourceGroups": [ - { - "ProcessorRange": [64,108], - "WarpRange": [0,8], - "SramRange": [0,147456], - "TaskGroups": [ - {"TaskId":6,"TaskRange":[0,64],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,108], - "ResourceGroups": [ - { - "ProcessorRange": [0,64], - "WarpRange": [0,8], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":7,"TaskRange":[0,64],"Granularity":1} - ] - } - ] - } - ] -} diff --git a/examples/tutorial/plan_tutorial.py b/examples/tutorial/plan_tutorial.py deleted file mode 100644 index 560021522..000000000 --- a/examples/tutorial/plan_tutorial.py +++ /dev/null @@ -1,395 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import argparse -import ark -import time -import json -import numpy as np -from dataclasses import dataclass -from typing import Optional -from pathlib import Path - - -@dataclass -class ModelArgs: - dim: int = 4096 - n_layers: int = 32 - n_heads: int = 32 - n_kv_heads: Optional[int] = None - vocab_size: int = -1 # defined later by tokenizer - multiple_of: int = ( - 256 # make SwiGLU hidden layer size multiple of large power of 2 - ) - ffn_dim_multiplier: Optional[float] = None - norm_eps: float = 1e-5 - max_batch_size: int = 32 - max_seq_len: int = 2048 - - -class ColumnParallelLinear(ark.Module): - """Linear layer with column parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its second dimension as A = [A_1, ..., A_p]. - Here the weight = A^T, so we need to partition the weight matrix along - its first dimension. - - """ - - def __init__( - self, - in_dim: int, - out_dim: int, - dtype: np.dtype, - gather_output: bool = True, - local_rank: int = 0, - world_size: int = 1, - ): - super().__init__() - self.in_dim = in_dim - self.out_dim = out_dim - self.dtype = dtype - self.local_rank = local_rank - self.world_size = world_size - self.gather_output = gather_output - - self.weight = ark.parameter( - [out_dim // world_size, in_dim], ark.DataType.from_numpy(dtype) - ) - self.data = None - - def forward(self, x): - if self.world_size == 1 or self.gather_output == False: - return ark.matmul(x, self.weight, transpose_other=True) - # We need to concat the output_tensor_shards along the last dimension - output_tensor = ark.tensor( - [x.shape()[0], x.shape()[1], self.out_dim], - ark.DataType.from_numpy(self.dtype), - ) - output_tensor_shards = ark.sharding( - output_tensor, - axis=2, - dim_per_shard=self.out_dim // self.world_size, - ) - local_result = ark.identity( - output_tensor_shards[self.local_rank], deps=output_tensor_shards - ) - # (batch_size, seq_len, out_dim // world_size) - local_result = ark.matmul( - x, self.weight, local_result, transpose_other=True - ) - gather_input = ark.identity(output_tensor, deps=[local_result]) - # return gather_input - gather_reshape = ark.reshape( - gather_input, [x.shape()[0] * x.shape()[1], self.out_dim] - ) - gather_out = ark.local_all_gather( - gather_reshape, self.local_rank, self.world_size, 1 - ) - return ark.reshape( - gather_out, [x.shape()[0], x.shape()[1], self.out_dim] - ) - - def initialize(self): - if self.data is None: - data = np.random.uniform( - low=-0.1, high=0.1, size=self.weight.shape() - ).astype(self.dtype) - self.data = data - self.weight.from_numpy(self.data) - - -class RowParallelLinear(ark.Module): - """Linear layer with row parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its first dimension and X along its second dimension as: - - - - | A_1 | - | . | - A = | . | X = [X_1, ..., X_p] - | . | - | A_p | - - - - - Here the weight = A^T, so we need to partition the weight matrix along - its second dimension. - """ - - def __init__( - self, - in_dim: int, - out_dim: int, - dtype: ark.DataType = ark.fp16, - input_is_parallel: bool = False, - local_rank: int = 0, - world_size: int = 1, - ): - super().__init__() - self.in_dim = in_dim - self.out_dim = out_dim - self.dtype = dtype - self.local_rank = local_rank - self.world_size = world_size - self.input_is_parallel = input_is_parallel - - self.weight = ark.parameter( - [out_dim, in_dim // world_size], ark.DataType.from_numpy(self.dtype) - ) - self.data = None - - def forward(self, x): - if self.world_size == 1: - return ark.matmul(x, self.weight, transpose_other=True) - x_ndims = len(x.shape()) - if self.input_is_parallel: - input_parallel = x - else: - x_shards = ark.sharding( - x, x_ndims - 1, self.in_dim // self.world_size - ) - input_parallel = x_shards[self.local_rank] - local_result = ark.matmul( - input_parallel, self.weight, transpose_other=True - ) - reduced_result = ark.local_all_reduce( - local_result, self.local_rank, self.world_size - ) - return reduced_result - - def initialize(self): - if self.data is None: - data = np.random.uniform( - low=-0.1, high=0.1, size=self.weight.shape() - ).astype(self.dtype) - self.data = data - self.weight.from_numpy(self.data) - - -class Silu(ark.Module): - """ - Silu activation function, silu(x) = x * sigmoid(x) - """ - - def __init__(self): - super().__init__() - - def forward(self, x: ark.Tensor): - x1 = ark.sigmoid(x) - return ark.mul(x, x1) - - -class FeedForward(ark.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int, - ffn_dim_multiplier: Optional[float], - dtype: np.dtype, - local_rank: int = 0, - world_size: int = 1, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ( - (hidden_dim + multiple_of - 1) // multiple_of - ) - - self.w1 = ColumnParallelLinear( - dim, hidden_dim, dtype, False, local_rank, world_size - ) - self.w2 = RowParallelLinear( - hidden_dim, dim, dtype, True, local_rank, world_size - ) - self.w3 = ColumnParallelLinear( - dim, hidden_dim, dtype, False, local_rank, world_size - ) - - def forward(self, x): - # self.w2(F.silu(self.w1(x)) * self.w3(x)) - x1 = self.w1(x) - x1 = Silu()(x1) - x2 = self.w3(x) - x3 = ark.mul(x1, x2) - x4 = self.w2(x3) - return x4 - - def initialize(self): - self.w1.initialize() - self.w2.initialize() - self.w3.initialize() - - -class Input(ark.Module): - def __init__( - self, batch_size: int, seq_len: int, dim: int, dtype: np.dtype - ): - super().__init__() - self.tensor = ark.tensor( - (batch_size, seq_len, dim), ark.DataType.from_numpy(dtype) - ) - self.data = None - - def forward(self): - return self.tensor - - def initialize(self): - if self.data is None: - self.data = np.random.uniform( - low=-0.1, high=0.1, size=self.tensor.shape() - ).astype(self.tensor.dtype().to_numpy()) - self.tensor.from_numpy(self.data) - - -def compare_results(result, ground_truth): - eps = np.finfo(result.dtype).eps - result = result.flatten() - ground_truth = ground_truth.flatten() - - max_value_idx = np.argmax(ground_truth) - min_value_idx = np.argmin(ground_truth) - - abs_diff = np.abs(result - ground_truth) - max_abs_diff_idx = np.argmax(abs_diff) - max_abs_diff = abs_diff[max_abs_diff_idx] - - abs_pt = np.abs(ground_truth) - rel_diff = abs_diff / (abs_pt + eps) - max_rel_diff_idx = np.argmax(rel_diff) - max_rel_diff = rel_diff[max_rel_diff_idx] - - # max rel_diff where abs_pt is larger than 1e-3 - max_rel_diff_3_idx = np.argmax(rel_diff * (abs_pt > 1e-3)) - max_rel_diff_3 = rel_diff[max_rel_diff_3_idx] - - mean_square_error = np.mean(np.square(result - ground_truth)) - - # Test info as string - - print( - f"Comparing ground truth vs results\n" - f" max_value: {ground_truth[max_value_idx]} vs {result[max_value_idx]} at index {max_value_idx}\n" - f" min_value: {ground_truth[min_value_idx]} vs {result[min_value_idx]} at index {min_value_idx}\n" - f" max_abs_diff: {max_abs_diff:.4e} ({ground_truth[max_abs_diff_idx]} vs {result[max_abs_diff_idx]} at index {max_abs_diff_idx})\n" - f" max_rel_diff: {max_rel_diff:.4e} ({ground_truth[max_rel_diff_idx]} vs {result[max_rel_diff_idx]} at index {max_rel_diff_idx})\n" - f" max_rel_diff_3: {max_rel_diff_3:.4e} ({ground_truth[max_rel_diff_3_idx]} vs {result[max_rel_diff_3_idx]} at index {max_rel_diff_3_idx})\n" - f" mean_square_error: {mean_square_error:.4e}\n" - ) - - -def config_rule_larger_tile(op: str, arch: str) -> str: - j = json.loads(op) - op_type = j["Type"] - if op_type == "Sigmoid" or op_type == "Mul": - pshape = j["ResultTensors"][0]["PaddedShape"] - if len(pshape) < 2 or pshape[-2] % 128 != 0 or pshape[-1] % 256 != 0: - return "" - num_tasks = pshape[-2] // 128 * pshape[-1] // 256 - cfg = { - "NumWarps": 8, - "SramBytes": 0, - "Tile": [128, 256], - "NumTasks": num_tasks, - } - return json.dumps(cfg) - return "" - - -def main(plan_path: str): - args = ModelArgs() - batch_size = 1 - seq_len = 512 - dtype = np.float16 - seed = int(time.time()) - - print(f"seed: {seed}") - np.random.seed(seed) - ark.srand(seed) - - InputModule = Input(batch_size, seq_len, args.dim, dtype) - input_tensor = InputModule() - - # Declare model - FeedForwardModule = FeedForward( - dim=args.dim, - hidden_dim=4 * args.dim, - multiple_of=args.multiple_of, - ffn_dim_multiplier=args.ffn_dim_multiplier, - dtype=dtype, - ) - output_tensor = FeedForwardModule(input_tensor) - - # Write model.json - with open("model.json", "w") as f: - f.write(ark.Model.get_model().compress().serialize()) - - # Calculate default result - ground_truth = None - with ark.Runtime.get_runtime() as rt: - planner = ark.Planner() - - # If this rule is installed, default planner will perform the same as - # `plan_1_larger_tile.json` on A100. - # planner.install_config_rule(config_rule_larger_tile) - - plan = planner.plan() - with open("default_plan.json", "w") as f: - f.write(plan) - rt.launch(plan=plan) - - # Initialize - InputModule.initialize() - FeedForwardModule.initialize() - - # Calculate output - rt.run() - ground_truth = output_tensor.to_numpy() - - # Measure throughput - iter = 100 - ts = time.time() - rt.run(iter) - elapsed_ms = (time.time() - ts) * 1e3 - print( - f"DefaultPlan elapsed time: total {elapsed_ms:.6f} ms, {elapsed_ms/iter:.6f} ms/iter" - ) - - # Run `plan_path` file if exists - if not Path(plan_path).is_file(): - print(f"File {plan_path} does not exist. Exiting...") - return - with ark.Runtime.get_runtime() as rt: - rt.launch(plan_path=plan_path) - - # Initialize - InputModule.initialize() - FeedForwardModule.initialize() - - # Calculate output - rt.run() - result = output_tensor.to_numpy() - - # Measure throughput - iter = 100 - ts = time.time() - rt.run(iter) - elapsed_ms = (time.time() - ts) * 1e3 - print( - f"Plan elapsed time: total {elapsed_ms:.6f} ms, {elapsed_ms/iter:.6f} ms/iter" - ) - - # Compare results - compare_results(result, ground_truth) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--plan_path", type=str, default="plan.json") - - args = parser.parse_args() - main(args.plan_path) diff --git a/examples/tutorial/planner_tutorial.py b/examples/tutorial/planner_tutorial.py index 1f6c3ac58..8702f8929 100644 --- a/examples/tutorial/planner_tutorial.py +++ b/examples/tutorial/planner_tutorial.py @@ -54,14 +54,14 @@ def eval(tensor: ark.Tensor): return tensor.to_torch() -def perf(): +def perf(num_iter: int = 1000): with ark.Runtime() as rt: rt.launch() start = time.time() - rt.run(iter=1000) + rt.run(iter=num_iter) end = time.time() - return (end - start) / 1000 + return (end - start) / num_iter if __name__ == "__main__": @@ -69,14 +69,13 @@ def perf(): shape = (32, 2048, 2048) - # input = torch.randn(*shape).to("cuda:0") - input = ark.tensor(shape) + input = torch.randn(*shape).to("cuda:0") - output = Softmax()(input) + output = Softmax()(ark.Tensor.from_torch(input)) - # if torch.allclose(eval(output), F.softmax(input, dim=-1), atol=1e-5): - # print("Correct result") - # else: - # print("Incorrect result") + if torch.allclose(eval(output), F.softmax(input, dim=-1), atol=1e-5): + print("Correct result") + else: + print("Incorrect result") print(f"Performance: {(perf() * 1e3):.3f} ms/iter") diff --git a/examples/tutorial/quickstart_tutorial.py b/examples/tutorial/quickstart_tutorial.py index ebd3f8530..1fce51452 100644 --- a/examples/tutorial/quickstart_tutorial.py +++ b/examples/tutorial/quickstart_tutorial.py @@ -41,12 +41,6 @@ def quickstart_tutorial(): output_tensor_host, input_tensor_host + other_tensor_host ) - # Stop the ARK runtime (undo Runtime.launch()) - runtime.stop() - - # Reset the ARK runtime (free all resources) - runtime.reset() - print("Quickstart tutorial is successful!") diff --git a/python/ark/__init__.py b/python/ark/__init__.py index 24e4acfc4..63480262c 100644 --- a/python/ark/__init__.py +++ b/python/ark/__init__.py @@ -31,9 +31,10 @@ def set_world_size(world_size): from .init import init from .tensor import Dims, Tensor, Parameter from .module import Module -from .runtime import Runtime +from .runtime import * from .serialize import save, load from .data_type import * +from .profiler import Profiler from .ops import * from .planner import * from .error import * diff --git a/python/ark/data_type.py b/python/ark/data_type.py index 4638cf972..3deef50f4 100644 --- a/python/ark/data_type.py +++ b/python/ark/data_type.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import numpy +from .torch import torch from . import core from . import log @@ -15,15 +16,14 @@ "uint8", ] - REGISTRY_DATA_TYPE = { - "fp32": {"np": numpy.float32}, - "fp16": {"np": numpy.float16}, - "bf16": {"np": None}, - "int32": {"np": numpy.int32}, - "uint32": {"np": numpy.uint32}, - "int8": {"np": numpy.int8}, - "uint8": {"np": numpy.uint8}, + "fp32": {"np": numpy.float32, "torch": torch.float32}, + "fp16": {"np": numpy.float16, "torch": torch.float16}, + "bf16": {"np": None, "torch": torch.bfloat16}, + "int32": {"np": numpy.int32, "torch": torch.int32}, + "uint32": {"np": numpy.uint32, "torch": None}, + "int8": {"np": numpy.int8, "torch": torch.int8}, + "uint8": {"np": numpy.uint8, "torch": torch.uint8}, } @@ -33,6 +33,7 @@ def __new__(cls, name, bases, attrs): if name in REGISTRY_DATA_TYPE: reg = REGISTRY_DATA_TYPE[name] new_class.to_numpy = staticmethod(lambda: reg["np"]) + new_class.to_torch = staticmethod(lambda: reg["torch"]) new_class.ctype = staticmethod(lambda: getattr(core, name.upper())) new_class.element_size = staticmethod( lambda: new_class.ctype().bytes() @@ -60,9 +61,10 @@ def from_numpy(np_type: numpy.dtype) -> "DataType": InvalidUsageError: If there is no defined conversion from numpy data type to ark data type. """ if not isinstance(np_type, numpy.dtype): - raise log.InvalidUsageError( - f"Expected a numpy data type, but got {type(np_type)}" - ) + try: + np_type = numpy.dtype(np_type) + except Exception as e: + raise log.InvalidUsageError(f"Not a numpy data type. {str(e)}") for type_name, reg in REGISTRY_DATA_TYPE.items(): if reg["np"] == np_type: return DataType.from_name(type_name) @@ -71,6 +73,28 @@ def from_numpy(np_type: numpy.dtype) -> "DataType": f" to ark data type." ) + @staticmethod + def from_torch(torch_type: torch.dtype) -> "DataType": + """ + Return the corresponding ark data type. + + Parameters: + torch_type (torch.dtype): The torch data type. + + Returns: + DataType: The corresponding ark data type. + + Raises: + ValueError: If there is no defined conversion from torch data type to ark data type. + """ + for type_name, reg in REGISTRY_DATA_TYPE.items(): + if reg["torch"] == torch_type: + return DataType.from_name(type_name) + raise ValueError( + f"Undefined conversion from torch data type {torch_type}" + f" to ark data type." + ) + @staticmethod def from_name(type_name: str) -> "DataType": """ @@ -120,6 +144,16 @@ def to_numpy() -> numpy.dtype: """ ... + @staticmethod + def to_torch() -> torch.dtype: + """ + Return the corresponding torch data type. + + Returns: + torch.dtype: The corresponding torch data type. + """ + ... + @staticmethod def ctype() -> core.CoreDataType: """ diff --git a/python/ark/executor.py b/python/ark/executor.py new file mode 100644 index 000000000..14f0817a8 --- /dev/null +++ b/python/ark/executor.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .core import CoreExecutor + + +__all__ = ["Executor"] + + +class ExecutorState: + executor: CoreExecutor = None + + +class Executor: + @staticmethod + def get() -> CoreExecutor: + if ExecutorState.executor is None: + ExecutorState.executor = CoreExecutor() + return ExecutorState.executor + + @staticmethod + def reset() -> None: + if ExecutorState.executor is None: + return + ExecutorState.executor.destroy() + ExecutorState.executor = None diff --git a/python/ark/init.py b/python/ark/init.py index f8e226ad1..07eb557b3 100644 --- a/python/ark/init.py +++ b/python/ark/init.py @@ -3,15 +3,13 @@ from . import core from .model import Model -from .runtime import RuntimeState +from .executor import Executor __all__ = ["init"] def init(): """Initializes ARK.""" + Executor.reset() Model.reset() - if RuntimeState.executor is not None: - if not RuntimeState.executor.destroyed(): - RuntimeState.executor.destroy() core.init() diff --git a/python/ark/model.py b/python/ark/model.py index bfd74d5e0..e103d4083 100644 --- a/python/ark/model.py +++ b/python/ark/model.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. from typing import NewType +from . import log from .core import CoreModel @@ -34,6 +35,13 @@ def get_world_size(): """ return ModelState.world_size + @staticmethod + def get_device_id(): + """ + Get the device id. + """ + return ModelState.device_id + @staticmethod def set_rank(rank: int): """ @@ -48,6 +56,15 @@ def set_world_size(world_size: int): """ ModelState.world_size = world_size + @staticmethod + def set_device_id(device_id: int): + """ + Set the device id. + """ + if device_id < 0: + raise log.InvalidUsageError("device_id must be non-negative") + ModelState.device_id = device_id + @staticmethod def reset(): """ @@ -57,6 +74,19 @@ def reset(): ModelState.rank = 0 ModelState.world_size = 1 + def __init__(self, rank: int = 0, world_size: int = 1): + """ + Initialize the model. + + Args: + rank: The rank of the model. + world_size: The world size of the model. + """ + super().__init__(rank, world_size) + + def __str__(self) -> str: + return self.serialize() + def compress(self) -> "Model": """ Compress the model. @@ -84,3 +114,4 @@ class ModelState: model: Model = None rank: int = 0 world_size: int = 1 + device_id: int = 0 diff --git a/python/ark/module.py b/python/ark/module.py index 368f36cf7..55d80b8e8 100644 --- a/python/ark/module.py +++ b/python/ark/module.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import logging import numpy as np from typing import Any, Dict from .tensor import Parameter @@ -71,6 +70,9 @@ def load_state_dict( all_keys = set(state_dict.keys()) pd = self.params_dict(prefix) for name, param in pd.items(): + if param.data_ptr() == 0: + log.WARN(f"Parameter {name} is not initialized") + continue param.from_numpy(state_dict[name]) all_keys.remove(name) if all_keys: diff --git a/python/ark/ops.py b/python/ark/ops.py index fa7879e07..46145035a 100644 --- a/python/ark/ops.py +++ b/python/ark/ops.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import List, Iterable, Union +from typing import List, Iterable, Union, Optional -from .tensor import Dims, Tensor, Parameter, NullTensor +from .tensor import Dims, Tensor, Parameter, NullTensor, _cpp_tensor +from .torch import torch, _no_torch from .data_type import DataType, fp32 from .model import Model from . import log @@ -12,6 +13,8 @@ __all__ = [ "tensor", "parameter", + "placeholder", + "noop", "reshape", "identity", "sharding", @@ -47,52 +50,6 @@ def is_list_or_tuple(obj): return isinstance(obj, list) or isinstance(obj, tuple) -def _tensor( - shape: Iterable[int], - dtype: DataType = fp32, - strides: Iterable[int] = [], - offsets: Iterable[int] = [], - padded_shape: Iterable[int] = [], - rank: int = -1, - name: str = "", -) -> Tensor: - if not is_list_or_tuple(shape): - raise log.InvalidUsageError( - "shape should be a list or tuple of integers" - ) - if not is_list_or_tuple(strides): - raise log.InvalidUsageError( - "strides should be a list or tuple of integers" - ) - if not is_list_or_tuple(offsets): - raise log.InvalidUsageError( - "offsets should be a list or tuple of integers" - ) - if not is_list_or_tuple(padded_shape): - raise log.InvalidUsageError( - "padded_shape should be a list or tuple of integers" - ) - # only support tensors with up to 4 dimensions - if ( - len(shape) > 4 - or len(strides) > 4 - or len(offsets) > 4 - or len(padded_shape) > 4 - ): - raise log.InvalidUsageError( - "Only support tensors with up to 4 dimensions" - ) - return Model.get_model().tensor( - Dims(shape), - dtype.ctype(), - Dims(strides), - Dims(offsets), - Dims(padded_shape), - rank, - name, - ) - - def add( input: Union[Tensor, float], other: Union[Tensor, float], @@ -147,7 +104,9 @@ def constant( def copy( - input: Union[Tensor, float], output: Tensor = NullTensor, name: str = "copy" + input: Union[Tensor, float], + output: Tensor = NullTensor, + name: str = "copy", ) -> Tensor: """ """ if output is not NullTensor: @@ -186,7 +145,9 @@ def embedding( def exp( - input: Tensor, output: Tensor = NullTensor, name: str = "exp" + input: Tensor, + output: Tensor = NullTensor, + name: str = "exp", ) -> Tensor: """ """ if output is not NullTensor: @@ -195,7 +156,9 @@ def exp( def gelu( - input: Tensor, output: Tensor = NullTensor, name: str = "gelu" + input: Tensor, + output: Tensor = NullTensor, + name: str = "gelu", ) -> Tensor: """ """ if output is not NullTensor: @@ -257,6 +220,35 @@ def noop(input: Tensor, name: str = "noop"): Model.get_model().noop(input._tensor, name) +def placeholder( + shape: Iterable[int], + dtype: DataType = fp32, + strides: Iterable[int] = [], + offsets: Iterable[int] = [], + padded_shape: Iterable[int] = [], + rank: int = -1, + data: Union[int, torch.Tensor] = 0, + name: str = "placeholder", +) -> Tensor: + """ """ + if not _no_torch and isinstance(data, torch.Tensor): + # Should we support initializing shape dtype stride offset and padded_shape + # just by passing in a torch.Tensor? + data = data.data_ptr() + return Tensor( + Model.get_model().placeholder( + Dims(shape), + dtype.ctype(), + Dims(strides), + Dims(offsets), + Dims(padded_shape), + rank, + data, + name, + ) + ) + + def reduce_max( input: Tensor, axis: int, @@ -309,7 +301,9 @@ def reduce_sum( def relu( - input: Tensor, output: Tensor = NullTensor, name: str = "relu" + input: Tensor, + output: Tensor = NullTensor, + name: str = "relu", ) -> Tensor: """ """ if output is not NullTensor: @@ -365,7 +359,9 @@ def rope( def rsqrt( - input: Tensor, output: Tensor = NullTensor, name: str = "rsqrt" + input: Tensor, + output: Tensor = NullTensor, + name: str = "rsqrt", ) -> Tensor: """ """ if output is not NullTensor: @@ -384,7 +380,9 @@ def sharding( def sigmoid( - input: Tensor, output: Tensor = NullTensor, name: str = "sigmoid" + input: Tensor, + output: Tensor = NullTensor, + name: str = "sigmoid", ) -> Tensor: """ """ if output is not NullTensor: @@ -393,7 +391,9 @@ def sigmoid( def sqrt( - input: Tensor, output: Tensor = NullTensor, name: str = "sqrt" + input: Tensor, + output: Tensor = NullTensor, + name: str = "sqrt", ) -> Tensor: """ """ if output is not NullTensor: @@ -426,7 +426,9 @@ def tensor( ) -> Tensor: """ """ return Tensor( - _tensor(shape, dtype, strides, offsets, padded_shape, rank, name) + _cpp_tensor( + shape, dtype, strides, offsets, padded_shape, rank, None, name + ) ) @@ -484,7 +486,9 @@ def parameter( ) -> Parameter: """ """ return Parameter( - _tensor(shape, dtype, strides, offsets, padded_shape, name) + _cpp_tensor( + shape, dtype, strides, offsets, padded_shape, -1, None, name + ) ) @@ -514,7 +518,9 @@ def layernorm( def zeros( - shape: Iterable[int], dtype: DataType = fp32, name: str = "zeros" + shape: Iterable[int], + dtype: DataType = fp32, + name: str = "zeros", ) -> Tensor: """ """ return Tensor( diff --git a/python/ark/planner.py b/python/ark/planner.py index 3c82719be..0ed9113e1 100644 --- a/python/ark/planner.py +++ b/python/ark/planner.py @@ -5,6 +5,7 @@ import json from typing import Callable, Dict, List, Any +from . import error from .core import CorePlanner, CorePlannerContext from .model import Model @@ -155,13 +156,27 @@ def processor_groups(self) -> List[Dict[str, Any]]: @staticmethod def from_str(plan_str: str) -> "Plan": - plan = json.loads(plan_str) + try: + plan = json.loads(plan_str) + except json.JSONDecodeError: + raise error.InvalidUsageError( + "Plan string is not a valid JSON string." + ) return Plan(plan) @staticmethod def from_file(file_path: str) -> "Plan": - with open(file_path, "r") as f: - plan = json.load(f) + try: + with open(file_path, "r") as f: + plan = json.load(f) + except FileNotFoundError: + raise error.InvalidUsageError( + f"Plan file {file_path} does not exist." + ) + except json.JSONDecodeError: + raise error.InvalidUsageError( + f"Plan file {file_path} is not a valid JSON file." + ) return Plan(plan) @@ -195,6 +210,15 @@ def __init__(self, **kwargs): if config is not None: self.config(json.dumps(config)) + def dump(self) -> str: + """ + Dump the context stack. + + Returns: + str: The context stack in JSON format. + """ + return super().dump() + def __enter__(self) -> "PlannerContext": """ Enter the plan manager. @@ -227,4 +251,4 @@ def plan(self) -> Plan: """ Generate an execution plan. """ - return Plan.from_str(super().plan(pretty=False)) + return Plan.from_str(super().plan(pretty=True)) diff --git a/python/ark/profiler.py b/python/ark/profiler.py new file mode 100644 index 000000000..da346cb7b --- /dev/null +++ b/python/ark/profiler.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import sys +import time +from typing import Optional, List + +from .runtime import Runtime +from .planner import Plan + + +def timeit(plan: Plan, iter: int, loop_mode: bool, warmup: int = 3): + with Runtime() as rt: + if loop_mode: + rt.launch(plan=plan, loop_mode=loop_mode) + rt.run(iter=warmup) + rt.stop() + start_time = time.time() + rt.run(iter=iter) + elapsed = time.time() - start_time + else: + rt.launch(plan=plan, loop_mode=loop_mode) + rt.run(iter=warmup) + rt.stop() + rt.launch(plan=plan, loop_mode=loop_mode, record=True) + rt.run(iter=iter) + elapsed = rt.stop() / 1.0e3 + return elapsed / iter + + +class Profiler: + def __init__(self, plan: Plan): + self.plan = plan + + def run( + self, + iter: int = 1000, + loop_mode: bool = True, + profile_processor_groups: bool = False, + target_processor_groups: Optional[List[int]] = None, + ): + if target_processor_groups is None: + sys.stderr.write( + f"End-to-end: {timeit(self.plan, iter, loop_mode):.6f} seconds/iter\n" + ) + + if not profile_processor_groups: + return + num_processor_groups = len(self.plan.processor_groups) + new_plan = { + "Rank": self.plan.rank, + "WorldSize": self.plan.world_size, + "Architecture": self.plan.architecture, + "NumProcessors": self.plan.num_processors, + "NumWarpsPerProcessor": self.plan.num_warps_per_processor, + "TaskInfos": self.plan.task_infos, + "ProcessorGroups": [None], + } + for i in range(num_processor_groups): + if ( + target_processor_groups is not None + and i not in target_processor_groups + ): + continue + new_plan["ProcessorGroups"][0] = self.plan.processor_groups[i] + lat_per_iter = timeit(Plan(new_plan), iter, loop_mode) + sys.stderr.write( + f"Processor group {i}: {lat_per_iter:.6f} seconds/iter\n" + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="ARK Profiler") + parser.add_argument( + "--iter", + type=int, + default=1000, + help="Number of iterations to run for each measurement", + ) + parser.add_argument( + "--loop_mode", + action="store_true", + help="Use loop mode to measure end-to-end latency", + ) + parser.add_argument( + "--profile_processor_groups", + action="store_true", + help="Profile processor groups", + ) + parser.add_argument( + "--target_processor_groups", + type=str, + help="Target processor groups to profile", + ) + parser.add_argument( + "--plan", type=str, help="Path to the plan file", required=True + ) + args = parser.parse_args() + + target_processor_groups = None + if args.target_processor_groups is not None: + target_processor_groups = list( + map(int, args.target_processor_groups.split(",")) + ) + + plan = Plan.from_file(args.plan) + profiler = Profiler(plan) + profiler.run( + iter=args.iter, + loop_mode=args.loop_mode, + profile_processor_groups=args.profile_processor_groups, + target_processor_groups=target_processor_groups, + ) diff --git a/python/ark/runtime.py b/python/ark/runtime.py index 017350103..0edfd26ec 100644 --- a/python/ark/runtime.py +++ b/python/ark/runtime.py @@ -1,27 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import logging from enum import Enum -from .core import CoreExecutor +from . import log +from .tensor import Tensor +from .torch import torch +from .executor import Executor from .planner import Planner, Plan +from .model import Model +from typing import Dict -__all__ = ["Executor", "Runtime"] - - -class RuntimeState: - """ - The RuntimeState class is used to store the state of the model. - """ - - runtime = None - executor = None - - -class Executor(CoreExecutor): - pass +__all__ = ["Runtime"] class Runtime: @@ -29,98 +20,107 @@ class Runtime: Convenience class for running a model. """ - class State(Enum): + class StateCode(Enum): """ - Runtime states. + Runtime state code. """ Init = 0 LaunchedNotRunning = 1 Running = 2 - @staticmethod - def get_runtime() -> "Runtime": - """ - Get the runtime. - """ - if RuntimeState.runtime is None: - RuntimeState.runtime = Runtime() - return RuntimeState.runtime - def __init__(self): - self.executor: Executor = None - self.state: Runtime.State = Runtime.State.Init - RuntimeState.runtime = self - - def __del__(self): - self.reset() + self.loop_mode: bool = True + self.state: Runtime.StateCode = Runtime.StateCode.Init - def __enter__(self): + def __enter__(self) -> "Runtime": return self def __exit__(self, exc_type, exc_val, exc_tb): - self.reset() + if self.launched(): + self.stop() + + def __del__(self): + if self.launched(): + self.stop() def launched(self) -> bool: """ Check if the runtime is launched. """ return ( - self.state == Runtime.State.LaunchedNotRunning - or self.state == Runtime.State.Running + self.state == Runtime.StateCode.LaunchedNotRunning + or self.state == Runtime.StateCode.Running ) def running(self) -> bool: """ Check if the runtime is running. """ - return self.state == Runtime.State.Running + return self.state == Runtime.StateCode.Running def launch( self, plan: Plan = None, - device_id: int = 0, + device_id: int = -1, stream: int = 0, loop_mode: bool = True, + record: bool = False, + tensor_mappings: Dict = {}, ): """ Create an executor and schedule the ARK model. The scheduler will generate the CUDA kernels. The GPU context and the connection between GPUs will be initialized. The executor will compile the cuda kernels and launch the ARK runtime. """ - if self.launched(): - logging.warn("Runtime is already launched, skip launching") - return + if device_id == -1: + device_id = Model.get_device_id() + elif device_id < 0: + raise log.InvalidUsageError(f"Invalid device_id: {device_id}") plan = Planner(device_id).plan() if plan is None else plan - # If the RuntimeState is init, we need to create a new executor and - # compile the kernels - if self.state == Runtime.State.Init: - if RuntimeState.executor is not None: - if not RuntimeState.executor.destroyed(): - logging.warn("Destroying an old executor") - RuntimeState.executor.destroy() - - RuntimeState.executor = Executor( - device_id, - stream, - "ArkRuntime", - str(plan), - loop_mode, - ) - self.executor = RuntimeState.executor - self.executor.compile() - self.executor.launch() - self.state = Runtime.State.LaunchedNotRunning - - def run(self, iter=1, non_blocking=False): + plan_str = str(plan) + if self.launched(): + # Stop the current running model + self.stop() + for ark_tensor in list(tensor_mappings.keys()): + torch_tensor = tensor_mappings[ark_tensor] + if not isinstance(torch_tensor, torch.Tensor): + raise log.InvalidUsageError("Must bind PyTorch tensor") + internal_ark_tensor = ark_tensor._tensor + tensor_mappings[internal_ark_tensor] = torch_tensor.data_ptr() + del tensor_mappings[ark_tensor] + # Recompile if the previous launch was not compiled with the same info + # or if this is the first launch + exe = Executor.get() + if plan_str != exe.plan() or device_id != exe.device_id(): + exe.compile(plan_str, device_id) + exe.launch(tensor_mappings, stream, loop_mode, record) + self.state = Runtime.StateCode.LaunchedNotRunning + self.loop_mode = loop_mode + + def run( + self, + iter: int = 1, + non_blocking: bool = False, + tensor_mappings: Dict[Tensor, torch.Tensor] = {}, + ): """ Run the ARK program for iter iterations and wait for the kernel to finish. """ - if self.state != Runtime.State.LaunchedNotRunning: - logging.error("ARK runtime is not launched") - raise RuntimeError("ARK runtime is not launched") - self.state = Runtime.State.Running - self.executor.run(iter) + if self.loop_mode and tensor_mappings: + raise log.InvalidUsageError( + "`loop_mode` argument when calling `runtime.launch` " + "must be set to false in order to pass non-empty " + "tensor mappings in `runtime.run`." + ) + if self.state != Runtime.StateCode.LaunchedNotRunning: + raise log.InvalidUsageError(f"ARK runtime is not launched") + self.state = Runtime.StateCode.Running + ph_map = {} + for ark_tensor in list(tensor_mappings.keys()): + t = tensor_mappings[ark_tensor] + ph_map[ark_tensor._tensor] = t.data_ptr() + Executor.get().run(iter, ph_map) if not non_blocking: self.wait() @@ -128,20 +128,19 @@ def barrier(self): """ Barrier for all ranks. """ - if self.state != Runtime.State.LaunchedNotRunning: - logging.error("ARK runtime is not launched") - raise RuntimeError("ARK runtime is not launched") - self.executor.barrier() + if self.state != Runtime.StateCode.LaunchedNotRunning: + raise log.InvalidUsageError("ARK runtime is not launched") + Executor.get().barrier() def wait(self): """ Wait for the kernel to finish. """ - if self.state != Runtime.State.Running: - logging.warn("ARK runtime is not running, skip waiting") + if self.state != Runtime.StateCode.Running: + log.WARN(f"ARK runtime is not running, skip waiting") return - self.executor.wait() - self.state = Runtime.State.LaunchedNotRunning + Executor.get().wait() + self.state = Runtime.StateCode.LaunchedNotRunning def stop(self) -> float: """ @@ -149,20 +148,8 @@ def stop(self) -> float: Once this is called, we need to call `launch()` again to run the model again. """ if not self.launched(): - logging.warn("ARK runtime is never launched, skip stopping") - return - elapsed = self.executor.stop() - self.state = Runtime.State.LaunchedNotRunning + log.WARN(f"ARK runtime is never launched, skip stopping") + return -1 + elapsed = Executor.get().stop() + self.state = Runtime.StateCode.LaunchedNotRunning return elapsed - - def reset(self): - """ - Reset the runtime. - """ - if self.launched(): - self.stop() - if self.executor is not None: - if not self.executor.destroyed(): - self.executor.destroy() - self.executor = None - self.state = Runtime.State.Init diff --git a/python/ark/serialize.py b/python/ark/serialize.py index 93473202e..584111825 100644 --- a/python/ark/serialize.py +++ b/python/ark/serialize.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import pickle -import logging +from . import log def save(state_dict, state_dict_file_path: str): @@ -10,9 +10,7 @@ def save(state_dict, state_dict_file_path: str): Save the state_dict of a module to a file """ if not isinstance(state_dict, dict): - logging.warn( - "Warning: Invalid state_dict saved to", state_dict_file_path - ) + log.WARN(f"Invalid state_dict saved to {state_dict_file_path}") with open(state_dict_file_path, "wb") as f: pickle.dump(state_dict, f) @@ -24,5 +22,5 @@ def load(state_dict_file_path: str): with open(state_dict_file_path, "rb") as f: state_dict = pickle.load(f) if not isinstance(state_dict, dict): - logging.warn("Warning: Invalid state_dict file") + log.WARN("Invalid state_dict file") return state_dict diff --git a/python/ark/tensor.py b/python/ark/tensor.py index 197d92921..216318b27 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -2,11 +2,14 @@ # Licensed under the MIT license. import numpy as np -from typing import List +from typing import Callable, Iterable, List, Union, Type +from . import log from .core import CoreDims, CoreTensor, NullTensor -from .data_type import DataType -from .runtime import Runtime +from .torch import torch, _no_torch +from .data_type import DataType, fp32 +from .executor import Executor +from .model import Model __all__ = ["Dims", "Tensor", "Parameter", "NullTensor"] @@ -15,14 +18,121 @@ class Dims(CoreDims): pass +Initializer = Type[Callable[[], Union[torch.Tensor, np.ndarray]]] + + class Tensor: - def __init__(self, _tensor: CoreTensor): + def __init__( + self, + _tensor: CoreTensor, + initializer: Initializer = None, + requires_grad: bool = False, + ): """ Initializes a new instance of the Tensor class. Args: - _tensor (core.CoreTensor): The underlying CoreTensor object. + _tensor (core.CoreTensor): The underlying _Tensor object. + initializer (Initializer): The initializer for the Tensor. + requires_grad (bool): Whether the tensor requires gradient. Defaults to True. + """ + self._tensor: CoreTensor = _tensor + self.initializer: Initializer = initializer + self.requires_grad: bool = requires_grad + + def __hash__(self): + return self._tensor.id() + + def __eq__(self, other): + if not isinstance(other, Tensor): + return False + return self._tensor.id() == other._tensor.id() + + def __getitem__(self, index) -> "Tensor": + if not isinstance(index, tuple): + index = (index,) + new_shape = [] + new_strides = [] + new_offsets = [] + new_padded_shape = [] + if len(index) > len(self.shape()): + raise log.InvalidUsageError( + f"Index has more dimensions than the tensor. Index: " + f"{index}, tensor shape: {self.shape()}" + ) + for i, idx in enumerate(index): + shape_len = self.shape()[i] + padded_shape_len = self._padded_shape()[i] + pad_len = padded_shape_len - shape_len + if isinstance(idx, int): + new_shape.append(1) + new_strides.append(self.strides()[i]) + new_offsets.append(idx) + if idx == shape_len - 1: + new_padded_shape.append(1 + pad_len) + else: + new_padded_shape.append(1) + elif isinstance(idx, slice): + start = idx.start or 0 + stop = idx.stop or self.shape()[i] + step = idx.step or 1 + if step < 0: + start, stop = stop + 1, start + 1 + if step != 1 and step != -1: + # TODO: support step other than 1 or -1 + raise log.UnsupportedError( + f"Step must be 1 or -1. Given: {step}" + ) + new_shape.append(stop - start) + new_strides.append(self.strides()[i]) + new_offsets.append(start) + if stop == shape_len: + new_padded_shape.append(stop + pad_len - start) + else: + new_padded_shape.append(stop - start) + else: + raise log.InvalidUsageError( + f"Index must be an integer or a slice. Index: {idx}" + ) + new_shape = Dims(new_shape) + new_strides = Dims(new_strides) + new_offsets = Dims(new_offsets) + new_padded_shape = Dims(new_padded_shape) + new_tensor = Tensor( + Model.get_model().refer( + self._tensor, + new_shape, + new_strides, + new_offsets, + new_padded_shape, + "", + ) + ) + new_tensor.requires_grad = self.requires_grad + return new_tensor + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + new_args = [] + for arg in args: + if isinstance(arg, Tensor): + new_args.append(Tensor.to_torch(arg)) + else: + new_args.append(arg) + new_kwargs = {} + for key, value in kwargs.items(): + if isinstance(value, Tensor): + new_kwargs[key] = Tensor.to_torch(value) + else: + new_kwargs[key] = value + return func(*new_args, **new_kwargs) + + def _padded_shape(self) -> List[int]: """ - self._tensor = _tensor + Returns the padded shape of the tensor. + """ + return self._tensor.padded_shape().vector() def shape(self) -> List[int]: """ @@ -48,6 +158,34 @@ def dtype(self) -> DataType: """ return DataType.from_ctype(self._tensor.data_type()) + def data_ptr(self) -> int: + """ + Returns the underlying data pointer. + """ + return Executor.get().tensor_address(self._tensor) + + def is_external(self) -> bool: + """ + Returns true if the tensor's data is not managed by ARK. + """ + return self._tensor.is_external() + + def _raise_if_no_data(self): + if self.data_ptr() != 0: + return + if self.is_external(): + raise log.InvalidUsageError( + "Tried to access data of an external tensor that does not " + "have data set. This is likely because this tensor is a " + "placeholder and you have not set the data." + ) + raise log.InvalidUsageError( + "Tried to access data of a tensor that is not allocated yet. " + "This is likely due to either you have not called " + "`Runtime.launch()` for the model or the tensor is unused " + "in the model." + ) + def to_numpy( self, ndarray: np.ndarray = None, stream: int = 0 ) -> np.ndarray: @@ -56,39 +194,142 @@ def to_numpy( a new numpy array will be created. If the tensor is not allocated, an empty numpy array without the data buffer will be returned. """ + self._raise_if_no_data() np_type = self.dtype().to_numpy() - rt = Runtime.get_runtime() - if not rt.launched(): - return np.ndarray(self.shape(), dtype=np_type, buffer=None) + if np_type is None: + raise log.InvalidUsageError( + f"Tensor data type {self.dtype().__name__} is not supported by numpy." + ) if ndarray is None: ndarray = np.zeros(self.shape(), dtype=np_type) elif not ndarray.flags["C_CONTIGUOUS"]: - raise ValueError("ndarray is not contiguous in memory") + raise log.InvalidUsageError("ndarray is not contiguous in memory") elif ndarray.shape != self.shape(): - raise ValueError("ndarray shape does not match the tensor") + raise log.InvalidUsageError( + "ndarray shape does not match the tensor" + ) elif ndarray.dtype != np_type: - raise ValueError("ndarray dtype does not match the tensor") + raise log.InvalidUsageError( + "ndarray dtype does not match the tensor" + ) elif ndarray.nbytes != self.nelems() * self.dtype().element_size(): - raise ValueError("ndarray size does not match the tensor") - rt.executor.tensor_read(self._tensor, ndarray, stream) + raise log.InvalidUsageError( + "ndarray size does not match the tensor" + ) + Executor.get().tensor_read(self._tensor, ndarray, stream) return ndarray def from_numpy(self, ndarray: np.ndarray, stream: int = 0) -> "Tensor": """ Copies the tensor from a host numpy array to the device. """ - rt = Runtime.get_runtime() - if not rt.launched(): - raise RuntimeError( - "Tensor is not allocated yet. `Tensor.from_numpy()` is " - "usable only after you call `Runtime.launch()`." - ) + self._raise_if_no_data() ndarray = ndarray.astype(self.dtype().to_numpy()) if not ndarray.flags["C_CONTIGUOUS"]: ndarray = np.ascontiguousarray(ndarray) if ndarray.nbytes != self.nelems() * self.dtype().element_size(): - raise ValueError("ndarray size does not match the tensor") - rt.executor.tensor_write(self._tensor, ndarray, stream) + raise log.InvalidUsageError( + "ndarray size does not match the tensor" + ) + Executor.get().tensor_write(self._tensor, ndarray, stream) + return self + + def to_dlpack(self): + """ + Returns a DLPack tensor that shares the same memory with the device tensor. + """ + self._raise_if_no_data() + return Executor.get().tensor_to_dlpack(self._tensor) + + @staticmethod + def from_dlpack(ext_tensor) -> "Tensor": + """ + Copies the tensor from a DLPack tensor to the device. + """ + raise log.UnsupportedError("from_dlpack is not implemented yet") + + def to_torch(self) -> torch.Tensor: + """ + Returns a torch tensor that shares the same memory with the device tensor. + """ + if _no_torch: + raise log.SystemError("torch is not available") + dl_capsule = self.to_dlpack() + torch_view = torch.utils.dlpack.from_dlpack(dl_capsule) + # Keep dl_capsule alive not to free the memory + torch_view.__ark_buffer__ = dl_capsule + return torch_view + + @staticmethod + def from_torch(tensor: torch.Tensor) -> "Tensor": + """ + Returns an ARK tensor that shares the same memory with the torch tensor. + """ + if _no_torch: + raise log.SystemError("torch is not available") + elif not tensor.is_contiguous(): + raise log.InvalidUsageError("Torch tensor must be contiguous.") + elif tensor.device.type == "cpu": + raise log.InvalidUsageError("Torch tensor must be on a device.") + # TODO: support strides and offsets + ark_tensor = Tensor( + _cpp_tensor( + shape=list(tensor.shape), + dtype=DataType.from_torch(tensor.dtype), + data=tensor.data_ptr(), + ) + ) + # Share ownership of the memory with the torch tensor + ark_tensor.__torch_buffer__ = tensor + return ark_tensor + + def copy( + self, data: Union[np.ndarray, torch.Tensor], stream: int = 0 + ) -> "Tensor": + """ + Copies data into this tensor. The data type may differ, + but the size must match. + """ + self._raise_if_no_data() + tensor_bytes = self.nelems() * self.dtype().element_size() + if isinstance(data, torch.Tensor): + if not data.is_contiguous(): + data = data.contiguous() + if data.numel() * data.element_size() != tensor_bytes: + raise log.InvalidUsageError( + "data size does not match the tensor" + ) + Executor.get().tensor_write( + self._tensor, + data.data_ptr(), + tensor_bytes, + stream, + data.device.type == "cuda", + ) + data.requires_grad = self.requires_grad + if isinstance(self, Parameter): + self.torch_param = data + elif isinstance(data, np.ndarray): + if not data.flags["C_CONTIGUOUS"]: + data = np.ascontiguousarray(data) + if data.nbytes != tensor_bytes: + raise log.InvalidUsageError( + "data size does not match the tensor" + ) + Executor.get().tensor_write(self._tensor, data, stream) + else: + raise log.InvalidUsageError( + "data must be a numpy array or a torch tensor" + ) + return self + + def initialize(self) -> "Tensor": + """ + Initializes the tensor. + """ + if self.initializer is not None: + data = self.initializer() + self.copy(data) return self @@ -97,8 +338,114 @@ class Parameter(Tensor): A tensor as a parameter. """ - def __init__(self, _tensor: CoreTensor): + def __init__( + self, + tensor: CoreTensor, + from_torch: bool = False, + ): """ Initializes a new instance of the Parameter class. + Args: + _tensor (_ark_core._Tensor): The underlying _Tensor object. + from_torch: Indicates if the Parameter is tied to a torch.nn.Paramter + """ + if not _no_torch and from_torch: + _tensor = tensor._tensor + self.torch_param = tensor + self.staged_tensor = None + Tensor.__init__( + self, + _tensor, + requires_grad=tensor.requires_grad, + ) + elif isinstance(tensor, CoreTensor): + _tensor = tensor + self.torch_param = None + self.staged_tensor = None + Tensor.__init__(self, _tensor, requires_grad=False) + else: + raise log.InvalidUsageError( + "tensor must be an ARK tensor or a torch.nn.Parameter" + ) + + def update_gradient(self, ark_tensor: Tensor): + """ + Stages an ARK tensor to be used for updating the gradient of its associated parameter. """ - super().__init__(_tensor) + if _no_torch: + raise log.SystemError("torch is not available") + if self.torch_param is None: + raise log.InvalidUsageError( + "there is no PyTorch parameter associated with this ARK parameter" + ) + if not self.torch_param.requires_grad: + raise log.InvalidUsageError( + "parameter does not require gradient updates" + ) + if ark_tensor is None or not isinstance(ark_tensor, Tensor): + raise log.InvalidUsageError( + "cannot use non-ARK tensor to update ARK gradient" + ) + self.staged_tensor = ark_tensor + + +def _is_list_or_tuple(obj): + return isinstance(obj, list) or isinstance(obj, tuple) + + +def _cpp_tensor( + shape: Iterable[int], + dtype: DataType = fp32, + strides: Iterable[int] = [], + offsets: Iterable[int] = [], + padded_shape: Iterable[int] = [], + rank: int = -1, + data: int = None, + name: str = "", +) -> Tensor: + if not _is_list_or_tuple(shape): + raise log.InvalidUsageError( + "shape should be a list or tuple of integers" + ) + if not _is_list_or_tuple(strides): + raise log.InvalidUsageError( + "strides should be a list or tuple of integers" + ) + if not _is_list_or_tuple(offsets): + raise log.InvalidUsageError( + "offsets should be a list or tuple of integers" + ) + if not _is_list_or_tuple(padded_shape): + raise log.InvalidUsageError( + "padded_shape should be a list or tuple of integers" + ) + # only support tensors with up to 4 dimensions + if ( + len(shape) > 4 + or len(strides) > 4 + or len(offsets) > 4 + or len(padded_shape) > 4 + ): + raise ValueError("Only support tensors with up to 4 dimensions") + if data is not None: + cpp_tensor = Model.get_model().placeholder( + Dims(shape), + dtype.ctype(), + Dims(strides), + Dims(offsets), + Dims(padded_shape), + rank, + data, + name, + ) + else: + cpp_tensor = Model.get_model().tensor( + Dims(shape), + dtype.ctype(), + Dims(strides), + Dims(offsets), + Dims(padded_shape), + rank, + name, + ) + return cpp_tensor diff --git a/python/ark/torch/__init__.py b/python/ark/torch/__init__.py new file mode 100644 index 000000000..c1b6db3a2 --- /dev/null +++ b/python/ark/torch/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +try: + import torch + + _no_torch = False +except ImportError: + from . import mock as torch + + _no_torch = True diff --git a/python/ark/torch/mock.py b/python/ark/torch/mock.py new file mode 100644 index 000000000..7a7de0ae6 --- /dev/null +++ b/python/ark/torch/mock.py @@ -0,0 +1,43 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + + +class dtype: ... + + +class float32: ... + + +class float16: ... + + +class bfloat16: ... + + +class int32: ... + + +class int8: ... + + +class uint8: ... + + +class ubyte: ... + + +class Tensor: ... + + +class nn: + + class Module: ... + + class Parameter: ... + + +class autograd: + + class Function: + + def apply(self, *args, **kwargs): ... diff --git a/python/executor_py.cpp b/python/executor_py.cpp index a2195f106..5833e733b 100644 --- a/python/executor_py.cpp +++ b/python/executor_py.cpp @@ -1,12 +1,17 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +#include #include #include #include #include #include +#include + +#include "gpu/gpu_memory.hpp" +#include "logging.hpp" namespace py = pybind11; @@ -40,31 +45,181 @@ static void tensor_read(ark::Executor *exe, const ark::Tensor &tensor, reinterpret_cast(stream), is_d2d); } +static DLDataType to_dl_dtype(const ark::DataType &ark_dtype) { + DLDataType dl_dtype; + dl_dtype.lanes = 1; + if (ark_dtype == ark::FP32) { + dl_dtype.code = kDLFloat; + dl_dtype.bits = 32; + } else if (ark_dtype == ark::FP16) { + dl_dtype.code = kDLFloat; + dl_dtype.bits = 16; + } else if (ark_dtype == ark::BF16) { + dl_dtype.code = kDLBfloat; + dl_dtype.bits = 16; + } else if (ark_dtype == ark::INT32) { + dl_dtype.code = kDLInt; + dl_dtype.bits = 32; + } else if (ark_dtype == ark::UINT32) { + dl_dtype.code = kDLUInt; + dl_dtype.bits = 32; + } else if (ark_dtype == ark::INT8) { + dl_dtype.code = kDLInt; + dl_dtype.bits = 8; + } else if (ark_dtype == ark::UINT8) { + dl_dtype.code = kDLUInt; + dl_dtype.bits = 8; + } else if (ark_dtype == ark::BYTE) { + dl_dtype.code = kDLUInt; + dl_dtype.bits = 8; + } else { + ERR(ark::InternalError, "unexpected"); + } + return dl_dtype; +} + +static DLDeviceType get_device_type() { +#if defined(ARK_CUDA) + return kDLCUDA; +#elif defined(ARK_ROCM) + return kDLROCM; +#else + return kDLCPU; +#endif +} + +namespace ark { + +class SharedTensor { + public: + SharedTensor(Executor &exe, const Tensor &tensor); + ~SharedTensor() = default; + + DLTensor dl_tensor() const; + + private: + std::shared_ptr buffer_; + void *data_; + int device_id_; + DataType dtype_; + std::shared_ptr> shape_; + std::shared_ptr> strides_; + std::shared_ptr> offsets_; +}; + +SharedTensor::SharedTensor(Executor &exe, const Tensor &tensor) { + buffer_ = exe.buffer(); + data_ = reinterpret_cast(exe.tensor_address(tensor)); + device_id_ = exe.device_id(); + dtype_ = tensor.data_type(); + shape_ = std::make_shared>(tensor.shape().vector()); + strides_ = + std::make_shared>(tensor.torch_strides().vector()); + offsets_ = + std::make_shared>(tensor.offsets().vector()); +} + +DLTensor SharedTensor::dl_tensor() const { + DLTensor dl_tensor; + dl_tensor.data = data_; + size_t offset_in_elements = offsets_->empty() ? 0 : offsets_->at(0); + dl_tensor.byte_offset = offset_in_elements * dtype_.bytes(); + dl_tensor.device.device_type = get_device_type(); + dl_tensor.device.device_id = device_id_; + dl_tensor.ndim = static_cast(shape_->size()); + dl_tensor.dtype = to_dl_dtype(dtype_); + dl_tensor.shape = shape_->data(); + dl_tensor.strides = strides_->data(); + return dl_tensor; +} + +} // namespace ark + +static py::capsule tensor_to_dlpack(ark::Executor &self, + const ark::Tensor &tensor) { + auto shared_tensor = new ark::SharedTensor(self, tensor); + DLManagedTensor *dl_managed_tensor = new DLManagedTensor(); + dl_managed_tensor->dl_tensor = shared_tensor->dl_tensor(); + dl_managed_tensor->manager_ctx = shared_tensor; + dl_managed_tensor->deleter = [](DLManagedTensor *self) { + if (self->manager_ctx) { + delete static_cast(self->manager_ctx); + self->manager_ctx = nullptr; + } + }; + const char *capsule_name = "dltensor"; + PyObject *dl_capsule = PyCapsule_New( + static_cast(dl_managed_tensor), capsule_name, + [](PyObject *capsule) { + const char *name = PyCapsule_GetName(capsule); + auto *dl_managed_tensor = static_cast( + PyCapsule_GetPointer(capsule, name)); + if (dl_managed_tensor) { + dl_managed_tensor->deleter(dl_managed_tensor); + dl_managed_tensor = nullptr; + } + }); + return py::reinterpret_steal(dl_capsule); +} + void register_executor(py::module &m) { py::class_(m, "CoreExecutor") - .def(py::init([](int device_id, uintptr_t stream, - const std::string &name, const std::string &plan, - bool loop_mode) { - return new ark::Executor(device_id, - reinterpret_cast(stream), - name, plan, loop_mode); - })) + .def(py::init<>()) .def("device_id", &ark::Executor::device_id) .def("stream", [](ark::Executor *self) { return reinterpret_cast(self->stream()); }) .def("plan", &ark::Executor::plan) - .def("compile", &ark::Executor::compile) - .def("launch", &ark::Executor::launch) - .def("run", &ark::Executor::run, py::arg("iter")) + .def("name", &ark::Executor::name) + .def("compile", &ark::Executor::compile, py::arg("device_id"), + py::arg("plan"), py::arg("name") = "executor") + .def( + "launch", + [](ark::Executor *self, + const std::unordered_map + &placeholder_data, + uintptr_t stream, bool loop_mode, bool record) { + std::unordered_map tensor_ptr_map; + for (const auto &[tensor, addr] : placeholder_data) { + tensor_ptr_map[tensor] = reinterpret_cast(addr); + } + + self->launch(tensor_ptr_map, + reinterpret_cast(stream), loop_mode, + record); + }, + py::arg("placeholder_data") = + std::unordered_map(), + py::arg("stream") = 0, py::arg("loop_mode") = true, + py::arg("record") = false) + + .def( + "run", + [](ark::Executor *self, int iter, + const std::unordered_map + &placeholder_data) { + std::unordered_map tensor_ptr_map; + for (const auto &[tensor, addr] : placeholder_data) { + tensor_ptr_map[tensor] = reinterpret_cast(addr); + } + self->run(iter, tensor_ptr_map); + }, + py::arg("iter"), + py::arg("placeholder_data") = + std::unordered_map()) .def("wait", &ark::Executor::wait, py::arg("max_spin_count") = -1) .def("stop", &ark::Executor::stop, py::arg("max_spin_count") = -1) .def("barrier", &ark::Executor::barrier) .def("destroy", &ark::Executor::destroy) .def("destroyed", &ark::Executor::destroyed) - .def("tensor_address", &ark::Executor::tensor_address, - py::arg("tensor")) + .def( + "tensor_address", + [](ark::Executor *self, const ark::Tensor &tensor) { + return reinterpret_cast( + self->tensor_address(tensor)); + }, + py::arg("tensor")) .def("tensor_read", py::overload_cast(&tensor_read), @@ -82,5 +237,6 @@ void register_executor(py::module &m) { py::overload_cast(&tensor_write), py::arg("tensor"), py::arg("address"), py::arg("bytes"), - py::arg("stream"), py::arg("is_d2d")); + py::arg("stream"), py::arg("is_d2d")) + .def("tensor_to_dlpack", &tensor_to_dlpack); } diff --git a/python/model_py.cpp b/python/model_py.cpp index b9e7ec54f..6568f3a5c 100644 --- a/python/model_py.cpp +++ b/python/model_py.cpp @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +#include #include #include #include @@ -8,6 +9,8 @@ #include #include +#include "logging.hpp" + namespace py = pybind11; void register_model(py::module &m) { @@ -71,6 +74,19 @@ void register_model(py::module &m) { py::arg("input"), py::arg("other"), py::arg("output"), py::arg("name")) .def("noop", &ark::Model::noop, py::arg("input"), py::arg("name")) + .def( + "placeholder", + [](ark::Model &model, const ark::Dims &shape, + const ark::DataType &data_type, const ark::Dims &strides, + const ark::Dims &offsets, const ark::Dims &padded_shape, + int rank, uintptr_t data, const std::string &name) { + return model.placeholder(shape, data_type, strides, offsets, + padded_shape, rank, + reinterpret_cast(data), name); + }, + py::arg("shape"), py::arg("data_type"), py::arg("strides"), + py::arg("offsets"), py::arg("padded_shape"), py::arg("rank"), + py::arg("data"), py::arg("name")) .def("reduce_max", &ark::Model::reduce_max, py::arg("input"), py::arg("axis"), py::arg("keepdims"), py::arg("output"), py::arg("name")) @@ -80,6 +96,9 @@ void register_model(py::module &m) { .def("reduce_sum", &ark::Model::reduce_sum, py::arg("input"), py::arg("axis"), py::arg("keepdims"), py::arg("output"), py::arg("name")) + .def("refer", &ark::Model::refer, py::arg("input"), py::arg("shape"), + py::arg("strides"), py::arg("offsets"), py::arg("padded_shape"), + py::arg("name")) .def("relu", &ark::Model::relu, py::arg("input"), py::arg("output"), py::arg("name")) .def("reshape", &ark::Model::reshape, py::arg("input"), @@ -104,14 +123,9 @@ void register_model(py::module &m) { const std::string &>(&ark::Model::sub), py::arg("input"), py::arg("other"), py::arg("output"), py::arg("name")) - .def("tensor", - py::overload_cast( - &ark::Model::tensor), - py::arg("shape"), py::arg("data_type"), py::arg("strides"), - py::arg("offsets"), py::arg("padded_shape"), py::arg("rank"), - py::arg("name")) + .def("tensor", &ark::Model::tensor, py::arg("shape"), + py::arg("data_type"), py::arg("strides"), py::arg("offsets"), + py::arg("padded_shape"), py::arg("rank"), py::arg("name")) .def("transpose", &ark::Model::transpose, py::arg("input"), py::arg("permutation"), py::arg("output"), py::arg("name")) .def("all_reduce", &ark::Model::all_reduce, py::arg("input"), diff --git a/python/planner_py.cpp b/python/planner_py.cpp index f0af0fa35..b43a8fdd8 100644 --- a/python/planner_py.cpp +++ b/python/planner_py.cpp @@ -13,6 +13,7 @@ namespace py = pybind11; void register_planner(py::module &m) { py::class_(m, "CorePlannerContext") .def(py::init()) + .def("id", &ark::PlannerContext::id) .def("processor_range", &ark::PlannerContext::processor_range, py::arg("start"), py::arg("end"), py::arg("step") = 1) .def("warp_range", &ark::PlannerContext::warp_range, py::arg("start"), @@ -20,7 +21,8 @@ void register_planner(py::module &m) { .def("sram_range", &ark::PlannerContext::sram_range, py::arg("start"), py::arg("end"), py::arg("step") = 1) .def("sync", &ark::PlannerContext::sync, py::arg("sync")) - .def("config", &ark::PlannerContext::config, py::arg("config")); + .def("config", &ark::PlannerContext::config, py::arg("config")) + .def("dump", &ark::PlannerContext::dump); py::class_(m, "CorePlanner") .def(py::init()) diff --git a/python/tensor_py.cpp b/python/tensor_py.cpp index e85352f53..c6fde978e 100644 --- a/python/tensor_py.cpp +++ b/python/tensor_py.cpp @@ -12,15 +12,23 @@ namespace py = pybind11; void register_tensor(py::module &m) { py::class_(m, "CoreTensor") .def("id", &ark::Tensor::id) - .def("shape", &ark::Tensor::shape, py::return_value_policy::reference) - .def("strides", &ark::Tensor::strides, - py::return_value_policy::reference) - .def("offsets", &ark::Tensor::offsets, - py::return_value_policy::reference) - .def("padded_shape", &ark::Tensor::padded_shape, - py::return_value_policy::reference) - .def("data_type", &ark::Tensor::data_type, - py::return_value_policy::reference); + .def("shape", &ark::Tensor::shape) + .def("strides", &ark::Tensor::strides) + .def("offsets", &ark::Tensor::offsets) + .def("padded_shape", &ark::Tensor::padded_shape) + .def("data_type", &ark::Tensor::data_type) + .def("torch_strides", &ark::Tensor::torch_strides) + .def("data", + [](const ark::Tensor& self) { + return reinterpret_cast(self.data()); + }) + .def( + "data", + [](ark::Tensor& self, uintptr_t data) { + return self.data(reinterpret_cast(data)); + }, + py::arg("data")) + .def("is_external", &ark::Tensor::is_external); m.attr("NullTensor") = &ark::NullTensor; } diff --git a/python/unittest/test.py b/python/unittest/test.py index 693adb2d1..01f57c759 100644 --- a/python/unittest/test.py +++ b/python/unittest/test.py @@ -6,3 +6,4 @@ from test_model import * from test_ops import * from test_runtime import * +from test_tensor import * diff --git a/python/unittest/test_planner.py b/python/unittest/test_planner.py new file mode 100644 index 000000000..0a739c714 --- /dev/null +++ b/python/unittest/test_planner.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from common import ark, pytest_ark + + +@pytest_ark() +def test_planner_processor_range(): + input_tensor = ark.tensor([64, 64], ark.fp16) + other_tensor = ark.tensor([64, 64], ark.fp16) + + with ark.PlannerContext(processor_range=[0, 128]): + with ark.PlannerContext(processor_range=[0, 8], sync=False): + ark.add(input_tensor, other_tensor) + with ark.PlannerContext(processor_range=[8, 16], sync=False): + ark.add(input_tensor, other_tensor) + + plan = ark.Planner().plan() + + pg = plan.processor_groups + assert len(pg) == 1 + assert pg[0]["ResourceGroups"][0]["ProcessorRange"] == [0, 8] + assert pg[0]["ResourceGroups"][1]["ProcessorRange"] == [8, 16] + + +@pytest_ark() +def test_planner_sync(): + input_tensor = ark.tensor([64, 64], ark.fp16) + other_tensor = ark.tensor([64, 64], ark.fp16) + + with ark.PlannerContext(sync=False): + with ark.PlannerContext(): + ark.add(input_tensor, other_tensor) + with ark.PlannerContext(): + ark.add(input_tensor, other_tensor) + + plan = ark.Planner().plan() + + pg = plan.processor_groups + assert len(pg) == 1 diff --git a/python/unittest/test_runtime.py b/python/unittest/test_runtime.py index 269253e13..969f6140e 100644 --- a/python/unittest/test_runtime.py +++ b/python/unittest/test_runtime.py @@ -2,11 +2,71 @@ # Licensed under the MIT license. from common import ark, pytest_ark +import numpy as np @pytest_ark() def test_runtime_empty(): - with ark.Runtime.get_runtime() as rt: + with ark.Runtime() as rt: rt.launch() rt.run() rt.stop() + + +@pytest_ark() +def test_runtime_init(): + M, N = 64, 64 + input_tensor = ark.tensor([M, N], ark.fp16) + other_tensor = ark.tensor([M, N], ark.fp16) + output_tensor = ark.add(input_tensor, other_tensor) + runtime = ark.Runtime() + runtime.launch() + input_tensor_host = np.random.rand(M, N).astype(np.float16) + input_tensor.from_numpy(input_tensor_host) + other_tensor_host = np.random.rand(M, N).astype(np.float16) + other_tensor.from_numpy(other_tensor_host) + runtime.run() + output_tensor_host = output_tensor.to_numpy() + np.testing.assert_allclose( + output_tensor_host, input_tensor_host + other_tensor_host + ) + runtime.stop() + ark.Model.reset() + prev_output = output_tensor + new_tensor = ark.tensor([M, N], ark.fp16) + final_output = ark.add(prev_output, new_tensor) + runtime.launch() + new_tensor_host = np.random.rand(M, N).astype(np.float16) + new_tensor.from_numpy(new_tensor_host) + runtime.run() + final_output_host = final_output.to_numpy() + np.testing.assert_allclose( + final_output_host, output_tensor_host + new_tensor_host + ) + + +@pytest_ark() +def test_runtime_reuse_plans(): + M, N = 64, 64 + input_tensor = ark.tensor([M, N], ark.fp16) + other_tensor = ark.tensor([M, N], ark.fp16) + output_tensor = ark.add(input_tensor, other_tensor) + runtime = ark.Runtime() + runtime.launch() + input_tensor_host = np.random.rand(M, N).astype(np.float16) + input_tensor.from_numpy(input_tensor_host) + other_tensor_host = np.random.rand(M, N).astype(np.float16) + other_tensor.from_numpy(other_tensor_host) + runtime.run() + output_tensor_host = output_tensor.to_numpy() + np.testing.assert_allclose( + output_tensor_host, input_tensor_host + other_tensor_host + ) + runtime.stop() + ark.Model.reset() + runtime.launch() + runtime.run() + output_tensor_host = output_tensor.to_numpy() + np.testing.assert_allclose( + output_tensor_host, input_tensor_host + other_tensor_host + ) diff --git a/python/unittest/test_tensor.py b/python/unittest/test_tensor.py new file mode 100644 index 000000000..c8be143f0 --- /dev/null +++ b/python/unittest/test_tensor.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from common import ark, pytest_ark +import numpy as np + + +@pytest_ark() +def test_tensor_slice(): + t0 = ark.ones([4, 64], ark.fp16) + t1 = t0[2:, :] + ark.noop(t1) + + assert t1.shape() == [2, 64] + assert t1.dtype() == ark.fp16 + assert t1.strides() == [4, 64] + + with ark.Runtime() as rt: + rt.launch() + rt.run() + + x = t1.to_numpy() + + assert np.allclose(x, np.ones([2, 64], np.float16)) + + +@pytest_ark(need_torch=True) +def test_tensor_torch(): + import torch + + ones = torch.ones(2, 1024, device=torch.device("cuda:0")) + + t = ark.Tensor.from_torch(ones) + t = ark.mul(t, 5) + + with ark.Runtime() as rt: + rt.launch() + rt.run() + + x = t.to_torch() + + assert torch.allclose(x, ones * 5) diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 96e442289..49251be74 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -40,6 +40,19 @@ if (NOT json_POPULATED) endif() set(JSON_INCLUDE_DIRS ${json_SOURCE_DIR}/include PARENT_SCOPE) +# DLPack +FetchContent_Declare( + dlpack + GIT_REPOSITORY https://github.com/dmlc/dlpack + GIT_TAG v0.8 + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/dlpack +) +FetchContent_GetProperties(dlpack) +if (NOT dlpack_POPULATED) + FetchContent_Populate(dlpack) +endif() +set(DLPACK_INCLUDE_DIRS ${dlpack_SOURCE_DIR}/include PARENT_SCOPE) + if(ARK_USE_CUDA) # Configure CUTLASS FetchContent_Declare( diff --git a/third_party/dlpack b/third_party/dlpack new file mode 160000 index 000000000..365b823ce --- /dev/null +++ b/third_party/dlpack @@ -0,0 +1 @@ +Subproject commit 365b823cedb281cd0240ca601aba9b78771f91a3