🔥 Leveraging Hierarchical Attention Network with Transformer (HANTransformer) for document classification tasks using 20 Newsgroups dataset.
2)Features
6)Training
8)Usage
10)Contributing
11)License
The Hierarchical Attention Network with Transformer (HANTransformer) is a sophisticated model designed for document classification tasks. By leveraging a hierarchical structure, the model effectively captures both word-level and sentence-level information, integrating Part-of-Speech (POS) tags and rule-based embeddings to enhance its understanding of textual data. This project demonstrates the implementation, preprocessing, and training processes required to deploy the HANTransformer using the 20 Newsgroups dataset.
1)Hierarchical Structure: Processes documents at both word and sentence levels.
2)Transformer-Based: Utilizes multi-head self-attention mechanisms for contextual understanding.
3)Incorporates POS Tags and Rule Embeddings: Enhances feature representation with linguistic information.
4)Scalable Preprocessing: Efficiently tokenizes and encodes data using multiprocessing.
5)Flexible Configuration: Easily adjustable hyperparameters for various use-cases. Comprehensive Training Pipeline: Includes training, evaluation, and model saving functionalities.
The HANTransformer model comprises several key components:
1)Fusion Layer: Combines word embeddings, POS tag embeddings, and rule embeddings using a gating mechanism.
2)Positional Encoding: Adds learnable positional information to embeddings.
3)Multi-Head Self-Attention: Captures dependencies and relationships within the data.
4)Transformer Encoder Layers: Stacks multiple layers of attention and feed-forward networks for deep feature extraction.
5)Attention Mechanisms: Applies attention at both word and sentence levels to generate meaningful representations.
6)Classification Head: Outputs logits corresponding to the target classes.
Python 3.7+
pip package manager
git clone https://github.com/yourusername/your-repo-name.git
cd your-repo-name
conda create -n HANT python=3.11
conda activate HANT
pip install -r requirements.txt
python -m spacy download en_core_web_sm
The preprocessing pipeline tokenizes the text data, builds vocabularies, encodes the texts, POS tags, and rules, and saves the processed data for training.
Steps:
1)Tokenization: Splits documents into sentences and words.
2)Vocabulary Building: Constructs vocabularies for words, POS tags, and rules.
3)Encoding: Converts tokens and tags into numerical IDs.
4)Mask Creation: Generates attention and sentence masks to handle padding.
5)Saving Processed Data: Stores the preprocessed data in JSON format.
python preprocess.py
This will create a data directory containing processed_data.json.
The training script initializes the model, loads the preprocessed data, and trains the HANTransformer on the 20 Newsgroups dataset.
Steps:
1)Load Processed Data: Reads the preprocessed JSON data.
2)Create Datasets and DataLoaders: Prepares data for batching.
3)Initialize the Model: Sets up the HANTransformer with specified hyperparameters.
4)Define Loss and Optimizer: Uses CrossEntropyLoss and Adam optimizer.
5)Training Loop: Trains the model for a defined number of epochs, evaluating on the test set.
6)Model Saving: Saves the best-performing model based on test accuracy.
python train.py
The best model will be saved as best_model.pt in the data directory.
During training, the model is evaluated on the test set after each epoch. The evaluation metrics include:
python evaluate.py
After training we can make predictions by running the prediction script.
python predict.py
This will ask the users an input sentence and give predictions on the sentence.