From 338286de9344dfa4e185966334eec0418136f282 Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Sun, 22 Oct 2023 19:51:08 -0700 Subject: [PATCH 01/12] Experiment tree-sitter imperative --- experimental/paper_experiments/spark2to3.py | 170 ++++++++++++++++++++ experimental/paper_experiments/utils.py | 78 +++++++++ 2 files changed, 248 insertions(+) create mode 100644 experimental/paper_experiments/spark2to3.py create mode 100644 experimental/paper_experiments/utils.py diff --git a/experimental/paper_experiments/spark2to3.py b/experimental/paper_experiments/spark2to3.py new file mode 100644 index 000000000..d97fee47d --- /dev/null +++ b/experimental/paper_experiments/spark2to3.py @@ -0,0 +1,170 @@ +from typing import Any, Dict, Optional, Tuple +from tree_sitter import Node, Tree +from utils import parse_code, traverse_tree, rewrite, SOURCE_CODE + + +relevant_builder_method_names_mapping = { + "setAppName": "appName", + "setMaster": "master", + "set": "config", + "setAll": "all", + "setIfMissing": "ifMissing", + "setJars": "jars", + "setExecutorEnv": "executorEnv", + "setSparkHome": "sparkHome", +} + + +def get_initializer_named(tree: Tree, name: str): + for node in traverse_tree(tree): + if node.type == "object_creation_expression": + oce_type = node.child_by_field_name("type") + if oce_type and oce_type.text.decode() == name: + return node + + +def get_enclosing_variable_declaration_name_type( + node: Node, +) -> Tuple[Node | None, str | None, str | None]: + name, typ, nd = None, None, None + if node.parent and node.parent.type == "variable_declarator": + n = node.parent.child_by_field_name("name") + if n: + name = n.text.decode() + if ( + node.parent.parent + and node.parent.parent.type == "local_variable_declaration" + ): + t = node.parent.parent.child_by_field_name("type") + if t: + typ = t.text.decode() + nd = node.parent.parent + return nd, name, typ + + +def all_enclosing_method_invocations(node: Node) -> list[Node]: + if node.parent and node.parent.type == "method_invocation": + return [node.parent] + all_enclosing_method_invocations(node.parent) + else: + return [] + + +def build_spark_session_builder(builder_mappings: list[tuple[str, Node]]): + replacement_expr = 'new SparkSession.builder().config("spark.sql.legacy.allowUntypedScalaUDF", "true")' + for name, args in builder_mappings: + replacement_expr += f".{name}{args.text.decode()}" + return replacement_expr + + +def update_spark_conf_init( + tree: Tree, src_code: str, state: Dict[str, Any] +) -> Tuple[Tree, str]: + spark_conf_init = get_initializer_named(tree, "SparkConf") + if not spark_conf_init: + print("No SparkConf initializer found") + return tree, src_code + + encapsulating_method_invocations = all_enclosing_method_invocations( + spark_conf_init + ) + builder_mappings = [] + for n in encapsulating_method_invocations: + name = n.child_by_field_name("name") + if ( + name + and name.text.decode() + in relevant_builder_method_names_mapping.keys() + ): + builder_mappings.append( + ( + relevant_builder_method_names_mapping[name.text.decode()], + n.child_by_field_name("arguments"), + ) + ) + + builder_mapping = build_spark_session_builder(builder_mappings) + + outermost_node_builder_pattern = ( + encapsulating_method_invocations[-1] + if encapsulating_method_invocations + else spark_conf_init + ) + + node, name, typ = get_enclosing_variable_declaration_name_type( + outermost_node_builder_pattern + ) + + if not (node and name and typ): + print("Not in a variable declaration") + return tree, src_code + + declaration_replacement = ( + f"SparkSession {name} = {builder_mapping}.getOrCreate();" + ) + + state["spark_conf_name"] = name + + return rewrite(node, src_code, declaration_replacement) + + +def update_spark_context_init( + tree: Tree, source_code: str, state: Dict[str, Any] +): + if "spark_conf_name" not in state: + print("Needs the name of the variable holding the SparkConf") + return tree, source_code + spark_conf_name = state["spark_conf_name"] + init = get_initializer_named(tree, "JavaSparkContext") + if not init: + return tree, source_code + + node, name, typ = get_enclosing_variable_declaration_name_type(init) + if node: + return rewrite( + node, + source_code, + f"SparkContext {name} = {spark_conf_name}.sparkContext()", + ) + else: + return rewrite(init, source_code, f"{spark_conf_name}.sparkContext()") + + +def get_setter_call(variable_name: str, tree: Tree) -> Optional[Node]: + for node in traverse_tree(tree): + if node.type == "method_invocation": + name = node.child_by_field_name("name") + r = node.child_by_field_name("object") + if name and r: + name = name.text.decode() + r = r.text.decode() + if r == variable_name and name in relevant_builder_method_names_mapping.keys(): + return node + + +def update_spark_conf_setters( + tree: Tree, source_code: str, state: Dict[str, Any] +): + setter_call = get_setter_call(state["spark_conf_name"], tree) + if setter_call: + rcvr = state["spark_conf_name"] + invc = setter_call.child_by_field_name("name") + args = setter_call.child_by_field_name("arguments") + if rcvr and invc and args: + new_fn = relevant_builder_method_names_mapping[invc.text.decode()] + replacement = f"{rcvr}.{new_fn}{args.text.decode()}" + return rewrite(setter_call, source_code, replacement) + return tree, source_code + +state = {} +no_change = False +while not no_change: + TREE: Tree = parse_code("java", SOURCE_CODE) + original_code = SOURCE_CODE + TREE, SOURCE_CODE = update_spark_conf_init(TREE, SOURCE_CODE, state) + TREE, SOURCE_CODE = update_spark_context_init(TREE, SOURCE_CODE, state) + no_change = SOURCE_CODE == original_code + no_setter_found = False + while not no_setter_found: + b4_code = SOURCE_CODE + TREE, SOURCE_CODE = update_spark_conf_setters(TREE, SOURCE_CODE, state) + no_setter_found = SOURCE_CODE == b4_code diff --git a/experimental/paper_experiments/utils.py b/experimental/paper_experiments/utils.py new file mode 100644 index 000000000..0f609388f --- /dev/null +++ b/experimental/paper_experiments/utils.py @@ -0,0 +1,78 @@ + +from tree_sitter import Node, Tree +from tree_sitter_languages import get_parser + + +SOURCE_CODE = """package com.piranha; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; + +public class Sample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("Sample App"); + + JavaSparkContext sc = new JavaSparkContext(conf); + + SparkConf conf1 = new SparkConf() + .setSparkHome(sparkHome) + .setExecutorEnv("spark.executor.extraClassPath", "test") + .setAppName(appName) + .setMaster(master) + .set("spark.driver.allowMultipleContexts", "true"); + + sc1 = new JavaSparkContext(conf1); + + + + var conf2 = new SparkConf(); + conf2.set("spark.driver.instances:", "100"); + conf2.setAppName(appName); + conf2.setSparkHome(sparkHome); + + sc2 = new JavaSparkContext(conf2); + + + } +} +""" + + +def parse_code(language: str, source_code: str) -> Tree: + "Helper function to parse into tree sitter nodes" + parser = get_parser(language) + source_tree = parser.parse(bytes(source_code, "utf8")) + return source_tree + +def traverse_tree(tree: Tree): + cursor = tree.walk() + + reached_root = False + while reached_root == False: + yield cursor.node + + if cursor.goto_first_child(): + continue + + if cursor.goto_next_sibling(): + continue + + retracing = True + while retracing: + if not cursor.goto_parent(): + retracing = False + reached_root = True + + if cursor.goto_next_sibling(): + retracing = False + + +def rewrite(node: Node, source_code: str, replacement: str): + new_source_code = ( + source_code[: node.start_byte] + + replacement + + source_code[node.end_byte :] + ) + print(new_source_code) + return parse_code("java", new_source_code), new_source_code From ad6e4c0a583ba61d41d85f3cc6f65bf9c7a59522 Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Mon, 23 Oct 2023 08:14:01 -0700 Subject: [PATCH 02/12] Experiment tree-sitter imperative --- experimental/paper_experiments/spark2to3.py | 33 ++++++-- .../spark_upgrade/spark_config/java_rules.py | 4 - .../spark_config/java_scala_rules.py | 84 ++++--------------- .../spark_upgrade/spark_config/scala_rules.py | 3 +- 4 files changed, 46 insertions(+), 78 deletions(-) diff --git a/experimental/paper_experiments/spark2to3.py b/experimental/paper_experiments/spark2to3.py index d97fee47d..48bb8999f 100644 --- a/experimental/paper_experiments/spark2to3.py +++ b/experimental/paper_experiments/spark2to3.py @@ -11,7 +11,7 @@ "setIfMissing": "ifMissing", "setJars": "jars", "setExecutorEnv": "executorEnv", - "setSparkHome": "sparkHome", + "setSparkHome": "sparkHome", } @@ -134,27 +134,50 @@ def get_setter_call(variable_name: str, tree: Tree) -> Optional[Node]: if node.type == "method_invocation": name = node.child_by_field_name("name") r = node.child_by_field_name("object") - if name and r: + if name and r: name = name.text.decode() r = r.text.decode() - if r == variable_name and name in relevant_builder_method_names_mapping.keys(): + if ( + r == variable_name + and name in relevant_builder_method_names_mapping.keys() + ): return node def update_spark_conf_setters( tree: Tree, source_code: str, state: Dict[str, Any] ): - setter_call = get_setter_call(state["spark_conf_name"], tree) + setter_call = get_setter_call(state["spark_conf_name"], tree) if setter_call: rcvr = state["spark_conf_name"] invc = setter_call.child_by_field_name("name") args = setter_call.child_by_field_name("arguments") if rcvr and invc and args: new_fn = relevant_builder_method_names_mapping[invc.text.decode()] - replacement = f"{rcvr}.{new_fn}{args.text.decode()}" + replacement = f"{rcvr}.{new_fn}{args.text.decode()}" return rewrite(setter_call, source_code, replacement) return tree, source_code + +def insert_import_statement( + tree: Tree, source_code: str, import_statement: str +): + for import_stmt in traverse_tree(tree): + if import_stmt.type == "import_declaration": + if import_stmt.text.decode() == import_statement: + return tree, source_code + + package_decl = tree.root_node.child_by_field_name("package_declaration") + if not package_decl: + return tree, source_code + + return rewrite( + package_decl, + source_code, + f"{package_decl.text.decode()}\n{import_statement}", + ) + + state = {} no_change = False while not no_change: diff --git a/plugins/spark_upgrade/spark_config/java_rules.py b/plugins/spark_upgrade/spark_config/java_rules.py index 0d6b2559b..c636ee848 100644 --- a/plugins/spark_upgrade/spark_config/java_rules.py +++ b/plugins/spark_upgrade/spark_config/java_rules.py @@ -9,7 +9,6 @@ groups={"update_enclosing_var_declaration"}, ) - insert_import_spark_session_java = Rule( name="insert_import_spark_session_java", query="(package_declaration) @pkg", @@ -45,7 +44,6 @@ groups={"update_spark_context"}, ) - insert_import_spark_context_java = Rule( name="insert_import_spark_context_java", query="(package_declaration) @pkg", @@ -60,7 +58,6 @@ }, ) - RULES = [ update_enclosing_var_declaration_java, insert_import_spark_session_java, @@ -69,7 +66,6 @@ update_spark_context_var_decl_lhs_java, ] - EDGES = [ OutgoingEdges( "update_enclosing_var_declaration_java", diff --git a/plugins/spark_upgrade/spark_config/java_scala_rules.py b/plugins/spark_upgrade/spark_config/java_scala_rules.py index 6ad85a5d9..b516ebbbd 100644 --- a/plugins/spark_upgrade/spark_config/java_scala_rules.py +++ b/plugins/spark_upgrade/spark_config/java_scala_rules.py @@ -1,21 +1,5 @@ -# Copyright (c) 2023 Uber Technologies, Inc. +from polyglot_piranha import Rule -#
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file -# except in compliance with the License. You may obtain a copy of the License at -#
http://www.apache.org/licenses/LICENSE-2.0 - -#
Unless required by applicable law or agreed to in writing, software distributed under the -# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing permissions and -# limitations under the License. - -from polyglot_piranha import ( - Rule, -) - - -# Rules for transforming builder patterns -# Rule to transform EntropyCalculator() arguments spark_conf_change_java_scala = Rule( name="spark_conf_change_java_scala", query="cs new SparkConf()", @@ -78,20 +62,11 @@ is_seed_rule=False, ) -set_executor_env_change_1_java_scala = Rule( +set_executor_env_change_java_scala = Rule( name="set_executor_env_change_1_java_scala", - query="cs :[r].setExecutorEnv(:[a1])", - replace_node="*", - replace="@r.executorEnv(@a1)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -set_executor_env_change_2_java_scala = Rule( - name="set_executor_env_change_2_java_scala", - query="cs :[r].setExecutorEnv(:[a1], :[a2])", + query="cs :[r].setExecutorEnv:[args]", replace_node="*", - replace="@r.executorEnv(@a1, @a2)", + replace="@r.executorEnv@args", groups={"BuilderPattern"}, is_seed_rule=False, ) @@ -111,7 +86,7 @@ replace_node="*", replace="@conf_var.appName(@app_name)", groups={"StandAloneCall"}, - holes= {"conf_var"}, + holes={"conf_var"}, is_seed_rule=False, ) @@ -121,7 +96,7 @@ replace_node="*", replace="@conf_var.master(@master)", groups={"StandAloneCall"}, - holes= {"conf_var"}, + holes={"conf_var"}, is_seed_rule=False, ) @@ -131,7 +106,7 @@ replace_node="*", replace="@conf_var.config(@a1, @a2)", groups={"StandAloneCall"}, - holes= {"conf_var"}, + holes={"conf_var"}, is_seed_rule=False, ) @@ -141,7 +116,7 @@ replace_node="*", replace="@conf_var.all(@a1)", groups={"StandAloneCall"}, - holes= {"conf_var"}, + holes={"conf_var"}, is_seed_rule=False, ) @@ -151,7 +126,7 @@ replace_node="*", replace="@conf_var.ifMissing(@a1)", groups={"StandAloneCall"}, - holes= {"conf_var"}, + holes={"conf_var"}, is_seed_rule=False, ) @@ -161,27 +136,17 @@ replace_node="*", replace="@conf_var.jars(@a1)", groups={"StandAloneCall"}, - holes= {"conf_var"}, + holes={"conf_var"}, is_seed_rule=False, ) -set_executor_env_change_1_java_scala_stand_alone_call = Rule( +set_executor_env_change_java_scala_stand_alone_call = Rule( name="set_executor_env_change_1_java_scala_stand_alone_call", - query="cs @conf_var.setExecutorEnv(:[a1])", + query="cs @conf_var.setExecutorEnv:[args]", replace_node="*", - replace="@conf_var.executorEnv(@a1)", + replace="@conf_var.executorEnv@args", groups={"StandAloneCall"}, - holes= {"conf_var"}, - is_seed_rule=False, -) - -set_executor_env_change_2_java_scala_stand_alone_call = Rule( - name="set_executor_env_change_2_java_scala_stand_alone_call", - query="cs @conf_var.setExecutorEnv(:[a1], :[a2])", - replace_node="*", - replace="@conf_var.executorEnv(@a1, @a2)", - groups={"StandAloneCall"}, - holes= {"conf_var"}, + holes={"conf_var"}, is_seed_rule=False, ) @@ -191,44 +156,29 @@ replace_node="*", replace="@conf_var.sparkHome(@a1)", groups={"StandAloneCall"}, - holes= {"conf_var"}, + holes={"conf_var"}, is_seed_rule=False, ) - - - dummy = Rule(name="dummy", is_seed_rule=False) RULES = [ - # Transforms the initializer spark_conf_change_java_scala, - - # Transforms the builder pattern app_name_change_java_scala, master_name_change_java_scala, setter_name_change_java_scala, set_all_change_java_scala, set_if_missing_java_scala, set_jars_change_java_scala, - set_executor_env_change_1_java_scala, - set_executor_env_change_2_java_scala, + set_executor_env_change_java_scala, set_spark_home_change_java_scala, - - # Transforms the stand alone calls app_name_change_java_scala_stand_alone_call, master_name_change_java_scala_stand_alone_call, setter_name_change_java_scala_stand_alone_call, set_all_change_java_scala_stand_alone_call, set_if_missing_java_scala_stand_alone_call, set_jars_change_java_scala_stand_alone_call, - set_executor_env_change_1_java_scala_stand_alone_call, - set_executor_env_change_2_java_scala_stand_alone_call, + set_executor_env_change_java_scala_stand_alone_call, set_spark_home_change_java_scala_stand_alone_call, - - - dummy, ] - - diff --git a/plugins/spark_upgrade/spark_config/scala_rules.py b/plugins/spark_upgrade/spark_config/scala_rules.py index 1aac922e5..05222767d 100644 --- a/plugins/spark_upgrade/spark_config/scala_rules.py +++ b/plugins/spark_upgrade/spark_config/scala_rules.py @@ -1,4 +1,3 @@ - from polyglot_piranha import Rule @@ -18,7 +17,7 @@ replace="@conf_var.sparkContext", holes={"conf_var"}, is_seed_rule=False, - groups={"update_spark_context"} + groups={"update_spark_context"}, ) From c44be8ea88c23f928294d4242beddd58b9910d43 Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Mon, 23 Oct 2023 08:39:00 -0700 Subject: [PATCH 03/12] . --- .../spark_config/java_scala_rules.py | 233 +++++------------- .../spark_upgrade/spark_config/scala_rules.py | 2 - 2 files changed, 59 insertions(+), 176 deletions(-) diff --git a/plugins/spark_upgrade/spark_config/java_scala_rules.py b/plugins/spark_upgrade/spark_config/java_scala_rules.py index b516ebbbd..02b5a6838 100644 --- a/plugins/spark_upgrade/spark_config/java_scala_rules.py +++ b/plugins/spark_upgrade/spark_config/java_scala_rules.py @@ -8,177 +8,62 @@ holes={"spark_conf"}, ) -app_name_change_java_scala = Rule( - name="app_name_change_java_scala", - query="cs :[r].setAppName(:[app_name])", - replace_node="*", - replace="@r.appName(@app_name)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -master_name_change_java_scala = Rule( - name="master_name_change_java_scala", - query="cs :[r].setMaster(:[master])", - replace_node="*", - replace="@r.master(@master)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -setter_name_change_java_scala = Rule( - name="setter_name_change_java_scala", - query="cs :[r].set(:[a1],:[a2])", - replace_node="*", - replace="@r.config(@a1, @a2)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -set_all_change_java_scala = Rule( - name="set_all_change_java_scala", - query="cs :[r].setAll(:[a1])", - replace_node="*", - replace="@r.all(@a1)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -set_if_missing_java_scala = Rule( - name="set_if_missing_java_scala", - query="cs :[r].setIfMissing(:[a1], :[a2])", - replace_node="*", - replace="@r.ifMissing(@a1)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -set_jars_change_java_scala = Rule( - name="set_jars_change_java_scala", - query="cs :[r].setJars(:[a1])", - replace_node="*", - replace="@r.jars(@a1)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -set_executor_env_change_java_scala = Rule( - name="set_executor_env_change_1_java_scala", - query="cs :[r].setExecutorEnv:[args]", - replace_node="*", - replace="@r.executorEnv@args", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -set_spark_home_change_java_scala = Rule( - name="set_spark_home_change_java_scala", - query="cs :[r].setSparkHome(:[a1])", - replace_node="*", - replace="@r.sparkHome(@a1)", - groups={"BuilderPattern"}, - is_seed_rule=False, -) - -app_name_change_java_scala_stand_alone_call = Rule( - name="app_name_change_java_scala_stand_alone_call", - query="cs @conf_var.setAppName(:[app_name])", - replace_node="*", - replace="@conf_var.appName(@app_name)", - groups={"StandAloneCall"}, - holes={"conf_var"}, - is_seed_rule=False, -) - -master_name_change_java_scala_stand_alone_call = Rule( - name="master_name_change_java_scala_stand_alone_call", - query="cs @conf_var.setMaster(:[master])", - replace_node="*", - replace="@conf_var.master(@master)", - groups={"StandAloneCall"}, - holes={"conf_var"}, - is_seed_rule=False, -) - -setter_name_change_java_scala_stand_alone_call = Rule( - name="setter_name_change_java_scala_stand_alone_call", - query="cs @conf_var.set(:[a1],:[a2])", - replace_node="*", - replace="@conf_var.config(@a1, @a2)", - groups={"StandAloneCall"}, - holes={"conf_var"}, - is_seed_rule=False, -) - -set_all_change_java_scala_stand_alone_call = Rule( - name="set_all_change_java_scala_stand_alone_call", - query="cs @conf_var.setAll(:[a1])", - replace_node="*", - replace="@conf_var.all(@a1)", - groups={"StandAloneCall"}, - holes={"conf_var"}, - is_seed_rule=False, -) - -set_if_missing_java_scala_stand_alone_call = Rule( - name="set_if_missing_java_scala_stand_alone_call", - query="cs @conf_var.setIfMissing(:[a1], :[a2])", - replace_node="*", - replace="@conf_var.ifMissing(@a1)", - groups={"StandAloneCall"}, - holes={"conf_var"}, - is_seed_rule=False, -) - -set_jars_change_java_scala_stand_alone_call = Rule( - name="set_jars_change_java_scala_stand_alone_call", - query="cs @conf_var.setJars(:[a1])", - replace_node="*", - replace="@conf_var.jars(@a1)", - groups={"StandAloneCall"}, - holes={"conf_var"}, - is_seed_rule=False, -) - -set_executor_env_change_java_scala_stand_alone_call = Rule( - name="set_executor_env_change_1_java_scala_stand_alone_call", - query="cs @conf_var.setExecutorEnv:[args]", - replace_node="*", - replace="@conf_var.executorEnv@args", - groups={"StandAloneCall"}, - holes={"conf_var"}, - is_seed_rule=False, -) - -set_spark_home_change_java_scala_stand_alone_call = Rule( - name="set_spark_home_change_java_scala_stand_alone_call", - query="cs @conf_var.setSparkHome(:[a1])", - replace_node="*", - replace="@conf_var.sparkHome(@a1)", - groups={"StandAloneCall"}, - holes={"conf_var"}, - is_seed_rule=False, -) - -dummy = Rule(name="dummy", is_seed_rule=False) - -RULES = [ - spark_conf_change_java_scala, - app_name_change_java_scala, - master_name_change_java_scala, - setter_name_change_java_scala, - set_all_change_java_scala, - set_if_missing_java_scala, - set_jars_change_java_scala, - set_executor_env_change_java_scala, - set_spark_home_change_java_scala, - app_name_change_java_scala_stand_alone_call, - master_name_change_java_scala_stand_alone_call, - setter_name_change_java_scala_stand_alone_call, - set_all_change_java_scala_stand_alone_call, - set_if_missing_java_scala_stand_alone_call, - set_jars_change_java_scala_stand_alone_call, - set_executor_env_change_java_scala_stand_alone_call, - set_spark_home_change_java_scala_stand_alone_call, - dummy, -] +def get_setter_rules(name: str, query: str, replace: str) -> list[Rule]: + return [ + Rule( + name=name, + query=query.format(receiver=":[r]"), + replace_node="*", + replace=replace.format(receiver="@r"), + groups={"BuilderPattern"}, + is_seed_rule=False, + ), + Rule( + name=name + "_stand_alone_call", + query=query.format(receiver="@conf_var"), + replace_node="*", + replace=replace.format(receiver="@conf_var"), + holes={"conf_var"}, + groups={"StandAloneCall"}, + is_seed_rule=False, + ), + ] + + +RULES = [spark_conf_change_java_scala] + get_setter_rules( + "app_name_change_java_scala", + "cs {receiver}.setAppName(:[app_name])", + "{receiver}.appName(@app_name)", + ) + get_setter_rules( + "master_name_change_java_scala", + "cs {receiver}.setMaster(:[master])", + "{receiver}.master(@master)", + ) + get_setter_rules( + "setter_name_change_java_scala", + "cs {receiver}.set(:[a1],:[a2])", + "{receiver}.config(@a1, @a2)", + ) + get_setter_rules( + "set_all_change_java_scala", + "cs {receiver}.setAll(:[a1])", + "{receiver}.all(@a1)", + ) + get_setter_rules( + "set_if_missing_java_scala", + "cs {receiver}.setIfMissing(:[a1], :[a2])", + "{receiver}.ifMissing(@a1)", + ) + get_setter_rules( + "set_jars_change_java_scala", + "cs {receiver}.setJars(:[a1])", + "{receiver}.jars(@a1)", + ) + get_setter_rules( + "set_executor_env_2_change_java_scala", + "cs {receiver}.setExecutorEnv(:[a1], :[a2])", + "{receiver}.executorEnv(@a1, @a2)", + ) + get_setter_rules( + "set_executor_env_1_change_java_scala", + "cs {receiver}.setExecutorEnv(:[a1])", + "{receiver}.executorEnv(@a1)", + ) + get_setter_rules( + "set_spark_home_change_java_scala", + "cs {receiver}.setSparkHome(:[a1])", + "{receiver}.sparkHome(@a1)", + ) + [Rule(name="dummy", is_seed_rule=False)] diff --git a/plugins/spark_upgrade/spark_config/scala_rules.py b/plugins/spark_upgrade/spark_config/scala_rules.py index 05222767d..1fcfc6b90 100644 --- a/plugins/spark_upgrade/spark_config/scala_rules.py +++ b/plugins/spark_upgrade/spark_config/scala_rules.py @@ -1,6 +1,5 @@ from polyglot_piranha import Rule - update_enclosing_var_declaration_scala = Rule( name="update_enclosing_var_declaration_scala", query="cs val :[conf_var] = :[rhs]", @@ -20,5 +19,4 @@ groups={"update_spark_context"}, ) - RULES = [update_enclosing_var_declaration_scala, update_spark_context_scala] From 044a2732069440a0a55879d8f5a832b7b033aec6 Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Mon, 23 Oct 2023 08:46:53 -0700 Subject: [PATCH 04/12] . --- .../spark_upgrade/spark_config/java_rules.py | 44 +++++++++---------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/plugins/spark_upgrade/spark_config/java_rules.py b/plugins/spark_upgrade/spark_config/java_rules.py index c636ee848..d3c7dd233 100644 --- a/plugins/spark_upgrade/spark_config/java_rules.py +++ b/plugins/spark_upgrade/spark_config/java_rules.py @@ -1,5 +1,21 @@ from polyglot_piranha import Filter, OutgoingEdges, Rule + +def insert_import(rule_name, fq_type: str) -> Rule: + return Rule( + name=rule_name, + query="(package_declaration) @pkg", + replace_node="pkg", + replace=f"@pkg \n import {fq_type};", + is_seed_rule=False, + filters={ + Filter( + enclosing_node="(program) @cu", + not_contains=[f"cs import {fq_type};"], + ) + }, + ) + update_enclosing_var_declaration_java = Rule( name="update_enclosing_var_declaration_java", query="cs :[type] :[conf_var] = :[rhs];", @@ -9,18 +25,8 @@ groups={"update_enclosing_var_declaration"}, ) -insert_import_spark_session_java = Rule( - name="insert_import_spark_session_java", - query="(package_declaration) @pkg", - replace_node="pkg", - replace="@pkg \n import org.apache.spark.sql.SparkSession;", - is_seed_rule=False, - filters={ - Filter( - enclosing_node="(program) @cu", - not_contains=["cs import org.apache.spark.sql.SparkSession;"], - ) - }, +insert_import_spark_session_java = insert_import( + "insert_import_spark_session_java", "org.apache.spark.sql.SparkSession" ) update_spark_context_java = Rule( @@ -44,18 +50,8 @@ groups={"update_spark_context"}, ) -insert_import_spark_context_java = Rule( - name="insert_import_spark_context_java", - query="(package_declaration) @pkg", - replace_node="pkg", - replace="@pkg \n import org.apache.spark.SparkContext;", - is_seed_rule=False, - filters={ - Filter( - enclosing_node="(program) @cu", - not_contains=["cs import org.apache.spark.SparkContext;"], - ) - }, +insert_import_spark_context_java = insert_import( + "insert_import_spark_context_java", "org.apache.spark.SparkContext" ) RULES = [ From b799d252210c42862f44eb41299c7c23612a1b12 Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Mon, 23 Oct 2023 11:16:56 -0700 Subject: [PATCH 05/12] . --- experimental/paper_experiments/spark2to3.py | 10 ++++++---- experimental/paper_experiments/utils.py | 3 --- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/experimental/paper_experiments/spark2to3.py b/experimental/paper_experiments/spark2to3.py index 48bb8999f..25570b486 100644 --- a/experimental/paper_experiments/spark2to3.py +++ b/experimental/paper_experiments/spark2to3.py @@ -164,17 +164,17 @@ def insert_import_statement( ): for import_stmt in traverse_tree(tree): if import_stmt.type == "import_declaration": - if import_stmt.text.decode() == import_statement: + if import_stmt.text.decode() == f'import {import_statement};': return tree, source_code - package_decl = tree.root_node.child_by_field_name("package_declaration") + package_decl = [n for n in traverse_tree(tree) if n.type == "package_declaration"] if not package_decl: return tree, source_code - + package_decl = package_decl[0] return rewrite( package_decl, source_code, - f"{package_decl.text.decode()}\n{import_statement}", + f"{package_decl.text.decode()}\nimport {import_statement};", ) @@ -184,6 +184,8 @@ def insert_import_statement( TREE: Tree = parse_code("java", SOURCE_CODE) original_code = SOURCE_CODE TREE, SOURCE_CODE = update_spark_conf_init(TREE, SOURCE_CODE, state) + TREE, SOURCE_CODE= insert_import_statement(TREE, SOURCE_CODE, "org.apache.spark.sql.SparkSession") + TREE, SOURCE_CODE= insert_import_statement(TREE, SOURCE_CODE, "org.apache.spark.SparkContext") TREE, SOURCE_CODE = update_spark_context_init(TREE, SOURCE_CODE, state) no_change = SOURCE_CODE == original_code no_setter_found = False diff --git a/experimental/paper_experiments/utils.py b/experimental/paper_experiments/utils.py index 0f609388f..c43f57a90 100644 --- a/experimental/paper_experiments/utils.py +++ b/experimental/paper_experiments/utils.py @@ -5,9 +5,6 @@ SOURCE_CODE = """package com.piranha; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; - public class Sample { public static void main(String[] args) { SparkConf conf = new SparkConf() From 0166a45b32ccfeda621a6c1dd455daaa0fa4ae7a Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Mon, 23 Oct 2023 20:51:16 -0700 Subject: [PATCH 06/12] . --- experimental/paper_experiments/spark2to3.py | 243 ++++++++++++++------ experimental/paper_experiments/utils.py | 68 +++++- 2 files changed, 232 insertions(+), 79 deletions(-) diff --git a/experimental/paper_experiments/spark2to3.py b/experimental/paper_experiments/spark2to3.py index 25570b486..85362dab3 100644 --- a/experimental/paper_experiments/spark2to3.py +++ b/experimental/paper_experiments/spark2to3.py @@ -1,6 +1,14 @@ from typing import Any, Dict, Optional, Tuple from tree_sitter import Node, Tree -from utils import parse_code, traverse_tree, rewrite, SOURCE_CODE +from utils import ( + JAVA, + SCALA_SOURCE_CODE, + parse_code, + traverse_tree, + rewrite, + JAVA_SOURCE_CODE, + SCALA, +) relevant_builder_method_names_mapping = { @@ -15,38 +23,64 @@ } -def get_initializer_named(tree: Tree, name: str): +def get_initializer_named( + tree: Tree, name: str, language: str +) -> Optional[Node]: for node in traverse_tree(tree): - if node.type == "object_creation_expression": - oce_type = node.child_by_field_name("type") - if oce_type and oce_type.text.decode() == name: - return node + if language == JAVA: + if node.type == "object_creation_expression": + oce_type = node.child_by_field_name("type") + if oce_type and oce_type.text.decode() == name: + return node + if language == SCALA: + if node.type == "instance_expression": + if any(c.text.decode() == name for c in node.children): + return node def get_enclosing_variable_declaration_name_type( - node: Node, -) -> Tuple[Node | None, str | None, str | None]: - name, typ, nd = None, None, None - if node.parent and node.parent.type == "variable_declarator": - n = node.parent.child_by_field_name("name") - if n: - name = n.text.decode() + node: Node, language: str +) -> Tuple[Node | None, str | None]: + name, nd = None, None + if language == JAVA: + if node.parent and node.parent.type == "variable_declarator": + n = node.parent.child_by_field_name("name") + if n: + name = n.text.decode() + if ( + node.parent.parent + and node.parent.parent.type == "local_variable_declaration" + ): + t = node.parent.parent.child_by_field_name("type") + if t: + nd = node.parent.parent + if language == SCALA: if ( - node.parent.parent - and node.parent.parent.type == "local_variable_declaration" + node.parent + and node.parent.type == "val_definition" ): - t = node.parent.parent.child_by_field_name("type") - if t: - typ = t.text.decode() - nd = node.parent.parent - return nd, name, typ - - -def all_enclosing_method_invocations(node: Node) -> list[Node]: - if node.parent and node.parent.type == "method_invocation": - return [node.parent] + all_enclosing_method_invocations(node.parent) + n = node.parent.child_by_field_name("pattern") + if n: + name = n.text.decode() + nd = node.parent + return nd, name + + +def all_enclosing_method_invocations(node: Node, language: str) -> list[Node]: + if language == JAVA: + if node.parent and node.parent.type == "method_invocation": + return [node.parent] + all_enclosing_method_invocations( + node.parent, language + ) + else: + return [] else: - return [] + if node.parent and node.parent.parent and node.parent.parent.type == "call_expression": + return [node.parent.parent] + all_enclosing_method_invocations( + node.parent.parent, language + ) + else: + return [] def build_spark_session_builder(builder_mappings: list[tuple[str, Node]]): @@ -57,19 +91,23 @@ def build_spark_session_builder(builder_mappings: list[tuple[str, Node]]): def update_spark_conf_init( - tree: Tree, src_code: str, state: Dict[str, Any] + tree: Tree, src_code: str, state: Dict[str, Any], language: str ) -> Tuple[Tree, str]: - spark_conf_init = get_initializer_named(tree, "SparkConf") + spark_conf_init = get_initializer_named(tree, "SparkConf", language) if not spark_conf_init: print("No SparkConf initializer found") return tree, src_code encapsulating_method_invocations = all_enclosing_method_invocations( - spark_conf_init + spark_conf_init, language ) builder_mappings = [] for n in encapsulating_method_invocations: - name = n.child_by_field_name("name") + name = ( + n.child_by_field_name("name") + if language == JAVA + else n.children[0].children[2] + ) if ( name and name.text.decode() @@ -90,106 +128,163 @@ def update_spark_conf_init( else spark_conf_init ) - node, name, typ = get_enclosing_variable_declaration_name_type( - outermost_node_builder_pattern + node, name = get_enclosing_variable_declaration_name_type( + outermost_node_builder_pattern, language ) - if not (node and name and typ): + if not (node and name): print("Not in a variable declaration") return tree, src_code - declaration_replacement = ( - f"SparkSession {name} = {builder_mapping}.getOrCreate();" + declaration_replacement = get_declaration_replacement( + name, builder_mapping, language ) state["spark_conf_name"] = name - return rewrite(node, src_code, declaration_replacement) + return rewrite(node, src_code, declaration_replacement, language) + + +def get_declaration_replacement(name, builder_mapping, language): + if language == JAVA: + return f"SparkSession {name} = {builder_mapping}.getOrCreate();" + else: + return f"val {name} = {builder_mapping}.getOrCreate()" def update_spark_context_init( - tree: Tree, source_code: str, state: Dict[str, Any] + tree: Tree, source_code: str, state: Dict[str, Any], language: str ): if "spark_conf_name" not in state: print("Needs the name of the variable holding the SparkConf") return tree, source_code spark_conf_name = state["spark_conf_name"] - init = get_initializer_named(tree, "JavaSparkContext") + init = get_initializer_named(tree, "JavaSparkContext", language) if not init: return tree, source_code - node, name, typ = get_enclosing_variable_declaration_name_type(init) + node, name = get_enclosing_variable_declaration_name_type(init, language) if node: return rewrite( node, source_code, f"SparkContext {name} = {spark_conf_name}.sparkContext()", + language ) else: return rewrite(init, source_code, f"{spark_conf_name}.sparkContext()") -def get_setter_call(variable_name: str, tree: Tree) -> Optional[Node]: +def get_setter_call(variable_name: str, tree: Tree, language: str) -> Optional[Node]: for node in traverse_tree(tree): - if node.type == "method_invocation": - name = node.child_by_field_name("name") - r = node.child_by_field_name("object") - if name and r: - name = name.text.decode() - r = r.text.decode() - if ( - r == variable_name - and name in relevant_builder_method_names_mapping.keys() - ): - return node + if language == JAVA: + if node.type == "method_invocation": + name = node.child_by_field_name("name") + r = node.child_by_field_name("object") + if name and r: + name = name.text.decode() + r = r.text.decode() + if ( + r == variable_name + and name in relevant_builder_method_names_mapping.keys() + ): + return node + if language == SCALA: + if node.type == "call_expression": + _fn = node.child_by_field_name("function") + if not _fn: + continue + name = _fn.child_by_field_name("field") + r = _fn.child_by_field_name("value") + if name and r: + name = name.text.decode() + r = r.text.decode() + if ( + r == variable_name + and name in relevant_builder_method_names_mapping.keys() + ): + return node def update_spark_conf_setters( - tree: Tree, source_code: str, state: Dict[str, Any] + tree: Tree, source_code: str, state: Dict[str, Any], language: str ): - setter_call = get_setter_call(state["spark_conf_name"], tree) + setter_call = get_setter_call(state["spark_conf_name"], tree, language) if setter_call: rcvr = state["spark_conf_name"] - invc = setter_call.child_by_field_name("name") + invc = ( + setter_call.child_by_field_name("name") + if language == JAVA + else setter_call.children[0].children[2] + ) args = setter_call.child_by_field_name("arguments") if rcvr and invc and args: new_fn = relevant_builder_method_names_mapping[invc.text.decode()] replacement = f"{rcvr}.{new_fn}{args.text.decode()}" - return rewrite(setter_call, source_code, replacement) + return rewrite(setter_call, source_code, replacement, language) return tree, source_code def insert_import_statement( - tree: Tree, source_code: str, import_statement: str + tree: Tree, source_code: str, import_statement: str, language: str ): for import_stmt in traverse_tree(tree): if import_stmt.type == "import_declaration": - if import_stmt.text.decode() == f'import {import_statement};': + if import_stmt.text.decode() == f"import {import_statement}" + (";" if language == JAVA else ""): return tree, source_code - package_decl = [n for n in traverse_tree(tree) if n.type == "package_declaration"] + package_decl = [ + n + for n in traverse_tree(tree) + if n.type + == ("package_declaration" if language == JAVA else "package_clause") + ] if not package_decl: return tree, source_code package_decl = package_decl[0] + if language == JAVA: + return rewrite( + package_decl, + source_code, + f"{package_decl.text.decode()}\nimport {import_statement};", + language + ) return rewrite( package_decl, source_code, - f"{package_decl.text.decode()}\nimport {import_statement};", + f"{package_decl.text.decode()}\nimport {import_statement}", + language ) -state = {} -no_change = False -while not no_change: - TREE: Tree = parse_code("java", SOURCE_CODE) - original_code = SOURCE_CODE - TREE, SOURCE_CODE = update_spark_conf_init(TREE, SOURCE_CODE, state) - TREE, SOURCE_CODE= insert_import_statement(TREE, SOURCE_CODE, "org.apache.spark.sql.SparkSession") - TREE, SOURCE_CODE= insert_import_statement(TREE, SOURCE_CODE, "org.apache.spark.SparkContext") - TREE, SOURCE_CODE = update_spark_context_init(TREE, SOURCE_CODE, state) - no_change = SOURCE_CODE == original_code - no_setter_found = False - while not no_setter_found: - b4_code = SOURCE_CODE - TREE, SOURCE_CODE = update_spark_conf_setters(TREE, SOURCE_CODE, state) - no_setter_found = SOURCE_CODE == b4_code +def run(language, source_code): + state = {} + no_change = False + while not no_change: + TREE: Tree = parse_code(language, source_code) + original_code = source_code + TREE, source_code = update_spark_conf_init( + TREE, source_code, state, language + ) + TREE, source_code = insert_import_statement( + TREE, source_code, "org.apache.spark.sql.SparkSession", language + ) + TREE, source_code = insert_import_statement( + TREE, source_code, "org.apache.spark.SparkContext", language + ) + TREE, source_code = update_spark_context_init( + TREE, source_code, state, language + ) + no_change = source_code == original_code + no_setter_found = False + while not no_setter_found: + b4_code = source_code + TREE, source_code = update_spark_conf_setters( + TREE, source_code, state, language + ) + no_setter_found = source_code == b4_code + return source_code + + +# run(JAVA, JAVA_SOURCE_CODE) +run(SCALA, SCALA_SOURCE_CODE) diff --git a/experimental/paper_experiments/utils.py b/experimental/paper_experiments/utils.py index c43f57a90..e5186bd1c 100644 --- a/experimental/paper_experiments/utils.py +++ b/experimental/paper_experiments/utils.py @@ -1,9 +1,12 @@ -from tree_sitter import Node, Tree +from tree_sitter import Language, Node, Parser, Tree from tree_sitter_languages import get_parser +JAVA = "java" +SCALA = "scala" -SOURCE_CODE = """package com.piranha; + +JAVA_SOURCE_CODE = """package com.piranha; public class Sample { public static void main(String[] args) { @@ -36,9 +39,64 @@ """ + +SCALA_SOURCE_CODE = """package com.piranha + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession + +class Sample { + def main(args: Array[String]): Unit = { + + val conf= new SparkConf() + .setAppName("Sample App") + + val sc = new SparkContext(conf) + + val conf1 = new SparkConf() + .setMaster(master) + .setAll(Seq(("k2", "v2"), ("k3", "v3"))) + .setAppName(appName) + .setSparkHome(sparkHome) + .setExecutorEnv("spark.executor.extraClassPath", "test") + .set("spark.driver.allowMultipleContexts", "true") + sc1 = new SparkContext(conf1) + + + val conf2 = new SparkConf() + .setMaster(master) + + conf2.setSparkHome(sparkHome) + + conf2.setExecutorEnv("spark.executor.extraClassPath", "test") + + + + } + +} +""" + + +Language.build_library( + # Store the library in the `build` directory + 'build/my-languages.so', + + # Include one or more languages + [ + '/Users/ketkara/repositories/open-source/tree-sitter-scala', + '/Users/ketkara/repositories/open-source/tree-sitter-java', + ] +) + +SCALA_LANGUAGE = Language('build/my-languages.so', 'scala') +JAVA_LANGUAGE = Language('build/my-languages.so', 'java') + def parse_code(language: str, source_code: str) -> Tree: "Helper function to parse into tree sitter nodes" - parser = get_parser(language) + parser = Parser() + parser.set_language(JAVA_LANGUAGE if language == JAVA else SCALA_LANGUAGE) + source_tree = parser.parse(bytes(source_code, "utf8")) return source_tree @@ -65,11 +123,11 @@ def traverse_tree(tree: Tree): retracing = False -def rewrite(node: Node, source_code: str, replacement: str): +def rewrite(node: Node, source_code: str, replacement: str, language: str): new_source_code = ( source_code[: node.start_byte] + replacement + source_code[node.end_byte :] ) print(new_source_code) - return parse_code("java", new_source_code), new_source_code + return parse_code(language, new_source_code), new_source_code From 1e87f107170bd4f9a540bb84e8495682a951abcd Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Tue, 24 Oct 2023 08:29:04 -0700 Subject: [PATCH 07/12] . --- .../paper_experiments/query_all_repos.py | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 experimental/paper_experiments/query_all_repos.py diff --git a/experimental/paper_experiments/query_all_repos.py b/experimental/paper_experiments/query_all_repos.py new file mode 100644 index 000000000..5cbcffbe9 --- /dev/null +++ b/experimental/paper_experiments/query_all_repos.py @@ -0,0 +1,116 @@ +import requests +import argparse + + +# GitHub API endpoint for repository search +GITHUB_REPO_URL = "https://api.github.com/search/repositories" +# GitHub API endpoint for code search +GITHUB_CODE_SEARCH_URL = "https://api.github.com/search/code" + +# # String to search for in the repositories +# search_string = "new SparkConf()" + + +# parses the command line arguments +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Searches for a string in the top starred Java and Scala repositories on GitHub" + ) + parser.add_argument( + "search_string", + help="String to search for in the repositories", + ) + parser.add_argument("token", help="Github CLI token") + parser.add_argument( + "--languages", + nargs="+", + default=["Java", "Scala"], + help="List of languages to search for", + ) + parser.add_argument( + "--per_page", + default=100, + help="Number of results per page (max 100)", + ) + parser.add_argument( + "--page", + default=1, + help="Page number", + ) + return parser.parse_args() + + +def lookup_keyword_in_repos(headers, search_string, repos): + for repo in repos: + code_params = { + "q": f"{search_string} repo:{repo['owner']}/{repo['name']}", + } + + code_response = requests.get( + "https://api.github.com/search/code", + params=code_params, + headers=headers, + ) + + if code_response.status_code == 200: + code_data = code_response.json() + if code_data["total_count"] > 0: + print(f"{repo['owner']}/{repo['name']} - {code_data['total_count']}") + + +def get_repo_info(response_json): + repositories = [] + for item in response_json["items"]: + repositories.append( + {"name": item["name"], "owner": item["owner"]["login"]} + ) + return repositories + + +def search(token, search_string, languages): + # Set up the headers with your token + headers = {"Authorization": f"Bearer {token}"} + # Parameters for the repository search + _lang_clause= " ".join([f"language:{l}" for l in languages]) + repo_params = { + "q": f"stars:>100 {_lang_clause}", + "sort": "stars", + "order": "desc", + "per_page": 100, # Number of results per page (max 100) + "page": 1, # Page number + } + + try: + # Fetch the top starred repositories + while True: + repositories = [] + response = requests.get( + GITHUB_REPO_URL, params=repo_params, headers=headers + ) + + if response.status_code == 200: + + data = response.json() + repositories = get_repo_info(data) + + if "next" not in response.links: + break + + repo_params["page"] += 1 + + else: + print( + f"Repository request failed with status code {response.status_code}" + ) + print(response.text) + break + + lookup_keyword_in_repos(headers, search_string, repositories) + + except requests.exceptions.RequestException as e: + print(f"An error occurred: {e}") + + + +args = parse_arguments() +search(args.token, args.search_string, args.languages) From 59c348143ce3b93ce8e968536a231fbd8e5d1d06 Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Tue, 24 Oct 2023 11:10:35 -0700 Subject: [PATCH 08/12] . --- .../paper_experiments/query_all_repos.py | 89 ++++++++++++++----- 1 file changed, 69 insertions(+), 20 deletions(-) diff --git a/experimental/paper_experiments/query_all_repos.py b/experimental/paper_experiments/query_all_repos.py index 5cbcffbe9..d6bfdc031 100644 --- a/experimental/paper_experiments/query_all_repos.py +++ b/experimental/paper_experiments/query_all_repos.py @@ -1,3 +1,4 @@ +from typing import List import requests import argparse @@ -37,13 +38,38 @@ def parse_arguments(): default=1, help="Page number", ) + # argument to specify the path to write the results + parser.add_argument( + "--output", + default="results.csv", + help="Path to write the results", + ) return parser.parse_args() -def lookup_keyword_in_repos(headers, search_string, repos): - for repo in repos: +class MatchedRepo: + owner: str + name: str + stars: int + search_string: str + token: str + languages: list[str] + number_of_matches: int = 0 + files: list[str] = [] + + def __init__(self, name, owner, stars, search_string, token, languages): + self.name = name + self.owner = owner + self.stars = stars + self.search_string = search_string + self.token = token + self.languages = languages + self.lookup() + + def lookup(self): + headers = {"Authorization": f"Bearer {self.token}"} code_params = { - "q": f"{search_string} repo:{repo['owner']}/{repo['name']}", + "q": f"{self.search_string} repo:{self.owner}/{self.name}", } code_response = requests.get( @@ -51,27 +77,39 @@ def lookup_keyword_in_repos(headers, search_string, repos): params=code_params, headers=headers, ) - if code_response.status_code == 200: code_data = code_response.json() - if code_data["total_count"] > 0: - print(f"{repo['owner']}/{repo['name']} - {code_data['total_count']}") - - -def get_repo_info(response_json): + self.number_of_matches = code_data["total_count"] + self.files = [ + item["path"] + for item in code_data["items"] + if any(l for l in self.languages if l in item["path"]) + ] + def to_csv(self): + files = "|".join(self.files) + return f"{self.name}, {self.owner}, {self.stars}, {self.number_of_matches}, {files}" + + +def get_repo_info( + response_json, search_string, token, languages +) -> List[MatchedRepo]: repositories = [] for item in response_json["items"]: + name = item["name"] + owner = item["owner"]["login"] + stars = item["stargazers_count"] + repositories.append( - {"name": item["name"], "owner": item["owner"]["login"]} + MatchedRepo(name, owner, stars, search_string, token, languages) ) return repositories -def search(token, search_string, languages): +def search(token, search_string, languages, output_csv): # Set up the headers with your token headers = {"Authorization": f"Bearer {token}"} # Parameters for the repository search - _lang_clause= " ".join([f"language:{l}" for l in languages]) + _lang_clause = " ".join([f"language:{l}" for l in languages]) repo_params = { "q": f"stars:>100 {_lang_clause}", "sort": "stars", @@ -89,13 +127,22 @@ def search(token, search_string, languages): ) if response.status_code == 200: - data = response.json() - repositories = get_repo_info(data) - + repositories = get_repo_info( + response_json=data, + search_string=search_string, + token=token, + languages=languages, + ) + for r in repositories: + if r.number_of_matches > 0: + entry = r.to_csv() + with open(output_csv, "a+") as f: + f.write(entry + "\n") + print(entry) if "next" not in response.links: break - + repo_params["page"] += 1 else: @@ -104,13 +151,15 @@ def search(token, search_string, languages): ) print(response.text) break - - lookup_keyword_in_repos(headers, search_string, repositories) except requests.exceptions.RequestException as e: print(f"An error occurred: {e}") - args = parse_arguments() -search(args.token, args.search_string, args.languages) +search( + token=args.token, + search_string=args.search_string, + languages=args.languages, + output_csv=args.output, +) From 26afbbd4a19a7c52832b774ddb61a45729788846 Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Tue, 24 Oct 2023 11:20:01 -0700 Subject: [PATCH 09/12] . --- experimental/paper_experiments/query_all_repos.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/experimental/paper_experiments/query_all_repos.py b/experimental/paper_experiments/query_all_repos.py index d6bfdc031..706be9fce 100644 --- a/experimental/paper_experiments/query_all_repos.py +++ b/experimental/paper_experiments/query_all_repos.py @@ -115,12 +115,15 @@ def search(token, search_string, languages, output_csv): "sort": "stars", "order": "desc", "per_page": 100, # Number of results per page (max 100) - "page": 1, # Page number + # "page": 1, # Page number } - + counter = 10 try: # Fetch the top starred repositories while True: + counter += 1 + if counter > 10: + break repositories = [] response = requests.get( GITHUB_REPO_URL, params=repo_params, headers=headers From c55bb3419b84d0e30d4cfc34f3216f9424969f26 Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Tue, 24 Oct 2023 11:20:42 -0700 Subject: [PATCH 10/12] . --- experimental/paper_experiments/query_all_repos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/paper_experiments/query_all_repos.py b/experimental/paper_experiments/query_all_repos.py index 706be9fce..e54d0f03d 100644 --- a/experimental/paper_experiments/query_all_repos.py +++ b/experimental/paper_experiments/query_all_repos.py @@ -117,7 +117,7 @@ def search(token, search_string, languages, output_csv): "per_page": 100, # Number of results per page (max 100) # "page": 1, # Page number } - counter = 10 + counter = 0 try: # Fetch the top starred repositories while True: From ed7ffc2500fd4890d63a2ff0122a5adc0698ecf5 Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Tue, 24 Oct 2023 11:27:15 -0700 Subject: [PATCH 11/12] . --- experimental/paper_experiments/query_all_repos.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/experimental/paper_experiments/query_all_repos.py b/experimental/paper_experiments/query_all_repos.py index e54d0f03d..63531de16 100644 --- a/experimental/paper_experiments/query_all_repos.py +++ b/experimental/paper_experiments/query_all_repos.py @@ -83,11 +83,11 @@ def lookup(self): self.files = [ item["path"] for item in code_data["items"] - if any(l for l in self.languages if l in item["path"]) + if any(l for l in self.languages if l.lower() in item["path"]) ] def to_csv(self): - files = "|".join(self.files) - return f"{self.name}, {self.owner}, {self.stars}, {self.number_of_matches}, {files}" + # files = "|".join(self.files) + return f"{self.name}, {self.owner}, {self.stars}, {self.number_of_matches}"#, {files}" def get_repo_info( @@ -115,9 +115,8 @@ def search(token, search_string, languages, output_csv): "sort": "stars", "order": "desc", "per_page": 100, # Number of results per page (max 100) - # "page": 1, # Page number } - counter = 0 + counter = 1 try: # Fetch the top starred repositories while True: @@ -125,6 +124,7 @@ def search(token, search_string, languages, output_csv): if counter > 10: break repositories = [] + repo_params["page"] = counter response = requests.get( GITHUB_REPO_URL, params=repo_params, headers=headers ) @@ -146,8 +146,6 @@ def search(token, search_string, languages, output_csv): if "next" not in response.links: break - repo_params["page"] += 1 - else: print( f"Repository request failed with status code {response.status_code}" From c0146f63e65946ea931e88511992121252953115 Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Tue, 24 Oct 2023 12:28:54 -0700 Subject: [PATCH 12/12] . --- experimental/paper_experiments/query_all_repos.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/experimental/paper_experiments/query_all_repos.py b/experimental/paper_experiments/query_all_repos.py index 63531de16..0e29d7bc5 100644 --- a/experimental/paper_experiments/query_all_repos.py +++ b/experimental/paper_experiments/query_all_repos.py @@ -116,15 +116,16 @@ def search(token, search_string, languages, output_csv): "order": "desc", "per_page": 100, # Number of results per page (max 100) } - counter = 1 + counter = 0 try: # Fetch the top starred repositories while True: counter += 1 - if counter > 10: + if counter > 70: break repositories = [] repo_params["page"] = counter + print(counter) response = requests.get( GITHUB_REPO_URL, params=repo_params, headers=headers )