diff --git a/.gitignore b/.gitignore index ab364ad8a7e7..5b9cc7ac3156 100644 --- a/.gitignore +++ b/.gitignore @@ -1,31 +1,40 @@ +# Ignore Python compiled files *.pyc -.idea/ -*~ -*.swp + +# Ignore IDE-specific files and directories +.idea/ # JetBrains IDE settings +.vscode/ # Visual Studio Code settings +.theia/ # Theia IDE settings + +# Ignore temporary and backup files +*~ # General backup files +*.swp # Vim swap files + +# Ignore log files *.log + +# Ignore a specific generated file deepspeed/git_version_info_installed.py + +# Ignore Python bytecode cache __pycache__ # Build + installation data -build/ -dist/ -*.so -deepspeed.egg-info/ -build.txt - -# Website -docs/_site/ -docs/build +build/ # Build artifacts +dist/ # Distribution files +*.so # Compiled shared objects +deepspeed.egg-info/ # Deepspeed package info +build.txt # Build information + +# Website generated files +docs/_site/ # Jekyll generated site +docs/build # Generated documentation docs/code-docs/source/_build docs/code-docs/_build docs/code-docs/build -.sass-cache/ -.jekyll-cache/ +.sass-cache/ # SASS cache +.jekyll-cache/ # Jekyll cache .jekyll-metadata # Testing data -tests/unit/saved_checkpoint/ - -# Dev/IDE data -.vscode -.theia +tests/unit/saved_checkpoint/ # Saved checkpoints for testing diff --git a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h index 50ba58b1d1dd..3760ccab852a 100644 --- a/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h +++ b/csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h @@ -243,10 +243,18 @@ class CustomMmaMultistage : public CustomMmaBase { } CUTLASS_DEVICE - bool set_prologue_done(bool value) { prologue_done_ = value; } + bool set_prologue_done(bool value) + { + prologue_done_ = value; + return true; + } CUTLASS_DEVICE - bool set_zero_outside_bounds(bool value) { zero_outside_bounds_ = value; } + bool set_zero_outside_bounds(bool value) + { + zero_outside_bounds_ = value; + return true; + } template CUTLASS_DEVICE static void prologue(typename Base::SharedStorage& shared_storage, diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index 75bb8d4d6c8f..db3668de610b 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -57,7 +57,7 @@ UNIVERSAL_CHECKPOINT_VERSION_VALUE = 0.2 # Vocabulary padding -VOCAB_DIVISIBILITY_PADDING_TENSOR = 'vocab_divisibility_padding_tensor' +VOCAB_TENSOR = 'vocab_tensor' PADDED_VOCAB_SIZE = 'padded_vocab_size' ORIGINAL_VOCAB_SIZE = 'original_vocab_size' diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index 7fb96ce98e29..327b2a8cbae7 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -26,8 +26,7 @@ PARAM_SHAPES, PARAM, CAT_DIM, - VOCAB_DIVISIBILITY_PADDING_TENSOR, - ORIGINAL_VOCAB_SIZE, + VOCAB_TENSOR, UNIVERSAL_CHECKPOINT_INFO, VOCABULARY_PARAMETER_PATTERNS, PIPELINE_REPLICATED_PARAMETER_PATTERNS, @@ -55,6 +54,10 @@ def parse_arguments(): parser.add_argument('--keep_temp_folder', action='store_true', help='Preserve temporary folder of intermediate checkpoint slice files. Useful for debugging.') + parser.add_argument('--no_strict', + dest='strict', + action='store_false', + help='Do not perform validity checks on converted checkpoint.') args = parser.parse_args() print(f'args = {args}') return args @@ -149,15 +152,8 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape): return slices -def _get_vocab_divisibility_padding_tensor(universal_checkpoint_info, padded_vocab_tensor): - original_vocab_size = universal_checkpoint_info.get(ORIGINAL_VOCAB_SIZE) - if padded_vocab_tensor.shape[0] > original_vocab_size: - return padded_vocab_tensor[-1] - else: - return torch.zeros(padded_vocab_tensor.shape[1]) - - def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape): + name, shape = name_and_shape slice_base_path = os.path.join(slice_dir, name) param_base_path = os.path.join(dir, name) @@ -167,6 +163,18 @@ def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape): parameters_to_average = universal_checkpoint_info.get(PARAMETER_TO_AVERAGE_PATTERNS, []) parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, []) vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, []) + unmatched_patterns = set(replicated_parameters + parameters_to_average + parameters_with_row_parallelism + + vocabulary_parameters) + + def get_matched_pattern(patterns_, name_): + matched_ = [pattern_ for pattern_ in patterns_ if re.match(pattern_, name_)] + assert len(matched_) <= 1, f'Got more than one matching patterns={matched_} for {name_}' + if matched_: + pattern_ = matched_[0] + unmatched_patterns.discard(pattern_) + return pattern_ + return None + for state in ("fp32", "exp_avg", "exp_avg_sq"): slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape) final_path = os.path.join(param_base_path, f"{state}.pt") @@ -174,32 +182,34 @@ def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape): #print(f"Expected shape: {shape}") #print(f"Fragment sizes:", list(frag.shape for frag in slices)) ckpt_dict = {} - if any(re.match(pattern, name) for pattern in replicated_parameters): + if get_matched_pattern(replicated_parameters, name): if len(slices) > 1: assert all([slices[0].equal(other_slice) for other_slice in slices[1:]]) param = slices[0] # print(f'replicate {name} using first slice') - elif any(re.match(pattern, name) for pattern in parameters_to_average): + elif get_matched_pattern(parameters_to_average, name): param = sum(slices) / len(slices) # print(f'merge {name} using average') else: - cat_dim = 1 if any(re.match(pattern, name) for pattern in parameters_with_row_parallelism) else 0 + cat_dim = 1 if get_matched_pattern(parameters_with_row_parallelism, name) else 0 # print(f"merge {name} with CAT DIM: {cat_dim}") param = torch.cat(slices, dim=cat_dim) ckpt_dict[CAT_DIM] = cat_dim - if any(re.match(pattern, name) for pattern in vocabulary_parameters): + if get_matched_pattern(vocabulary_parameters, name): #print(f"Before {param.shape=}") # strip padding - #param = _strip_vocab_padding(ds_checkpoint, param) - ckpt_dict[VOCAB_DIVISIBILITY_PADDING_TENSOR] = _get_vocab_divisibility_padding_tensor( - universal_checkpoint_info, param) + original_vocab_size = universal_checkpoint_info['original_vocab_size'] + param = param[:original_vocab_size, :] + ckpt_dict[VOCAB_TENSOR] = True #print(f"After {param.shape=}") #print(f"Final shape: {param.shape}") ckpt_dict[PARAM] = param _save_checkpoint(final_path, ckpt_dict) + return unmatched_patterns + def _get_chunks(l, n): for i in range(0, len(l), n): @@ -208,10 +218,13 @@ def _get_chunks(l, n): def _do_parallel_work(do_work, work_chunks, num_workers): pool = multiprocessing.Pool(num_workers) + results = [] for batch in tqdm.tqdm(work_chunks): - pool.map(do_work, batch) + res = pool.map(do_work, batch) + results.extend(res) pool.close() pool.join() + return results def _extract_zero_shard_files(args, ds_checkpoint, temp_dir): @@ -232,7 +245,16 @@ def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir): #pprint(work_chunks) zero_output_folder = os.path.join(args.output_folder, "zero") do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree) - _do_parallel_work(do_work, work_chunks, args.num_merge_workers) + unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_merge_workers) + + # verify that all patterns were used + # if a pattern was not used by any of the workers, then it was not used at all -> assert/alert + sets = [set(lst) for lst in unmatched_patterns_lists] + unmatched_patterns = list(set.intersection(*sets)) + if args.strict: + assert not unmatched_patterns, f'Unused patterns={unmatched_patterns} while merging tp slices' + elif unmatched_patterns: + print(f'Warning: Unused patterns={unmatched_patterns} while merging tp slices') def _save_optimizer_state(args, ds_checkpoint): diff --git a/deepspeed/checkpoint/universal_checkpoint.py b/deepspeed/checkpoint/universal_checkpoint.py index cd4d7d51a4c2..cfb04b778f50 100644 --- a/deepspeed/checkpoint/universal_checkpoint.py +++ b/deepspeed/checkpoint/universal_checkpoint.py @@ -6,7 +6,7 @@ import os import torch import types -from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_DIVISIBILITY_PADDING_TENSOR, CAT_DIM) +from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM) def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): @@ -43,21 +43,16 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): # the converter to universal currently strips the original padding completely so the saved # weight is padding-free and we just need to add new padding depending on the target TP # degree - vocab_divisibility_padding_tensor = ckpt_dict.get(VOCAB_DIVISIBILITY_PADDING_TENSOR, None) - if vocab_divisibility_padding_tensor is not None: + is_vocab_tensor = ckpt_dict.get(VOCAB_TENSOR, False) + if is_vocab_tensor: # In the absence of data passed from the user wrt new padded vocab specific to tp degree # we can again derive that data by reverse engineering the target shapes like so: padded_target_vocab_size = self.shape[0] * tp_world_size + assert padded_target_vocab_size >= full_hp_param.shape[0], \ + f'Vocab tensor padded size {padded_target_vocab_size} < loaded universal size {full_hp_param.shape[0]}' if padded_target_vocab_size > full_hp_param.shape[0]: - # Need to expand padding_size = padded_target_vocab_size - full_hp_param.shape[0] - # Implement the following concat in efficient way using pad - #full_hp_param = torch.cat((full_hp_param, padding_tensor), 0) full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_size), "constant", 0) - full_hp_param[:-padding_size, :] = vocab_divisibility_padding_tensor - else: - # Need to shrink or keep the same - full_hp_param = full_hp_param[:padded_target_vocab_size, :] full_param_numel = full_hp_param.numel() tp_slice_numel = self.numel() diff --git a/deepspeed/ops/deepspeed4science/evoformer_attn.py b/deepspeed/ops/deepspeed4science/evoformer_attn.py index ba7e20e51d50..da5843d6de31 100644 --- a/deepspeed/ops/deepspeed4science/evoformer_attn.py +++ b/deepspeed/ops/deepspeed4science/evoformer_attn.py @@ -30,7 +30,7 @@ def _attention(Q, K, V, bias1, bias2): return O, lse -def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2): +def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2, bias1_grad, bias2_grad): assert max(Q.shape[-1], V.shape[-1]) <= 64, "Hidden size is too large. Need to change kMax to a larger value" dQ = torch.empty_like(Q, dtype=Q.dtype) dK = torch.empty_like(K, dtype=K.dtype) @@ -44,8 +44,14 @@ def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2): if kernel_ is None: kernel_ = EvoformerAttnBuilder().load() delta = torch.empty_like(lse) - dB1 = torch.zeros_like(bias1, dtype=torch.float32) - dB2 = torch.zeros_like(bias2, dtype=torch.float32) + if bias1_grad: + dB1 = torch.zeros_like(bias1, dtype=torch.float32) + else: + dB1 = torch.tensor([], dtype=torch.float32, device=bias1.device) + if bias2_grad: + dB2 = torch.zeros_like(bias2, dtype=torch.float32) + else: + dB2 = torch.tensor([], dtype=torch.float32, device=bias2.device) kernel_.attention_bwd(dO, Q, K, V, O, lse, delta, bias1, bias2, dQ, dK, dV, dB1, dB2) return dQ, dK, dV, dB1.to(dO.dtype), dB2.to(dO.dtype) @@ -69,10 +75,12 @@ def forward(ctx, q, k, v, bias1=None, bias2=None): @staticmethod def backward(ctx, grad_output): q, k, v, o, lse, bias1, bias2 = ctx.saved_tensors - dQ, dK, dV, dB1, dB2 = attention_bwd(grad_output, q, k, v, o, lse, bias1, bias2) - if bias1.numel() == 0: + is_b1_grad = bias1.numel() != 0 and ctx.needs_input_grad[3] + is_b2_grad = bias2.numel() != 0 and ctx.needs_input_grad[4] + dQ, dK, dV, dB1, dB2 = attention_bwd(grad_output, q, k, v, o, lse, bias1, bias2, is_b1_grad, is_b2_grad) + if not is_b1_grad: dB1 = None - if bias2.numel() == 0: + if not is_b2_grad: dB2 = None return dQ, dK, dV, dB1, dB2 @@ -90,13 +98,9 @@ def DS4Sci_EvoformerAttention(Q, K, V, biases): bias_2_shape = lambda x: (x.shape[0], 1, x.shape[3], x.shape[2], x.shape[2]) if biases[0] is not None: - assert biases[0].shape == bias_1_shape(Q) - else: - biases[0] = Q.new_zeros(bias_1_shape(Q)) + assert biases[0].shape == bias_1_shape(Q), "bias1 shape is incorrect" if biases[1] is not None: - assert biases[1].shape == bias_2_shape(Q) - else: - biases[1] = Q.new_zeros(bias_2_shape(Q)) + assert biases[1].shape == bias_2_shape(Q), "bias2 shape is incorrect" return EvoformerFusedAttention.apply(Q, K, V, biases[0], biases[1]) diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index 108cb37b57fb..72a7bc0516ba 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -237,6 +237,14 @@ def model_parallel_cuda_manual_seed(seed): _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed) +def model_parallel_reconfigure_tp_seed(seed): + global mpu + tp_rank = bwc_tensor_model_parallel_rank(mpu) + model_parallel_seed = seed + 2718 + tp_rank + with _CUDA_RNG_STATE_TRACKER.fork(): + get_accelerator().manual_seed(model_parallel_seed) + + def get_partition_start(item): global mp_rank, mp_size, mp_group size = item.numel() diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 550af8fac057..494816e6a846 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -75,7 +75,6 @@ def __init__(self, self.fp32_groups_gradient_flat_partition = [] self.fp32_groups_has_gradients = [] - self.step_count = 0 self.group_paddings = [] if self.using_real_optimizer: @@ -252,7 +251,6 @@ def step(self, closure=None): self.update_lp_params() self.clear_hp_grads() - self.step_count += 1 def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs): """Perform a backward pass and copy the low-precision gradients to the diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index b683f1f37e12..0aeda20fbeb3 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2726,6 +2726,8 @@ def load_checkpoint(self, if self.load_universal_checkpoint(): self.optimizer.update_lp_params() + if load_zero_checkpoint: + self.update_optimizer_step(step=client_states['iteration'] + 1) return load_path, client_states @@ -2903,6 +2905,24 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}") return True + def update_optimizer_step(self, step): + + def set_step(d): + if isinstance(d['step'], torch.Tensor): + d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device) + else: + d['step'] = step + + optimizer = self.optimizer + base_optimizer = optimizer.optimizer + state = base_optimizer.state + for group in optimizer.param_groups: + if 'step' in group: + set_step(group) + for p in group['params']: + if p in state and len(state[p]) > 0 and 'step' in state[p]: + set_step(state[p]) + def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode): zero_ckpt_names = [] for dp_rank in range(dp_world_size): diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index a929ca5842b6..182f806c839c 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -107,7 +107,6 @@ def __init__(self, self.clip_grad = clip_grad self.norm_type = 2 - self.step_count = 0 if required_torch_version(max_version=0.4): self.clip_grad_norm = torch.nn.utils.clip_grad_norm @@ -289,8 +288,6 @@ def step(self, closure=None): self.timers.log(STEP_TIMERS) - self.step_count += 1 - return self.overflow def _get_norm_with_moe_layers(self, all_groups_norm): diff --git a/tests/benchmarks/DS4Sci_EvoformerAttention_bench.py b/tests/benchmarks/DS4Sci_EvoformerAttention_bench.py index b19eae7272c4..e3d8825f5415 100644 --- a/tests/benchmarks/DS4Sci_EvoformerAttention_bench.py +++ b/tests/benchmarks/DS4Sci_EvoformerAttention_bench.py @@ -3,7 +3,7 @@ # DeepSpeed Team """ -This script is to test the correctness of the DS4Sci_EvoformerAttention op. +This script is to test the performance of the DS4Sci_EvoformerAttention op. To run the script, 1. Clone the CUTLASS repo. E.g. git clone https://github.com/NVIDIA/cutlass.git 2. Specify the CUTLASS_PATH environment variable. E.g. export CUTLASS_PATH=$(pwd)/cutlass @@ -83,7 +83,7 @@ def benchmark(): Q = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True) K = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True) V = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True) - bias1 = torch.randn(batch, N, 1, 1, seq_len, dtype=dtype, device="cuda", requires_grad=True) + bias1 = torch.randn(batch, N, 1, 1, seq_len, dtype=dtype, device="cuda", requires_grad=False) bias2 = torch.randn(batch, 1, heads, seq_len, seq_len, dtype=dtype, device="cuda", requires_grad=True) # warm up DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2]) diff --git a/tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py b/tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py index f8cd46e29228..25624f3a6818 100644 --- a/tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py +++ b/tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py @@ -69,42 +69,35 @@ def test_DS4Sci_EvoformerAttention(dtype, tensor_shape): dtype=dtype, device=get_accelerator().device_name(), requires_grad=True) - bias1 = torch.randn(batch, - n, - 1, - 1, - seq_len, - dtype=dtype, - device=get_accelerator().device_name(), - requires_grad=True) - bias2 = torch.randn(batch, - 1, - heads, - seq_len, - seq_len, - dtype=dtype, - device=get_accelerator().device_name(), - requires_grad=True) + mask = torch.randint(0, 2, (batch, n, 1, 1, seq_len), dtype=dtype, device=get_accelerator().device_name()) + mask_bias = 1e9 * (mask - 1) + bias = torch.randn(batch, + 1, + heads, + seq_len, + seq_len, + dtype=dtype, + device=get_accelerator().device_name(), + requires_grad=True) dummy_out = torch.rand_like(Q, dtype=dtype, device=get_accelerator().device_name()) - ref_out = attention_reference(Q, K, V, [bias1, bias2], 1 / (dim**0.5)) + ref_out = attention_reference(Q, K, V, [mask_bias, bias], 1 / (dim**0.5)) ref_out.backward(dummy_out) ref_dv, V.grad = V.grad.clone(), None ref_dk, K.grad = K.grad.clone(), None ref_dq, Q.grad = Q.grad.clone(), None - ref_db1, bias1.grad = bias1.grad.clone(), None - ref_db2, bias2.grad = bias2.grad.clone(), None + ref_db, bias.grad = bias.grad.clone(), None - out = DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2]) + out = DS4Sci_EvoformerAttention(Q, K, V, [mask_bias, bias]) out.backward(dummy_out) dv, v_grad = V.grad.clone(), None dk, k_grad = K.grad.clone(), None dq, q_grad = Q.grad.clone(), None - db1, bias1.grad = bias1.grad.clone(), None - db2, bias2.grad = bias2.grad.clone(), None + db, bias.grad = bias.grad.clone(), None - assert torch.allclose(ref_out, out, atol=2e-2, rtol=0), f"\n{ref_out} \n {out}" - assert torch.allclose(ref_dv, dv, atol=2e-2, rtol=0), f"\n{ref_dv} \n {dv}" - assert torch.allclose(ref_dk, dk, atol=2e-2, rtol=0), f"\n{ref_dk} \n {dk}" - assert torch.allclose(ref_dq, dq, atol=2e-2, rtol=0), f"\n{ref_dq} \n {dq}" - assert torch.allclose(ref_db1, db1, atol=2e-2, rtol=1e-2), f"{ref_db1} \n {db1}" - assert torch.allclose(ref_db2, db2, atol=2e-2, rtol=1e-2), f"{ref_db2} \n {db2}" + eps = 1e-2 if dtype == torch.float16 else 5e-2 + + assert torch.max(torch.abs(ref_out - out)).item() < eps, f"out eps: {torch.max(torch.abs(ref_out - out))}" + assert torch.max(torch.abs(ref_dv - dv)) < eps, f"dv eps: {torch.max(torch.abs(ref_dv - dv))}" + assert torch.max(torch.abs(ref_dk - dk)) < eps, f"dk eps: {torch.max(torch.abs(ref_dk - dk))}" + assert torch.max(torch.abs(ref_dq - dq)) < eps, f"dq eps: {torch.max(torch.abs(ref_dq - dq))}" + assert torch.max(torch.abs(ref_db - db)) < 2 * eps, f"db eps: {torch.max(torch.abs(ref_db - db))}"