-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
141 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,10 @@ | ||
# Use an appropriate base image for natural language processing exercises | ||
# Dockerfile | ||
FROM fschlatt/natural-language-processing-exercises:0.0.1 | ||
|
||
# Install system dependencies | ||
RUN apt-get update && \ | ||
apt-get install -y --no-install-recommends \ | ||
build-essential \ | ||
gcc \ | ||
g++ \ | ||
&& rm -rf /var/lib/apt/lists/* | ||
RUN pip install sklearn-crfsuite | ||
|
||
# Install Python dependencies | ||
RUN pip install tira rest_api_client pandas sklearn_crfsuite matplotlib seqeval evaluate | ||
|
||
# Add the script to the image | ||
ADD run.py /code/run.py | ||
ADD train.py /code/train.py | ||
ADD model.joblib /code/model.joblib | ||
|
||
# Define the entry point | ||
ENTRYPOINT ["python3", "/code/run.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,72 @@ | ||
import json | ||
from pathlib import Path | ||
from joblib import load | ||
from tira.rest_api_client import Client | ||
from tira.third_party_integrations import get_output_directory | ||
import pandas as pd | ||
|
||
def generate_predictions(text_validation): | ||
predictions = [] | ||
for _, row in text_validation.iterrows(): | ||
sentence = row['sentence'] | ||
tokens = sentence.split() | ||
tags = ['B-geo'] * len(tokens) # Simplified tagging logic | ||
predictions.append({'id': int(row['id']), 'tags': tags}) | ||
return predictions | ||
def preprocess_data(text_data): | ||
data = [] | ||
for i in range(len(text_data)): | ||
sentence = text_data.iloc[i]['sentence'].split() | ||
data.append(sentence) | ||
return data | ||
|
||
def extract_features(sentence, i): | ||
word = sentence[i] | ||
features = { | ||
'word': word, | ||
'is_upper': word.isupper(), | ||
'is_title': word.istitle(), | ||
'is_digit': word.isdigit(), | ||
'suffix-3': word[-3:], | ||
} | ||
if i > 0: | ||
word1 = sentence[i-1] | ||
features.update({ | ||
'-1:word': word1, | ||
'-1:is_upper': word1.isupper(), | ||
'-1:is_title': word1.istitle(), | ||
'-1:is_digit': word1.isdigit(), | ||
}) | ||
else: | ||
features['BOS'] = True | ||
|
||
if i < len(sentence)-1: | ||
word1 = sentence[i+1] | ||
features.update({ | ||
'+1:word': word1, | ||
'+1:is_upper': word1.isupper(), | ||
'+1:is_title': word1.istitle(), | ||
'+1:is_digit': word1.isdigit(), | ||
}) | ||
else: | ||
features['EOS'] = True | ||
|
||
return features | ||
|
||
def sent2features(sentence): | ||
return [extract_features(sentence, i) for i in range(len(sentence))] | ||
|
||
if __name__ == "__main__": | ||
tira = Client() | ||
|
||
# Load validation data (automatically replaced by test data when run on TIRA) | ||
# Load the data | ||
text_validation = tira.pd.inputs("nlpbuw-fsu-sose-24", "ner-validation-20240612-training") | ||
|
||
# Generate predictions | ||
predictions = generate_predictions(text_validation) | ||
# Preprocess data | ||
val_data = preprocess_data(text_validation) | ||
X_val = [sent2features(s) for s in val_data] | ||
|
||
# Save predictions | ||
output_directory = get_output_directory(str(Path(__file__).parent)) | ||
output_file = Path(output_directory) / "predictions.jsonl" | ||
# Load the model | ||
model = load(Path(__file__).parent / "model.joblib") | ||
|
||
with open(output_file, 'w') as f: | ||
for prediction in predictions: | ||
f.write(json.dumps(prediction) + "\n") | ||
# Predict | ||
y_pred = model.predict(X_val) | ||
|
||
print(f"Predictions saved to {output_file}") | ||
# Save predictions | ||
predictions = text_validation.copy() | ||
predictions['tags'] = [list(x) for x in y_pred] | ||
predictions = predictions[['id', 'tags']] | ||
|
||
output_directory = get_output_directory(str(Path(__file__).parent)) | ||
predictions.to_json(Path(output_directory) / "predictions.jsonl", orient="records", lines=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from pathlib import Path | ||
from joblib import dump | ||
import pandas as pd | ||
import sklearn_crfsuite | ||
from sklearn_crfsuite import metrics | ||
from tira.rest_api_client import Client | ||
|
||
def preprocess_data(text_data, labels_data): | ||
data = [] | ||
for i in range(len(text_data)): | ||
sentence = text_data.iloc[i]['sentence'].split() | ||
labels = labels_data.iloc[i]['tags'] | ||
data.append((sentence, labels)) | ||
return data | ||
|
||
def extract_features(sentence, i): | ||
word = sentence[i] | ||
features = { | ||
'word': word, | ||
'is_upper': word.isupper(), | ||
'is_title': word.istitle(), | ||
'is_digit': word.isdigit(), | ||
'suffix-3': word[-3:], | ||
} | ||
if i > 0: | ||
word1 = sentence[i-1] | ||
features.update({ | ||
'-1:word': word1, | ||
'-1:is_upper': word1.isupper(), | ||
'-1:is_title': word1.istitle(), | ||
'-1:is_digit': word1.isdigit(), | ||
}) | ||
else: | ||
features['BOS'] = True | ||
|
||
if i < len(sentence)-1: | ||
word1 = sentence[i+1] | ||
features.update({ | ||
'+1:word': word1, | ||
'+1:is_upper': word1.isupper(), | ||
'+1:is_title': word1.istitle(), | ||
'+1:is_digit': word1.isdigit(), | ||
}) | ||
else: | ||
features['EOS'] = True | ||
|
||
return features | ||
|
||
def sent2features(sentence): | ||
return [extract_features(sentence, i) for i in range(len(sentence))] | ||
|
||
def sent2labels(sentence): | ||
return [label for label in sentence] | ||
|
||
if __name__ == "__main__": | ||
tira = Client() | ||
|
||
# Load the data | ||
text_train = tira.pd.inputs("nlpbuw-fsu-sose-24", "ner-training-20240612-training") | ||
targets_train = tira.pd.truths("nlpbuw-fsu-sose-24", "ner-training-20240612-training") | ||
|
||
# Preprocess data | ||
train_data = preprocess_data(text_train, targets_train) | ||
X_train = [sent2features(s) for s, t in train_data] | ||
y_train = [t for s, t in train_data] | ||
|
||
# Train CRF model | ||
crf = sklearn_crfsuite.CRF( | ||
algorithm='lbfgs', | ||
c1=0.1, | ||
c2=0.1, | ||
max_iterations=100, | ||
all_possible_transitions=True | ||
) | ||
crf.fit(X_train, y_train) | ||
|
||
# Save the model | ||
dump(crf, Path(__file__).parent / "model.joblib") |