Skip to content

Commit

Permalink
Added test for llama-3.1-8b model finetuning demo
Browse files Browse the repository at this point in the history
  • Loading branch information
abhijeet-dhumal authored and openshift-merge-bot[bot] committed Aug 28, 2024
1 parent 69495c4 commit 2a0386c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 14 deletions.
2 changes: 0 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ require (
sigs.k8s.io/kueue v0.6.2
)

replace github.com/project-codeflare/codeflare-common => /home/abdhumal/abhidev/RedHatDev/codeflare-common

require (
github.com/aymerick/douceur v0.2.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
Expand Down
40 changes: 40 additions & 0 deletions tests/odh/mnist_raytune_hpo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,46 @@ func mnistRayTuneHpo(t *testing.T, numGpus int) {
ContainElement(WithTransform(KueueWorkloadAdmitted, BeTrueBecause("Workload failed to be admitted"))),
),
)
time.Sleep(30 * time.Second)

// Fetch created raycluster
rayClusterName := "mnisthpotest"
rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Get(test.Ctx(), rayClusterName, metav1.GetOptions{})
test.Expect(err).ToNot(HaveOccurred())

// Initialise raycluster client to interact with raycluster to get rayjob details using REST-API
dashboardUrl := GetDashboardUrl(test, namespace, rayCluster)
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, InsecureSkipVerify: true}
rayClient, err := NewRayClusterClient(rayClusterClientConfig, test.Config().BearerToken)
if err != nil {
test.T().Errorf("%s", err)
}

jobID := GetTestJobId(test, rayClient, dashboardUrl.Host)
test.Expect(jobID).ToNot(Equal(nil))

// Wait for the job to be succeeded or failed
var rayJobStatus string
fmt.Printf("Waiting for job to be Succeeded...\n")
test.Eventually(func() string {
resp, err := rayClient.GetJobDetails(jobID)
test.Expect(err).ToNot(HaveOccurred())
rayJobStatusVal := resp.Status
if rayJobStatusVal == "SUCCEEDED" || rayJobStatusVal == "FAILED" {
fmt.Printf("JobStatus : %s\n", rayJobStatusVal)
rayJobStatus = rayJobStatusVal
return rayJobStatus
}
if rayJobStatus != rayJobStatusVal && rayJobStatusVal != "SUCCEEDED" {
fmt.Printf("JobStatus : %s...\n", rayJobStatusVal)
rayJobStatus = rayJobStatusVal
}
return rayJobStatus
}, TestTimeoutDouble, 3*time.Second).Should(Or(Equal("SUCCEEDED"), Equal("FAILED")), "Job did not complete within the expected time")
test.Expect(rayJobStatus).To(Equal("SUCCEEDED"), "RayJob failed !")

// Store job logs in output directory
WriteRayJobAPILogs(test, rayClient, jobID)

// Fetch created raycluster
rayClusterName := "mnisthpotest"
Expand Down
30 changes: 18 additions & 12 deletions tests/odh/ray_finetune_llm_deepspeed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

func TestRayFinetuneLlmDeepspeedDemo(t *testing.T) {
rayFinetuneLlmDeepspeed(t, 1)
func TestRayFinetuneLlmDeepspeedDemoLlama_2_7b(t *testing.T) {
rayFinetuneLlmDeepspeed(t, 1, "zero_3_llama_2_7b.json")
}
func TestRayFinetuneLlmDeepspeedDemoLlama_31_8b(t *testing.T) {
rayFinetuneLlmDeepspeed(t, 1, "zero_3_offload_optim_param.json")
}

func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string) {
test := With(t)

// Create a namespace
Expand All @@ -51,7 +54,7 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
// list changes required in llm-deepspeed-finetune-demo.ipynb file and update those
requiredChangesInNotebook := map[string]string{
"import os": "import os,time,sys",
"import sys": "!cp /opt/app-root/notebooks/* ./",
"import sys": "!cp /opt/app-root/notebooks/* ./\\n\",\n\t\"!ls",
"from codeflare_sdk.cluster.auth import TokenAuthentication": "from codeflare_sdk.cluster.auth import TokenAuthentication\\n\",\n\t\"from codeflare_sdk.job import RayJobClient",
"token = ''": fmt.Sprintf("token = '%s'", userToken),
"server = ''": fmt.Sprintf("server = '%s'", GetOpenShiftApiUrl(test)),
Expand All @@ -61,23 +64,26 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
"num_workers=7": "num_workers=1",
"worker_cpu_requests=16": "worker_cpu_requests=4",
"worker_cpu_limits=16": "worker_cpu_limits=4",
"worker_memory_requests=128": "worker_memory_requests=60",
"worker_memory_limits=256": "worker_memory_limits=60",
"worker_memory_requests=128": "worker_memory_requests=64",
"worker_memory_limits=256": "worker_memory_limits=128",
"head_memory=128": "head_memory=48",
"client = cluster.job_client": "ray_dashboard = cluster.cluster_dashboard_uri()\\n\",\n\t\"header = {\\\"Authorization\\\": \\\"Bearer " + userToken + "\\\"}\\n\",\n\t\"client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\n",
"--num-devices=8": fmt.Sprintf("--num-devices=%d", numGpus),
"--num-epochs=3": fmt.Sprintf("--num-epochs=%d", 1),
"--ds-config=./deepspeed_configs/zero_3_llama_2_7b.json": "--ds-config=./zero_3_llama_2_7b.json \\\"\\n\",\n\t\" \\\"--lora-config=./lora.json \\\"\\n\",\n\t\" \\\"--as-test",
"'pip': 'requirements.txt'": "'pip': '/opt/app-root/src/requirements.txt'",
"'working_dir': './'": "'working_dir': '/opt/app-root/src'",
"client.stop_job(submission_id)": "finished = False\\n\",\n\t\"while not finished:\\n\",\n\t\" time.sleep(1)\\n\",\n\t\" status = client.get_job_status(submission_id)\\n\",\n\t\" finished = (status == \\\"SUCCEEDED\\\")\\n\",\n\t\"if finished:\\n\",\n\t\" print(\\\"Job completed Successfully !\\\")\\n\",\n\t\"else:\\n\",\n\t\" print(\\\"Job failed !\\\")\\n\",\n\t\"time.sleep(10)\\n",
"--ds-config=./deepspeed_configs/zero_3_offload_optim+param.json": fmt.Sprintf("--ds-config=./%s \\\"\\n\",\n\t\" \\\"--lora-config=./lora.json \\\"\\n\",\n\t\" \\\"--as-test", modelConfigFile),
"--batch-size-per-device=32": "--batch-size-per-device=6",
"--eval-batch-size-per-device=32": "--eval-batch-size-per-device=6",
"'pip': 'requirements.txt'": "'pip': '/opt/app-root/src/requirements.txt'",
"'working_dir': './'": "'working_dir': '/opt/app-root/src'",
"client.stop_job(submission_id)": "finished = False\\n\",\n\t\"while not finished:\\n\",\n\t\" time.sleep(1)\\n\",\n\t\" status = client.get_job_status(submission_id)\\n\",\n\t\" finished = (status == \\\"SUCCEEDED\\\")\\n\",\n\t\"if finished:\\n\",\n\t\" print(\\\"Job completed Successfully !\\\")\\n\",\n\t\"else:\\n\",\n\t\" print(\\\"Job failed !\\\")\\n\",\n\t\"time.sleep(10)\\n",
}

updatedNotebookContent := string(ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/ray_finetune_llm_deepspeed.ipynb"))
for oldValue, newValue := range requiredChangesInNotebook {
updatedNotebookContent = strings.Replace(updatedNotebookContent, oldValue, newValue, -1)
}
updatedNotebook := []byte(updatedNotebookContent)
os.WriteFile("demo.ipynb", updatedNotebook, 0644)

// Test configuration
jupyterNotebookConfigMapFileName := "ray_finetune_llm_deepspeed.ipynb"
Expand All @@ -87,7 +93,7 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
"requirements.txt": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/requirements.txt"),
"create_dataset.py": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/create_dataset.py"),
"lora.json": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/lora_configs/lora.json"),
"zero_3_llama_2_7b.json": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/deepspeed_configs/zero_3_llama_2_7b.json"),
modelConfigFile: ReadFileExt(test, fmt.Sprintf(workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/deepspeed_configs/%s", modelConfigFile)),
"utils.py": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/utils.py"),
}

Expand Down Expand Up @@ -120,7 +126,7 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {

// Initialise raycluster client to interact with raycluster to get rayjob details using REST-API
dashboardUrl := GetDashboardUrl(test, namespace, rayCluster)
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, SkipTlsVerification: true}
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, InsecureSkipVerify: true}
rayClient, err := NewRayClusterClient(rayClusterClientConfig, test.Config().BearerToken)
if err != nil {
test.T().Errorf("%s", err)
Expand Down

0 comments on commit 2a0386c

Please sign in to comment.