This project implements a multi-label classification model that combines image and text data using a custom CNN and text embedding models. The model is trained using PyTorch and handles class imbalance, data augmentation, and evaluation metrics for multi-label classification.
Ensure you have the following software installed:
-
Python 3.7 or higher
-
PyTorch
-
torchvision
-
scikit-learn
-
pandas
-
numpy
-
torchmetrics
-
torchtext
-
PIL (Pillow)
-
train.csv
andtest.csv
: CSV files containing training and testing data with columns 'ImageID', 'Caption', and 'Labels'. -
fixed_data/
: Directory containing the processed images. -
vocab.pkl
: The vocabulary file generated from the captions. -
main.py
: The main script containing the code for data preprocessing, model definition, training, and evaluation. -
README.md
: This file.
-
Clone the repository:
git clone <repository_url> cd <repository_directory>
-
Install the required Python packages:
pip install torch torchvision scikit-learn pandas numpy torchmetrics torchtext pillow
-
Prepare the dataset:
- Ensure
train.csv
andtest.csv
are in the project root directory. - Ensure the images mentioned in the CSV files are placed in the
fixed_data/
directory.
- Ensure
-
Preprocess the images:
from main import process_dataset process_dataset('path/to/input_folder', 'path/to/output_folder', target_size=300)
-
Build the vocabulary:
from main import build_vocab build_vocab('train.csv', 'test.csv', 'vocab.pkl')
-
Train the model:
from main import train_model train_model('train.csv', 'vocab.pkl')
To start training the model, run the following command:
python main.py