Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FEATURE] Put every class in training set #21028

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions tools/im2rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ def write_list(path_out, image_list):
line += '%s\n' % item[1]
fout.write(line)


def each_class_to_beginning(image_list: list):
"""Take off one photo of each class"""
images = {}
for elt in image_list:
cls = elt[-1]
if cls not in images:
images[cls] = elt
unique_classes = list(images.values())
for elt in unique_classes:
image_list.remove(elt)
return unique_classes + image_list

def make_list(args):
"""Generates .lst file.
Parameters
Expand All @@ -101,6 +114,7 @@ def make_list(args):
if args.shuffle is True:
random.seed(100)
random.shuffle(image_list)
image_list = each_class_to_beginning(image_list)
N = len(image_list)
chunk_size = (N + args.chunks - 1) // args.chunks
for i in range(args.chunks):
Expand All @@ -115,10 +129,10 @@ def make_list(args):
write_list(args.prefix + str_chunk + '.lst', chunk)
else:
if args.test_ratio:
write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test])
write_list(args.prefix + str_chunk + '_test.lst', chunk[sep:sep+sep_test])
if args.train_ratio + args.test_ratio < 1.0:
write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:])
write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep])
write_list(args.prefix + str_chunk + '_val.lst', chunk[sep+sep_test:])
write_list(args.prefix + str_chunk + '_train.lst', chunk[:sep])

def read_list(path_in):
"""Reads the .lst file and generates corresponding iterator.
Expand Down