From a544085807c2928ef73ff70a6b55eec18e7c2313 Mon Sep 17 00:00:00 2001 From: turboderp Date: Fri, 1 Sep 2023 12:53:50 +0200 Subject: [PATCH] Command line argument to override rope_theta from config.json --- model_init.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/model_init.py b/model_init.py index d694a7d3..7703f07c 100644 --- a/model_init.py +++ b/model_init.py @@ -15,6 +15,7 @@ def add_args(parser): parser.add_argument("-l", "--length", type = int, help = "Maximum sequence length", default = 2048) parser.add_argument("-cpe", "--compress_pos_emb", type = float, help = "Compression factor for positional embeddings", default = 1.0) parser.add_argument("-a", "--alpha", type = float, help = "alpha for context size extension via embedding extension", default = 1.0) + parser.add_argument("-theta", "--theta", type = float, help = "theta (base) for RoPE embeddings") parser.add_argument("-gpfix", "--gpu_peer_fix", action = "store_true", help = "Prevent direct copies of data between GPUs") @@ -140,6 +141,9 @@ def make_config(args): config.silu_no_half2 = args.silu_no_half2 config.concurrent_streams = args.concurrent_streams + if args.theta: + config.rotary_embedding_base = args.theta + return config