Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement query linter suggestions #1306

Merged
merged 5 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs_website/docs/changelog/breaking_change.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ slug: /changelog

Here are the list of breaking changes that you should be aware of when updating Querybook:

## v3.27.0
Updated properties of `QueryValidationResult` object. `line` and `ch` are replaced with `start_line` and `start_ch` respectively.

## v3.22.0

Updated the charset of table `data_element` to `utf8mb4`. For those mysql db's default charset is not utf8, please run below sql to update it if needed.
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "querybook",
"version": "3.26.3",
"version": "3.27.0",
"description": "A Big Data Webapp",
"private": true,
"scripts": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,34 @@ class QueryValidationSeverity(Enum):
class QueryValidationResult(object):
def __init__(
self,
line: int, # 0 based
ch: int, # location of the starting token
start_line: int, # 0 based
start_ch: int, # location of the starting token
severity: QueryValidationSeverity,
message: str,
obj_type: QueryValidationResultObjectType = QueryValidationResultObjectType.LINT,
end_line: int = None, # 0 based
end_ch: int = None, # location of the ending token
suggestion: str = None,
):
self.type = obj_type
self.line = line
self.ch = ch
self.start_line = start_line
self.start_ch = start_ch
self.end_line = end_line
self.end_ch = end_ch
self.severity = severity
self.message = message
self.suggestion = suggestion

def to_dict(self):
return {
"type": self.type.value,
"line": self.line,
"ch": self.ch,
"start_line": self.start_line,
"start_ch": self.start_ch,
"end_line": self.end_line,
"end_ch": self.end_ch,
"severity": self.severity.value,
"message": self.message,
"suggestion": self.suggestion,
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from abc import ABCMeta, abstractmethod
from typing import List, Tuple
from sqlglot import Tokenizer
from sqlglot.tokens import Token

from lib.query_analysis.validation.base_query_validator import (
QueryValidationResult,
QueryValidationResultObjectType,
QueryValidationSeverity,
)


class BaseSQLGlotValidator(metaclass=ABCMeta):
@property
@abstractmethod
def message(self) -> str:
raise NotImplementedError()

@property
@abstractmethod
def severity(self) -> QueryValidationSeverity:
raise NotImplementedError()

@property
@abstractmethod
def tokenizer(self) -> Tokenizer:
raise NotImplementedError()

def _tokenize_query(self, query: str) -> List[Token]:
return self.tokenizer.tokenize(query)

def _get_query_coordinate_by_index(self, query: str, index: int) -> Tuple[int, int]:
rows = query[: index + 1].splitlines(keepends=False)
return len(rows) - 1, len(rows[-1]) - 1

def _get_query_validation_result(
self,
query: str,
start_index: int,
end_index: int,
suggestion: str = None,
validation_result_object_type=QueryValidationResultObjectType.LINT,
):
start_line, start_ch = self._get_query_coordinate_by_index(query, start_index)
end_line, end_ch = self._get_query_coordinate_by_index(query, end_index)

return QueryValidationResult(
start_line,
start_ch,
self.severity,
self.message,
validation_result_object_type,
end_line=end_line,
end_ch=end_ch,
suggestion=suggestion,
)

@abstractmethod
def get_query_validation_results(
self, query: str, raw_tokens: List[Token] = None
) -> List[QueryValidationResult]:
raise NotImplementedError()
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
from typing import List
from sqlglot import TokenType, Tokenizer
from sqlglot.dialects import Trino
from sqlglot.tokens import Token

from lib.query_analysis.validation.base_query_validator import (
BaseQueryValidator,
QueryValidationResult,
QueryValidationSeverity,
)
from lib.query_analysis.validation.validators.presto_explain_validator import (
PrestoExplainValidator,
)
from lib.query_analysis.validation.validators.base_sqlglot_validator import (
BaseSQLGlotValidator,
)


class BasePrestoSQLGlotValidator(BaseSQLGlotValidator):
@property
def tokenizer(self) -> Tokenizer:
return Trino.Tokenizer()


class UnionAllValidator(BasePrestoSQLGlotValidator):
@property
def message(self):
return "Using UNION ALL instead of UNION will execute faster"

@property
def severity(self) -> str:
return QueryValidationSeverity.WARNING

def get_query_validation_results(
self, query: str, raw_tokens: List[Token] = None
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)
validation_errors = []
for i, token in enumerate(raw_tokens):
if token.token_type == TokenType.UNION:
if (
i < len(raw_tokens) - 1
and raw_tokens[i + 1].token_type != TokenType.ALL
):
validation_errors.append(
self._get_query_validation_result(
query, token.start, token.end, "UNION ALL"
)
)
return validation_errors


class ApproxDistinctValidator(BasePrestoSQLGlotValidator):
@property
def message(self):
return (
"Using APPROX_DISTINCT(x) instead of COUNT(DISTINCT x) will execute faster"
)

@property
def severity(self) -> str:
return QueryValidationSeverity.WARNING

def get_query_validation_results(
self, query: str, raw_tokens: List[Token] = None
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)

validation_errors = []
for i, token in enumerate(raw_tokens):
if (
i < len(raw_tokens) - 2
and token.token_type == TokenType.VAR
and token.text.lower().strip() == "count"
and raw_tokens[i + 1].token_type == TokenType.L_PAREN
and raw_tokens[i + 2].token_type == TokenType.DISTINCT
):
validation_errors.append(
self._get_query_validation_result(
query,
token.start,
raw_tokens[i + 2].end,
"APPROX_DISTINCT(",
)
)
return validation_errors


class RegexpLikeValidator(BasePrestoSQLGlotValidator):
@property
def message(self):
return "Combining multiple LIKEs into one REGEXP_LIKE will execute faster"

@property
def severity(self) -> str:
return QueryValidationSeverity.WARNING

def _get_regexp_like_suggestion(self, column_name: str, like_strings: List[str]):
sanitized_like_strings = [
like_string.strip("\"'") for like_string in like_strings
]
return f"REGEXP_LIKE({column_name}, '{'|'.join(sanitized_like_strings)}')"

def get_query_validation_results(
self, query: str, raw_tokens: List[Token] = None
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)

validation_errors = []

start_column_token = None
like_strings = []
token_idx = 0
while token_idx < len(raw_tokens) - 2:
token_1 = raw_tokens[token_idx]
token_2 = raw_tokens[token_idx + 1]
token_3 = raw_tokens[token_idx + 2]

# Check if the next set of three tokens matches a "like" phrase (i.e. <column> LIKE <string>)
if (
token_1.token_type == TokenType.VAR
and (
start_column_token is None
or token_1.text == start_column_token.text
)
and token_2.token_type == TokenType.LIKE
and token_3.token_type == TokenType.STRING
):
if start_column_token is None:
start_column_token = raw_tokens[token_idx]
like_strings.append(token_3.text)
token_idx += 3
if (
token_idx == len(raw_tokens)
or raw_tokens[token_idx].token_type != TokenType.OR
): # No "OR" token following the phrase, so we cannot combine additional phrases
# Check if there are multiple phrases that can be combined
if len(like_strings) > 1:
validation_errors.append(
self._get_query_validation_result(
query,
start_column_token.start,
raw_tokens[token_idx - 1].end,
suggestion=self._get_regexp_like_suggestion(
start_column_token.text, like_strings
),
)
)
start_column_token = None
like_strings = []

# If next tokens do not match the "like" phrase pattern, check if a suggestion can be made if there are previously matched phrases
elif start_column_token is not None:
if (
len(like_strings) > 1
): # Check if a validation suggestion can be created
validation_errors.append(
self._get_query_validation_result(
query,
start_column_token.start,
raw_tokens[token_idx - 1].end,
suggestion=self._get_regexp_like_suggestion(
start_column_token.text, like_strings
),
)
)
start_column_token = None
like_strings = []
token_idx += 1

return validation_errors


class PrestoOptimizingValidator(BaseQueryValidator):
def languages(self):
return ["presto", "trino"]

def _get_explain_validator(self):
return PrestoExplainValidator("")

def _get_sqlglot_validators(self) -> List[BaseSQLGlotValidator]:
return [
UnionAllValidator(),
ApproxDistinctValidator(),
RegexpLikeValidator(),
]

def _get_sql_glot_validation_results(
self, query: str
) -> List[QueryValidationResult]:
validation_suggestions = []

query_raw_tokens = None
for validator in self._get_sqlglot_validators():
if query_raw_tokens is None:
query_raw_tokens = validator._tokenize_query(query)
validation_suggestions.extend(
validator.get_query_validation_results(
query, raw_tokens=query_raw_tokens
)
)

return validation_suggestions

def _get_presto_explain_validation_results(
self, query: str, uid: int, engine_id: int
) -> List[QueryValidationResult]:
return self._get_explain_validator().validate(query, uid, engine_id)

def validate(
self,
query: str,
uid: int,
engine_id: int,
) -> List[QueryValidationResult]:
validation_results = [
*self._get_presto_explain_validation_results(query, uid, engine_id),
*self._get_sql_glot_validation_results(query),
]
return validation_results
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def test_simple(self):
1, statement_start_locations, error_line=0, error_ch=2, error_msg=""
)

self.assertEqual(validation_result.line, 0)
self.assertEqual(validation_result.ch, 12)
self.assertEqual(validation_result.start_line, 0)
self.assertEqual(validation_result.start_ch, 12)

validation_result = self._validator._map_statement_error_to_query(
2, statement_start_locations, error_line=0, error_ch=5, error_msg=""
)
self.assertEqual(validation_result.line, 1)
self.assertEqual(validation_result.ch, 5)
self.assertEqual(validation_result.start_line, 1)
self.assertEqual(validation_result.start_ch, 5)
Loading
Loading