Skip to content

Commit

Permalink
Merge pull request #4 from aicis/pesto-independent-changes
Browse files Browse the repository at this point in the history
Added support for JNO protocol
  • Loading branch information
jot2re authored Oct 19, 2022
2 parents ee5eb7f + ff2bd97 commit d3002f6
Show file tree
Hide file tree
Showing 83 changed files with 3,467 additions and 1,210 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package dk.alexandra.fresco.outsourcing.client;

import static dk.alexandra.fresco.outsourcing.utils.GenericUtils.intFromBytes;

import dk.alexandra.fresco.framework.Party;
import dk.alexandra.fresco.framework.builder.numeric.field.FieldDefinition;
import dk.alexandra.fresco.framework.util.ByteAndBitConverter;
import dk.alexandra.fresco.outsourcing.network.ClientSideNetworkFactory;
import dk.alexandra.fresco.outsourcing.network.TwoPartyNetwork;
import java.math.BigInteger;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractClientBase {
private static final Logger logger = LoggerFactory.getLogger(AbstractClientBase.class);

protected FieldDefinition definition;
protected List<Party> servers;
protected Map<Integer, TwoPartyNetwork> serverNetworks;
protected int clientId;

/**
* Creates new {@link AbstractClientBase}.
*
* @param clientId client ID
* @param servers servers to connect to
*/
protected AbstractClientBase(int clientId, List<Party> servers) {
if (clientId < 1) {
throw new IllegalArgumentException("Client ID must be 1 or higher");
}
this.clientId = clientId;
this.servers = servers;
}

/**
* Connects to all worker server and initializes server map with all connected servers.
*/
protected final void initServerNetworks(ExecutorService es, TwoPartyNetwork masterNetwork,
byte[] handShakeMessage)
throws InterruptedException, java.util.concurrent.ExecutionException {
Map<Integer, Future<TwoPartyNetwork>> futureNetworks = new HashMap<>(servers.size() - 1);
for (Party s : servers.stream().filter(p -> p.getPartyId() != 1)
.collect(Collectors.toList())) {
Future<TwoPartyNetwork> futureNetwork = es.submit(connect(s, handShakeMessage));
futureNetworks.put(s.getPartyId(), futureNetwork);
}
serverNetworks = new HashMap<>(servers.size());
serverNetworks.put(1, masterNetwork);
for (Entry<Integer, Future<TwoPartyNetwork>> f : futureNetworks.entrySet()) {
serverNetworks.put(f.getKey(), f.getValue().get());
}
}

protected final void initFieldDefinition(Function<BigInteger, FieldDefinition> definitionSupplier,
TwoPartyNetwork masterNetwork) {
byte[] modResponse = masterNetwork.receive();
BigInteger modulus = new BigInteger(modResponse);
this.definition = definitionSupplier.apply(modulus);
}

/**
* Connects to server with given handshake message.
*/
protected final Callable<TwoPartyNetwork> connect(Party server, byte[] handShakeMessage) {
return () -> {
logger.info("C{}: Connecting to server {} ... ", clientId, server);
TwoPartyNetwork network =
ClientSideNetworkFactory.getNetwork(server.getHostname(), server.getPort());
network.send(handShakeMessage);
logger.info("C{}: Connected to server {}", clientId, server);
return network;
};
}

protected void handshake(Function<BigInteger, FieldDefinition> definitionSupplier,
int amount) {
logger.info("C{}: Starting handshake", clientId);
try {
ExecutorService es = Executors.newFixedThreadPool(servers.size() - 1);

Party serverOne = servers.stream().filter(p -> p.getPartyId() == 1).findFirst().get();
logger.info("C{}: connecting to master server {}", clientId, serverOne);
TwoPartyNetwork masterNetwork = es
.submit(connect(serverOne, getHandShakeMessage(0, amount))).get();
logger.info("C{}: Connected to master server", clientId);
byte[] response = masterNetwork.receive();

int priority = intFromBytes(response);
logger.info("C{}: Received priority {}", clientId, priority);

initServerNetworks(es, masterNetwork, getHandShakeMessage(priority, amount));

es.shutdown();

initFieldDefinition(definitionSupplier, masterNetwork);
} catch (Exception e) {
logger.error("Error during handshake", e);
e.printStackTrace();
}
}

protected byte[] getHandShakeMessage(int priority, int amount) {
byte[] msg = new byte[Integer.BYTES * 3];
System.arraycopy(ByteAndBitConverter.toByteArray(priority), 0, msg, 0, Integer.BYTES);
System.arraycopy(ByteAndBitConverter.toByteArray(clientId), 0, msg, Integer.BYTES,
Integer.BYTES);
System.arraycopy(ByteAndBitConverter.toByteArray(amount), 0, msg, Integer.BYTES * 2,
Integer.BYTES);
return msg;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package dk.alexandra.fresco.outsourcing.client;

import static dk.alexandra.fresco.outsourcing.utils.GenericUtils.intFromBytes;

import dk.alexandra.fresco.framework.builder.numeric.field.FieldDefinition;
import dk.alexandra.fresco.outsourcing.network.TwoPartyNetwork;
import dk.alexandra.fresco.outsourcing.server.ClientSession;
import dk.alexandra.fresco.outsourcing.server.ClientSessionHandler;
import dk.alexandra.fresco.outsourcing.server.DemoClientSessionRequestHandler;
import dk.alexandra.fresco.suite.spdz.SpdzResourcePool;
import java.util.Arrays;
import java.util.Comparator;
import java.util.PriorityQueue;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractSessionEndPoint<T extends ClientSession> implements ClientSessionHandler<T> {
private static final Logger logger = LoggerFactory
.getLogger(AbstractSessionEndPoint.class);

protected final SpdzResourcePool resourcePool;
protected final int expectedClients;
protected final PriorityQueue<DemoClientSessionRequestHandler.QueuedClient> orderingQueue;
protected final BlockingQueue<DemoClientSessionRequestHandler.QueuedClient> processingQueue;
protected final FieldDefinition definition;
protected int clientsReady;
protected int sessionsProduced;

public AbstractSessionEndPoint(SpdzResourcePool resourcePool,
FieldDefinition definition,
int expectedClients) {
if (expectedClients < 0) {
throw new IllegalArgumentException(
"Expected input clients cannot be negative, but was: " + expectedClients);
}
this.resourcePool = resourcePool;
this.definition = definition;
this.expectedClients = expectedClients;
this.processingQueue = new ArrayBlockingQueue<>(expectedClients);
this.orderingQueue = new PriorityQueue<>(expectedClients,
Comparator.comparingInt(DemoClientSessionRequestHandler.QueuedClient::getPriority));
this.clientsReady = 0;
}

protected abstract T getClientSession(DemoClientSessionRequestHandler.QueuedClient client);

@Override
public T next() {
try {
DemoClientSessionRequestHandler.QueuedClient client = processingQueue.take();
T session = getClientSession(client);
sessionsProduced++;
return session;
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}

@Override
public boolean hasNext() {
return expectedClients - sessionsProduced > 0;
}

@Override
public int registerNewSessionRequest(byte[] handshakeMessage, TwoPartyNetwork network) {
// Bytes 0-3: client priority, assigned by server 1 (big endian int)
// Bytes 4-7: unique id for client (big endian int)
// Bytes 8-11: number of inputs (big endian int)
int priority = intFromBytes(Arrays.copyOfRange(handshakeMessage, 0, Integer.BYTES * 1));
int clientId =
intFromBytes(Arrays.copyOfRange(handshakeMessage, Integer.BYTES * 1, Integer.BYTES * 2));
int numInputs =
intFromBytes(Arrays.copyOfRange(handshakeMessage, Integer.BYTES * 2, Integer.BYTES * 3));
return registerNewSessionRequest(priority, clientId, numInputs, network);
}

@Override
public int getExpectedClients() {
return expectedClients;
}

private int registerNewSessionRequest(int suggestedPriority, int clientId, int inputAmount,
TwoPartyNetwork network) {
if (resourcePool.getMyId() == 1) {
int priority = clientsReady++;
DemoClientSessionRequestHandler.QueuedClient q = new DemoClientSessionRequestHandler.QueuedClient(priority, clientId, inputAmount, network);
processingQueue.add(q);
return q.getPriority();
} else {
DemoClientSessionRequestHandler.QueuedClient q = new DemoClientSessionRequestHandler.QueuedClient(suggestedPriority, clientId, inputAmount, network);
orderingQueue.add(q);
while (!orderingQueue.isEmpty() && orderingQueue.peek().getPriority() == clientsReady) {
clientsReady++;
processingQueue.add(orderingQueue.remove());
}
logger.info(
"S{}: Finished handskake for input client {} with priority {}. Expecting {} inputs.",
resourcePool.getMyId(), q.getClientId(), q.getPriority(), q.getInputAmount());
return q.getPriority();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package dk.alexandra.fresco.outsourcing.client;

import dk.alexandra.fresco.framework.builder.numeric.field.FieldElement;
import dk.alexandra.fresco.framework.network.serializers.ByteSerializer;
import dk.alexandra.fresco.outsourcing.network.TwoPartyNetwork;
import dk.alexandra.fresco.outsourcing.server.ClientSession;

public class GenericClientSession implements ClientSession {
private final int clientId;
private final TwoPartyNetwork network;
private final ByteSerializer<FieldElement> serializer;

public GenericClientSession(int clientId, TwoPartyNetwork network,
ByteSerializer<FieldElement> serializer) {
this.clientId = clientId;
this.network = network;
this.serializer = serializer;
}

@Override
public int getClientId() {
return clientId;
}

@Override
public TwoPartyNetwork getNetwork() {
return network;
}

@Override
public ByteSerializer<FieldElement> getSerializer() {
return serializer;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package dk.alexandra.fresco.outsourcing.client;

import dk.alexandra.fresco.framework.builder.numeric.field.FieldDefinition;
import dk.alexandra.fresco.outsourcing.server.DemoClientSessionRequestHandler.QueuedClient;
import dk.alexandra.fresco.suite.spdz.SpdzResourcePool;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GenericClientSessionEndpoint extends AbstractSessionEndPoint<GenericClientSession> {

private static final Logger logger = LoggerFactory
.getLogger(GenericClientSessionEndpoint.class);

public GenericClientSessionEndpoint(SpdzResourcePool resourcePool,
FieldDefinition definition,
int expectedClients) {
super(resourcePool, definition, expectedClients);
}

@Override
protected GenericClientSession getClientSession(QueuedClient client) {
return new GenericClientSession(client.getClientId(), client.getNetwork(), definition);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
* </p>
*/
public interface InputClient {

/**
* Inputs a list of values given as BigIntegers.
*
Expand All @@ -31,6 +30,6 @@ public interface InputClient {
*
* @param inputs a list of input values
*/
void putIntInputs(List<Integer> inputs);
void putIntInputs(List<Integer> inputs);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package dk.alexandra.fresco.outsourcing.client.ddnnt;

import dk.alexandra.fresco.framework.Party;
import dk.alexandra.fresco.framework.builder.numeric.field.FieldElement;
import dk.alexandra.fresco.outsourcing.client.AbstractClientBase;

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List;

/**
* Forms base for {@link DdnntInputClient} and {@link DdnntOutputClient}.
*/
public abstract class DdnntClientBase extends AbstractClientBase {

/**
* Creates new {@link AbstractClientBase}.
*
* @param clientId client ID
* @param servers servers to connect to
*/
DdnntClientBase(int clientId, List<Party> servers) {
super(clientId, servers);
}

/**
* Computes pairwise sum of left and right elements.
*/
final List<FieldElement> sumLists(List<FieldElement> left, List<FieldElement> right) {
if (left.size() != right.size()) {
throw new IllegalArgumentException("Left and right should be same size");
}
List<FieldElement> res = new ArrayList<>(left.size());
for (int i = 0; i < left.size(); i++) {
FieldElement b = left.get(i).add(right.get(i));
res.add(b);
}
return res;
}

/**
* Returns true if a * b = c, false otherwise.
*/
final boolean productCheck(FieldElement a, FieldElement b, FieldElement c) {
FieldElement actualProd = a.multiply(b);
BigInteger actualProdConverted = definition.convertToUnsigned(actualProd);
BigInteger expected = definition.convertToUnsigned(c);
return actualProdConverted.equals(expected);
}

}
Loading

0 comments on commit d3002f6

Please sign in to comment.