Skip to content

Commit

Permalink
[Model] Enable Llama xfail inference on CI
Browse files Browse the repository at this point in the history
- Disable TVM param conversion in order to reduce DRAM memory usage
- This behaiviour disables TVM const prop and verificaiton, but itn't a blocker for functional model
- Compile requirement is reduced to around 25 GB, which is enough to fit below 32 GB host memory on CI

Closes #247
  • Loading branch information
nvukobratTT committed Sep 9, 2024
1 parent 66ace96 commit 02313ce
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 2 additions & 1 deletion forge/test/mlir/llama/test_llama_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from test.mlir.llama.utils.utils import load_model


@pytest.mark.xfail(reason="Tile broadcast op is not supported on MLIR.")
@pytest.mark.xfail()
def test_llama_inference():
# Load Llama 3B model and tokenizer
model_path = "openlm-research/open_llama_3b"
Expand All @@ -27,6 +27,7 @@ def test_llama_inference():
# Compile the model
compiled_model = forge.compile(framework_model, input_ids)


@pytest.mark.skip(reason="No need to run in CI, this is PoC that should be mapped to work on device.")
def test_llama_inference_no_cache_cpu():
"""
Expand Down
7 changes: 6 additions & 1 deletion forge/test/mlir/llama/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
def load_model(model_path="openlm-research/open_llama_3b", use_cache=False):
# Compiler configurations
compiler_cfg = forge.config._get_global_compiler_config()

# Disable CPU fallback, we want to run whole model on device
compiler_cfg.enable_tvm_cpu_fallback = False
# Reduce compile memory usage, but disables TVM verification
# and TVM constant evaluation (Forge const eval is enabled)
compiler_cfg.convert_framework_params_to_tvm = False

# Load Llama 3B model
config = LlamaConfig()
Expand All @@ -23,5 +28,5 @@ def load_model(model_path="openlm-research/open_llama_3b", use_cache=False):
model_path, device_map="auto", config=config
)
framework_model.eval()

return framework_model

0 comments on commit 02313ce

Please sign in to comment.