Skip to content

Commit

Permalink
partially fix cython integration
Browse files Browse the repository at this point in the history
  • Loading branch information
kentslaney committed Dec 21, 2023
1 parent 550ef71 commit f099e68
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 26 deletions.
40 changes: 22 additions & 18 deletions libmc/_client.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import threading
import zlib
import marshal
import warnings

from contextlib import contextmanager

cdef extern from "Common.h" namespace "douban::mc":
ctypedef enum op_code_t:
Expand Down Expand Up @@ -388,10 +388,7 @@ cdef class PyClientSettings:

self.init()

def init():
pass

def _args():
def _args(self):
return (self.servers, self.do_split, self.comp_threshold, self.noreply,
self.prefix, self.hash_fn, self.failover, self.encoding)

Expand All @@ -405,7 +402,7 @@ cdef class PyClient(PyClientSettings):
cdef object _thread_ident
cdef object _created_stack

def init():
def init(self):
self.last_error = RET_OK
self._thread_ident = None
self._created_stack = traceback.extract_stack()
Expand Down Expand Up @@ -1132,10 +1129,10 @@ cdef class PyClient(PyClientSettings):
return errCodeToString(self.last_error)


class PyPoolClient(PyClient):
cdef class PyPoolClient(PyClient):
cdef IndexedClient* _indexed

def init():
def init(self):
self.last_error = RET_OK
self._thread_ident = None
self._created_stack = traceback.extract_stack()
Expand All @@ -1144,29 +1141,36 @@ class PyPoolClient(PyClient):
pass


class PyClientPool(PyClientSettings):
worker = PyPoolClient
cdef class PyClientPool(PyClientSettings):
cdef list clients
cdef ClientPool* _imp

def init():
cdef init(self):
self._imp = new ClientPool()
self._imp.config(CFG_HASH_FUNCTION, self.hash_fn)
self.clients = []

def setup(self, IndexedClientClient* imp):
worker = __class__.worker(*self._args())
cdef setup(self, IndexedClient* imp):
worker = PyPoolClient(*self._args())
worker._indexed = imp
worker._imp = imp.c
worker._imp = &imp.c
return worker

def acquire(self):
worker = self._imp._acquire()
if worker.index >= len(self.clients):
clients += [None] * (worker.index - len(self.clients))
clients.append(setup(worker))
self.clients += [None] * (worker.index - len(self.clients))
self.clients.append(self.setup(worker))
elif self.clients[worker.index] == None:
self.clients[i] = setup(worker);
return self.clients[i]
self.clients[worker.index] = self.setup(worker)
return self.clients[worker.index]

def release(self, PyPoolClient worker):
self._imp._release(worker._indexed)

@contextmanager
def client(self):
try:
yield self.acquire()
finally:
self.release()
11 changes: 5 additions & 6 deletions src/ClientPool.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//#include <execution>
//#include <thread>
#include <thread>
#include <atomic>
#include <array>
#include "ClientPool.h"
Expand Down Expand Up @@ -63,8 +63,8 @@ int ClientPool::updateServers(const char* const* hosts, const uint32_t* ports,

std::atomic<int> rv = 0;
std::lock_guard<std::mutex> updating(m_fifo_access);
//std::for_each(std::execution::par_unseq, irange(), irange(m_clients.size()),
std::for_each(irange(), irange(m_clients.size()),
//std::for_each(std::execution::par_unseq, irange(), irange(m_clients.size()),
[this, &rv](int i) {
std::lock_guard<std::mutex> updating_worker(*m_thread_workers[i]);
const int err = m_clients[i].c.updateServers(
Expand Down Expand Up @@ -92,8 +92,8 @@ int ClientPool::growPool(size_t by) {
size_t from = m_clients.size();
m_clients.resize(from + by);
std::atomic<int> rv = 0;
//std::for_each(std::execution::par_unseq, irange(from), irange(from + by),
std::for_each(irange(from), irange(from + by),
//std::for_each(std::execution::par_unseq, irange(from), irange(from + by),
[this, &rv](int i) {
const int err = setup(&m_clients[i].c);
m_clients[i].index = i;
Expand Down Expand Up @@ -125,9 +125,8 @@ IndexedClient* ClientPool::_acquire() {
const auto growing = shouldGrowUnsafe();
m_acquiring_growth.unlock_shared();
if (growing) {
//std::thread acquire_overflow(&ClientPool::autoGrow, this);
//acquire_overflow.detach();
autoGrow();
std::thread acquire_overflow(&ClientPool::autoGrow, this);
acquire_overflow.detach();
}

int idx = acquireWorker();
Expand Down
5 changes: 3 additions & 2 deletions tests/test_client_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const unsigned int data_size = 10;
const unsigned int n_servers = 20;
const unsigned int start_port = 21211;
const char host[] = "127.0.0.1";
unsigned int n_threads = 8;

TEST(test_client_pool, simple_set_get) {
uint32_t ports[n_servers];
Expand All @@ -36,7 +37,7 @@ TEST(test_client_pool, simple_set_get) {
const char* keys = &key[0];
const char* values = &value[0];

for (unsigned int j = 0; j < n_ops; j++) {
for (unsigned int j = 0; j < n_ops * n_threads; j++) {
gen_random(key, data_size);
gen_random(value, data_size);
auto c = pool->acquire();
Expand All @@ -52,7 +53,6 @@ TEST(test_client_pool, simple_set_get) {
}

TEST(test_client_pool, threaded_set_get) {
unsigned int n_threads = 8;
uint32_t ports[n_servers];
const char* hosts[n_servers];
for (unsigned int i = 0; i < n_servers; i++) {
Expand All @@ -63,6 +63,7 @@ TEST(test_client_pool, threaded_set_get) {
std::thread* threads = new std::thread[n_threads];
ClientPool* pool = new ClientPool();
pool->config(CFG_HASH_FUNCTION, OPT_HASH_FNV1A_32);
//pool->config(CFG_INITIAL_CLIENTS, 4);
pool->init(hosts, ports, n_servers);

for (unsigned int i = 0; i < n_threads; i++) {
Expand Down
38 changes: 38 additions & 0 deletions tests/test_client_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# coding: utf-8
import unittest
from threading import Thread
from libmc import ClientPool


class ThreadedSingleServerCase(unittest.TestCase):
def setUp(self):
self.pool = ClientPool(["127.0.0.1:21211"])

def misc(self):
for i in range(5):
with self.pool.client() as mc:
tid = str(mc._get_current_thread_ident() + (i,))
f, t = 'foo ' + tid, 'tuiche ' + tid
mc.get_multi([f, t])
mc.delete(f)
mc.delete(t)
assert mc.get(f) is None
assert mc.get(t) is None

mc.set(f, 'biu')
mc.set(t, 'bb')
assert mc.get(f) == 'biu'
assert mc.get(t) == 'bb'
assert (mc.get_multi([f, t]) ==
{f: 'biu', t: 'bb'})
mc.set_multi({f: 1024, t: '8964'})
assert (mc.get_multi([f, t]) ==
{f: 1024, t: '8964'})

def test_misc(self):
ts = [Thread(target=self.misc) for i in range(8)]
for t in ts:
t.start()

for t in ts:
t.join()

0 comments on commit f099e68

Please sign in to comment.