diff --git a/.coveragerc_py27 b/.coveragerc_py27 new file mode 100644 index 00000000..dfbfd3ef --- /dev/null +++ b/.coveragerc_py27 @@ -0,0 +1,20 @@ +[run] +branch = True +timid = True + +[report] +exclude_lines = + pragma: no cover + pragma: py2 no cover + if six.PY3 + elif six.PY3 + +partial_branches = + pragma: no cover + pragma: py2 no cover + if six.PY3 + elif six.PY3 + +show_missing = True + +fail_under = 90 diff --git a/.coveragerc_py35 b/.coveragerc_py35 new file mode 100644 index 00000000..96bb72bf --- /dev/null +++ b/.coveragerc_py35 @@ -0,0 +1,20 @@ +[run] +branch = True +timid = True + +[report] +exclude_lines = + pragma: no cover + pragma: py3 no cover + if six.PY2 + elif six.PY2 + +partial_branches = + pragma: no cover + pragma: py3 no cover + if six.PY3 + elif six.PY3 + +show_missing = True + +fail_under = 90 diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..fdf48402 --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +application_import_names = local_mode_utils, sagemaker_mxnet_container, test, timeout +import-order-style = google diff --git a/.gitignore b/.gitignore index bb42ad22..3d8c3661 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,7 @@ dist .pytest_cache **/*.pyc **/*.py~ +.tox/* +.coverage +test/resources/local_mode_lock +.idea/* diff --git a/README.rst b/README.rst index dc7ac47b..d055317a 100644 --- a/README.rst +++ b/README.rst @@ -1,18 +1,9 @@ -========================== -SageMaker MXNet Containers -========================== +========================= +SageMaker MXNet Container +========================= -SageMaker MXNet Containers is an open source library for making the -MXNet framework run on Amazon SageMaker. - -This repository also contains Dockerfiles which install this library, MXNet, and dependencies -for building SageMaker MXNet images. - -For information on running MXNet jobs on SageMaker: `Python -SDK `__. - -For notebook examples: `SageMaker Notebook -Examples `__. +SageMaker MXNet Container is an open-source library for making Docker images for using MXNet on Amazon SageMaker. +For information on running MXNet jobs on Amazon SageMaker, please refer to the `SageMaker Python SDK documentation `__. Table of Contents ----------------- @@ -27,102 +18,148 @@ Getting Started Prerequisites ~~~~~~~~~~~~~ -Make sure you have installed all of the following prerequisites on your -development machine: +Make sure you have installed all of the following prerequisites on your development machine: - `Docker `__ - -For Testing on GPU -^^^^^^^^^^^^^^^^^^ - -- `Nvidia-Docker `__ +- For GPU testing: `nvidia-docker2 `__ Recommended ^^^^^^^^^^^ -- A python environment management tool. (e.g. - `PyEnv `__, +- A Python environment management tool (e.g. `PyEnv `__, `VirtualEnv `__) -Building your image -------------------- +Building Images +--------------- -`Amazon SageMaker `__ -utilizes Docker containers to run all training jobs & inference endpoints. +The Dockerfiles in this repository are intended to be used for building Docker images to run training jobs and inference endpoints on `Amazon SageMaker `__. -The Docker images are built from the Dockerfiles specified in -`Docker/ `__. +The current master branch of this repository contains Dockerfiles and support code for MXNet versions 1.3.0 and higher. +For MXNet versions 0.12.1-1.2.1, check out v1.0.0 of this repository. -The Docker files are grouped based on MXNet version and separated -based on Python version and processor type. +For each supported MXNet version, Dockerfiles can be found for each processor type (i.e. CPU and GPU). +For MXNet versions 0.12.1 and 1.0.0, there are separate Dockerfiles for each Python version as well. -The Docker images, used to run training & inference jobs, are built from -both corresponding "base" and "final" Dockerfiles. +All images are tagged with -- (e.g. 1.3.0-cpu-py3). -Base Images -~~~~~~~~~~~ +MXNet 1.1.0 and higher +~~~~~~~~~~~~~~~~~~~~~~ -The "base" Dockerfile encompass the installation of the framework and all of the dependencies -needed. +For these MXNet versions, there is one set of Dockerfiles for each version. +They install the SageMaker-specific support code found in this repository. -Tagging scheme is based on --. (e.g. 0.12.1-cpu-py2) +Before building these images, you need to have two files already saved locally. +The first is a pip-installable binary of the MXNet library. +This can be something you compile from source or `download from PyPI `__. -All "final" Dockerfiles build images using base images that use the tagging scheme -above. +The second is a pip-installable binary of this repository. +To create the SageMaker MXNet Container Python package: + +:: + + # Create the binary + git clone https://github.com/aws/sagemaker-mxnet-container.git + cd sagemaker-mxnet-container + python setup.py sdist + + # Copy your Python package to the appropriate "final" Dockerfile directory + cp dist/sagemaker_mxnet_container-.tar.gz docker//final -If you want to build your base docker image, then use: +Once you have those binaries, you can then build the image. +The Dockerfiles expect two build arguments: + +- ``py_version``: the Python version. +- ``framework_installable``: the path to the MXNet binary + +To build an image: :: - # All build instructions assume you're building from the same directory as the dockerfile. + # All build instructions assume you're building from the same directory as the Dockerfile. # CPU - docker build -t mxnet-base:-cpu- -f Dockerfile.cpu . + docker build -t preprod-mxnet: \ + --build-arg py_version= \ + --build-arg framework_installable= \ + -f Dockerfile.cpu . # GPU - docker build -t mxnet-base:-gpu- -f Dockerfile.gpu . + docker build -t preprod-mxnet: \ + --build-arg py_version= \ + --build-arg framework_installable= \ + -f Dockerfile.gpu . + +Don't forget the period at the end of the command! :: # Example # CPU - docker build -t mxnet-base:0.12.1-cpu-py2 -f Dockerfile.cpu . + docker build -t preprod-mxnet:1.1.0-cpu-py3 --build-arg py_version=3 + --build-arg framework_installable=mxnet-1.1.0-py2.py3-none-manylinux1_x86_64.whl -f Dockerfile.cpu . # GPU - docker build -t mxnet-base:0.12.1-gpu-py2 -f Dockerfile.gpu . + docker build -t preprod-mxnet:1.1.0-gpu-py3 --build-arg py_version=3 + --build-arg framework_installable=mxnet-1.1.0-py2.py3-none-manylinux1_x86_64.whl -f Dockerfile.gpu . -Final Images -~~~~~~~~~~~~ -The "final" Dockerfiles encompass the installation of the SageMaker specific support code. +MXNet 0.12.1 and 1.0.0 +~~~~~~~~~~~~~~~~~~~~~~ + +For these MXNet versions, there are "base" and "final" Dockerfiles for each image. +The "base" Dockerfile installs MXNet and its necessary dependencies. +The "final" Dockerfile installs the SageMaker-specific support code found in this repository. + +Base Images +^^^^^^^^^^^ + +To build a "base" image: -All "final" Dockerfiles use `base images for building `__. +:: -These "base" images are specified with the naming convention of -mxnet-base:--. + # All build instructions assume you're building from the same directory as the Dockerfile. -Before building "final" images: + # CPU + docker build -t mxnet-base:-cpu- -f Dockerfile.cpu . -Build your "base" image. Make sure it is named and tagged in accordance with your "final" -Dockerfile. + # GPU + docker build -t mxnet-base:-gpu- -f Dockerfile.gpu . +:: + + # Example + + # CPU + docker build -t mxnet-base:0.12.1-cpu-py2 -f Dockerfile.cpu . + + # GPU + docker build -t mxnet-base:0.12.1-gpu-py2 -f Dockerfile.gpu . + +Final Images +^^^^^^^^^^^^ + +All "final" Dockerfiles assume the "base" image has already been built. +Make sure the "base" image is named and tagged as expected by the "final" Dockerfile. + +In addition, the "final" Dockerfiles require a pip-installable binary of this repository. +To create the SageMaker MXNet Container Python package: :: - # Create the SageMaker MXNet Container Python package. - cd sagemaker-mxnet-containers + # Create the binary + git clone -b v1.0.0 https://github.com/aws/sagemaker-mxnet-container.git + cd sagemaker-mxnet-container python setup.py sdist - #. Copy your Python package to "final" Dockerfile directory that you are building. + # Copy your Python package to the appropriate "final" Dockerfile directory cp dist/sagemaker_mxnet_container-.tar.gz docker//final -If you want to build "final" Docker images, then use: +To build a "final" image: :: - # All build instructions assumes you're building from the same directory as the dockerfile. + # All build instructions assumes you're building from the same directory as the Dockerfile. # CPU docker build -t : -f Dockerfile.cpu . @@ -140,125 +177,111 @@ If you want to build "final" Docker images, then use: # GPU docker build -t preprod-mxnet:0.12.1-gpu-py2 -f Dockerfile.gpu . - # For building images of MXNet versions 1.1 and above - docker build -t preprod-mxnet:1.1.0-cpu-py2 --build-arg py_version=2 - --build-arg framework_installable=mxnet-1.1.0-py2.py3-none-manylinux1_x86_64.whl -f Dockerfile.cpu . - Running the tests ----------------- -Running the tests requires installation of the SageMaker MXNet Container code and its test -dependencies. +Running the tests requires installation of the SageMaker MXNet Container code and its test dependencies. :: - git clone https://github.com/aws/sagemaker-mxnet-containers.git - cd sagemaker-mxnet-containers + git clone https://github.com/aws/sagemaker-mxnet-container.git + cd sagemaker-mxnet-container pip install -e .[test] -Tests are defined in -`test/ `__ -and include unit, integration and functional tests. +Tests are defined in `test/ `__ and include unit and integration tests. +The integration tests include both running the Docker containers locally and running them on SageMaker. +The tests are compatible with only the Docker images built by Dockerfiles in the current branch. +If you want to run tests for MXNet versions 1.2.1 or below, please use the v1.0.0 tests. + +All test instructions should be run from the top level directory Unit Tests ~~~~~~~~~~ -If you want to run unit tests, then use: +To run unit tests: :: - # All test instructions should be run from the top level directory - pytest test/unit -Integration Tests -~~~~~~~~~~~~~~~~~ - -Running integration tests require `Docker `__ and `AWS -credentials `__, -as the integration tests make calls to a couple AWS services. The integration and functional -tests require configurations specified within their respective -`conftest.py `__. +Local Integration Tests +~~~~~~~~~~~~~~~~~~~~~~~ -Integration tests on GPU require `Nvidia-Docker `__. +Running local integration tests require `Docker `__ and `AWS credentials `__, +as the integration tests make calls to a couple AWS services. +Local integration tests on GPU require `nvidia-docker2 `__. +You Docker image must also be built in order to run the tests against it. -Before running integration tests: +Local integration tests use the following pytest arguments: -#. Build your Docker image. -#. Pass in the correct pytest arguments to run tests against your Docker image. +- ``docker-base-name``: the Docker image's repository. Defaults to 'preprod-mxnet'. +- ``framework-version``: the MXNet version. Defaults to the latest supported version. +- ``py-version``: the Python version. Defaults to '3'. +- ``processor``: CPU or GPU. Defaults to 'cpu'. +- ``tag``: the Docker image's tag. Defaults to --py -If you want to run local integration tests, then use: +To run local integration tests: :: - # Required arguments for integration tests are found in test/integ/conftest.py - - pytest test/integ --docker-base-name \ - --tag \ - --py-version <2_or_3> \ - --framework-version \ - --processor + pytest test/integration/local --docker-base-name \ + --tag \ + --py-version <2_or_3> \ + --framework-version \ + --processor :: # Example - pytest test/integ --docker-base-name preprod-mxnet \ - --tag 1.0 \ - --py-version 2 \ - --framework-version 0.12.1 \ - --processor cpu + pytest test/integration/local --docker-base-name preprod-mxnet \ + --tag 1.3.0-cpu-py3 \ + --py-version 3 \ + --framework-version 1.3.0 \ + --processor cpu -Functional Tests -~~~~~~~~~~~~~~~~ +SageMaker Integration Tests +~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Functional tests require your Docker image to be within an `Amazon ECR repository `__. +SageMaker integration tests require your Docker image to be within an `Amazon ECR repository `__. -The Docker-base-name is your `ECR repository namespace `__. +SageMaker integration tests use the following pytest arguments: -The instance-type is your specified `Amazon SageMaker Instance Type -`__ that the functional test will run on. +- ``docker-base-name``: the Docker image's `ECR repository namespace `__. +- ``framework-version``: the MXNet version. Defaults to the latest supported version. +- ``py-version``: the Python version. Defaults to '3'. +- ``processor``: CPU or GPU. Defaults to 'cpu'. +- ``tag``: the Docker image's tag. Defaults to --py +- ``aws-id``: your AWS account ID. +- ``instance-type``: the specified `Amazon SageMaker Instance Type `__ that the tests will run on. + Defaults to 'ml.c4.xlarge' for CPU and 'ml.p2.xlarge' for GPU. -Before running functional tests: - -#. Build your Docker image. -#. Push the image to your ECR repository. -#. Pass in the correct pytest arguments to run tests on SageMaker against the image within your ECR repository. - -If you want to run a functional end to end test on `Amazon -SageMaker `__, then use: +To run SageMaker integration tests: :: - # Required arguments for integration tests are found in test/functional/conftest.py - - pytest test/functional --aws-id \ - --docker-base-name \ - --instance-type \ - --tag \ + pytest test/integration/sagmaker --aws-id \ + --docker-base-name \ + --instance-type \ + --tag \ :: # Example - pytest test/functional --aws-id 12345678910 \ - --docker-base-name preprod-mxnet \ - --instance-type ml.m4.xlarge \ - --tag 1.0 + pytest test/integration/sagemaker --aws-id 12345678910 \ + --docker-base-name preprod-mxnet \ + --instance-type ml.m4.xlarge \ + --tag 1.3.0-cpu-py3 Contributing ------------ -Please read -`CONTRIBUTING.md `__ -for details on our code of conduct, and the process for submitting pull -requests to us. +Please read `CONTRIBUTING.md `__ +for details on our code of conduct, and the process for submitting pull requests to us. License ------- -SageMaker MXNet Containers is licensed under the Apache 2.0 License. It is copyright 2018 Amazon -.com, Inc. or its affiliates. All Rights Reserved. The license is available at: -http://aws.amazon.com/apache2.0/ +SageMaker MXNet Containers is licensed under the Apache 2.0 License. +It is copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +The license is available at: http://aws.amazon.com/apache2.0/ diff --git a/docker/0.12.1/base/Dockerfile.cpu b/docker/0.12.1/base/Dockerfile.cpu deleted file mode 100644 index 79fe0d68..00000000 --- a/docker/0.12.1/base/Dockerfile.cpu +++ /dev/null @@ -1,38 +0,0 @@ -FROM ubuntu:16.04 - -RUN apt-get update && \ - apt-get -y install build-essential libopencv-dev libopenblas-dev libjemalloc-dev libgfortran3 \ - python-dev python3-dev git wget curl nginx - -RUN cd /tmp && \ - curl -O https://bootstrap.pypa.io/get-pip.py && \ - python2 get-pip.py && \ - python3 get-pip.py - -COPY patches /patches - -RUN cd /tmp && \ - git clone --recursive https://github.com/apache/incubator-mxnet mxnet && \ - cd /tmp/mxnet && \ - git checkout tags/0.12.1 -b 0.12.1 && git submodule update --init --recursive && \ - git apply --verbose /patches/*.patch && \ - make -j$(nproc) USE_BLAS=openblas USE_MKL2017=1 USE_DIST_KVSTORE=1 && \ - cd /tmp/mxnet/python && \ - python2 setup.py install && \ - python3 setup.py install && \ - cd / && \ - rm -fr /tmp/mxnet - -# https://stackoverflow.com/questions/29274638/opencv-libdc1394-error-failed-to-initialize-libdc1394 -RUN ln -s /dev/null /dev/raw1394 - -RUN cd /tmp && \ - curl -O https://dl.influxdata.com/telegraf/releases/telegraf_1.4.2-1_amd64.deb && \ - dpkg -i telegraf_1.4.2-1_amd64.deb && \ - rm telegraf_1.4.2-1_amd64.deb - -ENV PYTHONDONTWRITEBYTECODE=1 \ - PYTHONUNBUFFERED=1 \ - LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/lib" - -WORKDIR / diff --git a/docker/0.12.1/base/Dockerfile.gpu b/docker/0.12.1/base/Dockerfile.gpu deleted file mode 100644 index 25b82ffd..00000000 --- a/docker/0.12.1/base/Dockerfile.gpu +++ /dev/null @@ -1,45 +0,0 @@ -FROM nvidia/cuda:9.0-cudnn7-devel - -RUN apt-get update && \ - apt-get -y install build-essential libopencv-dev libatlas-base-dev libcurl4-openssl-dev libgtest-dev \ - libjemalloc-dev cmake python-dev python3-dev python-opencv unzip git wget curl nginx - -# install pip -RUN cd /tmp && \ - curl -O https://bootstrap.pypa.io/get-pip.py && \ - python2 get-pip.py && \ - python3 get-pip.py - -COPY patches /patches - -# build mxnet -# from https://github.com/apache/incubator-mxnet/blob/master/docker/install/cpp.sh -RUN cd /usr/src/gtest && cmake CMakeLists.txt && make && cp *.a /usr/lib - -# using (future) version of: -# https://github.com/apache/incubator-mxnet/blob/master/docker/Dockerfiles/Dockerfile.in.lib.gpu -RUN cd /tmp && \ - git clone --recursive https://github.com/apache/incubator-mxnet mxnet && cd mxnet && \ - git checkout tags/0.12.1 -b 0.12.1 && git submodule update --init --recursive && \ - git apply --verbose /patches/*.patch && \ - make -j$(nproc) USE_CUDA=1 USE_CUDA_PATH=/usr/local/cuda USE_CUDNN=1 USE_DIST_KVSTORE=1 && \ - cd /tmp/mxnet/python && \ - python2 setup.py install && \ - python3 setup.py install && \ - cd / && \ - rm -fr /tmp/mxnet - -# https://stackoverflow.com/questions/29274638/opencv-libdc1394-error-failed-to-initialize-libdc1394 -RUN ln -s /dev/null /dev/raw1394 - -# install telegraf -RUN cd /tmp && \ - curl -O https://dl.influxdata.com/telegraf/releases/telegraf_1.4.2-1_amd64.deb && \ - dpkg -i telegraf_1.4.2-1_amd64.deb && \ - rm telegraf_1.4.2-1_amd64.deb - -ENV PYTHONDONTWRITEBYTECODE=1 \ - PYTHONUNBUFFERED=1 \ - LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/lib" - -WORKDIR / diff --git a/docker/0.12.1/base/patches/0001-Python-3-compatiblity-during-optimizer-serialization.patch b/docker/0.12.1/base/patches/0001-Python-3-compatiblity-during-optimizer-serialization.patch deleted file mode 100644 index f75f36d2..00000000 --- a/docker/0.12.1/base/patches/0001-Python-3-compatiblity-during-optimizer-serialization.patch +++ /dev/null @@ -1,25 +0,0 @@ -From c462d343b0782e82a7a91d3dcccc3863bb947dff Mon Sep 17 00:00:00 2001 -From: Tobias Domhan -Date: Sat, 25 Nov 2017 23:02:49 +0100 -Subject: [PATCH] Python 3 compatiblity during optimizer serialization. (#8334) - ---- - python/mxnet/kvstore.py | 2 +- - 1 file changed, 1 insertion(+), 1 deletion(-) - -diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py -index d068d06..ad598d0 100644 ---- a/python/mxnet/kvstore.py -+++ b/python/mxnet/kvstore.py -@@ -451,7 +451,7 @@ class KVStore(object): - # send the optimizer to server - try: - # use ASCII protocol 0, might be slower, but not a big ideal -- optim_str = pickle.dumps(optimizer, 0) -+ optim_str = py_str(pickle.dumps(optimizer, 0)) - except: - raise - self._send_command_to_servers(0, optim_str) --- -2.7.3.AMZN - diff --git a/docker/0.12.1/final/py2/Dockerfile.cpu b/docker/0.12.1/final/py2/Dockerfile.cpu deleted file mode 100644 index d528843d..00000000 --- a/docker/0.12.1/final/py2/Dockerfile.cpu +++ /dev/null @@ -1,15 +0,0 @@ -# Use local version of image built from Dockerfile.cpu in /docker/$version/base directory -FROM mxnet-base:0.12.1-cpu-py2 -ARG mxnet_support_tar=sagemaker_mxnet_container-1.0.0.tar.gz - -RUN pip2 install --no-cache \ - numpy==1.13.3 - -COPY $mxnet_support_tar . - -RUN pip2 install $mxnet_support_tar - -RUN rm $mxnet_support_tar - -# Entrypoint script comes from sagemaker_container_support -ENTRYPOINT ["/usr/bin/python2", "/usr/local/bin/entry.py"] diff --git a/docker/0.12.1/final/py2/Dockerfile.gpu b/docker/0.12.1/final/py2/Dockerfile.gpu deleted file mode 100644 index e186c0d1..00000000 --- a/docker/0.12.1/final/py2/Dockerfile.gpu +++ /dev/null @@ -1,15 +0,0 @@ -# Use local version of image built from Dockerfile.gpu in /docker/$version/base directory -FROM mxnet-base:0.12.1-gpu-py2 -ARG mxnet_support_tar=sagemaker_mxnet_container-1.0.0.tar.gz - -RUN pip2 install --no-cache \ - numpy==1.13.3 - -COPY $mxnet_support_tar . - -RUN pip2 install $mxnet_support_tar - -RUN rm $mxnet_support_tar - -# Entrypoint script comes from sagemaker_container_support -ENTRYPOINT ["/usr/bin/python2", "/usr/local/bin/entry.py"] diff --git a/docker/0.12.1/final/py3/Dockerfile.cpu b/docker/0.12.1/final/py3/Dockerfile.cpu deleted file mode 100644 index 09fc3151..00000000 --- a/docker/0.12.1/final/py3/Dockerfile.cpu +++ /dev/null @@ -1,15 +0,0 @@ -# Use local version of image built from Dockerfile.cpu in /docker/$version/base directory -FROM mxnet-base:0.12.1-cpu-py3 -ARG mxnet_support_tar=sagemaker_mxnet_container-1.0.0.tar.gz - -RUN pip3 install --no-cache \ - numpy==1.13.3 - -COPY $mxnet_support_tar . - -RUN pip3 install $mxnet_support_tar - -RUN rm $mxnet_support_tar - -# Entrypoint script comes from sagemaker_container_support -ENTRYPOINT ["/usr/bin/python3", "/usr/local/bin/entry.py"] diff --git a/docker/0.12.1/final/py3/Dockerfile.gpu b/docker/0.12.1/final/py3/Dockerfile.gpu deleted file mode 100644 index 5fc6dd96..00000000 --- a/docker/0.12.1/final/py3/Dockerfile.gpu +++ /dev/null @@ -1,15 +0,0 @@ -# Use local version of image built from Dockerfile.gpu in /docker/$version/base directory -FROM mxnet-base:0.12.1-gpu-py3 -ARG mxnet_support_tar=sagemaker_mxnet_container-1.0.0.tar.gz - -RUN pip3 install --no-cache \ - numpy==1.13.3 - -COPY $mxnet_support_tar . - -RUN pip3 install $mxnet_support_tar - -RUN rm $mxnet_support_tar - -# Entrypoint script comes from sagemaker_container_support -ENTRYPOINT ["/usr/bin/python3", "/usr/local/bin/entry.py"] diff --git a/docker/1.0.0/base/Dockerfile.cpu b/docker/1.0.0/base/Dockerfile.cpu deleted file mode 100644 index fb3e6207..00000000 --- a/docker/1.0.0/base/Dockerfile.cpu +++ /dev/null @@ -1,41 +0,0 @@ -FROM ubuntu:16.04 - -RUN apt-get update && \ - apt-get -y install build-essential libopencv-dev libopenblas-dev libjemalloc-dev libgfortran3 \ - python-dev python3-dev git wget curl nginx - -RUN cd /tmp && \ - curl -O https://bootstrap.pypa.io/get-pip.py && \ - python2 get-pip.py && \ - python3 get-pip.py - -COPY patches /patches -COPY cpu_patches /cpu_patches - -# Install mxnet for python 2 and python 3 -RUN cd /tmp && \ - git clone --recursive https://github.com/apache/incubator-mxnet mxnet && \ - cd /tmp/mxnet && \ - git checkout tags/1.0.0 -b 1.0.0 && git submodule update --init --recursive && \ - git apply --verbose /patches/*.patch && \ - git apply --verbose /cpu_patches/*.patch && \ - make -j$(nproc) USE_BLAS=openblas USE_MKL2017=1 USE_DIST_KVSTORE=1 && \ - cd /tmp/mxnet/python && \ - python2 setup.py install && \ - python3 setup.py install && \ - cd / && \ - rm -fr /tmp/mxnet - -# https://stackoverflow.com/questions/29274638/opencv-libdc1394-error-failed-to-initialize-libdc1394 -RUN ln -s /dev/null /dev/raw1394 - -RUN cd /tmp && \ - curl -O https://dl.influxdata.com/telegraf/releases/telegraf_1.4.2-1_amd64.deb && \ - dpkg -i telegraf_1.4.2-1_amd64.deb && \ - rm telegraf_1.4.2-1_amd64.deb - -ENV PYTHONDONTWRITEBYTECODE=1 \ - PYTHONUNBUFFERED=1 \ - LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/lib" - -WORKDIR / diff --git a/docker/1.0.0/base/Dockerfile.gpu b/docker/1.0.0/base/Dockerfile.gpu deleted file mode 100644 index b6b4184f..00000000 --- a/docker/1.0.0/base/Dockerfile.gpu +++ /dev/null @@ -1,45 +0,0 @@ -FROM nvidia/cuda:9.0-cudnn7-devel - -RUN apt-get update && \ - apt-get -y install build-essential libopencv-dev libatlas-base-dev libcurl4-openssl-dev libgtest-dev \ - libjemalloc-dev cmake python-dev python3-dev python-opencv unzip git wget curl nginx - -# install pip -RUN cd /tmp && \ - curl -O https://bootstrap.pypa.io/get-pip.py && \ - python2 get-pip.py && \ - python3 get-pip.py - -COPY patches /patches - -# build mxnet -# from https://github.com/apache/incubator-mxnet/blob/master/docker/install/cpp.sh -RUN cd /usr/src/gtest && cmake CMakeLists.txt && make && cp *.a /usr/lib - -# using (future) version of: -# https://github.com/apache/incubator-mxnet/blob/master/docker/Dockerfiles/Dockerfile.in.lib.gpu -RUN cd /tmp && \ - git clone --recursive https://github.com/apache/incubator-mxnet mxnet && cd mxnet && \ - git checkout tags/1.0.0 -b 1.0.0 && git submodule update --init --recursive && \ - git apply --verbose /patches/*.patch && \ - make -j$(nproc) USE_CUDA=1 USE_CUDA_PATH=/usr/local/cuda USE_CUDNN=1 USE_DIST_KVSTORE=1 && \ - cd /tmp/mxnet/python && \ - python2 setup.py install && \ - python3 setup.py install && \ - cd / && \ - rm -fr /tmp/mxnet - -# https://stackoverflow.com/questions/29274638/opencv-libdc1394-error-failed-to-initialize-libdc1394 -RUN ln -s /dev/null /dev/raw1394 - -# install telegraf -RUN cd /tmp && \ - curl -O https://dl.influxdata.com/telegraf/releases/telegraf_1.4.2-1_amd64.deb && \ - dpkg -i telegraf_1.4.2-1_amd64.deb && \ - rm telegraf_1.4.2-1_amd64.deb - -ENV PYTHONDONTWRITEBYTECODE=1 \ - PYTHONUNBUFFERED=1 \ - LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/lib" - -WORKDIR / diff --git a/docker/1.0.0/base/cpu_patches/0001-apache.incubator-mxnet.pull.9218.patch b/docker/1.0.0/base/cpu_patches/0001-apache.incubator-mxnet.pull.9218.patch deleted file mode 100644 index a8a94c77..00000000 --- a/docker/1.0.0/base/cpu_patches/0001-apache.incubator-mxnet.pull.9218.patch +++ /dev/null @@ -1,20 +0,0 @@ ---- a/prepare_mkl.sh -+++ b/prepare_mkl.sh -@@ -75,14 +75,14 @@ MXNET_ROOT=`dirname $0` - USE_MKLML=0 - # NOTE: if you update the following line, please also update the dockerfile at - # tests/ci_build/Dockerfile.mkl --VERSION_MATCH=20170908 -+VERSION_MATCH=20171227 - PLATFORM=$(uname) - if [ $PLATFORM == "Darwin" ]; then - INFIX=mac - elif [ $PLATFORM == "Linux" ]; then - INFIX=lnx - fi --ARCHIVE_BASENAME=mklml_${INFIX}_2018.0.20170908.tgz -+ARCHIVE_BASENAME=mklml_${INFIX}_2018.0.1.20171227.tgz - MKL_CONTENT_DIR=`echo $ARCHIVE_BASENAME | rev | cut -d "." -f 2- | rev` --MKLURL="https://github.com/01org/mkl-dnn/releases/download/v0.10/$ARCHIVE_BASENAME" -+MKLURL="https://github.com/intel/mkl-dnn/releases/download/v0.12/$ARCHIVE_BASENAME" - # there are diffrent MKL lib to be used for GCC and for ICC diff --git a/docker/1.0.0/base/patches/0001-Python-3-compatiblity-during-optimizer-serialization.patch b/docker/1.0.0/base/patches/0001-Python-3-compatiblity-during-optimizer-serialization.patch deleted file mode 100644 index f75f36d2..00000000 --- a/docker/1.0.0/base/patches/0001-Python-3-compatiblity-during-optimizer-serialization.patch +++ /dev/null @@ -1,25 +0,0 @@ -From c462d343b0782e82a7a91d3dcccc3863bb947dff Mon Sep 17 00:00:00 2001 -From: Tobias Domhan -Date: Sat, 25 Nov 2017 23:02:49 +0100 -Subject: [PATCH] Python 3 compatiblity during optimizer serialization. (#8334) - ---- - python/mxnet/kvstore.py | 2 +- - 1 file changed, 1 insertion(+), 1 deletion(-) - -diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py -index d068d06..ad598d0 100644 ---- a/python/mxnet/kvstore.py -+++ b/python/mxnet/kvstore.py -@@ -451,7 +451,7 @@ class KVStore(object): - # send the optimizer to server - try: - # use ASCII protocol 0, might be slower, but not a big ideal -- optim_str = pickle.dumps(optimizer, 0) -+ optim_str = py_str(pickle.dumps(optimizer, 0)) - except: - raise - self._send_command_to_servers(0, optim_str) --- -2.7.3.AMZN - diff --git a/docker/1.0.0/final/py2/Dockerfile.cpu b/docker/1.0.0/final/py2/Dockerfile.cpu deleted file mode 100644 index 29c97ce6..00000000 --- a/docker/1.0.0/final/py2/Dockerfile.cpu +++ /dev/null @@ -1,15 +0,0 @@ -# Use local version of image built from Dockerfile.cpu in /docker/$version/base directory -FROM mxnet-base:1.0.0-cpu-py2 -ARG mxnet_support_tar=sagemaker_mxnet_container-1.0.0.tar.gz - -RUN pip2 install --no-cache \ - numpy==1.13.3 - -COPY $mxnet_support_tar . - -RUN pip2 install $mxnet_support_tar - -RUN rm $mxnet_support_tar - -# Entrypoint script comes from sagemaker_container_support -ENTRYPOINT ["/usr/bin/python2", "/usr/local/bin/entry.py"] diff --git a/docker/1.0.0/final/py2/Dockerfile.gpu b/docker/1.0.0/final/py2/Dockerfile.gpu deleted file mode 100644 index d0a0ea34..00000000 --- a/docker/1.0.0/final/py2/Dockerfile.gpu +++ /dev/null @@ -1,15 +0,0 @@ -# Use local version of image built from Dockerfile.gpu in /docker/$version/base directory -FROM mxnet-base:1.0.0-gpu-py2 -ARG mxnet_support_tar=sagemaker_mxnet_container-1.0.0.tar.gz - -RUN pip2 install --no-cache \ - numpy==1.13.3 - -COPY $mxnet_support_tar . - -RUN pip2 install $mxnet_support_tar - -RUN rm $mxnet_support_tar - -# Entrypoint script comes from sagemaker_container_support -ENTRYPOINT ["/usr/bin/python2", "/usr/local/bin/entry.py"] diff --git a/docker/1.0.0/final/py3/Dockerfile.cpu b/docker/1.0.0/final/py3/Dockerfile.cpu deleted file mode 100644 index 65356fce..00000000 --- a/docker/1.0.0/final/py3/Dockerfile.cpu +++ /dev/null @@ -1,15 +0,0 @@ -# Use local version of image built from Dockerfile.cpu in /docker/$version/base directory -FROM mxnet-base:1.0.0-cpu-py3 -ARG mxnet_support_tar=sagemaker_mxnet_container-1.0.0.tar.gz - -RUN pip3 install --no-cache \ - numpy==1.13.3 - -COPY $mxnet_support_tar . - -RUN pip3 install $mxnet_support_tar - -RUN rm $mxnet_support_tar - -# Entrypoint script comes from sagemaker_container_support -ENTRYPOINT ["/usr/bin/python3", "/usr/local/bin/entry.py"] diff --git a/docker/1.0.0/final/py3/Dockerfile.gpu b/docker/1.0.0/final/py3/Dockerfile.gpu deleted file mode 100644 index 642110e2..00000000 --- a/docker/1.0.0/final/py3/Dockerfile.gpu +++ /dev/null @@ -1,15 +0,0 @@ -# Use local version of image built from Dockerfile.gpu in /docker/$version/base directory -FROM mxnet-base:1.0.0-gpu-py3 -ARG mxnet_support_tar=sagemaker_mxnet_container-1.0.0.tar.gz - -RUN pip3 install --no-cache \ - numpy==1.13.3 - -COPY $mxnet_support_tar . - -RUN pip3 install $mxnet_support_tar - -RUN rm $mxnet_support_tar - -# Entrypoint script comes from sagemaker_container_support -ENTRYPOINT ["/usr/bin/python3", "/usr/local/bin/entry.py"] diff --git a/docker/1.1.0/final/Dockerfile.cpu b/docker/1.1.0/final/Dockerfile.cpu deleted file mode 100644 index f260f535..00000000 --- a/docker/1.1.0/final/Dockerfile.cpu +++ /dev/null @@ -1,46 +0,0 @@ -FROM ubuntu:16.04 -ARG framework_installable -ARG py_version -ARG framework_support_installable=sagemaker_mxnet_container-1.0.0.tar.gz - -# Validate that arguments are specified -RUN test $framework_installable || exit 1 && \ - test $py_version || exit 1 - -WORKDIR /tmp - -RUN apt-get update && \ - apt-get -y install build-essential libopencv-dev libopenblas-dev libjemalloc-dev libgfortran3 \ - python-dev python3-dev wget curl nginx - -# Symlink /usr/bin/python to the python version we're building for. -RUN rm /usr/bin/python && ln -s "/usr/bin/python$py_version" /usr/bin/python - -# Install pip -RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ - /usr/bin/python get-pip.py - -# Will install from pypi once packages are released there. For now, copy from local file system. -COPY $framework_installable . -COPY $framework_support_installable . - -RUN framework_installable_local=$(basename $framework_installable) && \ - framework_support_installable_local=$(basename $framework_support_installable) && \ - \ - pip install $framework_installable_local && \ - pip install $framework_support_installable_local && \ - \ - rm $framework_installable_local && \ - rm $framework_support_installable_local - -# This is here to make our installed version of OpenCV work. -# https://stackoverflow.com/questions/29274638/opencv-libdc1394-error-failed-to-initialize-libdc1394 -# TODO: Should we be installing OpenCV in our image like this? Is there another way we can fix this? -RUN ln -s /dev/null /dev/raw1394 - -ENV PYTHONDONTWRITEBYTECODE=1 \ - PYTHONUNBUFFERED=1 \ - LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/lib" - -# Entrypoint script comes from sagemaker_container_support -ENTRYPOINT ["/usr/bin/python", "/usr/local/bin/entry.py"] diff --git a/docker/1.1.0/final/Dockerfile.gpu b/docker/1.1.0/final/Dockerfile.gpu deleted file mode 100644 index e302cb1e..00000000 --- a/docker/1.1.0/final/Dockerfile.gpu +++ /dev/null @@ -1,46 +0,0 @@ -FROM nvidia/cuda:9.0-cudnn7-devel -ARG framework_installable -ARG py_version -ARG framework_support_installable=sagemaker_mxnet_container-1.0.0.tar.gz - -# Validate that arguments are specified -RUN test $framework_installable || exit 1 && \ - test $py_version || exit 1 - -WORKDIR /tmp - -RUN apt-get update && \ - apt-get -y install build-essential libopencv-dev libatlas-base-dev libcurl4-openssl-dev libgtest-dev \ - libjemalloc-dev python-dev python3-dev unzip git wget curl nginx - -# Symlink /usr/bin/python to the python version we're building for. -RUN rm /usr/bin/python && ln -s "/usr/bin/python$py_version" /usr/bin/python - -# install pip -RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ - /usr/bin/python get-pip.py - -# Will install from pypi once packages are released there. For now, copy from local file system. -COPY $framework_installable . -COPY $framework_support_installable . - -RUN framework_installable_local=$(basename $framework_installable) && \ - framework_support_installable_local=$(basename $framework_support_installable) && \ - \ - pip install $framework_installable_local && \ - pip install $framework_support_installable_local && \ - \ - rm $framework_installable_local && \ - rm $framework_support_installable_local - -# This is here to make our installed version of OpenCV work. -# https://stackoverflow.com/questions/29274638/opencv-libdc1394-error-failed-to-initialize-libdc1394 -# TODO: Should we be installing OpenCV in our image like this? Is there another way we can fix this? -RUN ln -s /dev/null /dev/raw1394 - -ENV PYTHONDONTWRITEBYTECODE=1 \ - PYTHONUNBUFFERED=1 \ - LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/lib" - -# Entrypoint script comes from sagemaker_container_support -ENTRYPOINT ["/usr/bin/python", "/usr/local/bin/entry.py"] diff --git a/docker/1.2.1/final/Dockerfile.cpu b/docker/1.3.0/final/Dockerfile.cpu similarity index 74% rename from docker/1.2.1/final/Dockerfile.cpu rename to docker/1.3.0/final/Dockerfile.cpu index f1e7d9e1..af7cceb8 100644 --- a/docker/1.2.1/final/Dockerfile.cpu +++ b/docker/1.3.0/final/Dockerfile.cpu @@ -1,7 +1,9 @@ FROM ubuntu:16.04 ARG framework_installable ARG py_version -ARG framework_support_installable=sagemaker_mxnet_container-1.0.0.tar.gz +ARG framework_support_installable=sagemaker_mxnet_container-2.0.0.tar.gz + +LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true # Validate that arguments are specified RUN test $framework_installable || exit 1 && \ @@ -10,15 +12,15 @@ RUN test $framework_installable || exit 1 && \ WORKDIR /tmp RUN apt-get update && \ - apt-get -y install build-essential libopencv-dev libopenblas-dev libjemalloc-dev libgfortran3 \ - python-dev python3-dev wget curl nginx + apt-get -y install libopencv-dev libopenblas-dev python python3 wget curl nginx # Symlink /usr/bin/python to the python version we're building for. RUN rm /usr/bin/python && ln -s "/usr/bin/python$py_version" /usr/bin/python # Install pip RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ - /usr/bin/python get-pip.py + /usr/bin/python get-pip.py && \ + rm get-pip.py # Will install from pypi once packages are released there. For now, copy from local file system. COPY $framework_installable . @@ -33,8 +35,8 @@ RUN framework_installable_local=$(basename $framework_installable) && \ rm $framework_installable_local && \ rm $framework_support_installable_local -# For ONNX model import - https://onnx.ai/ -RUN pip install onnx +RUN pip install onnx==1.2.1 \ + keras-mxnet==2.2.2 # This is here to make our installed version of OpenCV work. # https://stackoverflow.com/questions/29274638/opencv-libdc1394-error-failed-to-initialize-libdc1394 @@ -45,5 +47,5 @@ ENV PYTHONDONTWRITEBYTECODE=1 \ PYTHONUNBUFFERED=1 \ LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/lib" -# Entrypoint script comes from sagemaker_container_support -ENTRYPOINT ["/usr/bin/python", "/usr/local/bin/entry.py"] +ENV SAGEMAKER_TRAINING_MODULE sagemaker_mxnet_container.training:main +ENV SAGEMAKER_SERVING_MODULE sagemaker_mxnet_container.serving:main diff --git a/docker/1.2.1/final/Dockerfile.gpu b/docker/1.3.0/final/Dockerfile.gpu similarity index 71% rename from docker/1.2.1/final/Dockerfile.gpu rename to docker/1.3.0/final/Dockerfile.gpu index 52bc94dd..ce515a9b 100644 --- a/docker/1.2.1/final/Dockerfile.gpu +++ b/docker/1.3.0/final/Dockerfile.gpu @@ -1,7 +1,9 @@ -FROM nvidia/cuda:9.0-cudnn7-devel +FROM nvidia/cuda:9.0-cudnn7-runtime ARG framework_installable ARG py_version -ARG framework_support_installable=sagemaker_mxnet_container-1.0.0.tar.gz +ARG framework_support_installable=sagemaker_mxnet_container-2.0.0.tar.gz + +LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true # Validate that arguments are specified RUN test $framework_installable || exit 1 && \ @@ -10,15 +12,16 @@ RUN test $framework_installable || exit 1 && \ WORKDIR /tmp RUN apt-get update && \ - apt-get -y install build-essential libopencv-dev libatlas-base-dev libcurl4-openssl-dev libgtest-dev \ - libjemalloc-dev python-dev python3-dev unzip git wget curl nginx + apt-get -y install libopencv-dev libatlas-base-dev libcurl4-openssl-dev \ + python python3 unzip git wget curl nginx # Symlink /usr/bin/python to the python version we're building for. RUN rm /usr/bin/python && ln -s "/usr/bin/python$py_version" /usr/bin/python # install pip RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ - /usr/bin/python get-pip.py + /usr/bin/python get-pip.py && \ + rm get-pip.py # Will install from pypi once packages are released there. For now, copy from local file system. COPY $framework_installable . @@ -33,8 +36,8 @@ RUN framework_installable_local=$(basename $framework_installable) && \ rm $framework_installable_local && \ rm $framework_support_installable_local -# For ONNX model import - https://onnx.ai/ -RUN pip install onnx +RUN pip install onnx==1.2.1 \ + keras-mxnet==2.2.2 # This is here to make our installed version of OpenCV work. # https://stackoverflow.com/questions/29274638/opencv-libdc1394-error-failed-to-initialize-libdc1394 @@ -45,5 +48,5 @@ ENV PYTHONDONTWRITEBYTECODE=1 \ PYTHONUNBUFFERED=1 \ LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/lib" -# Entrypoint script comes from sagemaker_container_support -ENTRYPOINT ["/usr/bin/python", "/usr/local/bin/entry.py"] +ENV SAGEMAKER_TRAINING_MODULE sagemaker_mxnet_container.training:main +ENV SAGEMAKER_SERVING_MODULE sagemaker_mxnet_container.serving:main diff --git a/setup.py b/setup.py index c0da5067..891589f3 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,29 @@ -import os +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +from __future__ import absolute_import + from glob import glob +import os from os.path import basename from os.path import splitext -from setuptools import setup, find_packages +from setuptools import find_packages, setup def read(fname): @@ -12,7 +32,7 @@ def read(fname): setup( name='sagemaker_mxnet_container', - version='1.0.0', + version='2.0.0', description='Open source library for creating MXNet containers to run on Amazon SageMaker.', packages=find_packages(where='src', exclude=('test',)), @@ -36,9 +56,9 @@ def read(fname): # We don't declare our dependency on mxnet here because we build with # different packages for different variants (e.g. mxnet-mkl and mxnet-cu90). - install_requires=['sagemaker-container-support >= 1.0.0, <2'], + install_requires=['sagemaker-containers>=2.2.5', 'retrying==1.3.3'], extras_require={ - 'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist', 'mock', - 'requests==2.18.4', 'sagemaker'] + 'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist', 'mock', 'sagemaker', + 'requests==2.18.4', 'docker-compose', 'mxnet==1.3.0.post0'] }, ) diff --git a/src/mxnet_container/__init__.py b/src/mxnet_container/__init__.py deleted file mode 100644 index bf8cce97..00000000 --- a/src/mxnet_container/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -from mxnet_container.train import train -from mxnet_container.serve.transformer import load_dependencies, transformer diff --git a/src/mxnet_container/serve/__init__.py b/src/mxnet_container/serve/__init__.py deleted file mode 100644 index ecbe7b56..00000000 --- a/src/mxnet_container/serve/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - diff --git a/src/mxnet_container/serve/environment.py b/src/mxnet_container/serve/environment.py deleted file mode 100644 index e0791a20..00000000 --- a/src/mxnet_container/serve/environment.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -import container_support as cs -import os - - -class MXNetHostingEnvironment(cs.HostingEnvironment): - DEFAULT_MODEL_FIRST_DIMENSION_SIZE_PARAM = "SAGEMAKER_DEFAULT_MODEL_FIRST_DIMENSION_SIZE" - - def __init__(self, base_dir=cs.ContainerEnvironment.BASE_DIRECTORY): - super(MXNetHostingEnvironment, self).__init__(base_dir) - self.preferred_batch_size = int(os.environ.get( - MXNetHostingEnvironment.DEFAULT_MODEL_FIRST_DIMENSION_SIZE_PARAM, '1')) - - self.update_mxnet_envvars() - - @staticmethod - def update_mxnet_envvars(): - if not os.environ.get('MXNET_CPU_WORKER_NTHREADS'): - os.environ['MXNET_CPU_WORKER_NTHREADS'] = '1' - - if not os.environ.get('MXNET_CPU_PRIORITY_NTHREADS'): - os.environ['MXNET_CPU_PRIORITY_NTHREADS'] = '1' - - if not os.environ.get('MXNET_KVSTORE_REDUCTION_NTHREADS'): - os.environ['MXNET_KVSTORE_REDUCTION_NTHREADS'] = '1' - - if not os.environ.get('MXNET_ENGINE_TYPE'): - os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine' - - if not os.environ.get('OMP_NUM_THREADS'): - os.environ['OMP_NUM_THREADS'] = '1' diff --git a/src/mxnet_container/serve/transformer.py b/src/mxnet_container/serve/transformer.py deleted file mode 100644 index 5f10d0da..00000000 --- a/src/mxnet_container/serve/transformer.py +++ /dev/null @@ -1,419 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -import csv -import json -import os -from types import ModuleType - -import mxnet as mx -from container_support.serving import UnsupportedContentTypeError, \ - UnsupportedAcceptTypeError, \ - UnsupportedInputShapeError, \ - JSON_CONTENT_TYPE, \ - CSV_CONTENT_TYPE -from six import StringIO - -from mxnet_container.serve.environment import MXNetHostingEnvironment -from mxnet_container.train import DEFAULT_MODEL_FILENAMES, DEFAULT_MODEL_NAME - - -def transformer(user_module): - return MXNetTransformer.from_module(user_module) - - -def load_dependencies(): - pass - - -class MXNetTransformer(object): - """A ``Transformer`` encapsulates the function(s) responsible for parsing - incoming request data, passing it through a model, and converting the - result into something that can be returned as the body of and HTTP response. - """ - - def __init__(self, model, transform_fn): - """Initialize a Transformer. - - :param model: a fully initialized model - :param transform_fn: a transformer function - """ - self.model = model - self.transform_fn = transform_fn - - def transform(self, input_data, content_type, accept): - """Transforms input data into a prediction result. - The input data is expected to have the given ``input_content_type``. - The output returned should have the given ``output_content_type``. - - :param input_data: input data - :param content_type: content type from Content-Type headers - :param accept: content type from Accept header - :return: the transformed result - """ - return self.transform_fn(self.model, input_data, content_type, accept) - - @classmethod - def select_transformer_class(cls, model): - if isinstance(model, mx.module.BaseModule): - return ModuleTransformer - - if isinstance(model, mx.gluon.block.Block): - return GluonBlockTransformer - - raise ValueError('Unsupported model type: {}'.format(model.__class__.__name__)) - - @classmethod - def from_module(cls, m): - """Initialize a Transformer using functions supplied by the given module. - - The module may provide a ``model_fn`` that returns a fully initialized model of - some kind. Generally this will be a Gluon ``Block`` or a Module API ``Module``, but - it can be anything, as long as it is compatible with the ``transform_fn``. - - If the ``model_fn`` is not present, a default implementation will be used instead. The - default implementation is compatible with the ``Module``s saved by the ``default_save`` - method in MXNetTrainingEnvironment. - - The ``model_fn`` (user-supplied or default) will be called once during - each inference worker's startup process. - - The module may supply a ``transform_fn``. If it is present, it will be used to handle - each inference request. If it is not present, then a ``transform_fn`` will be composed - by chaining an ``input_fn``, ``predict_fn`` and ``output_fn``. If any of these are - implemented in the given module, they will be used. Otherwise default implementations will - be used instead. - - :param m: a python module - :return: a configured Transformer object - """ - - if not isinstance(m, ModuleType): - raise ValueError("not a module!") - - env = MXNetHostingEnvironment() - - # load model - if hasattr(m, 'model_fn'): - model = m.model_fn(env.model_dir) - else: - model = ModuleTransformer._default_model_fn(env.model_dir, env.preferred_batch_size) - - # if user has supplied a transform_fn, we can use base MXNetTransformer directly - if hasattr(m, 'transform_fn'): - return MXNetTransformer(model, m.transform_fn) - - # otherwise we need to create a Module- or Gluon-specific subclass - transformer_class = cls.select_transformer_class(model) - return transformer_class.from_module(m, model) - - -class GluonBlockTransformer(MXNetTransformer): - def __init__(self, block, transform_fn): - super(GluonBlockTransformer, self).__init__(block, transform_fn) - - @classmethod - def from_module(cls, module, block): - input_fn = GluonBlockTransformer._get_function(module, 'input_fn') - predict_fn = GluonBlockTransformer._get_function(module, 'predict_fn') - output_fn = GluonBlockTransformer._get_function(module, 'output_fn') - - def transform_fn(block, data, content_type, accept): - i = input_fn(data, content_type) - p = predict_fn(block, i) - o, ct = output_fn(p, accept) - return o, ct - - return cls(block, transform_fn) - - @classmethod - def _get_function(cls, module, name): - if hasattr(module, name): - return getattr(module, name) - else: - return getattr(cls, '_default_' + name) - - @staticmethod - def _default_input_fn(input, content_type): - """A default input handler for Gluon ``Block``s. - :param input: the request payload - :param content_type: the request content_type (must equal JSON_CONTENT_TYPE) - :return: NDArray to pass to ``predict_fn`` - """ - if JSON_CONTENT_TYPE == content_type: - return mx.nd.array(json.loads(input)) - - raise UnsupportedContentTypeError(content_type) - - @staticmethod - def _default_predict_fn(block, ndarray): - """A default prediction function for Gluon ``Block``s. - :param block: a Gluon ``Block`` - :param ndarray: an NDArray (axis 1 = batch index) - :return: an NDArray - """ - return block(ndarray) - - @staticmethod - def _default_output_fn(ndarray, accept): - """A default output handler for Gluon ``Block``s. - - :param ndarray: an NDArray - :param accept: content type from accept header (must equal JSON_CONTENT_TYPE) - :return: a json string - :raises: UnsupportedAcceptTypeError if accept != JSON_CONTENT_TYPE - """ - if JSON_CONTENT_TYPE == accept: - return json.dumps(ndarray.asnumpy().tolist()), JSON_CONTENT_TYPE - - raise UnsupportedAcceptTypeError(accept) - - -class ModuleTransformer(MXNetTransformer): - def __init__(self, module, transform_fn): - super(ModuleTransformer, self).__init__(module, transform_fn) - - @classmethod - def from_module(cls, m, model): - """Initialize a Transformer using functions supplied by the given module. - - If the module contains a ``transform_fn``, it will be used to handle incoming request - data, execute the model prediction, and generation of response content. - - If the module does not contain a ``transform_fn``, then one will be assembled by: - - chaining a ``process_request_fn`` and ``output_fn`` if ``process_request_fn`` is defined - - otherwise: chaining an ``input_fn``, ``predict_fn``, and ``output_fn`` - Default handlers will be used for any of these that are not present in the supplied module. - - :param m: a python module - :return: a configured Transformer object - """ - # TODO remove process_request_fn? - if hasattr(m, 'process_request_fn'): - process_fn = m.process_request_fn - else: - input_fn = cls._default_input_fn if not hasattr(m, 'input_fn') else m.input_fn - predict_fn = cls._default_predict_fn if not hasattr(m, 'predict_fn') else m.predict_fn - process_fn = cls._process_request_fn(input_fn, predict_fn) - - if hasattr(m, 'output_fn'): - output_fn = m.output_fn - else: - output_fn = cls._default_output_fn - - transform_fn = cls._build_transform_fn(process_fn, output_fn) - - return cls(model, transform_fn) - - @staticmethod - def _process_request_fn(input_handler, prediction_handler): - """Construct processing function from handlers. - - :param input_handler: handles input and transforms for predict call - :param prediction_handler: consumes data from input handler and calls predict - :return: processing function - """ - - def process(model, data, content_type): - """Processing function for MXNet models. - - :param model: loaded MXNet model - :param data: data from the request - :param content_type: specified in the request - :return: a list of NDArray - """ - return prediction_handler(model, input_handler(model, data, content_type)) - - return process - - @staticmethod - def _default_input_fn(model, data, content_type): - """A default input handler for MXNet models to support default input. - - :param model: loaded MXNet model - :param data: data from the request - :param content_type: specified in the request - :return: NDArrayIter to feed to predict call - """ - - if content_type == JSON_CONTENT_TYPE: - return ModuleTransformer._process_json_input(model, data) - - if content_type == CSV_CONTENT_TYPE: - return ModuleTransformer._process_csv_input(model, data) - - raise UnsupportedContentTypeError(content_type) - - @staticmethod - def _process_json_input(model, data): - """A default inout handler for json input. - - 'data' is deserialized from json into NDArray. This array is used to create - iterator that is used to call 'predict' on the model. - - :param data: json data from the request - :return: NDArrayIter to feed to predict call - """ - - parsed = json.loads(data) - return ModuleTransformer._prepare_input_for_default_predict(model, mx.nd.array(parsed)) - - @staticmethod - def _process_csv_input(model, data): - """A default prediction function for MXNet models that takes csv as input. - - :param model: loaded MXNet model - :param data: data from the request - :return: NDArrayIter to feed to predict call - """ - - # we can only support case when there is a single data input - if len(model.data_shapes) != 1: - raise UnsupportedInputShapeError(len(model.data_shapes)) - - # model is already loaded with data shape bound, - # ignore the first parameter that is batch_size - model_data_shape = model.data_shapes[0] - (shape_name, input_data_shape) = model_data_shape - no_batch_data_shape = input_data_shape[1:] - - # let's read the csv into ndarray doing reshaping as - # specified by the model since csv is arriving flattened - csv_buff = StringIO(data) - csv_to_parse = csv.reader(csv_buff, delimiter=',') - full_array = [] - for row in csv_to_parse: - casted_list = [float(i) for i in row] - shaped_row = mx.nd.array(casted_list).reshape(no_batch_data_shape) - full_array.append(shaped_row.asnumpy().tolist()) - return ModuleTransformer._prepare_input_for_default_predict(model, mx.nd.array(full_array)) - - @staticmethod - def _prepare_input_for_default_predict(model, ndarray): - # We require model to only have one input - [data_shape] = model.data_shapes - - # Batch size is first dimension of model input - model_batch_size = data_shape[1][0] - pad_rows = max(0, model_batch_size - ndarray.shape[0]) - - # If ndarray has fewer rows than model_batch_size, then pad it with zeros. - if pad_rows: - num_pad_values = pad_rows - for dimension in ndarray.shape[1:]: - num_pad_values *= dimension - padding_shape = tuple([pad_rows] + list(ndarray.shape[1:])) - padding = mx.ndarray.zeros(shape=padding_shape) - ndarray = mx.ndarray.concat(ndarray, padding, dim=0) - - model_input = mx.io.NDArrayIter(ndarray, batch_size=model_batch_size, - last_batch_handle='pad') - - if pad_rows: - # Update the getpad method on the model_input data iterator to return the amount of - # padding. MXNet will ignore the last getpad() rows during Module predict. - def _getpad(): - return pad_rows - - model_input.getpad = _getpad - - return model_input - - @staticmethod - def _build_transform_fn(process_request_fn, output_fn): - """ Create a transformer function. - :param process_request_fn: input processing function - :param output_fn: an output handler function - :return: - """ - - def f(model, data, input_content_type, requested_output_content_type): - prediction_result = process_request_fn(model, data, input_content_type) - o, ct = output_fn(prediction_result, requested_output_content_type) - return o, ct - - return f - - @staticmethod - def _default_predict_fn(module, data): - """A default prediction function for MXNet models. - :param module: an MXNet Module - :param data: NDArrayIter - :return: a list of NDArray or list of lists of NDArray - """ - - return module.predict(data) - - @staticmethod - def _default_output_fn(data, content_type): - """A default output handler for MXNet models. - - :param data: output of ``mxnet.Module.predict(...)`` - :param content_type: requested content type by the request to be returned - :return: a json string - """ - - if content_type == JSON_CONTENT_TYPE: - result_to_send = [arr.asnumpy().tolist() for arr in data] - return json.dumps(result_to_send), JSON_CONTENT_TYPE - - if content_type == CSV_CONTENT_TYPE: - result_to_send = [arr.asnumpy().flatten() for arr in data] - str_io = StringIO() - csv_writer = csv.writer(str_io, delimiter=',') - for row in result_to_send: - csv_writer.writerow(row) - return str_io.getvalue(), CSV_CONTENT_TYPE - - raise UnsupportedAcceptTypeError(content_type) - - @staticmethod - def _default_model_fn(model_dir, preferred_batch_size): - for f in DEFAULT_MODEL_FILENAMES.values(): - path = os.path.join(model_dir, f) - if not os.path.exists(path): - raise ValueError('missing %s file' % f) - - shapes_file = os.path.join(model_dir, DEFAULT_MODEL_FILENAMES['shapes']) - data_names, data_shapes = ModuleTransformer._read_data_shapes(shapes_file, - preferred_batch_size) - - sym, args, aux = mx.model.load_checkpoint('%s/%s' % (model_dir, DEFAULT_MODEL_NAME), 0) - - # TODO mxnet ctx - better default, allow user control - mod = mx.mod.Module(symbol=sym, context=mx.cpu(), data_names=data_names, label_names=None) - mod.bind(for_training=False, data_shapes=data_shapes) - mod.set_params(args, aux, allow_missing=True) - - return mod - - @staticmethod - def _read_data_shapes(path, preferred_batch_size=1): - with open(path, 'r') as f: - signature = json.load(f) - - data_names = [] - data_shapes = [] - - for s in signature: - name = s['name'] - data_names.append(name) - - shape = s['shape'] - - if preferred_batch_size: - shape[0] = preferred_batch_size - - data_shapes.append((name, shape)) - - return data_names, data_shapes diff --git a/src/mxnet_container/train.py b/src/mxnet_container/train.py deleted file mode 100644 index 8d774329..00000000 --- a/src/mxnet_container/train.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -import container_support as cs -import inspect -import json -import logging -import os -import socket -import subprocess - -logger = logging.getLogger(__name__) - -DEFAULT_MODEL_NAME = "model" - -DEFAULT_MODEL_FILENAMES = { - 'symbol': 'model-symbol.json', - 'params': 'model-0000.params', - 'shapes': 'model-shapes.json', -} - -UPCOMING_SCRIPT_MODE_WARNING = ( - '\033[1;33m' # print warning in yellow - 'This required structure for training scripts will be ' - 'deprecated with the next major release of MXNet images. ' - 'The train() function will no longer be required; ' - 'instead the training script must be able to be run as a standalone script. ' - 'For more information, see https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/mxnet#updating-your-mxnet-training-script.' # noqa: E501 - '\033[1;0m' -) - - -class MXNetTrainingEnvironment(cs.TrainingEnvironment): - """ Configuration for single machine and distributed mxnet training. - """ - - def __init__(self, base_dir): - super(MXNetTrainingEnvironment, self).__init__(base_dir) - - self._ps_verbose = int(self.hyperparameters.get('_ps_verbose', 0)) - self._ps_port = int(self.hyperparameters.get('_ps_port', 8000)) - self._scheduler_host = sorted(self.hosts)[0] - self._scheduler_ip = host_lookup(self._scheduler_host) - # Block until all host lookups succeed. Relies on retrying host_lookup. - for host in self.hosts: - host_lookup(host) - - @property - def distributed(self): - """ Returns True if this configuration defines a distributed learning task.""" - return len(self.hosts) > 1 - - @property - def current_host_scheduler(self): - """ Returns True if this machine should be the mxnet parameter server scheduler.""" - return self._scheduler_host == self.current_host - - def env_vars_for_role(self, role): - """ Returns environment variables for a python process to run as an - mxnet parameter server process with the specified role. - - Args: - role (str): One of "worker", "server", or "scheduler" - """ - if role not in ["worker", "scheduler", "server"]: - raise ValueError("Unexpected role {}".format(role)) - return { - 'DMLC_NUM_WORKER': str(len(self.hosts)), - 'DMLC_NUM_SERVER': str(len(self.hosts)), - 'DMLC_ROLE': role, - 'DMLC_PS_ROOT_URI': str(self._scheduler_ip), - 'DMLC_PS_ROOT_PORT': str(self._ps_port), - 'PS_VERBOSE': str(self._ps_verbose) - } - - @property - def kwargs_for_training(self): - """ Returns a dictionary of key-word arguments for input to the user supplied - module train function. """ - return { - 'hyperparameters': dict(self.hyperparameters), - 'input_data_config': dict(self.channels), - 'channel_input_dirs': dict(self.channel_dirs), - 'output_data_dir': self.output_data_dir, - 'model_dir': self.model_dir, - 'num_gpus': self.available_gpus, - 'num_cpus': self.available_cpus, - 'hosts': list(self.hosts), - 'current_host': self.current_host - } - - def default_save(self, mod): - """ Saves the specified mxnet module to ``self.model_dir``. - - This generates three files in ``self.model_dir``: - - - model-symbol.json - The serialized module symbolic graph. Formed by - invoking ```module.symbol.save``` - - model-0000.params - The serialized module parameters. Formed by - invoking ```module.save_params``` - - model-shapes.json - The serialized module input data shapes. A json list - of json data-shape objects. Each data-shape object - contains a string name and a list of integer dimensions. - Args: - mod : (mxnet.mod.Module) The module to save.""" - if not self.distributed or self.current_host_scheduler: - mod.symbol.save(os.path.join(self.model_dir, DEFAULT_MODEL_FILENAMES['symbol'])) - mod.save_params(os.path.join(self.model_dir, DEFAULT_MODEL_FILENAMES['params'])) - signature = self._build_data_shape_signature(mod) - with open(os.path.join(self.model_dir, DEFAULT_MODEL_FILENAMES['shapes']), 'w') as f: - json.dump(signature, f) - - @classmethod - def _build_data_shape_signature(cls, mod): - """ Returns a list of data shape description dicts. Each element in the - returned list is a dict with a 'name' key, mapping to a string name - and a 'shape' key, mapping to a list of ints. - """ - return [{"name": data_desc.name, "shape": [dim for dim in data_desc.shape]} - for data_desc in mod.data_shapes] - - -@cs.retry(stop_max_delay=1000 * 60 * 15, - wait_exponential_multiplier=100, - wait_exponential_max=30000) -def host_lookup(host): - """ Retrying host lookup on host """ - return socket.gethostbyname(host) - - -def _run_mxnet_process(role, mxnet_env): - """ Runs an mxnet process for the specified role with the specified - environment. - - Args: - role (str): The mxnet process role. - mxnet_env (MXNetEnvironment): The mxnet environment used to provide - environment variables for the launched process. - Returns: - (int) The launched process id """ - - role_env = os.environ.copy() - role_env.update(mxnet_env.env_vars_for_role(role)) - return subprocess.Popen("python -c 'import mxnet'", shell=True, env=role_env).pid - - -def train(base_dir=MXNetTrainingEnvironment.BASE_DIRECTORY): - """ Runs mxnet training on a user supplied module in either a local or distributed - SageMaker environment. - - The user supplied module and its dependencies are downloaded from S3, and the module - imported using a ``MXNetTrainingEnvironment`` instance. - - Training is invoked by calling a "train" function in the user supplied module. - - if the environment contains multiple hosts, then a distributed learning - task is started. This function will, in addition to running the user supplied script - as an mxnet parameter server worker process, launch an additional mxnet server - process. If the host this process is executing on is designated as the scheduler, then - this funciton will launch an mxnet scheduler parameter server process. - - Args: - base_dir (str): The SageMaker container environment base directory. - """ - logger.warning(UPCOMING_SCRIPT_MODE_WARNING) - - mxnet_env = MXNetTrainingEnvironment(base_dir) - logger.info("MXNetTrainingEnvironment: {}".format(repr(mxnet_env.__dict__))) - - if mxnet_env.user_script_archive.lower().startswith('s3://'): - mxnet_env.download_user_module() - - logger.info("Starting distributed training task") - if mxnet_env.current_host_scheduler: - _run_mxnet_process("scheduler", mxnet_env) - _run_mxnet_process("server", mxnet_env) - os.environ.update(mxnet_env.env_vars_for_role("worker")) - - user_module = mxnet_env.import_user_module() - train_args = inspect.getargspec(user_module.train) - - # avoid forcing our callers to specify **kwargs in their function - # signature. If they have **kwargs we still pass all the args, but otherwise - # we will just pass what they ask for. - if train_args.keywords is None: - kwargs_to_pass = {} - for arg in train_args.args: - if arg != "self" and arg in mxnet_env.kwargs_for_training: - kwargs_to_pass[arg] = mxnet_env.kwargs_for_training[arg] - else: - kwargs_to_pass = mxnet_env.kwargs_for_training - - model = user_module.train(**kwargs_to_pass) - if model: - if hasattr(user_module, 'save'): - user_module.save(model, mxnet_env.model_dir) - else: - mxnet_env.default_save(model) - - mxnet_env.write_success_file() diff --git a/src/sagemaker_mxnet_container/__init__.py b/src/sagemaker_mxnet_container/__init__.py new file mode 100644 index 00000000..45ba045d --- /dev/null +++ b/src/sagemaker_mxnet_container/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker_mxnet_container import training_utils # noqa: F401 diff --git a/src/sagemaker_mxnet_container/serving.py b/src/sagemaker_mxnet_container/serving.py new file mode 100644 index 00000000..5aa98978 --- /dev/null +++ b/src/sagemaker_mxnet_container/serving.py @@ -0,0 +1,332 @@ +# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import logging +import os + +import mxnet as mx +from sagemaker_containers.beta.framework import (content_types, encoders, env, errors, modules, + transformer, worker) + +logger = logging.getLogger(__name__) + +PREFERRED_BATCH_SIZE_PARAM = 'SAGEMAKER_DEFAULT_MODEL_FIRST_DIMENSION_SIZE' +DEFAULT_ENV_VARS = { + 'MXNET_CPU_WORKER_NTHREADS': '1', + 'MXNET_CPU_PRIORITY_NTHREADS': '1', + 'MXNET_KVSTORE_REDUCTION_NTHREADS': '1', + 'MXNET_ENGINE_TYPE': 'NativeEngine', + 'OMP_NUM_THREADS': '1', +} + +DEFAULT_MODEL_NAME = 'model' +DEFAULT_MODEL_FILENAMES = { + 'symbol': 'model-symbol.json', + 'params': 'model-0000.params', + 'shapes': 'model-shapes.json', +} + + +def default_model_fn(model_dir, preferred_batch_size=1): + """Function responsible for loading the model. This implementation is designed to work with + the default save function provided for MXNet training. + + Args: + model_dir (str): The directory where model files are stored + preferred_batch_size (int): The preferred batch size of the model's data shape (default: 1) + + Returns: + mxnet.mod.Module: the loaded model. + """ + for f in DEFAULT_MODEL_FILENAMES.values(): + path = os.path.join(model_dir, f) + if not os.path.exists(path): + raise ValueError('Failed to load model with default model_fn: missing file {}.' + 'Expected files: {}'.format(f, [file_name for _, file_name + in DEFAULT_MODEL_FILENAMES.items()])) + + shapes_file = os.path.join(model_dir, DEFAULT_MODEL_FILENAMES['shapes']) + preferred_batch_size = preferred_batch_size or os.environ.get(PREFERRED_BATCH_SIZE_PARAM) + data_names, data_shapes = _read_data_shapes(shapes_file, preferred_batch_size) + + sym, args, aux = mx.model.load_checkpoint(os.path.join(model_dir, DEFAULT_MODEL_NAME), 0) + + # TODO mxnet ctx - better default, allow user control + mod = mx.mod.Module(symbol=sym, context=mx.cpu(), data_names=data_names, label_names=None) + mod.bind(for_training=False, data_shapes=data_shapes) + mod.set_params(args, aux, allow_missing=True) + + return mod + + +def _read_data_shapes(path, preferred_batch_size=1): + with open(path, 'r') as f: + signature = json.load(f) + + data_names = [] + data_shapes = [] + + for s in signature: + name = s['name'] + data_names.append(name) + + shape = s['shape'] + + if preferred_batch_size: + shape[0] = preferred_batch_size + + data_shapes.append((name, shape)) + + return data_names, data_shapes + + +class MXNetTransformer(transformer.Transformer): + """Base class for creating Transformers with default methods for use with MXNet models. + """ + + VALID_CONTENT_TYPES = (content_types.JSON,) + + def __init__(self, model=None, model_fn=None, input_fn=None, predict_fn=None, output_fn=None, + error_class=None): + """Initialize an ``MXNetTransformer``. For each function, if one is not specified, + a default implementation is used. + + Args: + model (obj): a loaded model object that is ready for to be used for prediction + model_fn (fn): a function that loads a model + input_fn (fn): a function that takes request data and deserializes it for prediction + predict_fn (fn): a function that performs prediction with a model + output_fn (fn): a function that serializes a prediction into a response + error_class (Exception): the error class used to wrap functions that are not + the default ones defined in SageMaker Containers. + """ + input_fn = input_fn or self.default_input_fn + predict_fn = predict_fn or self.default_predict_fn + output_fn = output_fn or self.default_output_fn + + super(MXNetTransformer, self).__init__(model_fn=model_fn, input_fn=input_fn, + predict_fn=predict_fn, output_fn=output_fn, + error_class=error_class) + self._model = model + + def initialize(self): + """Execute any initialization necessary to start making predictions with the Transformer. + This method will load a model if it hasn't been loaded already. + """ + if self._model is None: + super(MXNetTransformer, self).initialize() + + def default_input_fn(self, input_data, content_type): + """Take request data and deserialize it into an object for prediction. + When an InvokeEndpoint operation is made against an Endpoint running SageMaker model server, + the model server receives two pieces of information: + + - The request's content type, for example "application/json" + - The request data + + The ``input_fn`` is responsible for preprocessing request data before prediction. + + Args: + input_data (obj): the request data + content_type (str): the request's content type + + Returns: + mxnet.nd.array: an MXNet NDArray + + Raises: + sagemaker_containers.beta.framework.errors.UnsupportedFormatError: if an unsupported + content type is used. + """ + if content_type in self.VALID_CONTENT_TYPES: + np_array = encoders.decode(input_data, content_type) + return mx.nd.array(np_array) + else: + raise errors.UnsupportedFormatError(content_type) + + def default_predict_fn(self, data, model): + """Use the model to create a prediction for the data. + + Args: + data (obj): input data for prediction + model (obj): the loaded model + + Returns: + obj: the prediction result + """ + transformer.default_predict_fn(data, model) + + def default_output_fn(self, prediction, accept): + """Serialize the prediction into a response. + + Args: + prediction (mxnet.nd.array): an MXNet NDArray that is the result of a prediction + accept (str): the accept content type expected by the client + + Returns: + sagemaker_containers.beta.framework.worker.Response: a Flask response object + + Raises: + sagemaker_containers.beta.framework.errors.UnsupportedFormatError: if an unsupported + accept type is used. + """ + if accept in self.VALID_CONTENT_TYPES: + return worker.Response(encoders.encode(prediction.asnumpy().tolist(), accept), accept) + else: + raise errors.UnsupportedFormatError(accept) + + +class ModuleTransformer(MXNetTransformer): + + VALID_CONTENT_TYPES = (content_types.JSON, content_types.CSV) + + def default_input_fn(self, input_data, content_type): + """Take request data and deserialize it into an object for prediction. + When an InvokeEndpoint operation is made against an Endpoint running SageMaker model server, + the model server receives two pieces of information: + + - The request's content type, for example "application/json" + - The request data + + The ``input_fn`` is responsible for preprocessing request data before prediction. + + Args: + input_data (obj): the request data + content_type (str): the request's content type + + Returns: + mxnet.io.NDArrayIter: data ready for prediction. + + Raises: + sagemaker_containers.beta.framework.errors.UnsupportedFormatError: if an unsupported + accept type is used. + """ + if content_type not in self.VALID_CONTENT_TYPES: + raise errors.UnsupportedFormatError(content_type) + + np_array = encoders.decode(input_data, content_type) + ndarray = mx.nd.array(np_array) + + # We require model to only have one input + [data_shape] = self._model.data_shapes + + # Reshape flattened CSV as specified by the model + if content_type == content_types.CSV: + _, target_shape = data_shape + ndarray = ndarray.reshape(target_shape) + + # Batch size is first dimension of model input + model_batch_size = data_shape[1][0] + pad_rows = max(0, model_batch_size - ndarray.shape[0]) + + # If ndarray has fewer rows than model_batch_size, then pad it with zeros. + if pad_rows: + num_pad_values = pad_rows + for dimension in ndarray.shape[1:]: + num_pad_values *= dimension + padding_shape = tuple([pad_rows] + list(ndarray.shape[1:])) + padding = mx.ndarray.zeros(shape=padding_shape) + ndarray = mx.ndarray.concat(ndarray, padding, dim=0) + + model_input = mx.io.NDArrayIter(ndarray, batch_size=model_batch_size, + last_batch_handle='pad') + + if pad_rows: + # Update the getpad method on the model_input data iterator to return the amount of + # padding. MXNet will ignore the last getpad() rows during Module predict. + def _getpad(): + return pad_rows + + model_input.getpad = _getpad + + return model_input + + def default_predict_fn(self, data, module): + """Use the model to create a prediction for the data. + + Args: + data (mxnet.io.NDArrayIter): input data for prediction + model (mxnet.module.BaseModule): an MXNet Module + + Returns: + list: the prediction result. This will be either a list of ``NDArray`` or + a list of lists of ``NDArray`` + """ + return module.predict(data) + + +class GluonBlockTransformer(MXNetTransformer): + def default_predict_fn(self, data, block): + """Use the model to create a prediction for the data. + + Args: + data (mxnet.nd.array): input data for prediction (deserialized by ``input_fn``) + block (mxnet.gluon.block.Block): a Gluon neural network + + Returns: + mxnet.nd.array: the prediction result + """ + return block(data) + + +def _update_mxnet_env_vars(): + for k, v in DEFAULT_ENV_VARS.items(): + if k not in os.environ: + os.environ[k] = v + + +def _transformer_with_transform_fn(model_fn, transform_fn): + user_transformer = transformer.Transformer(model_fn=model_fn, transform_fn=transform_fn) + user_transformer.initialize() + return user_transformer + + +def _user_module_transformer(user_module, model_dir): + model_fn = getattr(user_module, 'model_fn', default_model_fn) + + if hasattr(user_module, 'transform_fn'): + return _transformer_with_transform_fn(model_fn, getattr(user_module, 'transform_fn')) + + model = model_fn(model_dir) + if isinstance(model, mx.module.BaseModule): + transformer_cls = ModuleTransformer + elif isinstance(model, mx.gluon.block.Block): + transformer_cls = GluonBlockTransformer + else: + raise ValueError('Unsupported model type: {}'.format(model.__class__.__name__)) + + input_fn = getattr(user_module, 'input_fn', None) + predict_fn = getattr(user_module, 'predict_fn', None) + output_fn = getattr(user_module, 'output_fn', None) + + return transformer_cls(model=model, model_fn=model_fn, input_fn=input_fn, + predict_fn=predict_fn, output_fn=output_fn) + + +app = None + + +def main(environ, start_response): + global app + if app is None: + serving_env = env.ServingEnv() + _update_mxnet_env_vars() + + user_module = modules.import_module(serving_env.module_dir, serving_env.module_name) + user_transformer = _user_module_transformer(user_module, serving_env.model_dir) + + app = worker.Worker(transform_fn=user_transformer.transform, + module_name=serving_env.module_name) + + return app(environ, start_response) diff --git a/src/sagemaker_mxnet_container/training.py b/src/sagemaker_mxnet_container/training.py new file mode 100644 index 00000000..ff076e48 --- /dev/null +++ b/src/sagemaker_mxnet_container/training.py @@ -0,0 +1,82 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import logging +import os +import socket +import subprocess + +from retrying import retry +import sagemaker_containers.beta.framework as framework + +from sagemaker_mxnet_container.training_utils import scheduler_host + +LAUNCH_PS_ENV_NAME = 'sagemaker_parameter_server_enabled' +ROLES = ['worker', 'scheduler', 'server'] + +logger = logging.getLogger(__name__) + + +def _env_vars_for_role(role, hosts, ps_port, ps_verbose): + if role in ROLES: + return { + 'DMLC_NUM_WORKER': str(len(hosts)), + 'DMLC_NUM_SERVER': str(len(hosts)), + 'DMLC_ROLE': role, + 'DMLC_PS_ROOT_URI': _host_lookup(scheduler_host(hosts)), + 'DMLC_PS_ROOT_PORT': ps_port, + 'PS_VERBOSE': ps_verbose, + } + + raise ValueError('Unexpected role: {}'.format(role)) + + +def _run_mxnet_process(role, hosts, ps_port, ps_verbose): + role_env = os.environ.copy() + role_env.update(_env_vars_for_role(role, hosts, ps_port, ps_verbose)) + subprocess.Popen("python -c 'import mxnet'", shell=True, env=role_env).pid + + +@retry(stop_max_delay=1000 * 60 * 15, wait_exponential_multiplier=100, + wait_exponential_max=30000) +def _host_lookup(host): + return socket.gethostbyname(host) + + +def _verify_hosts(hosts): + for host in hosts: + _host_lookup(host) + + +def train(env): + logger.info('MXNet training environment: {}'.format(env.to_env_vars())) + + if env.additional_framework_parameters.get(LAUNCH_PS_ENV_NAME, False): + _verify_hosts(env.hosts) + + ps_port = env.hyperparameters.get('_ps_port', '8000') + ps_verbose = env.hyperparameters.get('_ps_verbose', '0') + + logger.info('Starting distributed training task') + if scheduler_host(env.hosts) == env.current_host: + _run_mxnet_process('scheduler', env.hosts, ps_port, ps_verbose) + _run_mxnet_process('server', env.hosts, ps_port, ps_verbose) + os.environ.update(_env_vars_for_role('worker', env.hosts, ps_port, ps_verbose)) + + framework.modules.run_module(env.module_dir, env.to_cmd_args(), + env.to_env_vars(), env.module_name) + + +def main(): + train(framework.training_env()) diff --git a/src/sagemaker_mxnet_container/training_utils.py b/src/sagemaker_mxnet_container/training_utils.py new file mode 100644 index 00000000..5168d2bf --- /dev/null +++ b/src/sagemaker_mxnet_container/training_utils.py @@ -0,0 +1,62 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os + +SYMBOL_PATH = 'model-symbol.json' +PARAMS_PATH = 'model-0000.params' +SHAPES_PATH = 'model-shapes.json' + + +def save(model_dir, model, current_host=None, hosts=None): + """Save an MXNet Module to a given location if the current host is the scheduler host. + + This generates three files in the model directory: + + - model-symbol.json: The serialized module symbolic graph. + Formed by invoking ``module.symbole.save``. + - model-0000.params: The serialized module parameters. + Formed by invoking ``module.save_params``. + - model-shapes.json: The serialized module input data shapes in the form of a JSON list of + JSON data-shape objects. Each data-shape object contains a string name and + a list of integer dimensions. + + Args: + model_dir (str): the directory for saving the model + model (mxnet.mod.Module): the module to be saved + """ + current_host = current_host or os.environ['SM_CURRENT_HOST'] + hosts = hosts or json.loads(os.environ['SM_HOSTS']) + + if current_host == scheduler_host(hosts): + model.symbol.save(os.path.join(model_dir, SYMBOL_PATH)) + model.save_params(os.path.join(model_dir, PARAMS_PATH)) + + signature = [{'name': data_desc.name, 'shape': [dim for dim in data_desc.shape]} + for data_desc in model.data_shapes] + with open(os.path.join(model_dir, SHAPES_PATH), 'w') as f: + json.dump(signature, f) + + +def scheduler_host(hosts): + """Return which host in a list of hosts serves as the scheduler for a parameter server setup. + + Args: + hosts (list[str]): a list of hosts + + Returns: + str: the name of the scheduler host + """ + return hosts[0] diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 00000000..74f14335 --- /dev/null +++ b/test/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import diff --git a/test/integ/conftest.py b/test/conftest.py similarity index 70% rename from test/integ/conftest.py rename to test/conftest.py index 6a62d673..35cddbe4 100644 --- a/test/integ/conftest.py +++ b/test/conftest.py @@ -1,26 +1,26 @@ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing # permissions and limitations under the License. +from __future__ import absolute_import import logging import os import platform - -import boto3 -import pytest import shutil import tempfile -from sagemaker import Session +import boto3 +import pytest +from sagemaker import LocalSession, Session logger = logging.getLogger(__name__) logging.getLogger('boto').setLevel(logging.INFO) @@ -35,12 +35,15 @@ def pytest_addoption(parser): parser.addoption('--docker-base-name', default='preprod-mxnet') parser.addoption('--region', default='us-west-2') - parser.addoption('--framework-version', required=True) - parser.addoption('--py-version', required=True, choices=['2', '3']) - parser.addoption('--processor', required=True, choices=['gpu','cpu']) + parser.addoption('--framework-version', default='1.3.0') + parser.addoption('--py-version', default='3', choices=['2', '3']) + parser.addoption('--processor', default='cpu', choices=['gpu', 'cpu']) + parser.addoption('--aws-id', default=None) + parser.addoption('--instance-type', default=None) # If not specified, will default to {framework-version}-{processor}-py{py-version} parser.addoption('--tag', default=None) + @pytest.fixture(scope='session') def docker_base_name(request): return request.config.getoption('--docker-base-name') @@ -66,6 +69,11 @@ def processor(request): return request.config.getoption('--processor') +@pytest.fixture(scope='session') +def aws_id(request): + return request.config.getoption('--aws-id') + + @pytest.fixture(scope='session') def tag(request, framework_version, processor, py_version): provided_tag = request.config.getoption('--tag') @@ -73,16 +81,37 @@ def tag(request, framework_version, processor, py_version): return provided_tag if provided_tag is not None else default_tag +@pytest.fixture(scope='session') +def instance_type(request, processor): + return request.config.getoption('--instance-type') or \ + 'ml.c4.xlarge' if processor == 'cpu' else 'ml.p2.xlarge' + + @pytest.fixture(scope='session') def docker_image(docker_base_name, tag): return '{}:{}'.format(docker_base_name, tag) +@pytest.fixture(scope='session') +def ecr_image(aws_id, docker_base_name, tag, region): + return '{}.dkr.ecr.{}.amazonaws.com/{}:{}'.format(aws_id, region, docker_base_name, tag) + + @pytest.fixture(scope='session') def sagemaker_session(region): return Session(boto_session=boto3.Session(region_name=region)) +@pytest.fixture(scope='session') +def sagemaker_local_session(region): + return LocalSession(boto_session=boto3.Session(region_name=region)) + + +@pytest.fixture(scope='session') +def local_instance_type(processor): + return 'local' if processor == 'cpu' else 'local_gpu' + + @pytest.fixture def opt_ml(): tmp = tempfile.mkdtemp() diff --git a/test/functional/conftest.py b/test/functional/conftest.py deleted file mode 100644 index bdfc7757..00000000 --- a/test/functional/conftest.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -import logging - -import boto3 -import pytest -from sagemaker import Session - -logger = logging.getLogger(__name__) -logging.getLogger('boto').setLevel(logging.INFO) -logging.getLogger('botocore').setLevel(logging.INFO) -logging.getLogger('factory.py').setLevel(logging.INFO) -logging.getLogger('auth.py').setLevel(logging.INFO) -logging.getLogger('connectionpool.py').setLevel(logging.INFO) - - -def pytest_addoption(parser): - parser.addoption('--aws-id', required=True) - parser.addoption('--docker-base-name', default='preprod-mxnet') - parser.addoption('--instance-type', required=True) - parser.addoption('--region', default='us-west-2') - parser.addoption('--tag', required=True) - - -@pytest.fixture(scope='session') -def aws_id(request): - return request.config.getoption('--aws-id') - - -@pytest.fixture(scope='session') -def docker_base_name(request): - return request.config.getoption('--docker-base-name') - - -@pytest.fixture(scope='session') -def instance_type(request): - return request.config.getoption('--instance-type') - - -@pytest.fixture(scope='session') -def region(request): - return request.config.getoption('--region') - - -@pytest.fixture(scope='session') -def tag(request): - return request.config.getoption('--tag') - - -@pytest.fixture(scope='session') -def docker_registry(aws_id, region): - return '{}.dkr.ecr.{}.amazonaws.com'.format(aws_id, region) - - -@pytest.fixture(scope='module') -def ecr_image(docker_registry, docker_base_name, tag): - return '{}/{}:{}'.format(docker_registry, docker_base_name, tag) - - -@pytest.fixture(scope='session') -def sagemaker_session(region): - return Session(boto_session=boto3.Session(region_name=region)) diff --git a/test/integ/docker_utils.py b/test/integ/docker_utils.py deleted file mode 100644 index ef10bc15..00000000 --- a/test/integ/docker_utils.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -from __future__ import absolute_import - -import logging -import os -import subprocess -import sys -import tempfile -import uuid -from time import sleep - -logger = logging.getLogger(__name__) - -CYAN_COLOR = '\033[36m' -END_COLOR = '\033[0m' - - -def registry(aws_id, region): - return '{}.dkr.ecr.{}.amazonaws.com'.format(aws_id, region) - - -def train(image_name, resource_folder, processor): - docker_cmd = 'docker' if processor == 'cpu' else 'nvidia-docker' - - cmd = [docker_cmd, - 'run', - '--rm', - '-h', 'algo-1', - '-v', '{}:/opt/ml'.format(resource_folder), - '-e', 'AWS_ACCESS_KEY_ID', - '-e', 'AWS_SECRET_ACCESS_KEY', - '-e', 'AWS_SESSION_TOKEN', - image_name, 'train'] - check_call(cmd) - - -def check_call(cmd, *popenargs, **kwargs): - if isinstance(cmd, str): - cmd = cmd.split(" ") - _print_cmd(cmd) - subprocess.check_call(cmd, *popenargs, **kwargs) - - -def _print_cmd(cmd): - print('executing docker command: {}{}{}'.format(CYAN_COLOR, ' '.join(cmd), END_COLOR)) - sys.stdout.flush() - - -class Container(object): - def __init__(self, image, processor, startup_delay=1): - self.temp_dir = tempfile.gettempdir() - self.image = image - self.name = str(uuid.uuid4()) - self.startup_delay = startup_delay - self.docker_cmd = 'docker' if processor == 'cpu' else 'nvidia-docker' - - def __enter__(self): - print('in container.enter for container ' + self.image + ',' + self.name) - self.remove_container() - - cmd = [self.docker_cmd, - 'run', - '-d', - '-t', - '-e', 'AWS_ACCESS_KEY_ID', - '-e', 'AWS_SECRET_ACCESS_KEY', - '-e', 'AWS_SESSION_TOKEN', - '--entrypoint', 'bash', - '--name', self.name, - self.image] - - check_call(cmd) - - # waiting for the server to spin up - sleep(self.startup_delay) - - self.execute_command(['pip', 'install', 'requests']) - self.execute_command(['pip', 'install', 'pytest']) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.remove_container() - - def remove_container(self): - cmd = [self.docker_cmd, - 'rm', - '-f', - self.name] - - try: - check_call(cmd) - except: - pass - - def copy(self, src, dst): - cmd = [self.docker_cmd, - 'cp', - src, - '{}:{}'.format(self.name, dst)] - - check_call(cmd) - - def execute_command(self, cmd): - - docker_cmd = [self.docker_cmd, 'exec', '-t', self.name] - docker_cmd.extend(cmd) - - _print_cmd(docker_cmd) - - lines = [] - process = subprocess.Popen(docker_cmd, stdout=subprocess.PIPE) - print( - '{}============================= container output ============================='.format( - CYAN_COLOR)) - for line in iter(process.stdout.readline, b''): - sys.stdout.write(line.decode('utf-8')) - sys.stdout.flush() - lines.append(line.decode('utf-8')) - msg = '\n{}========================= end of container output ==========================' - print(msg.format(CYAN_COLOR)) - - process.wait() - - warnings = 0 - for line in lines: - if line.startswith('WARNING'): - warnings += 1 - print(line) - else: - break - output = '\n'.join(lines[warnings:]) - - if process.returncode != 0: - print("docker exec error. output:\n{}".format(output)) - raise ValueError("non-zero exit code: {}".format(process.returncode)) - - return output - - def execute_pytest(self, tests_path): - container_test_path = '/root/{}'.format(os.path.basename(tests_path)) - self.copy(tests_path, container_test_path) - return self.execute_command(['pytest', '-vv', '-s', '--color=yes', container_test_path]) - - -class HostingContainer(Container): - def __init__(self, image, opt_ml, script_name, processor, region='us-west-2', - requirements_file=None, startup_delay=5): - super(HostingContainer, self).__init__(image=image, - processor=processor, - startup_delay=startup_delay) - self.opt_ml = opt_ml - self.script_name = script_name - self.region = region - self.requirements_file = requirements_file - - def __enter__(self): - cmd = [self.docker_cmd, - 'run', - '-d', - '-h', 'algo-1', - '-v', '{}:/opt/ml'.format(self.opt_ml), - '-e', 'AWS_ACCESS_KEY_ID', - '-e', 'AWS_SECRET_ACCESS_KEY', - '-e', 'AWS_SESSION_TOKEN', - '-e', 'SAGEMAKER_CONTAINER_LOG_LEVEL=20', - '-e', 'SAGEMAKER_REGION={}'.format(self.region), - '-e', 'SAGEMAKER_PROGRAM={}'.format(self.script_name), - '-e', 'SAGEMAKER_REQUIREMENTS={}'.format(self.requirements_file), - '--name', self.name, - self.image, 'serve'] - - check_call(cmd) - - # waiting for the server to spin up - sleep(self.startup_delay) - - self.execute_command(['pip', 'install', 'requests']) - self.execute_command(['pip', 'install', 'pytest']) - - return self - - def invoke_endpoint(self, input, content_type='application/json', accept='application/json'): - return self.execute_command([ - 'curl', - '-f', - '-H', 'Content-Type: {}'.format(content_type), - '-H', 'Accept: {}'.format(accept), - '-d', str(input), - 'http://127.0.0.1:8080/invocations' - ]) - - def ping(self): - self.execute_command(['curl', '-f', '-v', 'http://localhost:8080/ping']) diff --git a/test/integ/test_default_model_fn.py b/test/integ/test_default_model_fn.py deleted file mode 100644 index a0c0c00b..00000000 --- a/test/integ/test_default_model_fn.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -import docker_utils -import utils - - -# The image should serve a MXNet model saved in the -# default format, even without a user-provided script. -def test_default_model_fn(docker_image, opt_ml, processor): - resource_path = 'test/resources/default_handlers' - for path in ['code', 'model']: - utils.copy_resource(resource_path, opt_ml, path) - - input = [[1, 2]] - - with docker_utils.HostingContainer(image=docker_image, processor=processor, - opt_ml=opt_ml, script_name='empty_module.py') as c: - c.ping() - output = c.invoke_endpoint(input) - assert '[[4.9999918937683105]]' == output diff --git a/test/integ/test_gluon_hosting.py b/test/integ/test_gluon_hosting.py deleted file mode 100644 index aeb423ce..00000000 --- a/test/integ/test_gluon_hosting.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -import json - -import docker_utils -import utils - - -# The image should support serving Gluon-created models. -def test_gluon_hosting(docker_image, opt_ml, processor): - resource_path = 'test/resources/gluon_hosting' - for path in ['code', 'model']: - utils.copy_resource(resource_path, opt_ml, path) - - with open('test/resources/mnist_images/04.json', 'r') as f: - input = json.load(f) - - with docker_utils.HostingContainer(image=docker_image, processor=processor, - opt_ml=opt_ml, script_name='gluon.py') as c: - c.ping() - output = c.invoke_endpoint(input) - assert '[4.0]' == output diff --git a/test/integ/test_hosting.py b/test/integ/test_hosting.py deleted file mode 100644 index ae2e02db..00000000 --- a/test/integ/test_hosting.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -import json - -import docker_utils -import utils - - -# The image should use the model_fn and transform_fn defined -# in the user-provided script when serving. -def test_hosting(docker_image, opt_ml, processor): - resource_path = 'test/resources/dummy_hosting' - utils.copy_resource(resource_path, opt_ml, 'code') - - input = json.dumps({'some': 'json'}) - - with docker_utils.HostingContainer(image=docker_image, processor=processor, - opt_ml=opt_ml, script_name='dummy_hosting_module.py') as c: - c.ping() - output = c.invoke_endpoint(input) - assert input == output diff --git a/test/integ/test_linear_regression.py b/test/integ/test_linear_regression.py deleted file mode 100644 index b1c04d81..00000000 --- a/test/integ/test_linear_regression.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -from __future__ import print_function - -import os - -import docker_utils -import numpy as np -import utils -from sagemaker import fw_utils -from sagemaker.utils import sagemaker_timestamp - - -def test_linear_regression(docker_image, sagemaker_session, opt_ml, processor): - resource_path = 'test/resources/linear_regression' - - # create training data - train_data = np.random.uniform(0, 1, [1000, 2]) - train_label = np.array([train_data[i][0] + 2 * train_data[i][1] for i in range(1000)]) - - # eval data... repeat so there's enough to cover multicpu/gpu contexts - eval_data = np.array([[7, 2], [6, 10], [12, 2]]).repeat(32, 0) - eval_label = np.array([11, 26, 16]).repeat(32, 0) - - # save training data - for path in ['training', 'evaluation']: - os.makedirs(os.path.join(opt_ml, 'input', 'data', path)) - np.savetxt(os.path.join(opt_ml, 'input/data/training/train_data.txt'), train_data) - np.savetxt(os.path.join(opt_ml, 'input/data/training/train_label.txt'), train_label) - np.savetxt(os.path.join(opt_ml, 'input/data/evaluation/eval_data.txt'), eval_data) - np.savetxt(os.path.join(opt_ml, 'input/data/evaluation/eval_label.txt'), eval_label) - - s3_source_archive = fw_utils.tar_and_upload_dir(session=sagemaker_session.boto_session, - bucket=sagemaker_session.default_bucket(), - s3_key_prefix=sagemaker_timestamp(), - script='linear_regression.py', - directory=resource_path) - - utils.create_config_files('linear_regression.py', s3_source_archive.s3_prefix, opt_ml) - os.makedirs(os.path.join(opt_ml, 'model')) - - docker_utils.train(docker_image, opt_ml, processor) - - for f in ['output/success', 'model/model-symbol.json', 'model/model-0000.params', - 'model/model-shapes.json']: - assert os.path.exists(os.path.join(opt_ml, f)), 'expected file not found: {}'.format(f) diff --git a/test/integ/test_py_version.py b/test/integ/test_py_version.py deleted file mode 100644 index baaaad14..00000000 --- a/test/integ/test_py_version.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -from __future__ import print_function - -import json -import os - -import docker_utils -import utils -from sagemaker import fw_utils -from sagemaker.utils import sagemaker_timestamp - - -# The image should run the user-provided code using the right python version when training. -def test_train_py_version(docker_image, sagemaker_session, py_version, opt_ml, processor): - resource_path = 'test/resources/py_version/code' - - s3_source_archive = fw_utils.tar_and_upload_dir(session=sagemaker_session.boto_session, - bucket=sagemaker_session.default_bucket(), - s3_key_prefix=sagemaker_timestamp(), - script='usermodule.py', - directory=resource_path) - - hp = _py_version_dict(py_version) - - utils.create_config_files('usermodule.py', s3_source_archive.s3_prefix, opt_ml, - additional_hp=hp) - os.makedirs(os.path.join(opt_ml, 'model')) - docker_utils.train(docker_image, opt_ml, processor) - - # The usermodule.py train_fn will assert on the expected - # python versions passed in through hyperparameters, - # and training will fail if they are incorrect. - - success_file = os.path.join(opt_ml, 'output', 'success') - assert os.path.exists(success_file), 'expected file not found: {}'.format(success_file) - - -# The image should run the user-provided code using the right python version when hosting. -def test_hosting_py_version(docker_image, py_version, opt_ml, processor): - resource_path = 'test/resources/py_version' - utils.copy_resource(resource_path, opt_ml, 'code') - - input = json.dumps(_py_version_dict(py_version)) - - with docker_utils.HostingContainer(image=docker_image, processor=processor, - opt_ml=opt_ml, script_name='usermodule.py') as c: - c.ping() - # We send the json of the expect py versions in the request. - # The usermodule.py transform_fn will assert on the python versions, - # and this request will fail and throw an exception if they are incorrect. - c.invoke_endpoint(input) - - -def _py_version_dict(py_version): - maj_to_minor = {2: 7, # Need Python 2.7 for Python 2 - 3: 4} # Need Python 3.4 or above for Python 3 - - return {'py_major_version': str(py_version), - 'py_minimum_minor_version': str(maj_to_minor[py_version])} diff --git a/test/integ/utils.py b/test/integ/utils.py deleted file mode 100644 index 4089f3d7..00000000 --- a/test/integ/utils.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -import json -import logging -import os -import shutil - - -def serialize_hyperparameters(hp): - return {str(k): json.dumps(v) for (k, v) in hp.items()} - - -def save_as_json(data, filename): - with open(filename, "wt") as f: - json.dump(data, f) - - -def file_exists(resource_folder, file_name): - return os.path.exists(os.path.join(resource_folder, file_name)) - - -def create_config_files(program, s3_source_archive, path, additional_hp={}): - rc = { - "current_host": "algo-1", - "hosts": ["algo-1"] - } - - hp = {'sagemaker_region': 'us-west-2', - 'sagemaker_program': program, - 'sagemaker_submit_directory': s3_source_archive, - 'sagemaker_container_log_level': logging.INFO} - - hp.update(additional_hp) - - ic = { - "training": {"ContentType": "trainingContentType"}, - "evaluation": {"ContentType": "evalContentType"}, - "Validation": {} - } - - write_conf_files(rc, hp, ic, path) - - -def write_conf_files(rc, hp, ic, path): - os.makedirs('{}/input/config'.format(path)) - - rc_file = os.path.join(path, 'input/config/resourceconfig.json') - hp_file = os.path.join(path, 'input/config/hyperparameters.json') - ic_file = os.path.join(path, 'input/config/inputdataconfig.json') - - hp = serialize_hyperparameters(hp) - - save_as_json(rc, rc_file) - save_as_json(hp, hp_file) - save_as_json(ic, ic_file) - - -def copy_resource(resource_path, opt_ml_path, relative_src_path, relative_dst_path=None): - if not relative_dst_path: - relative_dst_path = relative_src_path - - shutil.copytree(os.path.join(resource_path, relative_src_path), - os.path.join(opt_ml_path, relative_dst_path)) diff --git a/test/integration/__init__.py b/test/integration/__init__.py new file mode 100644 index 00000000..98debc09 --- /dev/null +++ b/test/integration/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os + +RESOURCE_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'resources')) + +MODEL_SUCCESS_FILES = { + 'output': ['success'], + 'model': ['model-symbol.json', 'model-shapes.json', 'model-0000.params'], +} + +# Workaround for the intermittent worker timeout errors +# TODO: find and solve the root cause of this issue +NUM_MODEL_SERVER_WORKERS = 2 diff --git a/test/integration/local/local_mode_utils.py b/test/integration/local/local_mode_utils.py new file mode 100644 index 00000000..f4b871b7 --- /dev/null +++ b/test/integration/local/local_mode_utils.py @@ -0,0 +1,45 @@ +# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from contextlib import contextmanager +import fcntl +import os +import tarfile +import time + +from test.integration import RESOURCE_PATH + +LOCK_PATH = os.path.join(RESOURCE_PATH, 'local_mode_lock') + + +@contextmanager +def lock(): + # Since Local Mode uses the same port for serving, we need a lock in order + # to allow concurrent test execution. + local_mode_lock_fd = open(LOCK_PATH, 'w') + local_mode_lock = local_mode_lock_fd.fileno() + + fcntl.lockf(local_mode_lock, fcntl.LOCK_EX) + + try: + yield + finally: + time.sleep(5) + fcntl.lockf(local_mode_lock, fcntl.LOCK_UN) + + +def assert_output_files_exist(output_path, directory, files): + with tarfile.open(os.path.join(output_path, '{}.tar.gz'.format(directory))) as tar: + for f in files: + tar.getmember(f) diff --git a/test/integration/local/test_default_model_fn.py b/test/integration/local/test_default_model_fn.py new file mode 100644 index 00000000..5824ccf6 --- /dev/null +++ b/test/integration/local/test_default_model_fn.py @@ -0,0 +1,40 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +from __future__ import absolute_import + +import os + +from sagemaker.mxnet.model import MXNetModel + +import local_mode_utils +from test.integration import NUM_MODEL_SERVER_WORKERS, RESOURCE_PATH + + +# The image should serve a MXNet model saved in the +# default format, even without a user-provided script. +def test_default_model_fn(docker_image, sagemaker_local_session, local_instance_type): + default_handler_path = os.path.join(RESOURCE_PATH, 'default_handlers') + m = MXNetModel('file://{}'.format(os.path.join(default_handler_path, 'model')), 'SageMakerRole', + os.path.join(default_handler_path, 'code', 'empty_module.py'), + image=docker_image, sagemaker_session=sagemaker_local_session, + model_server_workers=NUM_MODEL_SERVER_WORKERS) + + input = [[1, 2]] + + with local_mode_utils.lock(): + try: + predictor = m.deploy(1, local_instance_type) + output = predictor.predict(input) + assert [[4.9999918937683105]] == output + finally: + sagemaker_local_session.delete_endpoint(m.endpoint_name) diff --git a/test/integration/local/test_gluon_hosting.py b/test/integration/local/test_gluon_hosting.py new file mode 100644 index 00000000..d51d78d8 --- /dev/null +++ b/test/integration/local/test_gluon_hosting.py @@ -0,0 +1,41 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os + +from sagemaker.mxnet.model import MXNetModel + +import local_mode_utils +from test.integration import NUM_MODEL_SERVER_WORKERS, RESOURCE_PATH + + +# The image should support serving Gluon-created models. +def test_gluon_hosting(docker_image, sagemaker_local_session, local_instance_type): + gluon_path = os.path.join(RESOURCE_PATH, 'gluon_hosting') + m = MXNetModel('file://{}'.format(os.path.join(gluon_path, 'model')), 'SageMakerRole', + os.path.join(gluon_path, 'code', 'gluon.py'), image=docker_image, + sagemaker_session=sagemaker_local_session, + model_server_workers=NUM_MODEL_SERVER_WORKERS) + + with open(os.path.join(RESOURCE_PATH, 'mnist_images', '04.json'), 'r') as f: + input = json.load(f) + + with local_mode_utils.lock(): + try: + predictor = m.deploy(1, local_instance_type) + output = predictor.predict(input) + assert [4.0] == output + finally: + sagemaker_local_session.delete_endpoint(m.endpoint_name) diff --git a/test/integration/local/test_hosting.py b/test/integration/local/test_hosting.py new file mode 100644 index 00000000..a468043b --- /dev/null +++ b/test/integration/local/test_hosting.py @@ -0,0 +1,41 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os + +from sagemaker.mxnet.model import MXNetModel + +import local_mode_utils +from test.integration import NUM_MODEL_SERVER_WORKERS, RESOURCE_PATH + + +# The image should use the model_fn and transform_fn defined +# in the user-provided script when serving. +def test_hosting(docker_image, sagemaker_local_session, local_instance_type): + hosting_resource_path = os.path.join(RESOURCE_PATH, 'dummy_hosting') + m = MXNetModel('file://{}'.format(os.path.join(hosting_resource_path, 'code')), 'SageMakerRole', + os.path.join(hosting_resource_path, 'code', 'dummy_hosting_module.py'), + image=docker_image, sagemaker_session=sagemaker_local_session, + model_server_workers=NUM_MODEL_SERVER_WORKERS) + + input = json.dumps({'some': 'json'}) + + with local_mode_utils.lock(): + try: + predictor = m.deploy(1, local_instance_type) + output = predictor.predict(input) + assert input == output + finally: + sagemaker_local_session.delete_endpoint(m.endpoint_name) diff --git a/test/integration/local/test_keras_training.py b/test/integration/local/test_keras_training.py new file mode 100644 index 00000000..4068fe3c --- /dev/null +++ b/test/integration/local/test_keras_training.py @@ -0,0 +1,37 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +from __future__ import absolute_import + +import os + +from sagemaker.mxnet import MXNet + +import local_mode_utils +from test.integration import MODEL_SUCCESS_FILES, RESOURCE_PATH + + +def test_keras_training(docker_image, sagemaker_local_session, local_instance_type, + framework_version, tmpdir): + keras_path = os.path.join(RESOURCE_PATH, 'keras') + script_path = os.path.join(keras_path, 'keras_mnist.py') + + mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1, + train_instance_type=local_instance_type, sagemaker_session=sagemaker_local_session, + image_name=docker_image, framework_version=framework_version, + output_path='file://{}'.format(tmpdir)) + + train = 'file://{}'.format(os.path.join(keras_path, 'data')) + mx.fit({'train': train}) + + for directory, files in MODEL_SUCCESS_FILES.items(): + local_mode_utils.assert_output_files_exist(str(tmpdir), directory, files) diff --git a/test/integration/local/test_linear_regression.py b/test/integration/local/test_linear_regression.py new file mode 100644 index 00000000..b6baf675 --- /dev/null +++ b/test/integration/local/test_linear_regression.py @@ -0,0 +1,42 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +from __future__ import absolute_import + +import os + +from sagemaker.mxnet import MXNet + +import local_mode_utils +from test.integration import MODEL_SUCCESS_FILES, RESOURCE_PATH + + +def test_linear_regression(docker_image, sagemaker_local_session, local_instance_type, + framework_version, tmpdir): + lr_path = os.path.join(RESOURCE_PATH, 'linear_regression') + + mx = MXNet(entry_point=os.path.join(lr_path, 'linear_regression.py'), role='SageMakerRole', + train_instance_count=1, train_instance_type=local_instance_type, + sagemaker_session=sagemaker_local_session, image_name=docker_image, + framework_version=framework_version, output_path='file://{}'.format(tmpdir)) + + data_path = os.path.join(lr_path, 'data') + s3_prefix = 'integ-test-data/mxnet-linear-regression' + train_input = sagemaker_local_session.upload_data(path=os.path.join(data_path, 'training'), + key_prefix=s3_prefix) + eval_input = sagemaker_local_session.upload_data(path=os.path.join(data_path, 'evaluation'), + key_prefix=s3_prefix) + + mx.fit({'training': train_input, 'evaluation': eval_input}) + + for directory, files in MODEL_SUCCESS_FILES.items(): + local_mode_utils.assert_output_files_exist(str(tmpdir), directory, files) diff --git a/test/integration/local/test_mnist_training.py b/test/integration/local/test_mnist_training.py new file mode 100644 index 00000000..b9a79511 --- /dev/null +++ b/test/integration/local/test_mnist_training.py @@ -0,0 +1,78 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +from __future__ import absolute_import + +import os + +import numpy +from sagemaker.mxnet import MXNet +from sagemaker.predictor import csv_serializer + +import local_mode_utils +from test.integration import MODEL_SUCCESS_FILES, NUM_MODEL_SERVER_WORKERS, RESOURCE_PATH + +MNIST_PATH = os.path.join(RESOURCE_PATH, 'mnist') +SCRIPT_PATH = os.path.join(MNIST_PATH, 'mnist.py') + +TRAIN_INPUT = 'file://{}'.format(os.path.join(MNIST_PATH, 'train')) +TEST_INPUT = 'file://{}'.format(os.path.join(MNIST_PATH, 'test')) + + +def test_mnist_training_and_serving(docker_image, sagemaker_local_session, local_instance_type, + framework_version, tmpdir): + mx = MXNet(entry_point=SCRIPT_PATH, role='SageMakerRole', train_instance_count=1, + train_instance_type=local_instance_type, sagemaker_session=sagemaker_local_session, + image_name=docker_image, framework_version=framework_version, + output_path='file://{}'.format(tmpdir)) + + _train_and_assert_success(mx, str(tmpdir)) + + with local_mode_utils.lock(): + try: + model = mx.create_model(model_server_workers=NUM_MODEL_SERVER_WORKERS) + predictor = _csv_predictor(model, local_instance_type) + data = numpy.zeros(shape=(1, 1, 28, 28)) + prediction = predictor.predict(data) + finally: + mx.delete_endpoint() + + # Check that there is a probability for each possible class in the prediction + prediction_values = prediction.decode('utf-8').split(',') + assert len(prediction_values) == 10 + + +def _csv_predictor(model, instance_type): + predictor = model.deploy(1, instance_type) + predictor.content_type = 'text/csv' + predictor.serializer = csv_serializer + predictor.accept = 'text/csv' + predictor.deserializer = None + return predictor + + +def test_distributed_mnist_training(docker_image, sagemaker_local_session, framework_version, + tmpdir): + mx = MXNet(entry_point=SCRIPT_PATH, role='SageMakerRole', train_instance_count=2, + train_instance_type='local', sagemaker_session=sagemaker_local_session, + image_name=docker_image, framework_version=framework_version, + output_path='file://{}'.format(tmpdir), + hyperparameters={'sagemaker_parameter_server_enabled': True}) + + _train_and_assert_success(mx, str(tmpdir)) + + +def _train_and_assert_success(estimator, output_path): + estimator.fit({'train': TRAIN_INPUT, 'test': TEST_INPUT}) + + for directory, files in MODEL_SUCCESS_FILES.items(): + local_mode_utils.assert_output_files_exist(output_path, directory, files) diff --git a/test/integration/local/test_onnx.py b/test/integration/local/test_onnx.py new file mode 100644 index 00000000..54ce55b5 --- /dev/null +++ b/test/integration/local/test_onnx.py @@ -0,0 +1,56 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +from __future__ import absolute_import + +import os + +import numpy +from sagemaker.mxnet import MXNet, MXNetModel + +import local_mode_utils +from test.integration import NUM_MODEL_SERVER_WORKERS, RESOURCE_PATH + +ONNX_PATH = os.path.join(RESOURCE_PATH, 'onnx') +SCRIPT_PATH = os.path.join(ONNX_PATH, 'code', 'onnx_export_import.py') + + +def test_onnx_export(docker_image, sagemaker_local_session, local_instance_type, framework_version, + tmpdir): + mx = MXNet(entry_point=SCRIPT_PATH, role='SageMakerRole', train_instance_count=1, + train_instance_type=local_instance_type, sagemaker_session=sagemaker_local_session, + image_name=docker_image, framework_version=framework_version, + output_path='file://{}'.format(tmpdir)) + + input_path = 'file://{}'.format(os.path.join(ONNX_PATH, 'mxnet_module')) + mx.fit({'train': input_path}) + + local_mode_utils.assert_output_files_exist(str(tmpdir), 'model', ['model.onnx']) + + +def test_onnx_import(docker_image, sagemaker_local_session, local_instance_type): + model_path = 'file://{}'.format(os.path.join(ONNX_PATH, 'onnx_model')) + m = MXNetModel(model_path, 'SageMakerRole', SCRIPT_PATH, image=docker_image, + sagemaker_session=sagemaker_local_session, + model_server_workers=NUM_MODEL_SERVER_WORKERS) + + input = numpy.zeros(shape=(1, 1, 28, 28)) + + with local_mode_utils.lock(): + try: + predictor = m.deploy(1, local_instance_type) + output = predictor.predict(input) + finally: + sagemaker_local_session.delete_endpoint(m.endpoint_name) + + # Check that there is a probability for each possible class in the prediction + assert len(output[0]) == 10 diff --git a/test/functional/test_mnist_distributed.py b/test/integration/sagemaker/test_mnist_distributed.py similarity index 58% rename from test/functional/test_mnist_distributed.py rename to test/integration/sagemaker/test_mnist_distributed.py index 0bebbf7a..08a436a8 100644 --- a/test/functional/test_mnist_distributed.py +++ b/test/integration/sagemaker/test_mnist_distributed.py @@ -1,58 +1,48 @@ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing # permissions and limitations under the License. +from __future__ import absolute_import -from sagemaker import Session +import os + +import numpy as np from sagemaker.mxnet.estimator import MXNet from sagemaker.utils import sagemaker_timestamp -from timeout import timeout, timeout_and_delete_endpoint -import numpy as np -import os +from test.integration import RESOURCE_PATH +from timeout import timeout, timeout_and_delete_endpoint -class MXNetTestEstimator(MXNet): - def __init__(self, docker_image_uri, **kwargs): - super(MXNetTestEstimator, self).__init__(**kwargs) - self.docker_image_uri = docker_image_uri - def train_image(self): - return self.docker_image_uri +def test_mxnet_distributed(sagemaker_session, ecr_image, instance_type, framework_version): + data_path = os.path.join(RESOURCE_PATH, 'mnist') + script_path = os.path.join(data_path, 'mnist.py') - def create_model(self, model_server_workers=None): - model = super(MXNetTestEstimator, self).create_model() - model.image = self.docker_image_uri - return model + mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=2, + train_instance_type=instance_type, sagemaker_session=sagemaker_session, + image_name=ecr_image, framework_version=framework_version, + hyperparameters={'sagemaker_parameter_server_enabled': True}) + prefix = 'mxnet_mnist/{}'.format(sagemaker_timestamp()) -def test_mxnet_distributed(sagemaker_session, ecr_image, instance_type): with timeout(minutes=15): - script_path = 'test/resources/mnist/mnist.py' - data_path = 'test/resources/mnist' - - mx = MXNetTestEstimator(entry_point=script_path, role='SageMakerRole', - train_instance_count=2, train_instance_type=instance_type, - sagemaker_session=sagemaker_session, - docker_image_uri=ecr_image) - - - prefix = 'mxnet_mnist/{}'.format(sagemaker_timestamp()) train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), key_prefix=prefix + '/train') test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), key_prefix=prefix + '/test') + mx.fit({'train': train_input, 'test': test_input}) - + with timeout_and_delete_endpoint(estimator=mx, minutes=30): predictor = mx.deploy(initial_instance_count=1, instance_type=instance_type) - data=np.zeros(shape=(1, 1, 28, 28)) + data = np.zeros(shape=(1, 1, 28, 28)) predictor.predict(data) diff --git a/test/functional/timeout.py b/test/integration/sagemaker/timeout.py similarity index 96% rename from test/functional/timeout.py rename to test/integration/sagemaker/timeout.py index a85070db..b6302ede 100644 --- a/test/functional/timeout.py +++ b/test/integration/sagemaker/timeout.py @@ -1,19 +1,20 @@ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing # permissions and limitations under the License. +from __future__ import absolute_import +from contextlib import contextmanager import logging import signal -from contextlib import contextmanager from botocore.exceptions import ClientError diff --git a/test/resources/default_handlers/code/empty_module.py b/test/resources/default_handlers/code/empty_module.py index 6d3c76a8..408840ee 100644 --- a/test/resources/default_handlers/code/empty_module.py +++ b/test/resources/default_handlers/code/empty_module.py @@ -1,14 +1,14 @@ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing # permissions and limitations under the License. # nothing here... we are testing default model loading and handlers diff --git a/test/resources/gluon_hosting/code/gluon.py b/test/resources/gluon_hosting/code/gluon.py index 8b712fa4..ce9f3164 100644 --- a/test/resources/gluon_hosting/code/gluon.py +++ b/test/resources/gluon_hosting/code/gluon.py @@ -1,30 +1,22 @@ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing # permissions and limitations under the License. +from __future__ import absolute_import, print_function -from __future__ import print_function - -import json import mxnet as mx from mxnet import gluon def model_fn(model_dir): - """ - Load the gluon model. Called once when hosting service starts. - - :param: model_dir The directory where model files are stored. - :return: a model (in this case a Gluon network) - """ symbol = mx.sym.load('%s/model.json' % model_dir) outputs = mx.symbol.softmax(data=symbol, name='softmax_label') inputs = mx.sym.var('data') @@ -34,21 +26,6 @@ def model_fn(model_dir): return net -def transform_fn(net, data, input_content_type, output_content_type): - """ - Transform a request using the Gluon model. Called once per request. - - :param net: The Gluon model. - :param data: The request payload. - :param input_content_type: The request content type. - :param output_content_type: The (desired) response content type. - :return: response payload and content type. - """ - # we can use content types to vary input/output handling, but - # here we just assume json for both - parsed = json.loads(data) - nda = mx.nd.array(parsed) +def predict_fn(nda, net): output = net(nda) - prediction = mx.nd.argmax(output, axis=1) - response_body = json.dumps(prediction.asnumpy().tolist()) - return response_body, output_content_type + return mx.nd.argmax(output, axis=1) diff --git a/test/resources/keras/data/mnist.npz b/test/resources/keras/data/mnist.npz new file mode 100644 index 00000000..b2c8e721 Binary files /dev/null and b/test/resources/keras/data/mnist.npz differ diff --git a/test/resources/keras/keras_mnist.py b/test/resources/keras/keras_mnist.py new file mode 100644 index 00000000..940dddb1 --- /dev/null +++ b/test/resources/keras/keras_mnist.py @@ -0,0 +1,111 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +'''Adapted from the Keras-MXNet example found at +https://github.com/awslabs/keras-apache-mxnet/blob/master/examples/mnist_cnn.py +''' +from __future__ import absolute_import, print_function + +import argparse +import json +import os + +import keras +from keras.models import Sequential +from keras.layers import Dense, Dropout, Flatten +from keras.layers import Conv2D, MaxPooling2D +from keras import backend as K +import numpy as np + + +def main(batch_size, epochs, num_classes, training_channel, model_dir): + # input image dimensions + img_rows, img_cols = 28, 28 + + # the data, split between train and test sets + dataset = np.load(os.path.join(training_channel, 'mnist.npz')) + x_train = dataset['x_train'] + y_train = dataset['y_train'] + x_test = dataset['x_test'] + y_test = dataset['y_test'] + + if K.image_data_format() == 'channels_first': + x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) + x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) + input_shape = (1, img_rows, img_cols) + else: + x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) + x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) + input_shape = (img_rows, img_cols, 1) + + x_train = x_train.astype('float32') + x_test = x_test.astype('float32') + x_train /= 255 + x_test /= 255 + print('x_train shape:', x_train.shape) + print(x_train.shape[0], 'train samples') + print(x_test.shape[0], 'test samples') + + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, num_classes) + y_test = keras.utils.to_categorical(y_test, num_classes) + + model = Sequential() + model.add(Conv2D(32, kernel_size=(3, 3), + activation='relu', + input_shape=input_shape)) + model.add(Conv2D(64, (3, 3), activation='relu')) + model.add(MaxPooling2D(pool_size=(2, 2))) + model.add(Dropout(0.25)) + model.add(Flatten()) + model.add(Dense(128, activation='relu')) + model.add(Dropout(0.5)) + model.add(Dense(num_classes, activation='softmax')) + + model.compile(loss=keras.losses.categorical_crossentropy, + optimizer=keras.optimizers.Adadelta(), + metrics=['accuracy']) + + model.fit(x_train, y_train, + batch_size=batch_size, + epochs=epochs, + verbose=1, + validation_data=(x_test, y_test)) + score = model.evaluate(x_test, y_test, verbose=0) + + print('Test loss:', score[0]) + print('Test accuracy:', score[1]) + + print('Saving model in MXNet format') + model_prefix = os.path.join(model_dir, 'model') + data_name, data_shapes = keras.models.save_mxnet_model(model=model, prefix=model_prefix, + epoch=0) + + signature = [{'name': data_name, 'shape': [dim for dim in data_desc.shape]} + for data_desc in data_shapes] + with open(os.path.join(model_dir, 'model-shapes.json'), 'w') as f: + json.dump(signature, f) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--batch-size', type=int, default=128) + parser.add_argument('--epochs', type=int, default=1) + parser.add_argument('--num_classes', type=float, default=12) + + parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) + parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN']) + + args = parser.parse_args() + + main(args.batch_size, args.epochs, args.num_classes, args.train, args.model_dir) diff --git a/test/resources/linear_regression/data/evaluation/eval_data.txt b/test/resources/linear_regression/data/evaluation/eval_data.txt new file mode 100644 index 00000000..e855e241 --- /dev/null +++ b/test/resources/linear_regression/data/evaluation/eval_data.txt @@ -0,0 +1,96 @@ +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +7.000000000000000000e+00 2.000000000000000000e+00 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +6.000000000000000000e+00 1.000000000000000000e+01 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 +1.200000000000000000e+01 2.000000000000000000e+00 diff --git a/test/resources/linear_regression/data/evaluation/eval_label.txt b/test/resources/linear_regression/data/evaluation/eval_label.txt new file mode 100644 index 00000000..12b7a051 --- /dev/null +++ b/test/resources/linear_regression/data/evaluation/eval_label.txt @@ -0,0 +1,96 @@ +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +1.100000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +2.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 +1.600000000000000000e+01 diff --git a/test/resources/linear_regression/data/training/train_data.txt b/test/resources/linear_regression/data/training/train_data.txt new file mode 100644 index 00000000..35dd03ea --- /dev/null +++ b/test/resources/linear_regression/data/training/train_data.txt @@ -0,0 +1,1000 @@ +8.383024078487886221e-01 1.677887367104703431e-01 +2.975134639944655435e-01 1.264650109599662064e-01 +2.496195570479450287e-01 8.591748407401146315e-01 +5.673422450810294837e-01 1.458089452536538433e-01 +5.832260661394322865e-01 7.166366823828255184e-01 +7.543975788532053528e-02 1.898609137606303010e-01 +1.161756569665539640e-01 6.103897427104261197e-01 +4.978558310832063016e-01 1.829751424763988998e-01 +9.532638876851960985e-01 6.610771972214268821e-02 +7.496557534864022321e-01 5.487246287652003396e-02 +8.341828997173914573e-01 1.104816762540004049e-01 +1.150912121068765614e-01 6.672354652467545888e-01 +1.897460773214154628e-02 7.846245703978612651e-01 +5.627194090337650501e-02 2.867594716432917412e-01 +8.174390194179276259e-01 4.563914419118150212e-01 +3.989178634746715080e-01 7.489820057015632582e-01 +8.160698834365822663e-01 7.183502899993565727e-01 +2.851832171545162220e-01 3.544062835716355053e-01 +9.847964820974128841e-01 8.375890996472251659e-01 +4.413986596987403699e-01 3.729481122744591204e-01 +8.333100770906549659e-01 1.214244265728061656e-01 +6.025881457603260438e-01 6.748859514084808442e-01 +2.774288229203732969e-01 8.951575035262029356e-01 +8.310747944405715293e-01 8.469795180363602904e-01 +9.606102242705764072e-01 2.261863096909838555e-01 +7.790341219499494763e-01 4.395847416268153784e-01 +8.185249724673727689e-01 8.154528733429207144e-01 +4.864674198509159053e-01 7.865969225155068356e-01 +9.774765709125590174e-01 7.100258815433094783e-01 +1.192802133214388993e-01 7.796972644067801106e-02 +3.979421707796121366e-01 9.537079477427545804e-03 +8.634802482941709156e-01 4.315626026512549007e-01 +6.507671830791106338e-02 5.803393321028155194e-01 +2.286115070914566116e-01 2.391841060812761599e-01 +7.139118803132278313e-01 2.963082837258553814e-01 +7.862211048496275945e-01 7.877350805847213033e-01 +8.549908979550473465e-01 4.502536931102080198e-01 +7.971059036210261795e-01 8.492300165291947200e-01 +5.933492067336298970e-01 6.032366908277607687e-01 +6.109971904460352876e-01 9.048489283694125529e-01 +3.336285170250096188e-01 2.593987030589036680e-01 +6.032417677737254014e-01 8.269827270469150005e-01 +1.917233916232584656e-01 9.429978485784278330e-01 +1.676145331676655026e-01 7.186871020147722611e-02 +2.987801297375362042e-01 8.894890374498145569e-01 +8.842081313558880495e-02 5.016939990677020678e-01 +4.714445148098037608e-01 6.944196004065336103e-01 +6.825831962559553157e-01 9.330737112884830431e-01 +1.596080233876779886e-01 4.490293053489992081e-01 +2.293319442968004829e-01 3.938302460609182098e-01 +1.724255200247095665e-01 9.656252673464847947e-01 +3.571546311400020102e-02 8.290361047550742768e-01 +3.564068639443062070e-01 1.311785446577451442e-01 +5.160575318976567960e-01 1.816201313117915550e-01 +3.061326303757062606e-01 3.331317916462018802e-01 +4.839776607157061372e-01 1.597060032073258640e-01 +3.233320531875456538e-01 1.544771924451002754e-01 +3.014176450696248732e-01 3.498450418777804272e-01 +5.750417202374324965e-01 2.815025161948752075e-01 +8.282908238934235667e-01 5.346921822244938838e-01 +8.652779589008368077e-01 6.387646270379421232e-01 +3.608082302376293127e-01 3.124137611494598454e-01 +3.607061761532146527e-01 4.455322911182015844e-01 +3.826876255704469987e-02 3.377869020284606227e-01 +3.800436644493996319e-01 1.023426972440211191e-02 +8.168350371976894619e-01 4.387607563547659684e-01 +3.768102037124578807e-01 9.110710754796123023e-01 +8.220030254942101156e-01 4.581763411287436139e-01 +1.694595915439977007e-01 1.684670908686005664e-01 +6.268263732250551890e-01 3.467261007113505711e-01 +8.040352074373506852e-01 6.839697290179880262e-01 +5.297253159877972628e-01 7.801082886281295314e-01 +9.204472650679746959e-01 6.809176859771448598e-03 +6.274320145445084762e-01 9.412849540108589874e-01 +2.208717638484863910e-01 2.672388586827953993e-01 +2.273923772978706515e-01 9.601296870780158255e-01 +6.668816224681960447e-01 1.081783540866855242e-01 +3.524471036896552611e-01 3.518815540013774390e-01 +5.964383677725925148e-01 8.800393950010902966e-01 +1.828686310894113154e-01 8.194458442357740457e-01 +6.867874625068572936e-01 7.272342428532760916e-01 +4.696262454662605279e-01 7.910382575586286658e-01 +1.240133014893253227e-01 8.107120105361649287e-01 +2.184692953172299168e-01 4.396747030365865117e-01 +1.909150812786594154e-01 4.044686681756488600e-01 +9.149829347144164959e-01 4.232369506474595866e-01 +7.810148179270500757e-01 7.735706391299421725e-01 +2.784715045117651666e-01 9.169654356102895365e-01 +5.935966409269489708e-01 4.749747742867101996e-01 +7.625103644874287045e-02 4.629955087520672974e-01 +4.412492795213716068e-01 7.345116388622754400e-01 +1.124375872313340130e-01 7.469995494365728739e-01 +7.302050139999879397e-01 2.826951807225084323e-01 +9.164444762399615030e-02 4.855003868554905733e-01 +9.300725486390291064e-01 5.037961688595056220e-01 +7.152702613072786653e-01 3.506846843769784172e-02 +1.693038628685927804e-01 9.908347930346177446e-01 +6.716395514934471178e-01 6.671559524759217430e-01 +7.055318334308481321e-01 8.117988463192934434e-01 +5.552755972738980805e-01 5.213329221332383412e-01 +3.065995892509220111e-01 6.976625638007041275e-01 +5.883525953831193478e-01 9.807732363996419789e-01 +4.673344799061497357e-01 7.896972864321324392e-01 +7.004011853610250249e-01 3.213698348667398674e-01 +5.920091084742643561e-01 3.054471318357341314e-01 +1.239839610132075221e-01 8.365899701547263589e-02 +1.457939770636089705e-01 7.213691276948183129e-01 +4.454130259495948652e-01 9.510865606958091023e-01 +7.846062937926735481e-01 5.644167618037644329e-01 +5.929528371943701970e-01 1.396595307293948540e-01 +9.380514689339938261e-01 8.717622571464408443e-01 +8.806271043046921543e-01 9.492755159187297176e-01 +9.184434708497772881e-01 9.141298838211823208e-01 +4.608944926443627743e-01 4.397045689174887695e-01 +2.338490711339061834e-01 6.353480332043058842e-01 +3.659690331032360300e-01 5.452304071170951394e-01 +3.677826051771501126e-01 9.061714568924033042e-01 +6.260565816992923471e-01 6.129126230225873107e-01 +3.844835631335716819e-01 8.656862201497333986e-01 +2.029361256598716245e-01 5.859680666404064597e-01 +4.084770043552404806e-01 1.537312073923035038e-01 +1.274365690811052332e-02 6.035470996215089867e-01 +7.749047074695316617e-01 6.117242373171003500e-01 +7.406181215626550518e-01 8.837384171435757052e-01 +9.125778267790626286e-01 7.791205971331660596e-01 +2.083268323224333196e-01 6.342260606520001565e-01 +7.637902100602371736e-01 9.268285068617030387e-01 +9.669409947888499834e-01 2.724413156444204231e-01 +8.669094829859033124e-01 4.521841115683111756e-02 +7.517756726253931321e-01 5.840131057956556848e-01 +8.682565025043532669e-01 5.843736672974518820e-01 +7.766415460621236555e-01 5.121018760063240638e-01 +4.750343219630869251e-01 1.057627699957567335e-01 +1.919163524300298018e-01 8.726928211535622815e-02 +8.841871727741301701e-02 2.244008153170994335e-01 +5.946262124724579490e-01 5.206392815264951901e-01 +2.047774026317988882e-01 7.868516444616732164e-01 +3.552550365931861354e-02 7.098442598521116365e-01 +1.591235920252519476e-01 8.277011471878045201e-01 +9.985348457206892903e-01 8.084832796127849885e-02 +9.850606605232240520e-01 7.781544786578342698e-02 +9.015701052498616974e-01 9.805835854201766510e-01 +5.371699437281217682e-01 5.469001913409059501e-01 +2.846456456390242806e-01 9.959521962514357174e-01 +4.031580133435924695e-01 6.303222613148458109e-01 +4.287075503852456615e-01 4.496844539941227126e-01 +4.639008224384898460e-01 6.851133963464660637e-01 +6.792573666782096842e-01 2.095271567687492231e-01 +8.136217385026819349e-01 1.697270402150350410e-01 +2.557927744094984357e-01 5.665059824807582300e-01 +1.281541089371324160e-01 8.549223496314084514e-02 +2.638698544702112736e-01 1.222485755292984644e-01 +3.047582875020528048e-01 6.275069175070332461e-01 +5.647620217370177276e-01 7.904567975298585347e-01 +4.036462600481217633e-01 7.806385381619854735e-01 +5.939400223387876210e-01 3.103020736244491706e-01 +7.197387431953912174e-01 4.809271587859470509e-01 +9.766983041485984840e-01 3.685620204928425903e-01 +6.604153321826389256e-02 1.526300811945057179e-01 +8.124566717800533189e-01 4.251228731910109415e-01 +7.209512005426614856e-01 5.137763417687180034e-01 +7.292611164144505720e-01 2.693120741546124464e-01 +1.329570016995029746e-01 4.008052825224850402e-02 +9.589113498649407541e-01 3.609583739244716094e-01 +2.769818801221202165e-02 5.697518290238317018e-01 +1.068270709019163434e-01 4.153460509448425597e-03 +1.871425861099094146e-01 7.896987946425083615e-01 +8.633058092519574345e-02 8.743037428545405998e-01 +4.724161645817301913e-01 8.731409446206293001e-01 +9.115857343860696016e-01 2.782428200986323263e-01 +6.325777061237630727e-01 3.363035209446058049e-01 +4.601150433498071290e-01 6.434064965434341810e-01 +9.588229648594517318e-01 9.275169099301002751e-01 +3.678638474423008287e-01 9.445353179519598852e-01 +1.103085667345976573e-01 8.206803046436577098e-01 +1.652374359311311602e-01 6.611442835019621933e-01 +9.705314026417570794e-01 3.580572247682171083e-01 +8.907559119146977178e-01 7.178262292373349673e-02 +6.493020361372067173e-01 9.029645961084464822e-03 +3.255857998968599842e-01 5.790046317664084841e-01 +1.595940802551198967e-01 9.287002027571378271e-01 +6.687281063376360191e-02 4.784874040266523521e-01 +2.101727105521983274e-01 5.390278141626621000e-01 +7.185277561946411540e-01 3.065072057270324946e-01 +7.529913611915978189e-01 4.621527785605770156e-01 +8.189974948343687888e-01 1.813330609888427203e-01 +8.518971644481416750e-01 7.068007272814534891e-01 +1.843020058649663673e-01 1.580029847797694842e-01 +6.231113591133967677e-01 9.711918574105556923e-01 +2.365353159412968420e-01 1.981442190777924495e-01 +1.503215362808757449e-01 1.718080464922876471e-01 +4.259226148073714802e-01 1.150065162707918365e-01 +8.561407760264651001e-01 1.280444278463812591e-01 +9.706412130992536635e-01 7.789112789907962631e-01 +2.631319180515721001e-01 9.111269429411616994e-01 +7.906060154263128137e-01 7.157317853826180576e-01 +1.816949484719589947e-01 1.982784638966361701e-01 +6.856286424429820503e-01 5.962126088148623770e-01 +1.519024795299419583e-01 7.932596323498776192e-01 +3.087794499460388975e-01 6.437959676475982773e-01 +4.083290756906061691e-01 5.111462208658594797e-01 +6.349306821086332375e-01 3.674893369853412528e-01 +6.040465714073095516e-02 3.730148375098043534e-01 +5.816092512410933058e-01 5.877069279099511601e-02 +8.657758234536604425e-01 3.932066426795367553e-01 +9.731014665173752798e-01 8.201859403191955611e-01 +3.312575732352163183e-01 9.904049458051351618e-01 +3.529610398836890983e-02 6.040642370786475368e-02 +9.400166327348009609e-01 7.666889772421199289e-01 +9.650481298595759982e-01 9.909496607908230281e-01 +6.376082761387081854e-01 7.715989066775119642e-02 +7.917263851389778617e-01 7.603850953605617580e-01 +8.653994475512121198e-01 3.103028659121853661e-01 +7.886294032337868298e-01 9.676218191546855918e-01 +6.544486419528149357e-01 9.103767910363743665e-01 +3.594094959582955218e-01 5.086077976477177964e-01 +8.075087244553775800e-01 8.921249486859645916e-01 +1.738462671576084251e-01 4.876520558562669194e-01 +6.100956714825823690e-01 1.227994862819939970e-01 +7.431119248842745773e-01 3.283844308330718942e-01 +6.049508960714664596e-01 1.061238485780235941e-02 +6.670344101478763443e-01 3.709316200128898178e-01 +8.032014188616156680e-01 3.620482501876087778e-01 +8.461810100443541982e-01 3.366049898437263321e-01 +8.243376277360814353e-01 5.626099916746604146e-01 +2.777456458651673676e-01 5.352153434576349955e-01 +2.144972128094989605e-01 6.462905748725051858e-01 +1.411054648258945887e-01 3.874013487592282701e-01 +1.455679330739455546e-01 7.866004666417265589e-01 +9.646378321458504157e-02 1.276574830233686564e-01 +9.537557360550436014e-01 9.218327072747647399e-01 +6.761169230998558266e-01 6.553251339168126854e-01 +9.438443018871440771e-01 3.136369379742454111e-02 +7.806376882304875142e-01 1.036613415943627992e-01 +6.755450199591354066e-01 8.830939075518353665e-01 +3.703200403786915951e-01 7.378558239125937446e-01 +4.083981502320577217e-02 5.317646501195525399e-01 +4.588240943760051316e-01 1.334080084370786734e-01 +3.976369325491413420e-02 3.080165119178776001e-01 +9.128749918100648175e-01 2.518579840637702461e-02 +8.910145971192089398e-01 2.800267901402442305e-01 +7.817915814544191822e-01 9.085400931363494248e-01 +4.786402605264623578e-01 7.838388744797001317e-01 +9.983789310731807642e-01 3.986135384421525263e-01 +3.054915353443502513e-01 3.759900218180347586e-01 +5.862288822424402968e-01 8.899453307850887507e-01 +6.765340553132245249e-01 3.312945153751368776e-01 +5.215770176717444695e-01 6.636277266050472079e-02 +3.762914768931964193e-01 5.353243736321113833e-01 +8.544889234164023684e-01 4.448957017112782353e-01 +8.029159032349508784e-02 3.069587128468759252e-01 +6.258645534197797566e-01 2.491265826156449092e-01 +7.422393609585329344e-01 4.593356660587325502e-01 +3.163739726904254423e-01 9.607950890652511289e-01 +3.029668438840564804e-01 9.487809125717810987e-01 +9.150429934149041378e-01 9.901244682643428607e-02 +6.641958436835612156e-01 5.693547716538129855e-01 +4.801492487993587188e-01 9.837307809886627341e-01 +5.829688097508683997e-01 4.388802447164756870e-01 +5.130316576660298722e-01 4.216077670428421964e-01 +2.010624734166412964e-01 1.489273420293244499e-01 +3.799110985123050899e-01 5.007459418709111487e-01 +5.574251988579199679e-01 5.407979817969279512e-01 +5.775213219019252886e-01 1.118383387883373770e-01 +3.307113136482112692e-01 7.787229757858391554e-01 +8.575086525727236841e-01 5.792566609388496435e-01 +2.093723714815615145e-01 7.097089697122516139e-01 +3.495779539458884200e-01 6.768775943572785669e-01 +1.112705738470751715e-01 7.583304337439176201e-01 +9.113687908399233928e-01 9.093006414144270133e-01 +6.275516389226335479e-01 3.686000584564471660e-01 +5.523575098546420481e-01 8.012823982125339972e-01 +7.785719898025718066e-01 7.821961622405214642e-01 +2.314221362538955162e-01 2.502065122934530317e-01 +9.969400903071042874e-01 9.070560639303645312e-01 +6.184620637843044610e-01 3.640655729818116360e-01 +5.533826833374978493e-01 9.781846421125112157e-01 +7.056656540474212091e-01 7.684872224990613176e-02 +7.094972538524537686e-01 2.958657351994834483e-01 +8.922266534709257790e-01 2.396104929036385611e-01 +9.088901795656982863e-01 4.592736874772258560e-01 +7.906062059344364235e-01 7.423972752090797211e-01 +1.572559381748483975e-02 5.442046586863313884e-03 +8.029383384763015830e-01 9.064380071402458050e-01 +4.474780280080000727e-01 2.051584974906528469e-01 +3.673755618999777761e-01 9.880209097345016289e-01 +4.250615504716066528e-01 9.594554828287480408e-01 +6.983651986342834794e-01 5.477076630169783744e-01 +9.610521792159071319e-01 7.942742909471527391e-01 +6.451155114680274805e-01 4.615565424813307116e-01 +3.247598738250020434e-01 8.691276554143118060e-01 +3.858907938192537301e-01 4.368324097406192097e-01 +6.768853743332736705e-02 7.120422155957784049e-02 +2.519664777694199387e-01 3.292418259014442405e-01 +3.920689511744778200e-01 6.307389569217785041e-01 +9.116129520311636592e-01 2.977841258284951165e-01 +6.064251184432498665e-01 9.336707696004856727e-01 +1.410163269998576663e-01 6.053285551030473943e-02 +6.371210597738642401e-01 3.557806181379385979e-02 +7.714756916845022516e-01 9.470162729043541949e-01 +8.554096007262812495e-01 9.930686607055322357e-01 +7.172758345498162491e-01 6.265858098513791274e-01 +8.414140473657533148e-01 4.966874202890800838e-01 +3.306656428652440249e-01 2.472794634517974099e-01 +6.511463329353402996e-01 6.516209056415722678e-01 +8.820225643933244930e-01 8.019755852650246597e-01 +7.779340791044830894e-01 5.046969613447482850e-01 +5.542777219484996687e-01 3.617756419363350462e-01 +4.847585756805794510e-01 1.952114773834285888e-01 +3.257134988998301450e-01 9.983727139976084075e-04 +5.133966839063784926e-01 9.811159825641211363e-01 +4.739891551273156534e-02 1.093404791395267139e-01 +7.925342221010112098e-01 8.909326503781184181e-01 +6.980721060596640770e-01 9.065275786665114577e-01 +6.861438119226962451e-01 9.169856535499099071e-01 +7.231947551005425412e-01 3.064293695020999797e-01 +5.415256600401792131e-01 8.602108827468467611e-01 +2.258945378495835765e-01 6.573072855590507135e-01 +3.986527099224081994e-01 5.285795244542125237e-01 +3.025482935135203055e-01 9.919005395913123024e-01 +4.338153910258238932e-01 4.358024575011643087e-01 +2.475174792259851042e-01 5.541684938604243271e-01 +4.025447365008089085e-01 3.739459273320903110e-01 +8.321271090494570810e-01 9.099078839299188370e-01 +5.146785153919719580e-01 3.797232557702716615e-01 +1.044422553781739760e-01 2.868734835397968030e-02 +2.330791089083606993e-01 1.238008772222493059e-01 +8.749137721371718079e-01 3.380391813680962976e-01 +4.227708387638877685e-01 6.469578655288212854e-01 +5.149468595291012418e-01 3.885421701514381043e-01 +2.046123698273422109e-01 5.729580570622277991e-01 +8.872730952537809523e-01 8.028148065411891965e-01 +6.799096989470755359e-01 9.521607074599734988e-01 +9.035598222764366083e-01 1.974544262734838629e-01 +9.871727975854077686e-01 1.494009905894433388e-01 +8.985718030276261814e-01 1.547013863854193039e-01 +3.596368262945428551e-01 4.187935153656615617e-01 +4.427880136542013956e-01 5.941570165267008319e-01 +4.968555709885420502e-01 4.448332783750671426e-01 +4.266589995258218693e-02 8.187415125368459057e-01 +1.529841786770963497e-01 8.023377971040025480e-01 +2.835860619678104433e-01 3.292602702079205912e-03 +5.521107226064472240e-01 4.498031862636304723e-01 +8.250388850302322608e-01 7.037417032578121567e-01 +6.731741655782760292e-01 7.162084806034378914e-01 +2.949377647847289952e-01 3.113450588166752908e-01 +9.036248699388248484e-01 3.609980282055822443e-01 +2.310580782604680916e-01 7.583371333093635869e-02 +8.684089708278489583e-01 2.316528671907031045e-01 +4.284359159912076720e-01 1.155217610395248284e-01 +8.918850210557537350e-01 9.292577227737051082e-01 +4.181795520561645274e-01 2.078362469070892393e-01 +8.155545111386341972e-02 6.054325191692860875e-01 +7.986509800341756726e-02 3.364875371445056329e-01 +3.740819990885517754e-01 6.614432408548044373e-01 +2.090305626274024942e-01 7.915071994125496557e-01 +7.205251981321542099e-01 8.238090322040405589e-01 +3.995986937307246523e-01 3.615694072664621395e-02 +8.032968235336979124e-01 3.699843208614299606e-01 +8.528756709147838899e-01 8.444379197121522074e-01 +9.137486199663812370e-01 6.745037659274336628e-01 +1.969977227256869412e-01 4.499942670997181260e-03 +1.565419907470932959e-01 1.987313269769153123e-01 +8.557953186463665984e-01 6.256300984176158542e-01 +5.352179415260442941e-01 4.977126528848289100e-01 +1.575528912997643527e-01 1.482871082529678697e-01 +3.998489784248510093e-01 9.679595516746858364e-01 +2.735152496837670588e-01 9.981633311935217412e-01 +5.085433224155877507e-02 5.960495377300722986e-02 +9.644220853710960828e-02 3.768905107056532700e-01 +2.263517137286907843e-01 4.594442166656355075e-01 +4.368493403307759371e-01 2.927338106150034758e-02 +7.716110444356890641e-01 5.988577286684080292e-01 +5.658915975088050221e-01 1.951150983790292948e-01 +1.119465796372507427e-01 1.083981660927685331e-01 +2.660377744095732133e-01 5.082524806378612148e-01 +3.360855153180662969e-01 2.354405899889167042e-01 +1.106621321590484097e-01 7.172729208623528363e-01 +6.914291509931883573e-01 9.952141184613827463e-01 +1.940206147451362106e-02 7.913193129581146401e-01 +1.947742547753916753e-01 1.079716224505279376e-01 +4.423010412464462027e-01 7.569704630565607939e-01 +9.820230073315712183e-01 9.980107503654567225e-02 +5.131140055202482175e-01 9.181783553883702265e-01 +1.102726695894991060e-01 6.960421932285588698e-01 +5.117917311727535612e-01 7.825073411482024177e-01 +4.547210701706549951e-01 4.194896587204901550e-01 +1.040742477261946686e-01 2.936268896063535250e-01 +1.339617371809707080e-01 7.375545881290336281e-01 +8.456265460760997721e-01 7.975834484820485404e-01 +1.249534142168581274e-01 2.941426687222089553e-01 +1.221504860511302137e-02 2.301704645655116011e-01 +7.036082943956578628e-01 6.675637335145860884e-01 +8.140813482154464209e-01 9.042133057914429894e-01 +5.031009675413422588e-02 8.911673993004404215e-01 +8.699340103121477874e-02 3.928902084374309434e-01 +9.355638875780134533e-01 9.742595318612956579e-02 +3.085094848789464717e-01 8.185794189840217561e-01 +6.925739728098793124e-01 5.997167102166891528e-01 +9.231250629430460819e-01 3.125797295621405292e-01 +8.408389814747191382e-01 8.159498493843351241e-01 +4.157537391147637651e-01 6.311759701258471411e-01 +3.580290531545354815e-01 5.293758720824448538e-01 +1.841598964801085536e-01 7.021177245457383975e-01 +7.635348849014744754e-02 2.380585886428914222e-01 +4.387826081365773234e-01 5.827285780118126590e-01 +9.456176352381311379e-01 7.733824091774560427e-01 +2.653191723174960215e-01 1.097204150468367168e-01 +4.426151333624752660e-01 6.522563600736013267e-01 +3.678735156451374833e-01 5.097290661839307679e-01 +9.618776700428398607e-01 7.610691466310840481e-01 +6.063372640732855867e-01 1.735247167552067538e-02 +8.448417309023040778e-01 7.258132822641665260e-01 +8.408132605171986240e-01 9.267591412696354336e-01 +4.620358379210881550e-01 1.509328769027159511e-01 +1.866974173205060428e-02 4.715917225936879786e-01 +7.905913943637078445e-01 7.500327640741093482e-01 +9.746584957004947869e-01 5.599891130763412495e-01 +9.354557650720832784e-01 5.529444068433244652e-01 +8.562210744517403382e-01 5.439230917414853428e-01 +9.890053808503045385e-01 7.562612278181509717e-01 +1.852713861303961762e-01 3.521977305619883314e-01 +4.613999606915777374e-01 9.941629351053085095e-01 +7.805736413025843312e-03 6.723821088145814917e-01 +1.067205820079188960e-01 2.479713043864242294e-01 +3.438753275276172916e-01 1.166057072887581869e-01 +4.988632291039905020e-01 1.196617386547983442e-01 +5.793395002748347178e-02 4.835071938230007804e-01 +4.771667054310673173e-01 1.538108082280676969e-01 +4.788870898000183729e-01 4.925107275522856431e-01 +6.825217672938170832e-01 1.507721315641344928e-01 +1.361340387186270595e-01 1.612456563213356953e-01 +5.439000427226557965e-01 3.504410962385700046e-01 +8.438894260647635814e-01 3.014415013827586032e-01 +4.735260180752270953e-01 1.957143995173128426e-01 +9.176613094129515424e-01 3.083205904365670103e-01 +1.930233855476229943e-01 9.457989500312418141e-01 +5.664341421196406046e-01 2.900231323129399685e-01 +1.978474493894737396e-01 4.074017026881328052e-01 +4.077400027980844222e-01 4.121838750982187438e-01 +3.395572841853528034e-01 7.636238794804683927e-02 +7.401181577988665827e-01 1.032612456261183809e-02 +7.379096487642236024e-01 3.614455574496189261e-02 +8.328150249637792379e-01 7.188593847261725012e-01 +7.227805999768921463e-02 8.538271090307636424e-02 +9.691501375311373234e-02 7.063275214016347947e-01 +8.933306350411753805e-01 7.775603047642398735e-01 +5.985121099644223897e-01 6.237277942857158974e-01 +4.956196254914534416e-01 2.992500865040209623e-01 +5.209105554580506503e-02 9.111617735973254195e-02 +3.829011363771986920e-01 8.884055998444249092e-01 +3.878071740178880011e-01 2.058684126468334208e-01 +7.390561355335544258e-01 3.169897942950061775e-02 +5.864948510060807862e-03 9.968974827681926776e-01 +9.040896497400939813e-01 2.283696201233843537e-01 +7.357577015507330565e-01 7.486648259255979099e-01 +3.264799505801634361e-01 3.113162884114172302e-01 +7.810714164649570357e-01 3.323908679670606325e-01 +7.679176339539339757e-02 3.748737110347453161e-01 +3.565165865465014239e-02 3.246034005351716711e-01 +1.093496190211734609e-01 7.597085319050164154e-01 +9.093456749524542193e-02 5.410418219259970352e-01 +2.427967970848180501e-01 6.564241187942023714e-01 +9.138229049476330745e-01 4.908421189605568502e-01 +1.291069230265344814e-01 7.251989675602796837e-01 +1.409541622977511510e-01 6.454081673611818948e-01 +2.884556794144453873e-01 9.233658815738509107e-01 +6.736188190265453013e-01 3.000741596404912226e-01 +4.563584561427888042e-01 7.974320194520057026e-01 +9.904164770621127767e-01 2.081475934431399333e-01 +7.026252992665003028e-01 2.195972215970745101e-01 +1.781834691511624547e-01 6.849485960807493168e-01 +3.567995794796274822e-01 8.349362693471118479e-02 +6.201617201556451908e-01 2.834602991292343166e-01 +8.596021118137735506e-01 6.157241713583965215e-01 +9.244697173911231580e-02 4.634685704660772565e-01 +9.424422324482255586e-01 5.754115887945733254e-01 +5.225044284502698977e-01 9.367894785287403936e-01 +8.574540812722769578e-01 4.146728604452698530e-02 +5.954524222187198568e-01 3.635496743132089481e-01 +4.066516735146737949e-01 4.262444022916811459e-02 +6.536921901009495883e-01 2.584566440277392685e-01 +4.757137841272600642e-01 6.503839357884960704e-01 +6.728020312063508479e-01 5.130133378295996804e-01 +7.264472821179496220e-01 5.624692557519872516e-01 +1.595937090410205483e-01 2.775564581827192345e-01 +6.381247994293870462e-02 2.468976353274412450e-01 +5.259745968299561758e-01 6.848567523685504810e-01 +9.601238219624145254e-01 6.402441829373576310e-01 +3.315107435458013452e-01 5.080895726293063186e-01 +2.848964747805079467e-01 6.051472003714331027e-02 +9.530843437441492982e-01 5.203492828025872985e-01 +2.323446010216982938e-01 3.328127679770490621e-01 +2.281959622431951606e-01 7.884583019913606883e-01 +3.864585852135926647e-02 5.674592339101215410e-01 +2.234721414668535688e-01 9.932490478910508402e-01 +7.826277029460694257e-01 7.088174831246227825e-01 +8.670694941972373337e-01 2.736416820378763104e-01 +1.217059723212294831e-01 6.497205479173525644e-01 +1.381521314958327196e-01 2.591855843152879313e-02 +9.854879359251045301e-02 8.975566102692902337e-01 +4.503024504493163827e-01 4.413300799456845169e-01 +8.876436613531758502e-01 6.798049580871364794e-01 +3.802225086274289012e-01 5.385627802569674660e-01 +5.918066777338415463e-04 4.088473878415384011e-01 +9.762447492477377420e-01 3.492490683492609493e-01 +2.317144613979703616e-01 8.568992556287140516e-01 +2.219766318943260375e-01 9.193324082826727084e-01 +3.303373343830153797e-01 9.616981500961963025e-01 +2.098522626610911379e-01 1.831008568626395494e-01 +1.286500753586045098e-01 3.645463749934729769e-01 +8.954369901118905961e-01 5.476737612801951816e-02 +8.614121408140957836e-01 5.961927235008219750e-01 +6.874870120376100902e-01 1.904487059006509941e-01 +5.713990629455255599e-01 1.254387492770148160e-02 +1.341243689714005649e-01 4.417310689094990916e-01 +5.485754342961137731e-01 1.382769082045059994e-01 +1.482692416728816909e-01 8.613497709871698049e-02 +8.783772167373280482e-01 4.487131089245243176e-01 +5.253435074114043424e-02 2.908691039884850893e-01 +4.351461288838956998e-01 8.576403080489769515e-02 +2.245891639734476364e-02 1.642216118959480831e-01 +7.193145046025805067e-01 9.087167134314375616e-01 +3.280772654937409438e-01 1.230552091322263397e-01 +3.437767172102235147e-01 8.515439377120637143e-01 +7.022672554086155650e-01 6.766914466511018134e-01 +3.883264006417851411e-01 2.237723726128778567e-01 +7.084290220966880103e-01 1.048521967117377329e-01 +9.210686299294170709e-01 5.702584783461603912e-01 +7.476381262799322025e-01 6.836300637619746601e-01 +6.028305540761168535e-01 9.110875948933063739e-01 +6.450752233527956570e-01 4.716567123292317865e-01 +2.128167859001461881e-01 8.073270680968812618e-01 +7.367606973152996952e-01 2.659597722649783558e-01 +5.638750405803171217e-01 4.529097210902803816e-01 +3.153400069860686772e-01 1.213027192454227965e-01 +4.383829359617239518e-01 9.970557819893532958e-02 +6.894354786430673121e-01 6.627522299795105543e-01 +3.735786823466145101e-01 6.662914972081191989e-01 +9.012340577928240615e-01 9.124747397004417948e-02 +4.171180031711874125e-02 8.046123418166754027e-01 +9.106345828391736452e-01 6.942447696194820050e-02 +3.187554560663546921e-01 7.049680839516307973e-01 +8.644028755561017352e-01 5.750808744770172787e-01 +9.175778777076349924e-01 6.652203687935333010e-01 +9.361104951046901324e-01 7.837470780191203579e-01 +7.713580864926913305e-01 5.048360170863220864e-01 +2.234269416182480539e-01 3.877918924141723345e-01 +7.495425259218346792e-01 3.513002166977918783e-03 +1.523820701451646942e-01 3.054697243597449541e-01 +1.227009922747063220e-01 3.063235261717880409e-01 +9.380491684173838385e-01 3.245521704109459726e-01 +9.874675500761005420e-01 7.659994635427876952e-01 +5.689781825085866718e-01 3.581257241530183544e-01 +5.064606194126873850e-01 6.887797761812964881e-01 +5.140342007692975113e-01 8.117578847187219004e-01 +8.141755468068677670e-01 7.246822339838220328e-01 +5.236463040872161301e-01 4.112909637619865943e-01 +8.551398989454037647e-01 8.437702618326180515e-01 +1.355641852033336869e-01 6.581783570625826529e-01 +4.101152602215402121e-01 3.345949593409534017e-01 +1.977642419303093924e-03 3.791785756323633461e-01 +4.262064572608939006e-01 1.443244231455963478e-01 +9.649505473489465857e-01 9.318261193841106405e-01 +1.187535242850421557e-01 6.535117811758486273e-01 +6.088350550834270569e-01 2.037189166341989432e-01 +8.801694795634512625e-01 7.403965482731129644e-01 +4.129598559430694982e-01 3.974636917839036165e-01 +1.048679713590531248e-01 8.192532495058443276e-01 +2.103627351008530155e-02 7.403835043175972519e-01 +5.633546357737514443e-01 8.377430096964610939e-01 +4.075170219418589035e-01 9.948231005176640940e-01 +1.818561936302240944e-01 7.338217981423564851e-01 +5.494714354181297056e-01 7.545660426739950744e-01 +1.208350216572945390e-01 6.003468245485784616e-01 +4.156888846351469713e-02 1.593720599374193192e-02 +4.685528073102881441e-01 1.828895196232701581e-01 +1.795636630958974411e-02 5.322800128455369428e-01 +5.308718608193101485e-01 1.017595403781013097e-01 +7.559668411663326149e-01 1.814265724604927632e-01 +8.295680445882546161e-01 9.351583705836271632e-01 +8.248173915931933387e-02 6.339619789880839296e-01 +6.056624270163069035e-02 4.509365602947875695e-01 +5.210182455383979283e-01 4.107197927484074196e-02 +4.375091097547595398e-01 5.964029779263316788e-01 +5.083068539382360473e-01 3.177553246043725022e-01 +2.524448938812202670e-01 2.651012298128986311e-01 +7.516135043276879202e-01 5.362427644728184895e-01 +9.664424944533788198e-01 1.216378847222130410e-01 +4.823520824385874040e-02 7.508776783894376416e-01 +7.695584242932145269e-01 3.566328754932351908e-01 +2.979596989004885677e-01 9.719635086844701233e-01 +6.527941447869687330e-01 2.428626517249086181e-01 +2.078741871135608710e-01 5.609834053575017965e-01 +9.210743681754396395e-02 6.754130475947299761e-01 +3.900667563599302667e-01 6.895898386506000843e-01 +1.863769919791595786e-01 8.036973607023739641e-01 +6.044568260864299214e-01 7.152722797582423686e-01 +5.281262399742195912e-01 7.177431359562378432e-01 +6.176592386470886042e-01 7.745337132362445365e-01 +2.912639982223417245e-01 2.281518181365307374e-01 +9.879684393695083999e-01 5.045610125949442359e-01 +6.515707444210294685e-01 2.649034315405518480e-02 +2.229044364813378420e-01 5.872962676112370994e-01 +9.076973516241017004e-01 4.003614359331484351e-01 +3.869816510997891035e-01 7.226902763751598124e-01 +2.924575625510107413e-01 2.261917768518628691e-02 +8.073775022621298980e-01 4.792023494099860725e-01 +2.852467678963364062e-01 7.236350203203090192e-01 +8.618823889927100801e-01 7.693694304743381451e-02 +1.774867927662450162e-01 9.484330371147726702e-01 +2.788065807017234832e-01 9.001755988385525820e-01 +6.489824607368465115e-01 5.278933087254380485e-01 +7.792718032258008076e-01 3.514420586683525416e-01 +8.927707920444960532e-02 8.725324421910816142e-01 +3.041142844210357188e-01 3.901649208250186351e-01 +1.160110211898776056e-01 8.125790118589920619e-01 +9.955215729939642655e-01 7.622197123565443988e-01 +7.587880529977198973e-01 4.057398369667259974e-01 +5.920462130540221635e-01 8.869973082249138496e-01 +4.751251400017607596e-01 1.276294785019704836e-01 +6.285949532548033014e-01 1.822297066679565702e-01 +3.218145878816695005e-01 4.995750014614246393e-01 +5.258931460244908873e-01 7.967565045603254292e-01 +5.441262048900243897e-01 5.399994423608776017e-01 +1.049801435526787063e-01 7.993129642989154071e-01 +8.248744372008091075e-01 7.998336471068726583e-01 +8.259243585556951261e-01 7.685882786939971512e-01 +5.557531040109136011e-01 6.515400574844844028e-01 +7.829614700624405454e-01 4.804193069304583119e-01 +8.897709265554543112e-02 8.438516841064150897e-01 +3.837375697513344397e-01 2.449650586965446131e-01 +7.085085170096523388e-01 4.564498263864957828e-02 +4.831699508173840085e-01 5.776989984771069286e-01 +6.282684089236583169e-01 6.459452093774490367e-01 +1.757629591209973752e-01 5.527782453769336524e-02 +6.796004371192584381e-01 2.266259460264508085e-01 +9.597802500860100894e-01 3.481416725340237273e-01 +7.037630854329687269e-01 9.242750200341917077e-01 +5.010837694192707970e-01 9.040882173687329049e-01 +2.933263967548637163e-01 2.138624895897545120e-02 +1.531419142295791236e-01 8.990893391665865098e-01 +2.880603188252883751e-02 6.453758487009941813e-02 +2.812732115927741861e-01 1.735009200222847880e-02 +4.837128151627079209e-02 5.982976303315697963e-01 +6.661769366905064826e-01 2.454743864204814718e-01 +9.581356502727376157e-01 5.944246469480429784e-01 +2.545108641980791120e-01 2.828186415238487239e-02 +4.776677869609158300e-02 5.153250398083841777e-02 +4.012122120774097578e-01 5.684372064953436832e-01 +4.444460135979297899e-01 9.472474525055233352e-01 +8.816138021915194134e-01 7.125032019083059609e-01 +5.914078979684496140e-01 5.651939539827961489e-01 +4.356198270658702132e-01 3.723850440160991759e-03 +3.623523609664758283e-02 3.544586168729285669e-01 +2.684828031046444519e-01 4.812385312768369117e-01 +7.101108678317759049e-01 7.717603178265882713e-01 +1.414178901200575655e-01 7.736935128054334143e-02 +7.373346068665437292e-01 3.385455254460815189e-01 +6.626870667655482272e-02 9.102973897688595439e-01 +2.516816292677741540e-01 7.901286824991042845e-01 +8.049632503826346097e-01 6.396512288771249999e-02 +8.754419863255999967e-01 3.538722430122054341e-02 +5.272611300256335731e-01 5.926642921320142143e-01 +5.822792192373116293e-02 4.430019580490932585e-01 +7.285726082707725260e-01 3.043769026795516419e-01 +8.572291486676064665e-01 3.348907162730342391e-01 +6.453875474891593855e-01 7.262912056828685348e-02 +7.088852012862842944e-01 7.481009239536384747e-02 +7.781171566854169219e-01 2.163589710620887896e-01 +8.873270129840704623e-02 5.971324008880048995e-01 +3.197106666530737940e-01 4.939157107891551934e-01 +1.554098125780527528e-01 5.887348471051101484e-03 +5.570926362958826816e-01 9.500914828218951502e-01 +5.090841116638744612e-02 3.114287715111283550e-01 +7.633917954483568913e-01 6.294851836447792337e-01 +2.688417903027584543e-02 6.448095328258510728e-01 +8.912655817351665677e-01 2.120567861541335919e-01 +3.974371388525699267e-01 9.644716287439840663e-01 +5.823049039779960578e-01 6.493975426657253580e-01 +8.678007314838097619e-01 6.052940046550494557e-01 +3.590282899309689135e-01 6.583261623203532364e-01 +8.642205683627406376e-01 6.173148440027111938e-01 +5.868537352236697924e-01 8.345389542714531927e-02 +2.641224586281778475e-01 5.851841673029556379e-02 +7.864519469784118089e-01 7.307077237336123687e-01 +7.903340313316409294e-03 8.938162095984893973e-01 +9.049730474197619223e-01 9.137966368862314726e-01 +6.154283056243046168e-01 1.986480653053938061e-01 +4.330208003295985542e-01 6.933992708173439823e-01 +9.175604616837926830e-01 9.476402874815395005e-01 +2.321518913003384199e-01 2.644195418059079872e-01 +5.577865215570303459e-01 5.621060448126775633e-01 +3.825472355323962237e-01 8.573061354370772325e-01 +2.117208067332099208e-01 7.558356270051139747e-01 +3.538540572514348170e-02 1.897269717201677253e-01 +5.119489865927520089e-01 8.464759565036672484e-03 +8.764159498513089952e-01 5.141191667863689307e-01 +7.534205138168876825e-01 5.879031560588231509e-01 +5.048356042481927375e-02 1.492291178011176944e-01 +3.330868259777490792e-01 7.583389454896984905e-01 +1.778887192449216670e-01 8.434728336780827895e-03 +3.410692868867399863e-01 3.853122731054468009e-01 +1.140515617791125758e-01 2.077992063966812575e-01 +2.883659551304600388e-01 4.258251226071103623e-01 +2.285578939978002433e-01 9.873211816270192065e-01 +8.761664863960298977e-01 8.121987996428687850e-01 +2.128292476806452616e-01 7.899499637622502402e-01 +1.245731316176624803e-01 1.783970285362517227e-01 +7.243147685690876081e-02 5.455740099877369742e-01 +4.660054411505430982e-01 6.390960722477979639e-01 +3.707579587392769715e-01 6.782273168122123774e-01 +7.856993728959337231e-01 4.544278932980941876e-01 +6.345933555704985629e-02 2.598855853329318233e-02 +1.332583267347483558e-01 9.358734870280179807e-01 +1.941169199354861874e-01 9.750203540278805647e-01 +2.655960651022588603e-01 4.727441694070015910e-01 +5.145138297179999709e-01 8.773305525549915940e-01 +4.967313515883727826e-01 5.134922754777506704e-01 +6.563620328781015756e-01 6.154721571793843049e-01 +2.137875580739523329e-01 9.784723083154184264e-01 +6.443863277904252618e-01 1.712955151830615197e-01 +8.187888796074483944e-01 5.430509914193981658e-01 +4.457050966899438471e-01 7.634855878468963519e-01 +8.393913572278954849e-02 1.055360192248533391e-01 +3.723579180551873646e-01 2.932122403186994131e-01 +9.038032762688730237e-01 3.555734467956012113e-01 +8.310038626955711383e-01 2.014581554768498650e-01 +2.879827343752314750e-01 1.347608350091109575e-01 +5.861459913445251635e-01 8.579185134890391362e-01 +8.937921512469971796e-01 1.917913905135948394e-02 +1.712115141563219645e-01 3.775686184837985238e-01 +3.609522349696062671e-01 4.016075373921546010e-01 +9.859000263795933972e-01 8.384653830161430399e-01 +1.579482868523491845e-01 7.230597817327677479e-01 +8.249424005113004110e-01 3.402621510212822109e-02 +7.788473047170800800e-01 8.843142594620712060e-03 +5.980353240565023931e-01 6.698945866730424692e-01 +4.649141291952869626e-01 3.331306745508549882e-01 +4.607568396251037202e-01 1.375452015259245320e-01 +3.940638362408231954e-01 6.578983401948277354e-01 +2.979030604451201603e-01 5.619459559238717539e-01 +9.096836294131154244e-01 1.163021801246277853e-02 +7.122482131097271907e-01 6.184617701600746420e-01 +9.351357478200653661e-01 8.711284134729423423e-01 +2.888807188501854029e-01 7.127929033685657112e-01 +1.073668832238354742e-01 6.272327059726612308e-01 +6.363240542506543029e-02 2.273810554894636349e-01 +9.109460488484705243e-01 6.034207359717773667e-01 +8.106366507744677818e-01 7.501577065736853012e-02 +9.067133139915272144e-01 1.781257310932599092e-01 +1.268500777715941830e-01 8.605972836578483776e-01 +4.433733276634036491e-01 2.074337847702457660e-01 +5.313630654852069535e-01 1.245160995356370259e-01 +4.064132924002339387e-02 2.757316761389969617e-01 +1.490490063975113966e-02 9.688502155586038445e-01 +9.415410515775940459e-01 5.469853079734202650e-01 +9.035050980999652648e-01 6.008756563417972929e-01 +2.428658053093464275e-01 2.811736339788150074e-01 +5.298575128146194579e-01 9.855281502486992196e-01 +4.288901159274234320e-01 3.573337828383784665e-01 +3.574717700835194911e-02 7.751829329336363816e-01 +9.103388224418299535e-01 5.126620837391837604e-02 +2.624568787657962643e-01 7.243498650264889038e-01 +8.555009893047700320e-01 4.116404937282136522e-01 +1.332530267811786739e-01 1.577372821782696555e-01 +3.092811646396320668e-02 4.527659131632109091e-01 +4.255568006390483271e-01 1.884521515103796396e-01 +9.184021118951570806e-01 1.685117089194491102e-01 +6.490960748899591293e-01 7.617708628429187279e-01 +8.132764529421080413e-01 1.231407394268501099e-01 +9.025704551206739090e-01 9.098733516938393162e-01 +9.954653559742755453e-01 3.406395643213395719e-01 +2.281902694349513538e-01 2.797788064032538102e-01 +4.319869960186834001e-01 6.685879624885855410e-01 +5.019243836168962858e-01 7.303962180649213876e-01 +1.309260741066866629e-01 7.018534052095382059e-01 +3.554069808688353760e-02 7.180439068550409365e-01 +9.151277169558854618e-01 7.090937278169288094e-01 +6.995489084505391864e-01 5.291440149606218002e-01 +3.393335572494574626e-01 9.993804792068130549e-01 +7.572329817682987851e-02 5.948952828867681841e-02 +2.097381367764484805e-01 1.367217079407101465e-02 +3.235728851817581964e-01 8.762297599744730547e-01 +6.431331162107156718e-01 5.367601851651536693e-01 +3.607585709205554103e-01 3.190806597254842725e-01 +4.251716264106255583e-01 8.562348034465712132e-01 +4.066581120912381442e-01 9.827586624331338117e-01 +9.043309118105447020e-01 7.632484952230000896e-01 +7.614696436021350934e-02 9.236188202668300651e-01 +5.932462175263797377e-02 4.911408406100881141e-01 +9.668933246010449345e-01 5.832294521057445058e-01 +6.662347183964689723e-02 3.975425345290490853e-01 +7.116409244823973523e-02 2.957766173393681086e-01 +9.631792152288688458e-01 6.397910284244049395e-01 +1.128629149613796789e-01 3.787830777227455537e-01 +5.451194667653219694e-01 1.112246008380932683e-01 +2.533045516384183626e-01 1.916772169684801685e-01 +1.116897468172373120e-02 6.397273873250586274e-02 +6.874816494608968487e-01 3.161627233689374705e-01 +3.291706459339697188e-01 3.514023081501465029e-01 +4.210611552446644756e-01 9.947631582004567985e-01 +8.964820479454381763e-01 5.312044695658465487e-01 +5.520807683579117553e-01 4.189032213809968930e-01 +5.504279646662376679e-01 8.640699125690795013e-01 +8.763626111920964634e-01 9.046130199065686162e-01 +1.139916820036463996e-01 9.087825724062976729e-01 +4.738680263462405895e-01 4.150775351682552960e-01 +9.512238895897489055e-01 6.964667377640976209e-02 +1.414590656787420020e-01 4.226333820871741587e-01 +1.282369778598845267e-01 7.715179538600827058e-01 +1.847309217354827782e-01 6.072762727055831533e-01 +7.998669280259028680e-01 7.436234331528007724e-01 +2.105227746358684371e-03 6.507322665403569850e-01 +8.023894395308018845e-01 6.939731168798601857e-01 +1.204948444013995701e-01 5.440075536111917964e-01 +6.202383621305045391e-01 7.012468189838982902e-02 +3.570258065373245193e-01 7.876936477849316720e-01 +8.157453079386777217e-01 8.232607988401881016e-01 +8.728229514565279956e-01 9.154248578038189921e-01 +2.322469376984092060e-01 1.134364532828557870e-03 +5.067839668015045040e-01 1.910183594724866873e-01 +1.245808710454127111e-01 9.819856416721098880e-01 +6.715344006944702349e-01 9.256659931247881445e-01 +7.926605245083367945e-01 9.672937152638602099e-01 +6.779824947790246137e-01 5.577508221221844087e-01 +1.937709372231182225e-01 3.515153565507245226e-02 +4.995223823477987768e-01 7.109049836233916064e-01 +7.463178161945269729e-01 5.743324117364775327e-01 +1.239379063948186088e-01 6.318684084676029489e-01 +1.875181145049569231e-01 8.197301972019480631e-01 +8.402578644598499569e-01 4.242018408962140441e-01 +2.678243529246217003e-01 6.230576841518946640e-01 +6.684547025660826369e-01 5.167279532946317255e-01 +5.357028898889649726e-01 5.976630823316623209e-01 +4.780859402444781470e-02 1.773563671779311290e-01 +6.449142857972468557e-01 8.303055746658202230e-01 +5.273839954317396383e-01 4.207125957999925392e-01 +1.762342332323126870e-01 4.393944051039871956e-01 +5.745506120050608567e-01 8.879382618314549669e-01 +4.712578561336036032e-01 1.881524570991451872e-01 +9.207864301979026278e-01 4.721265776693372995e-02 +5.593344057462952268e-01 1.697116004454535254e-02 +1.285041269667059893e-01 8.120506443759866189e-01 +1.373952958931543566e-01 9.128417874912870822e-01 +4.797927927505397649e-01 1.530851599816389630e-01 +8.988386452130850968e-01 4.336288161559882504e-01 +1.175939542535520133e-01 9.528203753474627780e-01 +6.853143805242034592e-02 9.899559204027282844e-01 +4.712369213667376711e-01 7.821498430902841248e-01 +7.660074850408304936e-02 4.397621066959265246e-01 +7.440848903115272472e-01 2.065276446615971206e-02 +9.690579279184877715e-01 9.038124159455668272e-01 +8.677520316851424775e-01 7.968841884723243396e-03 +7.150059719654046075e-01 1.902271594463174065e-01 +5.655878451055392775e-01 7.753998150871609374e-01 +3.771476847159315193e-01 5.314088967309682499e-02 +4.668356548341547896e-01 1.420268260268842120e-01 +7.733674248214993385e-01 9.191599567942703164e-01 +3.124880761435674392e-01 2.375488008927364447e-01 +8.856597436412122493e-01 3.864916902872622684e-01 +8.123284169990085957e-02 4.566267102606430539e-01 +7.695937128860208531e-01 1.536313709702308294e-03 +2.698634808995478140e-01 9.452697658139228576e-01 +8.680414416537420852e-01 5.957853601477216543e-01 +3.819804142131647540e-01 8.114547746686524565e-01 +1.639913461418038843e-01 6.069831370954037464e-02 +2.354814886069188651e-01 4.520356633556380510e-01 +4.477917297380786810e-01 5.222451552930984509e-01 +1.257116182264614856e-01 9.949725420833556155e-01 +3.832602759318003427e-01 7.113804967962201742e-01 +7.586015615124593880e-01 7.711228139572948725e-01 +3.173704442792059099e-01 4.460294653779597329e-01 +2.460033341796068473e-01 4.287079288236438357e-01 +1.148843304879398897e-02 4.950792280788270938e-01 +8.334981513439918555e-01 4.456172873974024728e-01 +7.482377364162514644e-01 6.628131569537574208e-01 +2.913923158975885164e-01 4.641078348173609536e-01 +2.498121371813314395e-01 9.509253979922854816e-01 +2.746487564911307411e-01 9.406035318420645774e-01 +2.820021904771764509e-01 8.776718603642663652e-01 +5.923569347935606633e-01 3.045901353139992551e-01 +9.832021019709979992e-02 1.133697466283044575e-01 +7.575807868368231723e-01 8.913146297652617678e-01 +1.200499803584796910e-01 1.625810412290606966e-01 +5.736411592810998217e-01 5.525297789065853893e-01 +5.973171245609736335e-01 3.796416167469480740e-01 +7.398704979355252842e-02 3.383623300418003854e-01 +1.189435600287831463e-01 5.426707019358838480e-01 +3.239727104423606452e-01 6.377551772898540916e-01 +6.514953263949951889e-02 6.437429913342302967e-01 +5.865338433254904116e-01 6.870900701704205638e-01 +6.764268671041336889e-01 7.829111812252187441e-01 +4.774031903028691515e-01 6.922412134057697175e-01 +5.553412318949063620e-01 4.094470218695424046e-01 +6.702597792296548507e-01 6.892860677710554995e-02 +2.608398070621155140e-01 7.359820956981057982e-01 +5.297362527411881850e-01 5.870368754892547614e-01 +4.495556863824029525e-01 2.055714641337329507e-01 +3.160105494353637745e-02 7.992069592538864065e-02 +8.715559155454377649e-01 2.190342555410986769e-01 +5.394581354897152181e-01 3.114766196383733643e-01 +2.646824402710644675e-01 2.644708182858351853e-01 +4.790240214000296426e-01 4.105537986554945729e-01 +2.339760287682213225e-02 2.111852087135975831e-01 +7.947030429688425324e-01 2.569986565779707188e-01 +7.311945670131854147e-01 5.285372826908655286e-01 +8.807686396342387924e-01 5.264275228182759836e-01 +7.115187496909101128e-01 9.584686142522280683e-03 +5.694667937026812732e-01 4.304140549976702701e-01 +1.114794538095644771e-01 9.012950227395988279e-02 +5.809108515980748733e-01 2.427899796970887092e-01 +7.480543830079553480e-01 2.669943659965072769e-01 +4.145656890299360686e-01 2.822511551775029881e-01 +7.357428154676989296e-01 3.761743886678923232e-01 +1.982175298962123300e-01 2.145578560427198189e-01 +8.030523291624790394e-01 6.733740429003510775e-01 +7.985431569558109066e-01 4.600792113841289943e-02 +3.626366184837356288e-01 6.675843961525843584e-01 +8.126970411895033664e-01 7.117348397517724301e-01 +8.543728600807161122e-01 8.763862220891506238e-01 +3.439831765635570360e-02 9.150777320789010849e-01 +9.309174853070691924e-01 9.078253988863690394e-01 +2.828498891789231751e-01 1.784800947868523391e-01 +7.400077699488807337e-01 2.858741810871601174e-02 +1.402374107513389978e-01 5.314238718751859158e-01 +6.926128819307058038e-01 4.246893979741607206e-01 +4.525197692624272072e-02 7.329835507263371408e-01 +8.631691826519332089e-01 6.734207211663955084e-01 +9.563620424281455978e-01 7.575804057530998170e-01 +1.590887085139082169e-01 3.736062867396472420e-03 +9.547665660630203499e-01 8.931008985953772239e-01 +8.692577998882099477e-01 5.616860296656878138e-02 +8.014509840718068157e-01 6.383966892081838207e-01 +5.139217873633956657e-01 4.388016676713035746e-01 +9.226322890909642860e-01 2.431893405078497938e-01 +9.383095756068016158e-02 8.649499041071950778e-01 +7.231752783516243177e-02 6.701048257886371706e-01 +7.440755967174422780e-01 2.229287066963419051e-01 +5.583000625748638646e-01 6.842632793187735096e-01 +7.291534933448510225e-01 6.448798050761928025e-01 +2.859697100748088694e-01 2.224053690704860253e-01 +6.983349285163419351e-01 2.650309054416272181e-01 +7.393835784473681283e-01 5.683017334047056890e-01 +3.281125648760345737e-01 8.528350494373129198e-01 +8.310856271089960279e-01 2.270109016976548100e-01 +2.613100192448142289e-01 8.067880232487177894e-01 +1.906703724254994459e-01 7.652906799978680397e-01 +5.021349848726747167e-01 2.467178437568799954e-01 +8.034236100027259964e-01 4.000588611526736393e-01 +8.280299802446300372e-02 7.593265680566931852e-01 +2.094310401346415329e-01 4.420948848064607528e-01 +7.605774736529228308e-01 1.983077295901262138e-01 +7.148071044953775433e-01 4.305571565443000459e-02 +2.539481688592389874e-01 9.950927771966910917e-02 +9.255422129163556333e-01 2.563625309341487357e-01 +5.552258590897310997e-01 5.247022466724801104e-01 +6.401721279147509769e-02 3.202030640431303699e-01 +9.425454907546026995e-02 2.259191401585394221e-01 +3.098009052525027007e-01 4.549292042872352670e-01 +6.332345557549374782e-01 9.977095323472180421e-01 +1.284826418370591528e-01 7.624886069340857109e-01 +8.607351624891471653e-01 3.622120301741800219e-01 +7.503446125384759524e-01 7.132646555367825325e-01 +9.682764691162804027e-03 1.231167267860211467e-01 +9.763878931189516885e-01 9.053692802601964740e-01 +7.884054615109020725e-01 2.946753766241453354e-01 +4.791241966535260133e-01 4.113693196348048309e-01 +7.700173289367712171e-01 6.297055736461186770e-01 +3.001887616662290981e-01 4.175023994163478136e-01 +6.801337579162639591e-01 4.004592520854307613e-01 +5.146554122787571073e-01 9.279315011947080194e-01 +8.100881518267603054e-01 2.754934359629414020e-01 +9.746559238693424110e-01 5.294821398894512532e-01 +1.700217155238656908e-01 5.075408119665145712e-01 +7.254889937549470202e-01 7.652061514710470913e-01 +8.876874194198891566e-01 4.016024778138125306e-01 +3.052129905495051254e-01 2.409756315254385495e-02 +8.416169013236775021e-01 2.660941948289545778e-01 +7.705144244837808243e-01 7.063650875259269712e-01 +2.967877194186074874e-02 4.786639120970216377e-01 +7.449817879894059525e-01 6.431431009682617717e-01 +2.651308362008266695e-01 7.995712666693107762e-01 +7.461182912574907400e-01 9.759385179074575589e-02 +6.370832629752050114e-01 9.176188303565137039e-01 +9.227711168620567062e-02 8.426289119799799110e-01 +2.210974527005129531e-01 9.353256284880046012e-01 +7.223180082661153634e-01 4.175342817226668224e-01 +3.862777346910406528e-02 7.601777755833890593e-01 +4.704990715940061685e-01 5.359549094288862392e-02 +7.849230423037100168e-01 2.544796317144470965e-02 +4.713624207062446780e-01 9.330759600987514890e-01 +5.934135104272557637e-01 9.149901698274116590e-01 +6.675146061662015073e-01 9.260833775048843419e-01 +9.769631128172938661e-01 9.314416488498484803e-01 +4.134763876571817542e-01 1.281813699400595752e-01 +9.860602052319290545e-01 7.985099557951832461e-01 +6.129852372669720717e-01 2.401219608273685857e-01 +4.026030944261168587e-01 3.503647793933621912e-01 +6.351573476392453621e-01 1.989757813093903094e-01 diff --git a/test/resources/linear_regression/data/training/train_label.txt b/test/resources/linear_regression/data/training/train_label.txt new file mode 100644 index 00000000..72c4b531 --- /dev/null +++ b/test/resources/linear_regression/data/training/train_label.txt @@ -0,0 +1,1000 @@ +1.173879881269729308e+00 +5.504434859143979564e-01 +1.967969238528174181e+00 +8.589601355883371703e-01 +2.016499430905083212e+00 +4.551615854065811373e-01 +1.336955142387406203e+00 +8.638061160360041013e-01 +1.085479327129481586e+00 +8.594006792394423000e-01 +1.055146252225392267e+00 +1.449562142600385739e+00 +1.588223748527864077e+00 +6.297908841899599874e-01 +1.730221903241557779e+00 +1.896881874877798024e+00 +2.252770463435295412e+00 +9.939957842977872327e-01 +2.659974681391862994e+00 +1.187294884247658722e+00 +1.076158930236267297e+00 +1.952360048577287621e+00 +2.067743829972779057e+00 +2.525033830513291999e+00 +1.412982843652544229e+00 +1.658203605203580233e+00 +2.449430719153214309e+00 +2.059661264881929466e+00 +2.397528333999177974e+00 +2.752196662027949214e-01 +4.170163297344672282e-01 +1.726605453596680828e+00 +1.225755382513542102e+00 +7.069797192540089315e-01 +1.306528447764938594e+00 +2.361691266019070312e+00 +1.755498284175463386e+00 +2.495565936679415842e+00 +1.799822588389151434e+00 +2.420695047184860282e+00 +8.524259231428169548e-01 +2.257207221867555624e+00 +2.077719088780114021e+00 +3.113519535706199548e-01 +2.077758204637165207e+00 +1.091808811270992940e+00 +1.860283715622871092e+00 +2.548730618832921291e+00 +1.057666634085676405e+00 +1.016992436418636903e+00 +2.103676054717679378e+00 +1.693787672624148755e+00 +6.187639532597964953e-01 +8.792977945212399060e-01 +9.723962136681100210e-01 +8.033896671303578652e-01 +6.322864380777462046e-01 +1.001107728825185728e+00 +1.138046752627182912e+00 +1.897675188342411445e+00 +2.142807212976721054e+00 +9.856357525365490035e-01 +1.251770758389617821e+00 +7.138425666139659453e-01 +4.005122038982038557e-01 +1.694356549907221510e+00 +2.198952354671682485e+00 +1.738355707751697343e+00 +5.063937732811988335e-01 +1.320278574647756331e+00 +2.171974665473326738e+00 +2.089941893244056104e+00 +9.340656187875175931e-01 +2.510001922566226451e+00 +7.553494812140771897e-01 +2.147651751453902413e+00 +8.832383306415670932e-01 +1.056210211692410139e+00 +2.356517157774772997e+00 +1.821760319560959296e+00 +2.141255948213409255e+00 +2.051702760583517637e+00 +1.745437322561655069e+00 +1.097818701390402829e+00 +9.998524176299571353e-01 +1.761456836009335669e+00 +2.328156096186934310e+00 +2.112402375732344240e+00 +1.543546189500369259e+00 +1.002242053952877576e+00 +1.910272557245922487e+00 +1.606436686104479872e+00 +1.295595375445004915e+00 +1.062645221334977297e+00 +1.937664886358040350e+00 +7.854071981826743487e-01 +2.150973448937828270e+00 +2.005951456445290493e+00 +2.329129526069435130e+00 +1.597941441540374763e+00 +1.701924716852330377e+00 +2.549899068182403195e+00 +2.046729052770414725e+00 +1.343140855094504760e+00 +1.202903372145732508e+00 +2.913019550441527938e-01 +1.588532232453245596e+00 +2.347586147341213181e+00 +1.913439817400202525e+00 +8.722718986531599050e-01 +2.681575983226875515e+00 +2.779178136142151700e+00 +2.746703238492141708e+00 +1.340303630479340313e+00 +1.504545137542518063e+00 +1.456429847337426420e+00 +2.180125518961956832e+00 +1.851881827744466857e+00 +2.115856003433038701e+00 +1.374872258940684544e+00 +7.159394191398474883e-01 +1.219837856151128497e+00 +1.998353182103732362e+00 +2.508094955849806240e+00 +2.470819021045394859e+00 +1.476778953626433744e+00 +2.617447223783643473e+00 +1.511823626077690719e+00 +9.573463052995655476e-01 +1.919801884216704391e+00 +2.037003837099256920e+00 +1.800845298074771783e+00 +6.865598619546003922e-01 +3.664549166607422581e-01 +5.372203479116118841e-01 +1.635904775525448329e+00 +1.778480691555145210e+00 +1.455214023363541997e+00 +1.814525886400860877e+00 +1.160231501643246288e+00 +1.140691556254791017e+00 +2.862737276090214777e+00 +1.630970326409933779e+00 +2.276550038141895715e+00 +1.663802535973283980e+00 +1.328076458373490976e+00 +1.834127615131421862e+00 +1.098311680215708019e+00 +1.153075818932752128e+00 +1.388804739371015007e+00 +2.991385788634141063e-01 +5.083670055288082024e-01 +1.559772122516119186e+00 +2.145675616796734797e+00 +1.964923336372092599e+00 +1.214544169587685962e+00 +1.681593060767285319e+00 +1.713822345134283776e+00 +3.713016956072753283e-01 +1.662702418162075091e+00 +1.748503884080097492e+00 +1.267885264723675576e+00 +2.131180582039999827e-01 +1.680828097713884084e+00 +1.167201846059875425e+00 +1.151339919208131946e-01 +1.766540175394926138e+00 +1.834938066634276943e+00 +2.218698053822988570e+00 +1.468071374583334254e+00 +1.305184748012974794e+00 +1.746928036436675491e+00 +2.813856784719652282e+00 +2.256934483346220599e+00 +1.751669176021913188e+00 +1.487526002935055658e+00 +1.686645852178191296e+00 +1.034321157762164711e+00 +6.673613280593756469e-01 +1.483595063429676841e+00 +2.016994485769395773e+00 +1.023847618687068195e+00 +1.288228338877522638e+00 +1.331542167648706254e+00 +1.677296918312751739e+00 +1.181663616812054229e+00 +2.265498619011048653e+00 +5.003079754245053357e-01 +2.565495073934508152e+00 +6.328237540968817409e-01 +4.939376292654510392e-01 +6.559356473489551531e-01 +1.112229631719227729e+00 +2.528463771080846190e+00 +2.085385803933895499e+00 +2.222069586191548929e+00 +5.782518762652313349e-01 +1.878053860072706804e+00 +1.738421744229697197e+00 +1.596371385241235341e+00 +1.430621517422325129e+00 +1.369909356079315632e+00 +8.064343321603396619e-01 +6.991506368230835378e-01 +1.652189108812733842e+00 +2.613473347155766291e+00 +2.312067464845486420e+00 +1.561089514040984172e-01 +2.473394587219040819e+00 +2.946947451441221943e+00 +7.919280574742105783e-01 +2.312496575860101267e+00 +1.486005179375582852e+00 +2.723873041543158013e+00 +2.475202224025563780e+00 +1.376625091253731004e+00 +2.591758621827306541e+00 +1.149150378870142264e+00 +8.556946440465703629e-01 +1.399880786550418366e+00 +6.261756657870711784e-01 +1.408897650173655869e+00 +1.527297919236833224e+00 +1.519390989731806751e+00 +1.949557611085402264e+00 +1.348176332780437470e+00 +1.507078362554509443e+00 +9.159081623443511289e-01 +1.718768866357398561e+00 +3.517787492613223543e-01 +2.797421150604573192e+00 +1.986767190933481197e+00 +1.006571689481993159e+00 +9.879603714192131125e-01 +2.441732835062806029e+00 +1.846031688203879195e+00 +1.104369115262310963e+00 +7.256401112501624784e-01 +6.557967170906693344e-01 +9.632465886228188667e-01 +1.451068177399697401e+00 +2.598871767727118254e+00 +2.046318009485862621e+00 +1.795606007957485817e+00 +1.057471578980419658e+00 +2.366119543812617909e+00 +1.339123086063498391e+00 +6.543025629927539111e-01 +1.446940224157419186e+00 +1.744280326838958839e+00 +6.942090160172469382e-01 +1.124117718651069575e+00 +1.660910693075998035e+00 +2.237964150820927589e+00 +2.200528669027618456e+00 +1.113067887067772599e+00 +1.802905386991187076e+00 +2.447610810776684076e+00 +1.460729299183819663e+00 +1.356247191751714265e+00 +4.989171574752901961e-01 +1.381402982254127387e+00 +1.639021162451775870e+00 +8.011979994786000425e-01 +1.888157265219889691e+00 +2.016021974450422860e+00 +1.628790310906064853e+00 +1.703333142660445443e+00 +1.627931441334910412e+00 +2.729970073668777530e+00 +1.364751755835527991e+00 +2.154922306279710043e+00 +2.342964314283614513e+00 +7.318351608408015796e-01 +2.811052218167833239e+00 +1.346593209747927844e+00 +2.509751967562520392e+00 +8.593630985472334727e-01 +1.301228724251420665e+00 +1.371447639278202901e+00 +1.827437554520149998e+00 +2.275400756352595977e+00 +2.660968699121146752e-02 +2.615814352756792971e+00 +8.577950229893057665e-01 +2.343417381368980923e+00 +2.343972516129102956e+00 +1.793780524668240339e+00 +2.549600761110212499e+00 +1.568228596430688793e+00 +2.063015184653625766e+00 +1.259555613300492150e+00 +2.100969805524830480e-01 +9.104501295723084198e-01 +1.653546865018034939e+00 +1.507181203688154003e+00 +2.473766657644221212e+00 +2.620820380204671451e-01 +7.082771834014519596e-01 +2.665508237493210864e+00 +2.841546922137345721e+00 +1.970447454252574504e+00 +1.834788887943913593e+00 +8.252245697688388448e-01 +1.954388144218484946e+00 +2.485973734923373701e+00 +1.787328001793979659e+00 +1.277829005821169872e+00 +8.751815304474366286e-01 +3.277102443278253618e-01 +2.475628649034620654e+00 +2.660798737917849932e-01 +2.574399522857247824e+00 +2.511127263392686881e+00 +2.520115119022515948e+00 +1.336053494104742612e+00 +2.261947425533872735e+00 +1.540509108967685004e+00 +1.455811758830833247e+00 +2.286349372696145021e+00 +1.305420306028152400e+00 +1.355854466946833758e+00 +1.150436591164989419e+00 +2.651942876909294533e+00 +1.274125026932515281e+00 +1.618169520861333366e-01 +4.806808633528593111e-01 +1.550992134873364403e+00 +1.716686569821530339e+00 +1.292031199831977339e+00 +1.350528483951797920e+00 +2.492902708336159456e+00 +2.584231113867022422e+00 +1.298468674823404445e+00 +1.285974778764294335e+00 +1.207974575798464789e+00 +1.197223857025865978e+00 +1.631102046707602948e+00 +1.386522127738676335e+00 +1.680148925026273998e+00 +1.757659772885101557e+00 +2.901712673719688551e-01 +1.451717095133708169e+00 +2.232522291545856685e+00 +2.105591126785151701e+00 +9.176278824180795768e-01 +1.625620926349989226e+00 +3.827255049223408090e-01 +1.331714705209255278e+00 +6.594794380702573289e-01 +2.750400466603164062e+00 +8.338520458703430061e-01 +1.292420489452435595e+00 +7.528401722924288331e-01 +1.696968480798160650e+00 +1.792044961452501806e+00 +2.368143262540235217e+00 +4.719125751840170802e-01 +1.543265465256557833e+00 +2.541751510339088416e+00 +2.262756151821248451e+00 +2.059976080676813037e-01 +5.540046447009239206e-01 +2.107055515481598196e+00 +1.530643247295702114e+00 +4.541271078057000921e-01 +2.335768081774222793e+00 +2.269841912070810430e+00 +1.700642397875732348e-01 +8.502232299484161482e-01 +1.145240147059961799e+00 +4.953961024537766322e-01 +1.969326501772505011e+00 +9.561217942668636116e-01 +3.287429118227878089e-01 +1.282542735685295643e+00 +8.069666952958997053e-01 +1.545207973883754082e+00 +2.681857387915953961e+00 +1.602040687390742901e+00 +4.107174996764475505e-01 +1.956241967359567901e+00 +1.181625157404662563e+00 +2.349470716296988559e+00 +1.502357056046616846e+00 +2.076806413469158397e+00 +1.293700387611635305e+00 +6.913280269389017185e-01 +1.609070913439037964e+00 +2.440793443040196742e+00 +7.132387516612760381e-01 +4.725559777361362235e-01 +2.038735761424829818e+00 +2.622507959798332511e+00 +1.832644895355015180e+00 +8.727738179060766655e-01 +1.130415793950272585e+00 +1.945668322846989984e+00 +1.892007393243257507e+00 +1.548284522067327140e+00 +2.472738680243389275e+00 +1.678105679366458158e+00 +1.416780797319425300e+00 +1.588395345571585349e+00 +5.524706657759302919e-01 +1.604239764160202641e+00 +2.492382453593043223e+00 +4.847600024111694550e-01 +1.747127853509677919e+00 +1.387331648012998908e+00 +2.484015963305008068e+00 +6.410422074243269375e-01 +2.296468295430637241e+00 +2.694331543056469602e+00 +7.639015917265200573e-01 +9.618531869194265616e-01 +2.290656922511926652e+00 +2.094636721853177175e+00 +2.041344578758732098e+00 +1.944067257934710913e+00 +2.501527836486606482e+00 +8.896668472543728390e-01 +2.449725830902194978e+00 +1.352569954042188716e+00 +6.026631907807673549e-01 +5.770867421051336654e-01 +7.381867064135871903e-01 +1.024948337673484922e+00 +7.847883218872027111e-01 +1.463908544904589659e+00 +9.840660304220860688e-01 +4.586253513612984500e-01 +1.244782235199795917e+00 +1.446772428830280788e+00 +8.649548171098527805e-01 +1.534302490286085563e+00 +2.084621285610106511e+00 +1.146480406745520542e+00 +1.012650854765739350e+00 +1.232107752994521910e+00 +4.922820600814464820e-01 +7.607704069240902589e-01 +8.101987602541473876e-01 +2.270533794416124351e+00 +2.430434818038419431e-01 +1.509570056556383211e+00 +2.448451244569655127e+00 +1.845967698535854185e+00 +1.094119798499495477e+00 +2.343234102652701489e-01 +2.159712336066048621e+00 +7.995439993115548427e-01 +8.024540943925556613e-01 +1.999659914046446163e+00 +1.360828889986862578e+00 +2.233087353401928876e+00 +9.491125274029978964e-01 +1.445853152399078301e+00 +8.265391854648840297e-01 +6.848584597249934847e-01 +1.628766682831206403e+00 +1.173018211347239603e+00 +1.555645034673222682e+00 +1.895507142868746886e+00 +1.579504858147093849e+00 +1.431770497020115052e+00 +2.135187442562147098e+00 +1.273767138307527746e+00 +2.051222495046800098e+00 +1.406711663948392754e+00 +1.141819742460649323e+00 +1.548080661312661199e+00 +5.237868333490498518e-01 +1.187082318414113935e+00 +2.091050454530566594e+00 +1.019384112671266829e+00 +2.093265410037372209e+00 +2.396083385507750574e+00 +9.403886533613309284e-01 +1.322551770845137753e+00 +4.919005539730100240e-01 +1.170605478156428125e+00 +1.776481655704252205e+00 +1.698828706865550320e+00 +1.851385793621924236e+00 +7.147066254064590174e-01 +5.576077505978211946e-01 +1.895688101567057249e+00 +2.240612187837129898e+00 +1.347689888804413982e+00 +4.059259148547945673e-01 +1.993782909349323784e+00 +8.979701369757964180e-01 +1.805112566225916648e+00 +1.173564326341602460e+00 +2.209970237248955360e+00 +2.200262669195315102e+00 +1.414352858272990066e+00 +1.421147068155934612e+00 +1.899892483588903058e-01 +1.893662014131090920e+00 +1.332962610340685305e+00 +2.247253577527448698e+00 +1.457348069141363833e+00 +8.182865823608106437e-01 +1.674742885946259641e+00 +1.945512972655398354e+00 +2.060641448459671565e+00 +2.253733634575407763e+00 +5.760539763863702367e-01 +8.577428253455504636e-01 +1.004971742367929632e+00 +2.053797587815739512e+00 +1.068384423838911967e+00 +5.964868128009285231e-01 +1.017586506790398637e+00 +8.251292507051257719e-01 +3.205391958703156519e-01 +1.775803434586376683e+00 +6.342725587181106128e-01 +6.066741904936910901e-01 +3.509021401892409298e-01 +2.536747931465455519e+00 +5.741876837581936233e-01 +2.046864592634350721e+00 +2.055650148710819192e+00 +8.358711458675408545e-01 +9.181334155201634761e-01 +2.061585586621737853e+00 +2.114898253803881634e+00 +2.425005743862729712e+00 +1.588388648011259230e+00 +1.827470922093908712e+00 +1.268680241845256518e+00 +1.469694482760877996e+00 +5.579454454769142702e-01 +6.377940923595946110e-01 +2.014939938602088532e+00 +1.706161676762852908e+00 +1.083729005732912309e+00 +1.650936483950469658e+00 +1.049483536763069935e+00 +1.728691623969616398e+00 +2.014564624510136071e+00 +2.248018615294701483e+00 +2.503604651142930848e+00 +1.781030120665335392e+00 +9.990107264465927228e-01 +7.565685302557905167e-01 +7.633215188646546023e-01 +7.353480446182824037e-01 +1.587153509239275895e+00 +2.519466477161675932e+00 +1.285229630814623381e+00 +1.884020171775280250e+00 +2.137549970206741534e+00 +2.263540014774511722e+00 +1.346228231611189319e+00 +2.542680422610639646e+00 +1.451920899328499104e+00 +1.079305178903446905e+00 +7.603347936840297860e-01 +7.148553035520865961e-01 +2.828602786117167867e+00 +1.425777086636739410e+00 +1.016272888351824832e+00 +2.360962576109677080e+00 +1.207887239510876842e+00 +1.743374470370741669e+00 +1.501803282145279805e+00 +2.238840655166673521e+00 +2.397163222977186869e+00 +1.649499789914937065e+00 +2.058603520766119743e+00 +1.321528670754451351e+00 +7.344330045099856097e-02 +8.343318465568284603e-01 +1.082516392000663519e+00 +7.343909415755127679e-01 +1.118819986087318252e+00 +2.699884785755509053e+00 +1.350405697135487193e+00 +9.624393632912058294e-01 +6.031622040880794122e-01 +1.630315065607423008e+00 +1.143817503146981052e+00 +7.826473535070175291e-01 +1.824099033273324899e+00 +1.209718263897805013e+00 +1.549990565022734135e+00 +1.482824175279684908e+00 +2.241886716269428703e+00 +1.138519448236785969e+00 +1.329840997828564575e+00 +1.442933532007003805e+00 +1.769246433661130435e+00 +1.793771713383907507e+00 +2.035001385602914770e+00 +1.963612511886695167e+00 +2.166726665119577788e+00 +7.475676344954031993e-01 +1.997090464559396761e+00 +7.045514307291398381e-01 +1.397496971703811930e+00 +1.708420223490398460e+00 +1.832362203850108617e+00 +3.376959179213833151e-01 +1.765782201082102043e+00 +1.732516808536954445e+00 +1.015756275087577709e+00 +2.074352866995790468e+00 +2.079157778378828425e+00 +1.704769078187722720e+00 +1.482155920562505891e+00 +1.834341963586612945e+00 +1.084444126071073100e+00 +1.741169044907861618e+00 +2.519960997707053174e+00 +1.570267726931172003e+00 +2.366040829503849974e+00 +7.303840970057017268e-01 +9.930543665907164419e-01 +1.320964590804518668e+00 +2.119406155145141746e+00 +1.624125089611779593e+00 +1.703606072150509521e+00 +2.424541731414554313e+00 +2.363100915943689539e+00 +1.858833218979882407e+00 +1.743800083923357169e+00 +1.776680460868375722e+00 +8.736676871444236658e-01 +7.997984822869514954e-01 +1.638567947771597755e+00 +1.920158827678556390e+00 +2.863186081963841056e-01 +1.132852329172159944e+00 +1.656063595154057655e+00 +2.552313125501352253e+00 +2.309260204156736496e+00 +3.360988946728146187e-01 +1.951320592562752143e+00 +1.578812016227276738e-01 +3.159733955972311437e-01 +1.244966542179410496e+00 +1.157125709531469315e+00 +2.146984944168823795e+00 +3.110745925028488568e-01 +1.508317866577684185e-01 +1.538086625068097124e+00 +2.338940918608976460e+00 +2.306620206008131113e+00 +1.721795805934041912e+00 +4.430675279461921967e-01 +7.451524698425047166e-01 +1.230959865658318275e+00 +2.253631503484952336e+00 +2.961565926811442484e-01 +1.414425657758706656e+00 +1.886863486214273911e+00 +1.831938994265982723e+00 +9.328934961580596097e-01 +9.462164349280410836e-01 +1.712589714289662002e+00 +9.442318380219176799e-01 +1.337326413629875699e+00 +1.527010581213674945e+00 +7.906457886257330925e-01 +8.585053860770119893e-01 +1.210835098809594612e+00 +1.282997503074416734e+00 +1.307542088231384181e+00 +1.671845095201549558e-01 +2.457275601939672871e+00 +6.737659541886441561e-01 +2.022362162737915359e+00 +1.316503244681977991e+00 +1.315379154043433640e+00 +2.326380396340538059e+00 +1.881099989309446663e+00 +2.078388740793908784e+00 +1.675680614571675386e+00 +2.098850256368162803e+00 +7.537615260779604309e-01 +3.811592920887689750e-01 +2.247867394445636435e+00 +1.795535759510295204e+00 +2.732566321192225089e+00 +1.012724436235092229e+00 +1.819819341964286519e+00 +2.812841036646871906e+00 +7.609909749121543943e-01 +1.681998611182385472e+00 +2.097159506406550911e+00 +1.723392060743437870e+00 +4.148393491654789322e-01 +5.288785057228253539e-01 +1.904654283424046746e+00 +1.929226825934533984e+00 +3.489417960270546626e-01 +1.849764716957146060e+00 +1.947581759184833228e-01 +1.111693833097633588e+00 +5.296499745724750907e-01 +1.140016200344680763e+00 +2.203200257251838767e+00 +2.500564085681767246e+00 +1.792729175205145742e+00 +4.813671886901659258e-01 +1.163579496832382709e+00 +1.744197585646138915e+00 +1.727212592363701837e+00 +1.694555159492121987e+00 +1.154364526236362209e-01 +2.005005300790784428e+00 +2.144157627991247317e+00 +1.211084403916262042e+00 +2.269174934827983048e+00 +1.523715902543874012e+00 +1.887306347236870074e+00 +2.170732174704789408e+00 +9.869773581565483012e-01 +1.904890862446244615e+00 +1.972676272383736551e+00 +2.950111741724962267e-01 +9.587823986925861908e-01 +1.614950169860075446e+00 +1.233920173649270868e+00 +5.575044043934533899e-01 +2.301983018322603325e+00 +9.321504293497161475e-01 +9.263487511239190120e-01 +1.164167309753915358e+00 +2.662830792411879699e+00 +1.604067850317884680e+00 +8.929948307155568532e-01 +7.965335899063215042e-01 +1.937824497402587332e+00 +1.131175478296996939e+00 +7.358472426769527841e-01 +1.709860516630478777e+00 +1.421794972292863779e+00 +9.329440654380409814e-01 +1.949171753429876475e+00 +2.677392574765950162e+00 +1.714466525587316825e+00 +1.361832295169157936e+00 +5.183945164039927000e-01 +2.117787520792025369e+00 +9.606681920892048421e-01 +1.262964776178046922e+00 +1.848044645087290938e+00 +8.582408972038951811e-01 +7.803952645564810053e-01 +5.921046815180173173e-01 +1.952605331756958940e+00 +2.035511667524434465e+00 +2.105256410783559851e+00 +8.052130732669764424e-01 +2.500913813312017897e+00 +1.143557681604180365e+00 +1.586113042875624712e+00 +1.012871239189666817e+00 +1.711156608818773961e+00 +1.678781976761197337e+00 +4.487275911377179849e-01 +9.364599427903850248e-01 +8.024611036598076064e-01 +1.255425529734055301e+00 +2.172637800575796696e+00 +1.059557931795808372e+00 +2.722317158508352541e+00 +1.676744484616954800e+00 +7.877478822414589743e-01 +1.769162920995854371e+00 +1.962716819746739061e+00 +1.534632884525763075e+00 +1.471628511796965411e+00 +2.333315172589743192e+00 +1.757836938371782676e+00 +2.338094515663083683e+00 +1.947023547541835153e-01 +2.370824783645905098e-01 +2.076032405130704195e+00 +1.716653486541023010e+00 +9.989198903715239553e-01 +2.137641233303767763e+00 +2.372175436957505656e+00 +2.430827902256544881e+00 +1.923384604893873639e+00 +1.041606302972814202e+00 +2.133352228812533724e+00 +8.617085408977450678e-01 +6.627173271269759525e-01 +2.242761272077678836e+00 +8.704290704068707862e-01 +7.675686684415085059e-01 +6.366589855753786997e-01 +1.391144521467354567e-01 +1.319807096198771790e+00 +1.031975262234262836e+00 +2.410587471645578184e+00 +1.958890987077131385e+00 +1.389887211119905430e+00 +2.278567789804396782e+00 +2.685588651005233807e+00 +1.931556826816241745e+00 +1.304023096682751071e+00 +1.090517237142568430e+00 +9.867258298530903193e-01 +1.671272885580049827e+00 +1.399283467146648974e+00 +2.287113794331504302e+00 +1.303569760827072654e+00 +2.190335673290522145e+00 +1.208509951623783163e+00 +7.604877259272841972e-01 +1.932413102107187974e+00 +2.462266905619054036e+00 +2.703672667064165758e+00 +2.345156667640663217e-01 +8.888206857464778787e-01 +2.088552154389632598e+00 +2.522866386944046635e+00 +2.727247955036057103e+00 +1.793484139023393542e+00 +2.640740085332631271e-01 +1.921332349594581990e+00 +1.894982639667482038e+00 +1.387674723330024396e+00 +1.826978508908853049e+00 +1.688661546252278045e+00 +1.513939721228410917e+00 +1.701910609155345977e+00 +1.731029054552289725e+00 +4.025213283803100728e-01 +2.305525435128887413e+00 +1.368809187031724717e+00 +1.055023043440287189e+00 +2.350427135667970902e+00 +8.475627703318939776e-01 +1.015211745731769977e+00 +5.932767258353859319e-01 +1.752605415718679227e+00 +1.963078870875728521e+00 +7.859631127138176909e-01 +1.766096277525061709e+00 +2.023234704948477791e+00 +2.048443278857877026e+00 +2.035536607547305810e+00 +9.561249618959360985e-01 +7.853904192438466714e-01 +2.776682759809621537e+00 +8.836897154545889643e-01 +1.095460290858039532e+00 +2.116387475279861263e+00 +4.834294640621251693e-01 +7.508893068879232136e-01 +2.611687338410039860e+00 +7.875856779290403287e-01 +1.658643124215736897e+00 +9.944862622211869674e-01 +7.726663403054254697e-01 +2.160403012527393418e+00 +2.059612161949185172e+00 +2.004889963550469556e+00 +2.853879735608846335e-01 +1.139552815318194856e+00 +1.492282040324275583e+00 +2.115656702393172495e+00 +1.806021269524240580e+00 +2.300847189427049244e+00 +1.209429375035125265e+00 +1.103419191826894519e+00 +1.001646889206448066e+00 +1.724732726138796801e+00 +2.073864050323766417e+00 +1.219607985532310312e+00 +2.151662933165902292e+00 +2.155855820175260007e+00 +2.037345911205709292e+00 +1.201537205421559174e+00 +3.250597034537087149e-01 +2.540210046367346486e+00 +4.452120628166010841e-01 +1.678700717094270711e+00 +1.356600358054869782e+00 +7.507117098771532993e-01 +1.204284963900550842e+00 +1.599483065022068828e+00 +1.352635515307960112e+00 +1.960713983666331650e+00 +2.242249229554571066e+00 +1.861885617114408475e+00 +1.374235275633991282e+00 +8.081169927838659506e-01 +1.732803998458327221e+00 +1.703810003719697708e+00 +8.606986146498688539e-01 +1.914424467943136587e-01 +1.309624426627635119e+00 +1.162411374766461947e+00 +7.936240768427348380e-01 +1.300131618711018788e+00 +4.457680203040172984e-01 +1.308700356124783859e+00 +1.788269132394916472e+00 +1.933623685270790649e+00 +7.306881219759546742e-01 +1.430294903698021702e+00 +2.917384583574842427e-01 +1.066490810992252403e+00 +1.282043115000969902e+00 +9.790679993849420448e-01 +1.488091592803483465e+00 +6.273332419816519678e-01 +2.149800414963181083e+00 +8.905589992326367055e-01 +1.697805410788904457e+00 +2.236166720693048227e+00 +2.607145304259017138e+00 +1.864553781814157762e+00 +2.746568283079807493e+00 +6.398100787526278532e-01 +7.971826061663127572e-01 +1.203085154501710718e+00 +1.541991677879027245e+00 +1.511219078378917002e+00 +2.210010624984724004e+00 +2.471522853934345232e+00 +1.665608342487011617e-01 +2.740968363253774687e+00 +9.815950058213475105e-01 +2.078244362488174346e+00 +1.391525122706002815e+00 +1.409010970106663763e+00 +1.823730765775070317e+00 +1.412527179412436773e+00 +1.189933010110125977e+00 +1.926826621212410995e+00 +2.018913103497236516e+00 +7.307804482157809201e-01 +1.228396739399596260e+00 +1.875987045256779506e+00 +2.033782663750660191e+00 +1.285107430504305537e+00 +1.874886065742249919e+00 +1.721251732421235525e+00 +9.955706723864347074e-01 +1.603541332308073386e+00 +1.601456134137849485e+00 +1.093620809747563039e+00 +1.157192932833175369e+00 +8.009185358042375524e-01 +4.529667242985772058e-01 +1.438267274784652994e+00 +1.604630352434691432e+00 +7.044233408777358374e-01 +5.460928293925391142e-01 +1.219659313826973346e+00 +2.628653620449373562e+00 +1.653459855705230463e+00 +1.585159222837507098e+00 +2.176873923612041128e+00 +2.559162182632050975e-01 +2.787126453639344525e+00 +1.377756214759192854e+00 +1.301862835923135675e+00 +2.029428476229008460e+00 +1.135193560498924725e+00 +1.481052262087125371e+00 +2.370518414668173257e+00 +1.361075023752642998e+00 +2.033620203648244917e+00 +1.185103339456894833e+00 +2.255901296697041314e+00 +1.690892375047514218e+00 +3.534081168545928353e-01 +1.373805290981586769e+00 +2.183244599535634656e+00 +9.870065961359040241e-01 +2.031267989925929385e+00 +1.864273369539448222e+00 +9.413059948389822518e-01 +2.472320923688232419e+00 +1.777534935646165604e+00 +2.091748709676521933e+00 +1.557386571711449008e+00 +1.558983324635882184e+00 +5.776900534797834164e-01 +8.358189686465994361e-01 +2.337514340903747545e+00 +2.423393850082079304e+00 +2.519681361175970302e+00 +2.839846410516990716e+00 +6.698391275373009046e-01 +2.583080116822295658e+00 +1.093229158921709132e+00 +1.103332653212841130e+00 +1.033108910258025981e+00 diff --git a/test/resources/linear_regression/linear_regression.py b/test/resources/linear_regression/linear_regression.py index 26b73fdb..d4ebfc91 100644 --- a/test/resources/linear_regression/linear_regression.py +++ b/test/resources/linear_regression/linear_regression.py @@ -1,24 +1,29 @@ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing # permissions and limitations under the License. +from __future__ import absolute_import +import argparse import logging +import json import os import mxnet as mx import numpy as np +from sagemaker_mxnet_container.training_utils import save -def train(num_cpus, num_gpus, channel_input_dirs, **kwargs): + +def train(num_cpus, num_gpus, channel_input_dirs): """ ensure mxnet is fully functional by training simple model see http://mxnet.incubator.apache.org/tutorials/python/linear-regression.html @@ -76,3 +81,18 @@ def _get_context(cpus, gpus): logging.info("mxnet context: %s" % str(ctx)) return ctx + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) + parser.add_argument('--input-channels', type=str, default=json.loads(os.environ['SM_TRAINING_ENV'])['channel_input_dirs']) + + args = parser.parse_args() + + num_cpus = int(os.environ['SM_NUM_CPUS']) + num_gpus = int(os.environ['SM_NUM_GPUS']) + + model = train(num_cpus, num_gpus, args.input_channels) + save(args.model_dir, model) diff --git a/test/resources/mnist/mnist.py b/test/resources/mnist/mnist.py index 9902cb5a..dc33e46d 100644 --- a/test/resources/mnist/mnist.py +++ b/test/resources/mnist/mnist.py @@ -1,17 +1,18 @@ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing # permissions and limitations under the License. - +import argparse import gzip +import json import logging import os import struct @@ -19,6 +20,8 @@ import mxnet as mx import numpy as np +from sagemaker_mxnet_container.training_utils import scheduler_host + def load_data(path): with gzip.open(find_file(path, "labels.gz")) as flbl: @@ -48,16 +51,17 @@ def build_graph(): return mx.sym.SoftmaxOutput(data=fc3, name='softmax') -def get_train_context(num_gpus): +def get_training_context(num_gpus): if num_gpus: return [mx.gpu(i) for i in range(num_gpus)] else: return mx.cpu() -def train(channel_input_dirs, hyperparameters, hosts, current_host, num_gpus, **kwargs): - (train_labels, train_images) = load_data(os.path.join(channel_input_dirs['train'])) - (test_labels, test_images) = load_data(os.path.join(channel_input_dirs['test'])) +def train(batch_size, epochs, learning_rate, num_gpus, training_channel, testing_channel, + hosts, current_host, model_dir): + (train_labels, train_images) = load_data(training_channel) + (test_labels, test_images) = load_data(testing_channel) # Data parallel training - shard the data so each host # only trains on a subset of the total data. @@ -68,22 +72,60 @@ def train(channel_input_dirs, hyperparameters, hosts, current_host, num_gpus, ** end = start + shard_size break - batch_size = 100 train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size, shuffle=True) val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size) + logging.getLogger().setLevel(logging.DEBUG) + kvstore = 'local' if len(hosts) == 1 else 'dist_sync' - mlp_model = mx.mod.Module( - symbol=build_graph(), - context=get_train_context(num_gpus)) + + mlp_model = mx.mod.Module(symbol=build_graph(), + context=get_training_context(num_gpus)) mlp_model.fit(train_iter, eval_data=val_iter, kvstore=kvstore, optimizer='sgd', - optimizer_params={ - 'learning_rate': float(hyperparameters.get("learning_rate", 0.1))}, + optimizer_params={'learning_rate': learning_rate}, eval_metric='acc', batch_end_callback=mx.callback.Speedometer(batch_size, 100), - num_epoch=1) - return mlp_model + num_epoch=epochs) + + if current_host == scheduler_host(hosts): + save(model_dir, mlp_model) + + +def save(model_dir, model): + model.symbol.save(os.path.join(model_dir, 'model-symbol.json')) + model.save_params(os.path.join(model_dir, 'model-0000.params')) + + signature = [{'name': data_desc.name, 'shape': [dim for dim in data_desc.shape]} + for data_desc in model.data_shapes] + with open(os.path.join(model_dir, 'model-shapes.json'), 'w') as f: + json.dump(signature, f) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('--batch-size', type=int, default=100) + parser.add_argument('--epochs', type=int, default=10) + parser.add_argument('--learning-rate', type=float, default=0.1) + + parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) + parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN']) + parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST']) + + parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST']) + parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS'])) + + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + + num_gpus = int(os.environ['SM_NUM_GPUS']) + + train(args.batch_size, args.epochs, args.learning_rate, num_gpus, args.train, args.test, + args.hosts, args.current_host, args.model_dir) diff --git a/test/resources/onnx/code/onnx_export_import.py b/test/resources/onnx/code/onnx_export_import.py new file mode 100644 index 00000000..b3fa959e --- /dev/null +++ b/test/resources/onnx/code/onnx_export_import.py @@ -0,0 +1,75 @@ +# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import argparse +import json +import os + +import mxnet as mx +from mxnet.contrib import onnx as onnx_mxnet +import numpy as np +import onnx +from onnx import checker + + +def _read_data_shapes(path, preferred_batch_size=1): + with open(path, 'r') as f: + signature = json.load(f) + + data_shapes = [] + + for s in signature: + shape = s['shape'] + + if preferred_batch_size: + shape[0] = preferred_batch_size + + data_shapes.append(shape) + + return data_shapes + + +def _assert_onnx_validity(model_path): + model_proto = onnx.load_model(model_path) + checker.check_graph(model_proto.graph) + + +def main(training_dir, model_dir): + sym = os.path.join(training_dir, 'model-symbol.json') + params = os.path.join(training_dir, 'model-0000.params') + data_shapes = _read_data_shapes(os.path.join(training_dir, 'model-shapes.json')) + + output_path = os.path.join(model_dir, 'model.onnx') + converted_path = onnx_mxnet.export_model(sym, params, data_shapes, np.float32, output_path) + _assert_onnx_validity(converted_path) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) + parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN']) + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + main(args.train, args.model_dir) + + +def model_fn(model_dir): + sym, arg_params, aux_params = onnx_mxnet.import_model(os.path.join(model_dir, 'model.onnx')) + mod = mx.mod.Module(symbol=sym, data_names=['data'], label_names=None) + mod.bind(for_training=False, data_shapes=[('data', [100, 1, 28, 28])]) + mod.set_params(arg_params=arg_params, aux_params=aux_params) + return mod diff --git a/test/resources/onnx/mxnet_module/model-0000.params b/test/resources/onnx/mxnet_module/model-0000.params new file mode 100644 index 00000000..21479bf1 Binary files /dev/null and b/test/resources/onnx/mxnet_module/model-0000.params differ diff --git a/test/resources/onnx/mxnet_module/model-shapes.json b/test/resources/onnx/mxnet_module/model-shapes.json new file mode 100644 index 00000000..b10c2926 --- /dev/null +++ b/test/resources/onnx/mxnet_module/model-shapes.json @@ -0,0 +1 @@ +[{"name": "data", "shape": [100, 1, 28, 28]}] diff --git a/test/resources/onnx/mxnet_module/model-symbol.json b/test/resources/onnx/mxnet_module/model-symbol.json new file mode 100644 index 00000000..5478ff83 --- /dev/null +++ b/test/resources/onnx/mxnet_module/model-symbol.json @@ -0,0 +1,111 @@ +{ + "nodes": [ + { + "op": "null", + "name": "data", + "inputs": [] + }, + { + "op": "Flatten", + "name": "flatten0", + "inputs": [[0, 0, 0]] + }, + { + "op": "null", + "name": "fullyconnected0_weight", + "attrs": {"num_hidden": "128"}, + "inputs": [] + }, + { + "op": "null", + "name": "fullyconnected0_bias", + "attrs": {"num_hidden": "128"}, + "inputs": [] + }, + { + "op": "FullyConnected", + "name": "fullyconnected0", + "attrs": {"num_hidden": "128"}, + "inputs": [[1, 0, 0], [2, 0, 0], [3, 0, 0]] + }, + { + "op": "Activation", + "name": "activation0", + "attrs": {"act_type": "relu"}, + "inputs": [[4, 0, 0]] + }, + { + "op": "null", + "name": "fullyconnected1_weight", + "attrs": {"num_hidden": "64"}, + "inputs": [] + }, + { + "op": "null", + "name": "fullyconnected1_bias", + "attrs": {"num_hidden": "64"}, + "inputs": [] + }, + { + "op": "FullyConnected", + "name": "fullyconnected1", + "attrs": {"num_hidden": "64"}, + "inputs": [[5, 0, 0], [6, 0, 0], [7, 0, 0]] + }, + { + "op": "Activation", + "name": "activation1", + "attrs": {"act_type": "relu"}, + "inputs": [[8, 0, 0]] + }, + { + "op": "null", + "name": "fullyconnected2_weight", + "attrs": {"num_hidden": "10"}, + "inputs": [] + }, + { + "op": "null", + "name": "fullyconnected2_bias", + "attrs": {"num_hidden": "10"}, + "inputs": [] + }, + { + "op": "FullyConnected", + "name": "fullyconnected2", + "attrs": {"num_hidden": "10"}, + "inputs": [[9, 0, 0], [10, 0, 0], [11, 0, 0]] + }, + { + "op": "null", + "name": "softmax_label", + "inputs": [] + }, + { + "op": "SoftmaxOutput", + "name": "softmax", + "inputs": [[12, 0, 0], [13, 0, 0]] + } + ], + "arg_nodes": [0, 2, 3, 6, 7, 10, 11, 13], + "node_row_ptr": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15 + ], + "heads": [[14, 0, 0]], + "attrs": {"mxnet_version": ["int", 10300]} +} diff --git a/test/resources/onnx/onnx_model/model.onnx b/test/resources/onnx/onnx_model/model.onnx new file mode 100644 index 00000000..df78f0e7 Binary files /dev/null and b/test/resources/onnx/onnx_model/model.onnx differ diff --git a/test/unit/test_module_transformer.py b/test/unit/test_module_transformer.py deleted file mode 100644 index 3756f029..00000000 --- a/test/unit/test_module_transformer.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -import json -import logging -import os -import tempfile -from types import ModuleType - -import pytest -from container_support.serving import UnsupportedContentTypeError, \ - UnsupportedAcceptTypeError, \ - UnsupportedInputShapeError, \ - JSON_CONTENT_TYPE, \ - CSV_CONTENT_TYPE -from mock import Mock -from mock import patch, MagicMock - -JSON_DATA = json.dumps({'k1': 'v1', 'k2': [1, 2, 3]}) -CSV_INPUT = "1,2,3\r\n" - - -@pytest.fixture() -def mxc(): - os.environ['SAGEMAKER_CONTAINER_LOG_LEVEL'] = str(logging.INFO) - os.environ['SAGEMAKER_REGION'] = 'us-west-2' - - mxnet_mock = MagicMock() - modules = { - 'mxnet': mxnet_mock, - } - ndarray = Mock('ndarray') - ndarray.asnumpy = Mock(name='asnumpy') - ndarray.asnumpy().tolist = Mock(name='tolist', return_value=[1, 2, 3]) - ndarray.asnumpy().flatten = Mock(name='flatten', return_value=[1, 2, 3]) - ndarray.reshape = Mock(return_value=ndarray) - ndarray.shape = [1, 1, 1] - mxnet_mock.nd.array = Mock(name='array', return_value=ndarray) - - patcher = patch.dict('sys.modules', modules) - patcher.start() - import mxnet_container - yield mxnet_container - patcher.stop() - - -@pytest.fixture() -def user_module(): - m = ModuleType('mod') - m.model_fn = _model_fn - yield m - - -@pytest.fixture() -def transformer(mxc, user_module): - import mxnet_container - from mxnet_container.serve.transformer import ModuleTransformer - transform_class = 'mxnet_container.serve.transformer.MXNetTransformer.select_transformer_class' - with patch(transform_class) as select: - select.return_value = ModuleTransformer - yield mxnet_container.serve.transformer.transformer(user_module) - - -@pytest.fixture -def module_module(mxc): - mock = Mock('BaseModule') - m = ModuleType('mod') - m.model_fn = lambda x: mock - m.input_fn = lambda a, b, c: 'input({})'.format(str(b)) - m.predict_fn = lambda a, b: 'predict({})'.format(str(b)) - m.output_fn = lambda a, b: ('output({})'.format(str(a)), b) - - # extras for test - m._module = mock - return m - - -class TestModuleTransformer(object): - @patch('mxnet_container.serve.transformer.MXNetTransformer.select_transformer_class') - def test_from_module(self, select, mxc, module_module): - from mxnet_container.serve.transformer import MXNetTransformer, ModuleTransformer - select.return_value = ModuleTransformer - - t = MXNetTransformer.from_module(module_module) - assert isinstance(t, ModuleTransformer) - assert t.model == module_module._module - assert t.transform('x', JSON_CONTENT_TYPE, JSON_CONTENT_TYPE) == \ - ('output(predict(input(x)))', JSON_CONTENT_TYPE) - - def test_transformer_from_module_transform_fn(self, mxc, user_module): - import mxnet_container - user_module.transform_fn = _transform_fn - t = mxnet_container.serve.transformer.transformer(user_module) - assert t.transform("data", JSON_CONTENT_TYPE, JSON_CONTENT_TYPE) == \ - ("transform_fn data", JSON_CONTENT_TYPE) - - def test_transformer_from_module_separate_fn(self, mxc, user_module): - user_module.process_request_fn = _process_request_fn - user_module.output_fn = _output_fn - t = next(transformer(mxc, user_module)) - assert t.transform("data", JSON_CONTENT_TYPE, JSON_CONTENT_TYPE) == \ - ("output_fn predict_fn input_fn data", JSON_CONTENT_TYPE) - - @patch('mxnet_container.serve.transformer.MXNetTransformer.select_transformer_class') - @patch('mxnet_container.serve.transformer.ModuleTransformer._default_model_fn') - def test_transformer_from_module_default_fns(self, model_fn, select, mxc): - import mxnet_container - model_fn.return_value = DummyModel() - select.return_value = mxnet_container.serve.transformer.ModuleTransformer - - m = ModuleType('mod') # an empty module - t = mxnet_container.serve.transformer.transformer(m) - assert hasattr(t, 'model') - assert hasattr(t, 'transform_fn') - - def test_transformer_default_handler_json(self, mxc, transformer): - with patch('json.dumps') as patched: - patched.return_value = JSON_DATA - response, response_content_type = transformer.transform(JSON_DATA, JSON_CONTENT_TYPE, - JSON_CONTENT_TYPE) - - assert JSON_DATA == response - assert JSON_CONTENT_TYPE == response_content_type - - @patch('mxnet_container.serve.transformer.MXNetTransformer.select_transformer_class') - def test_transformer_default_handler_csv(self, select, mxc): - import mxnet_container - - m = ModuleType('mod') - m.model_fn = _model_fn_csv - select.return_value = mxnet_container.serve.transformer.ModuleTransformer - - csv_transformer = mxnet_container.serve.transformer.transformer(m) - - response, response_content_type = csv_transformer.transform(CSV_INPUT, CSV_CONTENT_TYPE, - CSV_CONTENT_TYPE) - - assert CSV_INPUT == response - assert CSV_CONTENT_TYPE == response_content_type - - @patch('mxnet_container.serve.transformer.MXNetTransformer.select_transformer_class') - def test_transformer_default_handler_csv_empty(self, select, mxc): - import mxnet_container - select.return_value = mxnet_container.serve.transformer.ModuleTransformer - - m = ModuleType('mod') - m.model_fn = _model_fn_csv - - csv_transformer = mxnet_container.serve.transformer.transformer(m) - - response, response_content_type = csv_transformer.transform("", CSV_CONTENT_TYPE, - CSV_CONTENT_TYPE) - - assert CSV_INPUT == response - assert CSV_CONTENT_TYPE == response_content_type - - @patch('mxnet_container.serve.transformer.MXNetTransformer.select_transformer_class') - def test_transformer_default_handler_csv_wrong_shape(self, select, mxc): - import mxnet_container - select.return_value = mxnet_container.serve.transformer.ModuleTransformer - - m = ModuleType('mod') - m.model_fn = _model_fn_csv_wrong_shape - - csv_transformer = mxnet_container.serve.transformer.transformer(m) - - with pytest.raises(UnsupportedInputShapeError): - csv_transformer.transform(CSV_INPUT, CSV_CONTENT_TYPE, CSV_CONTENT_TYPE) - - def test_transformer_default_handler_unsupported_content_type(self, transformer): - with pytest.raises(UnsupportedContentTypeError): - transformer.transform(JSON_DATA, "application/bad", JSON_CONTENT_TYPE) - - def test_transformer_default_handler_unsupported_accept_type(self, transformer): - with pytest.raises(UnsupportedAcceptTypeError): - transformer.transform(JSON_DATA, JSON_CONTENT_TYPE, "application/bad") - - def test_transformer_read_data_shapes(self, mxc, user_module): - from mxnet_container.serve.transformer import ModuleTransformer - data_shapes = [ - {"name": "data1", "shape": [10, 2, 3, 4]}, - {"name": "data2", "shape": [13, 4, 5, 6]} - ] - - fname = tempfile.mkstemp()[1] - try: - with open(fname, 'w') as f: - json.dump(data_shapes, f) - - names, shapes = ModuleTransformer._read_data_shapes(f.name) - assert 2 == len(shapes) - assert ('data1', [1, 2, 3, 4]) in shapes - assert ('data2', [1, 4, 5, 6]) in shapes - - finally: - os.remove(fname) - - -class DummyModel(object): - def predict(self, data): - nd_array = Mock(name='ndarray') - nd_array.asnumpy = Mock(name='asnumpy') - nd_array.asnumpy().tolist = Mock(name='tolist', return_value=[1, 2, 3]) - nd_array.asnumpy().flatten = Mock(name='flatten', return_value=[1, 2, 3]) - nd_array.shape = [1, 1, 1] - return [nd_array] - - @property - def data_shapes(self): - return [["DataDesc1", (1, 1, 3)]] - - -class DummyModelForCsv(object): - def __init__(self, wrong_shape=False): - self.make_wrong_shape = wrong_shape - - def predict(self, data): - nd_array = Mock(name='ndarray') - nd_array.asnumpy = Mock(name='asnumpy') - nd_array.asnumpy().tolist = Mock(name='tolist', return_value=[1, 2, 3]) - nd_array.asnumpy().flatten = Mock(name='flatten', return_value=[1, 2, 3]) - nd_array.shape = [1, 1, 1] - return [nd_array] - - @property - def data_shapes(self): - if self.make_wrong_shape: - return [] - return [["DataDesc1", (1, 1, 3)]] - - -def _model_fn(model_dir): - return DummyModel() - - -def _model_fn_csv(model_dir): - return DummyModelForCsv() - - -def _model_fn_csv_wrong_shape(model_dir): - return DummyModelForCsv(wrong_shape=True) - - -def _process_request_fn(model, data, content_type): - return "predict_fn input_fn " + data - - -def _output_fn(data, content_type): - return "output_fn " + data, JSON_CONTENT_TYPE - - -def _transform_fn(model, data, input_content_type, output_content_type): - return "transform_fn " + str(data), JSON_CONTENT_TYPE diff --git a/test/unit/test_serving.py b/test/unit/test_serving.py new file mode 100644 index 00000000..737c7877 --- /dev/null +++ b/test/unit/test_serving.py @@ -0,0 +1,279 @@ +# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os + +from mock import call, Mock, mock_open, patch +import mxnet as mx +import pytest +from sagemaker_containers.beta.framework import content_types, errors, transformer, worker + +from sagemaker_mxnet_container.serving import (_user_module_transformer, default_model_fn, + GluonBlockTransformer, ModuleTransformer, + MXNetTransformer) + +MODEL_DIR = 'foo/model' + + +@patch('mxnet.cpu') +@patch('mxnet.mod.Module') +@patch('mxnet.model.load_checkpoint') +@patch('os.path.exists', return_value=True) +def test_default_model_fn(path_exists, mx_load_checkpoint, mx_module, mx_cpu): + sym = Mock() + args = Mock() + aux = Mock() + mx_load_checkpoint.return_value = [sym, args, aux] + + mx_context = Mock() + mx_cpu.return_value = mx_context + + data_name = 'foo' + data_shape = [1] + signature = json.dumps([{'name': data_name, 'shape': data_shape}]) + + with patch('six.moves.builtins.open', mock_open(read_data=signature)): + default_model_fn(MODEL_DIR) + + mx_load_checkpoint.assert_called_with(os.path.join(MODEL_DIR, 'model'), 0) + + init_call = call(symbol=sym, context=mx_context, data_names=[data_name], label_names=None) + assert init_call in mx_module.mock_calls + + model = mx_module.return_value + model.bind.assert_called_with(for_training=False, data_shapes=[(data_name, data_shape)]) + model.set_params.assert_called_with(args, aux, allow_missing=True) + + +@patch('sagemaker_containers.beta.framework.functions.error_wrapper', lambda x, y: x) +def test_mxnet_transformer_init(): + t = MXNetTransformer() + + assert t._model is None + assert t._model_fn == transformer.default_model_fn + assert t._input_fn == t.default_input_fn + assert t._predict_fn == t.default_predict_fn + assert t._output_fn == t.default_output_fn + assert t.VALID_CONTENT_TYPES == (content_types.JSON,) + + +@patch('sagemaker_containers.beta.framework.functions.error_wrapper', lambda x, y: x) +def test_mxnet_transformer_init_with_args(): + model = Mock() + model_fn = Mock() + input_fn = Mock() + predict_fn = Mock() + output_fn = Mock() + error_class = Mock() + + t = MXNetTransformer(model=model, model_fn=model_fn, input_fn=input_fn, predict_fn=predict_fn, + output_fn=output_fn, error_class=error_class) + + assert t._model == model + assert t._model_fn == model_fn + assert t._input_fn == input_fn + assert t._predict_fn == predict_fn + assert t._output_fn == output_fn + assert t._error_class == error_class + + +@patch('sagemaker_containers.beta.framework.transformer.Transformer.initialize') +def test_mxnet_transformer_initialize_without_model(transformer_initialize): + t = MXNetTransformer() + t.initialize() + + transformer_initialize.assert_called_once() + + +@patch('sagemaker_containers.beta.framework.transformer.Transformer.initialize') +def test_mxnet_transformer_initialize_with_model(transformer_initialize): + t = MXNetTransformer(model=Mock()) + t.initialize() + + transformer_initialize.assert_not_called() + + +@patch('sagemaker_containers.beta.framework.encoders.decode', return_value=[0]) +def test_mxnet_transformer_default_input_fn(decode): + input_data = Mock() + content_type = 'application/json' + + t = MXNetTransformer() + deserialized_data = t.default_input_fn(input_data, content_type) + + decode.assert_called_with(input_data, content_type) + assert deserialized_data == mx.nd.array([0]) + + +def test_mxnet_transformer_default_input_fn_invalid_content_type(): + t = MXNetTransformer() + + with pytest.raises(errors.UnsupportedFormatError) as e: + t.default_input_fn(None, 'bad/content-type') + assert 'Content type bad/content-type is not supported by this framework' in str(e) + + +@patch('sagemaker_containers.beta.framework.encoders.encode') +def test_mxnet_transformer_default_output_fn(encode): + prediction = mx.ndarray.zeros(1) + accept = 'application/json' + + t = MXNetTransformer() + response = t.default_output_fn(prediction, accept) + + flattened_prediction = prediction.asnumpy().tolist() + encode.assert_called_with(flattened_prediction, accept) + + assert isinstance(response, worker.Response) + + +def test_mxnet_transformer_default_output_fn_invalid_content_type(): + t = MXNetTransformer() + + with pytest.raises(errors.UnsupportedFormatError) as e: + t.default_output_fn(None, 'bad/content-type') + assert 'Content type bad/content-type is not supported by this framework' in str(e) + + +def test_module_transformer_init_valid_content_types(): + t = ModuleTransformer() + assert content_types.JSON in t.VALID_CONTENT_TYPES + assert content_types.CSV in t.VALID_CONTENT_TYPES + + +@patch('mxnet.io.NDArrayIter') +@patch('sagemaker_containers.beta.framework.encoders.decode', return_value=[0]) +def test_module_transformer_default_input_fn_with_json(decode, mx_ndarray_iter): + model = Mock(data_shapes=[(1, (1,))]) + t = ModuleTransformer(model=model) + + input_data = Mock() + content_type = 'application/json' + t.default_input_fn(input_data, content_type) + + decode.assert_called_with(input_data, content_type) + init_call = call(mx.nd.array([0]), batch_size=1, last_batch_handle='pad') + assert init_call in mx_ndarray_iter.mock_calls + + +@patch('mxnet.nd.array') +@patch('mxnet.io.NDArrayIter') +@patch('sagemaker_containers.beta.framework.encoders.decode', return_value=[0]) +def test_module_transformer_default_input_fn_with_csv(decode, mx_ndarray_iter, mx_ndarray): + ndarray = Mock(shape=(1, (1,))) + ndarray.reshape.return_value = ndarray + mx_ndarray.return_value = ndarray + + model = Mock(data_shapes=[(1, (1,))]) + t = ModuleTransformer(model=model) + + input_data = Mock() + content_type = 'text/csv' + t.default_input_fn(input_data, content_type) + + decode.assert_called_with(input_data, content_type) + ndarray.reshape.assert_called_with((1,)) + init_call = call(mx.nd.array([0]), batch_size=1, last_batch_handle='pad') + assert init_call in mx_ndarray_iter.mock_calls + + +def test_module_transformer_default_input_fn_invalid_content_type(): + t = ModuleTransformer() + + with pytest.raises(errors.UnsupportedFormatError) as e: + t.default_input_fn(None, 'bad/content-type') + assert 'Content type bad/content-type is not supported by this framework' in str(e) + + +def test_module_transformer_default_predict_fn(): + t = ModuleTransformer() + module = Mock() + data = Mock() + + t.default_predict_fn(data, module) + module.predict.assert_called_with(data) + + +def test_gluon_transformer_default_predict_fn(): + data = [0] + block = Mock() + + t = GluonBlockTransformer() + t.default_predict_fn(data, block) + + block.assert_called_with(data) + + +@patch('sagemaker_containers.beta.framework.functions.error_wrapper', lambda x, y: x) +@patch('sagemaker_mxnet_container.serving.default_model_fn') +def test_user_module_transformer_with_transform_fn(model_fn): + class UserModule: + def __init__(self): + self.transform_fn = Mock() + + user_module = UserModule() + + t = _user_module_transformer(user_module, MODEL_DIR) + assert t._transform_fn == user_module.transform_fn + + +@patch('sagemaker_containers.beta.framework.functions.error_wrapper', lambda x, y: x) +@patch('sagemaker_mxnet_container.serving.default_model_fn') +def test_user_module_transformer_module_transformer_no_user_methods(model_fn): + module = mx.module.BaseModule() + model_fn.return_value = module + + user_module = None + t = _user_module_transformer(user_module, MODEL_DIR) + + assert isinstance(t, ModuleTransformer) + assert t._model == module + assert t._model_fn == model_fn + assert t._input_fn == t.default_input_fn + assert t._predict_fn == t.default_predict_fn + assert t._output_fn == t.default_output_fn + + +@patch('sagemaker_containers.beta.framework.functions.error_wrapper', lambda x, y: x) +def test_user_module_transformer_gluon_transformer_with_user_methods(): + gluon_block = mx.gluon.block.Block() + + class UserModule: + def __init__(self): + self.input_fn = Mock() + self.predict_fn = Mock() + self.output_fn = Mock() + + def model_fn(self, model_dir): + return gluon_block + + user_module = UserModule() + t = _user_module_transformer(user_module, MODEL_DIR) + + assert isinstance(t, GluonBlockTransformer) + assert t._model == gluon_block + assert t._model_fn == user_module.model_fn + assert t._input_fn == user_module.input_fn + assert t._predict_fn == user_module.predict_fn + assert t._output_fn == user_module.output_fn + + +@patch('sagemaker_mxnet_container.serving.default_model_fn', return_value=Mock()) +def test_user_module_transformer_unsupported_model_type(model_fn): + user_module = None + with pytest.raises(ValueError) as e: + _user_module_transformer(user_module, MODEL_DIR) + + assert 'Unsupported model type' in str(e) diff --git a/test/unit/test_train.py b/test/unit/test_train.py deleted file mode 100644 index 56398982..00000000 --- a/test/unit/test_train.py +++ /dev/null @@ -1,226 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -import json -import os -import shutil -import tempfile -from inspect import getargspec - -import pytest -from container_support import ContainerEnvironment -from mock import patch, MagicMock, create_autospec - -INPUT_DATA_CONFIG = { - "train": {"ContentType": "trainingContentType"}, - "evaluation": {"ContentType": "evalContentType"}, - "Validation": {} -} - -HYPERPARAMETERS = { - ContainerEnvironment.USER_SCRIPT_NAME_PARAM: 'myscript.py', - ContainerEnvironment.USER_SCRIPT_ARCHIVE_PARAM: 's3://mybucket/code.tar.gz', - "sagemaker_s3_uri_training": "blah/blah", - "sagemaker_s3_uri_validation": "xxx/yyy", - 'sagemaker_region': "us-west-2" -} - - -class NoKWArgsModule: - def train(self, hyperparameters): - pass - - -class KWArgsModule: - def train(self, **kwargs): - pass - - -getargspec_orig = getargspec - - -def train_no_kwargs_mock(): - return create_autospec(NoKWArgsModule) - - -def train_kwargs_mock(): - return create_autospec(KWArgsModule) - - -@pytest.fixture() -def mxc(): - mxnet_mock = MagicMock() - modules = { - 'mxnet': mxnet_mock - } - - patcher = patch.dict('sys.modules', modules) - patcher.start() - import mxnet_container - yield mxnet_container - patcher.stop() - - -@pytest.fixture() -def training(): - d = optml() - yield d - shutil.rmtree(d) - - -def optml(): - tmp = tempfile.mkdtemp() - for d in ['input/data/training', 'input/config', 'model', 'output/data']: - os.makedirs(os.path.join(tmp, d)) - - with open(os.path.join(tmp, 'input/data/training/data.csv'), 'w') as f: - f.write('dummy data file') - - _write_resource_config(tmp, 'a', ['a', 'b']) - _write_config_file(tmp, 'inputdataconfig.json', INPUT_DATA_CONFIG) - _write_config_file(tmp, 'hyperparameters.json', _serialize_hyperparameters(HYPERPARAMETERS)) - - return tmp - - -def test_mxnet_env_is_distributed(mxc, training): - from mxnet_container.train import MXNetTrainingEnvironment - - with patch('socket.gethostbyname') as patched: - mxnet_env = MXNetTrainingEnvironment(training) - assert mxnet_env.distributed - - -def test_mxnet_env_is_not_distributed(mxc, training): - from mxnet_container.train import MXNetTrainingEnvironment - - _write_resource_config(training, 'a', ['a']) - - with patch('socket.gethostbyname') as patched: - mxnet_env = MXNetTrainingEnvironment(training) - assert not mxnet_env.distributed - - -def test_mnxet_env_env_vars(mxc, training): - from mxnet_container.train import MXNetTrainingEnvironment - - with patch('socket.gethostbyname') as patched: - patched.return_value = '0.0.0.0' - mxnet_env = MXNetTrainingEnvironment(training) - assert mxnet_env.env_vars_for_role('worker') == { - 'DMLC_NUM_WORKER': "2", - 'DMLC_NUM_SERVER': "2", - 'DMLC_ROLE': 'worker', - 'DMLC_PS_ROOT_URI': '0.0.0.0', - 'DMLC_PS_ROOT_PORT': "8000", - 'PS_VERBOSE': "0" - } - - -def test_mxnet_env_is_current_host_scheduler(mxc, training): - from mxnet_container.train import MXNetTrainingEnvironment - - with patch('socket.gethostbyname') as patched: - mxnet_env = MXNetTrainingEnvironment(training) - assert mxnet_env.current_host_scheduler - - -def test_mxnet_env_not_is_current_host_scheduler(mxc, training): - from mxnet_container.train import MXNetTrainingEnvironment - - _write_resource_config(training, 'b', ['a', 'b']) - - with patch('socket.gethostbyname') as patched: - mxnet_env = MXNetTrainingEnvironment(training) - assert not mxnet_env.current_host_scheduler - - -def test_train_with_no_kwargs_in_user_module(mxc): - from mxnet_container import train - with patch('container_support.download_s3_resource') as patched_download_s3_resource, \ - patch('container_support.untar_directory') as patched_untar_directory, \ - patch('socket.gethostbyname') as patched_gethostbyname, \ - patch('inspect.getargspec') as patched_getargspec, \ - patch('importlib.import_module', - new_callable=train_no_kwargs_mock) as patched_import_module: - patched_getargspec.return_value = getargspec_orig(NoKWArgsModule.train) - - train(optml()) - assert patched_import_module.return_value.train.called - - -def test_train_failing_script(mxc): - from mxnet_container import train - - def raise_error(*args, **kwargs): - raise ValueError("I failed") - - with patch('container_support.download_s3_resource') as patched_download_s3_resource, \ - patch('container_support.untar_directory') as patched_untar_directory, \ - patch('socket.gethostbyname') as patched_gethostbyname, \ - patch('inspect.getargspec') as patched_getargspec, \ - patch('importlib.import_module', - new_callable=train_kwargs_mock) as patched_import_module: - patched_getargspec.return_value = getargspec_orig(KWArgsModule.train) - patched_import_module.return_value.train.side_effect = raise_error - - with pytest.raises(ValueError): - train(optml()) - assert patched_import_module.return_value.train.called - - -def test_train(mxc): - from mxnet_container import train - - with patch('container_support.download_s3_resource') as patched_download_s3_resource, \ - patch('container_support.untar_directory') as patched_untar_directory, \ - patch('subprocess.Popen') as patched_Popen, \ - patch('socket.gethostbyname'), \ - patch('inspect.getargspec') as patched_getargspec, \ - patch('importlib.import_module', - new_callable=train_kwargs_mock) as patched_import_module: - patched_getargspec.return_value = getargspec_orig(KWArgsModule.train) - - train(optml()) - assert patched_Popen.call_count == 3 - assert patched_import_module.return_value.train.called - - -def test_train_save_shape(mxc, training): - with patch('socket.gethostbyname') as patched_gethostbyname: - from mxnet_container.train import MXNetTrainingEnvironment - env = MXNetTrainingEnvironment(training) - mock_module = MagicMock() - data_desc = MagicMock() - data_desc.name = "elizabeth" - data_desc.shape = [100, 200, 300, 400] - mock_module.data_shapes = [data_desc] - env.default_save(mock_module) - with open(os.path.join(env.model_dir, 'model-shapes.json')) as f: - read_data_shape = json.load(f) - expected_data_shape = [{'shape': [100, 200, 300, 400], 'name': 'elizabeth'}] - assert expected_data_shape == read_data_shape - - -def _write_config_file(path, filename, data): - path = os.path.join(path, "input/config/%s" % filename) - with open(path, 'w') as f: - json.dump(data, f) - - -def _write_resource_config(path, current_host, hosts): - _write_config_file(path, 'resourceconfig.json', {'current_host': current_host, 'hosts': hosts}) - - -def _serialize_hyperparameters(hp): - return {str(k): json.dumps(v) for (k, v) in hp.items()} diff --git a/test/unit/test_training.py b/test/unit/test_training.py new file mode 100644 index 00000000..9980829f --- /dev/null +++ b/test/unit/test_training.py @@ -0,0 +1,129 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import call, MagicMock, patch +import pytest + +from sagemaker_mxnet_container import training + +MODULE_DIR = 's3://my/bucket' +MODULE_NAME = 'script_name' + +SCHEDULER = 'host-1' +SINGLE_HOST_LIST = [SCHEDULER] +MULTIPLE_HOST_LIST = [SCHEDULER, 'host-2', 'host-3'] + + +IP_ADDRESS = '0.0.0.0000' +DEFAULT_PORT = '8000' +DEFAULT_VERBOSITY = '0' +BASE_ENV_VARS = { + 'DMLC_NUM_WORKER': str(len(MULTIPLE_HOST_LIST)), + 'DMLC_NUM_SERVER': str(len(MULTIPLE_HOST_LIST)), + 'DMLC_PS_ROOT_URI': IP_ADDRESS, + 'DMLC_PS_ROOT_PORT': DEFAULT_PORT, + 'PS_VERBOSE': DEFAULT_VERBOSITY, +} + +MXNET_COMMAND = "python -c 'import mxnet'" + + +@pytest.fixture +def single_machine_training_env(): + env = MagicMock() + + env.module_dir = MODULE_DIR + env.module_name = MODULE_NAME + env.hyperparameters = {} + env.additional_framework_parameters = {} + + return env + + +@pytest.fixture +def distributed_training_env(): + env = MagicMock() + + env.module_dir = MODULE_DIR + env.module_name = MODULE_NAME + env.hyperparameters = {} + + env.hosts = MULTIPLE_HOST_LIST + env.additional_framework_parameters = { + training.LAUNCH_PS_ENV_NAME: True, + } + + return env + + +@patch('os.environ', {}) +@patch('subprocess.Popen') +@patch('sagemaker_mxnet_container.training._host_lookup') +@patch('sagemaker_mxnet_container.training._verify_hosts') +@patch('sagemaker_containers.beta.framework.modules.run_module') +def test_train_for_distributed_scheduler(run_module, verify_hosts, host_lookup, popen, + distributed_training_env): + host_lookup.return_value = IP_ADDRESS + + distributed_training_env.current_host = SCHEDULER + training.train(distributed_training_env) + + verify_hosts.assert_called_with(MULTIPLE_HOST_LIST) + + scheduler_env = BASE_ENV_VARS.copy() + scheduler_env.update({'DMLC_ROLE': 'scheduler'}) + + server_env = BASE_ENV_VARS.copy() + server_env.update({'DMLC_ROLE': 'server'}) + + calls = [call(MXNET_COMMAND, shell=True, env=scheduler_env), + call(MXNET_COMMAND, shell=True, env=server_env)] + + popen.assert_has_calls(calls) + + +@patch('os.environ', {}) +@patch('subprocess.Popen') +@patch('sagemaker_mxnet_container.training._host_lookup') +@patch('sagemaker_mxnet_container.training._verify_hosts') +@patch('sagemaker_containers.beta.framework.modules.run_module') +def test_train_for_distributed_worker(run_module, verify_hosts, host_lookup, popen, + distributed_training_env): + host_lookup.return_value = IP_ADDRESS + + distributed_training_env.current_host = 'host-2' + training.train(distributed_training_env) + + verify_hosts.assert_called_with(MULTIPLE_HOST_LIST) + + server_env = BASE_ENV_VARS.copy() + server_env.update({'DMLC_ROLE': 'server'}) + + popen.assert_called_once_with(MXNET_COMMAND, shell=True, env=server_env) + + +@patch('sagemaker_containers.beta.framework.modules.run_module') +def test_train_for_single_machine(run_module, single_machine_training_env): + training.train(single_machine_training_env) + run_module.assert_called_with(MODULE_DIR, single_machine_training_env.to_cmd_args(), + single_machine_training_env.to_env_vars(), MODULE_NAME) + + +@patch('sagemaker_mxnet_container.training.train') +@patch('sagemaker_containers.beta.framework.training_env') +def test_main(env, train, single_machine_training_env): + env.return_value = single_machine_training_env + + training.main() + train.assert_called_with(single_machine_training_env) diff --git a/test/unit/test_training_utils.py b/test/unit/test_training_utils.py new file mode 100644 index 00000000..061ed7c7 --- /dev/null +++ b/test/unit/test_training_utils.py @@ -0,0 +1,69 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os + +from mock import Mock, mock_open, patch + +from sagemaker_mxnet_container import training_utils + +SCHEDULER_HOST = 'host-1' +WORKER_HOST = 'host-2' +MODEL_DIR = 'foo/model' + + +@patch('json.dump') +@patch('os.environ', {'SM_CURRENT_HOST': SCHEDULER_HOST, 'SM_HOSTS': json.dumps([SCHEDULER_HOST])}) +def test_save_single_machine(json_dump): + model = Mock() + model.data_shapes = [] + + with patch('six.moves.builtins.open', mock_open(read_data=Mock())): + training_utils.save(MODEL_DIR, model) + + model.symbol.save.assert_called_with(os.path.join(MODEL_DIR, 'model-symbol.json')) + model.save_params.assert_called_with(os.path.join(MODEL_DIR, 'model-0000.params')) + json_dump.assert_called_once + + +@patch('json.dump') +def test_save_distributed(json_dump): + model = Mock() + model.data_shapes = [] + + with patch('six.moves.builtins.open', mock_open(read_data=Mock())): + training_utils.save(MODEL_DIR, model, current_host=SCHEDULER_HOST, + hosts=[SCHEDULER_HOST, WORKER_HOST]) + + model.symbol.save.assert_called_with(os.path.join(MODEL_DIR, 'model-symbol.json')) + model.save_params.assert_called_with(os.path.join(MODEL_DIR, 'model-0000.params')) + json_dump.assert_called_once + + +def test_save_for_non_scheduler_host(): + model = Mock() + training_utils.save(MODEL_DIR, model, current_host=WORKER_HOST, + hosts=[SCHEDULER_HOST, WORKER_HOST]) + + model.symbol.save.assert_not_called + model.save_params.assert_not_called + + +def test_single_machine_scheduler_host(): + assert training_utils.scheduler_host([SCHEDULER_HOST]) == SCHEDULER_HOST + + +def test_distributed_scheduler_host(): + assert training_utils.scheduler_host([SCHEDULER_HOST, WORKER_HOST]) == SCHEDULER_HOST diff --git a/test/unit/test_transformer.py b/test/unit/test_transformer.py deleted file mode 100644 index fc8a04e2..00000000 --- a/test/unit/test_transformer.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -import logging -import os -from types import ModuleType - -import pytest -from container_support.serving import UnsupportedContentTypeError, \ - UnsupportedAcceptTypeError, \ - JSON_CONTENT_TYPE -from mock import Mock -from mock import patch, MagicMock - - -@pytest.fixture() -def mxc(): - os.environ['SAGEMAKER_CONTAINER_LOG_LEVEL'] = str(logging.INFO) - os.environ['SAGEMAKER_REGION'] = 'us-west-2' - - mxnet_mock = MagicMock() - modules = { - 'mxnet': mxnet_mock, - } - ndarray = Mock('ndarray') - ndarray.asnumpy = Mock(name='asnumpy') - ndarray.asnumpy().tolist = Mock(name='tolist', return_value=[1, 2, 3]) - ndarray.asnumpy().flatten = Mock(name='flatten', return_value=[1, 2, 3]) - ndarray.reshape = Mock(return_value=ndarray) - ndarray.shape = [1, 1, 1] - mxnet_mock.nd.array = Mock(name='array', return_value=ndarray) - - patcher = patch.dict('sys.modules', modules) - patcher.start() - import mxnet_container - yield mxnet_container - patcher.stop() - - -generic_model = object() - - -def generic_model_fn(model_dir): - return generic_model - - -def generic_transform_fn(model, input_data, content_type, accept): - return input_data, accept - - -@pytest.fixture -def generic_module(): - m = ModuleType('mod') - m.model_fn = generic_model_fn - m.transform_fn = generic_transform_fn - return m - - -class TestMXNetTransformer(object): - def test_from_module_complete(self, mxc): - from mxnet_container.serve.transformer import MXNetTransformer - t = MXNetTransformer.from_module(generic_module()) - assert isinstance(t, MXNetTransformer) - assert t.model == generic_model - assert t.transform_fn == generic_transform_fn - assert t.transform('x', 'content-type', 'accept') == ('x', 'accept') - - @patch('mxnet_container.serve.transformer.ModuleTransformer._default_model_fn') - def test_from_module_with_default_model_fn(self, model_fn, mxc, generic_module): - from mxnet_container.serve.transformer import MXNetTransformer - model_fn.return_value = generic_model - del generic_module.model_fn - - t = MXNetTransformer.from_module(generic_module) - # expect MXNetTransformer with transform_fn from module, model from default_model_fn - assert isinstance(t, MXNetTransformer) - assert t.model == generic_model - assert t.transform_fn == generic_transform_fn - - -@pytest.fixture -def gluon_module(mxc): - mock = Mock('Block') - m = ModuleType('mod') - m.model_fn = lambda x: mock - m.input_fn = lambda a, b: 'input({})'.format(str(a)) - m.predict_fn = lambda a, b: 'predict({})'.format(str(b)) - m.output_fn = lambda a, b: ('output({})'.format(str(a)), b) - - # extra attribute for test - m._block = mock - return m - - -class TestGluonBlockTransformer(object): - @patch('mxnet_container.serve.transformer.MXNetTransformer.select_transformer_class') - def test_from_module(self, select, mxc, gluon_module): - from mxnet_container.serve.transformer import MXNetTransformer, GluonBlockTransformer - select.return_value = GluonBlockTransformer - - t = MXNetTransformer.from_module(gluon_module) - assert isinstance(t, GluonBlockTransformer) - assert t.model == gluon_module._block - assert t.transform('x', JSON_CONTENT_TYPE, JSON_CONTENT_TYPE) == \ - ('output(predict(input(x)))', JSON_CONTENT_TYPE) - - @patch('mxnet_container.serve.transformer.MXNetTransformer.select_transformer_class') - @patch('mxnet_container.serve.transformer.GluonBlockTransformer._default_output_fn') - @patch('mxnet_container.serve.transformer.GluonBlockTransformer._default_predict_fn') - @patch('mxnet_container.serve.transformer.GluonBlockTransformer._default_input_fn') - def test_from_module_with_defaults(self, input_fn, predict_fn, output_fn, - select, mxc, gluon_module): - from mxnet_container.serve.transformer import MXNetTransformer, GluonBlockTransformer - select.return_value = GluonBlockTransformer - - # remove the handlers so we can test default handlers - del gluon_module.input_fn - del gluon_module.predict_fn - del gluon_module.output_fn - - input_fn.return_value = 'default_input' - predict_fn.return_value = 'default_predict' - output_fn.return_value = 'default_output', 'accept' - - t = MXNetTransformer.from_module(gluon_module) - assert isinstance(t, GluonBlockTransformer) - assert t.model == gluon_module._block - assert t.transform('x', JSON_CONTENT_TYPE, JSON_CONTENT_TYPE) == \ - ('default_output', 'accept') - - input_fn.assert_called_with('x', JSON_CONTENT_TYPE) - predict_fn.assert_called_with(gluon_module._block, 'default_input') - output_fn.assert_called_with('default_predict', JSON_CONTENT_TYPE) - - def test_default_input_fn(self, mxc): - import mxnet - from mxnet_container.serve.transformer import GluonBlockTransformer - _ = GluonBlockTransformer._default_input_fn('[[1,2,3,4]]', JSON_CONTENT_TYPE) - mxnet.nd.array.assert_called_with([[1, 2, 3, 4]]) - - def test_default_input_fn_unsupported_content_type(self, mxc): - from mxnet_container.serve.transformer import GluonBlockTransformer - - with pytest.raises(UnsupportedContentTypeError): - GluonBlockTransformer._default_input_fn('whatever', 'wrong content type') - - def test_default_predict_fn(self, mxc): - from mxnet_container.serve.transformer import GluonBlockTransformer - - # block, ndarray could be any compatible callable/arg pair - block = list - ndarray = [1, 2, 3] - - result = GluonBlockTransformer._default_predict_fn(block, ndarray) - - assert [1, 2, 3] == result - - def test_default_output_fn(self, mxc): - import mxnet - from mxnet_container.serve.transformer import GluonBlockTransformer - mock_ndarray = mxnet.nd.array() - output, accept = GluonBlockTransformer._default_output_fn(mock_ndarray, JSON_CONTENT_TYPE) - assert accept == JSON_CONTENT_TYPE - assert output == '[1, 2, 3]' - - def test_default_output_fn_unsupported_content_type(self, mxc): - from mxnet_container.serve.transformer import GluonBlockTransformer - - with pytest.raises(UnsupportedAcceptTypeError): - GluonBlockTransformer._default_output_fn('whatever', 'wrong content type') diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..7fbe84b8 --- /dev/null +++ b/tox.ini @@ -0,0 +1,66 @@ +# Tox (http://tox.testrun.org/) is a tool for running tests +# in multiple virtualenvs. This configuration file will run the +# test suite on all supported python versions. To use it, "pip install tox" +# and then run "tox" from this directory. +[tox] +envlist = py27,py35,flake8 +skip_missing_interpreters = False + +[travis] +python = + 3.5: py35, flake8 + +[flake8] +max-line-length = 100 +exclude = + build/ + .git + __pycache__ + examples/ + .tox + tests/data/ + test/resources + venv/ +max-complexity = 10 +ignore = + FI10, + FI12, + FI13, + FI14, + FI15, + FI16, + FI17, + FI50, + FI51, + FI52, + FI53, + FI54, + FI55, + FI56, + FI57, + E722 + +require-code = True + +[testenv] +# TEAMCITY_VERSION environment variable exists during build on Teamcity. teamcity-messages uses it in order to enable +# reporting to TeamCity. +passenv = + TEAMCITY_VERSION + AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY + AWS_SESSION_TOKEN +# {posargs} can be passed in by additional arguments specified when invoking tox. +# Can be used to specify which tests to run, e.g.: tox -- -s +commands = + coverage run --rcfile .coveragerc_{envname} --source sagemaker_mxnet_container -m pytest {posargs} + {env:IGNORE_COVERAGE:} coverage report --fail-under=90 --include *sagemaker_mxnet_container* +deps = .[test] + +[testenv:flake8] +basepython = python +deps = + flake8 + flake8-future-import + flake8-import-order +commands = flake8