Skip to content

Commit

Permalink
ref[provider]: refactor consistent hash load balancing
Browse files Browse the repository at this point in the history
  • Loading branch information
jaysunxiao committed Jan 20, 2024
1 parent 471f061 commit d5249ed
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import com.zfoo.net.util.FastTreeMapIntLong;
import com.zfoo.net.util.HashUtils;
import com.zfoo.protocol.ProtocolManager;
import com.zfoo.protocol.collection.CollectionUtils;
import com.zfoo.protocol.collection.HashSetLong;
import com.zfoo.protocol.exception.RunException;
import com.zfoo.protocol.model.Pair;
import com.zfoo.protocol.registration.ProtocolModule;
Expand All @@ -40,10 +40,19 @@ public class ConsistentHashLoadBalancer extends AbstractConsumerLoadBalancer {

public static final ConsistentHashLoadBalancer INSTANCE = new ConsistentHashLoadBalancer();

private volatile int lastClientSessionChangeId = 0;
private static final AtomicReferenceArray<FastTreeMapIntLong> consistentHashMap = new AtomicReferenceArray<>(ProtocolManager.MAX_MODULE_NUM);
private static final AtomicReferenceArray<ConsistentCache> consistentHashMap = new AtomicReferenceArray<>(ProtocolManager.MAX_MODULE_NUM);
private static final int VIRTUAL_NODE_NUMS = 200;

public static class ConsistentCache {
public HashSetLong providerSids;
public FastTreeMapIntLong treeMap;

public ConsistentCache(HashSetLong providerSids, FastTreeMapIntLong treeMap) {
this.providerSids = providerSids;
this.treeMap = treeMap;
}
}

public ConsistentHashLoadBalancer() {
}

Expand All @@ -64,66 +73,53 @@ public Session selectProvider(List<Session> providers, Object packet, Object arg
return RandomLoadBalancer.getInstance().selectProvider(providers, packet, argument);
}

updateConsistentHashMap(providers);

var module = ProtocolManager.moduleByProtocol(packet.getClass());
var fastTreeMap = consistentHashMap.get(module.getId());
if (fastTreeMap == null) {
fastTreeMap = updateModuleToConsistentHash(providers, module);
var consistentCache = consistentHashMap.get(module.getId());
if (consistentCache == null) {
consistentCache = updateModuleToConsistentHash(providers, module);
}
if (fastTreeMap == null) {
throw new RunException("ConsistentHashLoadBalancer [protocol:{}][argument:{}], no service provides the [module:{}]", packet.getClass(), argument, module);
var providerSids = consistentCache.providerSids;
// 一致性hash缓存不一致同样进行更新操作
if (providerSids.size() != providers.size() || providers.stream().anyMatch(it -> !providerSids.contains(it.getSid()))) {
consistentCache = updateModuleToConsistentHash(providers, module);
}
var nearestIndex = fastTreeMap.indexOfNearestCeilingKey(HashUtils.fnvHash(argument));
var treeMap = consistentCache.treeMap;
var nearestIndex = treeMap.indexOfNearestCeilingKey(HashUtils.fnvHash(argument));
if (nearestIndex < 0) {
throw new RunException("no service provides the [module:{}]", module);
}
var sid = fastTreeMap.getByIndex(nearestIndex);
var sid = treeMap.getByIndex(nearestIndex);
var session = NetContext.getSessionManager().getClientSession(sid);
if (session == null) {
throw new RunException("unknown no service provides the [module:{}]", module);
}
return session;
}

private void updateConsistentHashMap(List<Session> providers) {
// 如果更新时间不匹配,则更新到最新的服务提供者
var currentClientSessionChangeId = NetContext.getSessionManager().getClientSessionChangeId();
if (currentClientSessionChangeId != lastClientSessionChangeId) {
for (byte i = 0; i < ProtocolManager.MAX_MODULE_NUM; i++) {
var consistentHash = consistentHashMap.get(i);
if (consistentHash == null) {
continue;
}
var module = ProtocolManager.moduleByModuleId(i);
updateModuleToConsistentHash(providers, module);
}
lastClientSessionChangeId = currentClientSessionChangeId;
}
}


@Nullable
private FastTreeMapIntLong updateModuleToConsistentHash(List<Session> providers, ProtocolModule module) {
private ConsistentCache updateModuleToConsistentHash(List<Session> providers, ProtocolModule module) {
var sessionStringList = providers.stream()
.map(session -> new Pair<>(session.getConsumerRegister().toString(), session.getSid()))
.sorted((a, b) -> a.getKey().compareTo(b.getKey()))
.toList();

var consistentHash = new ConsistentHash<>(sessionStringList, VIRTUAL_NODE_NUMS);
var virtualNodeTreeMap = consistentHash.getVirtualNodeTreeMap();
if (CollectionUtils.isEmpty(virtualNodeTreeMap)) {
consistentHashMap.set(module.getId(), null);
return null;
}

var virtualTreeMap = new TreeMap<Integer, Long>();
for (var entry : virtualNodeTreeMap.entrySet()) {
virtualTreeMap.put(entry.getKey(), entry.getValue().getValue());
}

// 缓存服务提供者的sid
var sidSet = new HashSetLong(16);
providers.forEach(it -> sidSet.add(it.getSid()));
// 使用更高性能的tree map
var fastTreeMap = new FastTreeMapIntLong(virtualTreeMap);
consistentHashMap.set(module.getId(), fastTreeMap);
return fastTreeMap;

var consistentCache = new ConsistentCache(sidSet, fastTreeMap);
consistentHashMap.set(module.getId(), consistentCache);
return consistentCache;
}

}
2 changes: 0 additions & 2 deletions net/src/main/java/com/zfoo/net/session/ISessionManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,4 @@ public interface ISessionManager {

int clientSessionSize();

int getClientSessionChangeId();

}
13 changes: 0 additions & 13 deletions net/src/main/java/com/zfoo/net/session/SessionManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ public class SessionManager implements ISessionManager {

private static final Logger logger = LoggerFactory.getLogger(SessionManager.class);

private static final AtomicInteger CLIENT_ATOMIC = new AtomicInteger(0);

/**
* EN: As a server, the Session is connected by other clients
* CN: 作为服务器,被别的客户端连接的Session
Expand All @@ -38,17 +36,13 @@ public class SessionManager implements ISessionManager {
*/
private final ConcurrentHashMapLongObject<Session> serverSessionMap = new ConcurrentHashMapLongObject<>(128);


/**
* EN: As a client, connect to another server and save Sessions
* CN: 作为客户端,连接别的服务器上后,保存下来的Session
* 如:自己配置了Consumer,说明自己作为消费者将要消费远程接口,就会创建一个TcpClient去连接Provider,那么连接上后,就会保存下来到这个Map中
*/
private final ConcurrentHashMapLongObject<Session> clientSessionMap = new ConcurrentHashMapLongObject<>(8);

private volatile int clientSessionChangeId = CLIENT_ATOMIC.incrementAndGet();


@Override
public void addServerSession(Session session) {
if (serverSessionMap.containsKey(session.getSid())) {
Expand Down Expand Up @@ -91,7 +85,6 @@ public void addClientSession(Session session) {
return;
}
clientSessionMap.put(session.getSid(), session);
clientSessionChangeId = CLIENT_ATOMIC.incrementAndGet();
}

@Override
Expand All @@ -103,7 +96,6 @@ public void removeClientSession(Session session) {
try (session) {
clientSessionMap.remove(session.getSid());
}
clientSessionChangeId = CLIENT_ATOMIC.incrementAndGet();
}

@Override
Expand All @@ -121,9 +113,4 @@ public int clientSessionSize() {
return clientSessionMap.size();
}

@Override
public int getClientSessionChangeId() {
return clientSessionChangeId;
}

}
15 changes: 5 additions & 10 deletions net/src/main/java/com/zfoo/net/util/ConsistentHash.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,13 @@ public ConsistentHash(List<Pair<K, V>> realNodes, int virtualNodes) {
this.realNodes.addAll(realNodes);
this.virtualNodes = virtualNodes;

// 初始化
// 再添加虚拟节点,遍历LinkedList使用foreach循环效率会比较高
for (var realNode : realNodes) {
addNode(realNode);
}
}

public void addNode(Pair<K, V> realNode) {
for (var i = 0; i < this.virtualNodes; i++) {
var virtualNode = realNode.getKey().toString() + "&&VN" + i;
var hash = HashUtils.fnvHash(virtualNode);
virtualNodeTreeMap.put(hash, realNode);
for (var i = 0; i < this.virtualNodes; i++) {
var virtualNode = realNode.getKey().toString() + "&&VN" + i;
var hash = HashUtils.fnvHash(virtualNode);
virtualNodeTreeMap.put(hash, realNode);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ public boolean contains(Object o) {
return map.containsKey(o);
}

public boolean contains(long key) {
return map.containsKey(key);
}

@Override
public boolean add(Long e) {
return map.put(e, Boolean.TRUE) == null;
Expand Down

0 comments on commit d5249ed

Please sign in to comment.