diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml index f81e690e835b..ac19638e67de 100644 --- a/.github/workflows/hpu-gaudi2.yml +++ b/.github/workflows/hpu-gaudi2.yml @@ -39,13 +39,14 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.15.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + image: vault.habana.ai/gaudi-docker/1.16.2/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice env: PT_HPU_LAZY_MODE: 0 + TORCHINDUCTOR_COMPILE_THREADS: 1 TEST_LIST: | test_accelerator.py test_autotuning.py @@ -103,7 +104,7 @@ jobs: - name: Check container state run: | ldd --version - hl-smi + hl-smi -L python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" @@ -128,7 +129,7 @@ jobs: unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests export PT_HPU_LAZY_MODE=${PT_HPU_LAZY_MODE} + export TORCHINDUCTOR_COMPILE_THREADS=${TORCHINDUCTOR_COMPILE_THREADS} TEST_LIST=$(echo "$TEST_LIST" | awk 'NF{printf "%s%s", (NR>1 ? " or " : ""), $0} END{if (NR>1) print ""}') echo "TEST_LIST ${TEST_LIST}" - echo "PT_HPU_LAZY_MODE ${PT_HPU_LAZY_MODE}" pytest --verbose unit/ -k "${TEST_LIST}" diff --git a/deepspeed/linear/__init__.py b/deepspeed/linear/__init__.py index a27f1c3eaee7..9931a95a0a40 100644 --- a/deepspeed/linear/__init__.py +++ b/deepspeed/linear/__init__.py @@ -5,3 +5,4 @@ from .optimized_linear import OptimizedLinear from .config import LoRAConfig, QuantizationConfig +from .context_manager import Init, init_lora diff --git a/deepspeed/linear/config.py b/deepspeed/linear/config.py index ae9050a3c92b..2632ce7de9c4 100644 --- a/deepspeed/linear/config.py +++ b/deepspeed/linear/config.py @@ -3,7 +3,8 @@ # DeepSpeed Team -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import List @dataclass @@ -17,10 +18,19 @@ class LoRAConfig: base_weight_sharding (int): The degree to which the base weights are sharded, should typically be set to the data-parallel world size to maximize the memory reduction benefits. Defaults to 1, which means this feature is disabled. + offload (bool): offload frozen parameters to cpu when not in use + offload_ratio (float): ratio of parameters to offload to cpu when not in use + delay_lora_init (bool): initialize lora parameters at time of model init or allow manual init later + target_mods (str): target module names to apply LoRA to, defaults to llama-3.1 arch """ lora_r: int = 64 lora_alpha: float = 16. base_weight_sharding: int = 1 + offload: bool = False + offload_ratio: float = 0.0 + delay_lora_init: bool = False + target_mods: List[str] = field( + default_factory=lambda: ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']) @dataclass diff --git a/deepspeed/linear/context_manager.py b/deepspeed/linear/context_manager.py new file mode 100644 index 000000000000..204fa0fe9c1d --- /dev/null +++ b/deepspeed/linear/context_manager.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .optimized_linear import LoRAOptimizedLinear, OptimizedLinear + +import torch + +try: + import transformers +except ImportError: + transformers = None + + +def init_lora(model): + model.requires_grad_(False) + for m in model.modules(): + if isinstance(m, LoRAOptimizedLinear): + m.init_lora() + + +class Init(object): + """ + Init context wrapper similar in style to zero.Init. Allows for injecting OptimizedLinear during model + construction which will shard base weights and reduce overall memory usage during model init. Primarily + useful when initializing a model via transformers.AutoModelForCausalLM. + + Example usage: + lora_config = deepspeed.linear.LoRAConfig(..) + quant_config = deepspeed.linear.QuantizationConfig(..) + with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config): + model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-405B") + + """ + + def __init__(self, lora_config=None, quant_config=None): + self._orig_nn_linear = torch.nn.Linear + self._orig_causallm_pretrained = None + if transformers != None: + self._orig_causallm_pretrained = transformers.AutoModelForCausalLM.from_pretrained + self._orig_causallm_config = transformers.AutoModelForCausalLM.from_config + self.lora_config = lora_config + self.quant_config = quant_config + self._post_init_complete = False + + def __enter__(self): + + class OptLinearWrapper: + _orig_nn_linear = self._orig_nn_linear + _lora_config = self.lora_config + _quant_config = self.quant_config + + def __new__(self, *args, **kwargs): + self._lora_config.delay_lora_init = True + kwargs['lora_config'] = self._lora_config + kwargs['quantization_config'] = self._quant_config + kwargs['linear_cls'] = self._orig_nn_linear + return OptimizedLinear(*args, **kwargs) + + def _model_init(model): + if self.lora_config != None: + init_lora(model) + self._post_init_complete = True + return model + + # ensures non-lora params are frozen and lora weights are initialized + def from_pretrained(*args, **kwargs): + model = self._orig_causallm_pretrained(*args, **kwargs) + return _model_init(model) + + def from_config(*args, **kwargs): + model = self._orig_causallm_config(*args, **kwargs) + return _model_init(model) + + torch.nn.Linear = OptLinearWrapper + if transformers != None: + transformers.AutoModelForCausalLM.from_pretrained = from_pretrained + transformers.AutoModelForCausalLM.from_config = from_config + + def __exit__(self, *args, **kwargs): + torch.nn.Linear = self._orig_nn_linear + if not self._post_init_complete: + print('WARNING: For some reason LoRA modules are not initialized, this is usually done automatically ' + 'if using transformers via (AutoModelForCausalLM from_pretrained/from_config). ' + 'You must call `init_lora` on each module in order to use DeepSpeed LoRA, otherwise ' + 'you will error out during runtime.') + else: + transformers.AutoModelForCausalLM.from_pretrained = self._orig_causallm_pretrained + transformers.AutoModelForCausalLM.from_config = self._orig_causallm_config diff --git a/deepspeed/linear/optimized_linear.py b/deepspeed/linear/optimized_linear.py index e982785a8122..3720196aa255 100644 --- a/deepspeed/linear/optimized_linear.py +++ b/deepspeed/linear/optimized_linear.py @@ -40,7 +40,9 @@ def __new__(self, bias: bool = False, lora_config: LoRAConfig = None, quantization_config: QuantizationConfig = None, - dtype=torch.bfloat16): + device=None, + dtype=torch.bfloat16, + linear_cls=nn.Linear): if quantization_config is not None and not is_dataclass(quantization_config): raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}") @@ -48,7 +50,7 @@ def __new__(self, raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}") if lora_config is None and quantization_config is None: # Everything disabled, fall back to normal nn.Linear - self = nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype) + self = linear_cls(input_dim, output_dim, bias=bias, dtype=dtype, device=device) elif lora_config: # lora enabled, quantization may or may not be @@ -57,7 +59,9 @@ def __new__(self, bias=bias, lora_config=lora_config, quantization_config=quantization_config, - dtype=dtype) + dtype=dtype, + device=device, + linear_cls=linear_cls) elif quantization_config: # only quantization enabled, no lora @@ -78,57 +82,121 @@ def __init__(self, lora_config: LoRAConfig = None, quantization_config: QuantizationConfig = None, device=None, - dtype=torch.bfloat16): + dtype=torch.bfloat16, + linear_cls=nn.Linear): super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.bias = bias self.lora_config = lora_config self.quantization_config = quantization_config - device = get_accelerator().current_device_name() if device is None else device + self.device = get_accelerator().current_device_name() if device is None else device + self.linear_cls = linear_cls + self.dtype = dtype assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config" - + assert not self.bias, "bias=True is not supported by LoRAOptimizedLinear" self.zero_shards = self.lora_config.base_weight_sharding self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards) - w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype)) - torch.nn.init.xavier_uniform_(w) + if self.zero_shards > 1: + assert self.zero_shards == dist.get_world_size( + ), "base weight sharding is only supported across world size" + w = torch.nn.Parameter(torch.empty(self.output_dim * self.sharded_weight_size, dtype=dtype), + requires_grad=False) + else: + w = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=dtype), requires_grad=False) + torch.nn.init.xavier_uniform_(w.reshape(self.sharded_weight_size, self.output_dim)) if self.quantization_config is not None: assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization" - self.base_weight = QuantizedParameter(w, quantization_config=quantization_config) + self.weight = QuantizedParameter(w, quantization_config=quantization_config) else: - self.base_weight = w + self.weight = w + + self.disabled = False + self._initialized = False + if not self.lora_config.delay_lora_init: + self.init_lora() + + def disable(self): + self.disabled = True + self.weight = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=self.dtype), + requires_grad=False) + + def init_lora(self): + if self.disabled: + return + + if self.quantization_config is not None: + # ensure quant-param wasn't stripped, in some cases transformers will do this during model init + if not isinstance(self.weight, QuantizedParameter): + self.weight = QuantizedParameter(self.weight, quantization_config=self.quantization_config) + + self._initialized = True + self.weight.requires_grad = False - self.base_weight.requires_grad = False + # Mark base weight to prevent broadcast and ensure proper offload behavior + self.weight.ds_optim_param = True + + self.lora_scaling_factor = self.lora_config.lora_alpha / self.lora_config.lora_r - # Use RS lora for now. - self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r) # Keeping lora weights in bf16 precision for ease of training. - self.lora_weight_1 = nn.Linear(self.input_dim, - self.lora_config.lora_r, - bias=self.bias, - device=device, - dtype=dtype) - self.lora_weight_2 = nn.Linear(self.lora_config.lora_r, - self.output_dim, - bias=self.bias, - device=device, - dtype=dtype) + self.lora_weight_1 = self.linear_cls(self.input_dim, + self.lora_config.lora_r, + bias=self.bias, + device=self.device, + dtype=self.dtype) + self.lora_weight_2 = self.linear_cls(self.lora_config.lora_r, + self.output_dim, + bias=self.bias, + device=self.device, + dtype=self.dtype) + + # initialize "A" with kaiming uniform and "B" with zeros following this + # https://github.com/huggingface/peft/blob/62122b5add8d6892f70c82eaef2147a6ba33b90b/src/peft/tuners/lora/layer.py#L155 + nn.init.kaiming_uniform_(self.lora_weight_1.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_weight_2.weight) self.lora_weight_1.weight.requires_grad = True self.lora_weight_2.weight.requires_grad = True + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + if not any([target in prefix for target in self.lora_config.target_mods]): + # module does not match any target_mods, we must revert to normal nn.Linear via disable + self.disable() + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, + unexpected_keys, error_msgs) + + if self.zero_shards > 1: + if not dist.is_initialized(): + raise RuntimeError( + "attempting to use optimized linear base weight sharding but torch-distributed is not initialized, please init first." + ) + rank = dist.get_rank() + shape_local = self.output_dim * self.sharded_weight_size + base_weight_name = f"{prefix}weight" + incoming_param = state_dict[base_weight_name] + state_dict[base_weight_name] = incoming_param.flatten().narrow(0, rank * shape_local, shape_local) + + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) + def full_weight(self): - # This assumes weights are evenly sharded across gpus. which might not be correct. - # in that case, we should flatten before all_gather. - local_weight = self.base_weight.dequantized() if isinstance(self.base_weight, - QuantizedParameter) else self.base_weight - tensor_list = [ - torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype) - for _ in range(self.zero_shards) - ] - dist.all_gather(tensor_list, local_weight) - weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1)) - return weight + base_weight = self.weight + if getattr(base_weight, 'ds_offload', False): + # move to gpu so we can dequant and all-gather + assert base_weight.device == torch.device('cpu'), \ + f"expected base weight on cpu but found {base_weight.device}" + base_weight.offload(revert=True) + local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight + base_weight.offload() + else: + local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight + + tensor_out = torch.empty(self.output_dim * self.input_dim, + dtype=local_weight.dtype, + device=local_weight.device) + dist.all_gather_into_tensor(tensor_out, local_weight) + return tensor_out.reshape(self.output_dim, self.input_dim) def linear_without_F_linear(self, input, weight): output = torch.mm(input.reshape(-1, input.shape[-1]), weight) @@ -136,14 +204,18 @@ def linear_without_F_linear(self, input, weight): return output def forward(self, input_tensor): + if self.disabled: + return F.linear(input_tensor, self.weight) + assert self._initialized, "init_lora was never called, please initialize before proceeding" + # Gather the sharded base weight if self.zero_shards > 1: with torch.no_grad(): base_weight = self.full_weight() elif self.quantization_config: - base_weight = self.base_weight.dequantized() + base_weight = self.weight.dequantized() else: - base_weight = self.base_weight + base_weight = self.weight base_weight_output = F.linear(input_tensor, base_weight) lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor)) diff --git a/deepspeed/linear/quantization.py b/deepspeed/linear/quantization.py index 8e4f23dfba89..70fabea845ba 100644 --- a/deepspeed/linear/quantization.py +++ b/deepspeed/linear/quantization.py @@ -75,6 +75,13 @@ def dequantized(self) -> torch.Tensor: q_mantisa_bits=self.quantization_config.mantissa_bits) return self.data + def offload(self, revert=False): + if getattr(self, 'ds_offload', False): + if revert: + self.data = self.to(get_accelerator().current_device_name()) + else: + self.data = self.to('cpu') + def __getstate__(self): state = self.__dict__ state["data"] = self.data @@ -104,7 +111,9 @@ def __copy__(self): return new_instance def cuda(self, device=None, non_blocking=False): - return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) + device = "cuda" if device is None else device + self.quantizer.to(device, non_blocking=non_blocking) + return self.to(device, non_blocking=non_blocking) def to(self, *args, **kwargs): """ @@ -112,6 +121,7 @@ def to(self, *args, **kwargs): quantize it. """ tensor = super().to(*args, **kwargs) + self.quantizer.to(*args, **kwargs) self._ensure_quantized(tensor) return tensor diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py index 170954e0cf71..edd4ef57302c 100644 --- a/deepspeed/ops/fp_quantizer/quantize.py +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -91,6 +91,13 @@ def quantize(self, return out + def to(self, *args, **kwargs): + # Intermediate tensors may need to be moved to different devices + if hasattr(self, 'input_q'): + self.input_q = self.input_q.to(*args, **kwargs) + if hasattr(self, 'scale'): + self.scale = self.scale.to(*args, **kwargs) + def get_scales(self): return fp_quant_module.get_scales(self.scale, self.num_groups) diff --git a/deepspeed/runtime/eigenvalue.py b/deepspeed/runtime/eigenvalue.py index df63854dd1ca..36300eb904dd 100755 --- a/deepspeed/runtime/eigenvalue.py +++ b/deepspeed/runtime/eigenvalue.py @@ -7,6 +7,7 @@ from deepspeed.utils import log_dist import numpy as np import logging +from deepspeed.utils.torch import required_torch_version class Eigenvalue(object): @@ -36,12 +37,15 @@ def __init__(self, ranks=[0]) # Replace all nan/pos-inf/neg-inf to zero - # TODO: Pytorch new version may add this function, replace this one by then. def nan_to_num(self, x): - device = x.device - x = x.cpu().numpy() - x = np.nan_to_num(x=x, copy=False, nan=0.0, posinf=0.0, neginf=0.0) - return torch.from_numpy(x).to(device) + if required_torch_version(min_version=1.8): + return torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) + else: + # Fallback to numpy based implementation for backwards-compatibility with PyTorch 1.7 or older versions. + device = x.device + x = x.cpu().numpy() + x = np.nan_to_num(x=x, copy=False, nan=0.0, posinf=0.0, neginf=0.0) + return torch.from_numpy(x).to(device) def normalize(self, v): norm_squared = self.inner_product(v, v) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index d40141132aaf..1c74c0c735a0 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -35,6 +35,8 @@ from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.bf16_optimizer import BF16_Optimizer +from deepspeed.linear.optimized_linear import LoRAOptimizedLinear + from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \ ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \ @@ -326,6 +328,8 @@ def __init__(self, self.sparse_tensor_module_names.add(name + ".weight") logger.info("Will convert {} to sparse tensor during training".format(name)) + self._optimized_linear_offload_setup() + self.save_non_zero_checkpoint = False self.save_zero_checkpoint = False if not isinstance(self.optimizer, DeepSpeedZeRoOffload): @@ -363,6 +367,43 @@ def __init__(self, self._is_compiled = False + def _optimized_linear_offload_setup(self): + self.optimized_linear_base_weight_sharding = False + self.optimized_linear_lora_enabled = False + offload_ratio = None + for _, module in self.module.named_modules(): + if isinstance(module, LoRAOptimizedLinear): + self.optimized_linear_lora_enabled = True + offload_ratio = None + if offload_ratio is not None: + assert offload_ratio == module.lora_config.offload_ratio, \ + "all lora_config offload ratios should be the same across the model" + offload_ratio = module.lora_config.offload_ratio + if module.zero_shards > 1: + # set attr so checkpoint saving can handle BWS properly + self.optimized_linear_base_weight_sharding = True + + if offload_ratio is None: + # Nothing enabled, do nothing + return + + total_params = 0 + for _, p in self.module.named_parameters(): + if hasattr(p, 'ds_optim_param'): + total_params += p.numel() + + offload_limit = total_params * offload_ratio + logger.info(f'offloading {offload_ratio*100}% of eligible params, specifically {offload_limit} params') + total_offloaded = 0 + for _, p in self.module.named_parameters(): + if hasattr(p, 'ds_optim_param'): + if total_offloaded < offload_limit: + total_offloaded += p.numel() + p.ds_offload = True + p.offload() + else: + p.ds_offload = False + def destroy(self): if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): self.optimizer.destroy() @@ -1054,9 +1095,12 @@ def _broadcast_model(self): def is_replicated(p): if hasattr(p, "ds_status") and p.ds_status is not ZeroParamStatus.AVAILABLE: return False + elif hasattr(p, 'ds_optim_param'): + # do not broadcast OptimizedLinear parameters, they are unique per base weight shard + return False return True - for p in self.module.parameters(): + for n, p in self.module.named_parameters(): # Broadcast the model for different parameters if is_moe_param(p): if torch.is_tensor(p) and is_replicated(p): diff --git a/op_builder/evoformer_attn.py b/op_builder/evoformer_attn.py index 6e7721f94e01..af3aa7429775 100644 --- a/op_builder/evoformer_attn.py +++ b/op_builder/evoformer_attn.py @@ -41,18 +41,21 @@ def nvcc_args(self): args.append(f"-DGPU_ARCH={major}{minor}") return args - def is_compatible(self, verbose=True): + def is_compatible(self, verbose=False): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile kernels") return False if self.cutlass_path is None: - self.warning("Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH") + if verbose: + self.warning("Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH") return False with open(f'{self.cutlass_path}/CHANGELOG.md', 'r') as f: if '3.1.0' not in f.read(): - self.warning("Please use CUTLASS version >= 3.1.0") + if verbose: + self.warning("Please use CUTLASS version >= 3.1.0") return False cuda_okay = True if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda @@ -60,10 +63,12 @@ def is_compatible(self, verbose=True): torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda if cuda_capability < 7: - self.warning("Please use a GPU with compute capability >= 7.0") + if verbose: + self.warning("Please use a GPU with compute capability >= 7.0") cuda_okay = False if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("Please use CUDA 11+") + if verbose: + self.warning("Please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay diff --git a/op_builder/fp_quantizer.py b/op_builder/fp_quantizer.py index c7d2e72b5408..40cf504c2c83 100644 --- a/op_builder/fp_quantizer.py +++ b/op_builder/fp_quantizer.py @@ -22,11 +22,12 @@ def __init__(self, name=None): def absolute_name(self): return f'deepspeed.ops.fp_quantizer.{self.NAME}_op' - def is_compatible(self, verbose=True): + def is_compatible(self, verbose=False): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True @@ -35,17 +36,20 @@ def is_compatible(self, verbose=True): torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda if cuda_capability < 8: - self.warning("NVIDIA Inference is only supported on Ampere and newer architectures") + if verbose: + self.warning("NVIDIA Inference is only supported on Ampere and newer architectures") cuda_okay = False if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False try: import triton except ImportError: - self.warning(f"please install triton==2.3.0 or 2.3.1 if you want to use the FP Quantizer Kernels") + if verbose: + self.warning(f"please install triton==2.3.0 or 2.3.1 if you want to use the FP Quantizer Kernels") return False # triton 2.3.0 and 2.3.1 are okay and the only versions released in 2.3.x before 3.x was released @@ -59,9 +63,10 @@ def is_compatible(self, verbose=True): triton_mismatch = major != "2" or minor != "3" if triton_mismatch: - self.warning( - f"FP Quantizer is using an untested triton version ({installed_triton}), only 2.3.0 and 2.3.1 are known to be compatible with these kernels" - ) + if verbose: + self.warning( + f"FP Quantizer is using an untested triton version ({installed_triton}), only 2.3.0 and 2.3.1 are known to be compatible with these kernels" + ) return False return super().is_compatible(verbose) and cuda_okay diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py index d1957f39d9a8..45e8628e669f 100755 --- a/op_builder/inference_core_ops.py +++ b/op_builder/inference_core_ops.py @@ -23,7 +23,8 @@ def is_compatible(self, verbose=True): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True @@ -32,11 +33,13 @@ def is_compatible(self, verbose=True): torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda if cuda_capability < 6: - self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") cuda_okay = False if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay diff --git a/op_builder/inference_cutlass_builder.py b/op_builder/inference_cutlass_builder.py index 51f7931d9435..fda6e74bbf6a 100644 --- a/op_builder/inference_cutlass_builder.py +++ b/op_builder/inference_cutlass_builder.py @@ -22,7 +22,8 @@ def is_compatible(self, verbose=True): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True @@ -31,11 +32,13 @@ def is_compatible(self, verbose=True): torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda if cuda_capability < 6: - self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") cuda_okay = False if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay diff --git a/op_builder/ragged_ops.py b/op_builder/ragged_ops.py index ec7cab91885f..a4e365786a2b 100644 --- a/op_builder/ragged_ops.py +++ b/op_builder/ragged_ops.py @@ -23,7 +23,8 @@ def is_compatible(self, verbose=True): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True @@ -32,11 +33,13 @@ def is_compatible(self, verbose=True): torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda if cuda_capability < 6: - self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") cuda_okay = False if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay diff --git a/op_builder/ragged_utils.py b/op_builder/ragged_utils.py index 89450e1fd30d..a855f072af8c 100755 --- a/op_builder/ragged_utils.py +++ b/op_builder/ragged_utils.py @@ -23,7 +23,8 @@ def is_compatible(self, verbose=True): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True @@ -32,11 +33,13 @@ def is_compatible(self, verbose=True): torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda if cuda_capability < 6: - self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") cuda_okay = False if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay diff --git a/op_builder/sparse_attn.py b/op_builder/sparse_attn.py index 188d257ff4ef..2385adc8fe9c 100644 --- a/op_builder/sparse_attn.py +++ b/op_builder/sparse_attn.py @@ -27,45 +27,51 @@ def sources(self): def cxx_args(self): return ['-O2', '-fopenmp'] - def is_compatible(self, verbose=True): + def is_compatible(self, verbose=False): # Check to see if llvm and cmake are installed since they are dependencies #required_commands = ['llvm-config|llvm-config-9', 'cmake'] #command_status = list(map(self.command_exists, required_commands)) #deps_compatible = all(command_status) if self.is_rocm_pytorch(): - self.warning(f'{self.NAME} is not compatible with ROCM') + if verbose: + self.warning(f'{self.NAME} is not compatible with ROCM') return False try: import torch except ImportError: - self.warning(f"unable to import torch, please install it first") + if verbose: + self.warning(f"unable to import torch, please install it first") return False # torch-cpu will not have a cuda version if torch.version.cuda is None: cuda_compatible = False - self.warning(f"{self.NAME} cuda is not available from torch") + if verbose: + self.warning(f"{self.NAME} cuda is not available from torch") else: major, minor = torch.version.cuda.split('.')[:2] cuda_compatible = (int(major) == 10 and int(minor) >= 1) or (int(major) >= 11) if not cuda_compatible: - self.warning(f"{self.NAME} requires CUDA version 10.1+") + if verbose: + self.warning(f"{self.NAME} requires CUDA version 10.1+") TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) torch_compatible = (TORCH_MAJOR == 1 and TORCH_MINOR >= 5) if not torch_compatible: - self.warning( - f'{self.NAME} requires a torch version >= 1.5 and < 2.0 but detected {TORCH_MAJOR}.{TORCH_MINOR}') + if verbose: + self.warning( + f'{self.NAME} requires a torch version >= 1.5 and < 2.0 but detected {TORCH_MAJOR}.{TORCH_MINOR}') try: import triton except ImportError: # auto-install of triton is broken on some systems, reverting to manual install for now # see this issue: https://github.com/microsoft/DeepSpeed/issues/1710 - self.warning(f"please install triton==1.0.0 if you want to use sparse attention") + if verbose: + self.warning(f"please install triton==1.0.0 if you want to use sparse attention") return False if pkg_version: @@ -76,7 +82,9 @@ def is_compatible(self, verbose=True): triton_mismatch = installed_triton != "1.0.0" if triton_mismatch: - self.warning(f"using untested triton version ({installed_triton}), only 1.0.0 is known to be compatible") + if verbose: + self.warning( + f"using untested triton version ({installed_triton}), only 1.0.0 is known to be compatible") return False return super().is_compatible(verbose) and torch_compatible and cuda_compatible diff --git a/op_builder/spatial_inference.py b/op_builder/spatial_inference.py index 59caf57f938d..8a6b36cce0b0 100644 --- a/op_builder/spatial_inference.py +++ b/op_builder/spatial_inference.py @@ -21,7 +21,8 @@ def is_compatible(self, verbose=True): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True @@ -31,7 +32,8 @@ def is_compatible(self, verbose=True): cuda_capability = torch.cuda.get_device_properties(0).major if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index 5ee902289448..88b77499cc0e 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -21,7 +21,8 @@ def is_compatible(self, verbose=True): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True @@ -30,11 +31,13 @@ def is_compatible(self, verbose=True): torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major if cuda_capability < 6: - self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") cuda_okay = False if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay diff --git a/tests/unit/linear/test_ctx.py b/tests/unit/linear/test_ctx.py new file mode 100644 index 000000000000..e03d13fd6ce2 --- /dev/null +++ b/tests/unit/linear/test_ctx.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed +import pytest +from unit.common import DistributedTest + +import deepspeed.comm as dist +from deepspeed.linear import LoRAConfig, init_lora +from deepspeed.linear.optimized_linear import LoRAOptimizedLinear +from unit.simple_model import random_dataloader, SimpleModel + +try: + import transformers +except ImportError: + transformers = None + +if transformers is None: + pytest.skip("transformers is required for this test", allow_module_level=True) + + +def injection_assert(model): + # pick out random linear that should have been replaced and initialized + q_proj = model.model.layers[1].self_attn.q_proj + + assert isinstance(q_proj, LoRAOptimizedLinear), "injection did not happen" + assert q_proj._initialized, "lora was not initialized properly" + assert isinstance(q_proj.lora_weight_1, torch.nn.Linear) + assert isinstance(q_proj.lora_weight_2, torch.nn.Linear) + + +class TestEngine(DistributedTest): + world_size = 2 + + def test_model(self): + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=2) + quant_config = None + hidden_dim = 64 + nlayers = 4 + + with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config): + model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayers) + + init_lora(model) + + model_norms = [model.linears[i].weight.norm().item() for i in range(nlayers)] + + ds_config = { + "train_batch_size": 2, + "steps_per_print": 1, + "bf16": { + "enabled": True + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": 1 + } + } + model, *_ = deepspeed.initialize(config=ds_config, model=model, model_parameters=model.parameters()) + + engine_norms = [model.module.linears[i].weight.norm().item() for i in range(nlayers)] + + # Ensure that sharded weights are not broadcast during engine init + assert engine_norms == model_norms, f"{dist.get_rank()=} base weight norms are not the same after engine init, {engine_norms=} != {model_norms=}" + + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.bfloat16) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + +class TestInitTransformers(DistributedTest): + world_size = 2 + + def test_pretrained_init(self): + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=2) + quant_config = None + + with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config): + model = transformers.AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-Llama-3") + + injection_assert(model) + + def test_config_init(self): + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=2) + quant_config = None + + config = transformers.AutoConfig.from_pretrained("llamafactory/tiny-random-Llama-3") + + with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config): + model = transformers.AutoModelForCausalLM.from_config(config) + + injection_assert(model)