Skip to content

Commit

Permalink
llmexample is now also an OAEI matcher
Browse files Browse the repository at this point in the history
  • Loading branch information
sven-h committed Sep 19, 2023
1 parent 2c7f11f commit 5c4f1d4
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 65 deletions.
67 changes: 67 additions & 0 deletions examples/llm-transformers/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>

<oaei.mainClass>de.uni_mannheim.informatik.dws.melt.examples.llm_transformers.OLaLaForOAEI</oaei.mainClass><!-- mandatory: this class has to implement IOntologyMatchingToolBridge -->
<oaei.copyright>(C) Mannheim, 2021</oaei.copyright> <!--optional copyright appearing in the seals descriptor file -->
<oaei.license>GNU Lesser General Public License 2.1 or above</oaei.license> <!--optional license appearing in the seals descriptor file -->


<maven.deploy.skip>true</maven.deploy.skip><!-- needed to call mvn deploy without having a distributionManagement -->
<matching.version>3.4-SNAPSHOT</matching.version> <!-- version for all matching related packages -->
</properties>
Expand All @@ -23,6 +28,13 @@
<version>${matching.version}</version>
</dependency>

<!-- This dependency is necessary for web submission. It contains the server wrapper. -->
<dependency>
<groupId>de.uni-mannheim.informatik.dws.melt</groupId>
<artifactId>receiver-http</artifactId>
<version>${matching.version}</version>
</dependency>

<!-- contains alcomo filter -->
<dependency>
<groupId>de.uni-mannheim.informatik.dws.melt</groupId>
Expand Down Expand Up @@ -199,6 +211,61 @@
</execution>
</executions>
</plugin>



<!-- the following plugin will generate a docker image and save it into the target folder -->
<!-- to work with podman instead of docker, execute (replace {dashdash} with two dashes):
podman system service {dashdash}time=0 unix:/run/user/$(id -u)/podman/podman.sock
and then export DOCKER_HOST="unix:/run/user/$(id -u)/podman/podman.sock"
https://github.com/fabric8io/docker-maven-plugin/issues/1330 -->
<!--uncomment this if you want to build the docker image
you also need to replace {dash}{dash}no-install-recommends with the corresponding real dashes
<plugin>
<groupId>io.fabric8</groupId>
<artifactId>docker-maven-plugin</artifactId>
<version>0.43.4</version>
<configuration>
<images>
<image>
<name>%a-%v-web</name>
<build>
<from>nvidia/cuda:11.6.2-base-ubuntu20.04</from>
<runCmds>
<run>apt update</run>
<run>apt install default-jre python3 python3-pip python-is-python3 {dash}{dash}no-install-recommends -y</run>
<run>pip install torch numpy scikit-learn pandas gensim flask "Werkzeug==2.2.3" sentencepiece "protobuf==3.20.1" accelerate bitsandbytes transformers sentence-transformers</run>
<run>apt remove python3-pip -y</run>
<run>rm -rf /var/lib/apt/lists/*</run>
</runCmds>
<optimise>true</optimise>
<assembly><descriptorRef>web</descriptorRef></assembly>
<cmd><shell>java -cp "${project.build.finalName}.${project.packaging}:lib/*" de.uni_mannheim.informatik.dws.melt.receiver_http.Main</shell></cmd>
<workdir>/maven</workdir>
<ports><port>8080</port></ports>
</build>
</image>
</images>
</configuration>
<dependencies>
<dependency>
<groupId>de.uni-mannheim.informatik.dws.melt</groupId>
<artifactId>matching-assembly</artifactId>
<version>${matching.version}</version>
</dependency>
</dependencies>
<executions>
<execution>
<goals>
<goal>build</goal>
<goal>save</goal>
</goals>
<phase>install</phase>
</execution>
</executions>
</plugin>
-->
</plugins>
</build>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
public class CLIOptions {
private static final Logger LOGGER = LoggerFactory.getLogger(CLIOptions.class);

private static final List<String> PREDEFINED_PROMPTS = createPredefinedPrompts();
public static final List<String> PREDEFINED_PROMPTS = createPredefinedPrompts();
private static List<String> createPredefinedPrompts(){
List<String> prompts = new ArrayList<>();

Expand Down Expand Up @@ -197,6 +197,29 @@ private static List<Entry<String, TextExtractorMap>> createTextExtractors(){
extractors.add(new SimpleEntry<>("TE16VerbalizedRDFtruetrue",
wrap(false, new TextExtractorVerbalizedRDF(true, true)))); //16


//more "resource description" text extractors:
extractors.add(new SimpleEntry<>("TE17", //17
wrap(false, new TextExtractorResourceDescriptionInRDF()
.setStatementProcessor(TextExtractorResourceDescriptionInRDF.SKIP_DEFINITIONS))));
extractors.add(new SimpleEntry<>("TE18", //18
wrap(false, new TextExtractorResourceDescriptionInRDF()
.setStatementProcessor(TextExtractorResourceDescriptionInRDF.SKIP_DEFINITIONS_AND_LONG_LITERALS))));
extractors.add(new SimpleEntry<>("TE19", //19
wrap(false, new TextExtractorResourceDescriptionInRDF()
.setStatementProcessor(TextExtractorResourceDescriptionInRDF.SKIP_DEFINITIONS_AND_SHORTEN_LONG_LITERALS))));

extractors.add(new SimpleEntry<>("TE20", //20
wrap(false, new TextExtractorResourceDescriptionInRDF(true, RDFFormat.TURTLE)
.setStatementProcessor(TextExtractorResourceDescriptionInRDF.SKIP_DEFINITIONS))));
extractors.add(new SimpleEntry<>("TE21", //21
wrap(false, new TextExtractorResourceDescriptionInRDF(true, RDFFormat.TURTLE)
.setStatementProcessor(TextExtractorResourceDescriptionInRDF.SKIP_DEFINITIONS_AND_LONG_LITERALS))));
extractors.add(new SimpleEntry<>("TE22", //22
wrap(false, new TextExtractorResourceDescriptionInRDF(true, RDFFormat.TURTLE)
.setStatementProcessor(TextExtractorResourceDescriptionInRDF.SKIP_DEFINITIONS_AND_SHORTEN_LONG_LITERALS))));


return extractors;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
import de.uni_mannheim.informatik.dws.melt.matching_eval.evaluator.EvaluatorCopyResults;
import de.uni_mannheim.informatik.dws.melt.matching_eval.evaluator.EvaluatorRank;
import de.uni_mannheim.informatik.dws.melt.matching_eval.paramtuning.ConfidenceFinder;
import de.uni_mannheim.informatik.dws.melt.matching_jena.MatcherPipelineYAAAJenaConstructor;
import de.uni_mannheim.informatik.dws.melt.matching_jena.TextExtractor;
import de.uni_mannheim.informatik.dws.melt.matching_jena.TextExtractorMap;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.elementlevel.HighPrecisionMatcher;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.filter.AdditionalConfidenceFilter;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.filter.BadHostsFilter;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.filter.ConfidenceFilter;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.filter.extraction.MaxWeightBipartiteExtractor;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.filter.extraction.NaiveDescendingExtractor;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.metalevel.AddAlignmentExtensions;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.metalevel.AddAlignmentMatcher;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.metalevel.ConfidenceCombiner;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.metalevel.ForwardAlwaysMatcher;
Expand All @@ -32,6 +35,7 @@
import de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers.SentenceTransformersMatcher;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.Alignment;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.Correspondence;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.DefaultExtensions;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -224,13 +228,8 @@ private static void run(CLIOptions cliOptions) throws Exception {
configurationName + "addhighprec"
);

if(testCase.getTrack().getName().equals("conference")){
addFixedConfidenceAndOneToOne(testCaseResults, testCase, configurationName);
addFixedConfidenceAndOneToOne(testCaseResults, testCase, configurationName + "addhighprec");
}else{
addConfidenceAndOneToOne(testCaseResults, testCase, configurationName);
addConfidenceAndOneToOne(testCaseResults, testCase, configurationName + "addhighprec");
}
addConfidenceAndOneToOne(testCaseResults, testCase, configurationName);
addConfidenceAndOneToOne(testCaseResults, testCase, configurationName + "addhighprec");


if(cliOptions.isChoose()){
Expand Down Expand Up @@ -263,23 +262,6 @@ private static void run(CLIOptions cliOptions) throws Exception {

}

/*
Executor.runMatcherOnTop(testCaseResults, configurationName,
new ConfidenceFilter(0.5),
configurationName + "_CutConfidence0.5"
);
Executor.runMatcherOnTop(testCaseResults, configurationName + "_CutConfidence0.5",
new MaxWeightBipartiteExtractor(),
configurationName + "_CutConfidence0.5_OneOne"
);
Executor.runMatcherOnTop(testCaseResults, configurationName,
new AlcomoFilter(), configurationName + "_Alcomo");
Executor.runMatcherOnTop(testCaseResults, configurationName + "_Alcomo",
new MaxWeightBipartiteExtractor(), configurationName + "_Alcomo_OneOne");
*/
ers.addAll(testCaseResults);
}
}
Expand All @@ -298,61 +280,48 @@ private static void run(CLIOptions cliOptions) throws Exception {
new EvaluatorCSV(ers).writeToDirectory(resultsDir);
LOGGER.info("Finish evaluating");
}


private static void addFixedConfidenceAndOneToOne(ExecutionResultSet testCaseResults, TestCase testCase, String oldMatcherName){
Executor.runMatcherOnTop(testCaseResults, oldMatcherName,
new MaxWeightBipartiteExtractor(),
oldMatcherName + "_OneOne"
);
for(double d : Arrays.asList(0.5, 0.6, 0.7, 0.8, 0.9)){
Executor.runMatcherOnTop(testCaseResults, oldMatcherName + "_OneOne",
new ConfidenceFilter(d),
oldMatcherName + "_OneOne_CutConfidence" + d
);
}
}

private static void addConfidenceAndOneToOne(ExecutionResultSet testCaseResults, TestCase testCase, String oldMatcherName){

double bestConfidenceCrossEncoderF1 = ConfidenceFinder.getBestConfidenceForFmeasure(testCase.getParsedReferenceAlignment(),
testCaseResults.get(testCase, oldMatcherName).getSystemAlignment(),
GoldStandardCompleteness.PARTIAL_SOURCE_COMPLETE_TARGET_COMPLETE);
findBestConfidence(testCaseResults, testCase, oldMatcherName);

Executor.runMatcherOnTop(testCaseResults, oldMatcherName,
new ConfidenceFilter(bestConfidenceCrossEncoderF1),
oldMatcherName + "_CutBestConfidence" + bestConfidenceCrossEncoderF1
);

Executor.runMatcherOnTop(testCaseResults, oldMatcherName,
Executor.runMatcherOnTop(testCaseResults, testCase, oldMatcherName,
new MaxWeightBipartiteExtractor(),
oldMatcherName + "_OneOne"
);

double bestConfidence = ConfidenceFinder.getBestConfidenceForFmeasure(testCase.getParsedReferenceAlignment(),
testCaseResults.get(testCase, oldMatcherName + "_OneOne").getSystemAlignment(),
GoldStandardCompleteness.PARTIAL_SOURCE_COMPLETE_TARGET_COMPLETE);

Executor.runMatcherOnTop(testCaseResults, oldMatcherName + "_OneOne",
new ConfidenceFilter(bestConfidence),
oldMatcherName + "_OneOne_CutBestConfidence" + bestConfidence

Executor.runMatcherOnTop(testCaseResults, testCase, oldMatcherName,
new NaiveDescendingExtractor(),
oldMatcherName + "_OneOneNaive"
);

double bestConfidenceComplete = ConfidenceFinder.getBestConfidenceForFmeasure(testCase.getParsedReferenceAlignment(),
testCaseResults.get(testCase, oldMatcherName + "_OneOne").getSystemAlignment(),
GoldStandardCompleteness.COMPLETE);
findBestConfidence(testCaseResults, testCase, oldMatcherName + "_OneOne");
findBestConfidence(testCaseResults, testCase, oldMatcherName + "_OneOneNaive");

Executor.runMatcherOnTop(testCaseResults, oldMatcherName + "_OneOne",
new ConfidenceFilter(bestConfidence),
oldMatcherName + "_OneOne_CutBestCompleteConfidence" + bestConfidenceComplete
);

Executor.runMatcherOnTop(testCaseResults, oldMatcherName + "_OneOne",
Executor.runMatcherOnTop(testCaseResults, testCase, oldMatcherName + "_OneOne",
new ConfidenceFilter(0.5),
oldMatcherName + "_OneOne_CutConfidence0.5"
);

Executor.runMatcherOnTop(testCaseResults, testCase, oldMatcherName + "_OneOneNaive",
new ConfidenceFilter(0.5),
oldMatcherName + "_OneOneNaive_CutConfidence0.5"
);
}

private static void findBestConfidence(ExecutionResultSet testCaseResults, TestCase testCase, String oldMatcherName){
double bestConfidenceComplete = ConfidenceFinder.getBestConfidenceForFmeasure(testCase.getParsedReferenceAlignment(),
testCaseResults.get(testCase, oldMatcherName).getSystemAlignment(),
GoldStandardCompleteness.COMPLETE);

Executor.runMatcherOnTop(testCaseResults, testCase, oldMatcherName,
new MatcherPipelineYAAAJenaConstructor(
new ConfidenceFilter(bestConfidenceComplete),
new AddAlignmentExtensions(DefaultExtensions.MeltExtensions.CONFIGURATION_BASE + "bestConfidence", bestConfidenceComplete)
),
oldMatcherName + "_CutBestConfidence"
);
}


private static void runRecallOnly(CLIOptions cliOptions) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package de.uni_mannheim.informatik.dws.melt.examples.llm_transformers;

import de.uni_mannheim.informatik.dws.melt.matching_base.IMatcher;
import de.uni_mannheim.informatik.dws.melt.matching_jena.MatcherPipelineYAAAJenaConstructor;
import de.uni_mannheim.informatik.dws.melt.matching_jena.TextExtractor;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.elementlevel.HighPrecisionMatcher;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.filter.BadHostsFilter;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.filter.ConfidenceFilter;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.filter.extraction.NaiveDescendingExtractor;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.metalevel.AddAlignmentMatcher;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.metalevel.ConfidenceCombiner;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.util.StringProcessing;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.util.textExtractors.TextExtractorOnlyLabel;
import de.uni_mannheim.informatik.dws.melt.matching_jena_matchers.util.textExtractors.TextExtractorSet;
import de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers.LLMBinaryFilter;
import de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers.SentenceTransformersMatcher;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.Alignment;
import java.util.Properties;
import org.apache.jena.ontology.OntModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
*
*/
public class OLaLaForOAEI implements IMatcher<OntModel,Alignment,Properties> {

@Override
public Alignment match(OntModel source, OntModel target, Alignment inputAlignment, Properties parameters) throws Exception {

SentenceTransformersMatcher biEncoder = new SentenceTransformersMatcher(
TextExtractor.appendStringPostProcessing(new TextExtractorSet(), StringProcessing::normalizeOnlyCamelCaseAndUnderscore),
"multi-qa-mpnet-base-dot-v1"//"all-MiniLM-L6-v2"
);
biEncoder.setMultipleTextsToMultipleExamples(true);
biEncoder.setTopK(5);
//biEncoder.setTransformersCache(transformersCache);
biEncoder.addResourceFilter(SentenceTransformersPredicateBadHosts.class);


//String model = "TaylorAI/Flash-Llama-7B";
String model = "upstage/Llama-2-70b-instruct-v2";

LLMBinaryFilter llmTransformersFilter = new LLMBinaryFilter(
new TextExtractorOnlyLabel(),
model,
CLIOptions.PREDEFINED_PROMPTS.get(7));
llmTransformersFilter.setMultipleTextsToMultipleExamples(true);
//llmTransformersFilter.setTransformersCache(transformersCache);
llmTransformersFilter
.addGenerationArgument("max_new_tokens", 10)
.addGenerationArgument("temperature", 0.0);
llmTransformersFilter.addLoadingArguments(LLMConfiguration.getConfiguration(model).getLoadingArguments());

MatcherPipelineYAAAJenaConstructor highPrecision = new MatcherPipelineYAAAJenaConstructor(
new HighPrecisionMatcher(),
new BadHostsFilter()
);
Alignment highPrecisionAlignment = highPrecision.match(source, target, inputAlignment, parameters);

MatcherPipelineYAAAJenaConstructor matcher = new MatcherPipelineYAAAJenaConstructor(
biEncoder,
llmTransformersFilter,
new ConfidenceCombiner(LLMBinaryFilter.class),
new AddAlignmentMatcher(highPrecisionAlignment),
new NaiveDescendingExtractor(),
new ConfidenceFilter(0.5)
);

return matcher.match(source, target, inputAlignment, parameters);
}
}

0 comments on commit 5c4f1d4

Please sign in to comment.