This repository has been archived by the owner on Apr 14, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 50
#1704/#1705 weighted decision selection #1718
Draft
ghost
wants to merge
7
commits into
master
Choose a base branch
from
feature/1704-weighted-decision-selection
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
2fc7da4
chore(#1704): Update string representation in notEqual constraint
b3b6ec6
fix(#1704): Add weighted decision analysis to generation
9cf6fe5
chore(#1704): Add javaDoc to method
0d0e4cf
chore(#1704): Minor refactor from code review comments
0fef0a4
chore(#1704): Refactor tests to use helper methods to clarify intent
e94a048
chore(#1704): Refactor logic into specific functions
73bc213
chore(#1704): Add comments to explain the purpose of the method state…
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,15 +16,16 @@ | |
package com.scottlogic.datahelix.generator.core.walker.decisionbased; | ||
|
||
import com.google.inject.Inject; | ||
import com.scottlogic.datahelix.generator.common.distribution.WeightedElement; | ||
import com.scottlogic.datahelix.generator.common.profile.Field; | ||
import com.scottlogic.datahelix.generator.common.profile.Fields; | ||
import com.scottlogic.datahelix.generator.core.fieldspecs.FieldSpecFactory; | ||
import com.scottlogic.datahelix.generator.common.profile.InSetRecord; | ||
import com.scottlogic.datahelix.generator.core.fieldspecs.*; | ||
import com.scottlogic.datahelix.generator.core.profile.constraints.atomic.AtomicConstraint; | ||
import com.scottlogic.datahelix.generator.core.decisiontree.ConstraintNode; | ||
import com.scottlogic.datahelix.generator.core.decisiontree.DecisionNode; | ||
import com.scottlogic.datahelix.generator.core.decisiontree.DecisionTree; | ||
import com.scottlogic.datahelix.generator.core.fieldspecs.FieldSpec; | ||
import com.scottlogic.datahelix.generator.core.fieldspecs.RowSpec; | ||
import com.scottlogic.datahelix.generator.core.profile.constraints.atomic.InSetConstraint; | ||
import com.scottlogic.datahelix.generator.core.reducer.ConstraintReducer; | ||
import com.scottlogic.datahelix.generator.core.walker.pruner.Merged; | ||
import com.scottlogic.datahelix.generator.core.walker.pruner.TreePruner; | ||
|
@@ -51,45 +52,114 @@ public RowSpecTreeSolver(ConstraintReducer constraintReducer, | |
this.optionPicker = optionPicker; | ||
} | ||
|
||
public Stream<RowSpec> createRowSpecs(DecisionTree tree) { | ||
public Stream<WeightedElement<RowSpec>> createRowSpecs(DecisionTree tree) { | ||
return flatMap(reduceToRowNodes(tree.rootNode), | ||
rootNode -> toRowspec(tree.fields, rootNode)); | ||
} | ||
|
||
private Stream<RowSpec> toRowspec(Fields fields, ConstraintNode rootNode) { | ||
Optional<RowSpec> result = constraintReducer.reduceConstraintsToRowSpec(fields, rootNode); | ||
return result.map(Stream::of).orElseGet(Stream::empty); | ||
private Stream<WeightedElement<RowSpec>> toRowspec(Fields fields, WeightedElement<ConstraintNode> rootNode) { | ||
Optional<RowSpec> result = constraintReducer.reduceConstraintsToRowSpec(fields, rootNode.element()); | ||
return result | ||
.map(rowSpec -> new WeightedElement<>(rowSpec, rootNode.weight())) | ||
.map(Stream::of) | ||
.orElseGet(Stream::empty); | ||
} | ||
|
||
/** | ||
* a row node is a constraint node with no further decisions | ||
*/ | ||
private Stream<ConstraintNode> reduceToRowNodes(ConstraintNode rootNode) { | ||
if (rootNode.getDecisions().isEmpty()) { | ||
private Stream<WeightedElement<ConstraintNode>> reduceToRowNodes(ConstraintNode rootNode) { | ||
return reduceToRowNodes(new WeightedElement<>(rootNode, 1)); | ||
} | ||
|
||
private Stream<WeightedElement<ConstraintNode>> reduceToRowNodes(WeightedElement<ConstraintNode> rootNode) { | ||
if (rootNode.element().getDecisions().isEmpty()) { | ||
return Stream.of(rootNode); | ||
} | ||
|
||
DecisionNode decisionNode = optionPicker.pickDecision(rootNode); | ||
ConstraintNode rootWithoutDecision = rootNode.builder().removeDecision(decisionNode).build(); | ||
DecisionNode decisionNode = optionPicker.pickDecision(rootNode.element()); | ||
ConstraintNode rootWithoutDecision = rootNode.element().builder().removeDecision(decisionNode).build(); | ||
|
||
Stream<ConstraintNode> rootOnlyConstraintNodes = optionPicker.streamOptions(decisionNode) | ||
Stream<WeightedElement<ConstraintNode>> rootOnlyConstraintNodes = optionPicker.streamOptions(decisionNode) | ||
.map(option -> combineWithRootNode(rootWithoutDecision, option)) | ||
.filter(newNode -> !newNode.isContradictory()) | ||
.map(Merged::get); | ||
.filter(newNode -> !newNode.element().isContradictory()) | ||
.map(weighted -> new WeightedElement<>(weighted.element().get(), weighted.weight())); | ||
|
||
return flatMap( | ||
rootOnlyConstraintNodes, | ||
this::reduceToRowNodes); | ||
weightedConstraintNode -> reduceToRowNodes( | ||
new WeightedElement<ConstraintNode>( | ||
weightedConstraintNode.element(), | ||
weightedConstraintNode.weight() / rootNode.weight()))); | ||
} | ||
|
||
private Merged<ConstraintNode> combineWithRootNode(ConstraintNode rootNode, ConstraintNode option) { | ||
private WeightedElement<Merged<ConstraintNode>> combineWithRootNode(ConstraintNode rootNode, ConstraintNode option) { | ||
ConstraintNode constraintNode = rootNode.builder() | ||
.addDecisions(option.getDecisions()) | ||
.addAtomicConstraints(option.getAtomicConstraints()) | ||
.addRelations(option.getRelations()) | ||
.build(); | ||
|
||
return treePruner.pruneConstraintNode(constraintNode, getFields(option)); | ||
/* | ||
Find the relevance of this option in the context of the the tree. | ||
i.e. if this option says that a field must be equal to A in the possible set of A (20%), B (30%) and C (50%) yield the weighing of A (20%) | ||
applicabilityOfThisOption should yield 0.2 in this case (20% of 1) | ||
*/ | ||
double applicabilityOfThisOption = option.getAtomicConstraints().stream() | ||
.mapToDouble(optionAtomicConstraint -> rootNode.getAtomicConstraints().stream() | ||
.filter(rootAtomicConstraint -> rootAtomicConstraint.getField().equals(optionAtomicConstraint.getField())) | ||
.filter(rootAtomicConstraint -> rootAtomicConstraint instanceof InSetConstraint) | ||
.map(rootAtomicConstraint -> (InSetConstraint)rootAtomicConstraint) | ||
.findFirst() | ||
.map(matchingRootAtomicConstraint -> { | ||
double totalWeighting = getWeightOfAllLegalValues(matchingRootAtomicConstraint); | ||
double relevantWeighting = getWeightOfAllPermittedLegalValues(matchingRootAtomicConstraint, optionAtomicConstraint); | ||
|
||
return relevantWeighting / totalWeighting; | ||
Comment on lines
+115
to
+118
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe encapsulate this in a calculateApplicability method |
||
}) | ||
.orElse(1d)) | ||
.sum(); | ||
|
||
if (applicabilityOfThisOption > 1) { | ||
/* | ||
the applicability of this option (e.g. A) is greater than 100%, retrieve the fractional part of the number. | ||
if there is no fractional part, treat the applicability as 1 otherwise use the fractional part as the percentage. | ||
|
||
This can happen when other fields in the option match an inSet constraint, if they also have a weighting it | ||
will be taken into account, otherwise 1 (100%) would be used, hence 100% (field1) + 20% (field2) = 1.2. | ||
*/ | ||
|
||
double applicabilityFraction = applicabilityOfThisOption - (int) applicabilityOfThisOption; | ||
applicabilityOfThisOption = applicabilityFraction == 0 | ||
? 1 | ||
: applicabilityFraction; | ||
} | ||
|
||
if (applicabilityOfThisOption == 0){ | ||
/*no options were applicable, maybe an option was given 0%, yield contradictory so the option isn't processed.*/ | ||
|
||
return new WeightedElement<>( | ||
Merged.contradictory(), | ||
1 | ||
); | ||
} | ||
|
||
/*Yield a weighted element to inform the caller of the weighting of this constraint node*/ | ||
return new WeightedElement<>( | ||
treePruner.pruneConstraintNode(constraintNode, getFields(option)), | ||
applicabilityOfThisOption | ||
); | ||
} | ||
|
||
private static double getWeightOfAllLegalValues(InSetConstraint matchingRootAtomicConstraint){ | ||
return matchingRootAtomicConstraint.legalValues.stream() | ||
.mapToDouble(InSetRecord::getWeightValueOrDefault).sum(); | ||
} | ||
|
||
private static double getWeightOfAllPermittedLegalValues(InSetConstraint matchingRootAtomicConstraint, AtomicConstraint optionAtomicConstraint){ | ||
return matchingRootAtomicConstraint.legalValues.stream() | ||
.filter(legalValue -> optionAtomicConstraint.toFieldSpec().canCombineWithLegalValue(legalValue.getElement())) | ||
.mapToDouble(InSetRecord::getWeightValueOrDefault).sum(); | ||
} | ||
|
||
private Map<Field, FieldSpec> getFields(ConstraintNode option) { | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
package com.scottlogic.datahelix.generator.core.walker.rowspec; | ||
|
||
import com.google.inject.Inject; | ||
import com.scottlogic.datahelix.generator.common.distribution.WeightedElement; | ||
import com.scottlogic.datahelix.generator.core.decisiontree.DecisionTree; | ||
import com.scottlogic.datahelix.generator.core.fieldspecs.RowSpec; | ||
import com.scottlogic.datahelix.generator.core.generation.databags.DataBag; | ||
|
@@ -59,7 +60,7 @@ public Stream<DataBag> walk(DecisionTree tree) { | |
} | ||
|
||
private Stream<RowSpec> getFromCachedRowSpecs(DecisionTree tree) { | ||
List<RowSpec> rowSpecCache = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toList()); | ||
List<WeightedElement<RowSpec>> rowSpecCache = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toList()); | ||
return Stream.generate(() -> getRandomRowSpec(rowSpecCache)); | ||
} | ||
|
||
|
@@ -79,11 +80,43 @@ private Stream<RowSpec> getRowSpecAndRestart(DecisionTree tree) { | |
} | ||
|
||
private Optional<RowSpec> getFirstRowSpec(DecisionTree tree) { | ||
return rowSpecTreeSolver.createRowSpecs(tree).findFirst(); | ||
return rowSpecTreeSolver.createRowSpecs(tree).findFirst().map(WeightedElement::element); | ||
} | ||
|
||
private RowSpec getRandomRowSpec(List<RowSpec> rowSpecCache) { | ||
return rowSpecCache.get(random.nextInt(rowSpecCache.size())); | ||
/** | ||
* Get a row spec from the rowSpecCache in a weighted manner.<br /> | ||
* I.e. if there are 2 rowSpecs, one with a 70% weighting and the other with 30%.<br /> | ||
* Calling this method 10 times should *ROUGHLY* emit 7 of the 70% weighted rowSpecs and 3 of the others.<br /> | ||
* It does this by producing a virtual rowSpec 'range', i.e.<br /> | ||
* - values between 1 and 70 represent the 70% weighted rowSpec<br /> | ||
* - values between 71 and 100 represent the 30% weighted rowSpec<br /> | ||
* <br /> | ||
* The function then picks a random number between 1 and 100 and yields the rowSpec that encapsulates that value.<br /> | ||
* <br /> | ||
* As this method uses a random number generator, it will not ALWAYS yield a correct split, but it is more LIKELY than not.<br /> | ||
* @param rowSpecCache a list of weighted rowSpecs (weighting is between 0 and 1) | ||
* @return a rowSpec picked from the list of weighted rowSpecs | ||
*/ | ||
private RowSpec getRandomRowSpec(List<WeightedElement<RowSpec>> rowSpecCache) { | ||
double totalRange = rowSpecCache.stream() | ||
.mapToDouble(WeightedElement::weight).sum(); | ||
|
||
double nextRowSpecFromRange = random.nextInt((int)(totalRange * 100)) / 100d; | ||
|
||
double currentPosition = 0d; | ||
WeightedElement<RowSpec> lastRowSpec = null; | ||
|
||
for (WeightedElement<RowSpec> weightedRowSpec: rowSpecCache) { | ||
currentPosition += weightedRowSpec.weight(); | ||
|
||
if (currentPosition >= nextRowSpecFromRange) { | ||
return weightedRowSpec.element(); | ||
} | ||
|
||
lastRowSpec = weightedRowSpec; | ||
} | ||
|
||
return lastRowSpec.element(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we protect against a nullptr here in case lastRowSpec is null (eg rowSpecCache is empty could cause this). In which case return ??? (ideally some kind of 'empty' row spec rather than null - depends what calling code assumes). Or should we throw an exception as never expect this to be called with empty rowSepcCache. |
||
} | ||
|
||
private DataBag createDataBag(RowSpec rowSpec) { | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe make this more modular and change this to:
mapToDouble(optionAtomicConstraint -> findMatchingInSetConstraints(optionAtomicConstraint.getField(), rootNode.getAtomicConstraints())
and then define the method:
Stream findMatchingInSetConstraints(fieldtoMatch, atomicConstraints