diff --git a/libmc/_client.pyx b/libmc/_client.pyx index 87c61b6a..14c0533a 100644 --- a/libmc/_client.pyx +++ b/libmc/_client.pyx @@ -24,7 +24,7 @@ import threading import zlib import marshal import warnings - +from contextlib import contextmanager cdef extern from "Common.h" namespace "douban::mc": ctypedef enum op_code_t: @@ -388,10 +388,7 @@ cdef class PyClientSettings: self.init() - def init(): - pass - - def _args(): + def _args(self): return (self.servers, self.do_split, self.comp_threshold, self.noreply, self.prefix, self.hash_fn, self.failover, self.encoding) @@ -405,7 +402,7 @@ cdef class PyClient(PyClientSettings): cdef object _thread_ident cdef object _created_stack - def init(): + def init(self): self.last_error = RET_OK self._thread_ident = None self._created_stack = traceback.extract_stack() @@ -1132,10 +1129,10 @@ cdef class PyClient(PyClientSettings): return errCodeToString(self.last_error) -class PyPoolClient(PyClient): +cdef class PyPoolClient(PyClient): cdef IndexedClient* _indexed - def init(): + def init(self): self.last_error = RET_OK self._thread_ident = None self._created_stack = traceback.extract_stack() @@ -1144,29 +1141,36 @@ class PyPoolClient(PyClient): pass -class PyClientPool(PyClientSettings): - worker = PyPoolClient +cdef class PyClientPool(PyClientSettings): cdef list clients + cdef ClientPool* _imp - def init(): + cdef init(self): self._imp = new ClientPool() self._imp.config(CFG_HASH_FUNCTION, self.hash_fn) self.clients = [] - def setup(self, IndexedClientClient* imp): - worker = __class__.worker(*self._args()) + cdef setup(self, IndexedClient* imp): + worker = PyPoolClient(*self._args()) worker._indexed = imp - worker._imp = imp.c + worker._imp = &imp.c return worker def acquire(self): worker = self._imp._acquire() if worker.index >= len(self.clients): - clients += [None] * (worker.index - len(self.clients)) - clients.append(setup(worker)) + self.clients += [None] * (worker.index - len(self.clients)) + self.clients.append(self.setup(worker)) elif self.clients[worker.index] == None: - self.clients[i] = setup(worker); - return self.clients[i] + self.clients[worker.index] = self.setup(worker) + return self.clients[worker.index] def release(self, PyPoolClient worker): self._imp._release(worker._indexed) + + @contextmanager + def client(self): + try: + yield self.acquire() + finally: + self.release() diff --git a/src/ClientPool.cpp b/src/ClientPool.cpp index 02c982cb..753ba603 100644 --- a/src/ClientPool.cpp +++ b/src/ClientPool.cpp @@ -1,5 +1,5 @@ //#include -//#include +#include #include #include #include "ClientPool.h" @@ -63,8 +63,8 @@ int ClientPool::updateServers(const char* const* hosts, const uint32_t* ports, std::atomic rv = 0; std::lock_guard updating(m_fifo_access); - //std::for_each(std::execution::par_unseq, irange(), irange(m_clients.size()), std::for_each(irange(), irange(m_clients.size()), + //std::for_each(std::execution::par_unseq, irange(), irange(m_clients.size()), [this, &rv](int i) { std::lock_guard updating_worker(*m_thread_workers[i]); const int err = m_clients[i].c.updateServers( @@ -92,8 +92,8 @@ int ClientPool::growPool(size_t by) { size_t from = m_clients.size(); m_clients.resize(from + by); std::atomic rv = 0; - //std::for_each(std::execution::par_unseq, irange(from), irange(from + by), std::for_each(irange(from), irange(from + by), + //std::for_each(std::execution::par_unseq, irange(from), irange(from + by), [this, &rv](int i) { const int err = setup(&m_clients[i].c); m_clients[i].index = i; @@ -125,9 +125,8 @@ IndexedClient* ClientPool::_acquire() { const auto growing = shouldGrowUnsafe(); m_acquiring_growth.unlock_shared(); if (growing) { - //std::thread acquire_overflow(&ClientPool::autoGrow, this); - //acquire_overflow.detach(); - autoGrow(); + std::thread acquire_overflow(&ClientPool::autoGrow, this); + acquire_overflow.detach(); } int idx = acquireWorker(); diff --git a/tests/test_client_pool.cpp b/tests/test_client_pool.cpp index 611b1b52..6ae7ed73 100644 --- a/tests/test_client_pool.cpp +++ b/tests/test_client_pool.cpp @@ -12,6 +12,7 @@ const unsigned int data_size = 10; const unsigned int n_servers = 20; const unsigned int start_port = 21211; const char host[] = "127.0.0.1"; +unsigned int n_threads = 8; TEST(test_client_pool, simple_set_get) { uint32_t ports[n_servers]; @@ -36,7 +37,7 @@ TEST(test_client_pool, simple_set_get) { const char* keys = &key[0]; const char* values = &value[0]; - for (unsigned int j = 0; j < n_ops; j++) { + for (unsigned int j = 0; j < n_ops * n_threads; j++) { gen_random(key, data_size); gen_random(value, data_size); auto c = pool->acquire(); @@ -52,7 +53,6 @@ TEST(test_client_pool, simple_set_get) { } TEST(test_client_pool, threaded_set_get) { - unsigned int n_threads = 8; uint32_t ports[n_servers]; const char* hosts[n_servers]; for (unsigned int i = 0; i < n_servers; i++) { @@ -63,6 +63,7 @@ TEST(test_client_pool, threaded_set_get) { std::thread* threads = new std::thread[n_threads]; ClientPool* pool = new ClientPool(); pool->config(CFG_HASH_FUNCTION, OPT_HASH_FNV1A_32); + //pool->config(CFG_INITIAL_CLIENTS, 4); pool->init(hosts, ports, n_servers); for (unsigned int i = 0; i < n_threads; i++) { diff --git a/tests/test_client_pool.py b/tests/test_client_pool.py new file mode 100644 index 00000000..5964bec4 --- /dev/null +++ b/tests/test_client_pool.py @@ -0,0 +1,38 @@ +# coding: utf-8 +import unittest +from threading import Thread +from libmc import ClientPool + + +class ThreadedSingleServerCase(unittest.TestCase): + def setUp(self): + self.pool = ClientPool(["127.0.0.1:21211"]) + + def misc(self): + for i in range(5): + with self.pool.client() as mc: + tid = str(mc._get_current_thread_ident() + (i,)) + f, t = 'foo ' + tid, 'tuiche ' + tid + mc.get_multi([f, t]) + mc.delete(f) + mc.delete(t) + assert mc.get(f) is None + assert mc.get(t) is None + + mc.set(f, 'biu') + mc.set(t, 'bb') + assert mc.get(f) == 'biu' + assert mc.get(t) == 'bb' + assert (mc.get_multi([f, t]) == + {f: 'biu', t: 'bb'}) + mc.set_multi({f: 1024, t: '8964'}) + assert (mc.get_multi([f, t]) == + {f: 1024, t: '8964'}) + + def test_misc(self): + ts = [Thread(target=self.misc) for i in range(8)] + for t in ts: + t.start() + + for t in ts: + t.join()