Skip to content

Commit

Permalink
refactor and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AndersToft20 committed Nov 29, 2023
1 parent 678a93f commit 3be96c9
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 50 deletions.
4 changes: 2 additions & 2 deletions relation_extraction/LessNaive/lessNaive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .openie import POST_corenlp
import json
import sys
ontology_file_path = 'DBpedia_Ont.ttl'

import urllib.parse
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from rapidfuzz.distance import Levenshtein
Expand Down Expand Up @@ -72,7 +72,7 @@ def do_relation_extraction(data, ontology_relations):
return tuples

def main():
ontology_relations = extract_specific_relations(ontology_file_path)
ontology_relations = extract_specific_relations()
do_relation_extraction(json.load(open("inputSentences.json")), ontology_relations)


Expand Down
27 changes: 7 additions & 20 deletions relation_extraction/NaiveMVP/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import multiprocessing as mp
from functools import partial

ontology_file_path = '../DBpedia_Ont.ttl'
threshold = 0
normalized_levenshtein = NormalizedLevenshtein()

Expand All @@ -17,50 +16,37 @@ def find_best_match(token, relations):
"Finds the best match given a token and a set of relations"
best_relation_match = ""
highest_similarity = 0
dt = datetime.datetime.now()
for relation in relations:
similarity = normalized_levenshtein.similarity(token, relation)
highest_similarity = similarity if similarity > highest_similarity else highest_similarity
best_relation_match = relation if similarity == highest_similarity else best_relation_match
# print(f"find_best_match: {(datetime.datetime.now()-dt).total_seconds()}")

return {'similarity': highest_similarity, 'predicted_relation': best_relation_match}

def filter_tokens(tokens, entity_mentions):
"Filters out tokens that are substrings of the entity mentions"

filtered_tokens = []

for entity_mention in entity_mentions:
for token in tokens:
if token not in entity_mention["name"]:
filtered_tokens.append(token)

return filtered_tokens
ems = [em["name"] for em in entity_mentions]
return [token for token in tokens if token not in ems]

def find_best_triple(sentence, relations):
"Finds the best triple by comparing each token in a sentence to every relation and returning the triple where the similarity was highest"
entity_mentions = sentence["entity_mentions"]
dt = datetime.datetime.now()
filtered_tokens = filter_tokens(sentence["tokens"], entity_mentions)
#print(f"filter_tokens: {(datetime.datetime.now()-dt).total_seconds()}")
best_triple = []
highest_similarity = 0
dt = datetime.datetime.now()
for token in filtered_tokens:
result = find_best_match(token, relations)
if result["similarity"] > highest_similarity and result["similarity"] > threshold: #Only supporting 2 entity mentions per sentence
highest_similarity = result["similarity"]
best_triple = [entity_mentions[0]["iri"], result["predicted_relation"], entity_mentions[1]["iri"]]
if highest_similarity == 0:
best_triple = [entity_mentions[0]["name"], "---",entity_mentions[1]["name"]]
#print(f"handle all tokens: {(datetime.datetime.now()-dt).total_seconds()}")
best_triple = [entity_mentions[0]["iri"], "---",entity_mentions[1]["iri"]]
return best_triple

def parse_data(data, relations):
"Parses JSON data and converts it into a dictionary with information on sentence, tokens, and entity mentions"
output = []
for file in data:
file_name = file["fileName"]
sentences_in_data = file["sentences"]

for sentence_object in sentences_in_data:
Expand All @@ -87,7 +73,8 @@ def handle_relation_post_request(data):
try:
relations = extract_specific_relations()
except Exception as E:
print(f"Exection during retrieval of relations: {str(E)}")
print(f"Exception during retrieval of relations: {str(E)}")
raise Exception(f"Exception during retrieval of relations")

try:
parsed_data = parse_data(data, relations)
Expand All @@ -103,7 +90,7 @@ def handle_relation_post_request(data):


def main():
relations = extract_specific_relations(ontology_file_path)
relations = extract_specific_relations()
# Opening JSON file
with open('inputSentences.json', 'r') as f:
# returns JSON object as a dictionary
Expand Down
35 changes: 19 additions & 16 deletions relation_extraction/get_relations.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
import requests
import re

# def extract_specific_relations(ontology_file_path):
# "Function to extract relations based on the specified pattern"
# relations = set()
# with open(ontology_file_path, 'r', encoding='utf-8', errors='ignore') as file:
# lines = file.readlines()
# i = 0
# for i, line in enumerate(lines):
# line = line.strip()
# # Check if the line starts with a colon and the next lines contain the specified pattern
# if line.startswith(":") and i+1 < len(lines) and "a rdf:Property, owl:ObjectProperty ;" in lines[i+1]:
# relation = line.split()[0] # Extracting the relation name
# relation = relation[1:] # Remove colon
# relations.add(relation)
# i += 1
def extract_specific_relations_offline():
"Function to extract relations based on the specified pattern"
ontology_file_path = "./DBpedia_Ont.ttl"
print("Extracting relations offline...")
relations = set()
with open(ontology_file_path, 'r', encoding='utf-8', errors='ignore') as file:
lines = file.readlines()
i = 0
for i, line in enumerate(lines):
line = line.strip()
# Check if the line starts with a colon and the next lines contain the specified pattern
if line.startswith(":") and i+1 < len(lines) and "a rdf:Property, owl:ObjectProperty ;" in lines[i+1]:
relation = line.split()[0] # Extracting the relation name
relation = relation[1:] # Remove colon
relations.add(relation)
i += 1

# return sorted(relations)
return sorted(relations)

def extract_specific_relations():
"Function to extract relations based on the specified pattern"
print("Getting relations from online ontology...")
relations = []
URL = "http://192.38.54.90/ontology"
URL = "http://192.38.54.90/triples?g=http://knox_ontology"
query_string_s = 'http://dbpedia.org/ontology/'
query_string_o = 'http://www.w3.org/1999/02/22-rdf-syntax-ns#Property'
PARAMS = {"s": query_string_s, "o": query_string_o}
Expand Down
2 changes: 1 addition & 1 deletion relation_extraction/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ def format_output(output):
return formatted_output

def send_to_database_component(output):
URL = "http://192.38.54.90/knowledge-base"
URL = "http://192.38.54.90/triples?g=http://knox_database"
response = requests.post(url=URL, json=format_output(output))
print(f"db component response: {response.text}")
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
requests==2.31.0
strsimpy==0.2.1
strsimpy==0.2.1
mock==5.1.0
20 changes: 10 additions & 10 deletions test/test_server/test_server.py → test/test_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,22 @@

class TestServerEndpoint(unittest.TestCase):
@classmethod
def setUpClass(cls):
def setUpClass(self):
lock.acquire()
cls.server_thread = Thread(target=cls.start_server)
cls.server_thread.daemon = True
cls.server_thread.start()
self.server_thread = Thread(target=self.start_server)
self.server_thread.daemon = True
self.server_thread.start()

@classmethod
def start_server(cls):
cls.server = HTTPServer(('localhost', PORT), PreProcessingHandler)
def start_server(self):
self.server = HTTPServer(('localhost', PORT), PreProcessingHandler)
lock.release()
cls.server.serve_forever()
self.server.serve_forever()

@classmethod
def tearDownClass(cls):
cls.server.shutdown()
cls.server.server_close()
def tearDownClass(self):
self.server.shutdown()
self.server.server_close()

def test_pre_processing_endpoint_with_valid_data(self):
while(lock.locked()):
Expand Down
124 changes: 124 additions & 0 deletions test/test_server/test_pre_processing_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import unittest
from server.server import *
from unittest.mock import patch, Mock, MagicMock


class TestPreProcessingHandler(unittest.TestCase):

@patch('server.server.handle_relation_post_request', return_value=Mock())
@patch.object(PreProcessingHandler, 'wrongly_formatted_request_response')
@patch.object(PreProcessingHandler, 'handled_request_body', return_value=True)
@patch.object(PreProcessingHandler, '__init__', return_value=None)
def test_do_post_tripleconstruction_valid(self, mock_init, mock_handled_body, mock_wrong_resp, mock_handle_relation):
mock_init.return_value = None
handler = PreProcessingHandler()
handler.rfile = MagicMock()
handler.wfile = MagicMock()
handler.headers = {'Content-Length': '0'}
handler.send_response = MagicMock()
handler.send_header = MagicMock()
handler.end_headers = MagicMock()

# simulate a post request call to '/tripleconstruction'
handler.path = '/tripleconstruction'
handler.do_POST()
self.assertTrue(mock_handled_body.called)
mock_handle_relation.assert_called_once()
handler.send_response.assert_called_once_with(200)
handler.send_header.assert_called_once_with('Content-type', 'text/html')
handler.end_headers.assert_called_once()


@patch('server.server.handle_relation_post_request', return_value=Mock())
@patch.object(PreProcessingHandler, 'wrongly_formatted_request_response')
@patch.object(PreProcessingHandler, 'handled_request_body', return_value=True)
@patch.object(PreProcessingHandler, '__init__', return_value=None)
def test_do_post_invalid_endpoint(self, mock_init, mock_handled_body, mock_wrong_resp, mock_handle_relation):
mock_init.return_value = None
handler = PreProcessingHandler()
handler.rfile = MagicMock()
handler.wfile = MagicMock()
handler.headers = {'Content-Length': '0'}
handler.send_response = MagicMock()
handler.send_header = MagicMock()
handler.end_headers = MagicMock()

# simulate a post request call to an invalid endpoint
handler.path = '/invalid-endpoint'
handler.send_response.reset_mock()
handler.send_header.reset_mock()
handler.end_headers.reset_mock()
handler.do_POST()
handler.send_response.assert_called_once_with(404)
handler.send_header.assert_called_once_with('Content-type','text/html')
handler.end_headers.assert_called_once()


@patch('server.server.handle_relation_post_request', return_value=Mock())
@patch.object(PreProcessingHandler, 'wrongly_formatted_request_response')
@patch.object(PreProcessingHandler, 'handled_request_body', return_value=True)
@patch.object(PreProcessingHandler, '__init__', return_value=None)
def test_do_post_wrongly_formatted_request(self, mock_init, mock_handled_body, mock_wrong_resp, mock_handle_relation):
mock_init.return_value = None
handler = PreProcessingHandler()
handler.rfile = MagicMock()
handler.wfile = MagicMock()
handler.headers = {'Content-Length': '0'}
handler.send_response = MagicMock()
handler.send_header = MagicMock()
handler.end_headers = MagicMock()
mock_handle_relation.side_effect = Exception("test exception")

# simulate a post request call to an invalid endpoint
handler.path = '/tripleconstruction'
handler.do_POST()
mock_wrong_resp.assert_called_once_with("test exception")

@patch.object(PreProcessingHandler, '__init__', return_value=None)
def test_wrongly_formatted_request_response(self, mock_init):
mock_init.return_value = None
handler = PreProcessingHandler()
handler.rfile = MagicMock()
handler.wfile = MagicMock()
handler.wfile.write = MagicMock()
handler.headers = {'Content-Length': '0'}
handler.send_response = MagicMock()
handler.send_header = MagicMock()
handler.end_headers = MagicMock()

handler.wrongly_formatted_request_response("test message")
handler.send_response.assert_called_once_with(422)
handler.send_header.assert_called_once_with('Content-type','text/html')
handler.end_headers.assert_called_once()
handler.wfile.write.assert_called_once()


@patch('server.server.PreProcessingHandler.wrongly_formatted_request_response')
@patch.object(PreProcessingHandler, '__init__', return_value=None)
def test_handled_request_body_exception(self, mock_init, mock_wrongly_formatted_request_response):
mock_init.return_value = None
pph = PreProcessingHandler()

encoded_content = "Wrongly_formatted_data".encode()
post_content = {"post_data": encoded_content, "post_json": {}}
result = pph.handled_request_body(post_content)

self.assertFalse(result)
mock_wrongly_formatted_request_response.assert_called_once()


@patch('server.server.PreProcessingHandler.wrongly_formatted_request_response')
@patch.object(PreProcessingHandler, '__init__', return_value=None)
def test_handled_request_returns_true(self, mock_init, mock_wrongly_formatted_request_response):
mock_init.return_value = None
pph = PreProcessingHandler()

encoded_content = json.dumps({"test": "correct data"}).encode()
post_content = {"post_data": encoded_content, "post_json": {}}
result = pph.handled_request_body(post_content)

self.assertTrue(result)
mock_wrongly_formatted_request_response.assert_not_called()

if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 3be96c9

Please sign in to comment.