Skip to content

Commit

Permalink
Merge pull request #31 from PascalEgn/fix-memory-usage
Browse files Browse the repository at this point in the history
classifier: fix memory usage
  • Loading branch information
PascalEgn authored Aug 2, 2024
2 parents 86e3221 + 6972469 commit d59c610
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 24 deletions.
47 changes: 29 additions & 18 deletions inspire_classifier/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,24 +151,30 @@ def train():
train_and_save_classifier()


def predict_coreness(title, abstract):
def initialize_classifier():
"""
Predicts class-wise probabilities given the title and abstract.
Initializes the classifier.
"""
text = title + " <ENDTITLE> " + abstract
categories = ["rejected", "non_core", "core"]
try:
classifier = Classifier(
cuda_device_id=current_app.config["CLASSIFIER_CUDA_DEVICE_ID"]
)
except IOError as error:
raise IOError("Data ITOS not found.") from error

classifier = Classifier(
cuda_device_id=current_app.config["CLASSIFIER_CUDA_DEVICE_ID"]
)
try:
classifier.load_trained_classifier_weights(path_for("trained_classifier"))
except IOError as error:
raise IOError("Could not load the trained classifier weights.") from error
raise IOError(
"Could not load the trained classifier weights.",
path_for("trained_classifier"),
) from error

return classifier


def predict_coreness(classifier, title, abstract):
"""
Predicts class-wise probabilities given the title and abstract.
"""
text = title + " <ENDTITLE> " + abstract
categories = ["rejected", "non_core", "core"]
class_probabilities = classifier.predict(
text, temperature=current_app.config["CLASSIFIER_SOFTMAX_TEMPERATUR"]
)
Expand All @@ -191,17 +197,22 @@ def validate(validation_df):
raise IOError("There was a problem loading the classifier model") from error
predictions = []
validation_df = validation_df.sample(frac=1, random_state=42)
for _, row in tqdm(
validation_df.iterrows(), total=len(validation_df.label.values)
):
for _, row in tqdm(validation_df.iterrows(), total=len(validation_df.label.values)):
predicted_value = classifier.predict(
row.text, temperature=current_app.config["CLASSIFIER_SOFTMAX_TEMPERATUR"]
)
predicted_class = np.argmax(predicted_value)
predictions.append(predicted_class)

validation_df.insert(2, 'predicted_label', predictions)
validation_df.insert(2, "predicted_label", predictions)
validation_df.to_csv(f"{path_for('data')}/validation_results.csv", index=False)
print("f1 score ", f1_score(validation_df["label"], validation_df["predicted_label"], average="micro"))
pprint(classification_report(validation_df["label"], validation_df["predicted_label"]))
print(
"f1 score ",
f1_score(
validation_df["label"], validation_df["predicted_label"], average="micro"
),
)
pprint(
classification_report(validation_df["label"], validation_df["predicted_label"])
)
pprint(confusion_matrix(validation_df["label"], validation_df["predicted_label"]))
6 changes: 4 additions & 2 deletions inspire_classifier/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from prometheus_flask_exporter.multiprocess import GunicornInternalPrometheusMetrics
from webargs.flaskparser import use_args

from inspire_classifier.api import predict_coreness
from inspire_classifier.api import initialize_classifier, predict_coreness

from . import serializers

Expand All @@ -55,6 +55,8 @@ def create_app():
app.config["CLASSIFIER_BASE_PATH"] = app.instance_path
app.config.from_object("inspire_classifier.config")
app.config.from_pyfile("classifier.cfg", silent=True)
with app.app_context():
classifier = initialize_classifier()

@app.route("/api/health")
def date():
Expand All @@ -69,7 +71,7 @@ def date():
)
def core_classifier(args):
"""Endpoint for the CORE classifier."""
prediction = predict_coreness(args["title"], args["abstract"])
prediction = predict_coreness(classifier, args["title"], args["abstract"])
response = coreness_schema.dump(prediction)
return response

Expand Down
16 changes: 12 additions & 4 deletions scripts/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,13 @@ def __init__(self, index, query_filters, year_from, year_to, month_from, month_t
self.inspire_categories_field = "inspire_categories.term"
self.query_filters = [
query_filters
& Q("range", _created={"gte": f"{self.year_from}-{self.month_from}",
"lt": f"{self.year_to}-{self.month_to}",}),
& Q(
"range",
_created={
"gte": f"{self.year_from}-{self.month_from}",
"lt": f"{self.year_to}-{self.month_to}",
},
),
]

def _postprocess_record_data(self, record_data):
Expand Down Expand Up @@ -143,7 +148,9 @@ def prepare_inspire_classifier_dataset(data, save_data_path):
inspire_data_df["text"] = (
inspire_data_df["title"] + " <ENDTITLE> " + inspire_data_df["abstract"]
)
inspire_classifier_data_df = inspire_data_df[["id", "inspire_categories", "label", "text"]]
inspire_classifier_data_df = inspire_data_df[
["id", "inspire_categories", "label", "text"]
]
inspire_classifier_data_df.to_pickle(save_data_path)


Expand All @@ -163,7 +170,8 @@ def get_inspire_classifier_dataset(year_from, year_to, month_from, month_to):
month_to = f"{month_to:02d}-31"
print(f"Fetching {year_from}-{month_from} to {year_to}-{month_to}")
inspire_classifier_dataset_path = os.path.join(
os.getcwd(), f"inspire_classifier_dataset_{year_from}-{month_from}_{year_to}-{month_to}.pkl"
os.getcwd(),
f"inspire_classifier_dataset_{year_from}-{month_from}_{year_to}-{month_to}.pkl",
)
data = get_data_for_decisions(year_from, year_to, month_from, month_to)
prepare_inspire_classifier_dataset(data, inspire_classifier_dataset_path)
Expand Down

0 comments on commit d59c610

Please sign in to comment.