diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index baf644b..7598a68 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,22 +61,22 @@ repos: # mypy - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.2.0' # Use the sha / tag you want to point at + rev: 'v1.7.0' # Use the sha / tag you want to point at hooks: - id: mypy exclude: 'jupyter-book/.*|tests/.*' args: [--config=mypy.ini, --explicit-package-bases] - additional_dependencies: [ pandas-stubs ] + additional_dependencies: [ openai ] # mypy for tests - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.2.0' # Use the sha / tag you want to point at + rev: 'v1.7.0' # Use the sha / tag you want to point at hooks: - id: mypy name: mypy-tests files: ^tests/ args: [--config=mypy-tests.ini] - additional_dependencies: [ pandas-stubs ] + additional_dependencies: [ openai ] # pylint - repo: local @@ -91,4 +91,4 @@ repos: "-rn", # Only display messages "-sn", # Don't display the score "--rcfile=.pylintrc", - ] \ No newline at end of file + ] diff --git a/README.md b/README.md index 8d363d9..7a50d5a 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,12 @@ ![GitHub release (latest by date)](https://img.shields.io/github/v/release/phelps-sg/openai-pygenerator) ![GitHub](https://img.shields.io/github/license/phelps-sg/openai-pygenerator?color=blue) +## Import Note + +The openai Python API versions 1.0.0 and later now incorporate retry functionality and type annotations. This package has been migrated to the new API, but should only be used as a temporary measure to ensure backward compatibility. If you are using this package in production, consider rewriting your code so that it uses the new openai API directly. + +## Overview + This is a simple type-annotated wrapper around the OpenAI Python API which: - provides configurable retry functionality, - reduces the default timeout from 10 minutes to 20 seconds (configurable), diff --git a/poetry.lock b/poetry.lock index e9d6235..bd8f357 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -122,6 +122,41 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "annotated-types" +version = "0.6.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.8" +files = [ + {file = "annotated_types-0.6.0-py3-none-any.whl", hash = "sha256:0641064de18ba7a25dee8f96403ebc39113d0cb953a01429249d5c7564666a43"}, + {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} + +[[package]] +name = "anyio" +version = "3.7.1" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +optional = false +python-versions = ">=3.7" +files = [ + {file = "anyio-3.7.1-py3-none-any.whl", hash = "sha256:91dee416e570e92c64041bd18b900d1d6fa78dff7048769ce5ac5ddad004fbb5"}, + {file = "anyio-3.7.1.tar.gz", hash = "sha256:44a3c9aba0f5defa43261a8b3efb97891f2bd7d804e0e1f56419befa1adfc780"}, +] + +[package.dependencies] +exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} +idna = ">=2.8" +sniffio = ">=1.1" + +[package.extras] +doc = ["Sphinx", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-jquery"] +test = ["anyio[trio]", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (<0.22)"] + [[package]] name = "astroid" version = "3.0.1" @@ -379,6 +414,17 @@ files = [ {file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"}, ] +[[package]] +name = "distro" +version = "1.8.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.8.0-py3-none-any.whl", hash = "sha256:99522ca3e365cac527b44bde033f64c6945d90eb9f769703caaec52b09bbd3ff"}, + {file = "distro-1.8.0.tar.gz", hash = "sha256:02e111d1dc6a50abb8eed6bf31c3e48ed8b0830d1ea2a1b78c61765c2513fdd8"}, +] + [[package]] name = "exceptiongroup" version = "1.1.3" @@ -495,6 +541,62 @@ files = [ {file = "frozenlist-1.4.0.tar.gz", hash = "sha256:09163bdf0b2907454042edb19f887c6d33806adc71fbd54afc14908bfdc22251"}, ] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[[package]] +name = "httpcore" +version = "1.0.2" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"}, + {file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.23.0)"] + +[[package]] +name = "httpx" +version = "0.25.1" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.25.1-py3-none-any.whl", hash = "sha256:fec7d6cc5c27c578a391f7e87b9aa7d3d8fbcd034f6399f9f79b45bcc12a866a"}, + {file = "httpx-0.25.1.tar.gz", hash = "sha256:ffd96d5cf901e63863d9f1b4b6807861dbea4d301613415d9e6e57ead15fc5d0"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + [[package]] name = "identify" version = "2.5.31" @@ -716,25 +818,25 @@ setuptools = "*" [[package]] name = "openai" -version = "0.28.1" -description = "Python client library for the OpenAI API" +version = "1.3.2" +description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-0.28.1-py3-none-any.whl", hash = "sha256:d18690f9e3d31eedb66b57b88c2165d760b24ea0a01f150dd3f068155088ce68"}, - {file = "openai-0.28.1.tar.gz", hash = "sha256:4be1dad329a65b4ce1a660fe6d5431b438f429b5855c883435f0f7fcb6d2dcc8"}, + {file = "openai-1.3.2-py3-none-any.whl", hash = "sha256:97e2febbedc5f1308444d961df63aafb649efebf900d59dd3676fdede9bcd7b6"}, + {file = "openai-1.3.2.tar.gz", hash = "sha256:7904d8f029339931805a8962d88e955f9223a983a8fbd06e01ae40e14735362b"}, ] [package.dependencies] -aiohttp = "*" -requests = ">=2.20" -tqdm = "*" +anyio = ">=3.5.0,<4" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +tqdm = ">4" +typing-extensions = ">=4.5,<5" [package.extras] -datalib = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] -dev = ["black (>=21.6b0,<22.0)", "pytest (==6.*)", "pytest-asyncio", "pytest-mock"] -embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"] -wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] [[package]] name = "packaging" @@ -817,6 +919,142 @@ files = [ {file = "pycodestyle-2.11.1.tar.gz", hash = "sha256:41ba0e7afc9752dfb53ced5489e89f8186be00e599e712660695b7a75ff2663f"}, ] +[[package]] +name = "pydantic" +version = "2.5.1" +description = "Data validation using Python type hints" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic-2.5.1-py3-none-any.whl", hash = "sha256:dc5244a8939e0d9a68f1f1b5f550b2e1c879912033b1becbedb315accc75441b"}, + {file = "pydantic-2.5.1.tar.gz", hash = "sha256:0b8be5413c06aadfbe56f6dc1d45c9ed25fd43264414c571135c97dd77c2bedb"}, +] + +[package.dependencies] +annotated-types = ">=0.4.0" +pydantic-core = "2.14.3" +typing-extensions = ">=4.6.1" + +[package.extras] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.14.3" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic_core-2.14.3-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:ba44fad1d114539d6a1509966b20b74d2dec9a5b0ee12dd7fd0a1bb7b8785e5f"}, + {file = "pydantic_core-2.14.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4a70d23eedd88a6484aa79a732a90e36701048a1509078d1b59578ef0ea2cdf5"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cc24728a1a9cef497697e53b3d085fb4d3bc0ef1ef4d9b424d9cf808f52c146"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ab4a2381005769a4af2ffddae74d769e8a4aae42e970596208ec6d615c6fb080"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:905a12bf088d6fa20e094f9a477bf84bd823651d8b8384f59bcd50eaa92e6a52"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:38aed5a1bbc3025859f56d6a32f6e53ca173283cb95348e03480f333b1091e7d"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1767bd3f6370458e60c1d3d7b1d9c2751cc1ad743434e8ec84625a610c8b9195"}, + {file = "pydantic_core-2.14.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7cb0c397f29688a5bd2c0dbd44451bc44ebb9b22babc90f97db5ec3e5bb69977"}, + {file = "pydantic_core-2.14.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9ff737f24b34ed26de62d481ef522f233d3c5927279f6b7229de9b0deb3f76b5"}, + {file = "pydantic_core-2.14.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a1a39fecb5f0b19faee9a8a8176c805ed78ce45d760259a4ff3d21a7daa4dfc1"}, + {file = "pydantic_core-2.14.3-cp310-none-win32.whl", hash = "sha256:ccbf355b7276593c68fa824030e68cb29f630c50e20cb11ebb0ee450ae6b3d08"}, + {file = "pydantic_core-2.14.3-cp310-none-win_amd64.whl", hash = "sha256:536e1f58419e1ec35f6d1310c88496f0d60e4f182cacb773d38076f66a60b149"}, + {file = "pydantic_core-2.14.3-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:f1f46700402312bdc31912f6fc17f5ecaaaa3bafe5487c48f07c800052736289"}, + {file = "pydantic_core-2.14.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:88ec906eb2d92420f5b074f59cf9e50b3bb44f3cb70e6512099fdd4d88c2f87c"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:056ea7cc3c92a7d2a14b5bc9c9fa14efa794d9f05b9794206d089d06d3433dc7"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:076edc972b68a66870cec41a4efdd72a6b655c4098a232314b02d2bfa3bfa157"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e71f666c3bf019f2490a47dddb44c3ccea2e69ac882f7495c68dc14d4065eac2"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f518eac285c9632be337323eef9824a856f2680f943a9b68ac41d5f5bad7df7c"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9dbab442a8d9ca918b4ed99db8d89d11b1f067a7dadb642476ad0889560dac79"}, + {file = "pydantic_core-2.14.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0653fb9fc2fa6787f2fa08631314ab7fc8070307bd344bf9471d1b7207c24623"}, + {file = "pydantic_core-2.14.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c54af5069da58ea643ad34ff32fd6bc4eebb8ae0fef9821cd8919063e0aeeaab"}, + {file = "pydantic_core-2.14.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:cc956f78651778ec1ab105196e90e0e5f5275884793ab67c60938c75bcca3989"}, + {file = "pydantic_core-2.14.3-cp311-none-win32.whl", hash = "sha256:5b73441a1159f1fb37353aaefb9e801ab35a07dd93cb8177504b25a317f4215a"}, + {file = "pydantic_core-2.14.3-cp311-none-win_amd64.whl", hash = "sha256:7349f99f1ef8b940b309179733f2cad2e6037a29560f1b03fdc6aa6be0a8d03c"}, + {file = "pydantic_core-2.14.3-cp311-none-win_arm64.whl", hash = "sha256:ec79dbe23702795944d2ae4c6925e35a075b88acd0d20acde7c77a817ebbce94"}, + {file = "pydantic_core-2.14.3-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:8f5624f0f67f2b9ecaa812e1dfd2e35b256487566585160c6c19268bf2ffeccc"}, + {file = "pydantic_core-2.14.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6c2d118d1b6c9e2d577e215567eedbe11804c3aafa76d39ec1f8bc74e918fd07"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe863491664c6720d65ae438d4efaa5eca766565a53adb53bf14bc3246c72fe0"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:136bc7247e97a921a020abbd6ef3169af97569869cd6eff41b6a15a73c44ea9b"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aeafc7f5bbddc46213707266cadc94439bfa87ecf699444de8be044d6d6eb26f"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e16aaf788f1de5a85c8f8fcc9c1ca1dd7dd52b8ad30a7889ca31c7c7606615b8"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8fc652c354d3362e2932a79d5ac4bbd7170757a41a62c4fe0f057d29f10bebb"}, + {file = "pydantic_core-2.14.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f1b92e72babfd56585c75caf44f0b15258c58e6be23bc33f90885cebffde3400"}, + {file = "pydantic_core-2.14.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:75f3f534f33651b73f4d3a16d0254de096f43737d51e981478d580f4b006b427"}, + {file = "pydantic_core-2.14.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c9ffd823c46e05ef3eb28b821aa7bc501efa95ba8880b4a1380068e32c5bed47"}, + {file = "pydantic_core-2.14.3-cp312-none-win32.whl", hash = "sha256:12e05a76b223577a4696c76d7a6b36a0ccc491ffb3c6a8cf92d8001d93ddfd63"}, + {file = "pydantic_core-2.14.3-cp312-none-win_amd64.whl", hash = "sha256:1582f01eaf0537a696c846bea92082082b6bfc1103a88e777e983ea9fbdc2a0f"}, + {file = "pydantic_core-2.14.3-cp312-none-win_arm64.whl", hash = "sha256:96fb679c7ca12a512d36d01c174a4fbfd912b5535cc722eb2c010c7b44eceb8e"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:71ed769b58d44e0bc2701aa59eb199b6665c16e8a5b8b4a84db01f71580ec448"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:5402ee0f61e7798ea93a01b0489520f2abfd9b57b76b82c93714c4318c66ca06"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eaab9dc009e22726c62fe3b850b797e7f0e7ba76d245284d1064081f512c7226"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:92486a04d54987054f8b4405a9af9d482e5100d6fe6374fc3303015983fc8bda"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cf08b43d1d5d1678f295f0431a4a7e1707d4652576e1d0f8914b5e0213bfeee5"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8ca13480ce16daad0504be6ce893b0ee8ec34cd43b993b754198a89e2787f7e"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44afa3c18d45053fe8d8228950ee4c8eaf3b5a7f3b64963fdeac19b8342c987f"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:56814b41486e2d712a8bc02a7b1f17b87fa30999d2323bbd13cf0e52296813a1"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c3dc2920cc96f9aa40c6dc54256e436cc95c0a15562eb7bd579e1811593c377e"}, + {file = "pydantic_core-2.14.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:e483b8b913fcd3b48badec54185c150cb7ab0e6487914b84dc7cde2365e0c892"}, + {file = "pydantic_core-2.14.3-cp37-none-win32.whl", hash = "sha256:364dba61494e48f01ef50ae430e392f67ee1ee27e048daeda0e9d21c3ab2d609"}, + {file = "pydantic_core-2.14.3-cp37-none-win_amd64.whl", hash = "sha256:a402ae1066be594701ac45661278dc4a466fb684258d1a2c434de54971b006ca"}, + {file = "pydantic_core-2.14.3-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:10904368261e4509c091cbcc067e5a88b070ed9a10f7ad78f3029c175487490f"}, + {file = "pydantic_core-2.14.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:260692420028319e201b8649b13ac0988974eeafaaef95d0dfbf7120c38dc000"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c1bf1a7b05a65d3b37a9adea98e195e0081be6b17ca03a86f92aeb8b110f468"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d7abd17a838a52140e3aeca271054e321226f52df7e0a9f0da8f91ea123afe98"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a5c51460ede609fbb4fa883a8fe16e749964ddb459966d0518991ec02eb8dfb9"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d06c78074646111fb01836585f1198367b17d57c9f427e07aaa9ff499003e58d"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af452e69446fadf247f18ac5d153b1f7e61ef708f23ce85d8c52833748c58075"}, + {file = "pydantic_core-2.14.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e3ad4968711fb379a67c8c755beb4dae8b721a83737737b7bcee27c05400b047"}, + {file = "pydantic_core-2.14.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:c5ea0153482e5b4d601c25465771c7267c99fddf5d3f3bdc238ef930e6d051cf"}, + {file = "pydantic_core-2.14.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:96eb10ef8920990e703da348bb25fedb8b8653b5966e4e078e5be382b430f9e0"}, + {file = "pydantic_core-2.14.3-cp38-none-win32.whl", hash = "sha256:ea1498ce4491236d1cffa0eee9ad0968b6ecb0c1cd711699c5677fc689905f00"}, + {file = "pydantic_core-2.14.3-cp38-none-win_amd64.whl", hash = "sha256:2bc736725f9bd18a60eec0ed6ef9b06b9785454c8d0105f2be16e4d6274e63d0"}, + {file = "pydantic_core-2.14.3-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:1ea992659c03c3ea811d55fc0a997bec9dde863a617cc7b25cfde69ef32e55af"}, + {file = "pydantic_core-2.14.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d2b53e1f851a2b406bbb5ac58e16c4a5496038eddd856cc900278fa0da97f3fc"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c7f8e8a7cf8e81ca7d44bea4f181783630959d41b4b51d2f74bc50f348a090f"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8d3b9c91eeb372a64ec6686c1402afd40cc20f61a0866850f7d989b6bf39a41a"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ef3e2e407e4cad2df3c89488a761ed1f1c33f3b826a2ea9a411b0a7d1cccf1b"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f86f20a9d5bee1a6ede0f2757b917bac6908cde0f5ad9fcb3606db1e2968bcf5"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61beaa79d392d44dc19d6f11ccd824d3cccb865c4372157c40b92533f8d76dd0"}, + {file = "pydantic_core-2.14.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d41df8e10b094640a6b234851b624b76a41552f637b9fb34dc720b9fe4ef3be4"}, + {file = "pydantic_core-2.14.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2c08ac60c3caa31f825b5dbac47e4875bd4954d8f559650ad9e0b225eaf8ed0c"}, + {file = "pydantic_core-2.14.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:98d8b3932f1a369364606417ded5412c4ffb15bedbcf797c31317e55bd5d920e"}, + {file = "pydantic_core-2.14.3-cp39-none-win32.whl", hash = "sha256:caa94726791e316f0f63049ee00dff3b34a629b0d099f3b594770f7d0d8f1f56"}, + {file = "pydantic_core-2.14.3-cp39-none-win_amd64.whl", hash = "sha256:2494d20e4c22beac30150b4be3b8339bf2a02ab5580fa6553ca274bc08681a65"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:fe272a72c7ed29f84c42fedd2d06c2f9858dc0c00dae3b34ba15d6d8ae0fbaaf"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7e63a56eb7fdee1587d62f753ccd6d5fa24fbeea57a40d9d8beaef679a24bdd6"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7692f539a26265cece1e27e366df5b976a6db6b1f825a9e0466395b314ee48b"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af46f0b7a1342b49f208fed31f5a83b8495bb14b652f621e0a6787d2f10f24ee"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6e2f9d76c00e805d47f19c7a96a14e4135238a7551a18bfd89bb757993fd0933"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:de52ddfa6e10e892d00f747bf7135d7007302ad82e243cf16d89dd77b03b649d"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:38113856c7fad8c19be7ddd57df0c3e77b1b2336459cb03ee3903ce9d5e236ce"}, + {file = "pydantic_core-2.14.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:354db020b1f8f11207b35360b92d95725621eb92656725c849a61e4b550f4acc"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:76fc18653a5c95e5301a52d1b5afb27c9adc77175bf00f73e94f501caf0e05ad"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2646f8270f932d79ba61102a15ea19a50ae0d43b314e22b3f8f4b5fabbfa6e38"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37dad73a2f82975ed563d6a277fd9b50e5d9c79910c4aec787e2d63547202315"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:113752a55a8eaece2e4ac96bc8817f134c2c23477e477d085ba89e3aa0f4dc44"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:8488e973547e8fb1b4193fd9faf5236cf1b7cd5e9e6dc7ff6b4d9afdc4c720cb"}, + {file = "pydantic_core-2.14.3-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:3d1dde10bd9962b1434053239b1d5490fc31a2b02d8950a5f731bc584c7a5a0f"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:2c83892c7bf92b91d30faca53bb8ea21f9d7e39f0ae4008ef2c2f91116d0464a"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:849cff945284c577c5f621d2df76ca7b60f803cc8663ff01b778ad0af0e39bb9"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa89919fbd8a553cd7d03bf23d5bc5deee622e1b5db572121287f0e64979476"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf15145b1f8056d12c67255cd3ce5d317cd4450d5ee747760d8d088d85d12a2d"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4cc6bb11f4e8e5ed91d78b9880774fbc0856cb226151b0a93b549c2b26a00c19"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:832d16f248ca0cc96929139734ec32d21c67669dcf8a9f3f733c85054429c012"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b02b5e1f54c3396c48b665050464803c23c685716eb5d82a1d81bf81b5230da4"}, + {file = "pydantic_core-2.14.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:1f2d4516c32255782153e858f9a900ca6deadfb217fd3fb21bb2b60b4e04d04d"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:0a3e51c2be472b7867eb0c5d025b91400c2b73a0823b89d4303a9097e2ec6655"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:df33902464410a1f1a0411a235f0a34e7e129f12cb6340daca0f9d1390f5fe10"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27828f0227b54804aac6fb077b6bb48e640b5435fdd7fbf0c274093a7b78b69c"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e2979dc80246e18e348de51246d4c9b410186ffa3c50e77924bec436b1e36cb"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b28996872b48baf829ee75fa06998b607c66a4847ac838e6fd7473a6b2ab68e7"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ca55c9671bb637ce13d18ef352fd32ae7aba21b4402f300a63f1fb1fd18e0364"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:aecd5ed096b0e5d93fb0367fd8f417cef38ea30b786f2501f6c34eabd9062c38"}, + {file = "pydantic_core-2.14.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:44aaf1a07ad0824e407dafc637a852e9a44d94664293bbe7d8ee549c356c8882"}, + {file = "pydantic_core-2.14.3.tar.gz", hash = "sha256:3ad083df8fe342d4d8d00cc1d3c1a23f0dc84fce416eb301e69f1ddbbe124d3f"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + [[package]] name = "pyflakes" version = "3.1.0" @@ -844,8 +1082,8 @@ astroid = ">=3.0.1,<=3.1.0-dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ {version = ">=0.2", markers = "python_version < \"3.11\""}, + {version = ">=0.3.6", markers = "python_version >= \"3.11\""}, {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, - {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, ] isort = ">=4.2.5,<6" mccabe = ">=0.6,<0.8" @@ -956,27 +1194,6 @@ files = [ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] -[[package]] -name = "requests" -version = "2.31.0" -description = "Python HTTP for Humans." -optional = false -python-versions = ">=3.7" -files = [ - {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, - {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, -] - -[package.dependencies] -certifi = ">=2017.4.17" -charset-normalizer = ">=2,<4" -idna = ">=2.5,<4" -urllib3 = ">=1.21.1,<3" - -[package.extras] -socks = ["PySocks (>=1.5.6,!=1.5.7)"] -use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] - [[package]] name = "setuptools" version = "68.2.2" @@ -993,6 +1210,17 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +[[package]] +name = "sniffio" +version = "1.3.0" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, + {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, +] + [[package]] name = "tomli" version = "2.0.1" @@ -1046,22 +1274,6 @@ files = [ {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, ] -[[package]] -name = "urllib3" -version = "2.1.0" -description = "HTTP library with thread-safe connection pooling, file post, and more." -optional = false -python-versions = ">=3.8" -files = [ - {file = "urllib3-2.1.0-py3-none-any.whl", hash = "sha256:55901e917a5896a349ff771be919f8bd99aff50b79fe58fec595eb37bbc56bb3"}, - {file = "urllib3-2.1.0.tar.gz", hash = "sha256:df7aa8afb0148fa78488e7899b2c59b5f4ffcfa82e6c54ccb9dd37c1d7b52d54"}, -] - -[package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] -socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] -zstd = ["zstandard (>=0.18.0)"] - [[package]] name = "virtualenv" version = "20.24.6" @@ -1172,4 +1384,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "07527ed55206b578874655df007367e3222ff7bc3dd5c4e1c8f058060b3b80e6" +content-hash = "931d1c29fdf09b71513021778e4094540010261ba4dc300768b3e4ae8a4b15ee" diff --git a/pyproject.toml b/pyproject.toml index cf9dd46..7b6402b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ packages = [{include="openai_pygenerator", from="src"}] [tool.poetry.dependencies] python = "^3.8.1" -openai = "^0.28.1" +openai = "^1.0.0" [tool.poetry.group.dev.dependencies] pytest = "^7.4.2" diff --git a/src/openai_pygenerator/openai_pygenerator.py b/src/openai_pygenerator/openai_pygenerator.py index 47b8992..7abc6fe 100644 --- a/src/openai_pygenerator/openai_pygenerator.py +++ b/src/openai_pygenerator/openai_pygenerator.py @@ -20,26 +20,22 @@ import logging import os -import time from enum import Enum, auto -from typing import Callable, Dict, Iterable, Iterator, List, NewType, Optional, TypeVar +from functools import lru_cache +from typing import Callable, Iterator, List, NewType, Optional, TypeVar -import openai from openai import OpenAI - -client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) -import urllib3.exceptions -from openai.error import ( - APIConnectionError, - APIError, - RateLimitError, - ServiceUnavailableError, +from openai.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionMessage, + ChatCompletionMessageParam, + ChatCompletionUserMessageParam, ) -Completion = Dict[str, str] +Completion = ChatCompletionMessageParam Seconds = NewType("Seconds", int) Completions = Iterator[Completion] -History = Iterable[Completion] +History = List[Completion] Completer = Callable[[History, int], Completions] T = TypeVar("T") @@ -57,38 +53,46 @@ def var(name: str, to_type: Callable[[str], T], default: T) -> T: return to_type(result) +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") GPT_MODEL = var("GPT_MODEL", str, "gpt-3.5-turbo") GPT_TEMPERATURE = var("GPT_TEMPERATURE", float, 0.2) GPT_MAX_TOKENS = var("GPT_MAX_TOKENS", int, 500) GPT_MAX_RETRIES = var("GPT_MAX_RETRIES", int, 5) -GPT_RETRY_EXPONENT_SECONDS = Seconds(var("GPT_RETRY_EXPONENT_SECONDS", int, 2)) -GPT_RETRY_BASE_SECONDS = Seconds(var("GPT_RETRY_BASE_SECONDS", int, 20)) GPT_REQUEST_TIMEOUT_SECONDS = Seconds(var("GPT_REQUEST_TIMEOUT_SECONDS", int, 20)) logger = logging.getLogger(__name__) - def completer( model: str = GPT_MODEL, max_tokens: int = GPT_MAX_TOKENS, temperature: float = GPT_TEMPERATURE, max_retries: int = GPT_MAX_RETRIES, - retry_base: Seconds = GPT_RETRY_BASE_SECONDS, - retry_exponent: Seconds = GPT_RETRY_EXPONENT_SECONDS, + # retry_base: Seconds = GPT_RETRY_BASE_SECONDS, + # retry_exponent: Seconds = GPT_RETRY_EXPONENT_SECONDS, request_timeout: Seconds = GPT_REQUEST_TIMEOUT_SECONDS, ) -> Completer: + @lru_cache() + def get_client() -> OpenAI: + return OpenAI( + api_key=OPENAI_API_KEY, + timeout=request_timeout, + max_retries=max_retries, + ) + def f(messages: History, n: int = 1) -> Completions: return generate_completions( + get_client, messages, model, max_tokens, temperature, - max_retries, - retry_base, - retry_exponent, - request_timeout, - n, + n + # max_retries, + # retry_base, + # retry_exponent, + # request_timeout, + # n, ) return f @@ -98,59 +102,68 @@ def f(messages: History, n: int = 1) -> Completions: def generate_completions( + client: Callable[[], OpenAI], messages: History, model: str, max_tokens: int, temperature: float, - max_retries: int, - retry_base: Seconds, - retry_exponent: Seconds, - request_timeout: Seconds, + # max_retries: int, + # retry_base: Seconds, + # retry_exponent: Seconds, + # request_timeout: Seconds, n: int = 1, - retries: int = 0, + # retries: int = 0, ) -> Completions: logger.debug("messages = %s", messages) - try: - result = client.chat.completions.create(model=model, + # try: + result = client().chat.completions.create( + model=model, messages=messages, max_tokens=max_tokens, n=n, temperature=temperature, - request_timeout=request_timeout) - logger.debug("response = %s", result) - for choice in result.choices: # type: ignore - yield choice.message - except ( - openai.Timeout, # type: ignore - urllib3.exceptions.TimeoutError, - RateLimitError, - APIConnectionError, - APIError, - ServiceUnavailableError, - ) as err: - if isinstance(err, APIError) and not (err.http_status in [524, 502, 500]): - raise - logger.warning("Error returned from openai API: %s", err) - logger.debug("retries = %d", retries) - if retries < max_retries: - logger.info("Retrying... ") - time.sleep(retry_base + retry_exponent**retries) - for completion in generate_completions( - messages, - model, - max_tokens, - temperature, - max_retries, - retry_base, - retry_exponent, - request_timeout, - n, - retries + 1, - ): - yield completion - else: - logger.error("Maximum retries reached, aborting.") - raise + ) + # request_timeout=request_timeout) + logger.debug("response = %s", result) + for choice in result.choices: # type: ignore + yield to_message_param(choice.message) + # except ( + # openai.Timeout, # type: ignore + # urllib3.exceptions.TimeoutError, + # RateLimitError, + # APIConnectionError, + # APIError, + # ServiceUnavailableError, + # ) as err: + # if isinstance(err, APIError) and not (err.http_status in [524, 502, 500]): + # raise + # logger.warning("Error returned from openai API: %s", err) + # logger.debug("retries = %d", retries) + # if retries < max_retries: + # logger.info("Retrying... ") + # time.sleep(retry_base + retry_exponent**retries) + # for completion in generate_completions( + # messages, + # model, + # max_tokens, + # temperature, + # max_retries, + # retry_base, + # retry_exponent, + # request_timeout, + # n, + # retries + 1, + # ): + # yield completion + # else: + # logger.error("Maximum retries reached, aborting.") + # raise + + +def to_message_param(message: ChatCompletionMessage) -> Completion: + return ChatCompletionAssistantMessageParam( + {"role": message.role, "content": message.content} + ) def next_completion(completions: Completions) -> Optional[Completion]: @@ -161,11 +174,11 @@ def next_completion(completions: Completions) -> Optional[Completion]: def user_message(text: str) -> Completion: - return {"role": "user", "content": text} + return ChatCompletionUserMessageParam({"role": "user", "content": text}) def content(completion: Completion) -> str: - return completion["content"] + return str(completion["content"]) def role(completion: Completion) -> Role: diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 2cb9c3d..e56871c 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -20,19 +20,16 @@ from typing import Iterable -import openai.error import pytest -import urllib3.exceptions as urlex -from openai.error import ( - APIConnectionError, - APIError, - RateLimitError, - ServiceUnavailableError, +from openai.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionMessage, + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, ) -from openai.openai_object import OpenAIObject from openai_pygenerator import ( - GPT_MAX_RETRIES, ChatSession, Completer, Completion, @@ -40,16 +37,14 @@ History, Role, content, - gpt_completions, role, transcript, user_message, ) -def aio(text: str) -> OpenAIObject: - result = OpenAIObject() - result["message"] = text +def aio(text: str) -> ChatCompletionMessage: + result = ChatCompletionMessage(role="assistant", content=text) return result @@ -69,58 +64,6 @@ def mock_sleep(mocker): return mocker.patch("time.sleep", return_value=None) -def make_test_completion(role: str) -> Completion: - return {"role": role, "content": "testing"} - - -@pytest.mark.parametrize( - "error", - [ - RateLimitError("rate limited", http_status=429), - APIConnectionError("connection timeout"), - APIError("Gateway Timeout", http_status=524), - ServiceUnavailableError( - message=( - "openai.error.ServiceUnavailableError:" - " The server is overloaded or not ready yet" - ), - http_status=503, - ), - ], -) -def test_generate_completion(mock_openai, mock_sleep, error): - mock_openai.side_effect = [ - error, - error, - MockChoices(["Test completion 1", "Test completion 2"]), - ] - - completions = list(gpt_completions([])) # type: ignore - - assert completions == ["Test completion 1", "Test completion 2"] - assert mock_sleep.call_count == 2 - - -@pytest.mark.parametrize( - "error", - [ - RateLimitError("rate limited", http_status=429), - APIError("Gateway Timeout", http_status=524), - APIError("Server shutdown", http_status=500), - ServiceUnavailableError("Service unavailable"), - urlex.ReadTimeoutError("test-pool", "http://test", "read timeout"), # type: ignore - openai.Timeout, - ], -) -def test_generate_completion_error(mock_openai, mock_sleep, error): - mock_openai.side_effect = [error] * GPT_MAX_RETRIES - - with pytest.raises(Exception): - _ = list(gpt_completions([])) # type: ignore - - assert mock_sleep.call_count == GPT_MAX_RETRIES - - def test_user_message(): test_message = "test" result = user_message(test_message) @@ -133,7 +76,7 @@ def test_message(i: int) -> str: return f"message{i}" test_messages = [user_message(test_message(i)) for i in range(10)] - result = list(transcript(iter(test_messages))) + result = list(transcript(test_messages)) for i in range(10): assert result[i] == test_message(i) @@ -159,16 +102,38 @@ def mock_complete(_history: History, _n: int) -> Completions: ] -@pytest.mark.parametrize("role", ["user", "system", "assistant"]) -def test_content(role: str): - completion = make_test_completion(role) - assert content(completion) == "testing" +@pytest.mark.parametrize( + "message", + [ + ChatCompletionAssistantMessageParam( + {"role": "assistant", "content": "testing"} + ), + ChatCompletionUserMessageParam({"role": "user", "content": "testing"}), + ChatCompletionSystemMessageParam({"role": "system", "content": "testing"}), + ], +) +def test_content(message: ChatCompletionMessageParam): + assert content(message) == "testing" @pytest.mark.parametrize( - "test_role_str, expected", - [("user", Role.USER), ("system", Role.SYSTEM), ("assistant", Role.ASSISTANT)], + "completion, expected", + [ + ( + ChatCompletionAssistantMessageParam( + {"role": "assistant", "content": "testing"} + ), + Role.ASSISTANT, + ), + ( + ChatCompletionUserMessageParam({"role": "user", "content": "testing"}), + Role.USER, + ), + ( + ChatCompletionSystemMessageParam({"role": "system", "content": "testing"}), + Role.SYSTEM, + ), + ], ) -def test_role(test_role_str: str, expected: Role): - completion = make_test_completion(test_role_str) +def test_role(completion: Completion, expected: Role): assert role(completion) == expected