Skip to content
This repository has been archived by the owner on Apr 14, 2023. It is now read-only.

#1704/#1705 weighted decision selection #1718

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public FieldSpec toFieldSpec() {

@Override
public String toString(){
return String.format("`%s` = %s", field.getName(), value);
return String.format("`%s` != %s", field.getName(), value);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
package com.scottlogic.datahelix.generator.core.walker.decisionbased;

import com.google.inject.Inject;
import com.scottlogic.datahelix.generator.common.distribution.WeightedElement;
import com.scottlogic.datahelix.generator.common.profile.Field;
import com.scottlogic.datahelix.generator.common.profile.Fields;
import com.scottlogic.datahelix.generator.core.fieldspecs.FieldSpecFactory;
import com.scottlogic.datahelix.generator.common.profile.InSetRecord;
import com.scottlogic.datahelix.generator.core.fieldspecs.*;
import com.scottlogic.datahelix.generator.core.profile.constraints.atomic.AtomicConstraint;
import com.scottlogic.datahelix.generator.core.decisiontree.ConstraintNode;
import com.scottlogic.datahelix.generator.core.decisiontree.DecisionNode;
import com.scottlogic.datahelix.generator.core.decisiontree.DecisionTree;
import com.scottlogic.datahelix.generator.core.fieldspecs.FieldSpec;
import com.scottlogic.datahelix.generator.core.fieldspecs.RowSpec;
import com.scottlogic.datahelix.generator.core.profile.constraints.atomic.InSetConstraint;
import com.scottlogic.datahelix.generator.core.reducer.ConstraintReducer;
import com.scottlogic.datahelix.generator.core.walker.pruner.Merged;
import com.scottlogic.datahelix.generator.core.walker.pruner.TreePruner;
Expand All @@ -51,45 +52,114 @@ public RowSpecTreeSolver(ConstraintReducer constraintReducer,
this.optionPicker = optionPicker;
}

public Stream<RowSpec> createRowSpecs(DecisionTree tree) {
public Stream<WeightedElement<RowSpec>> createRowSpecs(DecisionTree tree) {
return flatMap(reduceToRowNodes(tree.rootNode),
rootNode -> toRowspec(tree.fields, rootNode));
}

private Stream<RowSpec> toRowspec(Fields fields, ConstraintNode rootNode) {
Optional<RowSpec> result = constraintReducer.reduceConstraintsToRowSpec(fields, rootNode);
return result.map(Stream::of).orElseGet(Stream::empty);
private Stream<WeightedElement<RowSpec>> toRowspec(Fields fields, WeightedElement<ConstraintNode> rootNode) {
Optional<RowSpec> result = constraintReducer.reduceConstraintsToRowSpec(fields, rootNode.element());
return result
.map(rowSpec -> new WeightedElement<>(rowSpec, rootNode.weight()))
.map(Stream::of)
.orElseGet(Stream::empty);
}

/**
* a row node is a constraint node with no further decisions
*/
private Stream<ConstraintNode> reduceToRowNodes(ConstraintNode rootNode) {
if (rootNode.getDecisions().isEmpty()) {
private Stream<WeightedElement<ConstraintNode>> reduceToRowNodes(ConstraintNode rootNode) {
return reduceToRowNodes(new WeightedElement<>(rootNode, 1));
}

private Stream<WeightedElement<ConstraintNode>> reduceToRowNodes(WeightedElement<ConstraintNode> rootNode) {
if (rootNode.element().getDecisions().isEmpty()) {
return Stream.of(rootNode);
}

DecisionNode decisionNode = optionPicker.pickDecision(rootNode);
ConstraintNode rootWithoutDecision = rootNode.builder().removeDecision(decisionNode).build();
DecisionNode decisionNode = optionPicker.pickDecision(rootNode.element());
ConstraintNode rootWithoutDecision = rootNode.element().builder().removeDecision(decisionNode).build();

Stream<ConstraintNode> rootOnlyConstraintNodes = optionPicker.streamOptions(decisionNode)
Stream<WeightedElement<ConstraintNode>> rootOnlyConstraintNodes = optionPicker.streamOptions(decisionNode)
.map(option -> combineWithRootNode(rootWithoutDecision, option))
.filter(newNode -> !newNode.isContradictory())
.map(Merged::get);
.filter(newNode -> !newNode.element().isContradictory())
.map(weighted -> new WeightedElement<>(weighted.element().get(), weighted.weight()));

return flatMap(
rootOnlyConstraintNodes,
this::reduceToRowNodes);
weightedConstraintNode -> reduceToRowNodes(
new WeightedElement<ConstraintNode>(
weightedConstraintNode.element(),
weightedConstraintNode.weight() / rootNode.weight())));
}

private Merged<ConstraintNode> combineWithRootNode(ConstraintNode rootNode, ConstraintNode option) {
private WeightedElement<Merged<ConstraintNode>> combineWithRootNode(ConstraintNode rootNode, ConstraintNode option) {
ConstraintNode constraintNode = rootNode.builder()
.addDecisions(option.getDecisions())
.addAtomicConstraints(option.getAtomicConstraints())
.addRelations(option.getRelations())
.build();

return treePruner.pruneConstraintNode(constraintNode, getFields(option));
/*
Find the relevance of this option in the context of the the tree.
i.e. if this option says that a field must be equal to A in the possible set of A (20%), B (30%) and C (50%) yield the weighing of A (20%)
applicabilityOfThisOption should yield 0.2 in this case (20% of 1)
*/
double applicabilityOfThisOption = option.getAtomicConstraints().stream()
.mapToDouble(optionAtomicConstraint -> rootNode.getAtomicConstraints().stream()
.filter(rootAtomicConstraint -> rootAtomicConstraint.getField().equals(optionAtomicConstraint.getField()))
.filter(rootAtomicConstraint -> rootAtomicConstraint instanceof InSetConstraint)
.map(rootAtomicConstraint -> (InSetConstraint)rootAtomicConstraint)
Comment on lines +109 to +112
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe make this more modular and change this to:
mapToDouble(optionAtomicConstraint -> findMatchingInSetConstraints(optionAtomicConstraint.getField(), rootNode.getAtomicConstraints())

and then define the method:
Stream findMatchingInSetConstraints(fieldtoMatch, atomicConstraints

.findFirst()
.map(matchingRootAtomicConstraint -> {
double totalWeighting = getWeightOfAllLegalValues(matchingRootAtomicConstraint);
double relevantWeighting = getWeightOfAllPermittedLegalValues(matchingRootAtomicConstraint, optionAtomicConstraint);

return relevantWeighting / totalWeighting;
Comment on lines +115 to +118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe encapsulate this in a calculateApplicability method

})
.orElse(1d))
.sum();

if (applicabilityOfThisOption > 1) {
/*
the applicability of this option (e.g. A) is greater than 100%, retrieve the fractional part of the number.
if there is no fractional part, treat the applicability as 1 otherwise use the fractional part as the percentage.

This can happen when other fields in the option match an inSet constraint, if they also have a weighting it
will be taken into account, otherwise 1 (100%) would be used, hence 100% (field1) + 20% (field2) = 1.2.
*/

double applicabilityFraction = applicabilityOfThisOption - (int) applicabilityOfThisOption;
applicabilityOfThisOption = applicabilityFraction == 0
? 1
: applicabilityFraction;
}

if (applicabilityOfThisOption == 0){
/*no options were applicable, maybe an option was given 0%, yield contradictory so the option isn't processed.*/

return new WeightedElement<>(
Merged.contradictory(),
1
);
}

/*Yield a weighted element to inform the caller of the weighting of this constraint node*/
return new WeightedElement<>(
treePruner.pruneConstraintNode(constraintNode, getFields(option)),
applicabilityOfThisOption
);
}

private static double getWeightOfAllLegalValues(InSetConstraint matchingRootAtomicConstraint){
return matchingRootAtomicConstraint.legalValues.stream()
.mapToDouble(InSetRecord::getWeightValueOrDefault).sum();
}

private static double getWeightOfAllPermittedLegalValues(InSetConstraint matchingRootAtomicConstraint, AtomicConstraint optionAtomicConstraint){
return matchingRootAtomicConstraint.legalValues.stream()
.filter(legalValue -> optionAtomicConstraint.toFieldSpec().canCombineWithLegalValue(legalValue.getElement()))
.mapToDouble(InSetRecord::getWeightValueOrDefault).sum();
}

private Map<Field, FieldSpec> getFields(ConstraintNode option) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.scottlogic.datahelix.generator.core.walker.rowspec;

import com.google.inject.Inject;
import com.scottlogic.datahelix.generator.common.distribution.WeightedElement;
import com.scottlogic.datahelix.generator.core.decisiontree.DecisionTree;
import com.scottlogic.datahelix.generator.core.fieldspecs.RowSpec;
import com.scottlogic.datahelix.generator.core.generation.databags.DataBag;
Expand Down Expand Up @@ -59,7 +60,7 @@ public Stream<DataBag> walk(DecisionTree tree) {
}

private Stream<RowSpec> getFromCachedRowSpecs(DecisionTree tree) {
List<RowSpec> rowSpecCache = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toList());
List<WeightedElement<RowSpec>> rowSpecCache = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toList());
return Stream.generate(() -> getRandomRowSpec(rowSpecCache));
}

Expand All @@ -79,11 +80,43 @@ private Stream<RowSpec> getRowSpecAndRestart(DecisionTree tree) {
}

private Optional<RowSpec> getFirstRowSpec(DecisionTree tree) {
return rowSpecTreeSolver.createRowSpecs(tree).findFirst();
return rowSpecTreeSolver.createRowSpecs(tree).findFirst().map(WeightedElement::element);
}

private RowSpec getRandomRowSpec(List<RowSpec> rowSpecCache) {
return rowSpecCache.get(random.nextInt(rowSpecCache.size()));
/**
* Get a row spec from the rowSpecCache in a weighted manner.<br />
* I.e. if there are 2 rowSpecs, one with a 70% weighting and the other with 30%.<br />
* Calling this method 10 times should *ROUGHLY* emit 7 of the 70% weighted rowSpecs and 3 of the others.<br />
* It does this by producing a virtual rowSpec 'range', i.e.<br />
* - values between 1 and 70 represent the 70% weighted rowSpec<br />
* - values between 71 and 100 represent the 30% weighted rowSpec<br />
* <br />
* The function then picks a random number between 1 and 100 and yields the rowSpec that encapsulates that value.<br />
* <br />
* As this method uses a random number generator, it will not ALWAYS yield a correct split, but it is more LIKELY than not.<br />
* @param rowSpecCache a list of weighted rowSpecs (weighting is between 0 and 1)
* @return a rowSpec picked from the list of weighted rowSpecs
*/
private RowSpec getRandomRowSpec(List<WeightedElement<RowSpec>> rowSpecCache) {
double totalRange = rowSpecCache.stream()
.mapToDouble(WeightedElement::weight).sum();

double nextRowSpecFromRange = random.nextInt((int)(totalRange * 100)) / 100d;

double currentPosition = 0d;
WeightedElement<RowSpec> lastRowSpec = null;

for (WeightedElement<RowSpec> weightedRowSpec: rowSpecCache) {
currentPosition += weightedRowSpec.weight();

if (currentPosition >= nextRowSpecFromRange) {
return weightedRowSpec.element();
}

lastRowSpec = weightedRowSpec;
}

return lastRowSpec.element();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we protect against a nullptr here in case lastRowSpec is null (eg rowSpecCache is empty could cause this). In which case return ??? (ideally some kind of 'empty' row spec rather than null - depends what calling code assumes). Or should we throw an exception as never expect this to be called with empty rowSepcCache.

}

private DataBag createDataBag(RowSpec rowSpec) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.scottlogic.datahelix.generator.core.walker.rowspec;

import com.google.inject.Inject;
import com.scottlogic.datahelix.generator.common.distribution.WeightedElement;
import com.scottlogic.datahelix.generator.common.util.FlatMappingSpliterator;
import com.scottlogic.datahelix.generator.core.decisiontree.DecisionTree;
import com.scottlogic.datahelix.generator.core.generation.databags.DataBag;
Expand All @@ -38,7 +39,7 @@ public RowSpecDecisionTreeWalker(RowSpecTreeSolver rowSpecTreeSolver, RowSpecDat
@Override
public Stream<DataBag> walk(DecisionTree tree) {
return FlatMappingSpliterator.flatMap(
rowSpecTreeSolver.createRowSpecs(tree),
rowSpecTreeSolver.createRowSpecs(tree).map(WeightedElement::element),
rowSpecDataBagGenerator::createDataBags);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class TestAtomicConstraintBuilder {
private TestConstraintNodeBuilder testConstraintNodeBuilder;
Expand Down Expand Up @@ -56,6 +57,14 @@ public TestConstraintNodeBuilder isInSet(Object... legalValues) {
return testConstraintNodeBuilder;
}

public TestConstraintNodeBuilder isInSet(InSetRecord... weightedValues) {
InSetConstraint inSetConstraint = new InSetConstraint(
field,
Stream.of(weightedValues).collect(Collectors.toList()));
testConstraintNodeBuilder.constraints.add(inSetConstraint);
return testConstraintNodeBuilder;
}

public TestConstraintNodeBuilder isNotInSet(Object... legalValues) {
AtomicConstraint isInSetConstraint = new InSetConstraint(
field,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.scottlogic.datahelix.generator.core.decisiontree;

import com.scottlogic.datahelix.generator.common.distribution.WeightedElement;
import com.scottlogic.datahelix.generator.common.profile.Field;
import com.scottlogic.datahelix.generator.common.profile.Fields;
import com.scottlogic.datahelix.generator.common.profile.ProfileFields;
Expand All @@ -32,7 +33,6 @@
import com.scottlogic.datahelix.generator.core.walker.decisionbased.RowSpecTreeSolver;
import com.scottlogic.datahelix.generator.core.walker.decisionbased.SequentialOptionPicker;
import com.scottlogic.datahelix.generator.core.walker.pruner.TreePruner;
import org.junit.Assert;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
Expand All @@ -47,6 +47,7 @@
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static com.shazam.shazamcrest.MatcherAssert.assertThat;

class RowSpecTreeSolverTests {
private final FieldSpecMerger fieldSpecMerger = new FieldSpecMerger();
Expand Down Expand Up @@ -115,9 +116,10 @@ void test()

final List<RowSpec> rowSpecs = dTreeWalker
.createRowSpecs(merged)
.map(WeightedElement::element)
.collect(Collectors.toList());

Assert.assertThat(rowSpecs, notNullValue());
assertThat(rowSpecs, notNullValue());
}

}
Loading