From 39083d36642b5e89244299608d45643ba5f53091 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 3 Aug 2023 16:19:53 -0700 Subject: [PATCH] [stabilityai_pytorch][inference] Stability AI Inference DLC (#3195) Co-authored-by: arjkesh <33526713+arjkesh@users.noreply.github.com> Co-authored-by: Shantanu Tripathi --- src/image_builder.py | 1 + stabilityai/pytorch/inference/buildspec.yml | 34 +--- .../docker/2.0/py3/cu118/Dockerfile.gpu | 47 +++++ .../torchserve-stabilityai-entrypoint.py | 48 +++++ .../container_tests/bin/security_checks.py | 10 +- .../inference/test_pytorch_inference.py | 6 +- .../test_boottime_container_security.py | 2 +- test/dlc_tests/sanity/test_pre_release.py | 6 +- .../pytorch/inference/integration/__init__.py | 5 + .../integration/sagemaker/test_stabilityai.py | 132 +++++++++++- .../stabilityai/sdxl-v1/model_gpu/.gitignore | 1 + .../stabilityai/sdxl-v1/model_gpu/README.md | 3 + .../sdxl-v1/model_gpu/code/sdxl_inference.py | 188 ++++++++++++++++++ test/testrunner.py | 2 +- 14 files changed, 444 insertions(+), 41 deletions(-) create mode 100644 stabilityai/pytorch/inference/docker/2.0/py3/cu118/Dockerfile.gpu create mode 100644 stabilityai/pytorch/inference/docker/build_artifacts/torchserve-stabilityai-entrypoint.py create mode 100644 test/sagemaker_tests/pytorch/inference/resources/stabilityai/sdxl-v1/model_gpu/.gitignore create mode 100644 test/sagemaker_tests/pytorch/inference/resources/stabilityai/sdxl-v1/model_gpu/README.md create mode 100644 test/sagemaker_tests/pytorch/inference/resources/stabilityai/sdxl-v1/model_gpu/code/sdxl_inference.py diff --git a/src/image_builder.py b/src/image_builder.py index d973eb1bb320..b75740b236ab 100644 --- a/src/image_builder.py +++ b/src/image_builder.py @@ -82,6 +82,7 @@ def image_builder(buildspec, image_types=[], device_types=[]): if ( "huggingface" in str(BUILDSPEC["framework"]) or "autogluon" in str(BUILDSPEC["framework"]) + or "stabilityai" in str(BUILDSPEC["framework"]) or "trcomp" in str(BUILDSPEC["framework"]) ): os.system("echo login into public ECR") diff --git a/stabilityai/pytorch/inference/buildspec.yml b/stabilityai/pytorch/inference/buildspec.yml index 271e9120ba3d..3ae0f93d0948 100644 --- a/stabilityai/pytorch/inference/buildspec.yml +++ b/stabilityai/pytorch/inference/buildspec.yml @@ -9,40 +9,17 @@ arch_type: x86 repository_info: inference_repository: &INFERENCE_REPOSITORY image_type: &INFERENCE_IMAGE_TYPE inference - root: !join [ *BASE_FRAMEWORK, "/", *INFERENCE_IMAGE_TYPE ] + root: !join [ "stabilityai/", *BASE_FRAMEWORK, "/", *INFERENCE_IMAGE_TYPE ] repository_name: &REPOSITORY_NAME !join [pr, "-", "stabilityai", "-", *BASE_FRAMEWORK, "-", *INFERENCE_IMAGE_TYPE] repository: &REPOSITORY !join [ *ACCOUNT_ID, .dkr.ecr., *REGION, .amazonaws.com/, *REPOSITORY_NAME ] context: inference_context: &INFERENCE_CONTEXT - torchserve-ec2-entrypoint: - source: docker/build_artifacts/torchserve-ec2-entrypoint.py - target: torchserve-ec2-entrypoint.py torchserve-entrypoint: - source: docker/build_artifacts/torchserve-entrypoint.py + source: docker/build_artifacts/torchserve-stabilityai-entrypoint.py target: torchserve-entrypoint.py - config: - source: docker/build_artifacts/config.properties - target: config.properties - deep_learning_container: - source: ../../src/deep_learning_container.py - target: deep_learning_container.py images: - BuildStabilityaiPytorchCpuPy310InferenceDockerImage: - <<: *INFERENCE_REPOSITORY - build: &STABILITYAI_PYTORCH_CPU_INFERENCE_PY3 false - image_size_baseline: 4900 - device_type: &DEVICE_TYPE cpu - python_version: &DOCKER_PYTHON_VERSION py3 - tag_python_version: &TAG_PYTHON_VERSION py310 - os_version: &OS_VERSION ubuntu20.04 - diffusers_version: &DIFFUSERS_VERSION 1.2.3 - tag: !join [ *VERSION, "-", 'diffusers',*DIFFUSERS_VERSION, '-', *DEVICE_TYPE, "-", *TAG_PYTHON_VERSION, "-", *OS_VERSION, "-sagemaker" ] - docker_file: !join [ docker/, *SHORT_VERSION, /, *DOCKER_PYTHON_VERSION, /Dockerfile., *DEVICE_TYPE ] - target: sagemaker - context: - <<: *INFERENCE_CONTEXT BuildStabilityaiPytorchGpuPy310InferenceDockerImage: <<: *INFERENCE_REPOSITORY build: &STABILITYAI_PYTORCH_GPU_INFERENCE_PY3 false @@ -52,10 +29,9 @@ images: tag_python_version: &TAG_PYTHON_VERSION py310 cuda_version: &CUDA_VERSION cu118 os_version: &OS_VERSION ubuntu20.04 - diffusers_version: &DIFFUSERS_VERSION 1.2.3 - tag: !join [ *VERSION, "-", 'diffusers',*DIFFUSERS_VERSION, '-', *DEVICE_TYPE, "-", *TAG_PYTHON_VERSION, "-", *CUDA_VERSION, "-", *OS_VERSION, "-sagemaker" ] + sgm_version: &SGM_VERSION 0.1.0 + tag: !join [ *VERSION, "-", 'sgm',*SGM_VERSION, '-', *DEVICE_TYPE, "-", *TAG_PYTHON_VERSION, "-", *CUDA_VERSION, "-", *OS_VERSION, "-sagemaker" ] docker_file: !join [ docker/, *SHORT_VERSION, /, *DOCKER_PYTHON_VERSION, /, *CUDA_VERSION, /Dockerfile., *DEVICE_TYPE ] - target: sagemaker context: - <<: *INFERENCE_CONTEXT + <<: *INFERENCE_CONTEXT diff --git a/stabilityai/pytorch/inference/docker/2.0/py3/cu118/Dockerfile.gpu b/stabilityai/pytorch/inference/docker/2.0/py3/cu118/Dockerfile.gpu new file mode 100644 index 000000000000..028b50e42486 --- /dev/null +++ b/stabilityai/pytorch/inference/docker/2.0/py3/cu118/Dockerfile.gpu @@ -0,0 +1,47 @@ +FROM 763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.0.1-gpu-py310-cu118-ubuntu20.04-sagemaker + +LABEL dlc_major_version="1" +ARG PYTHON=python3 +ARG XFORMERS_VERSION=0.0.20 + +# xformers must be installed from source due to the older version of python in the DLC +RUN pip install ninja \ + && pip install -v -U git+https://github.com/facebookresearch/xformers.git@v${XFORMERS_VERSION}#egg=xformers + +ARG SGM_VERSION=0.1.0 + +# Install Stability Generative Models, at the moment the wheel install does not work so we need the full repo +RUN cd /tmp \ + && git clone https://github.com/stability-ai/generative-models -b ${SGM_VERSION} \ + && cd generative-models \ + && pip install -r requirements/pt2.txt \ + && pip install . \ + && rm -rf /tmp/generative-models + +# Resolve pip check conflicts and other issues +RUN pip install --no-cache-dir -U \ + "awscli>=1.29.15" \ + "boto3>=1.28.15" \ + "certifi>=2023.07.22" \ + "pyopenssl>=23.2.0" \ + "cryptography>=41.0.2" \ + "transformers>=4.23.0" + +# Configure Torchserve for large model loading +ENV TS_DEFAULT_RESPONSE_TIMEOUT=1000 + +# Copy custom entrypoint, which can unpack cache files +ENV HUGGINGFACE_HUB_CACHE=/tmp/cache/huggingface/hub +ENV TRANSFORMERS_CACHE=/tmp/cache/huggingface/transformers +COPY torchserve-entrypoint.py /usr/local/bin/dockerd-entrypoint.py +RUN mkdir -p /tmp/cache/huggingface \ + && chmod +x /usr/local/bin/dockerd-entrypoint.py + +RUN HOME_DIR=/root \ + && curl -o ${HOME_DIR}/oss_compliance.zip https://aws-dlinfra-utilities.s3.amazonaws.com/oss_compliance.zip \ + && unzip ${HOME_DIR}/oss_compliance.zip -d ${HOME_DIR}/ \ + && cp ${HOME_DIR}/oss_compliance/test/testOSSCompliance /usr/local/bin/testOSSCompliance \ + && chmod +x /usr/local/bin/testOSSCompliance \ + && chmod +x ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh \ + && ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh ${HOME_DIR} ${PYTHON} \ + && rm -rf ${HOME_DIR}/oss_compliance* diff --git a/stabilityai/pytorch/inference/docker/build_artifacts/torchserve-stabilityai-entrypoint.py b/stabilityai/pytorch/inference/docker/build_artifacts/torchserve-stabilityai-entrypoint.py new file mode 100644 index 000000000000..851b537c3b28 --- /dev/null +++ b/stabilityai/pytorch/inference/docker/build_artifacts/torchserve-stabilityai-entrypoint.py @@ -0,0 +1,48 @@ +# Copyright 2019-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import shlex +import subprocess +import sys + +from sagemaker_inference import environment + +SAI_MODEL_CACHE_FILE = os.path.join( + environment.model_dir, os.getenv("SAI_MODEL_CACHE_FILE", "stabilityai-model-cache.tar") +) +SAI_MODEL_CACHE_PATH = os.getenv("SAI_MODEL_CACHE_PATH", "/tmp/cache") +SAI_MODEL_CACHE_STATUS_FILE = os.path.join(SAI_MODEL_CACHE_PATH, ".model-cache-unpacked") +if os.path.exists(SAI_MODEL_CACHE_FILE) and not os.path.exists(SAI_MODEL_CACHE_STATUS_FILE): + subprocess.check_call( + [ + "tar", + "-x", + "-z" if SAI_MODEL_CACHE_FILE.endswith(".gz") else "", + "-f", + SAI_MODEL_CACHE_FILE, + "-C", + SAI_MODEL_CACHE_PATH, + ] + ) + +if sys.argv[1] == "serve": + from sagemaker_pytorch_serving_container import serving + + serving.main() +else: + subprocess.check_call(shlex.split(" ".join(sys.argv[1:]))) + +# prevent docker exit +subprocess.call(["tail", "-f", "/dev/null"]) diff --git a/test/dlc_tests/container_tests/bin/security_checks.py b/test/dlc_tests/container_tests/bin/security_checks.py index f5c7828f2427..dad2635e972d 100644 --- a/test/dlc_tests/container_tests/bin/security_checks.py +++ b/test/dlc_tests/container_tests/bin/security_checks.py @@ -3,15 +3,19 @@ import os import time import calendar +import argparse LOGGER = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--image_uri", help="Provide Image Uri", default="") + args = parser.parse_args() home_dir = os.path.expanduser("~") check_that_cache_dir_is_removed(home_dir) - check_that_global_tmp_dir_is_empty() + check_that_global_tmp_dir_is_empty(image_uri=args.image_uri) check_vim_info_does_not_exists(home_dir) check_bash_history(home_dir) check_if_any_files_in_subfolder_with_mask_was_last_modified_before_the_boottime( @@ -49,7 +53,7 @@ def check_that_cache_dir_is_removed(home_dir): ) -def check_that_global_tmp_dir_is_empty(): +def check_that_global_tmp_dir_is_empty(image_uri=""): global_tmp_dir_path = "/tmp/" global_tmp_dir_content = [f for f in os.listdir(global_tmp_dir_path)] for f in global_tmp_dir_content: @@ -60,6 +64,8 @@ def check_that_global_tmp_dir_is_empty(): and "ccNPSUr9.s" not in f and "hsperfdata" not in f ): + if "stabilityai" in image_uri and "cache" in f.lower(): + continue raise ValueError( "/tmp folder includes file that probably should not be there: {}".format(f) ) diff --git a/test/dlc_tests/ec2/pytorch/inference/test_pytorch_inference.py b/test/dlc_tests/ec2/pytorch/inference/test_pytorch_inference.py index 7a48ad67386f..9900dc18ea32 100644 --- a/test/dlc_tests/ec2/pytorch/inference/test_pytorch_inference.py +++ b/test/dlc_tests/ec2/pytorch/inference/test_pytorch_inference.py @@ -133,7 +133,7 @@ def test_ec2_pytorch_inference_eia_gpu( @pytest.mark.usefixtures("feature_torchaudio_present") -@pytest.mark.usefixtures("sagemaker") +@pytest.mark.usefixtures("sagemaker", "stabilityai") @pytest.mark.integration("pt_torchaudio_gpu") @pytest.mark.model("N/A") @pytest.mark.parametrize("ec2_instance_type", PT_EC2_GPU_INSTANCE_TYPE, indirect=True) @@ -163,7 +163,7 @@ def test_pytorch_inference_torchaudio_cpu(pytorch_inference, ec2_connection, cpu @pytest.mark.usefixtures("feature_torchdata_present") -@pytest.mark.usefixtures("sagemaker") +@pytest.mark.usefixtures("sagemaker", "stabilityai") @pytest.mark.integration("pt_torchdata_gpu") @pytest.mark.model("N/A") @pytest.mark.parametrize("ec2_instance_type", PT_EC2_GPU_INSTANCE_TYPE, indirect=True) @@ -246,7 +246,7 @@ def ec2_pytorch_inference(image_uri, processor, ec2_connection, region): ec2_connection.run(f"docker rm -f {container_name}", warn=True, hide=True) -@pytest.mark.usefixtures("sagemaker") +@pytest.mark.usefixtures("sagemaker", "stabilityai") @pytest.mark.integration("telemetry") @pytest.mark.model("N/A") @pytest.mark.parametrize("ec2_instance_type", PT_EC2_SINGLE_GPU_INSTANCE_TYPE, indirect=True) diff --git a/test/dlc_tests/sanity/test_boottime_container_security.py b/test/dlc_tests/sanity/test_boottime_container_security.py index 816c14a9de00..0156a9ae8cc8 100644 --- a/test/dlc_tests/sanity/test_boottime_container_security.py +++ b/test/dlc_tests/sanity/test_boottime_container_security.py @@ -20,6 +20,6 @@ def test_security(image): ) try: docker_exec_cmd = f"docker exec -i {container_name}" - run(f"{docker_exec_cmd} python /test/bin/security_checks.py ", hide=True) + run(f"{docker_exec_cmd} python /test/bin/security_checks.py --image_uri {image}", hide=True) finally: run(f"docker rm -f {container_name}", hide=True) diff --git a/test/dlc_tests/sanity/test_pre_release.py b/test/dlc_tests/sanity/test_pre_release.py index 0a00d5b3f7cf..79f70c72c34c 100644 --- a/test/dlc_tests/sanity/test_pre_release.py +++ b/test/dlc_tests/sanity/test_pre_release.py @@ -65,6 +65,10 @@ def test_stray_files(image): # Running list of allowed files in the /tmp directory allowed_tmp_files = ["hsperfdata_root"] + # Allow cache dir for SAI images + if "stabilityai" in image: + allowed_tmp_files.append("cache") + # Ensure stray artifacts are not in the tmp directory tmp = run_cmd_on_container(container_name, ctx, "ls -A /tmp") _assert_artifact_free(tmp, stray_artifacts) @@ -716,7 +720,7 @@ def test_cuda_paths(gpu): python_version = re.search(r"(py\d+)", image).group(1) short_python_version = None image_tag = re.search( - r":(\d+(\.\d+){2}(-(transformers|diffusers)\d+(\.\d+){2})?-(gpu)-(py\d+)(-cu\d+)-(ubuntu\d+\.\d+)((-ec2)?-example|-ec2|-sagemaker-lite|-sagemaker-full|-sagemaker)?)", + r":(\d+(\.\d+){2}(-(transformers|diffusers|sgm)\d+(\.\d+){2})?-(gpu)-(py\d+)(-cu\d+)-(ubuntu\d+\.\d+)((-ec2)?-example|-ec2|-sagemaker-lite|-sagemaker-full|-sagemaker)?)", image, ).group(1) diff --git a/test/sagemaker_tests/pytorch/inference/integration/__init__.py b/test/sagemaker_tests/pytorch/inference/integration/__init__.py index 26be82de15e7..d85f10c8825f 100644 --- a/test/sagemaker_tests/pytorch/inference/integration/__init__.py +++ b/test/sagemaker_tests/pytorch/inference/integration/__init__.py @@ -46,6 +46,11 @@ resnet_neuronx_image_list = os.path.join(model_neuronx_dir, "imagenet1000_clsidx_to_labels.txt") call_model_fn_once_script = os.path.join(resources_path, code_sub_dir, "call_model_fn_once.py") +stabilityai_path = os.path.join(resources_path, "stabilityai") +sdxl_path = os.path.join(stabilityai_path, "sdxl-v1") +sdxl_gpu_path = os.path.join(sdxl_path, gpu_sub_dir) +sdxl_gpu_script = os.path.join(sdxl_gpu_path, code_sub_dir, "sdxl_inference.py") + ROLE = "dummy/unused-role" DEFAULT_TIMEOUT = 20 diff --git a/test/sagemaker_tests/pytorch/inference/integration/sagemaker/test_stabilityai.py b/test/sagemaker_tests/pytorch/inference/integration/sagemaker/test_stabilityai.py index f306f7960dee..d0edf45c847a 100644 --- a/test/sagemaker_tests/pytorch/inference/integration/sagemaker/test_stabilityai.py +++ b/test/sagemaker_tests/pytorch/inference/integration/sagemaker/test_stabilityai.py @@ -1,21 +1,145 @@ from __future__ import absolute_import +import os import sys +from io import BytesIO +from PIL import Image import pytest +import sagemaker +from sagemaker.pytorch import PyTorchModel +from sagemaker.serializers import JSONSerializer +from sagemaker.deserializers import BytesDeserializer + +import time import logging +from ...integration import sdxl_gpu_path, sdxl_gpu_script + + +from .timeout import timeout_and_delete_endpoint +from .... import invoke_pytorch_helper_function LOGGER = logging.getLogger(__name__) LOGGER.setLevel(logging.INFO) LOGGER.addHandler(logging.StreamHandler(sys.stdout)) -@pytest.mark.model("mnist") +@pytest.mark.model("sdxl") @pytest.mark.processor("gpu") @pytest.mark.gpu_test @pytest.mark.stabilityai_only -def test_mnist_distributed_gpu_stabilityai_one( - framework_version, ecr_image, instance_type, sagemaker_regions +def test_sdxl_v1_0_gpu_stabilityai(framework_version, ecr_image, instance_type, sagemaker_regions): + instance_type = "ml.g5.4xlarge" + model_bucket = "stabilityai-public-packages" + model_prefix = "model-packages/sdxl-v1-0-dlc" + model_file = "model.tar.gz" + inference_request = { + "text_prompts": [{"text": "A wonderous machine creating images"}], + "height": 1024, + "width": 1024, + } + function_args = { + "framework_version": framework_version, + "instance_type": instance_type, + "model_bucket": model_bucket, + "model_prefix": model_prefix, + "model_file": model_file, + "sdxl_script": sdxl_gpu_script, + "inference_request": inference_request, + } + invoke_pytorch_helper_function(ecr_image, sagemaker_regions, _test_sdxl_v1_0, function_args) + + +def _test_sdxl_v1_0( + ecr_image, + sagemaker_session, + framework_version, + instance_type, + model_bucket, + model_prefix, + model_file, + sdxl_script, + inference_request, + verify_logs=True, ): - raise Exception("Deliberate error 1") + endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-pytorch-serving") + + LOGGER.info(f"Downloading s3://{model_bucket}{model_prefix} to {sdxl_gpu_path}") + sagemaker_session.download_data( + path=sdxl_gpu_path, bucket=model_bucket, key_prefix=f"{model_prefix}/{model_file}" + ) + + model_data = sagemaker_session.upload_data( + path=os.path.join(sdxl_gpu_path, model_file), key_prefix="sagemaker-pytorch-serving/models" + ) + + pytorch = PyTorchModel( + model_data=model_data, + role="SageMakerRole", + framework_version=framework_version, + image_uri=ecr_image, + sagemaker_session=sagemaker_session, + entry_point=sdxl_script, + ) + + with timeout_and_delete_endpoint(endpoint_name, sagemaker_session, minutes=60): + predictor = pytorch.deploy( + initial_instance_count=1, + instance_type=instance_type, + endpoint_name=endpoint_name, + serializer=JSONSerializer(), + deserializer=BytesDeserializer(accept="image/png"), + ) + + # Model loading can take up to 5 min so we must wait + time.sleep(60 * 5) + + output = predictor.predict(inference_request) + + image = Image.open(BytesIO(output)) + assert image.height == inference_request["height"] + assert image.width == inference_request["width"] + + # Check for Cloudwatch logs + if verify_logs: + _check_for_cloudwatch_logs(endpoint_name, sagemaker_session) + + +def _check_for_cloudwatch_logs(endpoint_name, sagemaker_session): + client = sagemaker_session.boto_session.client("logs") + log_group_name = f"/aws/sagemaker/Endpoints/{endpoint_name}" + time.sleep(30) + identify_log_stream = client.describe_log_streams( + logGroupName=log_group_name, orderBy="LastEventTime", descending=True, limit=5 + ) + + try: + log_stream_name = identify_log_stream["logStreams"][0]["logStreamName"] + except IndexError as e: + raise RuntimeError( + f"Unable to look up log streams for the log group {log_group_name}" + ) from e + + log_events_response = client.get_log_events( + logGroupName=log_group_name, logStreamName=log_stream_name, limit=50, startFromHead=True + ) + + records_available = bool(log_events_response["events"]) + + if not records_available: + raise RuntimeError( + f"records_available variable is false... No cloudwatch events getting logged for the group {log_group_name}" + ) + else: + LOGGER.info( + f"Most recently logged events were found for the given log group {log_group_name} & log stream {log_stream_name}... Now verifying that TorchServe endpoint is logging on cloudwatch" + ) + check_for_torchserve_response = client.filter_log_events( + logGroupName=log_group_name, + logStreamNames=[log_stream_name], + filterPattern="Torch worker started.", + limit=10, + interleaved=False, + ) + assert bool(check_for_torchserve_response["events"]) diff --git a/test/sagemaker_tests/pytorch/inference/resources/stabilityai/sdxl-v1/model_gpu/.gitignore b/test/sagemaker_tests/pytorch/inference/resources/stabilityai/sdxl-v1/model_gpu/.gitignore new file mode 100644 index 000000000000..9ea69925795a --- /dev/null +++ b/test/sagemaker_tests/pytorch/inference/resources/stabilityai/sdxl-v1/model_gpu/.gitignore @@ -0,0 +1 @@ +model.tar.gz \ No newline at end of file diff --git a/test/sagemaker_tests/pytorch/inference/resources/stabilityai/sdxl-v1/model_gpu/README.md b/test/sagemaker_tests/pytorch/inference/resources/stabilityai/sdxl-v1/model_gpu/README.md new file mode 100644 index 000000000000..b033497ee42c --- /dev/null +++ b/test/sagemaker_tests/pytorch/inference/resources/stabilityai/sdxl-v1/model_gpu/README.md @@ -0,0 +1,3 @@ +# Stability AI DLC SDXL Integration Test Resources + +Due to the size of the models, these tests download from S3 then re-upload. \ No newline at end of file diff --git a/test/sagemaker_tests/pytorch/inference/resources/stabilityai/sdxl-v1/model_gpu/code/sdxl_inference.py b/test/sagemaker_tests/pytorch/inference/resources/stabilityai/sdxl-v1/model_gpu/code/sdxl_inference.py new file mode 100644 index 000000000000..dd72df351db7 --- /dev/null +++ b/test/sagemaker_tests/pytorch/inference/resources/stabilityai/sdxl-v1/model_gpu/code/sdxl_inference.py @@ -0,0 +1,188 @@ +import base64 +from io import BytesIO +from einops import rearrange +import json +from PIL import Image +from pytorch_lightning import seed_everything +import numpy as np +from sagemaker_inference.errors import BaseInferenceToolkitError +import sgm +from sgm.inference.api import ( + ModelArchitecture, + SamplingParams, + SamplingPipeline, + Sampler, +) +from sgm.inference.helpers import get_input_image_tensor, embed_watermark +import os + + +def model_fn(model_dir, context=None): + # Enable the refiner by default + disable_refiner = os.environ.get("SDXL_DISABLE_REFINER", "false").lower() == "true" + + sgm_path = os.path.dirname(sgm.__file__) + config_path = os.path.join(sgm_path, "configs/inference") + base_pipeline = SamplingPipeline( + ModelArchitecture.SDXL_V1_BASE, model_path=model_dir, config_path=config_path + ) + if disable_refiner: + print("Refiner model disabled by SDXL_DISABLE_REFINER environment variable") + refiner_pipeline = None + else: + refiner_pipeline = SamplingPipeline( + ModelArchitecture.SDXL_V1_REFINER, + model_path=model_dir, + config_path=config_path, + ) + + return {"base": base_pipeline, "refiner": refiner_pipeline} + + +def input_fn(request_body, request_content_type): + if request_content_type == "application/json": + model_input = json.loads(request_body) + if not "text_prompts" in model_input: + raise BaseInferenceToolkitError(400, "Invalid Request", "text_prompts missing") + return model_input + else: + raise BaseInferenceToolkitError( + 400, "Invalid Request", "Content-type must be application/json" + ) + + +def predict_fn(data, model, context=None): + # Only a single positive and optionally a single negative prompt are supported by this example. + prompts = [] + negative_prompts = [] + if "text_prompts" in data: + for text_prompt in data["text_prompts"]: + if "text" not in text_prompt: + raise BaseInferenceToolkitError( + 400, "Invalid Request", "text missing from text_prompt" + ) + if "weight" not in text_prompt: + text_prompt["weight"] = 1.0 + if text_prompt["weight"] < 0: + negative_prompts.append(text_prompt["text"]) + else: + prompts.append(text_prompt["text"]) + + if len(prompts) != 1: + raise BaseInferenceToolkitError( + 400, + "Invalid Request", + "One prompt with positive or default weight must be supplied", + ) + if len(negative_prompts) > 1: + raise BaseInferenceToolkitError( + 400, "Invalid Request", "Only one negative weighted prompt can be supplied" + ) + + seed = 0 + height = 1024 + width = 1024 + sampler_name = "DPMPP2MSampler" + cfg_scale = 7.0 + steps = 50 + use_pipeline = model["refiner"] is not None + init_image = None + image_strength = 0.35 + + if "height" in data: + height = data["height"] + if "width" in data: + width = data["width"] + if "sampler" in data: + sampler_name = data["sampler"] + if "cfg_scale" in data: + cfg_scale = data["cfg_scale"] + if "steps" in data: + steps = data["steps"] + if "seed" in data: + seed = data["seed"] + seed_everything(seed) + if "use_pipeline" in data: + use_pipeline = data["use_pipeline"] + if "init_image" in data: + if "image_strength" in data: + image_strength = data["image_strength"] + try: + init_image_bytes = BytesIO(base64.b64decode(data["init_image"])) + init_image_bytes.seek(0) + if init_image_bytes is not None: + init_image = get_input_image_tensor(Image.open(init_image_bytes)) + except Exception as e: + raise BaseInferenceToolkitError(400, "Invalid Request", "Unable to decode init_image") + + if model["refiner"] is None and use_pipeline: + raise BaseInferenceToolkitError(400, "Invalid Request", "Pipeline is not available") + + try: + if init_image is not None: + img_height, img_width = init_image.shape[2], init_image.shape[3] + output = model["base"].image_to_image( + params=SamplingParams( + width=img_width, + height=img_height, + steps=steps, + sampler=Sampler(sampler_name), + scale=cfg_scale, + img2img_strength=image_strength, + ), + image=init_image, + prompt=prompts[0], + negative_prompt=negative_prompts[0] if len(negative_prompts) > 0 else "", + return_latents=use_pipeline, + ) + else: + output = model["base"].text_to_image( + params=SamplingParams( + width=width, + height=height, + steps=steps, + sampler=Sampler(sampler_name), + scale=cfg_scale, + ), + prompt=prompts[0], + negative_prompt=negative_prompts[0] if len(negative_prompts) > 0 else "", + return_latents=use_pipeline, + ) + + if isinstance(output, (tuple, list)): + samples, samples_z = output + else: + samples = output + samples_z = None + + if use_pipeline and samples_z is not None: + print("Running Refinement Stage") + samples = model["refiner"].refiner( + params=SamplingParams( + steps=50, sampler=Sampler.EULER_EDM, scale=5.0, img2img_strength=0.3 + ), + image=samples_z, + prompt=prompts[0], + negative_prompt=negative_prompts[0] if len(negative_prompts) > 0 else "", + ) + + samples = embed_watermark(samples) + images = [] + for sample in samples: + sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") + image_bytes = BytesIO() + Image.fromarray(sample.astype(np.uint8)).save(image_bytes, format="PNG") + image_bytes.seek(0) + images.append(image_bytes.read()) + + return images + + except ValueError as e: + raise BaseInferenceToolkitError(400, "Invalid Request", str(e)) + + +def output_fn(prediction, accept): + # This only returns a single image since that's all the example code supports + if accept != "image/png": + raise BaseInferenceToolkitError(400, "Invalid Request", "Accept header must be image/png") + return prediction[0], accept diff --git a/test/testrunner.py b/test/testrunner.py index d6ae624b6899..2cce734b7c6f 100644 --- a/test/testrunner.py +++ b/test/testrunner.py @@ -58,7 +58,7 @@ def run_sagemaker_local_tests(images, pytest_cache_params): return # Run sagemaker Local tests framework, _ = get_framework_and_version_from_tag(images[0]) - framework = framework.replace("_trcomp", "") + framework = framework.replace("_trcomp", "").replace("stabilityai_", "") sm_tests_path = ( os.path.join("test", "sagemaker_tests", framework) if "huggingface" not in framework