-
Notifications
You must be signed in to change notification settings - Fork 1
/
data_split.py
35 lines (32 loc) · 1.3 KB
/
data_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import pandas as pd
import csv
df = pd.read_csv("./dataset/labelled_newscatcher_dataset.csv", encoding='utf-8', sep=';')
labels = set(df['topic'])
contents = df['title']
count = {}
cal = {}
for p in df['topic']:
cal[p] = 0
try:
count[p] += 1
except KeyError:
count[p] = 1
print(count)
train, val, test = [], [], []
for i, label in enumerate(df['category']):
if cal[label] < count[label] * 0.7:
train.append({'label': label, 'content': contents[i]})
elif cal[label] < count[label] * 0.85:
val.append({'label': label, 'content': contents[i]})
else:
test.append({'label': label, 'content': contents[i]})
cal[label] += 1
with open('./dataset/train.csv', 'a', newline='', encoding='utf-8') as f:
xieru = csv.DictWriter(f, ['label','content'],delimiter=';')
xieru.writerows(train) # writerows方法是一下子写入多行内容
with open('./dataset/val.csv', 'a', newline='', encoding='utf-8') as f:
xieru = csv.DictWriter(f, ['label','content'],delimiter=';')
xieru.writerows(val) # writerows方法是一下子写入多行内容
with open('./dataset/test.csv', 'a', newline='', encoding='utf-8') as f:
xieru = csv.DictWriter(f, ['label','content'],delimiter=';')
xieru.writerows(test) # writerows方法是一下子写入多行内容