diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py index 730146f5bcd2..07c70f3cc6c5 100644 --- a/deepspeed/launcher/multinode_runner.py +++ b/deepspeed/launcher/multinode_runner.py @@ -56,13 +56,21 @@ def __init__(self, args, world_info_base64): def backend_exists(self): return shutil.which('pdsh') + def parse_user_args(self): + processed_args = [] + for arg in self.args.user_args: + # With pdsh, if we are passing a string as an argument, it will get + # split on whitespace. To avoid this and support strings that + # contain '"', we do this extra processing step: + if " " in arg: + arg = '"{}"'.format(arg.replace('"', '\\"')) + processed_args.append(arg) + return processed_args + @property def name(self): return "pdsh" - def parse_user_args(self): - return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args)) - def get_cmd(self, environment, active_resources): environment['PDSH_RCMD_TYPE'] = 'ssh' if self.args.ssh_port is not None: # only specify ssh port if it is specified diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index a7fa2b5053e5..4f45e1831b48 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -12,7 +12,6 @@ import os import re import sys -import shlex import json import base64 import argparse @@ -389,9 +388,6 @@ def parse_num_nodes(str_num_nodes: str, elastic_training: bool): def main(args=None): args = parse_args(args) - # For when argparse interprets remaining args as a single string - args.user_args = shlex.split(" ".join(list(map(lambda x: x if x.startswith("-") else f'"{x}"', args.user_args)))) - if args.elastic_training: assert args.master_addr != "", "Master Addr is required when elastic training is enabled" @@ -447,7 +443,11 @@ def main(args=None): if not args.master_addr: assert multi_node_exec first_host = list(active_resources.keys())[0] - hostname_cmd = [f"ssh {first_host} hostname -I"] + ssh_check_cmd = "ssh " + if args.ssh_port is not None: + ssh_check_cmd += f" -p {args.ssh_port}" + ssh_check_cmd += f" {first_host} hostname -I" + hostname_cmd = [ssh_check_cmd] try: result = subprocess.check_output(hostname_cmd, shell=True) except subprocess.CalledProcessError as err: diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 6f545d4cb13b..b9a726bec67f 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -121,7 +121,9 @@ class Loading(): def is_load_module(module): load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm] - load_layer_names = ["LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm"] + load_layer_names = [ + "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear" + ] return module.__class__ in load_layers or module._get_name() in load_layer_names def load_buffer(module, state_dict, prefix): diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 44b44c79ba55..c5f4d3e6530d 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -232,9 +232,6 @@ def __init__( # for debug purposes - can then debug print: debug_get_module_name(module) debug_extract_module_and_param_names(model) - # needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict - self.param_names = {param: name for name, param in model.named_parameters()} - self._do_args_sanity_check(args) self._configure_with_arguments(args, mpu) self._do_sanity_check() @@ -261,6 +258,9 @@ def __init__( # Configure distributed model self._configure_distributed_model(model) + # needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict + self.param_names = {param: name for name, param in model.named_parameters()} + self._get_model_parameters() see_memory_usage(f"DeepSpeed Engine: After configure distributed model") diff --git a/tests/unit/launcher/test_user_args.py b/tests/unit/launcher/test_user_args.py new file mode 100644 index 000000000000..99afd0f2cfa7 --- /dev/null +++ b/tests/unit/launcher/test_user_args.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import subprocess + +from deepspeed.accelerator import get_accelerator + +if not get_accelerator().is_available(): + pytest.skip("only supported in accelerator environments.", allow_module_level=True) + +user_arg_test_script = """import argparse +parser = argparse.ArgumentParser() +parser.add_argument("--prompt", type=str) +parser.add_argument("--local_rank", type=int, default=0) +parser.add_argument("--world_size", type=int, default=1) +args = parser.parse_args() +print("ARG PARSE SUCCESS") +""" + + +@pytest.fixture(scope="function") +def user_script_fp(tmpdir): + script_fp = tmpdir.join("user_arg_test.py") + with open(script_fp, "w") as f: + f.write(user_arg_test_script) + return script_fp + + +@pytest.fixture(scope="function") +def cmd(user_script_fp, prompt, multi_node): + if multi_node: + cmd = ("deepspeed", "--force_multi", "--num_nodes", "1", "--num_gpus", "1", user_script_fp, "--prompt", prompt) + else: + cmd = ("deepspeed", "--num_nodes", "1", "--num_gpus", "1", user_script_fp, "--prompt", prompt) + return cmd + + +@pytest.mark.parametrize("prompt", [ + '''"I am 6' tall"''', """'I am 72" tall'""", """'"translate English to Romanian: "'""", + '''I'm going to tell them "DeepSpeed is the best"''' +]) +@pytest.mark.parametrize("multi_node", [True, False]) +def test_user_args(cmd): + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = p.communicate() + assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}" + + +def test_bash_string_args(tmpdir, user_script_fp): + bash_script = f""" + ARGS="--prompt 'DeepSpeed is the best'" + echo ${{ARGS}}|xargs deepspeed --num_nodes 1 --num_gpus 1 {user_script_fp} + """ + + bash_fp = tmpdir.join("bash_script.sh") + with open(bash_fp, "w") as f: + f.write(bash_script) + + p = subprocess.Popen(["bash", bash_fp], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = p.communicate() + assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}"