-
Notifications
You must be signed in to change notification settings - Fork 48
/
plot_reuters.py
32 lines (24 loc) · 1.18 KB
/
plot_reuters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
'''
Created on Dec, 2016
@author: hugo
'''
from __future__ import absolute_import
import argparse
from autoencoder.testing.visualize import reuters_visualize_pca_2d, reuters_visualize_tsne
from autoencoder.utils.io_utils import load_json
def main():
parser = argparse.ArgumentParser()
parser.add_argument('doc_codes_file', type=str, help='path to the input corpus file')
parser.add_argument('doc_labels_file', type=str, help='path to the output doc codes file')
parser.add_argument('cmd', choices=['pca', 'tsne'], help='plot cmd')
parser.add_argument('-o', '--output', type=str, default='out.png', help='path to the output file')
args = parser.parse_args()
cmd = args.cmd.lower()
classes_to_visual = {'ECAT': 'ECONOMICS', 'MCAT': 'MARKETS',
'CCAT': 'CORPORATE/INDUSTRIAL', 'GCAT': 'GOVERNMENT/SOCIAL'}
if cmd == 'pca':
reuters_visualize_pca_2d(load_json(args.doc_codes_file), load_json(args.doc_labels_file), classes_to_visual, args.output)
elif cmd == 'tsne':
reuters_visualize_tsne(load_json(args.doc_codes_file), load_json(args.doc_labels_file), classes_to_visual, args.output)
if __name__ == '__main__':
main()