Skip to content

Commit

Permalink
python接口增加显存控制
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 17, 2024
1 parent 3f594cf commit 4c3da0d
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 2 deletions.
38 changes: 36 additions & 2 deletions tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@
fastllm_lib.apply_chat_template.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p]
fastllm_lib.apply_chat_template.restype = ctypes.c_char_p

fastllm_lib.set_kv_cache_limit_llm_model.argtypes = [ctypes.c_int, ctypes.c_int64]

fastllm_lib.set_max_batch_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]

fastllm_lib.set_verbose_llm_model.argtypes = [ctypes.c_int, ctypes.c_bool]

fastllm_lib.get_max_input_len_llm_model.argtypes = [ctypes.c_int]
fastllm_lib.get_max_input_len_llm_model.restype = ctypes.c_int

def set_cpu_threads(threads: int):
fastllm_lib.set_cpu_threads(threads);

Expand Down Expand Up @@ -841,7 +850,32 @@ def release_memory(self):
fastllm_lib.release_memory(self.model)

def set_save_history(self, save: bool):
fastllm_lib.set_save_history(self.model, save);
fastllm_lib.set_save_history(self.model, save)

def set_atype(self, atype: str):
fastllm_lib.set_model_atype(self.model, str(atype).encode());
fastllm_lib.set_model_atype(self.model, str(atype).encode())

def set_kv_cache_limit(self, limit: str):
limit_bytes = 0
try:
if (limit.endswith('k') or limit.endswith('K')):
limit_bytes = int(limit[:-1]) * 1024
elif (limit.endswith('m') or limit.endswith('M')):
limit_bytes = int(limit[:-1]) * 1024 * 1024
elif (limit.endswith('g') or limit.endswith('G')):
limit_bytes = int(limit[:-1]) * 1024 * 1024 * 1024
else:
limit_bytes = int(limit[:-1])
except:
print('set_kv_cache_limit error, param should be like "10k" or "10m" or "1g"')
exit(0)
fastllm_lib.set_kv_cache_limit_llm_model(self.model, ctypes.c_int64(limit_bytes))

def set_max_batch(self, batch: int):
fastllm_lib.set_max_batch_llm_model(self.model, batch)

def set_verbose(self, verbose: int):
fastllm_lib.set_verbose_llm_model(self.model, verbose)

def get_max_input_len(self):
return fastllm_lib.get_max_input_len_llm_model(self.model)
1 change: 1 addition & 0 deletions tools/fastllm_pytools/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def init_logging(log_level = logging.INFO, log_file:str = None):
args = parse_args()
logging.info(args)
model = make_normal_llm_model(args)
model.set_verbose(True)
fastllm_completion = FastLLmCompletion(model_name = args.model_name,
model = model)
uvicorn.run(app, host = args.host, port = args.port)
6 changes: 6 additions & 0 deletions tools/fastllm_pytools/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ def make_normal_parser(des: str) -> argparse.ArgumentParser:
parser.add_argument('--dtype', type = str, default = "float16", help = '权重类型(读取HF模型时有效)')
parser.add_argument('--atype', type = str, default = "float32", help = '推理类型,可使用float32或float16')
parser.add_argument('--cuda_embedding', action = 'store_true', help = '在cuda上进行embedding')
parser.add_argument('--kv_cache_limit', type = str, default = "auto", help = 'kv缓存最大使用量')
parser.add_argument('--max_batch', type = int, default = -1, help = '每次最多同时推理的询问数量')
parser.add_argument('--device', type = str, help = '使用的设备')
return parser

Expand All @@ -29,4 +31,8 @@ def make_normal_llm_model(args):
llm.set_cuda_embedding(True)
model = llm.model(args.path, dtype = args.dtype, tokenizer_type = "auto")
model.set_atype(args.atype)
if (args.max_batch > 0):
model.set_max_batch(args.max_batch)
if (args.kv_cache_limit != "" and args.kv_cache_limit != "auto"):
model.set_kv_cache_limit(args.kv_cache_limit)
return model
20 changes: 20 additions & 0 deletions tools/src/pytools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,4 +391,24 @@ extern "C" {
auto model = models.GetModel(modelId);
model->AddPromptCache(input);
}

DLL_EXPORT void set_kv_cache_limit_llm_model(int modelId, long long bytes) {
auto model = models.GetModel(modelId);
model->kvCacheLimit = bytes;
}

DLL_EXPORT void set_max_batch_llm_model(int modelId, int batch) {
auto model = models.GetModel(modelId);
model->maxBatch = batch;
}

DLL_EXPORT void set_verbose_llm_model(int modelId, bool verbose) {
auto model = models.GetModel(modelId);
model->verbose = verbose;
}

DLL_EXPORT int get_max_input_len_llm_model(int modelId) {
auto model = models.GetModel(modelId);
return model->max_positions;
}
};

0 comments on commit 4c3da0d

Please sign in to comment.