-
Notifications
You must be signed in to change notification settings - Fork 1
/
crawl_pubs.py
114 lines (94 loc) · 3.28 KB
/
crawl_pubs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import argparse
import csv
import os
import re
from typing import Dict, List, Optional
import matplotlib.pyplot as plt
import requests
BUILD_DIR = "build"
BASE_URL = "https://scholar.google.com/scholar"
SEARCH_TERMS = {
"offline-rl": [
"offline reinforcement learning",
"batch reinforcement learning",
"offline RL",
"batch RL",
],
"rl": [
"reinforcement learning",
],
}
PATTERN = re.compile("About ([0-9,]*) results")
def get_query_string(search_terms: List[str]) -> str:
return " OR ".join(map(lambda term: f'"{term}"', search_terms))
def get_num_articles(year: int, search_terms: List[str]) -> int:
response = requests.get(
BASE_URL,
params={
"hl": "en",
"as_ylo": year,
"as_yhi": year,
"q": get_query_string(search_terms),
},
)
assert (
response.status_code == 200
), f"Error in request: status {response.status_code}"
match = re.search(PATTERN, response.text)
assert (
match is not None
), f"Error processing Google Scholar query for url {response.url}."
num_articles_str = match.groups()[0]
return int(num_articles_str.replace(",", ""))
def plot_num_articles(
years: List[int], num_articles_per_year: Dict[str, List[int]]
) -> None:
fig, ax1 = plt.subplots(figsize=(10, 10))
ax2 = ax1.twinx()
for (category, num_articles), ax, color in zip(
num_articles_per_year.items(), (ax1, ax2), ("blue", "red")
):
ax.plot(years, num_articles, color=color, label=category)
ax1.legend()
ax2.legend()
fig.savefig(os.path.join(BUILD_DIR, "num-articles-per-year.pdf"))
def save_num_articles(
save_path: str, years: List[int], num_articles_per_year: Dict[str, List[int]]
) -> None:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "w", encoding="utf8") as csv_file:
field_names = ["year"] + list(num_articles_per_year.keys())
writer = csv.DictWriter(csv_file, fieldnames=field_names)
writer.writeheader()
for idx, year in enumerate(years):
entry = {"year": year}
for category, num_articles in num_articles_per_year.items():
entry[category] = num_articles[idx]
writer.writerow(entry)
def main(save_path: Optional[str], render: bool) -> None:
years = list(range(2011, 2022))
num_articles_per_year: Dict[str, List[int]] = {}
for category, search_terms in SEARCH_TERMS.items():
num_articles_per_year[category] = [
get_num_articles(year, search_terms) for year in years
]
if render:
plot_num_articles(years, num_articles_per_year)
if save_path is not None:
save_num_articles(save_path, years, num_articles_per_year)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="This script crawls Google Scholar for information on the "
"number of publications in RL in the past decade."
)
parser.add_argument(
"--save_path",
help="Path to output csv file with crawled data.",
)
parser.add_argument(
"--render",
action="store_true",
help="Whether to render the plot uisng matplotlib.",
)
args = parser.parse_args()
main(**vars(args))