diff --git a/core/src/main/java/com/scottlogic/datahelix/generator/core/profile/constraints/atomic/NotEqualToConstraint.java b/core/src/main/java/com/scottlogic/datahelix/generator/core/profile/constraints/atomic/NotEqualToConstraint.java index 90c39a3a2..da1b54185 100644 --- a/core/src/main/java/com/scottlogic/datahelix/generator/core/profile/constraints/atomic/NotEqualToConstraint.java +++ b/core/src/main/java/com/scottlogic/datahelix/generator/core/profile/constraints/atomic/NotEqualToConstraint.java @@ -48,7 +48,7 @@ public FieldSpec toFieldSpec() { @Override public String toString(){ - return String.format("`%s` = %s", field.getName(), value); + return String.format("`%s` != %s", field.getName(), value); } @Override diff --git a/core/src/main/java/com/scottlogic/datahelix/generator/core/walker/decisionbased/RowSpecTreeSolver.java b/core/src/main/java/com/scottlogic/datahelix/generator/core/walker/decisionbased/RowSpecTreeSolver.java index 013c06c4f..55de1669c 100644 --- a/core/src/main/java/com/scottlogic/datahelix/generator/core/walker/decisionbased/RowSpecTreeSolver.java +++ b/core/src/main/java/com/scottlogic/datahelix/generator/core/walker/decisionbased/RowSpecTreeSolver.java @@ -16,15 +16,16 @@ package com.scottlogic.datahelix.generator.core.walker.decisionbased; import com.google.inject.Inject; +import com.scottlogic.datahelix.generator.common.distribution.WeightedElement; import com.scottlogic.datahelix.generator.common.profile.Field; import com.scottlogic.datahelix.generator.common.profile.Fields; -import com.scottlogic.datahelix.generator.core.fieldspecs.FieldSpecFactory; +import com.scottlogic.datahelix.generator.common.profile.InSetRecord; +import com.scottlogic.datahelix.generator.core.fieldspecs.*; import com.scottlogic.datahelix.generator.core.profile.constraints.atomic.AtomicConstraint; import com.scottlogic.datahelix.generator.core.decisiontree.ConstraintNode; import com.scottlogic.datahelix.generator.core.decisiontree.DecisionNode; import com.scottlogic.datahelix.generator.core.decisiontree.DecisionTree; -import com.scottlogic.datahelix.generator.core.fieldspecs.FieldSpec; -import com.scottlogic.datahelix.generator.core.fieldspecs.RowSpec; +import com.scottlogic.datahelix.generator.core.profile.constraints.atomic.InSetConstraint; import com.scottlogic.datahelix.generator.core.reducer.ConstraintReducer; import com.scottlogic.datahelix.generator.core.walker.pruner.Merged; import com.scottlogic.datahelix.generator.core.walker.pruner.TreePruner; @@ -51,45 +52,114 @@ public RowSpecTreeSolver(ConstraintReducer constraintReducer, this.optionPicker = optionPicker; } - public Stream createRowSpecs(DecisionTree tree) { + public Stream> createRowSpecs(DecisionTree tree) { return flatMap(reduceToRowNodes(tree.rootNode), rootNode -> toRowspec(tree.fields, rootNode)); } - private Stream toRowspec(Fields fields, ConstraintNode rootNode) { - Optional result = constraintReducer.reduceConstraintsToRowSpec(fields, rootNode); - return result.map(Stream::of).orElseGet(Stream::empty); + private Stream> toRowspec(Fields fields, WeightedElement rootNode) { + Optional result = constraintReducer.reduceConstraintsToRowSpec(fields, rootNode.element()); + return result + .map(rowSpec -> new WeightedElement<>(rowSpec, rootNode.weight())) + .map(Stream::of) + .orElseGet(Stream::empty); } /** * a row node is a constraint node with no further decisions */ - private Stream reduceToRowNodes(ConstraintNode rootNode) { - if (rootNode.getDecisions().isEmpty()) { + private Stream> reduceToRowNodes(ConstraintNode rootNode) { + return reduceToRowNodes(new WeightedElement<>(rootNode, 1)); + } + + private Stream> reduceToRowNodes(WeightedElement rootNode) { + if (rootNode.element().getDecisions().isEmpty()) { return Stream.of(rootNode); } - DecisionNode decisionNode = optionPicker.pickDecision(rootNode); - ConstraintNode rootWithoutDecision = rootNode.builder().removeDecision(decisionNode).build(); + DecisionNode decisionNode = optionPicker.pickDecision(rootNode.element()); + ConstraintNode rootWithoutDecision = rootNode.element().builder().removeDecision(decisionNode).build(); - Stream rootOnlyConstraintNodes = optionPicker.streamOptions(decisionNode) + Stream> rootOnlyConstraintNodes = optionPicker.streamOptions(decisionNode) .map(option -> combineWithRootNode(rootWithoutDecision, option)) - .filter(newNode -> !newNode.isContradictory()) - .map(Merged::get); + .filter(newNode -> !newNode.element().isContradictory()) + .map(weighted -> new WeightedElement<>(weighted.element().get(), weighted.weight())); return flatMap( rootOnlyConstraintNodes, - this::reduceToRowNodes); + weightedConstraintNode -> reduceToRowNodes( + new WeightedElement( + weightedConstraintNode.element(), + weightedConstraintNode.weight() / rootNode.weight()))); } - private Merged combineWithRootNode(ConstraintNode rootNode, ConstraintNode option) { + private WeightedElement> combineWithRootNode(ConstraintNode rootNode, ConstraintNode option) { ConstraintNode constraintNode = rootNode.builder() .addDecisions(option.getDecisions()) .addAtomicConstraints(option.getAtomicConstraints()) .addRelations(option.getRelations()) .build(); - return treePruner.pruneConstraintNode(constraintNode, getFields(option)); + /* + Find the relevance of this option in the context of the the tree. + i.e. if this option says that a field must be equal to A in the possible set of A (20%), B (30%) and C (50%) yield the weighing of A (20%) + applicabilityOfThisOption should yield 0.2 in this case (20% of 1) + */ + double applicabilityOfThisOption = option.getAtomicConstraints().stream() + .mapToDouble(optionAtomicConstraint -> rootNode.getAtomicConstraints().stream() + .filter(rootAtomicConstraint -> rootAtomicConstraint.getField().equals(optionAtomicConstraint.getField())) + .filter(rootAtomicConstraint -> rootAtomicConstraint instanceof InSetConstraint) + .map(rootAtomicConstraint -> (InSetConstraint)rootAtomicConstraint) + .findFirst() + .map(matchingRootAtomicConstraint -> { + double totalWeighting = getWeightOfAllLegalValues(matchingRootAtomicConstraint); + double relevantWeighting = getWeightOfAllPermittedLegalValues(matchingRootAtomicConstraint, optionAtomicConstraint); + + return relevantWeighting / totalWeighting; + }) + .orElse(1d)) + .sum(); + + if (applicabilityOfThisOption > 1) { + /* + the applicability of this option (e.g. A) is greater than 100%, retrieve the fractional part of the number. + if there is no fractional part, treat the applicability as 1 otherwise use the fractional part as the percentage. + + This can happen when other fields in the option match an inSet constraint, if they also have a weighting it + will be taken into account, otherwise 1 (100%) would be used, hence 100% (field1) + 20% (field2) = 1.2. + */ + + double applicabilityFraction = applicabilityOfThisOption - (int) applicabilityOfThisOption; + applicabilityOfThisOption = applicabilityFraction == 0 + ? 1 + : applicabilityFraction; + } + + if (applicabilityOfThisOption == 0){ + /*no options were applicable, maybe an option was given 0%, yield contradictory so the option isn't processed.*/ + + return new WeightedElement<>( + Merged.contradictory(), + 1 + ); + } + + /*Yield a weighted element to inform the caller of the weighting of this constraint node*/ + return new WeightedElement<>( + treePruner.pruneConstraintNode(constraintNode, getFields(option)), + applicabilityOfThisOption + ); + } + + private static double getWeightOfAllLegalValues(InSetConstraint matchingRootAtomicConstraint){ + return matchingRootAtomicConstraint.legalValues.stream() + .mapToDouble(InSetRecord::getWeightValueOrDefault).sum(); + } + + private static double getWeightOfAllPermittedLegalValues(InSetConstraint matchingRootAtomicConstraint, AtomicConstraint optionAtomicConstraint){ + return matchingRootAtomicConstraint.legalValues.stream() + .filter(legalValue -> optionAtomicConstraint.toFieldSpec().canCombineWithLegalValue(legalValue.getElement())) + .mapToDouble(InSetRecord::getWeightValueOrDefault).sum(); } private Map getFields(ConstraintNode option) { diff --git a/core/src/main/java/com/scottlogic/datahelix/generator/core/walker/rowspec/RandomRowSpecDecisionTreeWalker.java b/core/src/main/java/com/scottlogic/datahelix/generator/core/walker/rowspec/RandomRowSpecDecisionTreeWalker.java index 7b3e51a15..cf3da1f38 100644 --- a/core/src/main/java/com/scottlogic/datahelix/generator/core/walker/rowspec/RandomRowSpecDecisionTreeWalker.java +++ b/core/src/main/java/com/scottlogic/datahelix/generator/core/walker/rowspec/RandomRowSpecDecisionTreeWalker.java @@ -17,6 +17,7 @@ package com.scottlogic.datahelix.generator.core.walker.rowspec; import com.google.inject.Inject; +import com.scottlogic.datahelix.generator.common.distribution.WeightedElement; import com.scottlogic.datahelix.generator.core.decisiontree.DecisionTree; import com.scottlogic.datahelix.generator.core.fieldspecs.RowSpec; import com.scottlogic.datahelix.generator.core.generation.databags.DataBag; @@ -59,7 +60,7 @@ public Stream walk(DecisionTree tree) { } private Stream getFromCachedRowSpecs(DecisionTree tree) { - List rowSpecCache = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toList()); + List> rowSpecCache = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toList()); return Stream.generate(() -> getRandomRowSpec(rowSpecCache)); } @@ -79,11 +80,43 @@ private Stream getRowSpecAndRestart(DecisionTree tree) { } private Optional getFirstRowSpec(DecisionTree tree) { - return rowSpecTreeSolver.createRowSpecs(tree).findFirst(); + return rowSpecTreeSolver.createRowSpecs(tree).findFirst().map(WeightedElement::element); } - private RowSpec getRandomRowSpec(List rowSpecCache) { - return rowSpecCache.get(random.nextInt(rowSpecCache.size())); + /** + * Get a row spec from the rowSpecCache in a weighted manner.
+ * I.e. if there are 2 rowSpecs, one with a 70% weighting and the other with 30%.
+ * Calling this method 10 times should *ROUGHLY* emit 7 of the 70% weighted rowSpecs and 3 of the others.
+ * It does this by producing a virtual rowSpec 'range', i.e.
+ * - values between 1 and 70 represent the 70% weighted rowSpec
+ * - values between 71 and 100 represent the 30% weighted rowSpec
+ *
+ * The function then picks a random number between 1 and 100 and yields the rowSpec that encapsulates that value.
+ *
+ * As this method uses a random number generator, it will not ALWAYS yield a correct split, but it is more LIKELY than not.
+ * @param rowSpecCache a list of weighted rowSpecs (weighting is between 0 and 1) + * @return a rowSpec picked from the list of weighted rowSpecs + */ + private RowSpec getRandomRowSpec(List> rowSpecCache) { + double totalRange = rowSpecCache.stream() + .mapToDouble(WeightedElement::weight).sum(); + + double nextRowSpecFromRange = random.nextInt((int)(totalRange * 100)) / 100d; + + double currentPosition = 0d; + WeightedElement lastRowSpec = null; + + for (WeightedElement weightedRowSpec: rowSpecCache) { + currentPosition += weightedRowSpec.weight(); + + if (currentPosition >= nextRowSpecFromRange) { + return weightedRowSpec.element(); + } + + lastRowSpec = weightedRowSpec; + } + + return lastRowSpec.element(); } private DataBag createDataBag(RowSpec rowSpec) { diff --git a/core/src/main/java/com/scottlogic/datahelix/generator/core/walker/rowspec/RowSpecDecisionTreeWalker.java b/core/src/main/java/com/scottlogic/datahelix/generator/core/walker/rowspec/RowSpecDecisionTreeWalker.java index 7601344b3..4c3819e49 100644 --- a/core/src/main/java/com/scottlogic/datahelix/generator/core/walker/rowspec/RowSpecDecisionTreeWalker.java +++ b/core/src/main/java/com/scottlogic/datahelix/generator/core/walker/rowspec/RowSpecDecisionTreeWalker.java @@ -16,6 +16,7 @@ package com.scottlogic.datahelix.generator.core.walker.rowspec; import com.google.inject.Inject; +import com.scottlogic.datahelix.generator.common.distribution.WeightedElement; import com.scottlogic.datahelix.generator.common.util.FlatMappingSpliterator; import com.scottlogic.datahelix.generator.core.decisiontree.DecisionTree; import com.scottlogic.datahelix.generator.core.generation.databags.DataBag; @@ -38,7 +39,7 @@ public RowSpecDecisionTreeWalker(RowSpecTreeSolver rowSpecTreeSolver, RowSpecDat @Override public Stream walk(DecisionTree tree) { return FlatMappingSpliterator.flatMap( - rowSpecTreeSolver.createRowSpecs(tree), + rowSpecTreeSolver.createRowSpecs(tree).map(WeightedElement::element), rowSpecDataBagGenerator::createDataBags); } } diff --git a/core/src/test/java/com/scottlogic/datahelix/generator/core/builders/TestAtomicConstraintBuilder.java b/core/src/test/java/com/scottlogic/datahelix/generator/core/builders/TestAtomicConstraintBuilder.java index cbc620d75..11b68a953 100644 --- a/core/src/test/java/com/scottlogic/datahelix/generator/core/builders/TestAtomicConstraintBuilder.java +++ b/core/src/test/java/com/scottlogic/datahelix/generator/core/builders/TestAtomicConstraintBuilder.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; public class TestAtomicConstraintBuilder { private TestConstraintNodeBuilder testConstraintNodeBuilder; @@ -56,6 +57,14 @@ public TestConstraintNodeBuilder isInSet(Object... legalValues) { return testConstraintNodeBuilder; } + public TestConstraintNodeBuilder isInSet(InSetRecord... weightedValues) { + InSetConstraint inSetConstraint = new InSetConstraint( + field, + Stream.of(weightedValues).collect(Collectors.toList())); + testConstraintNodeBuilder.constraints.add(inSetConstraint); + return testConstraintNodeBuilder; + } + public TestConstraintNodeBuilder isNotInSet(Object... legalValues) { AtomicConstraint isInSetConstraint = new InSetConstraint( field, diff --git a/core/src/test/java/com/scottlogic/datahelix/generator/core/decisiontree/RowSpecTreeSolverTests.java b/core/src/test/java/com/scottlogic/datahelix/generator/core/decisiontree/RowSpecTreeSolverTests.java index 78e39520f..9d794c21a 100644 --- a/core/src/test/java/com/scottlogic/datahelix/generator/core/decisiontree/RowSpecTreeSolverTests.java +++ b/core/src/test/java/com/scottlogic/datahelix/generator/core/decisiontree/RowSpecTreeSolverTests.java @@ -16,6 +16,7 @@ package com.scottlogic.datahelix.generator.core.decisiontree; +import com.scottlogic.datahelix.generator.common.distribution.WeightedElement; import com.scottlogic.datahelix.generator.common.profile.Field; import com.scottlogic.datahelix.generator.common.profile.Fields; import com.scottlogic.datahelix.generator.common.profile.ProfileFields; @@ -32,7 +33,6 @@ import com.scottlogic.datahelix.generator.core.walker.decisionbased.RowSpecTreeSolver; import com.scottlogic.datahelix.generator.core.walker.decisionbased.SequentialOptionPicker; import com.scottlogic.datahelix.generator.core.walker.pruner.TreePruner; -import org.junit.Assert; import org.junit.jupiter.api.Test; import java.util.ArrayList; @@ -47,6 +47,7 @@ import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static com.shazam.shazamcrest.MatcherAssert.assertThat; class RowSpecTreeSolverTests { private final FieldSpecMerger fieldSpecMerger = new FieldSpecMerger(); @@ -115,9 +116,10 @@ void test() final List rowSpecs = dTreeWalker .createRowSpecs(merged) + .map(WeightedElement::element) .collect(Collectors.toList()); - Assert.assertThat(rowSpecs, notNullValue()); + assertThat(rowSpecs, notNullValue()); } } diff --git a/core/src/test/java/com/scottlogic/datahelix/generator/core/walker/decisionbased/RowSpecTreeSolverTests.java b/core/src/test/java/com/scottlogic/datahelix/generator/core/walker/decisionbased/RowSpecTreeSolverTests.java index 8be26d51f..5e8bd0a33 100644 --- a/core/src/test/java/com/scottlogic/datahelix/generator/core/walker/decisionbased/RowSpecTreeSolverTests.java +++ b/core/src/test/java/com/scottlogic/datahelix/generator/core/walker/decisionbased/RowSpecTreeSolverTests.java @@ -15,8 +15,11 @@ */ package com.scottlogic.datahelix.generator.core.walker.decisionbased; +import com.scottlogic.datahelix.generator.common.distribution.DistributedList; +import com.scottlogic.datahelix.generator.common.distribution.WeightedElement; import com.scottlogic.datahelix.generator.common.profile.Field; import com.scottlogic.datahelix.generator.common.profile.Fields; +import com.scottlogic.datahelix.generator.common.profile.InSetRecord; import com.scottlogic.datahelix.generator.common.profile.ProfileFields; import com.scottlogic.datahelix.generator.core.builders.TestConstraintNodeBuilder; import com.scottlogic.datahelix.generator.core.decisiontree.ConstraintNode; @@ -33,6 +36,7 @@ import static com.scottlogic.datahelix.generator.common.profile.FieldBuilder.createField; import static com.shazam.shazamcrest.MatcherAssert.assertThat; import static com.shazam.shazamcrest.matcher.Matchers.sameBeanAs; +import static org.hamcrest.core.Is.is; class RowSpecTreeSolverTests { private Field fieldA = createField("A"); @@ -51,16 +55,17 @@ void createRowSpecs_whenRootNodeHasNoDecisions_returnsRowSpecOfRoot() { DecisionTree tree = new DecisionTree(root, fields); //Act - Stream rowSpecs = rowSpecTreeSolver.createRowSpecs(tree); + Stream> rowSpecs = rowSpecTreeSolver.createRowSpecs(tree); //Assert List expectedRowSpecs = new ArrayList<>(); - Map fieldToFieldSpec = new HashMap<>(); - fieldToFieldSpec.put(fieldA, FieldSpecFactory.fromType(fieldA.getType())); - fieldToFieldSpec.put(fieldB, FieldSpecFactory.fromType(fieldB.getType())); - expectedRowSpecs.add(new RowSpec(fields, fieldToFieldSpec, Collections.emptyList())); + expectedRowSpecs.add(createRowSpec( + FieldSpecFactory.fromType(fieldA.getType()), + FieldSpecFactory.fromType(fieldB.getType()))); - assertThat(expectedRowSpecs, sameBeanAs(rowSpecs.collect(Collectors.toList()))); + assertThat( + expectedRowSpecs, + sameBeanAs(rowSpecs.map(WeightedElement::element).collect(Collectors.toList()))); } @Test @@ -70,16 +75,17 @@ void createRowSpecs_whenRootNodeHasNoDecisionsButSomeConstraints_returnsRowSpecO DecisionTree tree = new DecisionTree(root, fields); //Act - Stream rowSpecs = rowSpecTreeSolver.createRowSpecs(tree); + Stream> rowSpecs = rowSpecTreeSolver.createRowSpecs(tree); //Assert List expectedRowSpecs = new ArrayList<>(); - Map fieldToFieldSpec = new HashMap<>(); - fieldToFieldSpec.put(fieldA, FieldSpecFactory.fromLegalValuesList(Arrays.asList("1", "2", "3"))); - fieldToFieldSpec.put(fieldB, FieldSpecFactory.fromType(fieldB.getType())); - expectedRowSpecs.add(new RowSpec(fields, fieldToFieldSpec, Collections.emptyList())); + expectedRowSpecs.add(createRowSpec( + FieldSpecFactory.fromLegalValuesList(Arrays.asList("1", "2", "3")), + FieldSpecFactory.fromType(fieldB.getType()))); - assertThat(rowSpecs.collect(Collectors.toList()), sameBeanAs(expectedRowSpecs)); + assertThat( + rowSpecs.map(WeightedElement::element).collect(Collectors.toList()), + sameBeanAs(expectedRowSpecs)); } @Test @@ -95,19 +101,162 @@ void createRowSpecs_whenRootNodeHasSomeDecisions_returnsRowSpecOfRoot() { DecisionTree tree = new DecisionTree(root, fields); //Act - Set rowSpecs = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toSet()); + Set> rowSpecs = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toSet()); //Assert Set expectedRowSpecs = new HashSet<>(); - Map option0 = new HashMap<>(); - option0.put(fieldA, FieldSpecFactory.fromType(fieldA.getType())); - option0.put(fieldB, FieldSpecFactory.nullOnly()); - expectedRowSpecs.add(new RowSpec(fields, option0, Collections.emptyList())); - Map option1 = new HashMap<>(); - option1.put(fieldA, FieldSpecFactory.fromType(fieldA.getType())); - option1.put(fieldB, FieldSpecFactory.fromLegalValuesList(Arrays.asList("1","2","3"))); - expectedRowSpecs.add(new RowSpec(fields, option1, Collections.emptyList())); - - assertThat(rowSpecs, sameBeanAs(expectedRowSpecs)); + expectedRowSpecs.add(createRowSpec( + FieldSpecFactory.fromType(fieldA.getType()), + FieldSpecFactory.nullOnly())); + expectedRowSpecs.add(createRowSpec( + FieldSpecFactory.fromType(fieldA.getType()), + FieldSpecFactory.fromLegalValuesList(Arrays.asList("1","2","3")))); + + assertThat( + rowSpecs.stream().map(WeightedElement::element).collect(Collectors.toSet()), + sameBeanAs(expectedRowSpecs)); + } + + @Test + void createRowSpecs_whenRootNodeHasSomeWeightedDecisions_returnsCorrectlyWeightedRowSpecs() { + //Arrange + ConstraintNode root = TestConstraintNodeBuilder.constraintNode() + .where(fieldB).isInSet( + new InSetRecord("1", 0.25d), + new InSetRecord("2", 0.75d)) + .withDecision( + TestConstraintNodeBuilder.constraintNode() + .where(fieldB).isInSet("1"), + TestConstraintNodeBuilder.constraintNode() + .where(fieldB).isInSet("2")) + .build(); + DecisionTree tree = new DecisionTree(root, fields); + + //Act + Set> rowSpecs = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toSet()); + + //Assert + Set> expectedRowSpecs = new HashSet<>(); + expectedRowSpecs.add(new WeightedElement<>( + createRowSpec( + FieldSpecFactory.fromType(fieldA.getType()), + FieldSpecFactory.fromList(distributedListOfOneItem("1", 0.25))), + 0.25)); + expectedRowSpecs.add(new WeightedElement<>( + createRowSpec( + FieldSpecFactory.fromType(fieldA.getType()), + FieldSpecFactory.fromList(distributedListOfOneItem("2", 0.75))), + 0.75)); + + assertThat( + rowSpecs, + sameBeanAs(expectedRowSpecs)); + } + + @Test + void createRowSpecs_whenRootNodeHasSomeWeightedDecisionsAndAllValuesAreExcluded_returnsCorrectlyWeightedRowSpecs() { + //Arrange + ConstraintNode root = TestConstraintNodeBuilder.constraintNode() + .where(fieldB).isInSet( + new InSetRecord("1", 0.25d), + new InSetRecord("2", 0.75d)) + .withDecision( + TestConstraintNodeBuilder.constraintNode() + .where(fieldB).isInSet("3"), + TestConstraintNodeBuilder.constraintNode() + .where(fieldB).isInSet("4")) + .build(); + DecisionTree tree = new DecisionTree(root, fields); + + //Act + Set> rowSpecs = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toSet()); + + //Assert + assertThat( + rowSpecs.isEmpty(), + is(true)); + } + + @Test + void createRowSpecs_whenRootNodeHasUnweightedDecisionsAndAllValuesAreAllowed_returnsCorrectlyWeightedRowSpecs() { + //Arrange + ConstraintNode root = TestConstraintNodeBuilder.constraintNode() + .where(fieldB).isInSet("1", "2") + .withDecision( + TestConstraintNodeBuilder.constraintNode() + .where(fieldB).isInSet("1", "2", "3"), + TestConstraintNodeBuilder.constraintNode() + .where(fieldB).isInSet("1", "2", "4")) + .build(); + DecisionTree tree = new DecisionTree(root, fields); + + //Act + Set> rowSpecs = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toSet()); + + //Assert + Set> expectedRowSpecs = new HashSet<>(); + expectedRowSpecs.add( + new WeightedElement<>( + createRowSpec( + FieldSpecFactory.fromType(fieldA.getType()), + FieldSpecFactory.fromLegalValuesList(Arrays.asList("1", "2"))), + 1)); + expectedRowSpecs.add(new WeightedElement<>( + createRowSpec( + FieldSpecFactory.fromType(fieldA.getType()), + FieldSpecFactory.fromLegalValuesList(Arrays.asList("1", "2"))), + 1)); + + assertThat( + rowSpecs, + sameBeanAs(expectedRowSpecs)); + } + + @Test + void createRowSpecs_whenRootNodeHasNoWeightedDecisions_returnsCorrectlyWeightedRowSpecs() { + //Arrange + ConstraintNode root = TestConstraintNodeBuilder.constraintNode() + .where(fieldB).isNotNull() + .withDecision( + TestConstraintNodeBuilder.constraintNode() + .where(fieldB).isInSet("1"), + TestConstraintNodeBuilder.constraintNode() + .where(fieldB).isInSet("2")) + .build(); + DecisionTree tree = new DecisionTree(root, fields); + + //Act + Set> rowSpecs = rowSpecTreeSolver.createRowSpecs(tree).collect(Collectors.toSet()); + + //Assert + Set> expectedRowSpecs = new HashSet<>(); + expectedRowSpecs.add(new WeightedElement<>( + createRowSpec( + FieldSpecFactory.fromType(fieldA.getType()), + FieldSpecFactory.fromLegalValuesList(Collections.singletonList("1")).withNotNull()), + 1)); + expectedRowSpecs.add(new WeightedElement<>( + createRowSpec( + FieldSpecFactory.fromType(fieldA.getType()), + FieldSpecFactory.fromLegalValuesList(Collections.singletonList("2")).withNotNull()), + 1)); + + assertThat( + rowSpecs, + sameBeanAs(expectedRowSpecs)); + } + + private RowSpec createRowSpec(FieldSpec fieldSpecA, FieldSpec fieldSpecB){ + Map option = new HashMap<>(); + option.put(fieldA, fieldSpecA); + option.put(fieldB, fieldSpecB); + return new RowSpec(fields, option, Collections.emptyList()); + } + + private static DistributedList distributedListOfOneItem(T item, double weight) { + ArrayList> list = new ArrayList<>(); + list.add(new WeightedElement<>(item, weight)); + + return new DistributedList<>(list); } }