diff --git a/qa/L0_shared_memory/shared_memory_test.py b/qa/L0_shared_memory/shared_memory_test.py index 0790f1d558..321c80f058 100755 --- a/qa/L0_shared_memory/shared_memory_test.py +++ b/qa/L0_shared_memory/shared_memory_test.py @@ -56,7 +56,7 @@ class SharedMemoryTest(tu.TestResultCollector): def _setUp(self, protocol, log_file_path): self._tritonserver_ipaddr = os.environ.get("TRITONSERVER_IPADDR", "localhost") self._test_windows = bool(int(os.environ.get("TEST_WINDOWS", 0))) - self._shm_key_prefix = "/" if not self._test_windows else "Global\\" + self._shm_key_prefix = "/" if not self._test_windows else "" self._timeout = os.environ.get("SERVER_TIMEOUT", 120) self._protocol = protocol self._test_passed = False @@ -95,9 +95,11 @@ def _setup_client(self): def _build_server_args(self): if self._test_windows: - backend_dir = "C:\\opt\\tritonserver\\backends" - model_dir = "C:\\opt\\tritonserver\\qa\\L0_shared_memory\\models" - self._server_executable = "C:\\opt\\tritonserver\\bin\\tritonserver.exe" + backend_dir = os.environ.get("BACKEND_DIR", "C:\\opt\\tritonserver\\backends") + model_dir = os.environ.get("MODELDIR", (os.getcwd() + "\\models")) + self._server_executable = os.environ.get( + "SERVER", "C:\\tritonserver\\bin\\tritonserver.exe" + ) else: triton_dir = os.environ.get("TRITON_DIR", "/opt/tritonserver") backend_dir = os.environ.get("BACKEND_DIR", f"{triton_dir}/backends") @@ -178,21 +180,44 @@ def _configure_server( input1_data = np.ones(shape=16, dtype=np.int32) shm.set_shared_memory_region(shm_ip0_handle, [input0_data]) shm.set_shared_memory_region(shm_ip1_handle, [input1_data]) - self.triton_client.register_system_shared_memory( - "input0_data", (self._shm_key_prefix + "input0_data"), register_byte_size, offset=register_offset - ) - self.triton_client.register_system_shared_memory( - "input1_data", (self._shm_key_prefix + "input1_data"), register_byte_size, offset=register_offset - ) - self.triton_client.register_system_shared_memory( - "output0_data", (self._shm_key_prefix + "output0_data"), register_byte_size, offset=register_offset - ) - self.triton_client.register_system_shared_memory( - "output1_data", (self._shm_key_prefix + "output1_data"), register_byte_size, offset=register_offset - ) + try: + self._triton_client.register_system_shared_memory( + "input0_data", + (self._shm_key_prefix + "input0_data"), + register_byte_size, + offset=register_offset, + ) + self._triton_client.register_system_shared_memory( + "input1_data", + (self._shm_key_prefix + "input1_data"), + register_byte_size, + offset=register_offset, + ) + self._triton_client.register_system_shared_memory( + "output0_data", + (self._shm_key_prefix + "output0_data"), + register_byte_size, + offset=register_offset, + ) + self._triton_client.register_system_shared_memory( + "output1_data", + (self._shm_key_prefix + "output1_data"), + register_byte_size, + offset=register_offset, + ) + except utils.InferenceServerException as e: + shm_handles = [ + shm_ip0_handle, + shm_ip1_handle, + shm_op0_handle, + shm_op1_handle, + ] + self._cleanup_server(shm_handles) + raise (e) return [shm_ip0_handle, shm_ip1_handle, shm_op0_handle, shm_op1_handle] def _cleanup_server(self, shm_handles): + self._triton_client.unregister_system_shared_memory() for shm_handle in shm_handles: shm.destroy_shared_memory_region(shm_handle) @@ -326,6 +351,7 @@ def test_valid_create_set_register(self, protocol): self.assertEqual(len(shm_status), 1) else: self.assertEqual(len(shm_status.regions), 1) + self._triton_client.unregister_system_shared_memory() shm.destroy_shared_memory_region(shm_op0_handle) self._test_passed = True @@ -355,6 +381,7 @@ def test_different_name_same_key(self, protocol): ) except Exception as ex: self.assertIn("registering an active shared memory key", str(ex)) + self._triton_client.unregister_system_shared_memory() shm.destroy_shared_memory_region(shm_op0_handle) self._test_passed = True @@ -379,6 +406,7 @@ def test_unregister_before_register(self, protocol): self.assertEqual(len(shm_status), 0) else: self.assertEqual(len(shm_status.regions), 0) + self._triton_client.unregister_system_shared_memory() shm.destroy_shared_memory_region(shm_op0_handle) self._test_passed = True @@ -405,6 +433,7 @@ def test_unregister_after_register(self, protocol): self.assertEqual(len(shm_status), 0) else: self.assertEqual(len(shm_status.regions), 0) + self._triton_client.unregister_system_shared_memory() shm.destroy_shared_memory_region(shm_op0_handle) self._test_passed = True @@ -438,6 +467,7 @@ def test_reregister_after_register(self, protocol): self.assertEqual(len(shm_status), 1) else: self.assertEqual(len(shm_status.regions), 1) + self._triton_client.unregister_system_shared_memory() shm.destroy_shared_memory_region(shm_op0_handle) self._test_passed = True @@ -664,6 +694,7 @@ def test_register_out_of_bound(self, protocol): register_byte_size=0, register_offset=create_byte_size + 1, ) + self._test_passed = True if __name__ == "__main__": diff --git a/src/shared_memory_manager.cc b/src/shared_memory_manager.cc index 503f449471..906a119adf 100644 --- a/src/shared_memory_manager.cc +++ b/src/shared_memory_manager.cc @@ -36,6 +36,8 @@ #include #include #include +#else +#define TRITON_SHM_FILE_ROOT "C:\\triton_shm\\" #endif namespace triton { namespace server { @@ -122,7 +124,7 @@ OpenCudaIPCRegion( TRITONSERVER_Error* SharedMemoryManager::OpenSharedMemoryRegion( - const std::string& shm_key, ShmFile** shm_file) + const std::string& shm_key, std::shared_ptr& shm_file) { #ifdef _WIN32 HANDLE shm_handle = OpenFileMapping( @@ -138,8 +140,7 @@ SharedMemoryManager::OpenSharedMemoryRegion( std::string("Unable to open shared memory region: '" + shm_key + "'") .c_str()); } - // Dynamic memory will eventually be owned by uniqe_ptr - *shm_file = new ShmFile(shm_handle); + shm_file = std::make_shared(shm_handle); #else // get shared memory region descriptor int shm_fd = shm_open(shm_key.c_str(), O_RDWR, S_IRUSR | S_IWUSR); @@ -150,8 +151,7 @@ SharedMemoryManager::OpenSharedMemoryRegion( std::string("Unable to open shared memory region: '" + shm_key + "'") .c_str()); } - // Dynamic memory will eventually be owned by uniqe_ptr - *shm_file = new ShmFile(shm_fd); + shm_file = std::make_shared(shm_fd); #endif return nullptr; } @@ -161,14 +161,42 @@ SharedMemoryManager::GetSharedMemoryRegionSize( const std::string& shm_key, ShmFile* shm_file, uint64_t* shm_region_size) { #ifdef WIN32 - BY_HANDLE_FILE_INFORMATION info; - if(!GetFileInformationByHandle(shm_file->shm_handle_, &info)) { - LOG_VERBOSE(1) << "GetFileInformationByHandle failed with error code: " << GetWindowsError(); + // Open file for reading + LPCSTR backing_file_path = + std::string(TRITON_SHM_FILE_ROOT + shm_key).c_str(); + HANDLE backing_file_handle = CreateFile( + backing_file_path, 0, FILE_SHARE_READ, NULL, OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, NULL); + if (backing_file_handle == INVALID_HANDLE_VALUE) { + LOG_VERBOSE(1) << "Failed to open backing file with error code: " + << GetWindowsError(); + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string("Invalid shared memory region: '" + shm_key + "'").c_str()); + } + // Construct its size + uint64_t file_size; + DWORD high_order_size; + DWORD low_order_size = GetFileSize(backing_file_handle, &high_order_size); + if (low_order_size == INVALID_FILE_SIZE) { + CloseHandle(backing_file_handle); + LOG_VERBOSE(1) << "GetFileSize failed with error code: " + << GetWindowsError(); + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string("Invalid shared memory region: '" + shm_key + "'").c_str()); + } else if (high_order_size != NULL) { + file_size = ((uint64_t)high_order_size << 32) | low_order_size; + } else { + file_size = low_order_size; + } + if (!CloseHandle(backing_file_handle)) { + LOG_VERBOSE(1) << "failed to close backing file with error: " + << GetWindowsError(); return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INTERNAL, std::string("Invalid shared memory region: '" + shm_key + "'").c_str()); } - uint64_t file_size = ((uint64_t)info.nFileSizeHigh << 32) | info.nFileSizeLow; *shm_region_size = file_size; #else struct stat file_status; @@ -201,6 +229,7 @@ SharedMemoryManager::CheckSharedMemoryRegionSize( RETURN_IF_ERR(GetSharedMemoryRegionSize(shm_key, shm_file, &shm_region_size)); // User-provided offset and byte_size should not go out-of-bounds. if ((offset + byte_size) > shm_region_size) { + CloseSharedMemoryRegion(shm_file); return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, std::string( @@ -263,7 +292,7 @@ SharedMemoryManager::MapSharedMemory( byte_size); if (*mapped_addr == NULL) { - CloseSharedMemoryRegion(shm_handle); + CloseSharedMemoryRegion(shm_file); return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INTERNAL, std::string( @@ -289,9 +318,7 @@ SharedMemoryManager::MapSharedMemory( SharedMemoryManager::~SharedMemoryManager() { UnregisterAll(TRITONSERVER_MEMORY_CPU); -#ifndef _WIN32 UnregisterAll(TRITONSERVER_MEMORY_GPU); -#endif } TRITONSERVER_Error* @@ -310,14 +337,14 @@ SharedMemoryManager::RegisterSystemSharedMemory( // register void* mapped_addr; - ShmFile* shm_file = nullptr; + std::shared_ptr shm_file; bool shm_file_exists = false; // don't re-open if shared memory is already open for (auto itr = shared_memory_map_.begin(); itr != shared_memory_map_.end(); ++itr) { if (itr->second->shm_key_ == shm_key) { - shm_file = itr->second->platform_handle_.get(); + shm_file = itr->second->platform_handle_; shm_file_exists = true; break; } @@ -325,7 +352,7 @@ SharedMemoryManager::RegisterSystemSharedMemory( // open and set new shm_file if new shared memory key if (!shm_file_exists) { - RETURN_IF_ERR(OpenSharedMemoryRegion(shm_key, &shm_file)); + RETURN_IF_ERR(OpenSharedMemoryRegion(shm_key, shm_file)); } else { // FIXME: DLIS-6448 - We should allow users the flexibility to register // the same key under different names with different attributes. @@ -338,13 +365,13 @@ SharedMemoryManager::RegisterSystemSharedMemory( } // Enforce that registered region is in-bounds of shm file object. - RETURN_IF_ERR( - CheckSharedMemoryRegionSize(name, shm_key, shm_file, offset, byte_size)); + RETURN_IF_ERR(CheckSharedMemoryRegionSize( + name, shm_key, shm_file.get(), offset, byte_size)); // Mmap and then close the shared memory descriptor TRITONSERVER_Error* err_map = - MapSharedMemory(shm_file, offset, byte_size, &mapped_addr); - TRITONSERVER_Error* err_close = CloseSharedMemoryRegion(shm_file); + MapSharedMemory(shm_file.get(), offset, byte_size, &mapped_addr); + TRITONSERVER_Error* err_close = CloseSharedMemoryRegion(shm_file.get()); if (err_map != nullptr) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, diff --git a/src/shared_memory_manager.h b/src/shared_memory_manager.h index 666a0e2908..3f2865788a 100644 --- a/src/shared_memory_manager.h +++ b/src/shared_memory_manager.h @@ -136,10 +136,12 @@ class SharedMemoryManager { struct ShmFile { #ifdef _WIN32 HANDLE shm_handle_; - ShmFile(HANDLE shm_handle) : shm_file_(shm_handle){}; + ShmFile(HANDLE shm_handle) : shm_handle_(shm_handle){}; + ~ShmFile() { CloseHandle(shm_handle_); }; #else int shm_fd_; ShmFile(int fd) : shm_fd_(fd){}; + ~ShmFile() { close(fd); }; #endif // _WIN32 }; @@ -150,7 +152,7 @@ class SharedMemoryManager { /// opened shared memory object. /// \return a TRITONSERVER_Error indicating success or failure. TRITONSERVER_Error* OpenSharedMemoryRegion( - const std::string& shm_key, ShmFile** shm_file); + const std::string& shm_key, std::shared_ptr& shm_file); /// Get the size of the shared memory region. /// \param shm_key The name of the shared memory object @@ -160,9 +162,8 @@ class SharedMemoryManager { /// \param shm_region_size A pointer to store the size of the /// shared memory region. /// \return a TRITONSERVER_Error indicating success or failure. - TRITONSERVER_Error* - GetSharedMemoryRegionSize( - const std::string& shm_key, int shm_fd, size_t* shm_region_size) + TRITONSERVER_Error* GetSharedMemoryRegionSize( + const std::string& shm_key, ShmFile* shm_file, size_t* shm_region_size); /// Validate that offset + byte_size does not exceed the size of /// the registered shared memory region. @@ -175,10 +176,9 @@ class SharedMemoryManager { /// start of the block. /// \param byte_size The size, in bytes of the block. /// \return a TRITONSERVER_Error indicating success or failure. - TRITONSERVER_Error* - CheckSharedMemoryRegionSize( - const std::string& name, const std::string& shm_key, ShmFile* shm_file, - size_t offset, size_t byte_size) + TRITONSERVER_Error* CheckSharedMemoryRegionSize( + const std::string& name, const std::string& shm_key, ShmFile* shm_file, + size_t offset, size_t byte_size); /// Close the shared memory object. /// \param shm_file The file handle/descriptor of the the @@ -210,35 +210,33 @@ class SharedMemoryManager { struct SharedMemoryInfo { SharedMemoryInfo( const std::string& name, const std::string& shm_key, - const size_t offset, const size_t byte_size, ShmFile* shm_file, - void* mapped_addr, const TRITONSERVER_MemoryType kind, - const int64_t device_id) + const size_t offset, const size_t byte_size, + std::shared_ptr shm_file, void* mapped_addr, + const TRITONSERVER_MemoryType kind, const int64_t device_id) : name_(name), shm_key_(shm_key), offset_(offset), - byte_size_(byte_size), mapped_addr_(mapped_addr), kind_(kind), - device_id_(device_id) + platform_handle_(shm_file), byte_size_(byte_size), + mapped_addr_(mapped_addr), kind_(kind), device_id_(device_id) { - if (shm_file != nullptr) { - platform_handle_.reset(shm_file); - } } std::string name_; std::string shm_key_; size_t offset_; + std::shared_ptr platform_handle_; size_t byte_size_; void* mapped_addr_; TRITONSERVER_MemoryType kind_; int64_t device_id_; - std::unique_ptr platform_handle_; }; #ifdef TRITON_ENABLE_GPU struct CUDASharedMemoryInfo : SharedMemoryInfo { CUDASharedMemoryInfo( const std::string& name, const std::string& shm_key, - const size_t offset, const size_t byte_size, ShmFile* shm_file, - void* mapped_addr, const TRITONSERVER_MemoryType kind, - const int64_t device_id, const cudaIpcMemHandle_t* cuda_ipc_handle) + const size_t offset, const size_t byte_size, + std::shared_ptr shm_file, void* mapped_addr, + const TRITONSERVER_MemoryType kind, const int64_t device_id, + const cudaIpcMemHandle_t* cuda_ipc_handle) : SharedMemoryInfo( name, shm_key, offset, byte_size, shm_file, mapped_addr, kind, device_id),