-
Notifications
You must be signed in to change notification settings - Fork 0
/
split_train_val.py
50 lines (34 loc) · 1.36 KB
/
split_train_val.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import numpy as np
import shutil
np.random.seed(2016)
root_train = 'D:/programs/Kaggle_NCFM-master/data/train_split'
root_val = 'D:/programs/Kaggle_NCFM-master/data/val_split'
root_total = r'D:\programs\Kaggle_NCFM-master\data\train'
FishNames = ['calling', 'normal', 'smoking', 'smoking_calling']
nbr_train_samples = 0
nbr_val_samples = 0
# Training proportion
split_proportion = 0.8
for fish in FishNames:
if fish not in os.listdir(root_train):
os.mkdir(os.path.join(root_train, fish))
total_images = os.listdir(os.path.join(root_total, fish))
nbr_train = int(len(total_images) * split_proportion)
np.random.shuffle(total_images)
train_images = total_images[:nbr_train]
val_images = total_images[nbr_train:]
for img in train_images:
source = os.path.join(root_total, fish, img)
target = os.path.join(root_train, fish, img)
shutil.copy(source, target)
nbr_train_samples += 1
if fish not in os.listdir(root_val):
os.mkdir(os.path.join(root_val, fish))
for img in val_images:
source = os.path.join(root_total, fish, img)
target = os.path.join(root_val, fish, img)
shutil.copy(source, target)
nbr_val_samples += 1
print('Finish splitting train and val images!')
print('# training samples: {}, # val samples: {}'.format(nbr_train_samples, nbr_val_samples))