forked from daveshap/PlainTextWikipedia
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dewiki_functions.py
95 lines (78 loc) · 3.33 KB
/
dewiki_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import json
import re
from collections import Counter
from multiprocessing import Pool, Semaphore
from tqdm import tqdm
import psutil
from html2text import html2text as htt
import wikitextparser as wtp
NAMESPACE_DICT = Counter(["Talk", "Media", "User", "Wikipedia", "File", "MediaWiki", "MediaWiki talk",
"Template", "Help", "Category", "Category talk", "User talk", "Help talk",
"Special", "Portal", "Portal talk", "TimedText", "TimedText talk", "Module",
"Module talk", "Draft"])
SEMAPHORE = Semaphore(1000)
def dewiki(text):
text = wtp.parse(text).plain_text() # wiki to plaintext
text = htt(text) # remove any HTML
text = re.sub('\\n', ' ', text) # replace newlines
text = re.sub(r'\s+', ' ', text) # replace excess whitespace
return text
def analyze_chunk(text):
try:
if '<redirect title="' in text: # this is not the main article
return None
if '(disambiguation)' in text: # this is not an article
return None
else:
title = text.split('<title>')[1].split('</title>')[0]
title = htt(title)
# We are only interested in articles from the main namespace
if title.split(":", 1)[0] in NAMESPACE_DICT:
return None
serial = text.split('<id>')[1].split('</id>')[0]
content = text.split('</text')[0].split('<text')[1].split('>', maxsplit=1)[1]
content = dewiki(content)
return {'title': title.strip(), 'text': content.strip(), 'id': serial.strip()}
except Exception as oops:
print(oops)
return None
def save_article(article, savedir):
global SEMAPHORE
doc = analyze_chunk(article)
if doc:
#print('SAVING:', doc['title'])
filename = get_valid_filename(f"{doc['title']}_{doc['id']}.json")
with open(savedir + filename, 'w', encoding='utf-8') as outfile:
json.dump(doc, outfile, sort_keys=True, indent=1, ensure_ascii=False)
SEMAPHORE.release()
def file_linecount(filename: str) -> int:
"""Count the number of lines in a file"""
with open(filename, 'rb') as f:
return sum(1 for _ in f)
def handle_sub_proc_exception(e: Exception):
print(str(e))
def get_valid_filename(name):
s = str(name).strip().replace(" ", "_")
s = re.sub(r"(?u)[^-\w.]", "", s)
if s in {"", ".", ".."}:
raise ValueError("Could not derive file name from '%s'" % name)
return s
def process_file_text(filename, savedir):
page_start_regex = re.compile(r"<page>")
page_end_regex = re.compile(r"</page>")
print(f"Counting total number of lines in {filename}...")
process_bar = tqdm(total=file_linecount(filename))
article = ''
with (open(filename, 'r', encoding='utf-8') as infile,
Pool(processes=psutil.cpu_count(logical=False)) as process_pool):
for line in infile:
if re.search(page_start_regex, line):
article = ''
elif re.search(page_end_regex, line): # end of article
SEMAPHORE.acquire()
process_pool.apply_async(save_article, args=(article, savedir),
error_callback=handle_sub_proc_exception)
else:
article += line
process_bar.update(1)
process_bar.close()