forked from TatianaJin/csci5570
-
Notifications
You must be signed in to change notification settings - Fork 0
/
kv_client_table.hpp
165 lines (155 loc) · 5.48 KB
/
kv_client_table.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#pragma once
#include "base/abstract_partition_manager.hpp"
#include "base/magic.hpp"
#include "base/message.hpp"
#include "base/third_party/sarray.h"
#include "base/threadsafe_queue.hpp"
#include "worker/abstract_callback_runner.hpp"
#include <cinttypes>
#include <vector>
namespace csci5570 {
/**
* Provides the API to users, and implements the worker-side abstraction of model
* Each model in one application is uniquely handled by one KVClientTable
*
* @param Val type of model parameter values
*/
template <typename Val>
class KVClientTable {
using KVPairs = std::pair<third_party::SArray<Key>, third_party::SArray<double>>;
public:
/**
* @param app_thread_id user thread id
* @param model_id model id
* @param sender_queue the work queue of a sender communication thread
* @param partition_manager model partition manager
* @param callback_runner callback runner to handle received replies from servers
*/
KVClientTable(uint32_t app_thread_id, uint32_t model_id, ThreadsafeQueue<Message>* const sender_queue,
const AbstractPartitionManager* const partition_manager, AbstractCallbackRunner* const callback_runner)
: app_thread_id_(app_thread_id),
model_id_(model_id),
sender_queue_(sender_queue),
partition_manager_(partition_manager),
callback_runner_(callback_runner){};
// ========== API ========== //
// void Clock(const std::vector<Key>& keys) {
// third_party::SArray<Key> ktmp;
// int i = 0;
// while (i < keys.size()) {
// ktmp.push_back(keys[i]);
// i++;
// }
// std::vector<std::pair<int, third_party::SArray<Key>>> sliced;
// partition_manager_->Slice(ktmp, &sliced);
// uint32_t count = 0;
// while (count < sliced.size()) {
// Message m;
// m.meta.sender = app_thread_id_;
// m.meta.recver = sliced[count].first;
// m.meta.flag = Flag::kClock;
// m.meta.model_id = model_id_;
// sender_queue_->Push(m);
// count++;
// }
// }
void Clock() {
std::vector<uint32_t> server_thread_ids = this->partition_manager_.GetServerThreadIds();
uint32_t count = 0;
while (count < server_thread_ids.size()) {
Message m;
m.meta.sender = app_thread_id_;
m.meta.recver = server_thread_ids[count];
m.meta.flag = Flag::kClock;
m.meta.model_id = model_id_;
sender_queue_->Push(m);
count++;
}
}
// vector version
void Add(const std::vector<Key>& keys, const std::vector<Val>& vals) {
third_party::SArray<double> vtmp;
third_party::SArray<Key> ktmp;
int i = 0;
while (i < vals.size()) {
vtmp.push_back(vals[i]);
i++;
}
i = 0;
while (i < keys.size()) {
ktmp.push_back(keys[i]);
i++;
}
std::vector<std::pair<int, KVPairs>> sliced;
partition_manager_->Slice(std::make_pair(ktmp, vtmp), &sliced);
uint32_t count = 0;
while (count < sliced.size()) {
Message m;
third_party::SArray<char> key_char;
key_char = sliced[count].second.first;
third_party::SArray<Val> val_chart;
int j = 0;
while (j < sliced[count].second.second.size()) {
val_chart.push_back((Val) sliced[count].second.second.data()[j]);
j++;
}
third_party::SArray<char> val_char;
val_char = val_chart;
m.AddData(key_char);
m.AddData(val_char);
m.meta.sender = app_thread_id_;
m.meta.recver = sliced[count].first;
m.meta.flag = Flag::kAdd;
m.meta.model_id = model_id_;
sender_queue_->Push(m);
count++;
}
}
void Get(const std::vector<Key>& keys, std::vector<Val>* vals) {
third_party::SArray<Key> ktmp;
int i = 0;
while (i < keys.size()) {
ktmp.push_back(keys[i]);
i++;
}
std::vector<std::pair<int, third_party::SArray<Key>>> sliced;
partition_manager_->Slice(ktmp, &sliced);
uint32_t count = 0;
third_party::SArray<Val> vtp;
std::vector<Val> tmp;
std::function<void(Message&)> recv_handle = [&](Message m) {
vtp = m.data[1];
for (int i = 0; i < vtp.size(); i++) {
tmp.push_back(vtp[i]);
}
};
std::function<void()> recv_finish_handle = [&]() { vals->assign(tmp.begin(), tmp.end()); };
callback_runner_->RegisterRecvFinishHandle(app_thread_id_, model_id_, recv_finish_handle);
callback_runner_->RegisterRecvHandle(app_thread_id_, model_id_, recv_handle);
callback_runner_->NewRequest(app_thread_id_, model_id_, sliced.size());
while (count < sliced.size()) {
Message m;
third_party::SArray<char> key_char;
key_char = sliced[count].second;
m.AddData(key_char);
m.meta.sender = app_thread_id_;
m.meta.recver = sliced[count].first;
m.meta.model_id = model_id_;
m.meta.flag = Flag::kGet;
sender_queue_->Push(m);
count++;
}
callback_runner_->WaitRequest(app_thread_id_, model_id_);
}
// sarray version
void Add(const third_party::SArray<Key>& keys, const third_party::SArray<Val>& vals) {}
void Get(const third_party::SArray<Key>& keys, third_party::SArray<Val>* vals) {}
// ========== API ========== //
private:
uint32_t app_thread_id_; // identifies the user thread
uint32_t model_id_; // identifies the model on servers
ThreadsafeQueue<Message>* const sender_queue_; // not owned
AbstractCallbackRunner* const callback_runner_; // not owned
const AbstractPartitionManager* const partition_manager_; // not owned
}; // class KVClientTable
} // namespace csci5570