Skip to content

Commit

Permalink
Open backing file to validate shared memory
Browse files Browse the repository at this point in the history
  • Loading branch information
fpetrini15 committed Apr 13, 2024
1 parent 7405f83 commit 126f7dc
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 56 deletions.
63 changes: 47 additions & 16 deletions qa/L0_shared_memory/shared_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class SharedMemoryTest(tu.TestResultCollector):
def _setUp(self, protocol, log_file_path):
self._tritonserver_ipaddr = os.environ.get("TRITONSERVER_IPADDR", "localhost")
self._test_windows = bool(int(os.environ.get("TEST_WINDOWS", 0)))
self._shm_key_prefix = "/" if not self._test_windows else "Global\\"
self._shm_key_prefix = "/" if not self._test_windows else ""
self._timeout = os.environ.get("SERVER_TIMEOUT", 120)
self._protocol = protocol
self._test_passed = False
Expand Down Expand Up @@ -95,9 +95,11 @@ def _setup_client(self):

def _build_server_args(self):
if self._test_windows:
backend_dir = "C:\\opt\\tritonserver\\backends"
model_dir = "C:\\opt\\tritonserver\\qa\\L0_shared_memory\\models"
self._server_executable = "C:\\opt\\tritonserver\\bin\\tritonserver.exe"
backend_dir = os.environ.get("BACKEND_DIR", "C:\\opt\\tritonserver\\backends")
model_dir = os.environ.get("MODELDIR", (os.getcwd() + "\\models"))
self._server_executable = os.environ.get(
"SERVER", "C:\\tritonserver\\bin\\tritonserver.exe"
)
else:
triton_dir = os.environ.get("TRITON_DIR", "/opt/tritonserver")
backend_dir = os.environ.get("BACKEND_DIR", f"{triton_dir}/backends")
Expand Down Expand Up @@ -178,21 +180,44 @@ def _configure_server(
input1_data = np.ones(shape=16, dtype=np.int32)
shm.set_shared_memory_region(shm_ip0_handle, [input0_data])
shm.set_shared_memory_region(shm_ip1_handle, [input1_data])
self.triton_client.register_system_shared_memory(
"input0_data", (self._shm_key_prefix + "input0_data"), register_byte_size, offset=register_offset
)
self.triton_client.register_system_shared_memory(
"input1_data", (self._shm_key_prefix + "input1_data"), register_byte_size, offset=register_offset
)
self.triton_client.register_system_shared_memory(
"output0_data", (self._shm_key_prefix + "output0_data"), register_byte_size, offset=register_offset
)
self.triton_client.register_system_shared_memory(
"output1_data", (self._shm_key_prefix + "output1_data"), register_byte_size, offset=register_offset
)
try:
self._triton_client.register_system_shared_memory(
"input0_data",
(self._shm_key_prefix + "input0_data"),
register_byte_size,
offset=register_offset,
)
self._triton_client.register_system_shared_memory(
"input1_data",
(self._shm_key_prefix + "input1_data"),
register_byte_size,
offset=register_offset,
)
self._triton_client.register_system_shared_memory(
"output0_data",
(self._shm_key_prefix + "output0_data"),
register_byte_size,
offset=register_offset,
)
self._triton_client.register_system_shared_memory(
"output1_data",
(self._shm_key_prefix + "output1_data"),
register_byte_size,
offset=register_offset,
)
except utils.InferenceServerException as e:
shm_handles = [
shm_ip0_handle,
shm_ip1_handle,
shm_op0_handle,
shm_op1_handle,
]
self._cleanup_server(shm_handles)
raise (e)
return [shm_ip0_handle, shm_ip1_handle, shm_op0_handle, shm_op1_handle]

def _cleanup_server(self, shm_handles):
self._triton_client.unregister_system_shared_memory()
for shm_handle in shm_handles:
shm.destroy_shared_memory_region(shm_handle)

Expand Down Expand Up @@ -326,6 +351,7 @@ def test_valid_create_set_register(self, protocol):
self.assertEqual(len(shm_status), 1)
else:
self.assertEqual(len(shm_status.regions), 1)
self._triton_client.unregister_system_shared_memory()
shm.destroy_shared_memory_region(shm_op0_handle)
self._test_passed = True

Expand Down Expand Up @@ -355,6 +381,7 @@ def test_different_name_same_key(self, protocol):
)
except Exception as ex:
self.assertIn("registering an active shared memory key", str(ex))
self._triton_client.unregister_system_shared_memory()
shm.destroy_shared_memory_region(shm_op0_handle)
self._test_passed = True

Expand All @@ -379,6 +406,7 @@ def test_unregister_before_register(self, protocol):
self.assertEqual(len(shm_status), 0)
else:
self.assertEqual(len(shm_status.regions), 0)
self._triton_client.unregister_system_shared_memory()
shm.destroy_shared_memory_region(shm_op0_handle)
self._test_passed = True

Expand All @@ -405,6 +433,7 @@ def test_unregister_after_register(self, protocol):
self.assertEqual(len(shm_status), 0)
else:
self.assertEqual(len(shm_status.regions), 0)
self._triton_client.unregister_system_shared_memory()
shm.destroy_shared_memory_region(shm_op0_handle)
self._test_passed = True

Expand Down Expand Up @@ -438,6 +467,7 @@ def test_reregister_after_register(self, protocol):
self.assertEqual(len(shm_status), 1)
else:
self.assertEqual(len(shm_status.regions), 1)
self._triton_client.unregister_system_shared_memory()
shm.destroy_shared_memory_region(shm_op0_handle)
self._test_passed = True

Expand Down Expand Up @@ -664,6 +694,7 @@ def test_register_out_of_bound(self, protocol):
register_byte_size=0,
register_offset=create_byte_size + 1,
)
self._test_passed = True


if __name__ == "__main__":
Expand Down
65 changes: 46 additions & 19 deletions src/shared_memory_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#else
#define TRITON_SHM_FILE_ROOT "C:\\triton_shm\\"
#endif

namespace triton { namespace server {
Expand Down Expand Up @@ -122,7 +124,7 @@ OpenCudaIPCRegion(

TRITONSERVER_Error*
SharedMemoryManager::OpenSharedMemoryRegion(
const std::string& shm_key, ShmFile** shm_file)
const std::string& shm_key, std::shared_ptr<ShmFile>& shm_file)
{
#ifdef _WIN32
HANDLE shm_handle = OpenFileMapping(
Expand All @@ -138,8 +140,7 @@ SharedMemoryManager::OpenSharedMemoryRegion(
std::string("Unable to open shared memory region: '" + shm_key + "'")
.c_str());
}
// Dynamic memory will eventually be owned by uniqe_ptr
*shm_file = new ShmFile(shm_handle);
shm_file = std::make_shared<ShmFile>(shm_handle);
#else
// get shared memory region descriptor
int shm_fd = shm_open(shm_key.c_str(), O_RDWR, S_IRUSR | S_IWUSR);
Expand All @@ -150,8 +151,7 @@ SharedMemoryManager::OpenSharedMemoryRegion(
std::string("Unable to open shared memory region: '" + shm_key + "'")
.c_str());
}
// Dynamic memory will eventually be owned by uniqe_ptr
*shm_file = new ShmFile(shm_fd);
shm_file = std::make_shared<ShmFile>(shm_fd);
#endif
return nullptr;
}
Expand All @@ -161,14 +161,42 @@ SharedMemoryManager::GetSharedMemoryRegionSize(
const std::string& shm_key, ShmFile* shm_file, uint64_t* shm_region_size)
{
#ifdef WIN32
BY_HANDLE_FILE_INFORMATION info;
if(!GetFileInformationByHandle(shm_file->shm_handle_, &info)) {
LOG_VERBOSE(1) << "GetFileInformationByHandle failed with error code: " << GetWindowsError();
// Open file for reading
LPCSTR backing_file_path =
std::string(TRITON_SHM_FILE_ROOT + shm_key).c_str();
HANDLE backing_file_handle = CreateFile(
backing_file_path, 0, FILE_SHARE_READ, NULL, OPEN_EXISTING,
FILE_ATTRIBUTE_NORMAL, NULL);
if (backing_file_handle == INVALID_HANDLE_VALUE) {
LOG_VERBOSE(1) << "Failed to open backing file with error code: "
<< GetWindowsError();
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
std::string("Invalid shared memory region: '" + shm_key + "'").c_str());
}
// Construct its size
uint64_t file_size;
DWORD high_order_size;
DWORD low_order_size = GetFileSize(backing_file_handle, &high_order_size);
if (low_order_size == INVALID_FILE_SIZE) {
CloseHandle(backing_file_handle);
LOG_VERBOSE(1) << "GetFileSize failed with error code: "
<< GetWindowsError();
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
std::string("Invalid shared memory region: '" + shm_key + "'").c_str());
} else if (high_order_size != NULL) {
file_size = ((uint64_t)high_order_size << 32) | low_order_size;
} else {
file_size = low_order_size;
}
if (!CloseHandle(backing_file_handle)) {
LOG_VERBOSE(1) << "failed to close backing file with error: "
<< GetWindowsError();
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
std::string("Invalid shared memory region: '" + shm_key + "'").c_str());
}
uint64_t file_size = ((uint64_t)info.nFileSizeHigh << 32) | info.nFileSizeLow;
*shm_region_size = file_size;
#else
struct stat file_status;
Expand Down Expand Up @@ -201,6 +229,7 @@ SharedMemoryManager::CheckSharedMemoryRegionSize(
RETURN_IF_ERR(GetSharedMemoryRegionSize(shm_key, shm_file, &shm_region_size));
// User-provided offset and byte_size should not go out-of-bounds.
if ((offset + byte_size) > shm_region_size) {
CloseSharedMemoryRegion(shm_file);
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
Expand Down Expand Up @@ -263,7 +292,7 @@ SharedMemoryManager::MapSharedMemory(
byte_size);

if (*mapped_addr == NULL) {
CloseSharedMemoryRegion(shm_handle);
CloseSharedMemoryRegion(shm_file);
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
std::string(
Expand All @@ -289,9 +318,7 @@ SharedMemoryManager::MapSharedMemory(
SharedMemoryManager::~SharedMemoryManager()
{
UnregisterAll(TRITONSERVER_MEMORY_CPU);
#ifndef _WIN32
UnregisterAll(TRITONSERVER_MEMORY_GPU);
#endif
}

TRITONSERVER_Error*
Expand All @@ -310,22 +337,22 @@ SharedMemoryManager::RegisterSystemSharedMemory(

// register
void* mapped_addr;
ShmFile* shm_file = nullptr;
std::shared_ptr<ShmFile> shm_file;
bool shm_file_exists = false;

// don't re-open if shared memory is already open
for (auto itr = shared_memory_map_.begin(); itr != shared_memory_map_.end();
++itr) {
if (itr->second->shm_key_ == shm_key) {
shm_file = itr->second->platform_handle_.get();
shm_file = itr->second->platform_handle_;
shm_file_exists = true;
break;
}
}

// open and set new shm_file if new shared memory key
if (!shm_file_exists) {
RETURN_IF_ERR(OpenSharedMemoryRegion(shm_key, &shm_file));
RETURN_IF_ERR(OpenSharedMemoryRegion(shm_key, shm_file));
} else {
// FIXME: DLIS-6448 - We should allow users the flexibility to register
// the same key under different names with different attributes.
Expand All @@ -338,13 +365,13 @@ SharedMemoryManager::RegisterSystemSharedMemory(
}

// Enforce that registered region is in-bounds of shm file object.
RETURN_IF_ERR(
CheckSharedMemoryRegionSize(name, shm_key, shm_file, offset, byte_size));
RETURN_IF_ERR(CheckSharedMemoryRegionSize(
name, shm_key, shm_file.get(), offset, byte_size));

// Mmap and then close the shared memory descriptor
TRITONSERVER_Error* err_map =
MapSharedMemory(shm_file, offset, byte_size, &mapped_addr);
TRITONSERVER_Error* err_close = CloseSharedMemoryRegion(shm_file);
MapSharedMemory(shm_file.get(), offset, byte_size, &mapped_addr);
TRITONSERVER_Error* err_close = CloseSharedMemoryRegion(shm_file.get());
if (err_map != nullptr) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
Expand Down
40 changes: 19 additions & 21 deletions src/shared_memory_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,12 @@ class SharedMemoryManager {
struct ShmFile {
#ifdef _WIN32
HANDLE shm_handle_;
ShmFile(HANDLE shm_handle) : shm_file_(shm_handle){};
ShmFile(HANDLE shm_handle) : shm_handle_(shm_handle){};
~ShmFile() { CloseHandle(shm_handle_); };
#else
int shm_fd_;
ShmFile(int fd) : shm_fd_(fd){};
~ShmFile() { close(fd); };
#endif // _WIN32
};

Expand All @@ -150,7 +152,7 @@ class SharedMemoryManager {
/// opened shared memory object.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONSERVER_Error* OpenSharedMemoryRegion(
const std::string& shm_key, ShmFile** shm_file);
const std::string& shm_key, std::shared_ptr<ShmFile>& shm_file);

/// Get the size of the shared memory region.
/// \param shm_key The name of the shared memory object
Expand All @@ -160,9 +162,8 @@ class SharedMemoryManager {
/// \param shm_region_size A pointer to store the size of the
/// shared memory region.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONSERVER_Error*
GetSharedMemoryRegionSize(
const std::string& shm_key, int shm_fd, size_t* shm_region_size)
TRITONSERVER_Error* GetSharedMemoryRegionSize(
const std::string& shm_key, ShmFile* shm_file, size_t* shm_region_size);

/// Validate that offset + byte_size does not exceed the size of
/// the registered shared memory region.
Expand All @@ -175,10 +176,9 @@ class SharedMemoryManager {
/// start of the block.
/// \param byte_size The size, in bytes of the block.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONSERVER_Error*
CheckSharedMemoryRegionSize(
const std::string& name, const std::string& shm_key, ShmFile* shm_file,
size_t offset, size_t byte_size)
TRITONSERVER_Error* CheckSharedMemoryRegionSize(
const std::string& name, const std::string& shm_key, ShmFile* shm_file,
size_t offset, size_t byte_size);

/// Close the shared memory object.
/// \param shm_file The file handle/descriptor of the the
Expand Down Expand Up @@ -210,35 +210,33 @@ class SharedMemoryManager {
struct SharedMemoryInfo {
SharedMemoryInfo(
const std::string& name, const std::string& shm_key,
const size_t offset, const size_t byte_size, ShmFile* shm_file,
void* mapped_addr, const TRITONSERVER_MemoryType kind,
const int64_t device_id)
const size_t offset, const size_t byte_size,
std::shared_ptr<ShmFile> shm_file, void* mapped_addr,
const TRITONSERVER_MemoryType kind, const int64_t device_id)
: name_(name), shm_key_(shm_key), offset_(offset),
byte_size_(byte_size), mapped_addr_(mapped_addr), kind_(kind),
device_id_(device_id)
platform_handle_(shm_file), byte_size_(byte_size),
mapped_addr_(mapped_addr), kind_(kind), device_id_(device_id)
{
if (shm_file != nullptr) {
platform_handle_.reset(shm_file);
}
}

std::string name_;
std::string shm_key_;
size_t offset_;
std::shared_ptr<ShmFile> platform_handle_;
size_t byte_size_;
void* mapped_addr_;
TRITONSERVER_MemoryType kind_;
int64_t device_id_;
std::unique_ptr<ShmFile> platform_handle_;
};

#ifdef TRITON_ENABLE_GPU
struct CUDASharedMemoryInfo : SharedMemoryInfo {
CUDASharedMemoryInfo(
const std::string& name, const std::string& shm_key,
const size_t offset, const size_t byte_size, ShmFile* shm_file,
void* mapped_addr, const TRITONSERVER_MemoryType kind,
const int64_t device_id, const cudaIpcMemHandle_t* cuda_ipc_handle)
const size_t offset, const size_t byte_size,
std::shared_ptr<ShmFile> shm_file, void* mapped_addr,
const TRITONSERVER_MemoryType kind, const int64_t device_id,
const cudaIpcMemHandle_t* cuda_ipc_handle)
: SharedMemoryInfo(
name, shm_key, offset, byte_size, shm_file, mapped_addr, kind,
device_id),
Expand Down

0 comments on commit 126f7dc

Please sign in to comment.