Skip to content

Commit

Permalink
redispipeline publish at flush
Browse files Browse the repository at this point in the history
  • Loading branch information
a114j0y committed Sep 23, 2024
1 parent 24979b0 commit 65cba5f
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 33 deletions.
39 changes: 23 additions & 16 deletions common/producerstatetable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,37 @@ using namespace std;

namespace swss {

ProducerStateTable::ProducerStateTable(DBConnector *db, const string &tableName)
: ProducerStateTable(new RedisPipeline(db, 1), tableName, false)
ProducerStateTable::ProducerStateTable(DBConnector *db, const string &tableName, bool flushPub)
: ProducerStateTable(new RedisPipeline(db, 1), tableName, false, flushPub)
{
m_pipeowned = true;
}

ProducerStateTable::ProducerStateTable(RedisPipeline *pipeline, const string &tableName, bool buffered)
ProducerStateTable::ProducerStateTable(RedisPipeline *pipeline, const string &tableName, bool buffered, bool flushPub)
: TableBase(tableName, SonicDBConfig::getSeparator(pipeline->getDBConnector()))
, TableName_KeySet(tableName)
, m_flushPub(flushPub)
, m_buffered(buffered)
, m_pipeowned(false)
, m_tempViewActive(false)
, m_pipe(pipeline)
{
if (m_flushPub) {
m_pipe->addChannel(getChannelName(m_pipe->getDbId()));
}
// num in luaSet and luaDel means number of elements that were added to the key set,
// not including all the elements already present into the set.
string luaSet =
"local added = redis.call('SADD', KEYS[2], ARGV[2])\n"
"for i = 0, #KEYS - 3 do\n"
" redis.call('HSET', KEYS[3 + i], ARGV[3 + i * 2], ARGV[4 + i * 2])\n"
"end\n"
" if added > 0 then \n"
" redis.call('PUBLISH', KEYS[1], ARGV[1])\n"
"end\n";
m_shaSet = m_pipe->loadRedisScript(luaSet);

string luaDel =
"local added = redis.call('SADD', KEYS[2], ARGV[2])\n"
"redis.call('SADD', KEYS[4], ARGV[2])\n"
"redis.call('DEL', KEYS[3])\n"
"if added > 0 then \n"
" redis.call('PUBLISH', KEYS[1], ARGV[1])\n"
"end\n";
"redis.call('DEL', KEYS[3])\n";
m_shaDel = m_pipe->loadRedisScript(luaDel);

string luaBatchedSet =
Expand All @@ -59,9 +57,6 @@ ProducerStateTable::ProducerStateTable(RedisPipeline *pipeline, const string &ta
" redis.call('HSET', KEYS[3] .. KEYS[4 + i], attr, val)\n"
" end\n"
" idx = idx + tonumber(ARGV[idx]) * 2 + 1\n"
"end\n"
"if added > 0 then \n"
" redis.call('PUBLISH', KEYS[1], ARGV[1])\n"
"end\n";
m_shaBatchedSet = m_pipe->loadRedisScript(luaBatchedSet);

Expand All @@ -71,9 +66,6 @@ ProducerStateTable::ProducerStateTable(RedisPipeline *pipeline, const string &ta
" added = added + redis.call('SADD', KEYS[2], KEYS[5 + i])\n"
" redis.call('SADD', KEYS[3], KEYS[5 + i])\n"
" redis.call('DEL', KEYS[4] .. KEYS[5 + i])\n"
"end\n"
"if added > 0 then \n"
" redis.call('PUBLISH', KEYS[1], ARGV[1])\n"
"end\n";
m_shaBatchedDel = m_pipe->loadRedisScript(luaBatchedDel);

Expand All @@ -88,6 +80,21 @@ ProducerStateTable::ProducerStateTable(RedisPipeline *pipeline, const string &ta

string luaApplyView = loadLuaScript("producer_state_table_apply_view.lua");
m_shaApplyView = m_pipe->loadRedisScript(luaApplyView);

if (!m_flushPub) {
string luaPub =
"if added > 0 then \n"
" redis.call('PUBLISH', KEYS[1], ARGV[1])\n"
"end\n";
luaSet += luaPub;
luaDel += luaPub;
luaBatchedSet += luaPub;
luaBatchedDel += luaPub;
m_shaSet = m_pipe->loadRedisScript(luaSet);
m_shaDel = m_pipe->loadRedisScript(luaDel);
m_shaBatchedSet = m_pipe->loadRedisScript(luaBatchedSet);
m_shaBatchedDel = m_pipe->loadRedisScript(luaBatchedDel);
}
}

ProducerStateTable::~ProducerStateTable()
Expand Down
5 changes: 3 additions & 2 deletions common/producerstatetable.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ namespace swss {
class ProducerStateTable : public TableBase, public TableName_KeySet
{
public:
ProducerStateTable(DBConnector *db, const std::string &tableName);
ProducerStateTable(RedisPipeline *pipeline, const std::string &tableName, bool buffered = false);
ProducerStateTable(DBConnector *db, const std::string &tableName, bool flushPub = false);
ProducerStateTable(RedisPipeline *pipeline, const std::string &tableName, bool buffered = false, bool flushPub = false);
virtual ~ProducerStateTable();

void setBuffered(bool buffered);
Expand Down Expand Up @@ -51,6 +51,7 @@ class ProducerStateTable : public TableBase, public TableName_KeySet

void apply_temp_view();
private:
bool m_flushPub;
bool m_buffered;
bool m_pipeowned;
bool m_tempViewActive;
Expand Down
44 changes: 44 additions & 0 deletions common/redispipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

#include <string>
#include <queue>
#include <unordered_set>
#include <functional>
#include <chrono>
#include <iostream>
#include "redisreply.h"
#include "rediscommand.h"
#include "dbconnector.h"
Expand All @@ -22,9 +25,11 @@ class RedisPipeline {
RedisPipeline(const DBConnector *db, size_t sz = 128)
: COMMAND_MAX(sz)
, m_remaining(0)
, m_shaPub("")
{
m_db = db->newConnector(NEWCONNECTOR_TIMEOUT);
initializeOwnerTid();
lastHeartBeat = std::chrono::steady_clock::now();
}

~RedisPipeline() {
Expand Down Expand Up @@ -113,11 +118,19 @@ class RedisPipeline {

void flush()
{
lastHeartBeat = std::chrono::steady_clock::now();

if (m_remaining == 0) {
return;
}

while(m_remaining)
{
// Construct an object to use its dtor, so that resource is released
RedisReply r(pop());
}

publish();
}

size_t size()
Expand Down Expand Up @@ -145,12 +158,43 @@ class RedisPipeline {
m_ownerTid = gettid();
}

void addChannel(std::string channel)
{
if (m_channels.find(channel) != m_channels.end())
return;

m_channels.insert(channel);
m_luaPub += "redis.call('PUBLISH', '" + channel + "', 'G');";
m_shaPub = loadRedisScript(m_luaPub);
}

int getIdleTime(std::chrono::time_point<std::chrono::steady_clock> tcurrent=std::chrono::steady_clock::now())
{
return static_cast<int>(std::chrono::duration_cast<std::chrono::milliseconds>(tcurrent - lastHeartBeat).count());
}

void publish() {
if (m_shaPub == "") {
return;
}
RedisCommand cmd;
cmd.format(
"EVALSHA %s 0",
m_shaPub.c_str());
RedisReply r(m_db, cmd);
}

private:
DBConnector *m_db;
std::queue<int> m_expectedTypes;
size_t m_remaining;
long int m_ownerTid;

std::string m_luaPub;
std::string m_shaPub;
std::chrono::time_point<std::chrono::steady_clock> lastHeartBeat; // marks the timestamp of latest pipeline flush being invoked
std::unordered_set<std::string> m_channels;

void mayflush()
{
if (m_remaining >= COMMAND_MAX)
Expand Down
40 changes: 25 additions & 15 deletions tests/redis_piped_state_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ static inline void validateFields(const string& key, const vector<FieldValueTupl
}
}

static void producerWorker(int index)
static void producerWorker(int index, bool flushPub)
{
string tableName = "UT_REDIS_THREAD_" + to_string(index);
DBConnector db(TEST_DB, 0, true);
RedisPipeline pipeline(&db);
ProducerStateTable p(&pipeline, tableName, true);
ProducerStateTable p(&pipeline, tableName, true, flushPub);

for (int i = 0; i < NUMBER_OF_OPS; i++)
{
Expand Down Expand Up @@ -117,19 +117,23 @@ static void consumerWorker(int index)
cs.addSelectable(&c);
while ((ret = cs.select(&selectcs)) == Select::OBJECT)
{
c.pop(kco);
if (kfvOp(kco) == "SET")
{
numberOfKeysSet++;
validateFields(kfvKey(kco), kfvFieldsValues(kco));
} else if (kfvOp(kco) == "DEL")
std::deque<KeyOpFieldsValuesTuple> entries;
c.pops(entries);

for (auto& kco: entries)
{
numberOfKeyDeleted++;
if (kfvOp(kco) == "SET")
{
numberOfKeysSet++;
validateFields(kfvKey(kco), kfvFieldsValues(kco));
} else if (kfvOp(kco) == "DEL")
{
numberOfKeyDeleted++;
}

if ((i++ % 100) == 0)
cout << "-" << flush;
}

if ((i++ % 100) == 0)
cout << "-" << flush;

if (numberOfKeyDeleted == NUMBER_OF_OPS)
break;
}
Expand Down Expand Up @@ -654,7 +658,10 @@ TEST(ConsumerStateTable, async_test)
for (int i = 0; i < NUMBER_OF_THREADS; i++)
{
consumerThreads[i] = new thread(consumerWorker, i);
producerThreads[i] = new thread(producerWorker, i);
if (i < NUMBER_OF_THREADS/2)
producerThreads[i] = new thread(producerWorker, i, false);
else
producerThreads[i] = new thread(producerWorker, i, true);
}

cout << "Done. Waiting for all job to finish " << NUMBER_OF_OPS << " jobs." << endl;
Expand Down Expand Up @@ -689,7 +696,10 @@ TEST(ConsumerStateTable, async_multitable)
{
consumers[i] = new ConsumerStateTable(&db, string("UT_REDIS_THREAD_") +
to_string(i));
producerThreads[i] = new thread(producerWorker, i);
if (i < NUMBER_OF_THREADS/2)
producerThreads[i] = new thread(producerWorker, i, false);
else
producerThreads[i] = new thread(producerWorker, i, true);
}

for (i = 0; i < NUMBER_OF_THREADS; i++)
Expand Down

0 comments on commit 65cba5f

Please sign in to comment.