Skip to content

Commit

Permalink
iter
Browse files Browse the repository at this point in the history
  • Loading branch information
jimczi committed Dec 2, 2024
1 parent 184174f commit f204cc3
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@
* the inference results under the {@link SemanticTextField#INFERENCE_FIELD}.
*
* @param fieldName The original field name.
* @param originalValues The original values associated with the field name.
* @param originalValues The original values associated with the field name for indices created before {@link IndexVersions#INFERENCE_METADATA_FIELDS}, null otherwise.
* @param inference The inference result.
* @param contentType The {@link XContentType} used to store the embeddings chunks.
*/
public record SemanticTextField(
IndexVersion indexCreatedVersion,
String fieldName,
List<String> originalValues,
@Nullable List<String> originalValues,
InferenceResult inference,
XContentType contentType
) implements ToXContentObject {
Expand Down Expand Up @@ -274,17 +274,22 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<SemanticTextField, ParserContext> SEMANTIC_TEXT_FIELD_PARSER =
new ConstructingObjectParser<>(
SemanticTextFieldMapper.CONTENT_TYPE,
true,
(args, context) -> new SemanticTextField(
new ConstructingObjectParser<>(SemanticTextFieldMapper.CONTENT_TYPE, true, (args, context) -> {
List<String> originalValues = (List<String>) args[0];
if (context.indexVersionCreated.onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS)) {
if (originalValues != null && originalValues.isEmpty() == false) {
throw new IllegalArgumentException("Unknown field [" + TEXT_FIELD + "]");
}
originalValues = null;
}
return new SemanticTextField(
context.indexVersionCreated(),
context.fieldName(),
(List<String>) (args[0] == null ? List.of() : args[0]),
originalValues,
(InferenceResult) args[1],
context.xContentType()
)
);
);
});

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<InferenceResult, ParserContext> INFERENCE_RESULT_PARSER = new ConstructingObjectParser<>(
Expand Down Expand Up @@ -332,13 +337,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
(p, c) -> MODEL_SETTINGS_PARSER.parse(p, null),
new ParseField(MODEL_SETTINGS_FIELD)
);
INFERENCE_RESULT_PARSER.declareObject(constructorArg(), (p, c) -> {
INFERENCE_RESULT_PARSER.declareField(constructorArg(), (p, c) -> {
if (c.indexVersionCreated.onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS)) {
return parseChunksMap(p);
} else {
return Map.of(c.fieldName, parseChunksArrayLegacy(p));
}
}, new ParseField(CHUNKS_FIELD));
}, new ParseField(CHUNKS_FIELD), ObjectParser.ValueType.OBJECT_ARRAY);

CHUNKS_PARSER.declareString(optionalConstructorArg(), new ParseField(TEXT_FIELD));
CHUNKS_PARSER.declareInt(optionalConstructorArg(), new ParseField(CHUNKED_START_OFFSET_FIELD));
Expand Down Expand Up @@ -372,7 +377,7 @@ private static Map<String, List<Chunk>> parseChunksMap(XContentParser parser) th

private static List<Chunk> parseChunksArrayLegacy(XContentParser parser) throws IOException {
List<Chunk> results = new ArrayList<>();
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.nextToken(), parser);
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
results.add(CHUNKS_PARSER.parse(parser, null));
}
Expand All @@ -397,7 +402,7 @@ public static List<Chunk> toSemanticTextFieldChunks(
chunks.add(
new Chunk(
withOffsets ? null : input,
startOffset,
withOffsets ? startOffset : -1,
withOffsets ? startOffset + chunkAsByteReference.matchedText().length() : -1,
chunkAsByteReference.bytesReference()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.test.index.IndexVersionUtils;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
Expand All @@ -41,56 +42,68 @@
public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTextField> {
private static final String NAME = "field";

private IndexVersion currentIndexVersion;

@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return n -> n.endsWith(CHUNKED_EMBEDDINGS_FIELD);
}

@Override
protected void assertEqualInstances(SemanticTextField expectedInstance, SemanticTextField newInstance) {
assertThat(newInstance.indexCreatedVersion(), equalTo(newInstance.indexCreatedVersion()));
assertThat(newInstance.fieldName(), equalTo(expectedInstance.fieldName()));
assertThat(newInstance.originalValues(), equalTo(expectedInstance.originalValues()));
assertThat(newInstance.inference().modelSettings(), equalTo(expectedInstance.inference().modelSettings()));
assertThat(newInstance.inference().chunks().size(), equalTo(expectedInstance.inference().chunks().size()));
SemanticTextField.ModelSettings modelSettings = newInstance.inference().modelSettings();
for (int i = 0; i < newInstance.inference().chunks().size(); i++) {
/* assertThat(newInstance.inference().chunks().get(i).text(), equalTo(expectedInstance.inference().chunks().get(i).text()));
switch (modelSettings.taskType()) {
case TEXT_EMBEDDING -> {
double[] expectedVector = parseDenseVector(
expectedInstance.inference().chunks().get(i).rawEmbeddings(),
modelSettings.dimensions(),
expectedInstance.contentType()
);
double[] newVector = parseDenseVector(
newInstance.inference().chunks().get(i).rawEmbeddings(),
modelSettings.dimensions(),
newInstance.contentType()
);
assertArrayEquals(expectedVector, newVector, 0.0000001f);
}
case SPARSE_EMBEDDING -> {
List<WeightedToken> expectedTokens = parseWeightedTokens(
expectedInstance.inference().chunks().get(i).rawEmbeddings(),
expectedInstance.contentType()
);
List<WeightedToken> newTokens = parseWeightedTokens(
newInstance.inference().chunks().get(i).rawEmbeddings(),
newInstance.contentType()
);
assertThat(newTokens, equalTo(expectedTokens));
for (var entry : newInstance.inference().chunks().entrySet()) {
var expectedChunks = expectedInstance.inference().chunks().get(entry.getKey());
assertNotNull(expectedChunks);
assertThat(entry.getValue().size(), equalTo(expectedChunks.size()));
for (int i = 0; i < entry.getValue().size(); i++) {
var actualChunk = entry.getValue().get(i);
assertThat(actualChunk.text(), equalTo(expectedChunks.get(i).text()));
assertThat(actualChunk.startOffset(), equalTo(expectedChunks.get(i).startOffset()));
assertThat(actualChunk.endOffset(), equalTo(expectedChunks.get(i).endOffset()));
switch (modelSettings.taskType()) {
case TEXT_EMBEDDING -> {
double[] expectedVector = parseDenseVector(
expectedChunks.get(i).rawEmbeddings(),
modelSettings.dimensions(),
expectedInstance.contentType()
);
double[] newVector = parseDenseVector(
actualChunk.rawEmbeddings(),
modelSettings.dimensions(),
newInstance.contentType()
);
assertArrayEquals(expectedVector, newVector, 0.0000001f);
}
case SPARSE_EMBEDDING -> {
List<WeightedToken> expectedTokens = parseWeightedTokens(
expectedChunks.get(i).rawEmbeddings(),
expectedInstance.contentType()
);
List<WeightedToken> newTokens = parseWeightedTokens(actualChunk.rawEmbeddings(), newInstance.contentType());
assertThat(newTokens, equalTo(expectedTokens));
}
default -> throw new AssertionError("Invalid task type " + modelSettings.taskType());
}
default -> throw new AssertionError("Invalid task type " + modelSettings.taskType());
}**/
}
}
}

@Override
protected SemanticTextField createTestInstance() {
currentIndexVersion = randomFrom(
IndexVersionUtils.randomPreviousCompatibleVersion(random(), IndexVersions.INFERENCE_METADATA_FIELDS),
IndexVersionUtils.randomVersionBetween(random(), IndexVersions.INFERENCE_METADATA_FIELDS, IndexVersion.current())
);
List<String> rawValues = randomList(1, 5, () -> randomSemanticTextInput().toString());
try { // try catch required for override
return randomSemanticText(
IndexVersion.current(),
currentIndexVersion,
NAME,
TestModel.createRandomInstance(),
rawValues,
Expand All @@ -104,12 +117,12 @@ protected SemanticTextField createTestInstance() {

@Override
protected SemanticTextField doParseInstance(XContentParser parser) throws IOException {
return SemanticTextField.parse(parser, new SemanticTextField.ParserContext(IndexVersion.current(), NAME, parser.contentType()));
return SemanticTextField.parse(parser, new SemanticTextField.ParserContext(currentIndexVersion, NAME, parser.contentType()));
}

@Override
protected boolean supportsUnknownFields() {
return true;
return false;
}

public void testModelSettingsValidation() {
Expand Down Expand Up @@ -218,7 +231,7 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults(
return new SemanticTextField(
indexVersion,
fieldName,
inputs,
indexVersion.onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS) ? null : inputs,
new SemanticTextField.InferenceResult(
model.getInferenceEntityId(),
new SemanticTextField.ModelSettings(model),
Expand Down

0 comments on commit f204cc3

Please sign in to comment.