Skip to content

Commit

Permalink
Add support for shell import scripts in automation (#1157)
Browse files Browse the repository at this point in the history
* Add support for shell import scripts in automation

* support scripts without extension

* lint

* comment

* comment
  • Loading branch information
ajaits authored Dec 27, 2024
1 parent d948488 commit 50f4607
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 17 deletions.
66 changes: 51 additions & 15 deletions import-automation/executor/app/executor/import_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def _import_one_helper(
Args: See _import_one.
"""
import_name = import_spec['import_name']
urls = import_spec.get('data_download_url')
if urls:
for url in urls:
Expand Down Expand Up @@ -370,20 +371,23 @@ def _import_one_helper(
)
else:
# Run import script locally.
script_interpreter = _get_script_interpreter(
script_path, interpreter_path)
process = _run_user_script(
interpreter_path=interpreter_path,
interpreter_path=script_interpreter,
script_path=script_path,
timeout=self.config.user_script_timeout,
args=self.config.user_script_args,
cwd=absolute_import_dir,
env=self.config.user_script_env,
name=import_name,
)
_log_process(process=process)
process.check_returncode()

inputs = self._upload_import_inputs(
import_dir=absolute_import_dir,
output_dir=f'{relative_import_dir}/{import_spec["import_name"]}',
output_dir=f'{relative_import_dir}/{import_name}',
version=version,
import_spec=import_spec,
)
Expand Down Expand Up @@ -527,7 +531,8 @@ def run_and_handle_exception(
def _run_with_timeout_async(args: List[str],
timeout: float,
cwd: str = None,
env: dict = None) -> subprocess.CompletedProcess:
env: dict = None,
name: str = None) -> subprocess.CompletedProcess:
"""Runs a command in a subprocess asynchronously and emits the stdout/stderr.
Args:
Expand All @@ -542,9 +547,8 @@ def _run_with_timeout_async(args: List[str],
Same exceptions as subprocess.run.
"""
try:
logging.info(
f'Launching async command: {args} with timeout {timeout} in {cwd}, env:'
f' {env}')
logging.info(f'Launching async command for {name}: {args} '
f'with timeout {timeout} in {cwd}, env: {env}')
start_time = time.time()
stdout = []
stderr = []
Expand All @@ -559,18 +563,19 @@ def _run_with_timeout_async(args: List[str],
# Log output continuously until the command completes.
for line in process.stderr:
stderr.append(line)
logging.info(f'Process stderr: {line}')
logging.info(f'Process stderr:{name}: {line}')
for line in process.stdout:
stdout.append(line)
logging.info(f'Process stdout: {line}')
logging.info(f'Process stdout:{name}: {line}')

# Wait in case script has closed stderr/stdout early.
process.wait()
end_time = time.time()

return_code = process.returncode
end_msg = (
f'Completed script: "{args}", Return code: {return_code}, time:'
f' {end_time - start_time:.3f} secs.\n')
end_msg = (f'Completed script:{name}: "{args}", '
f'Return code: {return_code}, '
f'time: {end_time - start_time:.3f} secs.\n')
logging.info(end_msg)
return subprocess.CompletedProcess(
args=args,
Expand All @@ -581,8 +586,8 @@ def _run_with_timeout_async(args: List[str],
except Exception as e:
message = traceback.format_exc()
logging.exception(
f'An unexpected exception was thrown: {e} when running {args}:'
f' {message}')
f'An unexpected exception was thrown: {e} when running {name}:'
f'{args}: {message}')
return subprocess.CompletedProcess(
args=args,
returncode=1,
Expand Down Expand Up @@ -669,13 +674,41 @@ def _create_venv(requirements_path: Iterable[str], venv_dir: str,
return os.path.join(venv_dir, 'bin/python3'), process


def _get_script_interpreter(script: str, py_interpreter: str) -> str:
"""Returns the interpreter for the script.
Args:
script: user script to be executed
py_interpreter: Path to python within virtual environment
Returns:
interpreter for user script, such as python for .py, bash for .sh
Returns None if the script has no extension.
"""
if not script:
return None

base, ext = os.path.splitext(script.split(' ')[0])
match ext:
case '.py':
return py_interpreter
case '.sh':
return 'bash'
case _:
logging.info(f'Unknown extension for script: {script}.')
return None

return py_interpreter


def _run_user_script(
interpreter_path: str,
script_path: str,
timeout: float,
args: list = None,
cwd: str = None,
env: dict = None,
name: str = None,
) -> subprocess.CompletedProcess:
"""Runs a user Python script.
Expand All @@ -688,6 +721,7 @@ def _run_user_script(
the command line.
cwd: Current working directory of the process as a string.
env: Dict of environment variables for the user script run.
name: Name of the script.
Returns:
subprocess.CompletedProcess object used to run the script.
Expand All @@ -696,11 +730,13 @@ def _run_user_script(
subprocess.TimeoutExpired: The user script did not finish
within timeout.
"""
script_args = [interpreter_path]
script_args = []
if interpreter_path:
script_args.append(interpreter_path)
script_args.extend(script_path.split(' '))
if args:
script_args.extend(args)
return _run_with_timeout_async(script_args, timeout, cwd, env)
return _run_with_timeout_async(script_args, timeout, cwd, env, name)


def _clean_time(
Expand Down
7 changes: 5 additions & 2 deletions import-automation/executor/schedule_update_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
'A string specifying the path of an import in the following format:'
'<path_to_directory_relative_to_repository_root>:<import_name>.'
'Example: scripts/us_usda/quickstats:UsdaAgSurvey')
flags.DEFINE_string('config_override', _CONFIG_OVERRIDE_FILE,
'Config file with overridden parameters.')

_FLAGS(sys.argv)

Expand Down Expand Up @@ -301,8 +303,9 @@ def main(_):
cfg.gcp_project_id = _FLAGS.gke_project_id

logging.info(
f'Updating any config fields from local file: {_CONFIG_OVERRIDE_FILE}.')
cfg = _override_configs(_CONFIG_OVERRIDE_FILE, cfg)
f'Updating any config fields from local file: {_FLAGS.config_override}.'
)
cfg = _override_configs(_FLAGS.config_override, cfg)

logging.info('Reading Cloud scheduler configs from GCS.')
scheduler_config_dict = _get_cloud_config(_FLAGS.scheduler_config_filename)
Expand Down

0 comments on commit 50f4607

Please sign in to comment.