diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index 86dc1fd..17585b9 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -18,7 +18,7 @@ from enum import Enum from functools import total_ordering from pprint import pformat -from typing import Optional, Union, List +from typing import Union import fire @@ -1408,8 +1408,8 @@ def inference( ) if use_kv_cache: - if (batch_size_per_gpu * (seq_len + num_tokens_to_generate) - < self.get_pivot()): + if (batch_size_per_gpu * + (seq_len + num_tokens_to_generate) < self.get_pivot()): logger.warning( "kv_cache is only useful when batch_size *" " (seq+num_tokens_to_generate)" @@ -1629,16 +1629,16 @@ def config_batch_size_and_gradient_accumulation_steps( gradient_accumulation_steps = global_batch_size // ( batch_size_per_gpu * dp_size) assert (global_batch_size % (batch_size_per_gpu * dp_size) == 0 - and gradient_accumulation_steps - > 0), "no valid gradient_accumulation_steps, {assert_msg}" + and gradient_accumulation_steps > 0 + ), "no valid gradient_accumulation_steps, {assert_msg}" elif global_batch_size and gradient_accumulation_steps: # batch_size_per_gpu is None, the other two are not None batch_size_per_gpu = global_batch_size // ( gradient_accumulation_steps * dp_size) assert (global_batch_size % (gradient_accumulation_steps * dp_size) == 0 - and batch_size_per_gpu - > 0), "no valid batch_size_per_gpu, {assert_msg}" + and batch_size_per_gpu > 0 + ), "no valid batch_size_per_gpu, {assert_msg}" elif batch_size_per_gpu and gradient_accumulation_steps: # global_batch_size is None, the other two are not None global_batch_size = (batch_size_per_gpu * @@ -1667,9 +1667,9 @@ def config_batch_size_and_gradient_accumulation_steps( else: # (global_batch_size and gradient_accumulation_steps are None) or (global_batch_size and batch_size_per_gpu are None) or (all are None) batch_size_per_gpu = max_batch_size_per_gpu - gradient_accumulation_steps = (1 if gradient_accumulation_steps - is None else - gradient_accumulation_steps) + gradient_accumulation_steps = (1 if + gradient_accumulation_steps is None + else gradient_accumulation_steps) global_batch_size = (batch_size_per_gpu * gradient_accumulation_steps * self.parallelism_config.dp_size)