Skip to content

Commit

Permalink
Add train_test_split
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul committed Dec 24, 2024
1 parent e8058bb commit a59127e
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import torch
from sklearn.model_selection import train_test_split
from tokenizers import Tokenizer
from torch import nn

Expand Down Expand Up @@ -71,21 +72,21 @@ def loss_calculator(

def fit(
self,
train_texts: list[str],
train_labels: list[str],
validation_texts: list[str],
validation_labels: list[str],
texts: list[str],
labels: list[str],
**kwargs: Any,
) -> FinetunableStaticModel:
"""Fit a model."""
classes = sorted(set(train_labels) | set(validation_labels))
classes = sorted(set(labels))
self.classes_ = classes

if len(self.classes) != self.out_dim:
self.out_dim = len(self.classes)
self.head = self.construct_head()

label_mapping = {label: idx for idx, label in enumerate(self.classes)}
train_texts, validation_texts, train_labels, validation_labels = train_test_split(texts, labels, test_size=0.1)

# Turn labels into a LongTensor
train_labels_tensor = torch.Tensor([label_mapping[label] for label in train_labels]).long()
train_dataset = TextDataset(train_texts, train_labels_tensor, self.tokenizer)
Expand Down

0 comments on commit a59127e

Please sign in to comment.