Skip to content

Commit

Permalink
change code
Browse files Browse the repository at this point in the history
  • Loading branch information
Gambotch1 committed Jun 26, 2024
1 parent 31b0d26 commit 530ee32
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 143 deletions.
8 changes: 2 additions & 6 deletions ner-submission/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
# Dockerfile
# docker build -t fschlatt/authorship-verification-trivial:0.0.1 .
FROM fschlatt/natural-language-processing-exercises:0.0.1

RUN pip install sklearn-crfsuite

ADD run.py /code/run.py
ADD train.py /code/train.py
ADD model.joblib /code/model.joblib

ENTRYPOINT ["python3", "/code/run.py"]
ENTRYPOINT [ "python3", "/code/run.py" ]
81 changes: 22 additions & 59 deletions ner-submission/run.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,35 @@
from pathlib import Path
from joblib import load
from tira.rest_api_client import Client
from tira.third_party_integrations import get_output_directory
import pandas as pd

def preprocess_data(text_data):
data = []
for i in range(len(text_data)):
sentence = text_data.iloc[i]['sentence'].split()
data.append(sentence)
return data

def extract_features(sentence, i):
word = sentence[i]
features = {
'word': word,
'is_upper': word.isupper(),
'is_title': word.istitle(),
'is_digit': word.isdigit(),
'suffix-3': word[-3:],
}
if i > 0:
word1 = sentence[i-1]
features.update({
'-1:word': word1,
'-1:is_upper': word1.isupper(),
'-1:is_title': word1.istitle(),
'-1:is_digit': word1.isdigit(),
})
else:
features['BOS'] = True

if i < len(sentence)-1:
word1 = sentence[i+1]
features.update({
'+1:word': word1,
'+1:is_upper': word1.isupper(),
'+1:is_title': word1.istitle(),
'+1:is_digit': word1.isdigit(),
})
# Simple heuristic function to determine entity type based on common patterns
def simple_heuristic_token_classification(token):
if token.istitle():
return "B-per" # Assume title case words are persons
elif token.isupper():
return "B-org" # Assume upper case words are organizations
else:
features['EOS'] = True

return features

def sent2features(sentence):
return [extract_features(sentence, i) for i in range(len(sentence))]
return "O" # Default to outside any named entity

if __name__ == "__main__":
tira = Client()

# Load the data
text_validation = tira.pd.inputs("nlpbuw-fsu-sose-24", "ner-validation-20240612-training")

# Preprocess data
val_data = preprocess_data(text_validation)
X_val = [sent2features(s) for s in val_data]

# Load the model
model = load(Path(__file__).parent / "model.joblib")
tira = Client()

# Predict
y_pred = model.predict(X_val)
# loading validation data (automatically replaced by test data when run on tira)
text_validation = tira.pd.inputs(
"nlpbuw-fsu-sose-24", "ner-validation-20240612-training"
)
targets_validation = tira.pd.truths(
"nlpbuw-fsu-sose-24", "ner-validation-20240612-training"
)

# Save predictions
# labeling the data with simple heuristics
predictions = text_validation.copy()
predictions['tags'] = [list(x) for x in y_pred]
predictions['tags'] = predictions['sentence'].apply(lambda x: [simple_heuristic_token_classification(token) for token in x.split(' ')])
predictions = predictions[['id', 'tags']]


# saving the prediction
output_directory = get_output_directory(str(Path(__file__).parent))
predictions.to_json(Path(output_directory) / "predictions.jsonl", orient="records", lines=True)
predictions.to_json(
Path(output_directory) / "predictions.jsonl", orient="records", lines=True
)
78 changes: 0 additions & 78 deletions ner-submission/train.py

This file was deleted.

0 comments on commit 530ee32

Please sign in to comment.