Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev deepfm for check #385

Open
wants to merge 23 commits into
base: dev_deepfm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions KnowledgeDistillation/KnowledgeDistillation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Knowledge Distillation
Implementation of the knowledge distillation algorithm with [OneFlow](https://github.com/Oneflow-Inc/oneflow#install-with-pip-package).

---
## Knowledge Distillation:
In order to avoid the problem that the super-large model is not conducive to online, knowledge distillation aims to solve how to use a model with a small parameter amount to make it have a comparable effect with the large model.

KD mainly consists of two parts, namely Teacher and Student:
- Teacher: denotes the original large model, usually learning directly on supervised data. In the inference stage, obtain the probability distribution of each sample;
- Student: denotes the small model to be obtained, which learns based on the probability distribution of each sample obtained by the Teacher model, that is, learns the prior of the Teacher model

Therefore, the simple KD is mainly divided into two steps. First train the Teacher model, and then train the Student model.

In addition, we will also test the student model that has not been taught by the Teacher model for comparison to ensure that the algorithm is indeed effective.

## Data Acqusition
MNIST(training data:60000,testing data:10000)
Oneflow has implemented the data acquisition code (ofrecord format), so the subsequent code will be downloaded automatically without manual downloading. If you view specific data, please look up the details: http://yann.lecun.com/exdb/mnist/

## requirement

This project uses the lightly version of oneflow. You can use the following command to install.
CPU:
```bash
python3 -m pip install -f https://staging.oneflow.info/branch/master/cpu --pre oneflow
```
GPU:
```bash
python3 -m pip install -f https://staging.oneflow.info/branch/master/cu112 --pre oneflow
```
You can install other dependencies using the following command.
```bash
pip install -r requirements.txt
```

## Experiment Settings

You can use the command below to perform the training process. Modify the parameter `--model_type` for different training tasks. The optional parameters are explained as follows:
```bash
bash train.sh
```

"teacher": training Teacher model.
"student_kd": train the student model under the guidance of the Teacher model.
"student": train the student model (without the guidance of the Teacher model).
"compare": compare the advantages of the above three models.

### Step1:Training the Teacher Model:
Set the parameter `--model_type` to `teacher` and run the command to train the Teacher model.

We can obtain 98.19%。

### Step2:Training the Student Model with Teacher Model:
Set the parameter `--model_type` to `student_kd` and run the command to train the student_kd model.

Select the best Teacher model on the test set, and then use the soft label obtained by the Teacher model as supervision to train the student model.

Obtain the model file with the highest accuracy(e.g. `output/model_save/teacher`),then modify the parameter `--load_teacher_checkpoint_dir="./output/model_save/teacher"`to load the Teacher model, then the student model is trained under the guidance of the Teacher model.

We can obtain 89.19%。

### Step3:Training the Student Model without Teacher Model:
Set the parameter `--model_type` to `student` and run the command to train the student model.

We can obtain 89.19%。

### Step4:Compare the Performance of the Three Models:

Set the parameter `--model_type` to `compare` to run the above three tasks at one time, and verify that the knowledge distillation algorithm does improve the performance of the student model.
![](./output/images/compare_result.jpg)

### Explaination of Command Parameters for `train.py` Script
| parameters | meaning | remarks |
| -------- | -------- | -------- |
| --model_save_dir | Path to save the model. | The model names are: teacher,student_kd, student. |
| --image_save_name | Save name of pictures after training. | |
| --load_teacher_checkpoint_dir | Path to load Teacher model. | |
| --model_type | Type of model. | There are only four options, teacher, student_ kd, student and compare. |
| --epochs | How many epoches to train the models. | |
| --batch_size | Train batch size of each device. | |
| --temperature | Temperature of distillation Teacher model. | |
| --alpha | Weight coefficient when calculating student model loss. | |


## Inference on Single Image

### Quick Start
You can run the following command to infer a single picture:
```bash
bash infer.sh
```

### Explaination of Command Parameters for `infer.py` Script
You can also modify the parameters to predict according to the meaning of the parameters.

| parameters | meaning | remarks |
| -------- | -------- | -------- |
| --model_type | Type of model. | There are only three options: teacher, student_ kd and student. |
| --model_load_dir | Path from which to load the model. | |
| --image_save_name | Save the predicted picture to this path. | |
| --picture_index | Use flowvision to load the data set and select which picture for infering. | The range is 0-9999 |
| --temperature | Temperature of distillation the model. | |

The predicted results are similar to the following figure:
![](./output/images/infer_result.jpg)
From top to bottom are the figure, the predicted result, and the result after distillation.
96 changes: 96 additions & 0 deletions KnowledgeDistillation/KnowledgeDistillation/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import argparse
import matplotlib.pyplot as plt
import numpy as np
import time
import os

import oneflow as flow
from flowvision import datasets, transforms

from model import TeacherNet, StudentNet


def softmax_t(x, t):
x_exp = np.exp(x / t)
return x_exp / np.sum(x_exp)


if __name__ == "__main__":

parser = argparse.ArgumentParser("flags for test resnet50")
parser.add_argument(
"--model_load_dir", type=str, default="./output/model_save/teacher"
)
parser.add_argument(
"--model_type",
type=str,
default="teacher",
choices=["teacher", "student_kd", "student"],
)
parser.add_argument("--picture_index", type=int, default=0)
parser.add_argument("--temperature", type=float, default=10.0)
parser.add_argument(
"--image_save_name",
type=str,
default="./output/images/infer.jpg",
required=False,
help="images save name",
)

args = parser.parse_args()
args.device = flow.device("cuda" if flow.cuda.is_available() else "cpu")

start_t = time.perf_counter()
print("***** Model Init *****")
if args.model_type == "teacher":
model = TeacherNet()
else: # student_ks, student
model = StudentNet()
model.load_state_dict(flow.load(args.model_load_dir))
end_t = time.perf_counter()
print(f"***** Model Init Finish, time escapled {end_t - start_t:.6f} s *****")

model = model.to(args.device)
model.eval()

# dataset
dataset = datasets.MNIST(
"./",
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
subset_indices = [args.picture_index]
subset = flow.utils.data.Subset(dataset, subset_indices)
# dataloader
data_loader = flow.utils.data.DataLoader(subset, batch_size=1, shuffle=False,)
with flow.no_grad():
data, target = next(iter(data_loader))
data, target = data.to(args.device), target.to(args.device)
output = model(data)

test_x = data.cpu().numpy()
y_out = output.cpu().numpy()
y_out = y_out[0, ::]
print("Output (NO softmax):", y_out)
print("the number is", flow.argmax(output).cpu().numpy())

plt.subplot(3, 1, 1)
plt.imshow(test_x[0, 0, ::])

plt.subplot(3, 1, 2)
plt.bar(list(range(10)), softmax_t(y_out, 1), width=0.3)

plt.subplot(3, 1, 3)
plt.bar(list(range(10)), softmax_t(y_out, args.temperature), width=0.3)

directory = os.path.abspath(
os.path.dirname(args.image_save_name) + os.path.sep + "."
) # If the path does not exist, create it.
if not os.path.exists(directory):
os.makedirs(directory)

plt.savefig(args.image_save_name)
print("picture saved.")
6 changes: 6 additions & 0 deletions KnowledgeDistillation/KnowledgeDistillation/infer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
python infer.py \
--model_type="student_kd" \
--model_load_dir="./output/model_save/student_kd" \
--image_save_name="./output/images/infer_result.jpg" \
--picture_index=9999 \
--temperature=10
43 changes: 43 additions & 0 deletions KnowledgeDistillation/KnowledgeDistillation/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import oneflow as flow
import oneflow.nn as nn
import oneflow.nn.functional as F


class TeacherNet(nn.Module):
def __init__(self):
super(TeacherNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.3)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = flow.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
output = self.fc2(x)
return output


class StudentNet(nn.Module):
def __init__(self):
super(StudentNet, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
x = flow.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
output = F.relu(self.fc3(x))
return output
3 changes: 3 additions & 0 deletions KnowledgeDistillation/KnowledgeDistillation/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
flowvision==0.2.0
matplotlib==3.4.3
numpy==1.21.2
Loading