Skip to content
This repository has been archived by the owner on Apr 14, 2023. It is now read-only.

#1704/#1705 weighted decision selection #1718

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public FieldSpec toFieldSpec() {

@Override
public String toString(){
return String.format("`%s` = %s", field.getName(), value);
return String.format("`%s` != %s", field.getName(), value);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
package com.scottlogic.datahelix.generator.core.walker.decisionbased;

import com.google.inject.Inject;
import com.scottlogic.datahelix.generator.common.distribution.WeightedElement;
import com.scottlogic.datahelix.generator.common.profile.Field;
import com.scottlogic.datahelix.generator.common.profile.Fields;
import com.scottlogic.datahelix.generator.core.fieldspecs.FieldSpecFactory;
import com.scottlogic.datahelix.generator.common.profile.InSetRecord;
import com.scottlogic.datahelix.generator.core.fieldspecs.*;
import com.scottlogic.datahelix.generator.core.profile.constraints.atomic.AtomicConstraint;
import com.scottlogic.datahelix.generator.core.decisiontree.ConstraintNode;
import com.scottlogic.datahelix.generator.core.decisiontree.DecisionNode;
import com.scottlogic.datahelix.generator.core.decisiontree.DecisionTree;
import com.scottlogic.datahelix.generator.core.fieldspecs.FieldSpec;
import com.scottlogic.datahelix.generator.core.fieldspecs.RowSpec;
import com.scottlogic.datahelix.generator.core.profile.constraints.atomic.InSetConstraint;
import com.scottlogic.datahelix.generator.core.reducer.ConstraintReducer;
import com.scottlogic.datahelix.generator.core.walker.pruner.Merged;
import com.scottlogic.datahelix.generator.core.walker.pruner.TreePruner;
Expand All @@ -51,45 +52,91 @@ public RowSpecTreeSolver(ConstraintReducer constraintReducer,
this.optionPicker = optionPicker;
}

public Stream<RowSpec> createRowSpecs(DecisionTree tree) {
public Stream<WeightedElement<RowSpec>> createRowSpecs(DecisionTree tree) {
return flatMap(reduceToRowNodes(tree.rootNode),
rootNode -> toRowspec(tree.fields, rootNode));
}

private Stream<RowSpec> toRowspec(Fields fields, ConstraintNode rootNode) {
Optional<RowSpec> result = constraintReducer.reduceConstraintsToRowSpec(fields, rootNode);
return result.map(Stream::of).orElseGet(Stream::empty);
private Stream<WeightedElement<RowSpec>> toRowspec(Fields fields, WeightedElement<ConstraintNode> rootNode) {
Optional<RowSpec> result = constraintReducer.reduceConstraintsToRowSpec(fields, rootNode.element());
return result
.map(rowSpec -> new WeightedElement<>(rowSpec, rootNode.weight()))
.map(Stream::of)
.orElseGet(Stream::empty);
}

/**
* a row node is a constraint node with no further decisions
*/
private Stream<ConstraintNode> reduceToRowNodes(ConstraintNode rootNode) {
if (rootNode.getDecisions().isEmpty()) {
private Stream<WeightedElement<ConstraintNode>> reduceToRowNodes(ConstraintNode rootNode) {
return reduceToRowNodes(new WeightedElement<>(rootNode, 1));
}

private Stream<WeightedElement<ConstraintNode>> reduceToRowNodes(WeightedElement<ConstraintNode> rootNode) {
if (rootNode.element().getDecisions().isEmpty()) {
return Stream.of(rootNode);
}

DecisionNode decisionNode = optionPicker.pickDecision(rootNode);
ConstraintNode rootWithoutDecision = rootNode.builder().removeDecision(decisionNode).build();
DecisionNode decisionNode = optionPicker.pickDecision(rootNode.element());
ConstraintNode rootWithoutDecision = rootNode.element().builder().removeDecision(decisionNode).build();

Stream<ConstraintNode> rootOnlyConstraintNodes = optionPicker.streamOptions(decisionNode)
Stream<WeightedElement<ConstraintNode>> rootOnlyConstraintNodes = optionPicker.streamOptions(decisionNode)
.map(option -> combineWithRootNode(rootWithoutDecision, option))
.filter(newNode -> !newNode.isContradictory())
.map(Merged::get);
.filter(newNode -> !newNode.element().isContradictory())
.map(weighted -> new WeightedElement<>(weighted.element().get(), weighted.weight()));

return flatMap(
rootOnlyConstraintNodes,
this::reduceToRowNodes);
weightedConstraintNode -> reduceToRowNodes(
new WeightedElement<ConstraintNode>(
weightedConstraintNode.element(),
weightedConstraintNode.weight() / rootNode.weight())));
}

private Merged<ConstraintNode> combineWithRootNode(ConstraintNode rootNode, ConstraintNode option) {
private WeightedElement<Merged<ConstraintNode>> combineWithRootNode(ConstraintNode rootNode, ConstraintNode option) {
ConstraintNode constraintNode = rootNode.builder()
.addDecisions(option.getDecisions())
.addAtomicConstraints(option.getAtomicConstraints())
.addRelations(option.getRelations())
.build();

return treePruner.pruneConstraintNode(constraintNode, getFields(option));
double applicabilityOfThisOption = option.getAtomicConstraints().stream()
.mapToDouble(optionAtomicConstraint -> rootNode.getAtomicConstraints().stream()
.filter(rootAtomicConstraint -> rootAtomicConstraint.getField().equals(optionAtomicConstraint.getField()))
.filter(rootAtomicConstraint -> rootAtomicConstraint instanceof InSetConstraint)
.map(rootAtomicConstraint -> (InSetConstraint)rootAtomicConstraint)
Comment on lines +109 to +112
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe make this more modular and change this to:
mapToDouble(optionAtomicConstraint -> findMatchingInSetConstraints(optionAtomicConstraint.getField(), rootNode.getAtomicConstraints())

and then define the method:
Stream findMatchingInSetConstraints(fieldtoMatch, atomicConstraints

.findFirst()
.map(matchingRootAtomicConstraint -> {
double totalWeighting = matchingRootAtomicConstraint.legalValues.stream()
.mapToDouble(InSetRecord::getWeightValueOrDefault).sum();

double relevantWeighting = matchingRootAtomicConstraint.legalValues.stream()
.filter(legalValue -> optionAtomicConstraint.toFieldSpec().canCombineWithLegalValue(legalValue.getElement()))
.mapToDouble(InSetRecord::getWeightValueOrDefault).sum();

return relevantWeighting / totalWeighting;
})
.orElse(1d))
.sum();

if (applicabilityOfThisOption > 1){
double applicabilityFraction = applicabilityOfThisOption - (int) applicabilityOfThisOption;
applicabilityOfThisOption = applicabilityFraction == 0
? 1
: applicabilityFraction;
}

if (applicabilityOfThisOption == 0){
return new WeightedElement<>(
Merged.contradictory(),
1
);
}

return new WeightedElement<>(
treePruner.pruneConstraintNode(constraintNode, getFields(option)),
applicabilityOfThisOption
);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very complicated. Possibly needs refactored into a few well named methods. Also as its complicated we might need a comment explaining what this thing is doing so someone who isn't familiar with this area can understand it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some comments to the code and refactored some of the calculations into functions, the complexity is somewhat unavoidable i think.

}

private Map<Field, FieldSpec> getFields(ConstraintNode option) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.scottlogic.datahelix.generator.core.walker.rowspec;

import com.google.inject.Inject;
import com.scottlogic.datahelix.generator.common.distribution.WeightedElement;
import com.scottlogic.datahelix.generator.core.decisiontree.DecisionTree;
import com.scottlogic.datahelix.generator.core.fieldspecs.RowSpec;
import com.scottlogic.datahelix.generator.core.generation.databags.DataBag;
Expand Down Expand Up @@ -59,7 +60,7 @@ public Stream<DataBag> walk(DecisionTree tree) {
}

private Stream<RowSpec> getFromCachedRowSpecs(DecisionTree tree) {
List<RowSpec> rowSpecCache = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toList());
List<WeightedElement<RowSpec>> rowSpecCache = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toList());
return Stream.generate(() -> getRandomRowSpec(rowSpecCache));
}

Expand All @@ -79,11 +80,46 @@ private Stream<RowSpec> getRowSpecAndRestart(DecisionTree tree) {
}

private Optional<RowSpec> getFirstRowSpec(DecisionTree tree) {
return rowSpecTreeSolver.createRowSpecs(tree).findFirst();
return rowSpecTreeSolver.createRowSpecs(tree).findFirst().map(WeightedElement::element);
}

private RowSpec getRandomRowSpec(List<RowSpec> rowSpecCache) {
return rowSpecCache.get(random.nextInt(rowSpecCache.size()));
/**
* Get a row spec from the rowSpecCache in a weighted manner.<br />
* I.e. if there are 2 rowSpecs, one with a 70% weighting and the other with 30%.<br />
* Calling this method 10 times should *ROUGHLY* emit 7 of the 70% weighted rowSpecs and 3 of the others.<br />
* It does this by producing a virtual rowSpec 'range', i.e.<br />
* - values between 1 and 70 represent the 70% weighted rowSpec<br />
* - values between 71 and 100 represent the 30% weighted rowSpec<br />
* <br />
* The function then picks a random number between 1 and 100 and yields the rowSpec that encapsulates that value.<br />
* <br />
* As this method uses a random number generator, it will not ALWAYS yield a correct split, but it is more LIKELY than not.<br />
* @param rowSpecCache a list of weighted rowSpecs (weighting is between 0 and 1)
* @return a rowSpec picked from the list of weighted rowSpecs
*/
private RowSpec getRandomRowSpec(List<WeightedElement<RowSpec>> rowSpecCache) {
double totalRange = rowSpecCache.stream()
.mapToDouble(WeightedElement::weight).sum();

double nextRowSpecFromRange = random.nextInt((int)(totalRange * 100)) / 100d;

double currentPosition = 0d;
WeightedElement<RowSpec> lastRowSpec = null;

for (WeightedElement<RowSpec> weightedRowSpec: rowSpecCache) {
double thisRowSpecRange = weightedRowSpec.weight();
double newPosition = thisRowSpecRange + currentPosition;
This conversation was marked as resolved.
Show resolved Hide resolved

if (newPosition >= nextRowSpecFromRange)
{
return weightedRowSpec.element();
}

currentPosition = newPosition;
lastRowSpec = weightedRowSpec;
}

return lastRowSpec.element();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we protect against a nullptr here in case lastRowSpec is null (eg rowSpecCache is empty could cause this). In which case return ??? (ideally some kind of 'empty' row spec rather than null - depends what calling code assumes). Or should we throw an exception as never expect this to be called with empty rowSepcCache.

}

private DataBag createDataBag(RowSpec rowSpec) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.scottlogic.datahelix.generator.core.walker.rowspec;

import com.google.inject.Inject;
import com.scottlogic.datahelix.generator.common.distribution.WeightedElement;
import com.scottlogic.datahelix.generator.common.util.FlatMappingSpliterator;
import com.scottlogic.datahelix.generator.core.decisiontree.DecisionTree;
import com.scottlogic.datahelix.generator.core.generation.databags.DataBag;
Expand All @@ -38,7 +39,7 @@ public RowSpecDecisionTreeWalker(RowSpecTreeSolver rowSpecTreeSolver, RowSpecDat
@Override
public Stream<DataBag> walk(DecisionTree tree) {
return FlatMappingSpliterator.flatMap(
rowSpecTreeSolver.createRowSpecs(tree),
rowSpecTreeSolver.createRowSpecs(tree).map(WeightedElement::element),
rowSpecDataBagGenerator::createDataBags);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class TestAtomicConstraintBuilder {
private TestConstraintNodeBuilder testConstraintNodeBuilder;
Expand Down Expand Up @@ -56,6 +57,14 @@ public TestConstraintNodeBuilder isInSet(Object... legalValues) {
return testConstraintNodeBuilder;
}

public TestConstraintNodeBuilder isInSet(InSetRecord... weightedValues) {
InSetConstraint inSetConstraint = new InSetConstraint(
field,
Stream.of(weightedValues).collect(Collectors.toList()));
testConstraintNodeBuilder.constraints.add(inSetConstraint);
return testConstraintNodeBuilder;
}

public TestConstraintNodeBuilder isNotInSet(Object... legalValues) {
AtomicConstraint isInSetConstraint = new InSetConstraint(
field,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.scottlogic.datahelix.generator.core.decisiontree;

import com.scottlogic.datahelix.generator.common.distribution.WeightedElement;
import com.scottlogic.datahelix.generator.common.profile.Field;
import com.scottlogic.datahelix.generator.common.profile.Fields;
import com.scottlogic.datahelix.generator.common.profile.ProfileFields;
Expand All @@ -32,7 +33,6 @@
import com.scottlogic.datahelix.generator.core.walker.decisionbased.RowSpecTreeSolver;
import com.scottlogic.datahelix.generator.core.walker.decisionbased.SequentialOptionPicker;
import com.scottlogic.datahelix.generator.core.walker.pruner.TreePruner;
import org.junit.Assert;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
Expand All @@ -47,6 +47,7 @@
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static com.shazam.shazamcrest.MatcherAssert.assertThat;

class RowSpecTreeSolverTests {
private final FieldSpecMerger fieldSpecMerger = new FieldSpecMerger();
Expand Down Expand Up @@ -115,9 +116,10 @@ void test()

final List<RowSpec> rowSpecs = dTreeWalker
.createRowSpecs(merged)
.map(WeightedElement::element)
.collect(Collectors.toList());

Assert.assertThat(rowSpecs, notNullValue());
assertThat(rowSpecs, notNullValue());
}

}
Loading