Skip to content

Commit

Permalink
Merge branch 'master' into baichuan_support
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii authored Dec 15, 2023
2 parents 6edb577 + 65b7727 commit acf40e7
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 12 deletions.
14 changes: 11 additions & 3 deletions deepspeed/launcher/multinode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,21 @@ def __init__(self, args, world_info_base64):
def backend_exists(self):
return shutil.which('pdsh')

def parse_user_args(self):
processed_args = []
for arg in self.args.user_args:
# With pdsh, if we are passing a string as an argument, it will get
# split on whitespace. To avoid this and support strings that
# contain '"', we do this extra processing step:
if " " in arg:
arg = '"{}"'.format(arg.replace('"', '\\"'))
processed_args.append(arg)
return processed_args

@property
def name(self):
return "pdsh"

def parse_user_args(self):
return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args))

def get_cmd(self, environment, active_resources):
environment['PDSH_RCMD_TYPE'] = 'ssh'
if self.args.ssh_port is not None: # only specify ssh port if it is specified
Expand Down
10 changes: 5 additions & 5 deletions deepspeed/launcher/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import os
import re
import sys
import shlex
import json
import base64
import argparse
Expand Down Expand Up @@ -389,9 +388,6 @@ def parse_num_nodes(str_num_nodes: str, elastic_training: bool):
def main(args=None):
args = parse_args(args)

# For when argparse interprets remaining args as a single string
args.user_args = shlex.split(" ".join(list(map(lambda x: x if x.startswith("-") else f'"{x}"', args.user_args))))

if args.elastic_training:
assert args.master_addr != "", "Master Addr is required when elastic training is enabled"

Expand Down Expand Up @@ -447,7 +443,11 @@ def main(args=None):
if not args.master_addr:
assert multi_node_exec
first_host = list(active_resources.keys())[0]
hostname_cmd = [f"ssh {first_host} hostname -I"]
ssh_check_cmd = "ssh "
if args.ssh_port is not None:
ssh_check_cmd += f" -p {args.ssh_port}"
ssh_check_cmd += f" {first_host} hostname -I"
hostname_cmd = [ssh_check_cmd]
try:
result = subprocess.check_output(hostname_cmd, shell=True)
except subprocess.CalledProcessError as err:
Expand Down
4 changes: 3 additions & 1 deletion deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ class Loading():

def is_load_module(module):
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
load_layer_names = ["LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm"]
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

def load_buffer(module, state_dict, prefix):
Expand Down
6 changes: 3 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,6 @@ def __init__(
# for debug purposes - can then debug print: debug_get_module_name(module)
debug_extract_module_and_param_names(model)

# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()}

self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
Expand All @@ -261,6 +258,9 @@ def __init__(
# Configure distributed model
self._configure_distributed_model(model)

# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()}

self._get_model_parameters()

see_memory_usage(f"DeepSpeed Engine: After configure distributed model")
Expand Down
64 changes: 64 additions & 0 deletions tests/unit/launcher/test_user_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import subprocess

from deepspeed.accelerator import get_accelerator

if not get_accelerator().is_available():
pytest.skip("only supported in accelerator environments.", allow_module_level=True)

user_arg_test_script = """import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str)
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--world_size", type=int, default=1)
args = parser.parse_args()
print("ARG PARSE SUCCESS")
"""


@pytest.fixture(scope="function")
def user_script_fp(tmpdir):
script_fp = tmpdir.join("user_arg_test.py")
with open(script_fp, "w") as f:
f.write(user_arg_test_script)
return script_fp


@pytest.fixture(scope="function")
def cmd(user_script_fp, prompt, multi_node):
if multi_node:
cmd = ("deepspeed", "--force_multi", "--num_nodes", "1", "--num_gpus", "1", user_script_fp, "--prompt", prompt)
else:
cmd = ("deepspeed", "--num_nodes", "1", "--num_gpus", "1", user_script_fp, "--prompt", prompt)
return cmd


@pytest.mark.parametrize("prompt", [
'''"I am 6' tall"''', """'I am 72" tall'""", """'"translate English to Romanian: "'""",
'''I'm going to tell them "DeepSpeed is the best"'''
])
@pytest.mark.parametrize("multi_node", [True, False])
def test_user_args(cmd):
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = p.communicate()
assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}"


def test_bash_string_args(tmpdir, user_script_fp):
bash_script = f"""
ARGS="--prompt 'DeepSpeed is the best'"
echo ${{ARGS}}|xargs deepspeed --num_nodes 1 --num_gpus 1 {user_script_fp}
"""

bash_fp = tmpdir.join("bash_script.sh")
with open(bash_fp, "w") as f:
f.write(bash_script)

p = subprocess.Popen(["bash", bash_fp], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = p.communicate()
assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}"

0 comments on commit acf40e7

Please sign in to comment.