Skip to content
This repository has been archived by the owner on Apr 14, 2023. It is now read-only.

Commit

Permalink
Merge pull request #1687 from Ro4052/fix/weighted-distribution
Browse files Browse the repository at this point in the history
fix(#1552): use user provided weights instead of uniform distribution
  • Loading branch information
tjohnson-scottlogic authored Jul 16, 2020
2 parents 3c4f9ec + 14aab5c commit 8d4ea5f
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ public static <T> DistributedList<T> singleton(final T element) {
return DistributedList.uniform(Collections.singleton(element));
}

public static <T> DistributedList<T> weightedOrDefault(final Collection<T> underlyingSet) {
return new DistributedList<>(
underlyingSet.stream()
.map(element -> element instanceof WeightedElement
? (WeightedElement<T>) element
: WeightedElement.withDefaultWeight(element)
)
.collect(Collectors.toList()));
}

public static <T> DistributedList<T> uniform(final Collection<T> underlyingSet) {
return new DistributedList<>(
underlyingSet.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.scottlogic.datahelix.generator.common.whitelist;

import java.util.Objects;
import java.util.function.Function;

/**
* Wrapper containing specified element with a weight.
Expand Down Expand Up @@ -59,6 +60,10 @@ public static <T> WeightedElement<T> ofNull() {
return (WeightedElement<T>) NULL;
}

public static Object parseValue(WeightedElement element, Function<Object, Object> parse) {
return new WeightedElement<>(parse.apply(element.element()), element.weight());
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,29 @@ public void testUniformGeneratesUniformDistribution() {
assertEquals(manualSet, uniformSet);
}

@Test
public void testWeightedOrDefaultPassesThroughWeightedElements() {
WeightedElement<String> firstManual = new WeightedElement<>("first", 0.2);
WeightedElement<String> secondManual = new WeightedElement<>("second", 0.3);
WeightedElement<String> thirdManual = new WeightedElement<>("third", 0.5);

WeightedElement<String> first = new WeightedElement<>("first", 2);
WeightedElement<String> second = new WeightedElement<>("second", 3);
WeightedElement<String> third = new WeightedElement<>("third", 5);

List<WeightedElement<String>> manualElements = Arrays.asList(
firstManual,
secondManual,
thirdManual
);
DistributedList<String> manualSet = new DistributedList<>(manualElements);

List<WeightedElement<String>> elements = Arrays.asList(first, second, third);
DistributedList<WeightedElement<String>> weightedSet = DistributedList.weightedOrDefault(elements);

assertEquals(manualSet, weightedSet);
}

private DistributedList<String> prepareTwoElementSet() {
List<WeightedElement<String>> holders = Stream.of("first", "second", "third", "fourth")
.map(WeightedElement::withDefaultWeight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,17 @@ public void testConstructorRetainsWeight() {
assertEquals(weight, weightedElement.weight());
}

@Test
public void testParseValueParsesElement() {
final int element = 1;
final double weight = 1D;
WeightedElement<Integer> weightedElement = new WeightedElement<>(element, weight);

WeightedElement parsedElement = (WeightedElement) WeightedElement.parseValue(
weightedElement,
e -> Integer.toString((int) e)
);

assertTrue(parsedElement.element() instanceof String);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.scottlogic.datahelix.generator.common.ValidationException;
import com.scottlogic.datahelix.generator.common.profile.Field;
import com.scottlogic.datahelix.generator.common.profile.Fields;
import com.scottlogic.datahelix.generator.common.whitelist.WeightedElement;
import com.scottlogic.datahelix.generator.core.fieldspecs.relations.InMapRelation;
import com.scottlogic.datahelix.generator.common.whitelist.DistributedList;
import com.scottlogic.datahelix.generator.core.profile.constraints.atomic.*;
Expand Down Expand Up @@ -107,9 +108,12 @@ public InMapRelation createInMapRelation(InMapConstraintDTO dto, Fields fields)

private InSetConstraint createInSetConstraint(InSetConstraintDTO dto, Field field)
{
DistributedList<Object> values = DistributedList.uniform(dto.values.stream()
DistributedList<Object> values = DistributedList.weightedOrDefault(dto.values.stream()
.distinct()
.map(this::parseValue)
.map(value -> (value instanceof WeightedElement)
? WeightedElement.parseValue((WeightedElement) value, this::parseValue)
: this.parseValue(value)
)
.collect(Collectors.toList()));
return new InSetConstraint(field, values);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -84,7 +85,7 @@ private InMapConstraintDTO map(InMapFromFileConstraintDTO dto)

private InSetConstraintDTO map(InSetFromFileConstraintDTO dto)
{
List<Object> values = fileReader.setFromFile(getFile(dto.file)).stream().collect(Collectors.toList());
List<Object> values = new ArrayList<>(fileReader.setFromFile(getFile(dto.file)).distributedList());
InSetConstraintDTO inSetConstraintDTO = new InSetConstraintDTO();
inSetConstraintDTO.field = dto.field;
inSetConstraintDTO.values = values;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.scottlogic.datahelix.generator.common.profile.FieldType;
import com.scottlogic.datahelix.generator.common.validators.ValidationResult;
import com.scottlogic.datahelix.generator.common.whitelist.WeightedElement;
import com.scottlogic.datahelix.generator.profile.dtos.FieldDTO;
import com.scottlogic.datahelix.generator.profile.dtos.constraints.atomic.AtomicConstraintDTO;
import com.scottlogic.datahelix.generator.profile.validators.profile.ConstraintValidator;
Expand Down Expand Up @@ -59,7 +60,9 @@ ValidationResult fieldTypeMustMatchValueType(T dto, Object value)
{
return ValidationResult.failure("Value " + value + " must be a boolean" + getErrorInfo(dto));
}
if (!(value instanceof Number || value instanceof String && isNumber((String)value)) && fieldType == FieldType.NUMERIC)
if (!(value instanceof Number || value instanceof String && isNumber((String)value) ||
value instanceof WeightedElement && isNumber((String)((WeightedElement) value).element())) &&
fieldType == FieldType.NUMERIC)
{
return ValidationResult.failure("Value " + value + " must be a number" + getErrorInfo(dto));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.scottlogic.datahelix.generator.common.whitelist.WeightedElement;
import com.scottlogic.datahelix.generator.profile.dtos.constraints.ConstraintDTO;
import com.scottlogic.datahelix.generator.profile.dtos.constraints.atomic.EqualToConstraintDTO;
import com.scottlogic.datahelix.generator.profile.dtos.constraints.atomic.GranularToConstraintDTO;
Expand Down Expand Up @@ -110,7 +111,25 @@ public void shouldDeserialiseInSetCsvFileWithoutException() throws IOException {
// Assert
InSetConstraintDTO expected = new InSetConstraintDTO();
expected.field = "country";
expected.values = Collections.singletonList("test");
expected.values = Collections.singletonList(new WeightedElement<>("test", 1.0));

assertThat(actual, sameBeanAs(expected));
}

@Test
public void shouldDeserialiseWeightedInSetCsvFile() throws IOException {
// Arrange
final String json = "{\"field\": \"country\", \"inSet\": \"countries.csv\" }";
// Act
ConstraintDTO actual = deserialiseJsonString(new TestFileReader(true), json);

// Assert
InSetConstraintDTO expected = new InSetConstraintDTO();
expected.field = "country";
expected.values = Arrays.asList(
new WeightedElement<>("test1", 0.2),
new WeightedElement<>("test2", 0.8)
);

assertThat(actual, sameBeanAs(expected));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,44 @@
package com.scottlogic.datahelix.generator.profile;

import com.scottlogic.datahelix.generator.common.whitelist.DistributedList;
import com.scottlogic.datahelix.generator.common.whitelist.WeightedElement;
import com.scottlogic.datahelix.generator.profile.reader.FileReader;

import java.io.File;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class TestFileReader extends FileReader {
private final boolean weighted;

public TestFileReader() {
super(null);
weighted = false;
}

public TestFileReader(boolean weighted) {
super(null);
this.weighted = weighted;
}

@Override
public DistributedList<Object> setFromFile(File file) {
return DistributedList.uniform(Collections.singleton("test"));
return weighted
? this.getDistributedListWithWeights()
: DistributedList.uniform(Collections.singleton("test"));
}
@Override
public DistributedList<String> listFromMapFile(File file, String key) {
return DistributedList.uniform(Collections.singleton("test"));
}

private static DistributedList<Object> getDistributedListWithWeights() {
List<Object> elements = Arrays.asList(
new WeightedElement<>("test1", 20),
new WeightedElement<>("test2", 80)
);
return DistributedList.weightedOrDefault(elements);
}
}

0 comments on commit 8d4ea5f

Please sign in to comment.