Skip to content

Commit

Permalink
Update paths to weights for llama 8b benchmarking tests (#414)
Browse files Browse the repository at this point in the history
Signed-off-by: aviator19941 <[email protected]>
  • Loading branch information
aviator19941 authored Nov 2, 2024
1 parent 0b5c9c6 commit 8d5f850
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions sharktank/tests/models/llama/benchmark_amdgpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ class BenchmarkLlama3_1_8B(BaseBenchmarkTest):
def setUp(self):
super().setUp()
# TODO: add numpy files to Azure and download from it
self.artifacts_dir = Path("/data/llama-3.1/8b")
self.irpa_path = self.artifacts_dir / "llama8b_f16.irpa"
self.irpa_path_fp8 = self.artifacts_dir / "llama8b_fp8.irpa"
self.artifacts_dir = Path("/data/llama-3.1/weights/8b")
self.irpa_path = self.artifacts_dir / "fp16/llama3.1_8b_fp16.irpa"
self.irpa_path_fp8 = self.artifacts_dir / "f8/llama8b_fp8.irpa"
self.tensor_parallelism_size = 1
self.dir_path_8b = self.dir_path / "llama-8b"
self.temp_dir_8b = Path(self.dir_path_8b)
Expand Down Expand Up @@ -305,9 +305,9 @@ class BenchmarkLlama3_1_70B(BaseBenchmarkTest):
def setUp(self):
super().setUp()
# TODO: add numpy files to Azure and download from it
artifacts_dir = Path("/data/llama-3.1/70b")
self.irpa_path = artifacts_dir / "llama70b_f16.irpa"
self.irpa_path_fp8 = artifacts_dir / "llama70b_fp8.irpa"
artifacts_dir = Path("/data/llama-3.1/weights/70b")
self.irpa_path = artifacts_dir / "fp16/llama3.1_70b_f16.irpa"
self.irpa_path_fp8 = artifacts_dir / "f8/llama70b_fp8.irpa"
self.tensor_parallelism_size = 1
self.dir_path_70b = self.dir_path / "llama-70b"
self.temp_dir_70b = Path(self.dir_path_70b)
Expand Down Expand Up @@ -536,9 +536,9 @@ class BenchmarkLlama3_1_405B(BaseBenchmarkTest):
def setUp(self):
super().setUp()
# TODO: add numpy files to Azure and download from it
artifacts_dir = Path("/data/llama-3.1/405b")
self.irpa_path = artifacts_dir / "llama405b_f16.irpa"
self.irpa_path_fp8 = artifacts_dir / "llama405b_fp8.irpa"
artifacts_dir = Path("/data/llama-3.1/weights/405b")
self.irpa_path = artifacts_dir / "fp16/llama3.1_405b_fp16.irpa"
self.irpa_path_fp8 = artifacts_dir / "f8/llama405b_fp8.irpa"
self.tensor_parallelism_size = 8
self.dir_path_405b = self.dir_path / "llama-405b"
self.temp_dir_405b = Path(self.dir_path_405b)
Expand Down

0 comments on commit 8d5f850

Please sign in to comment.