diff --git a/experimental/paper_experiments/query_all_repos.py b/experimental/paper_experiments/query_all_repos.py new file mode 100644 index 000000000..0e29d7bc5 --- /dev/null +++ b/experimental/paper_experiments/query_all_repos.py @@ -0,0 +1,167 @@ +from typing import List +import requests +import argparse + + +# GitHub API endpoint for repository search +GITHUB_REPO_URL = "https://api.github.com/search/repositories" +# GitHub API endpoint for code search +GITHUB_CODE_SEARCH_URL = "https://api.github.com/search/code" + +# # String to search for in the repositories +# search_string = "new SparkConf()" + + +# parses the command line arguments +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Searches for a string in the top starred Java and Scala repositories on GitHub" + ) + parser.add_argument( + "search_string", + help="String to search for in the repositories", + ) + parser.add_argument("token", help="Github CLI token") + parser.add_argument( + "--languages", + nargs="+", + default=["Java", "Scala"], + help="List of languages to search for", + ) + parser.add_argument( + "--per_page", + default=100, + help="Number of results per page (max 100)", + ) + parser.add_argument( + "--page", + default=1, + help="Page number", + ) + # argument to specify the path to write the results + parser.add_argument( + "--output", + default="results.csv", + help="Path to write the results", + ) + return parser.parse_args() + + +class MatchedRepo: + owner: str + name: str + stars: int + search_string: str + token: str + languages: list[str] + number_of_matches: int = 0 + files: list[str] = [] + + def __init__(self, name, owner, stars, search_string, token, languages): + self.name = name + self.owner = owner + self.stars = stars + self.search_string = search_string + self.token = token + self.languages = languages + self.lookup() + + def lookup(self): + headers = {"Authorization": f"Bearer {self.token}"} + code_params = { + "q": f"{self.search_string} repo:{self.owner}/{self.name}", + } + + code_response = requests.get( + "https://api.github.com/search/code", + params=code_params, + headers=headers, + ) + if code_response.status_code == 200: + code_data = code_response.json() + self.number_of_matches = code_data["total_count"] + self.files = [ + item["path"] + for item in code_data["items"] + if any(l for l in self.languages if l.lower() in item["path"]) + ] + def to_csv(self): + # files = "|".join(self.files) + return f"{self.name}, {self.owner}, {self.stars}, {self.number_of_matches}"#, {files}" + + +def get_repo_info( + response_json, search_string, token, languages +) -> List[MatchedRepo]: + repositories = [] + for item in response_json["items"]: + name = item["name"] + owner = item["owner"]["login"] + stars = item["stargazers_count"] + + repositories.append( + MatchedRepo(name, owner, stars, search_string, token, languages) + ) + return repositories + + +def search(token, search_string, languages, output_csv): + # Set up the headers with your token + headers = {"Authorization": f"Bearer {token}"} + # Parameters for the repository search + _lang_clause = " ".join([f"language:{l}" for l in languages]) + repo_params = { + "q": f"stars:>100 {_lang_clause}", + "sort": "stars", + "order": "desc", + "per_page": 100, # Number of results per page (max 100) + } + counter = 0 + try: + # Fetch the top starred repositories + while True: + counter += 1 + if counter > 70: + break + repositories = [] + repo_params["page"] = counter + print(counter) + response = requests.get( + GITHUB_REPO_URL, params=repo_params, headers=headers + ) + + if response.status_code == 200: + data = response.json() + repositories = get_repo_info( + response_json=data, + search_string=search_string, + token=token, + languages=languages, + ) + for r in repositories: + if r.number_of_matches > 0: + entry = r.to_csv() + with open(output_csv, "a+") as f: + f.write(entry + "\n") + print(entry) + if "next" not in response.links: + break + + else: + print( + f"Repository request failed with status code {response.status_code}" + ) + print(response.text) + break + + except requests.exceptions.RequestException as e: + print(f"An error occurred: {e}") + + +args = parse_arguments() +search( + token=args.token, + search_string=args.search_string, + languages=args.languages, + output_csv=args.output, +) diff --git a/experimental/paper_experiments/spark2to3.py b/experimental/paper_experiments/spark2to3.py new file mode 100644 index 000000000..85362dab3 --- /dev/null +++ b/experimental/paper_experiments/spark2to3.py @@ -0,0 +1,290 @@ +from typing import Any, Dict, Optional, Tuple +from tree_sitter import Node, Tree +from utils import ( + JAVA, + SCALA_SOURCE_CODE, + parse_code, + traverse_tree, + rewrite, + JAVA_SOURCE_CODE, + SCALA, +) + + +relevant_builder_method_names_mapping = { + "setAppName": "appName", + "setMaster": "master", + "set": "config", + "setAll": "all", + "setIfMissing": "ifMissing", + "setJars": "jars", + "setExecutorEnv": "executorEnv", + "setSparkHome": "sparkHome", +} + + +def get_initializer_named( + tree: Tree, name: str, language: str +) -> Optional[Node]: + for node in traverse_tree(tree): + if language == JAVA: + if node.type == "object_creation_expression": + oce_type = node.child_by_field_name("type") + if oce_type and oce_type.text.decode() == name: + return node + if language == SCALA: + if node.type == "instance_expression": + if any(c.text.decode() == name for c in node.children): + return node + + +def get_enclosing_variable_declaration_name_type( + node: Node, language: str +) -> Tuple[Node | None, str | None]: + name, nd = None, None + if language == JAVA: + if node.parent and node.parent.type == "variable_declarator": + n = node.parent.child_by_field_name("name") + if n: + name = n.text.decode() + if ( + node.parent.parent + and node.parent.parent.type == "local_variable_declaration" + ): + t = node.parent.parent.child_by_field_name("type") + if t: + nd = node.parent.parent + if language == SCALA: + if ( + node.parent + and node.parent.type == "val_definition" + ): + n = node.parent.child_by_field_name("pattern") + if n: + name = n.text.decode() + nd = node.parent + return nd, name + + +def all_enclosing_method_invocations(node: Node, language: str) -> list[Node]: + if language == JAVA: + if node.parent and node.parent.type == "method_invocation": + return [node.parent] + all_enclosing_method_invocations( + node.parent, language + ) + else: + return [] + else: + if node.parent and node.parent.parent and node.parent.parent.type == "call_expression": + return [node.parent.parent] + all_enclosing_method_invocations( + node.parent.parent, language + ) + else: + return [] + + +def build_spark_session_builder(builder_mappings: list[tuple[str, Node]]): + replacement_expr = 'new SparkSession.builder().config("spark.sql.legacy.allowUntypedScalaUDF", "true")' + for name, args in builder_mappings: + replacement_expr += f".{name}{args.text.decode()}" + return replacement_expr + + +def update_spark_conf_init( + tree: Tree, src_code: str, state: Dict[str, Any], language: str +) -> Tuple[Tree, str]: + spark_conf_init = get_initializer_named(tree, "SparkConf", language) + if not spark_conf_init: + print("No SparkConf initializer found") + return tree, src_code + + encapsulating_method_invocations = all_enclosing_method_invocations( + spark_conf_init, language + ) + builder_mappings = [] + for n in encapsulating_method_invocations: + name = ( + n.child_by_field_name("name") + if language == JAVA + else n.children[0].children[2] + ) + if ( + name + and name.text.decode() + in relevant_builder_method_names_mapping.keys() + ): + builder_mappings.append( + ( + relevant_builder_method_names_mapping[name.text.decode()], + n.child_by_field_name("arguments"), + ) + ) + + builder_mapping = build_spark_session_builder(builder_mappings) + + outermost_node_builder_pattern = ( + encapsulating_method_invocations[-1] + if encapsulating_method_invocations + else spark_conf_init + ) + + node, name = get_enclosing_variable_declaration_name_type( + outermost_node_builder_pattern, language + ) + + if not (node and name): + print("Not in a variable declaration") + return tree, src_code + + declaration_replacement = get_declaration_replacement( + name, builder_mapping, language + ) + + state["spark_conf_name"] = name + + return rewrite(node, src_code, declaration_replacement, language) + + +def get_declaration_replacement(name, builder_mapping, language): + if language == JAVA: + return f"SparkSession {name} = {builder_mapping}.getOrCreate();" + else: + return f"val {name} = {builder_mapping}.getOrCreate()" + + +def update_spark_context_init( + tree: Tree, source_code: str, state: Dict[str, Any], language: str +): + if "spark_conf_name" not in state: + print("Needs the name of the variable holding the SparkConf") + return tree, source_code + spark_conf_name = state["spark_conf_name"] + init = get_initializer_named(tree, "JavaSparkContext", language) + if not init: + return tree, source_code + + node, name = get_enclosing_variable_declaration_name_type(init, language) + if node: + return rewrite( + node, + source_code, + f"SparkContext {name} = {spark_conf_name}.sparkContext()", + language + ) + else: + return rewrite(init, source_code, f"{spark_conf_name}.sparkContext()") + + +def get_setter_call(variable_name: str, tree: Tree, language: str) -> Optional[Node]: + for node in traverse_tree(tree): + if language == JAVA: + if node.type == "method_invocation": + name = node.child_by_field_name("name") + r = node.child_by_field_name("object") + if name and r: + name = name.text.decode() + r = r.text.decode() + if ( + r == variable_name + and name in relevant_builder_method_names_mapping.keys() + ): + return node + if language == SCALA: + if node.type == "call_expression": + _fn = node.child_by_field_name("function") + if not _fn: + continue + name = _fn.child_by_field_name("field") + r = _fn.child_by_field_name("value") + if name and r: + name = name.text.decode() + r = r.text.decode() + if ( + r == variable_name + and name in relevant_builder_method_names_mapping.keys() + ): + return node + + +def update_spark_conf_setters( + tree: Tree, source_code: str, state: Dict[str, Any], language: str +): + setter_call = get_setter_call(state["spark_conf_name"], tree, language) + if setter_call: + rcvr = state["spark_conf_name"] + invc = ( + setter_call.child_by_field_name("name") + if language == JAVA + else setter_call.children[0].children[2] + ) + args = setter_call.child_by_field_name("arguments") + if rcvr and invc and args: + new_fn = relevant_builder_method_names_mapping[invc.text.decode()] + replacement = f"{rcvr}.{new_fn}{args.text.decode()}" + return rewrite(setter_call, source_code, replacement, language) + return tree, source_code + + +def insert_import_statement( + tree: Tree, source_code: str, import_statement: str, language: str +): + for import_stmt in traverse_tree(tree): + if import_stmt.type == "import_declaration": + if import_stmt.text.decode() == f"import {import_statement}" + (";" if language == JAVA else ""): + return tree, source_code + + package_decl = [ + n + for n in traverse_tree(tree) + if n.type + == ("package_declaration" if language == JAVA else "package_clause") + ] + if not package_decl: + return tree, source_code + package_decl = package_decl[0] + if language == JAVA: + return rewrite( + package_decl, + source_code, + f"{package_decl.text.decode()}\nimport {import_statement};", + language + ) + return rewrite( + package_decl, + source_code, + f"{package_decl.text.decode()}\nimport {import_statement}", + language + ) + + +def run(language, source_code): + state = {} + no_change = False + while not no_change: + TREE: Tree = parse_code(language, source_code) + original_code = source_code + TREE, source_code = update_spark_conf_init( + TREE, source_code, state, language + ) + TREE, source_code = insert_import_statement( + TREE, source_code, "org.apache.spark.sql.SparkSession", language + ) + TREE, source_code = insert_import_statement( + TREE, source_code, "org.apache.spark.SparkContext", language + ) + TREE, source_code = update_spark_context_init( + TREE, source_code, state, language + ) + no_change = source_code == original_code + no_setter_found = False + while not no_setter_found: + b4_code = source_code + TREE, source_code = update_spark_conf_setters( + TREE, source_code, state, language + ) + no_setter_found = source_code == b4_code + return source_code + + +# run(JAVA, JAVA_SOURCE_CODE) +run(SCALA, SCALA_SOURCE_CODE) diff --git a/experimental/paper_experiments/utils.py b/experimental/paper_experiments/utils.py new file mode 100644 index 000000000..e5186bd1c --- /dev/null +++ b/experimental/paper_experiments/utils.py @@ -0,0 +1,133 @@ + +from tree_sitter import Language, Node, Parser, Tree +from tree_sitter_languages import get_parser + +JAVA = "java" +SCALA = "scala" + + +JAVA_SOURCE_CODE = """package com.piranha; + +public class Sample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("Sample App"); + + JavaSparkContext sc = new JavaSparkContext(conf); + + SparkConf conf1 = new SparkConf() + .setSparkHome(sparkHome) + .setExecutorEnv("spark.executor.extraClassPath", "test") + .setAppName(appName) + .setMaster(master) + .set("spark.driver.allowMultipleContexts", "true"); + + sc1 = new JavaSparkContext(conf1); + + + + var conf2 = new SparkConf(); + conf2.set("spark.driver.instances:", "100"); + conf2.setAppName(appName); + conf2.setSparkHome(sparkHome); + + sc2 = new JavaSparkContext(conf2); + + + } +} +""" + + + +SCALA_SOURCE_CODE = """package com.piranha + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession + +class Sample { + def main(args: Array[String]): Unit = { + + val conf= new SparkConf() + .setAppName("Sample App") + + val sc = new SparkContext(conf) + + val conf1 = new SparkConf() + .setMaster(master) + .setAll(Seq(("k2", "v2"), ("k3", "v3"))) + .setAppName(appName) + .setSparkHome(sparkHome) + .setExecutorEnv("spark.executor.extraClassPath", "test") + .set("spark.driver.allowMultipleContexts", "true") + sc1 = new SparkContext(conf1) + + + val conf2 = new SparkConf() + .setMaster(master) + + conf2.setSparkHome(sparkHome) + + conf2.setExecutorEnv("spark.executor.extraClassPath", "test") + + + + } + +} +""" + + +Language.build_library( + # Store the library in the `build` directory + 'build/my-languages.so', + + # Include one or more languages + [ + '/Users/ketkara/repositories/open-source/tree-sitter-scala', + '/Users/ketkara/repositories/open-source/tree-sitter-java', + ] +) + +SCALA_LANGUAGE = Language('build/my-languages.so', 'scala') +JAVA_LANGUAGE = Language('build/my-languages.so', 'java') + +def parse_code(language: str, source_code: str) -> Tree: + "Helper function to parse into tree sitter nodes" + parser = Parser() + parser.set_language(JAVA_LANGUAGE if language == JAVA else SCALA_LANGUAGE) + + source_tree = parser.parse(bytes(source_code, "utf8")) + return source_tree + +def traverse_tree(tree: Tree): + cursor = tree.walk() + + reached_root = False + while reached_root == False: + yield cursor.node + + if cursor.goto_first_child(): + continue + + if cursor.goto_next_sibling(): + continue + + retracing = True + while retracing: + if not cursor.goto_parent(): + retracing = False + reached_root = True + + if cursor.goto_next_sibling(): + retracing = False + + +def rewrite(node: Node, source_code: str, replacement: str, language: str): + new_source_code = ( + source_code[: node.start_byte] + + replacement + + source_code[node.end_byte :] + ) + print(new_source_code) + return parse_code(language, new_source_code), new_source_code diff --git a/plugins/spark_upgrade/spark_config/java_rules.py b/plugins/spark_upgrade/spark_config/java_rules.py index 0d6b2559b..d3c7dd233 100644 --- a/plugins/spark_upgrade/spark_config/java_rules.py +++ b/plugins/spark_upgrade/spark_config/java_rules.py @@ -1,5 +1,21 @@ from polyglot_piranha import Filter, OutgoingEdges, Rule + +def insert_import(rule_name, fq_type: str) -> Rule: + return Rule( + name=rule_name, + query="(package_declaration) @pkg", + replace_node="pkg", + replace=f"@pkg \n import {fq_type};", + is_seed_rule=False, + filters={ + Filter( + enclosing_node="(program) @cu", + not_contains=[f"cs import {fq_type};"], + ) + }, + ) + update_enclosing_var_declaration_java = Rule( name="update_enclosing_var_declaration_java", query="cs :[type] :[conf_var] = :[rhs];", @@ -9,19 +25,8 @@ groups={"update_enclosing_var_declaration"}, ) - -insert_import_spark_session_java = Rule( - name="insert_import_spark_session_java", - query="(package_declaration) @pkg", - replace_node="pkg", - replace="@pkg \n import org.apache.spark.sql.SparkSession;", - is_seed_rule=False, - filters={ - Filter( - enclosing_node="(program) @cu", - not_contains=["cs import org.apache.spark.sql.SparkSession;"], - ) - }, +insert_import_spark_session_java = insert_import( + "insert_import_spark_session_java", "org.apache.spark.sql.SparkSession" ) update_spark_context_java = Rule( @@ -45,22 +50,10 @@ groups={"update_spark_context"}, ) - -insert_import_spark_context_java = Rule( - name="insert_import_spark_context_java", - query="(package_declaration) @pkg", - replace_node="pkg", - replace="@pkg \n import org.apache.spark.SparkContext;", - is_seed_rule=False, - filters={ - Filter( - enclosing_node="(program) @cu", - not_contains=["cs import org.apache.spark.SparkContext;"], - ) - }, +insert_import_spark_context_java = insert_import( + "insert_import_spark_context_java", "org.apache.spark.SparkContext" ) - RULES = [ update_enclosing_var_declaration_java, insert_import_spark_session_java, @@ -69,7 +62,6 @@ update_spark_context_var_decl_lhs_java, ] - EDGES = [ OutgoingEdges( "update_enclosing_var_declaration_java", diff --git a/plugins/spark_upgrade/spark_config/java_scala_rules.py b/plugins/spark_upgrade/spark_config/java_scala_rules.py index 6ad85a5d9..02b5a6838 100644 --- a/plugins/spark_upgrade/spark_config/java_scala_rules.py +++ b/plugins/spark_upgrade/spark_config/java_scala_rules.py @@ -1,21 +1,5 @@ -# Copyright (c) 2023 Uber Technologies, Inc. +from polyglot_piranha import Rule -#
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file -# except in compliance with the License. You may obtain a copy of the License at -#
http://www.apache.org/licenses/LICENSE-2.0 - -#
Unless required by applicable law or agreed to in writing, software distributed under the -# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing permissions and -# limitations under the License. - -from polyglot_piranha import ( - Rule, -) - - -# Rules for transforming builder patterns -# Rule to transform EntropyCalculator() arguments spark_conf_change_java_scala = Rule( name="spark_conf_change_java_scala", query="cs new SparkConf()", @@ -24,211 +8,62 @@ holes={"spark_conf"}, ) -app_name_change_java_scala = Rule( - name="app_name_change_java_scala", - query="cs :[r].setAppName(:[app_name])", - replace_node="*", - replace="@r.appName(@app_name)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -master_name_change_java_scala = Rule( - name="master_name_change_java_scala", - query="cs :[r].setMaster(:[master])", - replace_node="*", - replace="@r.master(@master)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -setter_name_change_java_scala = Rule( - name="setter_name_change_java_scala", - query="cs :[r].set(:[a1],:[a2])", - replace_node="*", - replace="@r.config(@a1, @a2)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -set_all_change_java_scala = Rule( - name="set_all_change_java_scala", - query="cs :[r].setAll(:[a1])", - replace_node="*", - replace="@r.all(@a1)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -set_if_missing_java_scala = Rule( - name="set_if_missing_java_scala", - query="cs :[r].setIfMissing(:[a1], :[a2])", - replace_node="*", - replace="@r.ifMissing(@a1)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -set_jars_change_java_scala = Rule( - name="set_jars_change_java_scala", - query="cs :[r].setJars(:[a1])", - replace_node="*", - replace="@r.jars(@a1)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -set_executor_env_change_1_java_scala = Rule( - name="set_executor_env_change_1_java_scala", - query="cs :[r].setExecutorEnv(:[a1])", - replace_node="*", - replace="@r.executorEnv(@a1)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -set_executor_env_change_2_java_scala = Rule( - name="set_executor_env_change_2_java_scala", - query="cs :[r].setExecutorEnv(:[a1], :[a2])", - replace_node="*", - replace="@r.executorEnv(@a1, @a2)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -set_spark_home_change_java_scala = Rule( - name="set_spark_home_change_java_scala", - query="cs :[r].setSparkHome(:[a1])", - replace_node="*", - replace="@r.sparkHome(@a1)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -app_name_change_java_scala_stand_alone_call = Rule( - name="app_name_change_java_scala_stand_alone_call", - query="cs @conf_var.setAppName(:[app_name])", - replace_node="*", - replace="@conf_var.appName(@app_name)", - groups={"StandAloneCall"}, - holes= {"conf_var"}, - is_seed_rule=False, -) - -master_name_change_java_scala_stand_alone_call = Rule( - name="master_name_change_java_scala_stand_alone_call", - query="cs @conf_var.setMaster(:[master])", - replace_node="*", - replace="@conf_var.master(@master)", - groups={"StandAloneCall"}, - holes= {"conf_var"}, - is_seed_rule=False, -) - -setter_name_change_java_scala_stand_alone_call = Rule( - name="setter_name_change_java_scala_stand_alone_call", - query="cs @conf_var.set(:[a1],:[a2])", - replace_node="*", - replace="@conf_var.config(@a1, @a2)", - groups={"StandAloneCall"}, - holes= {"conf_var"}, - is_seed_rule=False, -) - -set_all_change_java_scala_stand_alone_call = Rule( - name="set_all_change_java_scala_stand_alone_call", - query="cs @conf_var.setAll(:[a1])", - replace_node="*", - replace="@conf_var.all(@a1)", - groups={"StandAloneCall"}, - holes= {"conf_var"}, - is_seed_rule=False, -) - -set_if_missing_java_scala_stand_alone_call = Rule( - name="set_if_missing_java_scala_stand_alone_call", - query="cs @conf_var.setIfMissing(:[a1], :[a2])", - replace_node="*", - replace="@conf_var.ifMissing(@a1)", - groups={"StandAloneCall"}, - holes= {"conf_var"}, - is_seed_rule=False, -) - -set_jars_change_java_scala_stand_alone_call = Rule( - name="set_jars_change_java_scala_stand_alone_call", - query="cs @conf_var.setJars(:[a1])", - replace_node="*", - replace="@conf_var.jars(@a1)", - groups={"StandAloneCall"}, - holes= {"conf_var"}, - is_seed_rule=False, -) - -set_executor_env_change_1_java_scala_stand_alone_call = Rule( - name="set_executor_env_change_1_java_scala_stand_alone_call", - query="cs @conf_var.setExecutorEnv(:[a1])", - replace_node="*", - replace="@conf_var.executorEnv(@a1)", - groups={"StandAloneCall"}, - holes= {"conf_var"}, - is_seed_rule=False, -) - -set_executor_env_change_2_java_scala_stand_alone_call = Rule( - name="set_executor_env_change_2_java_scala_stand_alone_call", - query="cs @conf_var.setExecutorEnv(:[a1], :[a2])", - replace_node="*", - replace="@conf_var.executorEnv(@a1, @a2)", - groups={"StandAloneCall"}, - holes= {"conf_var"}, - is_seed_rule=False, -) - -set_spark_home_change_java_scala_stand_alone_call = Rule( - name="set_spark_home_change_java_scala_stand_alone_call", - query="cs @conf_var.setSparkHome(:[a1])", - replace_node="*", - replace="@conf_var.sparkHome(@a1)", - groups={"StandAloneCall"}, - holes= {"conf_var"}, - is_seed_rule=False, -) - - - - -dummy = Rule(name="dummy", is_seed_rule=False) - -RULES = [ - # Transforms the initializer - spark_conf_change_java_scala, - - # Transforms the builder pattern - app_name_change_java_scala, - master_name_change_java_scala, - setter_name_change_java_scala, - set_all_change_java_scala, - set_if_missing_java_scala, - set_jars_change_java_scala, - set_executor_env_change_1_java_scala, - set_executor_env_change_2_java_scala, - set_spark_home_change_java_scala, - - # Transforms the stand alone calls - app_name_change_java_scala_stand_alone_call, - master_name_change_java_scala_stand_alone_call, - setter_name_change_java_scala_stand_alone_call, - set_all_change_java_scala_stand_alone_call, - set_if_missing_java_scala_stand_alone_call, - set_jars_change_java_scala_stand_alone_call, - set_executor_env_change_1_java_scala_stand_alone_call, - set_executor_env_change_2_java_scala_stand_alone_call, - set_spark_home_change_java_scala_stand_alone_call, - - - - dummy, -] - - +def get_setter_rules(name: str, query: str, replace: str) -> list[Rule]: + return [ + Rule( + name=name, + query=query.format(receiver=":[r]"), + replace_node="*", + replace=replace.format(receiver="@r"), + groups={"BuilderPattern"}, + is_seed_rule=False, + ), + Rule( + name=name + "_stand_alone_call", + query=query.format(receiver="@conf_var"), + replace_node="*", + replace=replace.format(receiver="@conf_var"), + holes={"conf_var"}, + groups={"StandAloneCall"}, + is_seed_rule=False, + ), + ] + + +RULES = [spark_conf_change_java_scala] + get_setter_rules( + "app_name_change_java_scala", + "cs {receiver}.setAppName(:[app_name])", + "{receiver}.appName(@app_name)", + ) + get_setter_rules( + "master_name_change_java_scala", + "cs {receiver}.setMaster(:[master])", + "{receiver}.master(@master)", + ) + get_setter_rules( + "setter_name_change_java_scala", + "cs {receiver}.set(:[a1],:[a2])", + "{receiver}.config(@a1, @a2)", + ) + get_setter_rules( + "set_all_change_java_scala", + "cs {receiver}.setAll(:[a1])", + "{receiver}.all(@a1)", + ) + get_setter_rules( + "set_if_missing_java_scala", + "cs {receiver}.setIfMissing(:[a1], :[a2])", + "{receiver}.ifMissing(@a1)", + ) + get_setter_rules( + "set_jars_change_java_scala", + "cs {receiver}.setJars(:[a1])", + "{receiver}.jars(@a1)", + ) + get_setter_rules( + "set_executor_env_2_change_java_scala", + "cs {receiver}.setExecutorEnv(:[a1], :[a2])", + "{receiver}.executorEnv(@a1, @a2)", + ) + get_setter_rules( + "set_executor_env_1_change_java_scala", + "cs {receiver}.setExecutorEnv(:[a1])", + "{receiver}.executorEnv(@a1)", + ) + get_setter_rules( + "set_spark_home_change_java_scala", + "cs {receiver}.setSparkHome(:[a1])", + "{receiver}.sparkHome(@a1)", + ) + [Rule(name="dummy", is_seed_rule=False)] diff --git a/plugins/spark_upgrade/spark_config/scala_rules.py b/plugins/spark_upgrade/spark_config/scala_rules.py index 1aac922e5..1fcfc6b90 100644 --- a/plugins/spark_upgrade/spark_config/scala_rules.py +++ b/plugins/spark_upgrade/spark_config/scala_rules.py @@ -1,7 +1,5 @@ - from polyglot_piranha import Rule - update_enclosing_var_declaration_scala = Rule( name="update_enclosing_var_declaration_scala", query="cs val :[conf_var] = :[rhs]", @@ -18,8 +16,7 @@ replace="@conf_var.sparkContext", holes={"conf_var"}, is_seed_rule=False, - groups={"update_spark_context"} + groups={"update_spark_context"}, ) - RULES = [update_enclosing_var_declaration_scala, update_spark_context_scala]