-
Notifications
You must be signed in to change notification settings - Fork 50
#1704/#1705 weighted decision selection #1718
base: master
Are you sure you want to change the base?
Changes from 3 commits
2fc7da4
b3b6ec6
9cf6fe5
0d0e4cf
0fef0a4
e94a048
73bc213
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,15 +16,16 @@ | |
package com.scottlogic.datahelix.generator.core.walker.decisionbased; | ||
|
||
import com.google.inject.Inject; | ||
import com.scottlogic.datahelix.generator.common.distribution.WeightedElement; | ||
import com.scottlogic.datahelix.generator.common.profile.Field; | ||
import com.scottlogic.datahelix.generator.common.profile.Fields; | ||
import com.scottlogic.datahelix.generator.core.fieldspecs.FieldSpecFactory; | ||
import com.scottlogic.datahelix.generator.common.profile.InSetRecord; | ||
import com.scottlogic.datahelix.generator.core.fieldspecs.*; | ||
import com.scottlogic.datahelix.generator.core.profile.constraints.atomic.AtomicConstraint; | ||
import com.scottlogic.datahelix.generator.core.decisiontree.ConstraintNode; | ||
import com.scottlogic.datahelix.generator.core.decisiontree.DecisionNode; | ||
import com.scottlogic.datahelix.generator.core.decisiontree.DecisionTree; | ||
import com.scottlogic.datahelix.generator.core.fieldspecs.FieldSpec; | ||
import com.scottlogic.datahelix.generator.core.fieldspecs.RowSpec; | ||
import com.scottlogic.datahelix.generator.core.profile.constraints.atomic.InSetConstraint; | ||
import com.scottlogic.datahelix.generator.core.reducer.ConstraintReducer; | ||
import com.scottlogic.datahelix.generator.core.walker.pruner.Merged; | ||
import com.scottlogic.datahelix.generator.core.walker.pruner.TreePruner; | ||
|
@@ -51,45 +52,91 @@ public RowSpecTreeSolver(ConstraintReducer constraintReducer, | |
this.optionPicker = optionPicker; | ||
} | ||
|
||
public Stream<RowSpec> createRowSpecs(DecisionTree tree) { | ||
public Stream<WeightedElement<RowSpec>> createRowSpecs(DecisionTree tree) { | ||
return flatMap(reduceToRowNodes(tree.rootNode), | ||
rootNode -> toRowspec(tree.fields, rootNode)); | ||
} | ||
|
||
private Stream<RowSpec> toRowspec(Fields fields, ConstraintNode rootNode) { | ||
Optional<RowSpec> result = constraintReducer.reduceConstraintsToRowSpec(fields, rootNode); | ||
return result.map(Stream::of).orElseGet(Stream::empty); | ||
private Stream<WeightedElement<RowSpec>> toRowspec(Fields fields, WeightedElement<ConstraintNode> rootNode) { | ||
Optional<RowSpec> result = constraintReducer.reduceConstraintsToRowSpec(fields, rootNode.element()); | ||
return result | ||
.map(rowSpec -> new WeightedElement<>(rowSpec, rootNode.weight())) | ||
.map(Stream::of) | ||
.orElseGet(Stream::empty); | ||
} | ||
|
||
/** | ||
* a row node is a constraint node with no further decisions | ||
*/ | ||
private Stream<ConstraintNode> reduceToRowNodes(ConstraintNode rootNode) { | ||
if (rootNode.getDecisions().isEmpty()) { | ||
private Stream<WeightedElement<ConstraintNode>> reduceToRowNodes(ConstraintNode rootNode) { | ||
return reduceToRowNodes(new WeightedElement<>(rootNode, 1)); | ||
} | ||
|
||
private Stream<WeightedElement<ConstraintNode>> reduceToRowNodes(WeightedElement<ConstraintNode> rootNode) { | ||
if (rootNode.element().getDecisions().isEmpty()) { | ||
return Stream.of(rootNode); | ||
} | ||
|
||
DecisionNode decisionNode = optionPicker.pickDecision(rootNode); | ||
ConstraintNode rootWithoutDecision = rootNode.builder().removeDecision(decisionNode).build(); | ||
DecisionNode decisionNode = optionPicker.pickDecision(rootNode.element()); | ||
ConstraintNode rootWithoutDecision = rootNode.element().builder().removeDecision(decisionNode).build(); | ||
|
||
Stream<ConstraintNode> rootOnlyConstraintNodes = optionPicker.streamOptions(decisionNode) | ||
Stream<WeightedElement<ConstraintNode>> rootOnlyConstraintNodes = optionPicker.streamOptions(decisionNode) | ||
.map(option -> combineWithRootNode(rootWithoutDecision, option)) | ||
.filter(newNode -> !newNode.isContradictory()) | ||
.map(Merged::get); | ||
.filter(newNode -> !newNode.element().isContradictory()) | ||
.map(weighted -> new WeightedElement<>(weighted.element().get(), weighted.weight())); | ||
|
||
return flatMap( | ||
rootOnlyConstraintNodes, | ||
this::reduceToRowNodes); | ||
weightedConstraintNode -> reduceToRowNodes( | ||
new WeightedElement<ConstraintNode>( | ||
weightedConstraintNode.element(), | ||
weightedConstraintNode.weight() / rootNode.weight()))); | ||
} | ||
|
||
private Merged<ConstraintNode> combineWithRootNode(ConstraintNode rootNode, ConstraintNode option) { | ||
private WeightedElement<Merged<ConstraintNode>> combineWithRootNode(ConstraintNode rootNode, ConstraintNode option) { | ||
ConstraintNode constraintNode = rootNode.builder() | ||
.addDecisions(option.getDecisions()) | ||
.addAtomicConstraints(option.getAtomicConstraints()) | ||
.addRelations(option.getRelations()) | ||
.build(); | ||
|
||
return treePruner.pruneConstraintNode(constraintNode, getFields(option)); | ||
double applicabilityOfThisOption = option.getAtomicConstraints().stream() | ||
.mapToDouble(optionAtomicConstraint -> rootNode.getAtomicConstraints().stream() | ||
.filter(rootAtomicConstraint -> rootAtomicConstraint.getField().equals(optionAtomicConstraint.getField())) | ||
.filter(rootAtomicConstraint -> rootAtomicConstraint instanceof InSetConstraint) | ||
.map(rootAtomicConstraint -> (InSetConstraint)rootAtomicConstraint) | ||
.findFirst() | ||
.map(matchingRootAtomicConstraint -> { | ||
double totalWeighting = matchingRootAtomicConstraint.legalValues.stream() | ||
.mapToDouble(InSetRecord::getWeightValueOrDefault).sum(); | ||
|
||
double relevantWeighting = matchingRootAtomicConstraint.legalValues.stream() | ||
.filter(legalValue -> optionAtomicConstraint.toFieldSpec().canCombineWithLegalValue(legalValue.getElement())) | ||
.mapToDouble(InSetRecord::getWeightValueOrDefault).sum(); | ||
|
||
return relevantWeighting / totalWeighting; | ||
}) | ||
.orElse(1d)) | ||
.sum(); | ||
|
||
if (applicabilityOfThisOption > 1){ | ||
double applicabilityFraction = applicabilityOfThisOption - (int) applicabilityOfThisOption; | ||
applicabilityOfThisOption = applicabilityFraction == 0 | ||
? 1 | ||
: applicabilityFraction; | ||
} | ||
|
||
if (applicabilityOfThisOption == 0){ | ||
return new WeightedElement<>( | ||
Merged.contradictory(), | ||
1 | ||
); | ||
} | ||
|
||
return new WeightedElement<>( | ||
treePruner.pruneConstraintNode(constraintNode, getFields(option)), | ||
applicabilityOfThisOption | ||
); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is very complicated. Possibly needs refactored into a few well named methods. Also as its complicated we might need a comment explaining what this thing is doing so someone who isn't familiar with this area can understand it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added some comments to the code and refactored some of the calculations into functions, the complexity is somewhat unavoidable i think. |
||
} | ||
|
||
private Map<Field, FieldSpec> getFields(ConstraintNode option) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
package com.scottlogic.datahelix.generator.core.walker.rowspec; | ||
|
||
import com.google.inject.Inject; | ||
import com.scottlogic.datahelix.generator.common.distribution.WeightedElement; | ||
import com.scottlogic.datahelix.generator.core.decisiontree.DecisionTree; | ||
import com.scottlogic.datahelix.generator.core.fieldspecs.RowSpec; | ||
import com.scottlogic.datahelix.generator.core.generation.databags.DataBag; | ||
|
@@ -59,7 +60,7 @@ public Stream<DataBag> walk(DecisionTree tree) { | |
} | ||
|
||
private Stream<RowSpec> getFromCachedRowSpecs(DecisionTree tree) { | ||
List<RowSpec> rowSpecCache = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toList()); | ||
List<WeightedElement<RowSpec>> rowSpecCache = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toList()); | ||
return Stream.generate(() -> getRandomRowSpec(rowSpecCache)); | ||
} | ||
|
||
|
@@ -79,11 +80,46 @@ private Stream<RowSpec> getRowSpecAndRestart(DecisionTree tree) { | |
} | ||
|
||
private Optional<RowSpec> getFirstRowSpec(DecisionTree tree) { | ||
return rowSpecTreeSolver.createRowSpecs(tree).findFirst(); | ||
return rowSpecTreeSolver.createRowSpecs(tree).findFirst().map(WeightedElement::element); | ||
} | ||
|
||
private RowSpec getRandomRowSpec(List<RowSpec> rowSpecCache) { | ||
return rowSpecCache.get(random.nextInt(rowSpecCache.size())); | ||
/** | ||
* Get a row spec from the rowSpecCache in a weighted manner.<br /> | ||
* I.e. if there are 2 rowSpecs, one with a 70% weighting and the other with 30%.<br /> | ||
* Calling this method 10 times should *ROUGHLY* emit 7 of the 70% weighted rowSpecs and 3 of the others.<br /> | ||
* It does this by producing a virtual rowSpec 'range', i.e.<br /> | ||
* - values between 1 and 70 represent the 70% weighted rowSpec<br /> | ||
* - values between 71 and 100 represent the 30% weighted rowSpec<br /> | ||
* <br /> | ||
* The function then picks a random number between 1 and 100 and yields the rowSpec that encapsulates that value.<br /> | ||
* <br /> | ||
* As this method uses a random number generator, it will not ALWAYS yield a correct split, but it is more LIKELY than not.<br /> | ||
* @param rowSpecCache a list of weighted rowSpecs (weighting is between 0 and 1) | ||
* @return a rowSpec picked from the list of weighted rowSpecs | ||
*/ | ||
private RowSpec getRandomRowSpec(List<WeightedElement<RowSpec>> rowSpecCache) { | ||
double totalRange = rowSpecCache.stream() | ||
.mapToDouble(WeightedElement::weight).sum(); | ||
|
||
double nextRowSpecFromRange = random.nextInt((int)(totalRange * 100)) / 100d; | ||
|
||
double currentPosition = 0d; | ||
WeightedElement<RowSpec> lastRowSpec = null; | ||
|
||
for (WeightedElement<RowSpec> weightedRowSpec: rowSpecCache) { | ||
double thisRowSpecRange = weightedRowSpec.weight(); | ||
double newPosition = thisRowSpecRange + currentPosition; | ||
This conversation was marked as resolved.
Show resolved
Hide resolved
|
||
|
||
if (newPosition >= nextRowSpecFromRange) | ||
{ | ||
return weightedRowSpec.element(); | ||
} | ||
|
||
currentPosition = newPosition; | ||
lastRowSpec = weightedRowSpec; | ||
} | ||
|
||
return lastRowSpec.element(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we protect against a nullptr here in case lastRowSpec is null (eg rowSpecCache is empty could cause this). In which case return ??? (ideally some kind of 'empty' row spec rather than null - depends what calling code assumes). Or should we throw an exception as never expect this to be called with empty rowSepcCache. |
||
} | ||
|
||
private DataBag createDataBag(RowSpec rowSpec) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe make this more modular and change this to:
mapToDouble(optionAtomicConstraint -> findMatchingInSetConstraints(optionAtomicConstraint.getField(), rootNode.getAtomicConstraints())
and then define the method:
Stream findMatchingInSetConstraints(fieldtoMatch, atomicConstraints