Skip to content

Commit

Permalink
chore: #16913 - replace AddressBook with Roster in SignedStateValidat…
Browse files Browse the repository at this point in the history
…or (#16999)

chore: #16913 - replace AddressBook with Roster in SignedStateValidator

Signed-off-by: Tim Farber-Newman <[email protected]>
  • Loading branch information
timfn-hg authored Dec 10, 2024
1 parent 06c74c3 commit 6a8c5e8
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@

import static com.swirlds.logging.legacy.LogMarker.EXCEPTION;

import com.hedera.hapi.node.state.roster.Roster;
import com.swirlds.common.context.PlatformContext;
import com.swirlds.platform.config.StateConfig;
import com.swirlds.platform.state.signed.SignedState;
import com.swirlds.platform.state.signed.SignedStateInvalidException;
import com.swirlds.platform.state.signed.SignedStateValidationData;
import com.swirlds.platform.state.signed.SignedStateValidator;
import com.swirlds.platform.system.address.AddressBook;
import edu.umd.cs.findbugs.annotations.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -48,9 +48,9 @@ public DefaultSignedStateValidator(@NonNull final PlatformContext platformContex
* {@inheritDoc}
*/
public void validate(
final SignedState signedState, final AddressBook addressBook, SignedStateValidationData previousStateData) {
final SignedState signedState, final Roster roster, SignedStateValidationData previousStateData) {
throwIfOld(signedState, previousStateData);
signedState.pruneInvalidSignatures(addressBook);
signedState.pruneInvalidSignatures(roster);
signedState.throwIfNotVerifiable();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import com.swirlds.platform.crypto.CryptoStatic;
import com.swirlds.platform.metrics.ReconnectMetrics;
import com.swirlds.platform.network.Connection;
import com.swirlds.platform.roster.RosterUtils;
import com.swirlds.platform.state.MerkleRoot;
import com.swirlds.platform.state.signed.ReservedSignedState;
import com.swirlds.platform.state.signed.SignedState;
Expand Down Expand Up @@ -106,8 +105,7 @@ public ReconnectLearner(
this.statistics = Objects.requireNonNull(statistics);

// Save some of the current state data for validation
this.stateValidationData = new SignedStateValidationData(
currentState.getReadablePlatformState(), RosterUtils.buildAddressBook(roster));
this.stateValidationData = new SignedStateValidationData(currentState.getReadablePlatformState(), roster);
}

/**
Expand Down Expand Up @@ -159,7 +157,7 @@ public ReservedSignedState execute(@NonNull final SignedStateValidator validator
try {
receiveSignatures();
reservedSignedState = reconnect();
validator.validate(reservedSignedState.get(), RosterUtils.buildAddressBook(roster), stateValidationData);
validator.validate(reservedSignedState.get(), roster, stateValidationData);
ReconnectUtils.endReconnectHandshake(connection);
SignedStateFileReader.unregisterServiceStates(reservedSignedState.get());
return reservedSignedState;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import static com.swirlds.platform.state.signed.SignedStateHistory.SignedStateAction.RELEASE;
import static com.swirlds.platform.state.signed.SignedStateHistory.SignedStateAction.RESERVE;

import com.hedera.hapi.node.state.roster.Roster;
import com.hedera.hapi.node.state.roster.RosterEntry;
import com.swirlds.base.time.Time;
import com.swirlds.common.crypto.Signature;
import com.swirlds.common.platform.NodeId;
Expand All @@ -46,9 +48,11 @@
import com.swirlds.state.merkle.SigSet;
import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
import java.security.cert.X509Certificate;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -623,6 +627,32 @@ private boolean isSignatureValid(@Nullable final Address address, @NonNull final
state.getHash().getBytes(), signature.getBytes(), address.getSigPublicKey());
}

/**
* Check if a signature is valid. If a node has no weight or is missing a certificate, we consider the signature to
* be invalid.
*
* @param rosterEntry the roster entry of the signer, or null if there was no signing address
* @param signature the signature to check
* @return true if the signature is valid, else false
*/
private boolean isSignatureValid(@Nullable final RosterEntry rosterEntry, @NonNull final Signature signature) {
if (rosterEntry == null) {
return false;
}

if (rosterEntry.weight() == 0) {
return false;
}

X509Certificate cert = RosterUtils.fetchGossipCaCertificate(rosterEntry);

if (cert == null) {
return false;
}

return signatureVerifier.verifySignature(state.getHash().getBytes(), signature.getBytes(), cert.getPublicKey());
}

/**
* Add a signature to the sigset if the signature is valid.
*
Expand Down Expand Up @@ -702,6 +732,41 @@ public void pruneInvalidSignatures(@NonNull final AddressBook trustedAddressBook
}
}

/**
* Remove all invalid signatures from a signed state.
*
* @param trustedRoster use this roster to determine signature validity instead of using the roster from the signed
* state. (Useful if validating signed states from untrusted sources.)
*/
public void pruneInvalidSignatures(@NonNull final Roster trustedRoster) {
Objects.requireNonNull(trustedRoster);

final Map<Long, RosterEntry> entriesByNodeId = RosterUtils.toMap(trustedRoster);
final List<NodeId> signaturesToRemove = new ArrayList<>();

for (final NodeId nodeId : sigSet) {
final RosterEntry entry = entriesByNodeId.get(nodeId.id());
if (!isSignatureValid(entry, sigSet.getSignature(nodeId))) {
signaturesToRemove.add(nodeId);
}
}

for (final NodeId nodeId : signaturesToRemove) {
sigSet.removeSignature(nodeId);
}

long newWeight = 0;

for (final NodeId nodeId : sigSet) {
final RosterEntry entry = entriesByNodeId.get(nodeId.id());
if (entry != null) {
newWeight += entry.weight();
}
}

signingWeight = newWeight;
}

/**
* Get the reservation history for this object (if configured to gather history)
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

package com.swirlds.platform.state.signed;

import com.hedera.hapi.node.state.roster.Roster;
import com.swirlds.common.crypto.Hash;
import com.swirlds.platform.roster.RosterUtils;
import com.swirlds.platform.state.PlatformStateAccessor;
import com.swirlds.platform.system.address.AddressBook;
import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
import java.time.Instant;
Expand All @@ -30,23 +31,22 @@
* the minimum round to be considered a valid state
* @param consensusTimestamp
* The consensus timestamp from an earlier state
* @param addressBookHash
* The address book hash value for the current address book (mostly used for diagnostics).
* @param rosterHash
* The roster hash value for the current roster (mostly used for diagnostics).
* @param consensusEventsRunningHash
* The running hash of the consensus event hashes throughout history
*/
public record SignedStateValidationData(
long round,
@NonNull Instant consensusTimestamp,
@Nullable Hash addressBookHash,
@Nullable Hash rosterHash,
@NonNull Hash consensusEventsRunningHash) {

public SignedStateValidationData(
@NonNull final PlatformStateAccessor that, @Nullable final AddressBook addressBook) {
public SignedStateValidationData(@NonNull final PlatformStateAccessor that, @Nullable final Roster roster) {
this(
that.getRound(),
that.getConsensusTimestamp(),
addressBook == null ? null : addressBook.getHash(),
roster == null ? null : RosterUtils.hash(roster),
that.getLegacyRunningEventHash());
}

Expand All @@ -64,8 +64,8 @@ public String getInfoString() {
.append(consensusTimestamp)
.append(", consensus Events running hash = ")
.append(consensusEventsRunningHash)
.append(", address book hash = ")
.append(addressBookHash != null ? addressBookHash : "not provided")
.append(", roster hash = ")
.append(rosterHash != null ? rosterHash : "not provided")
.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,27 @@

package com.swirlds.platform.state.signed;

import com.hedera.hapi.node.state.roster.Roster;
import com.swirlds.platform.state.service.ReadablePlatformStateStore;
import com.swirlds.platform.system.address.AddressBook;

/**
* Validates a signed state received via reconnect.
*/
public interface SignedStateValidator {

/**
* Determines if a signed state is valid with the address book. Validation usually includes
* verifying that the signed state is signed with a sufficient number of valid signatures to meet a certain weighting
* threshold, but other requirements could be included as well.
* Determines if a signed state is valid with the roster. Validation usually includes verifying that the signed
* state is signed with a sufficient number of valid signatures to meet a certain weighting threshold, but other
* requirements could be included as well.
*
* @param signedState the signed state to validate
* @param addressBook the address book used for this signed state
* @param roster the roster used for this signed state
* @param previousStateData A {@link SignedStateValidationData} containing data from the
* {@link ReadablePlatformStateStore} in the state before the signed state to be validated.
* This may be used to ensure signed state is usable and valid, and also contains useful information for
* diagnostics produced when the signed state is not considered valid.
* @throws SignedStateInvalidException if the signed state is not valid
*/
void validate(
final SignedState signedState,
final AddressBook addressBook,
final SignedStateValidationData previousStateData)
void validate(final SignedState signedState, final Roster roster, final SignedStateValidationData previousStateData)
throws SignedStateInvalidException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;

import com.hedera.hapi.node.state.roster.Roster;
import com.hedera.hapi.node.state.roster.RosterEntry;
import com.swirlds.common.context.PlatformContext;
import com.swirlds.common.crypto.Hash;
import com.swirlds.common.crypto.Signature;
Expand All @@ -31,11 +33,11 @@
import com.swirlds.common.test.fixtures.platform.TestPlatformContextBuilder;
import com.swirlds.merkledb.MerkleDb;
import com.swirlds.platform.crypto.SignatureVerifier;
import com.swirlds.platform.roster.RosterUtils;
import com.swirlds.platform.state.signed.SignedState;
import com.swirlds.platform.state.signed.SignedStateInvalidException;
import com.swirlds.platform.state.signed.SignedStateValidationData;
import com.swirlds.platform.system.address.AddressBook;
import com.swirlds.platform.test.fixtures.addressbook.RandomAddressBuilder;
import com.swirlds.platform.test.fixtures.addressbook.RandomRosterEntryBuilder;
import com.swirlds.platform.test.fixtures.state.RandomSignedStateGenerator;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.util.ArrayList;
Expand Down Expand Up @@ -71,7 +73,7 @@ class DefaultSignedStateValidatorTests {

private static final int ROUND = 0;

private AddressBook addressBook;
private Roster roster;

private DefaultSignedStateValidator validator;

Expand Down Expand Up @@ -235,16 +237,16 @@ private static List<Node> initNodes() {
.collect(Collectors.toList()));
}

@NonNull
private static AddressBook createAddressBook(@NonNull final Random random, @NonNull final List<Node> nodes) {
final AddressBook addressBook = new AddressBook();
private static Roster createRoster(@NonNull final Random random, @NonNull final List<Node> nodes) {
List<RosterEntry> rosterEntries = new ArrayList<>(nodes.size());
for (final Node node : nodes.stream().sorted().toList()) {
addressBook.add(RandomAddressBuilder.create(random)
.withNodeId(node.id)
rosterEntries.add(RandomRosterEntryBuilder.create(random)
.withNodeId(node.id.id())
.withWeight(node.weight)
.build());
}
return addressBook;

return new Roster(rosterEntries);
}

@BeforeEach
Expand All @@ -262,7 +264,7 @@ void tearDown() {
@DisplayName("Signed State Validation")
void testSignedStateValidationRandom(final String desc, final List<Node> nodes, final List<Node> signingNodes) {
final Randotron randotron = Randotron.create();
addressBook = createAddressBook(randotron, nodes);
roster = createRoster(randotron, nodes);

final PlatformContext platformContext =
TestPlatformContextBuilder.create().build();
Expand All @@ -271,18 +273,18 @@ void testSignedStateValidationRandom(final String desc, final List<Node> nodes,

final SignedState signedState = stateSignedByNodes(signingNodes);
final SignedStateValidationData originalData =
new SignedStateValidationData(signedState.getState().getReadablePlatformState(), addressBook);
new SignedStateValidationData(signedState.getState().getReadablePlatformState(), roster);

final boolean shouldSucceed = stateHasEnoughWeight(nodes, signingNodes);
if (shouldSucceed) {
assertDoesNotThrow(
() -> validator.validate(signedState, addressBook, originalData),
() -> validator.validate(signedState, roster, originalData),
"State signed with a majority of weight (%s out of %s) should pass validation."
.formatted(getValidSignatureWeight(signingNodes), getTotalWeight(nodes)));
} else {
assertThrows(
SignedStateInvalidException.class,
() -> validator.validate(signedState, addressBook, originalData),
() -> validator.validate(signedState, roster, originalData),
"State not signed with a majority of weight (%s out of %s) should NOT pass validation."
.formatted(getValidSignatureWeight(signingNodes), getTotalWeight(nodes)));
}
Expand Down Expand Up @@ -342,7 +344,7 @@ private SignedState stateSignedByNodes(final List<Node> signingNodes) {

return new RandomSignedStateGenerator()
.setRound(ROUND)
.setAddressBook(addressBook)
.setAddressBook(RosterUtils.buildAddressBook(roster))
.setStateHash(stateHash)
.setSignatures(nodeSigs(signingNodes))
.setSignatureVerifier(signatureVerifier)
Expand Down

0 comments on commit 6a8c5e8

Please sign in to comment.