-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathindex_recommender.py
326 lines (268 loc) · 11.5 KB
/
index_recommender.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import pandas as pd
import re
from sql_metadata import Parser
from collections import defaultdict
from itertools import combinations
import psycopg2 as pg
from pglast import parser
from random import shuffle, randint, choice
import os.path
import pickle
import time
verbose = True
def vprint(*kwargs):
if verbose:
print(*kwargs)
CONNECTION_STRING = "dbname='project1db' user='project1user' password='project1pass' host='localhost'"
MAX_WIDTH = 4
"""
Drops all the non-constrained indexes
'borrowed' from https://stackoverflow.com/a/48862822
"""
def drop_all_indexes(conn):
vprint("--- Dropping all existing indexes ---")
get_drop_index_queries_query = """
SELECT format('drop index %I.%I;', s.nspname, i.relname) as drop_statement
FROM pg_index idx
JOIN pg_class i on i.oid = idx.indexrelid
JOIN pg_class t on t.oid = idx.indrelid
JOIN pg_namespace s on i.relnamespace = s.oid
WHERE s.nspname in ('public')
AND not idx.indisprimary AND not idx.indisprimary;"""
with conn.cursor() as cur:
cur.execute(get_drop_index_queries_query)
drop_queries = cur.fetchall()
for query in drop_queries:
cur.execute(query[0])
cur.execute(get_drop_index_queries_query)
drop_queries = cur.fetchall()
vprint("--- Indexes dropped ---")
"""
Fetches the set of all user defined tables and their columns from the current database.
"""
def get_table_column_map(conn):
QUERY = "select table_name, column_name from information_schema.columns where table_schema='public';"
table_columns_map = defaultdict(set)
with conn.cursor() as cursor:
cursor.execute(QUERY)
db_columns = cursor.fetchall()
for table, column in db_columns:
table_columns_map[table].add(column)
return table_columns_map
def filter_interesting_queries(queries):
res = [re.sub(r'^statement: ', '', q)
for q in queries if q.startswith('statement')]
res = [q for q in res if 'pg_' not in q and not q.startswith('SHOW ALL') and not q.startswith(
'COMMIT') and not q.startswith('SET') and not q.startswith('BEGIN')]
res = [q for q in res if 'WHERE' in q or 'ORDER' in q or 'JOIN' in q or 'join' in q or 'where' in q or 'order' in q]
return res
"""
Clusters the given queries by mapping each query to its query template
which does not have the specific params.
"""
def cluster_queries(queries):
cluster_frequencies = defaultdict(int)
clusters = {}
gqueries = set()
for query in queries:
try:
generalized = parser.fingerprint(query)
cluster_frequencies[generalized] += 1
clusters[generalized] = query
gqueries.add(generalized)
except:
pass
return gqueries, cluster_frequencies, clusters
"""Extracts groups of columns for each cluster representative query
based on whether they appear together in a clause"""
def get_relevant_columns(clusters, cluster_counts, table_column_map, max_width):
table_column_combos = defaultdict(set)
for qgroup, query in clusters.items():
try:
pq = Parser(query)
columns = pq.columns_dict
pq_tables = pq.tables
groups = ['where', 'order_by', 'join']
for group in groups:
if group in columns:
group_table_columns = defaultdict(set)
for col in columns[group]:
if '.' not in col:
for table in pq_tables:
if table in table_column_map and col in table_column_map[table]:
col = table + '.' + col
break
tab, c = col.split('.')
group_table_columns[tab].add(col)
for tab, cols in group_table_columns.items():
for width in range(1, min(len(cols)+1, max_width+1)):
for col_combo in combinations(cols, width):
table_column_combos[tab].add(col_combo)
except:
pass
return table_column_combos
def get_potential_indexes(table_column_combos):
potential_indexes = []
for v in table_column_combos.values():
potential_indexes.extend(v)
return potential_indexes
def get_combinations_list(potential_indexes):
combs = []
for i in range(1, len(potential_indexes)+1):
combs.extend(combinations(potential_indexes, i))
return combs
def get_random_indexes(potential_indexes):
num_indexes = randint(1, len(potential_indexes))
indexes_set = set()
while len(indexes_set) < num_indexes:
indexes_set.add(choice(potential_indexes))
return indexes_set
"""
Processes the given workload logs and extracts columns relevant to indexing from it.
It also clusters the queries by mapping each query to its query template
which does not have the specific params.
"""
def get_columns_from_logs(logs_path, max_width=2):
QCOL = 13
df = pd.read_csv(logs_path, header=None)
queries = filter_interesting_queries(df[QCOL].tolist())
with pg.connect(CONNECTION_STRING) as conn:
table_column_map = get_table_column_map(conn)
gqueries, cluster_frequencies, clusters = cluster_queries(queries)
table_column_combos = get_relevant_columns(
clusters, cluster_frequencies, table_column_map, max_width=MAX_WIDTH)
return table_column_combos, cluster_frequencies, clusters
def generate_index_creation_queries(columns):
query_template = 'CREATE INDEX ON {} ({})'
queries = []
for column_group in columns:
table_name = column_group[0].split('.')[0]
columns = [column.split('.')[1] for column in column_group]
queries.append(query_template.format(table_name, ', '.join(columns)))
return queries
def create_hypothetical_indexes(index_queries, conn):
hypo_template = "SELECT * FROM hypopg_create_index('{}');"
with conn.cursor() as cur:
for index_creation_query in index_queries:
cur.execute(hypo_template.format(index_creation_query))
res = cur.fetchall()
"""
Scales the given cost estimation based on the frequency of each query.
"""
def get_scaled_loss(cluster_frequencies, costs):
cost = 0.
for cluster, count in cluster_frequencies.items():
cost += count*costs[cluster]
return cost
"""
Drops all the hypothetical indexes in the db.
"""
def remove_hypo_indexes(conn):
reset_indexes_q = 'SELECT * FROM hypopg_reset();'
with conn.cursor() as cur:
cur.execute(reset_indexes_q)
res = cur.fetchall()
"""
Enables hypopg in the given database
"""
def enable_hypopg(conn):
with conn.cursor() as cur:
cur.execute('CREATE EXTENSION IF NOT EXISTS hypopg;')
"""
Retrieves the query plan costs for a given set of queries
@param query_clusters A dictionary of the form query_fingerprint: query
"""
def get_query_costs(query_clusters, conn):
costs = dict()
with conn.cursor() as cur:
for cluster, query in query_clusters.items():
cur.execute('EXPLAIN (FORMAT JSON) '+query)
explain_res = cur.fetchall()
costs[cluster] = explain_res[0][0][0]['Plan']['Total Cost']
return costs
"""
Generates a file called actions.sql with create index statements based on its evaluation of given the workload logs.
@param log_file_path Filepath of the workload logs (csv)
@param timeout Timeout in minutes of the form Xm
@param max_iterations Specifies the number of times the function should try a new configuration
"""
def find_best_index(log_file_path, timeout, max_iterations=99999):
TIMEOUT_BUFFER = 20
start_time = time.time()
timeout = int(timeout.replace('m', ''))*60
with open('config.json', 'w') as f:
f.write('{"VACUUM": false}')
with pg.connect(CONNECTION_STRING) as conn:
state_file = log_file_path + '.statefile'
vprint(f"--- State being maintained in {state_file} ---")
if (os.path.isfile(state_file)):
with open(state_file, 'rb') as f:
saved_objects = pickle.load(f)
best_config = saved_objects['best_config']
min_cost = saved_objects['min_cost']
baseline_cost = saved_objects['baseline_cost']
table_column_combos = saved_objects['table_column_combos']
cluster_frequencies = saved_objects['cluster_frequencies']
clusters = saved_objects['clusters']
else:
table_column_combos, cluster_frequencies, clusters = get_columns_from_logs(
log_file_path)
# calculate the costs with no indexes
baseline_costs = get_query_costs(clusters, conn)
baseline_cost = get_scaled_loss(
cluster_frequencies, baseline_costs)
min_cost = baseline_cost
best_config = []
# drop all existing indexes
drop_all_indexes(conn)
# ensure that hypopg is enabled
enable_hypopg(conn)
potential_indexes = get_potential_indexes(table_column_combos)
for i in range(max_iterations):
# generate a random combination of indexes
cmb = get_random_indexes(potential_indexes)
index_q = generate_index_creation_queries(cmb)
create_hypothetical_indexes(index_q, conn)
costs = get_query_costs(clusters, conn)
# make sure that the total cost is proportional to the cluster frequencies
cost = get_scaled_loss(cluster_frequencies, costs)
remove_hypo_indexes(conn)
if cost < min_cost or (cost == min_cost and len(cmb) < len(best_config)):
min_cost = cost
best_config = cmb
time_elapsed = time.time() - start_time
if i % 1000 == 0 or ((timeout - time_elapsed) <= TIMEOUT_BUFFER):
index_creation_queries = generate_index_creation_queries(
best_config)
with open('actions.sql', 'w') as f:
for query in index_creation_queries:
f.write("{};\n".format(query))
with open(state_file, 'wb') as f:
saved_objects = pickle.dump({
'best_config': best_config,
'min_cost': min_cost,
'baseline_cost': baseline_cost,
'table_column_combos': table_column_combos,
'cluster_frequencies': cluster_frequencies,
'clusters': clusters
}, f)
if (timeout - time_elapsed) <= TIMEOUT_BUFFER:
vprint(
f'--- {time_elapsed}s have elapsed, within {TIMEOUT_BUFFER}s of the timeout {timeout}s. Exiting ---')
return
index_creation_queries = generate_index_creation_queries(best_config)
vprint('--- Best Indexes ---')
vprint(index_creation_queries)
vprint('--- Best Indexes END ---')
with open('actions.sql', 'w') as f:
for query in index_creation_queries:
f.write("{};\n".format(query))
with open(state_file, 'wb') as f:
saved_objects = pickle.dump({
'best_config': best_config,
'min_cost': min_cost,
'baseline_cost': baseline_cost,
'table_column_combos': table_column_combos,
'cluster_frequencies': cluster_frequencies,
'clusters': clusters
}, f)