diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 000000000..c73c6149f --- /dev/null +++ b/examples/README.md @@ -0,0 +1,11 @@ +## Running an example + +From the examples folder, run: +`PYTHONPATH=../ python your_example.py` + +## Adding a new example + +1. Clone new_example.py +2. Implement your example +3. Run it (as per above) +4. 👍 diff --git a/examples/meter_event_stream.py b/examples/meter_event_stream.py new file mode 100644 index 000000000..0d02bb9d8 --- /dev/null +++ b/examples/meter_event_stream.py @@ -0,0 +1,44 @@ +from datetime import datetime, timezone +import stripe + +# Global variable for the meter event session +meter_event_session = None + + +def refresh_meter_event_session(api_key): + global meter_event_session + + # Check if the session is None or expired + if meter_event_session is None or datetime.fromisoformat( + meter_event_session["expires_at"] + ) <= datetime.now(timezone.utc): + # Create a new meter event session if the existing session has expired + client = stripe.StripeClient(api_key) + meter_event_session = client.v2.billing.meter_event_session.create() + + +def send_meter_event(meter_event, api_key): + # Refresh the meter event session if necessary + refresh_meter_event_session(api_key) + if not meter_event_session: + raise RuntimeError("Unable to refresh meter event session") + + # Create a meter event with the current session's authentication token + client = stripe.StripeClient(meter_event_session["authentication_token"]) + client.v2.billing.meter_event_stream.create( + params={"events": [meter_event]} + ) + + +# Set your API key here +api_key = "{{API_KEY}}" +customer_id = "{{CUSTOMER_ID}}" + +# Send meter event +send_meter_event( + { + "event_name": "alpaca_ai_tokens", + "payload": {"stripe_customer_id": customer_id, "value": "25"}, + }, + api_key, +) diff --git a/examples/new_example.py b/examples/new_example.py new file mode 100644 index 000000000..53a93e7e6 --- /dev/null +++ b/examples/new_example.py @@ -0,0 +1,8 @@ +import stripe + +# Set your API key here +api_key = "{{API_KEY}}" + +print("Hello world") +# client = stripe.StripeClient(api_key) +# client.v2.... diff --git a/examples/stripe_webhook_handler.py b/examples/stripe_webhook_handler.py new file mode 100644 index 000000000..492fa0485 --- /dev/null +++ b/examples/stripe_webhook_handler.py @@ -0,0 +1,40 @@ +import os +from stripe import StripeClient +from stripe.events import V1BillingMeterErrorReportTriggeredEvent + +from flask import Flask, request, jsonify + +app = Flask(__name__) +api_key = os.environ.get("STRIPE_API_KEY") +webhook_secret = os.environ.get("WEBHOOK_SECRET") + +client = StripeClient(api_key) + + +@app.route("/webhook", methods=["POST"]) +def webhook(): + webhook_body = request.data + sig_header = request.headers.get("Stripe-Signature") + + try: + thin_event = client.parse_thin_event( + webhook_body, sig_header, webhook_secret + ) + + # Fetch the event data to understand the failure + event = client.v2.core.events.retrieve(thin_event.id) + if isinstance(event, V1BillingMeterErrorReportTriggeredEvent): + # CHECK: fetch_object is present and callable, returning a strongly-typed object (without casting) + meter = event.fetch_related_object() + meter_id = meter.id + + # Record the failures and alert your team + # Add your logic here + + return jsonify(success=True), 200 + except Exception as e: + return jsonify(error=str(e)), 400 + + +if __name__ == "__main__": + app.run(port=4242) diff --git a/stripe/__init__.py b/stripe/__init__.py index 90e300584..0300781a1 100644 --- a/stripe/__init__.py +++ b/stripe/__init__.py @@ -2,6 +2,7 @@ from typing import Optional import sys as _sys import os +import warnings # Stripe Python bindings # API docs at http://stripe.com/docs/api @@ -25,6 +26,7 @@ DEFAULT_API_BASE: str = "https://api.stripe.com" DEFAULT_CONNECT_API_BASE: str = "https://connect.stripe.com" DEFAULT_UPLOAD_API_BASE: str = "https://files.stripe.com" +DEFAULT_METER_EVENTS_API_BASE: str = "https://meter-events.stripe.com" api_key: Optional[str] = None @@ -32,22 +34,62 @@ api_base: str = DEFAULT_API_BASE connect_api_base: str = DEFAULT_CONNECT_API_BASE upload_api_base: str = DEFAULT_UPLOAD_API_BASE +meter_events_api_base: str = DEFAULT_METER_EVENTS_API_BASE api_version: str = _ApiVersion.CURRENT verify_ssl_certs: bool = True proxy: Optional[str] = None default_http_client: Optional["HTTPClient"] = None app_info: Optional[AppInfo] = None enable_telemetry: bool = True -max_network_retries: int = 0 +max_network_retries: int = 2 ca_bundle_path: str = os.path.join( os.path.dirname(__file__), "data", "ca-certificates.crt" ) +# Lazily initialized stripe.default_http_client +default_http_client = None +_default_proxy = None + + +def ensure_default_http_client(): + if default_http_client: + _warn_if_mismatched_proxy() + return + _init_default_http_client() + + +def _init_default_http_client(): + global _default_proxy + global default_http_client + + # If the stripe.default_http_client has not been set by the user + # yet, we'll set it here. This way, we aren't creating a new + # HttpClient for every request. + default_http_client = new_default_http_client( + verify_ssl_certs=verify_ssl_certs, proxy=proxy + ) + _default_proxy = proxy + + +def _warn_if_mismatched_proxy(): + global _default_proxy + from stripe import proxy + + if proxy != _default_proxy: + warnings.warn( + "stripe.proxy was updated after sending a " + "request - this is a no-op. To use a different proxy, " + "set stripe.default_http_client to a new client " + "configured with the proxy." + ) + + # Set to either 'debug' or 'info', controls console logging log: Optional[Literal["debug", "info"]] = None # OAuth from stripe._oauth import OAuth as OAuth +from stripe._oauth_service import OAuthService as OAuthService # Webhooks from stripe._webhook import ( @@ -58,6 +100,8 @@ # StripeClient from stripe._stripe_client import StripeClient as StripeClient # noqa +from stripe.v2._event import ThinEvent as ThinEvent # noqa + # Sets some basic information about the running application that's sent along # with API requests. Useful for plugin authors to identify their plugin when @@ -180,8 +224,6 @@ def set_app_info( from stripe import _request_metrics as request_metrics from stripe._file import File as FileUpload - import warnings - # Python 3.7+ supports module level __getattr__ that allows us to lazy load deprecated modules # this matters because if we pre-load all modules from api_resources while suppressing warning # users will never see those warnings @@ -218,6 +260,7 @@ def __getattr__(name): checkout as checkout, climate as climate, entitlements as entitlements, + events as events, financial_connections as financial_connections, forwarding as forwarding, identity as identity, @@ -229,6 +272,7 @@ def __getattr__(name): terminal as terminal, test_helpers as test_helpers, treasury as treasury, + v2 as v2, ) from stripe._account import Account as Account from stripe._account_capability_service import ( @@ -355,6 +399,9 @@ def __getattr__(name): from stripe._ephemeral_key_service import ( EphemeralKeyService as EphemeralKeyService, ) +from stripe._error import ( + TemporarySessionExpiredError as TemporarySessionExpiredError, +) from stripe._event import Event as Event from stripe._event_service import EventService as EventService from stripe._exchange_rate import ExchangeRate as ExchangeRate @@ -530,6 +577,7 @@ def __getattr__(name): from stripe._usage_record_summary import ( UsageRecordSummary as UsageRecordSummary, ) +from stripe._v2_services import V2Services as V2Services from stripe._webhook_endpoint import WebhookEndpoint as WebhookEndpoint from stripe._webhook_endpoint_service import ( WebhookEndpointService as WebhookEndpointService, diff --git a/stripe/_api_mode.py b/stripe/_api_mode.py index 1503e56e0..7290784d3 100644 --- a/stripe/_api_mode.py +++ b/stripe/_api_mode.py @@ -1,4 +1,4 @@ from typing_extensions import Literal -ApiMode = Literal["V1"] +ApiMode = Literal["V1", "V2"] diff --git a/stripe/_api_requestor.py b/stripe/_api_requestor.py index 61f3085eb..65bb449fe 100644 --- a/stripe/_api_requestor.py +++ b/stripe/_api_requestor.py @@ -29,15 +29,14 @@ log_info, dashboard_link, _convert_to_stripe_object, + get_api_mode, ) from stripe._version import VERSION import stripe._error as error import stripe.oauth_error as oauth_error from stripe._multipart_data_generator import MultipartDataGenerator from urllib.parse import urlencode -from stripe._encode import ( - _api_encode, -) +from stripe._encode import _api_encode, _json_encode_date_callback from stripe._stripe_response import ( StripeResponse, StripeStreamResponse, @@ -183,8 +182,8 @@ def request( base_address: BaseAddress, usage: Optional[List[str]] = None, ) -> "StripeObject": + api_mode = get_api_mode(url) requestor = self._replace_options(options) - api_mode = "V1" rbody, rcode, rheaders = requestor.request_raw( method.lower(), url, @@ -195,15 +194,17 @@ def request( options=options, usage=usage, ) - resp = requestor._interpret_response(rbody, rcode, rheaders) + resp = requestor._interpret_response(rbody, rcode, rheaders, api_mode) - return _convert_to_stripe_object( + obj = _convert_to_stripe_object( resp=resp, params=params, requestor=requestor, api_mode=api_mode, ) + return obj + async def request_async( self, method: str, @@ -214,7 +215,7 @@ async def request_async( base_address: BaseAddress, usage: Optional[List[str]] = None, ) -> "StripeObject": - api_mode = "V1" + api_mode = get_api_mode(url) requestor = self._replace_options(options) rbody, rcode, rheaders = await requestor.request_raw_async( method.lower(), @@ -226,15 +227,17 @@ async def request_async( options=options, usage=usage, ) - resp = requestor._interpret_response(rbody, rcode, rheaders) + resp = requestor._interpret_response(rbody, rcode, rheaders, api_mode) - return _convert_to_stripe_object( + obj = _convert_to_stripe_object( resp=resp, params=params, requestor=requestor, api_mode=api_mode, ) + return obj + def request_stream( self, method: str, @@ -245,7 +248,7 @@ def request_stream( base_address: BaseAddress, usage: Optional[List[str]] = None, ) -> StripeStreamResponse: - api_mode = "V1" + api_mode = get_api_mode(url) stream, rcode, rheaders = self.request_raw( method.lower(), url, @@ -262,6 +265,7 @@ def request_stream( cast(IOBase, stream), rcode, rheaders, + api_mode, ) return resp @@ -275,7 +279,7 @@ async def request_stream_async( base_address: BaseAddress, usage: Optional[List[str]] = None, ) -> StripeStreamResponseAsync: - api_mode = "V1" + api_mode = get_api_mode(url) stream, rcode, rheaders = await self.request_raw_async( method.lower(), url, @@ -290,10 +294,13 @@ async def request_stream_async( stream, rcode, rheaders, + api_mode, ) return resp - def handle_error_response(self, rbody, rcode, resp, rheaders) -> NoReturn: + def handle_error_response( + self, rbody, rcode, resp, rheaders, api_mode + ) -> NoReturn: try: error_data = resp["error"] except (KeyError, TypeError): @@ -316,15 +323,60 @@ def handle_error_response(self, rbody, rcode, resp, rheaders) -> NoReturn: ) if err is None: - err = self.specific_api_error( - rbody, rcode, resp, rheaders, error_data + err = ( + self.specific_v2_api_error( + rbody, rcode, resp, rheaders, error_data + ) + if api_mode == "V2" + else self.specific_v1_api_error( + rbody, rcode, resp, rheaders, error_data + ) ) raise err - def specific_api_error(self, rbody, rcode, resp, rheaders, error_data): + def specific_v2_api_error(self, rbody, rcode, resp, rheaders, error_data): + type = error_data.get("type") + code = error_data.get("code") + message = error_data.get("message") + error_args = { + "message": message, + "http_body": rbody, + "http_status": rcode, + "json_body": resp, + "headers": rheaders, + "code": code, + } + log_info( - "Stripe API error received", + "Stripe v2 API error received", + error_code=code, + error_type=error_data.get("type"), + error_message=message, + error_param=error_data.get("param"), + ) + + if type == "idempotency_error": + return error.IdempotencyError( + message, + rbody, + rcode, + resp, + rheaders, + code, + ) + # switchCases: The beginning of the section generated from our OpenAPI spec + elif type == "temporary_session_expired": + return error.TemporarySessionExpiredError(**error_args) + # switchCases: The end of the section generated from our OpenAPI spec + + return self.specific_v1_api_error( + rbody, rcode, resp, rheaders, error_data + ) + + def specific_v1_api_error(self, rbody, rcode, resp, rheaders, error_data): + log_info( + "Stripe v1 API error received", error_code=error_data.get("code"), error_type=error_data.get("type"), error_message=error_data.get("message"), @@ -402,8 +454,13 @@ def specific_oauth_error(self, rbody, rcode, resp, rheaders, error_code): return None - def request_headers(self, method, options: RequestOptions): - user_agent = "Stripe/v1 PythonBindings/%s" % (VERSION,) + def request_headers( + self, method: HttpVerb, api_mode: ApiMode, options: RequestOptions + ): + user_agent = "Stripe/%s PythonBindings/%s" % ( + api_mode.lower(), + VERSION, + ) if stripe.app_info: user_agent += " " + self._format_app_info(stripe.app_info) @@ -436,13 +493,23 @@ def request_headers(self, method, options: RequestOptions): if stripe_account: headers["Stripe-Account"] = stripe_account + stripe_context = options.get("stripe_context") + if stripe_context: + headers["Stripe-Context"] = stripe_context + idempotency_key = options.get("idempotency_key") if idempotency_key: headers["Idempotency-Key"] = idempotency_key - if method == "post": + # IKs should be set for all POST requests and v2 delete requests + if method == "post" or (api_mode == "V2" and method == "delete"): headers.setdefault("Idempotency-Key", str(uuid.uuid4())) - headers["Content-Type"] = "application/x-www-form-urlencoded" + + if method == "post": + if api_mode == "V2": + headers["Content-Type"] = "application/json" + else: + headers["Content-Type"] = "application/x-www-form-urlencoded" stripe_version = options.get("stripe_version") if stripe_version: @@ -462,10 +529,19 @@ def _args_for_request_with_retries( usage: Optional[List[str]] = None, ): """ - Mechanism for issuing an API call + Mechanism for issuing an API call. Used by request_raw and request_raw_async. """ request_options = merge_options(self._options, options) + # Special stripe_version handling for v2 requests: + if ( + options + and "stripe_version" in options + and (options["stripe_version"] is not None) + ): + # If user specified an API version, honor it + request_options["stripe_version"] = options["stripe_version"] + if request_options.get("api_key") is None: raise error.AuthenticationError( "No API key provided. (HINT: set your API key using " @@ -480,14 +556,19 @@ def _args_for_request_with_retries( url, ) - encoded_params = urlencode(list(_api_encode(params or {}))) + encoded_params = urlencode(list(_api_encode(params or {}, api_mode))) # Don't use strict form encoding by changing the square bracket control # characters back to their literals. This is fine by the server, and # makes these parameter strings easier to read. encoded_params = encoded_params.replace("%5B", "[").replace("%5D", "]") - encoded_body = encoded_params + if api_mode == "V2": + encoded_body = json.dumps( + params or {}, default=_json_encode_date_callback + ) + else: + encoded_body = encoded_params supplied_headers = None if ( @@ -496,7 +577,12 @@ def _args_for_request_with_retries( ): supplied_headers = dict(request_options["headers"]) - headers = self.request_headers(method, request_options) + headers = self.request_headers( + # this cast is safe because the blocks below validate that `method` is one of the allowed values + cast(HttpVerb, method), + api_mode, + request_options, + ) if method == "get" or method == "delete": if params: @@ -714,6 +800,7 @@ def _interpret_response( rbody: object, rcode: int, rheaders: Mapping[str, str], + api_mode: ApiMode, ) -> StripeResponse: try: if hasattr(rbody, "decode"): @@ -734,30 +821,17 @@ def _interpret_response( rheaders, ) if self._should_handle_code_as_error(rcode): - self.handle_error_response(rbody, rcode, resp.data, rheaders) - return resp - - async def _interpret_streaming_response_async( - self, - stream: AsyncIterable[bytes], - rcode: int, - rheaders: Mapping[str, str], - ) -> StripeStreamResponseAsync: - if self._should_handle_code_as_error(rcode): - json_content = b"".join([chunk async for chunk in stream]) - self._interpret_response(json_content, rcode, rheaders) - # _interpret_response is guaranteed to throw since we've checked self._should_handle_code_as_error - raise RuntimeError( - "_interpret_response should have raised an error" + self.handle_error_response( + rbody, rcode, resp.data, rheaders, api_mode ) - else: - return StripeStreamResponseAsync(stream, rcode, rheaders) + return resp def _interpret_streaming_response( self, stream: IOBase, rcode: int, rheaders: Mapping[str, str], + api_mode: ApiMode, ) -> StripeStreamResponse: # Streaming response are handled with minimal processing for the success # case (ie. we don't want to read the content). When an error is @@ -775,10 +849,27 @@ def _interpret_streaming_response( % self._get_http_client().name ) - self._interpret_response(json_content, rcode, rheaders) + self._interpret_response(json_content, rcode, rheaders, api_mode) # _interpret_response is guaranteed to throw since we've checked self._should_handle_code_as_error raise RuntimeError( "_interpret_response should have raised an error" ) else: return StripeStreamResponse(stream, rcode, rheaders) + + async def _interpret_streaming_response_async( + self, + stream: AsyncIterable[bytes], + rcode: int, + rheaders: Mapping[str, str], + api_mode: ApiMode, + ) -> StripeStreamResponseAsync: + if self._should_handle_code_as_error(rcode): + json_content = b"".join([chunk async for chunk in stream]) + self._interpret_response(json_content, rcode, rheaders, api_mode) + # _interpret_response is guaranteed to throw since we've checked self._should_handle_code_as_error + raise RuntimeError( + "_interpret_response should have raised an error" + ) + else: + return StripeStreamResponseAsync(stream, rcode, rheaders) diff --git a/stripe/_api_resource.py b/stripe/_api_resource.py index 1fb402ab8..2866b42c1 100644 --- a/stripe/_api_resource.py +++ b/stripe/_api_resource.py @@ -99,6 +99,7 @@ async def _request_async( params=None, *, base_address: BaseAddress = "api", + api_mode: ApiMode = "V1", ) -> StripeObject: obj = await StripeObject._request_async( self, @@ -109,7 +110,7 @@ async def _request_async( ) if type(self) is type(obj): - self._refresh_from(values=obj, api_mode="V1") + self._refresh_from(values=obj, api_mode=api_mode) return self else: return obj diff --git a/stripe/_base_address.py b/stripe/_base_address.py index aa7a133e7..b45e6eda5 100644 --- a/stripe/_base_address.py +++ b/stripe/_base_address.py @@ -2,10 +2,11 @@ from typing_extensions import NotRequired, TypedDict, Literal -BaseAddress = Literal["api", "files", "connect"] +BaseAddress = Literal["api", "files", "connect", "meter_events"] class BaseAddresses(TypedDict): api: NotRequired[Optional[str]] connect: NotRequired[Optional[str]] files: NotRequired[Optional[str]] + meter_events: NotRequired[Optional[str]] diff --git a/stripe/_encode.py b/stripe/_encode.py index 9552a739e..181038ef0 100644 --- a/stripe/_encode.py +++ b/stripe/_encode.py @@ -2,7 +2,7 @@ import datetime import time from collections import OrderedDict -from typing import Generator, Tuple, Any +from typing import Generator, Optional, Tuple, Any def _encode_datetime(dttime: datetime.datetime): @@ -21,7 +21,15 @@ def _encode_nested_dict(key, data, fmt="%s[%s]"): return d -def _api_encode(data) -> Generator[Tuple[str, Any], None, None]: +def _json_encode_date_callback(value): + if isinstance(value, datetime.datetime): + return _encode_datetime(value) + return value + + +def _api_encode( + data, api_mode: Optional[str] +) -> Generator[Tuple[str, Any], None, None]: for key, value in data.items(): if value is None: continue @@ -29,15 +37,16 @@ def _api_encode(data) -> Generator[Tuple[str, Any], None, None]: yield (key, value.stripe_id) elif isinstance(value, list) or isinstance(value, tuple): for i, sv in enumerate(value): + encoded_key = key if api_mode == "V2" else "%s[%d]" % (key, i) if isinstance(sv, dict): - subdict = _encode_nested_dict("%s[%d]" % (key, i), sv) - for k, v in _api_encode(subdict): + subdict = _encode_nested_dict(encoded_key, sv) + for k, v in _api_encode(subdict, api_mode): yield (k, v) else: - yield ("%s[%d]" % (key, i), sv) + yield (encoded_key, sv) elif isinstance(value, dict): subdict = _encode_nested_dict(key, value) - for subkey, subvalue in _api_encode(subdict): + for subkey, subvalue in _api_encode(subdict, api_mode): yield (subkey, subvalue) elif isinstance(value, datetime.datetime): yield (key, _encode_datetime(value)) diff --git a/stripe/_error.py b/stripe/_error.py index aba72701d..3b486ee79 100644 --- a/stripe/_error.py +++ b/stripe/_error.py @@ -1,7 +1,6 @@ from typing import Dict, Optional, Union, cast -# Used for global variable -import stripe # noqa: IMP101 +import stripe # noqa from stripe._error_object import ErrorObject @@ -13,7 +12,7 @@ class StripeError(Exception): headers: Optional[Dict[str, str]] code: Optional[str] request_id: Optional[str] - error: Optional[ErrorObject] + error: Optional["ErrorObject"] def __init__( self, @@ -76,10 +75,13 @@ def _construct_error_object(self) -> Optional[ErrorObject]: or not isinstance(self.json_body["error"], dict) ): return None + from stripe._error_object import ErrorObject return ErrorObject._construct_from( values=self.json_body["error"], requestor=stripe._APIRequestor._global_instance(), + # We pass in API mode as "V1" here because it's required, + # but ErrorObject is reused for both V1 and V2 errors. api_mode="V1", ) @@ -177,3 +179,11 @@ class SignatureVerificationError(StripeError): def __init__(self, message, sig_header, http_body=None): super(SignatureVerificationError, self).__init__(message, http_body) self.sig_header = sig_header + + +# classDefinitions: The beginning of the section generated from our OpenAPI spec +class TemporarySessionExpiredError(StripeError): + pass + + +# classDefinitions: The end of the section generated from our OpenAPI spec diff --git a/stripe/_http_client.py b/stripe/_http_client.py index 459b40ae5..0db2ef3f5 100644 --- a/stripe/_http_client.py +++ b/stripe/_http_client.py @@ -148,7 +148,7 @@ class _Proxy(TypedDict): http: Optional[str] https: Optional[str] - MAX_DELAY = 2 + MAX_DELAY = 5 INITIAL_DELAY = 0.5 MAX_RETRY_AFTER = 60 _proxy: Optional[_Proxy] @@ -242,10 +242,11 @@ def _sleep_time_seconds( self, num_retries: int, response: Optional[Tuple[Any, Any, Mapping[str, str]]] = None, - ): - # Apply exponential backoff with initial_network_retry_delay on the - # number of num_retries so far as inputs. - # Do not allow the number to exceed max_network_retry_delay. + ) -> float: + """ + Apply exponential backoff with initial_network_retry_delay on the number of num_retries so far as inputs. + Do not allow the number to exceed `max_network_retry_delay`. + """ sleep_seconds = min( HTTPClient.INITIAL_DELAY * (2 ** (num_retries - 1)), HTTPClient.MAX_DELAY, @@ -263,9 +264,11 @@ def _sleep_time_seconds( return sleep_seconds - def _add_jitter_time(self, sleep_seconds: float): - # Randomize the value in [(sleep_seconds/ 2) to (sleep_seconds)] - # Also separated method here to isolate randomness for tests + def _add_jitter_time(self, sleep_seconds: float) -> float: + """ + Randomize the value in `[(sleep_seconds/ 2) to (sleep_seconds)]`. + Also separated method here to isolate randomness for tests + """ sleep_seconds *= 0.5 * (1 + random.uniform(0, 1)) return sleep_seconds @@ -900,6 +903,11 @@ def close(self): pass +class _Proxy(TypedDict): + http: Optional[ParseResult] + https: Optional[ParseResult] + + class PycurlClient(HTTPClient): class _ParsedProxy(TypedDict, total=False): http: Optional[ParseResult] @@ -1025,7 +1033,7 @@ def _request_internal( self._curl.setopt(self.pycurl.TIMEOUT, 80) self._curl.setopt( self.pycurl.HTTPHEADER, - ["%s: %s" % (k, v) for k, v in iter(dict(headers).items())], + ["%s: %s" % (k, v) for k, v in dict(headers).items()], ) if self._verify_ssl_certs: self._curl.setopt(self.pycurl.CAINFO, stripe.ca_bundle_path) diff --git a/stripe/_multipart_data_generator.py b/stripe/_multipart_data_generator.py index 3151df83e..1e0be2ba1 100644 --- a/stripe/_multipart_data_generator.py +++ b/stripe/_multipart_data_generator.py @@ -19,7 +19,7 @@ def __init__(self, chunk_size: int = 1028): def add_params(self, params): # Flatten parameters first - params = dict(_api_encode(params)) + params = dict(_api_encode(params, "V1")) for key, value in params.items(): if value is None: diff --git a/stripe/_oauth.py b/stripe/_oauth.py index c2bd478b0..9940bd3ed 100644 --- a/stripe/_oauth.py +++ b/stripe/_oauth.py @@ -315,7 +315,7 @@ def authorize_url( OAuth._set_client_id(params) if "response_type" not in params: params["response_type"] = "code" - query = urlencode(list(_api_encode(params))) + query = urlencode(list(_api_encode(params, "V1"))) url = connect_api_base + path + "?" + query return url diff --git a/stripe/_oauth_service.py b/stripe/_oauth_service.py index 7c24269f5..2a1edc83c 100644 --- a/stripe/_oauth_service.py +++ b/stripe/_oauth_service.py @@ -57,7 +57,7 @@ def authorize_url( self._set_client_id(params) if "response_type" not in params: params["response_type"] = "code" - query = urlencode(list(_api_encode(params))) + query = urlencode(list(_api_encode(params, "V1"))) # connect_api_base will be always set to stripe.DEFAULT_CONNECT_API_BASE # if it is not overridden on the client explicitly. diff --git a/stripe/_object_classes.py b/stripe/_object_classes.py index 027e50d72..fb0a0ba5b 100644 --- a/stripe/_object_classes.py +++ b/stripe/_object_classes.py @@ -148,3 +148,12 @@ stripe.WebhookEndpoint.OBJECT_NAME: stripe.WebhookEndpoint, # Object classes: The end of the section generated from our OpenAPI spec } + +V2_OBJECT_CLASSES = { + # V2 Object classes: The beginning of the section generated from our OpenAPI spec + stripe.v2.billing.MeterEvent.OBJECT_NAME: stripe.v2.billing.MeterEvent, + stripe.v2.billing.MeterEventAdjustment.OBJECT_NAME: stripe.v2.billing.MeterEventAdjustment, + stripe.v2.billing.MeterEventSession.OBJECT_NAME: stripe.v2.billing.MeterEventSession, + stripe.v2.Event.OBJECT_NAME: stripe.v2.Event, + # V2 Object classes: The end of the section generated from our OpenAPI spec +} diff --git a/stripe/_request_options.py b/stripe/_request_options.py index caa26fa50..e97cf1e1c 100644 --- a/stripe/_request_options.py +++ b/stripe/_request_options.py @@ -7,6 +7,7 @@ class RequestOptions(TypedDict): api_key: NotRequired["str|None"] stripe_version: NotRequired["str|None"] stripe_account: NotRequired["str|None"] + stripe_context: NotRequired["str|None"] max_network_retries: NotRequired["int|None"] idempotency_key: NotRequired["str|None"] content_type: NotRequired["str|None"] @@ -25,6 +26,7 @@ def merge_options( return { "api_key": requestor.api_key, "stripe_account": requestor.stripe_account, + "stripe_context": requestor.stripe_context, "stripe_version": requestor.stripe_version, "max_network_retries": requestor.max_network_retries, "idempotency_key": None, @@ -36,6 +38,8 @@ def merge_options( "api_key": request.get("api_key") or requestor.api_key, "stripe_account": request.get("stripe_account") or requestor.stripe_account, + "stripe_context": request.get("stripe_context") + or requestor.stripe_context, "stripe_version": request.get("stripe_version") or requestor.stripe_version, "max_network_retries": request.get("max_network_retries") @@ -62,6 +66,7 @@ def extract_options_from_dict( "api_key", "stripe_version", "stripe_account", + "stripe_context", "max_network_retries", "idempotency_key", "content_type", diff --git a/stripe/_requestor_options.py b/stripe/_requestor_options.py index 1314a8a71..6f8ebb328 100644 --- a/stripe/_requestor_options.py +++ b/stripe/_requestor_options.py @@ -8,6 +8,7 @@ class RequestorOptions(object): api_key: Optional[str] stripe_account: Optional[str] + stripe_context: Optional[str] stripe_version: Optional[str] base_addresses: BaseAddresses max_network_retries: Optional[int] @@ -16,12 +17,14 @@ def __init__( self, api_key: Optional[str] = None, stripe_account: Optional[str] = None, + stripe_context: Optional[str] = None, stripe_version: Optional[str] = None, base_addresses: BaseAddresses = {}, max_network_retries: Optional[int] = None, ): self.api_key = api_key self.stripe_account = stripe_account + self.stripe_context = stripe_context self.stripe_version = stripe_version self.base_addresses = {} @@ -33,6 +36,10 @@ def __init__( self.base_addresses["connect"] = base_addresses.get("connect") if base_addresses.get("files") is not None: self.base_addresses["files"] = base_addresses.get("files") + if base_addresses.get("meter_events") is not None: + self.base_addresses["meter_events"] = base_addresses.get( + "meter_events" + ) self.max_network_retries = max_network_retries @@ -43,6 +50,7 @@ def to_dict(self): return { "api_key": self.api_key, "stripe_account": self.stripe_account, + "stripe_context": self.stripe_context, "stripe_version": self.stripe_version, "base_addresses": self.base_addresses, "max_network_retries": self.max_network_retries, @@ -59,6 +67,7 @@ def base_addresses(self): "api": stripe.api_base, "connect": stripe.connect_api_base, "files": stripe.upload_api_base, + "meter_events": stripe.meter_events_api_base, } @property @@ -73,6 +82,10 @@ def stripe_version(self): def stripe_account(self): return None + @property + def stripe_context(self): + return None + @property def max_network_retries(self): return stripe.max_network_retries diff --git a/stripe/_stripe_client.py b/stripe/_stripe_client.py index 8f52c2da8..661c018b1 100644 --- a/stripe/_stripe_client.py +++ b/stripe/_stripe_client.py @@ -7,10 +7,13 @@ DEFAULT_API_BASE, DEFAULT_CONNECT_API_BASE, DEFAULT_UPLOAD_API_BASE, + DEFAULT_METER_EVENTS_API_BASE, ) +from stripe._api_mode import ApiMode from stripe._error import AuthenticationError from stripe._api_requestor import _APIRequestor +from stripe._request_options import extract_options_from_dict from stripe._requestor_options import RequestorOptions, BaseAddresses from stripe._client_options import _ClientOptions from stripe._http_client import ( @@ -19,10 +22,14 @@ new_http_client_async_fallback, ) from stripe._api_version import _ApiVersion +from stripe._stripe_object import StripeObject +from stripe._stripe_response import StripeResponse +from stripe._util import _convert_to_stripe_object, get_api_mode from stripe._webhook import Webhook, WebhookSignature from stripe._event import Event +from stripe.v2._event import ThinEvent -from typing import Optional, Union, cast +from typing import Any, Dict, Optional, Union, cast # Non-generated services from stripe._oauth_service import OAuthService @@ -100,6 +107,7 @@ from stripe._transfer_service import TransferService from stripe._treasury_service import TreasuryService from stripe._webhook_endpoint_service import WebhookEndpointService +from stripe._v2_services import V2Services # services: The end of the section generated from our OpenAPI spec @@ -109,6 +117,7 @@ def __init__( api_key: str, *, stripe_account: Optional[str] = None, + stripe_context: Optional[str] = None, stripe_version: Optional[str] = None, base_addresses: BaseAddresses = {}, client_id: Optional[str] = None, @@ -140,12 +149,14 @@ def __init__( "api": DEFAULT_API_BASE, "connect": DEFAULT_CONNECT_API_BASE, "files": DEFAULT_UPLOAD_API_BASE, + "meter_events": DEFAULT_METER_EVENTS_API_BASE, **base_addresses, } requestor_options = RequestorOptions( api_key=api_key, stripe_account=stripe_account, + stripe_context=stripe_context, stripe_version=stripe_version or _ApiVersion.CURRENT, base_addresses=base_addresses, max_network_retries=max_network_retries, @@ -252,9 +263,27 @@ def __init__( self.transfers = TransferService(self._requestor) self.treasury = TreasuryService(self._requestor) self.webhook_endpoints = WebhookEndpointService(self._requestor) + self.v2 = V2Services(self._requestor) # top-level services: The end of the section generated from our OpenAPI spec - def construct_event( + def parse_thin_event( + self, + raw: Union[bytes, str, bytearray], + sig_header: str, + secret: str, + tolerance: int = Webhook.DEFAULT_TOLERANCE, + ) -> ThinEvent: + payload = ( + cast(Union[bytes, bytearray], raw).decode("utf-8") + if hasattr(raw, "decode") + else cast(str, raw) + ) + + WebhookSignature.verify_header(payload, sig_header, secret, tolerance) + + return ThinEvent(payload) + + def parse_snapshot_event( self, payload: Union[bytes, str], sig_header: str, @@ -274,3 +303,68 @@ def construct_event( ) return event + + def raw_request(self, method_: str, url_: str, **params): + params = params.copy() + options, params = extract_options_from_dict(params) + api_mode = get_api_mode(url_) + base_address = params.pop("base", "api") + + stripe_context = params.pop("stripe_context", None) + + # stripe-context goes *here* and not in api_requestor. Properties + # go on api_requestor when you want them to persist onto requests + # made when you call instance methods on APIResources that come from + # the first request. No need for that here, as we aren't deserializing APIResources + if stripe_context is not None: + options["headers"] = options.get("headers", {}) + assert isinstance(options["headers"], dict) + options["headers"].update({"Stripe-Context": stripe_context}) + + rbody, rcode, rheaders = self._requestor.request_raw( + method_, + url_, + params=params, + options=options, + base_address=base_address, + api_mode=api_mode, + usage=["raw_request"], + ) + + return self._requestor._interpret_response( + rbody, rcode, rheaders, api_mode + ) + + async def raw_request_async(self, method_: str, url_: str, **params): + params = params.copy() + options, params = extract_options_from_dict(params) + api_mode = get_api_mode(url_) + base_address = params.pop("base", "api") + + rbody, rcode, rheaders = await self._requestor.request_raw_async( + method_, + url_, + params=params, + options=options, + base_address=base_address, + api_mode=api_mode, + usage=["raw_request"], + ) + + return self._requestor._interpret_response( + rbody, rcode, rheaders, api_mode + ) + + def deserialize( + self, + resp: Union[StripeResponse, Dict[str, Any]], + params: Optional[Dict[str, Any]] = None, + *, + api_mode: ApiMode, + ) -> StripeObject: + return _convert_to_stripe_object( + resp=resp, + params=params, + requestor=self._requestor, + api_mode=api_mode, + ) diff --git a/stripe/_stripe_object.py b/stripe/_stripe_object.py index 2cc00104d..e8fd042e6 100644 --- a/stripe/_stripe_object.py +++ b/stripe/_stripe_object.py @@ -81,8 +81,6 @@ class StripeObject(Dict[str, Any]): class _ReprJSONEncoder(json.JSONEncoder): def default(self, o: Any) -> Any: if isinstance(o, datetime.datetime): - # pyright complains that _encode_datetime is "private", but it's - # private to outsiders, not to stripe_object return _encode_datetime(o) return super(StripeObject._ReprJSONEncoder, self).default(o) diff --git a/stripe/_thin_event.py b/stripe/_thin_event.py new file mode 100644 index 000000000..ed14105f3 --- /dev/null +++ b/stripe/_thin_event.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- + +from stripe._stripe_object import StripeObject +from stripe._util import sanitize_id +from typing import Optional +from typing_extensions import Literal +from typing import ClassVar + + +# The beginning of the section generated from our OpenAPI spec +class ThinEvent(StripeObject): + OBJECT_NAME: ClassVar[Literal["event"]] = "event" + + class Reason(StripeObject): + class Request(StripeObject): + id: str + """ + ID of the API request that caused the event. + """ + idempotency_key: str + """ + The idempotency key transmitted during the request. + """ + + type: Literal["request"] + """ + Open Enum. Event reason type. + """ + request: Optional[Request] + """ + Information on the API request that instigated the event. + """ + _inner_class_types = {"request": Request} + + class RelatedObject(StripeObject): + id: str + """ + Unique identifier for the object relevant to the event. + """ + type: str + """ + Type of the object relevant to the event. + """ + url: str + """ + URL to retrieve the resource. + """ + + created: str + """ + Time at which the object was created. + """ + id: str + """ + Unique identifier for the event. + """ + object: Literal["event"] + """ + String representing the object's type. Objects of the same type share the same value of the object field. + """ + reason: Reason + """ + Reason for the event. + """ + related_object: RelatedObject + """ + Object containing the reference to API resource relevant to the event. + """ + type: str + """ + The type of the event. + """ + _inner_class_types = {"reason": Reason, "related_object": RelatedObject} + + # The end of the section generated from our OpenAPI spec + def fetch_object(self) -> StripeObject: + url = self.related_object.get("url") + if url is None: + raise ValueError( + "Unexpected: cannot call fetch_object on an event without a 'url' field" + ) + return self._requestor.request( + "get", + url, + base_address="api", + ) + + def fetch_data(self, event_data_cls) -> StripeObject: + full_event = self._requestor.request( + "get", + "/v2/events/{id}".format(id=sanitize_id(self.get("id"))), + base_address="api", + ) + assert isinstance(full_event, dict) + data = full_event["data"] + if data is None: + raise ValueError( + "Unexpected: fetch_data returned an event without a 'data' field" + ) + return event_data_cls._construct_from( + values=data, requestor=self._requestor, api_mode="V2" + ) diff --git a/stripe/_util.py b/stripe/_util.py index 6458d70a7..81777ee94 100644 --- a/stripe/_util.py +++ b/stripe/_util.py @@ -192,8 +192,19 @@ def secure_compare(val1, val2): return result == 0 -def get_object_classes(): +def get_thin_event_classes(): + from stripe.events._event_classes import THIN_EVENT_CLASSES + + return THIN_EVENT_CLASSES + + +def get_object_classes(api_mode): # This is here to avoid a circular dependency + if api_mode == "V2": + from stripe._object_classes import V2_OBJECT_CLASSES + + return V2_OBJECT_CLASSES + from stripe._object_classes import OBJECT_CLASSES return OBJECT_CLASSES @@ -310,7 +321,20 @@ def _convert_to_stripe_object( resp = resp.copy() klass_name = resp.get("object") if isinstance(klass_name, str): - klass = get_object_classes().get(klass_name, StripeObject) + if api_mode == "V2" and klass_name == "event": + event_name = resp.get("type", "") + klass = get_thin_event_classes().get( + event_name, stripe.StripeObject + ) + else: + klass = get_object_classes(api_mode).get( + klass_name, stripe.StripeObject + ) + # TODO: this is a horrible hack. The API needs + # to return something for `object` here. + + elif "data" in resp and "next_page_url" in resp: + klass = stripe.v2.ListObject elif klass_ is not None: klass = klass_ else: @@ -393,6 +417,13 @@ def sanitize_id(id): return quotedId +def get_api_mode(url): + if url.startswith("/v2"): + return "V2" + else: + return "V1" + + class class_method_variant(object): def __init__(self, class_method_name): self.class_method_name = class_method_name diff --git a/stripe/_v2_services.py b/stripe/_v2_services.py new file mode 100644 index 000000000..93e3e7064 --- /dev/null +++ b/stripe/_v2_services.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe._stripe_service import StripeService +from stripe.v2._billing_service import BillingService +from stripe.v2._core_service import CoreService + + +class V2Services(StripeService): + def __init__(self, requestor): + super().__init__(requestor) + self.billing = BillingService(self._requestor) + self.core = CoreService(self._requestor) diff --git a/stripe/events/__init__.py b/stripe/events/__init__.py new file mode 100644 index 000000000..bcf79de80 --- /dev/null +++ b/stripe/events/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe.events._v1_billing_meter_error_report_triggered_event import ( + V1BillingMeterErrorReportTriggeredEvent as V1BillingMeterErrorReportTriggeredEvent, +) +from stripe.events._v1_billing_meter_no_meter_found_event import ( + V1BillingMeterNoMeterFoundEvent as V1BillingMeterNoMeterFoundEvent, +) diff --git a/stripe/events/_event_classes.py b/stripe/events/_event_classes.py new file mode 100644 index 000000000..cfbfe23ba --- /dev/null +++ b/stripe/events/_event_classes.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe.events._v1_billing_meter_error_report_triggered_event import ( + V1BillingMeterErrorReportTriggeredEvent, +) +from stripe.events._v1_billing_meter_no_meter_found_event import ( + V1BillingMeterNoMeterFoundEvent, +) + + +THIN_EVENT_CLASSES = { + V1BillingMeterErrorReportTriggeredEvent.LOOKUP_TYPE: V1BillingMeterErrorReportTriggeredEvent, + V1BillingMeterNoMeterFoundEvent.LOOKUP_TYPE: V1BillingMeterNoMeterFoundEvent, +} diff --git a/stripe/events/_v1_billing_meter_error_report_triggered_event.py b/stripe/events/_v1_billing_meter_error_report_triggered_event.py new file mode 100644 index 000000000..f20157177 --- /dev/null +++ b/stripe/events/_v1_billing_meter_error_report_triggered_event.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe._stripe_object import StripeObject +from stripe.billing._meter import Meter +from stripe.v2._event import Event +from typing import List, cast +from typing_extensions import Literal + + +class V1BillingMeterErrorReportTriggeredEvent(Event): + LOOKUP_TYPE = "v1.billing.meter.error_report_triggered" + type: Literal["v1.billing.meter.error_report_triggered"] + + class V1BillingMeterErrorReportTriggeredEventData(StripeObject): + class Reason(StripeObject): + class ErrorType(StripeObject): + class SampleError(StripeObject): + class Request(StripeObject): + identifier: str + """ + The request idempotency key. + """ + + error_message: str + """ + The error message. + """ + request: Request + """ + The request causes the error. + """ + _inner_class_types = {"request": Request} + + code: Literal[ + "archived_meter", + "meter_event_customer_not_found", + "meter_event_dimension_count_too_high", + "meter_event_invalid_value", + "meter_event_no_customer_defined", + "missing_dimension_payload_keys", + "no_meter", + "timestamp_in_future", + "timestamp_too_far_in_past", + ] + """ + Open Enum. + """ + error_count: int + """ + The number of errors of this type. + """ + sample_errors: List[SampleError] + """ + A list of sample errors of this type. + """ + _inner_class_types = {"sample_errors": SampleError} + + error_count: int + """ + The total error count within this window. + """ + error_types: List[ErrorType] + """ + The error details. + """ + _inner_class_types = {"error_types": ErrorType} + + developer_message_summary: str + """ + Extra field included in the event's `data` when fetched from /v2/events. + """ + reason: Reason + """ + This contains information about why meter error happens. + """ + validation_end: str + """ + The end of the window that is encapsulated by this summary. + """ + validation_start: str + """ + The start of the window that is encapsulated by this summary. + """ + _inner_class_types = {"reason": Reason} + + data: V1BillingMeterErrorReportTriggeredEventData + """ + Data for the v1.billing.meter.error_report_triggered event + """ + + class RelatedObject(StripeObject): + id: str + """ + Unique identifier for the object relevant to the event. + """ + type: str + """ + Type of the object relevant to the event. + """ + url: str + """ + URL to retrieve the resource. + """ + + related_object: RelatedObject + """ + Object containing the reference to API resource relevant to the event + """ + + def fetch_related_object(self) -> Meter: + """ + Retrieves the related object from the API. Makes an API request on every call. + """ + return cast( + Meter, + self._requestor.request( + "get", + self.related_object.url, + base_address="api", + options={"stripe_account": self.context}, + ), + ) diff --git a/stripe/events/_v1_billing_meter_no_meter_found_event.py b/stripe/events/_v1_billing_meter_no_meter_found_event.py new file mode 100644 index 000000000..680c094aa --- /dev/null +++ b/stripe/events/_v1_billing_meter_no_meter_found_event.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe._stripe_object import StripeObject +from stripe.v2._event import Event +from typing import List +from typing_extensions import Literal + + +class V1BillingMeterNoMeterFoundEvent(Event): + LOOKUP_TYPE = "v1.billing.meter.no_meter_found" + type: Literal["v1.billing.meter.no_meter_found"] + + class V1BillingMeterNoMeterFoundEventData(StripeObject): + class Reason(StripeObject): + class ErrorType(StripeObject): + class SampleError(StripeObject): + class Request(StripeObject): + identifier: str + """ + The request idempotency key. + """ + + error_message: str + """ + The error message. + """ + request: Request + """ + The request causes the error. + """ + _inner_class_types = {"request": Request} + + code: Literal[ + "archived_meter", + "meter_event_customer_not_found", + "meter_event_dimension_count_too_high", + "meter_event_invalid_value", + "meter_event_no_customer_defined", + "missing_dimension_payload_keys", + "no_meter", + "timestamp_in_future", + "timestamp_too_far_in_past", + ] + """ + Open Enum. + """ + error_count: int + """ + The number of errors of this type. + """ + sample_errors: List[SampleError] + """ + A list of sample errors of this type. + """ + _inner_class_types = {"sample_errors": SampleError} + + error_count: int + """ + The total error count within this window. + """ + error_types: List[ErrorType] + """ + The error details. + """ + _inner_class_types = {"error_types": ErrorType} + + developer_message_summary: str + """ + Extra field included in the event's `data` when fetched from /v2/events. + """ + reason: Reason + """ + This contains information about why meter error happens. + """ + validation_end: str + """ + The end of the window that is encapsulated by this summary. + """ + validation_start: str + """ + The start of the window that is encapsulated by this summary. + """ + _inner_class_types = {"reason": Reason} + + data: V1BillingMeterNoMeterFoundEventData + """ + Data for the v1.billing.meter.no_meter_found event + """ diff --git a/stripe/v2/__init__.py b/stripe/v2/__init__.py new file mode 100644 index 000000000..d8a2170e1 --- /dev/null +++ b/stripe/v2/__init__.py @@ -0,0 +1,10 @@ +from stripe.v2._list_object import ListObject as ListObject +from stripe.v2._amount import Amount as Amount, AmountParam as AmountParam + + +# The beginning of the section generated from our OpenAPI spec +from stripe.v2 import billing as billing, core as core +from stripe.v2._billing_service import BillingService as BillingService +from stripe.v2._core_service import CoreService as CoreService +from stripe.v2._event import Event as Event +# The end of the section generated from our OpenAPI spec diff --git a/stripe/v2/_amount.py b/stripe/v2/_amount.py new file mode 100644 index 000000000..97a6bf63a --- /dev/null +++ b/stripe/v2/_amount.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# NOT codegenned +from typing_extensions import TypedDict +from stripe._stripe_object import StripeObject + + +class Amount(StripeObject): + value: int + currency: str + + +class AmountParam(TypedDict): + value: int + currency: str diff --git a/stripe/v2/_billing_service.py b/stripe/v2/_billing_service.py new file mode 100644 index 000000000..77d36d39a --- /dev/null +++ b/stripe/v2/_billing_service.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe._stripe_service import StripeService +from stripe.v2.billing._meter_event_adjustment_service import ( + MeterEventAdjustmentService, +) +from stripe.v2.billing._meter_event_service import MeterEventService +from stripe.v2.billing._meter_event_session_service import ( + MeterEventSessionService, +) +from stripe.v2.billing._meter_event_stream_service import ( + MeterEventStreamService, +) + + +class BillingService(StripeService): + def __init__(self, requestor): + super().__init__(requestor) + self.meter_event_session = MeterEventSessionService(self._requestor) + self.meter_event_adjustments = MeterEventAdjustmentService( + self._requestor, + ) + self.meter_event_stream = MeterEventStreamService(self._requestor) + self.meter_events = MeterEventService(self._requestor) diff --git a/stripe/v2/_core_service.py b/stripe/v2/_core_service.py new file mode 100644 index 000000000..96c4a6f2e --- /dev/null +++ b/stripe/v2/_core_service.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe._stripe_service import StripeService +from stripe.v2.core._event_service import EventService + + +class CoreService(StripeService): + def __init__(self, requestor): + super().__init__(requestor) + self.events = EventService(self._requestor) diff --git a/stripe/v2/_event.py b/stripe/v2/_event.py new file mode 100644 index 000000000..2abb588dc --- /dev/null +++ b/stripe/v2/_event.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- + +import json +from typing import ClassVar, Optional + +from typing_extensions import Literal + +from stripe._stripe_object import StripeObject + +# This describes the common format for the pull payload of a V2 ThinEvent +# more specific classes will add `data` and `fetch_related_objects()` as needed + + +# The beginning of the section generated from our OpenAPI spec +class Event(StripeObject): + OBJECT_NAME: ClassVar[Literal["v2.core.event"]] = "v2.core.event" + + class Reason(StripeObject): + class Request(StripeObject): + id: str + """ + ID of the API request that caused the event. + """ + idempotency_key: str + """ + The idempotency key transmitted during the request. + """ + + type: Literal["request"] + """ + Open Enum. Event reason type. + """ + request: Optional[Request] + """ + Information on the API request that instigated the event. + """ + _inner_class_types = {"request": Request} + + context: Optional[str] + """ + Authentication context needed to fetch the event or related object. + """ + created: str + """ + Time at which the object was created. + """ + id: str + """ + Unique identifier for the event. + """ + livemode: bool + """ + Has the value `true` if the object exists in live mode or the value `false` if the object exists in test mode. + """ + object: Literal["v2.core.event"] + """ + String representing the object's type. Objects of the same type share the same value of the object field. + """ + reason: Optional[Reason] + """ + Reason for the event. + """ + type: str + """ + The type of the event. + """ + _inner_class_types = {"reason": Reason} + + +# The end of the section generated from our OpenAPI spec + + +class Reason: + id: str + idempotency_key: str + + def __init__(self, d) -> None: + self.id = d["id"] + self.idempotency_key = d["idempotency_key"] + + def __repr__(self) -> str: + return f"" + + +class RelatedObject: + id: str + type: str + url: str + + def __init__(self, d) -> None: + self.id = d["id"] + self.type_ = d["type"] + self.url = d["url"] + + def __repr__(self) -> str: + return f"" + + +class ThinEvent: + """ + ThinEvent represents the json that's delivered from an Event Destination. It's a basic `dict` with no additional methods or properties. Use it to check basic information about a delivered event. If you want more details, use `stripe.v2.Event.retrieve(thin_event.id)` to fetch the full event object. + """ + + id: str + type: str + created: str + context: Optional[str] = None + related_object: Optional[RelatedObject] = None + reason: Optional[Reason] = None + + def __init__(self, payload: str) -> None: + parsed = json.loads(payload) + + self.id = parsed["id"] + self.type = parsed["type"] + self.created = parsed["created"] + self.context = parsed.get("context") + if parsed.get("related_object"): + self.related_object = RelatedObject(parsed["related_object"]) + if parsed.get("reason"): + self.reason = Reason(parsed["reason"]) + + def __repr__(self) -> str: + return f"" diff --git a/stripe/v2/_list_object.py b/stripe/v2/_list_object.py new file mode 100644 index 000000000..a9d73546c --- /dev/null +++ b/stripe/v2/_list_object.py @@ -0,0 +1,59 @@ +from stripe._stripe_object import StripeObject +from typing import List, Optional, TypeVar, Generic + + +T = TypeVar("T", bound=StripeObject) + + +class ListObject(StripeObject, Generic[T]): + """ + Represents one page of a list of V2 Stripe objects. Use `.data` to access + the objects on this page, or use + + for item in list_object.auto_paging_iter(): + # do something with item + + to iterate over this and all following pages. + """ + + OBJECT_NAME = "list" + data: List[StripeObject] + next_page_url: Optional[str] + + def __getitem__(self, k): + if isinstance(k, str): # type: ignore + return super(ListObject, self).__getitem__(k) + else: + raise KeyError( + "You tried to access the %s index, but ListObjectV2 types only " + "support string keys. (HINT: List calls return an object with " + "a 'data' (which is the data array). You likely want to call " + ".data[%s])" % (repr(k), repr(k)) + ) + + def __iter__(self): + return getattr(self, "data", []).__iter__() + + def __len__(self): + return getattr(self, "data", []).__len__() + + def __reversed__(self): + return getattr(self, "data", []).__reversed__() + + def auto_paging_iter(self): + page = self.data + next_page_url = self.next_page_url + while True: + for item in page: + yield item + if next_page_url is None: + break + + result = self._request( + "get", + next_page_url, + base_address="api", + ) + assert isinstance(result, ListObject) + page = result.data + next_page_url = result.next_page_url diff --git a/stripe/v2/billing/__init__.py b/stripe/v2/billing/__init__.py new file mode 100644 index 000000000..ff5fd91c6 --- /dev/null +++ b/stripe/v2/billing/__init__.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe.v2.billing._meter_event import MeterEvent as MeterEvent +from stripe.v2.billing._meter_event_adjustment import ( + MeterEventAdjustment as MeterEventAdjustment, +) +from stripe.v2.billing._meter_event_adjustment_service import ( + MeterEventAdjustmentService as MeterEventAdjustmentService, +) +from stripe.v2.billing._meter_event_service import ( + MeterEventService as MeterEventService, +) +from stripe.v2.billing._meter_event_session import ( + MeterEventSession as MeterEventSession, +) +from stripe.v2.billing._meter_event_session_service import ( + MeterEventSessionService as MeterEventSessionService, +) +from stripe.v2.billing._meter_event_stream_service import ( + MeterEventStreamService as MeterEventStreamService, +) diff --git a/stripe/v2/billing/_meter_event.py b/stripe/v2/billing/_meter_event.py new file mode 100644 index 000000000..ce33c36cd --- /dev/null +++ b/stripe/v2/billing/_meter_event.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe._stripe_object import StripeObject +from typing import ClassVar, Dict +from typing_extensions import Literal + + +class MeterEvent(StripeObject): + """ + Fix me empty_doc_string. + """ + + OBJECT_NAME: ClassVar[Literal["billing.meter_event"]] = ( + "billing.meter_event" + ) + created: str + """ + The creation time of this meter event. + """ + event_name: str + """ + The name of the meter event. Corresponds with the `event_name` field on a meter. + """ + identifier: str + """ + A unique identifier for the event. If not provided, one will be generated. We recommend using a globally unique identifier for this. We'll enforce uniqueness within a rolling 24 hour period. + """ + livemode: bool + """ + Has the value `true` if the object exists in live mode or the value `false` if the object exists in test mode. + """ + object: Literal["billing.meter_event"] + """ + String representing the object's type. Objects of the same type share the same value of the object field. + """ + payload: Dict[str, str] + """ + The payload of the event. This must contain the fields corresponding to a meter's + `customer_mapping.event_payload_key` (default is `stripe_customer_id`) and + `value_settings.event_payload_key` (default is `value`). Read more about the payload. + """ + timestamp: str + """ + The time of the event. Must be within the past 35 calendar days or up to + 5 minutes in the future. Defaults to current timestamp if not specified. + """ diff --git a/stripe/v2/billing/_meter_event_adjustment.py b/stripe/v2/billing/_meter_event_adjustment.py new file mode 100644 index 000000000..7561e67ba --- /dev/null +++ b/stripe/v2/billing/_meter_event_adjustment.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe._stripe_object import StripeObject +from typing import ClassVar +from typing_extensions import Literal + + +class MeterEventAdjustment(StripeObject): + OBJECT_NAME: ClassVar[Literal["billing.meter_event_adjustment"]] = ( + "billing.meter_event_adjustment" + ) + + class Cancel(StripeObject): + identifier: str + """ + Unique identifier for the event. You can only cancel events within 24 hours of Stripe receiving them. + """ + + cancel: Cancel + """ + Specifies which event to cancel. + """ + created: str + """ + The time the adjustment was created. + """ + event_name: str + """ + The name of the meter event. Corresponds with the `event_name` field on a meter. + """ + id: str + """ + The unique id of this meter event adjustment. + """ + livemode: bool + """ + Has the value `true` if the object exists in live mode or the value `false` if the object exists in test mode. + """ + object: Literal["billing.meter_event_adjustment"] + """ + String representing the object's type. Objects of the same type share the same value of the object field. + """ + status: Literal["complete", "pending"] + """ + Open Enum. The meter event adjustment's status. + """ + type: Literal["cancel"] + """ + Open Enum. Specifies whether to cancel a single event or a range of events for a time period. Time period cancellation is not supported yet. + """ + _inner_class_types = {"cancel": Cancel} diff --git a/stripe/v2/billing/_meter_event_adjustment_service.py b/stripe/v2/billing/_meter_event_adjustment_service.py new file mode 100644 index 000000000..9533243f8 --- /dev/null +++ b/stripe/v2/billing/_meter_event_adjustment_service.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe._request_options import RequestOptions +from stripe._stripe_service import StripeService +from stripe.v2.billing._meter_event_adjustment import MeterEventAdjustment +from typing import cast +from typing_extensions import Literal, TypedDict + + +class MeterEventAdjustmentService(StripeService): + class CreateParams(TypedDict): + cancel: "MeterEventAdjustmentService.CreateParamsCancel" + """ + Specifies which event to cancel. + """ + event_name: str + """ + The name of the meter event. Corresponds with the `event_name` field on a meter. + """ + type: Literal["cancel"] + """ + Specifies whether to cancel a single event or a range of events for a time period. Time period cancellation is not supported yet. + """ + + class CreateParamsCancel(TypedDict): + identifier: str + """ + Unique identifier for the event. You can only cancel events within 24 hours of Stripe receiving them. + """ + + def create( + self, + params: "MeterEventAdjustmentService.CreateParams", + options: RequestOptions = {}, + ) -> MeterEventAdjustment: + """ + Creates a meter event adjustment to cancel a previously sent meter event. + """ + return cast( + MeterEventAdjustment, + self._request( + "post", + "/v2/billing/meter_event_adjustments", + base_address="api", + params=params, + options=options, + ), + ) + + async def create_async( + self, + params: "MeterEventAdjustmentService.CreateParams", + options: RequestOptions = {}, + ) -> MeterEventAdjustment: + """ + Creates a meter event adjustment to cancel a previously sent meter event. + """ + return cast( + MeterEventAdjustment, + await self._request_async( + "post", + "/v2/billing/meter_event_adjustments", + base_address="api", + params=params, + options=options, + ), + ) diff --git a/stripe/v2/billing/_meter_event_service.py b/stripe/v2/billing/_meter_event_service.py new file mode 100644 index 000000000..50eb75009 --- /dev/null +++ b/stripe/v2/billing/_meter_event_service.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe._request_options import RequestOptions +from stripe._stripe_service import StripeService +from stripe.v2.billing._meter_event import MeterEvent +from typing import Dict, cast +from typing_extensions import NotRequired, TypedDict + + +class MeterEventService(StripeService): + class CreateParams(TypedDict): + event_name: str + """ + The name of the meter event. Corresponds with the `event_name` field on a meter. + """ + identifier: NotRequired[str] + """ + A unique identifier for the event. If not provided, one will be generated. + We recommend using a globally unique identifier for this. We'll enforce + uniqueness within a rolling 24 hour period. + """ + payload: Dict[str, str] + """ + The payload of the event. This must contain the fields corresponding to a meter's + `customer_mapping.event_payload_key` (default is `stripe_customer_id`) and + `value_settings.event_payload_key` (default is `value`). Read more about + the + [payload](https://docs.stripe.com/billing/subscriptions/usage-based/recording-usage#payload-key-overrides). + """ + timestamp: NotRequired[str] + """ + The time of the event. Must be within the past 35 calendar days or up to + 5 minutes in the future. Defaults to current timestamp if not specified. + """ + + def create( + self, + params: "MeterEventService.CreateParams", + options: RequestOptions = {}, + ) -> MeterEvent: + """ + Creates a meter event. Events are validated synchronously, but are processed asynchronously. Supports up to 1,000 events per second in livemode. For higher rate-limits, please use meter event streams instead. + """ + return cast( + MeterEvent, + self._request( + "post", + "/v2/billing/meter_events", + base_address="api", + params=params, + options=options, + ), + ) + + async def create_async( + self, + params: "MeterEventService.CreateParams", + options: RequestOptions = {}, + ) -> MeterEvent: + """ + Creates a meter event. Events are validated synchronously, but are processed asynchronously. Supports up to 1,000 events per second in livemode. For higher rate-limits, please use meter event streams instead. + """ + return cast( + MeterEvent, + await self._request_async( + "post", + "/v2/billing/meter_events", + base_address="api", + params=params, + options=options, + ), + ) diff --git a/stripe/v2/billing/_meter_event_session.py b/stripe/v2/billing/_meter_event_session.py new file mode 100644 index 000000000..f1d96650e --- /dev/null +++ b/stripe/v2/billing/_meter_event_session.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe._stripe_object import StripeObject +from typing import ClassVar +from typing_extensions import Literal + + +class MeterEventSession(StripeObject): + OBJECT_NAME: ClassVar[Literal["billing.meter_event_session"]] = ( + "billing.meter_event_session" + ) + authentication_token: str + """ + The authentication token for this session. Use this token when calling the + high-throughput meter event API. + """ + created: str + """ + The creation time of this session. + """ + expires_at: str + """ + The time at which this session will expire. + """ + id: str + """ + The unique id of this auth session. + """ + livemode: bool + """ + Has the value `true` if the object exists in live mode or the value `false` if the object exists in test mode. + """ + object: Literal["billing.meter_event_session"] + """ + String representing the object's type. Objects of the same type share the same value of the object field. + """ diff --git a/stripe/v2/billing/_meter_event_session_service.py b/stripe/v2/billing/_meter_event_session_service.py new file mode 100644 index 000000000..600c80362 --- /dev/null +++ b/stripe/v2/billing/_meter_event_session_service.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe._request_options import RequestOptions +from stripe._stripe_service import StripeService +from stripe.v2.billing._meter_event_session import MeterEventSession +from typing import cast +from typing_extensions import TypedDict + + +class MeterEventSessionService(StripeService): + class CreateParams(TypedDict): + pass + + def create( + self, + params: "MeterEventSessionService.CreateParams" = {}, + options: RequestOptions = {}, + ) -> MeterEventSession: + """ + Creates a meter event session to send usage on the high-throughput meter event stream. Authentication tokens are only valid for 15 minutes, so you will need to create a new meter event session when your token expires. + """ + return cast( + MeterEventSession, + self._request( + "post", + "/v2/billing/meter_event_session", + base_address="api", + params=params, + options=options, + ), + ) + + async def create_async( + self, + params: "MeterEventSessionService.CreateParams" = {}, + options: RequestOptions = {}, + ) -> MeterEventSession: + """ + Creates a meter event session to send usage on the high-throughput meter event stream. Authentication tokens are only valid for 15 minutes, so you will need to create a new meter event session when your token expires. + """ + return cast( + MeterEventSession, + await self._request_async( + "post", + "/v2/billing/meter_event_session", + base_address="api", + params=params, + options=options, + ), + ) diff --git a/stripe/v2/billing/_meter_event_stream_service.py b/stripe/v2/billing/_meter_event_stream_service.py new file mode 100644 index 000000000..84e6908e5 --- /dev/null +++ b/stripe/v2/billing/_meter_event_stream_service.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe._request_options import RequestOptions +from stripe._stripe_service import StripeService +from typing import Dict, List +from typing_extensions import NotRequired, TypedDict + + +class MeterEventStreamService(StripeService): + class CreateParams(TypedDict): + events: List["MeterEventStreamService.CreateParamsEvent"] + """ + List of meter events to include in the request. + """ + + class CreateParamsEvent(TypedDict): + event_name: str + """ + The name of the meter event. Corresponds with the `event_name` field on a meter. + """ + identifier: NotRequired[str] + """ + A unique identifier for the event. If not provided, one will be generated. + We recommend using a globally unique identifier for this. We'll enforce + uniqueness within a rolling 24 hour period. + """ + payload: Dict[str, str] + """ + The payload of the event. This must contain the fields corresponding to a meter's + `customer_mapping.event_payload_key` (default is `stripe_customer_id`) and + `value_settings.event_payload_key` (default is `value`). Read more about + the + [payload](https://docs.stripe.com/billing/subscriptions/usage-based/recording-usage#payload-key-overrides). + """ + timestamp: NotRequired[str] + """ + The time of the event. Must be within the past 35 calendar days or up to + 5 minutes in the future. Defaults to current timestamp if not specified. + """ + + def create( + self, + params: "MeterEventStreamService.CreateParams", + options: RequestOptions = {}, + ) -> None: + """ + Creates meter events. Events are processed asynchronously, including validation. Requires a meter event session for authentication. Supports up to 10,000 requests per second in livemode. For even higher rate-limits, contact sales. + """ + self._request( + "post", + "/v2/billing/meter_event_stream", + base_address="meter_events", + params=params, + options=options, + ) + + async def create_async( + self, + params: "MeterEventStreamService.CreateParams", + options: RequestOptions = {}, + ) -> None: + """ + Creates meter events. Events are processed asynchronously, including validation. Requires a meter event session for authentication. Supports up to 10,000 requests per second in livemode. For even higher rate-limits, contact sales. + """ + await self._request_async( + "post", + "/v2/billing/meter_event_stream", + base_address="meter_events", + params=params, + options=options, + ) diff --git a/stripe/v2/core/__init__.py b/stripe/v2/core/__init__.py new file mode 100644 index 000000000..5879a6c60 --- /dev/null +++ b/stripe/v2/core/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe.v2.core._event_service import EventService as EventService diff --git a/stripe/v2/core/_event_service.py b/stripe/v2/core/_event_service.py new file mode 100644 index 000000000..7bffa637e --- /dev/null +++ b/stripe/v2/core/_event_service.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# File generated from our OpenAPI spec +from stripe._request_options import RequestOptions +from stripe._stripe_service import StripeService +from stripe._util import sanitize_id +from stripe.v2._event import Event +from stripe.v2._list_object import ListObject +from typing import cast +from typing_extensions import NotRequired, TypedDict + + +class EventService(StripeService): + class ListParams(TypedDict): + limit: NotRequired[int] + object_id: str + """ + Primary object ID used to retrieve related events. + """ + page: NotRequired[str] + + class RetrieveParams(TypedDict): + pass + + def list( + self, params: "EventService.ListParams", options: RequestOptions = {} + ) -> ListObject[Event]: + """ + List events, going back up to 30 days. + """ + return cast( + ListObject[Event], + self._request( + "get", + "/v2/core/events", + base_address="api", + params=params, + options=options, + ), + ) + + async def list_async( + self, params: "EventService.ListParams", options: RequestOptions = {} + ) -> ListObject[Event]: + """ + List events, going back up to 30 days. + """ + return cast( + ListObject[Event], + await self._request_async( + "get", + "/v2/core/events", + base_address="api", + params=params, + options=options, + ), + ) + + def retrieve( + self, + id: str, + params: "EventService.RetrieveParams" = {}, + options: RequestOptions = {}, + ) -> Event: + """ + Retrieves the details of an event. + """ + return cast( + Event, + self._request( + "get", + "/v2/core/events/{id}".format(id=sanitize_id(id)), + base_address="api", + params=params, + options=options, + ), + ) + + async def retrieve_async( + self, + id: str, + params: "EventService.RetrieveParams" = {}, + options: RequestOptions = {}, + ) -> Event: + """ + Retrieves the details of an event. + """ + return cast( + Event, + await self._request_async( + "get", + "/v2/core/events/{id}".format(id=sanitize_id(id)), + base_address="api", + params=params, + options=options, + ), + ) diff --git a/tests/api_resources/abstract/test_api_resource.py b/tests/api_resources/abstract/test_api_resource.py index ebb533c9f..005309f20 100644 --- a/tests/api_resources/abstract/test_api_resource.py +++ b/tests/api_resources/abstract/test_api_resource.py @@ -95,7 +95,7 @@ def test_convert_to_stripe_object(self): } converted = stripe.util.convert_to_stripe_object( - sample, "akey", None, None + sample, "akey", None, None, api_mode="V1" ) # Types diff --git a/tests/api_resources/test_list_object.py b/tests/api_resources/test_list_object.py index 98a4e72ce..fe6340a14 100644 --- a/tests/api_resources/test_list_object.py +++ b/tests/api_resources/test_list_object.py @@ -95,13 +95,15 @@ def test_empty_list(self): def test_iter(self): arr = [{"id": 1}, {"id": 2}, {"id": 3}] - expected = stripe.util.convert_to_stripe_object(arr) + expected = stripe.util.convert_to_stripe_object(arr, api_mode="V1") lo = stripe.ListObject.construct_from({"data": arr}, None) assert list(lo) == expected def test_iter_reversed(self): arr = [{"id": 1}, {"id": 2}, {"id": 3}] - expected = stripe.util.convert_to_stripe_object(list(reversed(arr))) + expected = stripe.util.convert_to_stripe_object( + list(reversed(arr)), api_mode="V1" + ) lo = stripe.ListObject.construct_from({"data": arr}, None) assert list(reversed(lo)) == expected diff --git a/tests/api_resources/test_list_object_v2.py b/tests/api_resources/test_list_object_v2.py new file mode 100644 index 000000000..d86ed54c5 --- /dev/null +++ b/tests/api_resources/test_list_object_v2.py @@ -0,0 +1,147 @@ +from __future__ import absolute_import, division, print_function + +import json + +import pytest + +import stripe +from stripe.v2._list_object import ListObject +from tests.http_client_mock import HTTPClientMock + + +class TestListObjectV2(object): + @pytest.fixture + def list_object(self): + return ListObject.construct_from( + { + "data": ["a", "b", "c"], + "next_page_url": None, + "previous_page_url": None, + }, + "mykey", + ) + + def test_iter(self): + arr = ["a", "b", "c"] + expected = stripe.util.convert_to_stripe_object(arr, api_mode="V2") + lo = ListObject.construct_from({"data": arr}, None) + assert list(lo) == expected + + @staticmethod + def pageable_model_response(ids, next_page_url): + return { + "data": [{"id": id, "object": "pageablemodel"} for id in ids], + "next_page_url": next_page_url, + } + + def test_iter_one_page(self, http_client_mock): + lo = ListObject.construct_from( + self.pageable_model_response(["pm_123", "pm_124"], None), "mykey" + ) + + http_client_mock.assert_no_request() + + seen = [item["id"] for item in lo.auto_paging_iter()] + + assert seen == ["pm_123", "pm_124"] + + def test_iter_two_pages(self, http_client_mock): + method = "get" + path = "/v2/pageablemodels" + + lo = ListObject.construct_from( + self.pageable_model_response( + ["pm_123", "pm_124"], "/v2/pageablemodels?foo=bar&page=page_2" + ), + None, + ) + + http_client_mock.stub_request( + method, + path=path, + query_string="foo=bar&page=page_3", + rbody=json.dumps( + self.pageable_model_response(["pm_127", "pm_128"], None) + ), + ) + + http_client_mock.stub_request( + method, + path=path, + query_string="foo=bar&page=page_2", + rbody=json.dumps( + self.pageable_model_response( + ["pm_125", "pm_126"], + "/v2/pageablemodels?foo=bar&page=page_3", + ) + ), + ) + + seen = [item["id"] for item in lo.auto_paging_iter()] + + http_client_mock.assert_requested( + method, path=path, query_string="foo=bar&page=page_2" + ) + + http_client_mock.assert_requested( + method, path=path, query_string="foo=bar&page=page_3" + ) + + assert seen == [ + "pm_123", + "pm_124", + "pm_125", + "pm_126", + "pm_127", + "pm_128", + ] + + def test_iter_forwards_api_key(self, http_client_mock: HTTPClientMock): + client = stripe.StripeClient( + http_client=http_client_mock.get_mock_http_client(), + api_key="sk_test_xyz", + ) + + method = "get" + query_string_1 = "object_id=obj_123" + query_string_2 = "object_id=obj_123&page=page_2" + path = "/v2/core/events" + + http_client_mock.stub_request( + method, + path=path, + query_string=query_string_1, + rbody='{"data": [{"id": "x"}], "next_page_url": "/v2/core/events?object_id=obj_123&page=page_2"}', + rcode=200, + rheaders={}, + ) + + http_client_mock.stub_request( + method, + path=path, + query_string=query_string_2, + rbody='{"data": [{"id": "y"}, {"id": "z"}], "next_page_url": null}', + rcode=200, + rheaders={}, + ) + + lo = client.v2.core.events.list( + params={"object_id": "obj_123"}, + options={"api_key": "sk_test_iter_forwards_options"}, + ) + + seen = [item["id"] for item in lo.auto_paging_iter()] + + assert seen == ["x", "y", "z"] + http_client_mock.assert_requested( + method, + path=path, + query_string=query_string_1, + api_key="sk_test_iter_forwards_options", + ) + http_client_mock.assert_requested( + method, + path=path, + query_string=query_string_2, + api_key="sk_test_iter_forwards_options", + ) diff --git a/tests/api_resources/test_search_result_object.py b/tests/api_resources/test_search_result_object.py index 6e4d3f535..83266f0b9 100644 --- a/tests/api_resources/test_search_result_object.py +++ b/tests/api_resources/test_search_result_object.py @@ -82,7 +82,7 @@ def test_empty_search_result(self): def test_iter(self): arr = [{"id": 1}, {"id": 2}, {"id": 3}] - expected = stripe.util.convert_to_stripe_object(arr) + expected = stripe.util.convert_to_stripe_object(arr, api_mode="V1") sro = stripe.SearchResultObject.construct_from({"data": arr}, None) assert list(sro) == expected diff --git a/tests/fixtures/card.json b/tests/fixtures/card.json deleted file mode 100644 index 97a40a318..000000000 --- a/tests/fixtures/card.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "id": "card_123", - "object": "payment_methods.card", - "address_city": null, - "address_country": null, - "address_line1": null, - "address_line1_check": null, - "address_line2": null, - "address_state": null, - "address_zip": null, - "address_zip_check": null, - "brand": "Visa", - "country": "US", - "customer": "cus_123", - "cvc_check": null, - "dynamic_last4": null, - "exp_month": 8, - "exp_year": 2019, - "fingerprint": "Xt5EWLLDS7FJjR1c", - "funding": "credit", - "last4": "4242", - "metadata": { - }, - "name": null, - "tokenization_method": null -} diff --git a/tests/fixtures/financial_account.json b/tests/fixtures/financial_account.json new file mode 100644 index 000000000..a3020431a --- /dev/null +++ b/tests/fixtures/financial_account.json @@ -0,0 +1,14 @@ +{ + "id": "fa_123", + "object": "financial_account", + "balance_types": [ + "storage" + ], + "country": "US", + "created": "2023-10-20T05:47:49.766Z", + "description": "", + "requested_currencies": [ + "usd" + ], + "status": "open" +} diff --git a/tests/http_client_mock.py b/tests/http_client_mock.py index 0e2ea588a..db13f432c 100644 --- a/tests/http_client_mock.py +++ b/tests/http_client_mock.py @@ -336,6 +336,7 @@ def assert_requested( api_key=None, stripe_version=None, stripe_account=None, + stripe_context=None, content_type=None, idempotency_key=None, user_agent=None, @@ -366,6 +367,7 @@ def assert_requested( api_key=api_key, stripe_version=stripe_version, stripe_account=stripe_account, + stripe_context=stripe_context, content_type=content_type, idempotency_key=idempotency_key, user_agent=user_agent, diff --git a/tests/test_api_requestor.py b/tests/test_api_requestor.py index 8acb0e87d..82f10e5a3 100644 --- a/tests/test_api_requestor.py +++ b/tests/test_api_requestor.py @@ -3,25 +3,25 @@ import tempfile import uuid from collections import OrderedDict +from urllib.parse import urlencode, urlsplit import pytest +import urllib3 import stripe from stripe import util +from stripe._api_requestor import _api_encode, _APIRequestor +from stripe._request_options import RequestOptions +from stripe._requestor_options import ( + RequestorOptions, + _GlobalRequestorOptions, +) +from stripe._stripe_object import StripeObject from stripe._stripe_response import ( StripeStreamResponse, StripeStreamResponseAsync, ) -from stripe._api_requestor import _APIRequestor, _api_encode -from stripe._stripe_object import StripeObject -from stripe._requestor_options import ( - _GlobalRequestorOptions, -) -from stripe._request_options import RequestOptions - -from urllib.parse import urlencode, urlsplit - -import urllib3 +from tests.http_client_mock import HTTPClientMock VALID_API_METHODS = ("get", "post", "delete") @@ -49,6 +49,19 @@ def __repr__(self): return "AnyUUID4Matcher()" +class IsNoneMatcher: + """ + Matcher to make assertions against None because `assert_requested` doesn't + run checks if you pass `None` as the expected value. + """ + + def __eq__(self, other): + return other is None + + def __repr__(self): + return "None (from IsNoneMatcher())" + + class TestAPIRequestor(object): ENCODE_INPUTS = { "dict": { @@ -113,8 +126,12 @@ def requestor(self, http_client_mock): return requestor @property - def valid_path(self): - return "/foo" + def v1_path(self): + return "/v1/foo" + + @property + def v2_path(self): + return "/v2/foo" def encoder_check(self, key): stk_key = "my%s" % (key,) @@ -162,9 +179,46 @@ def test_param_encoding(self, requestor, http_client_mock): http_client_mock.assert_requested("get", query_string=query_string) + def test_param_api_mode_preview(self, requestor, http_client_mock): + http_client_mock.stub_request( + "post", path=self.v2_path, rbody="{}", rcode=200 + ) + + requestor.request( + "post", self.v2_path, self.ENCODE_INPUTS, base_address="api" + ) + + expectation = '{"dict": {"astring": "bar", "anint": 5, "anull": null, "adatetime": 1356994800, "atuple": [1, 2], "adict": {"foo": "bar", "boz": 5}, "alist": ["foo", "bar"]}, "list": [1, "foo", "baz"], "string": "boo", "unicode": "\\u1234", "datetime": 1356994801, "none": null}' + + http_client_mock.assert_requested( + "post", + content_type="application/json", + post_data=expectation, + is_json=True, + ) + + def test_encodes_null_values_preview(self, requestor, http_client_mock): + http_client_mock.stub_request( + "post", path=self.v2_path, rbody="{}", rcode=200 + ) + + requestor.request( + "post", + self.v2_path, + {"foo": None}, + base_address="api", + ) + + http_client_mock.assert_requested( + "post", + content_type="application/json", + post_data='{"foo": null}', + is_json=True, + ) + def test_dictionary_list_encoding(self): params = {"foo": {"0": {"bar": "bat"}}} - encoded = list(_api_encode(params)) + encoded = list(_api_encode(params, "V1")) key, value = encoded[0] assert key == "foo[0][bar]" @@ -181,7 +235,7 @@ def test_ordereddict_encoding(self): ] ) } - encoded = list(_api_encode(params)) + encoded = list(_api_encode(params, "V1")) assert encoded[0][0] == "ordered[one]" assert encoded[1][0] == "ordered[two]" @@ -224,11 +278,11 @@ def test_url_construction(self, requestor, http_client_mock): def test_empty_methods(self, requestor, http_client_mock): for meth in VALID_API_METHODS: http_client_mock.stub_request( - meth, path=self.valid_path, rbody="{}", rcode=200 + meth, path=self.v1_path, rbody="{}", rcode=200 ) resp = requestor.request( - meth, self.valid_path, {}, base_address="api" + meth, self.v1_path, {}, base_address="api" ) if meth == "post": @@ -246,13 +300,13 @@ async def test_empty_methods_async(self, requestor, http_client_mock): for meth in VALID_API_METHODS: http_client_mock.stub_request( meth, - path=self.valid_path, + path=self.v1_path, rbody="{}", rcode=200, ) resp = await requestor.request_async( - meth, self.valid_path, {}, base_address="api" + meth, self.v1_path, {}, base_address="api" ) if meth == "post": @@ -277,14 +331,14 @@ async def async_iter(): for meth in VALID_API_METHODS: http_client_mock.stub_request( meth, - path=self.valid_path, + path=self.v1_path, rbody=async_iter(), rcode=200, ) resp = await requestor.request_stream_async( meth, - self.valid_path, + self.v1_path, {}, base_address="api", ) @@ -305,14 +359,14 @@ def test_empty_methods_streaming_response( for meth in VALID_API_METHODS: http_client_mock.stub_request( meth, - path=self.valid_path, + path=self.v1_path, rbody=util.io.BytesIO(b"thisisdata"), rcode=200, ) resp = requestor.request_stream( meth, - self.valid_path, + self.v1_path, {}, base_address="api", ) @@ -338,7 +392,7 @@ def test_methods_with_params_and_response( http_client_mock.stub_request( method, - path=self.valid_path, + path=self.v1_path, query_string=encoded if method != "post" else "", rbody='{"foo": "bar", "baz": 6}', rcode=200, @@ -352,7 +406,7 @@ def test_methods_with_params_and_response( resp = requestor.request( method, - self.valid_path, + self.v1_path, params, base_address="api", ) @@ -368,7 +422,7 @@ def test_methods_with_params_and_response( else: abs_url = "%s%s?%s" % ( stripe.api_base, - self.valid_path, + self.v1_path, encoded, ) http_client_mock.assert_requested(method, abs_url=abs_url) @@ -384,7 +438,7 @@ def test_methods_with_params_and_streaming_response( http_client_mock.stub_request( method, - path=self.valid_path, + path=self.v1_path, query_string=encoded if method != "post" else "", rbody=util.io.BytesIO(b'{"foo": "bar", "baz": 6}'), rcode=200, @@ -398,7 +452,7 @@ def test_methods_with_params_and_streaming_response( resp = requestor.request_stream( method, - self.valid_path, + self.v1_path, params, base_address="api", ) @@ -411,19 +465,19 @@ def test_methods_with_params_and_streaming_response( else: abs_url = "%s%s?%s" % ( stripe.api_base, - self.valid_path, + self.v1_path, encoded, ) http_client_mock.assert_requested(method, abs_url=abs_url) def test_uses_headers(self, requestor, http_client_mock): http_client_mock.stub_request( - "get", path=self.valid_path, rbody="{}", rcode=200 + "get", path=self.v1_path, rbody="{}", rcode=200 ) request_options: RequestOptions = {"headers": {"foo": "bar"}} requestor.request( "get", - self.valid_path, + self.v1_path, {}, options=request_options, base_address="api", @@ -432,12 +486,12 @@ def test_uses_headers(self, requestor, http_client_mock): def test_uses_api_version(self, requestor, http_client_mock): http_client_mock.stub_request( - "get", path=self.valid_path, rbody="{}", rcode=200 + "get", path=self.v1_path, rbody="{}", rcode=200 ) request_options: RequestOptions = {"stripe_version": "fooversion"} requestor.request( "get", - self.valid_path, + self.v1_path, options=request_options, base_address="api", ) @@ -448,7 +502,7 @@ def test_uses_api_version(self, requestor, http_client_mock): def test_prefers_headers_api_version(self, requestor, http_client_mock): http_client_mock.stub_request( - "get", path=self.valid_path, rbody="{}", rcode=200 + "get", path=self.v1_path, rbody="{}", rcode=200 ) request_options: RequestOptions = { "stripe_version": "fooversion", @@ -456,7 +510,7 @@ def test_prefers_headers_api_version(self, requestor, http_client_mock): } requestor.request( "get", - self.valid_path, + self.v1_path, {}, options=request_options, base_address="api", @@ -471,10 +525,10 @@ def test_uses_instance_key(self, requestor, http_client_mock): requestor = requestor._replace_options(RequestOptions(api_key=key)) http_client_mock.stub_request( - "get", path=self.valid_path, rbody="{}", rcode=200 + "get", path=self.v1_path, rbody="{}", rcode=200 ) - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") http_client_mock.assert_requested("get", api_key=key) assert requestor.api_key == key @@ -486,16 +540,66 @@ def test_uses_instance_account(self, requestor, http_client_mock): ) http_client_mock.stub_request( - "get", path=self.valid_path, rbody="{}", rcode=200 + "get", path=self.v1_path, rbody="{}", rcode=200 ) - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") http_client_mock.assert_requested( "get", stripe_account=account, ) + def test_removes_None_account( + self, requestor, http_client_mock: HTTPClientMock + ): + """ + important test! + + If there's no context on a retrieved event, it's important that passing `stripe-account: None` + in the generated fetch_related_object doesn't actually send the null header + """ + account = None + requestor = requestor._replace_options( + RequestOptions(stripe_account=account) + ) + + http_client_mock.stub_request( + "get", path=self.v1_path, rbody="{}", rcode=200 + ) + + requestor.request("get", self.v1_path, {}, base_address="api") + + assert len(http_client_mock.get_all_calls()) == 1 + call = http_client_mock.get_last_call() + assert call.headers is not None + + assert "Stripe-Account" not in call.headers + + def test_uses_instance_context(self, http_client_mock): + context = "acct_bar" + + requestor = _APIRequestor( + options=RequestorOptions( + **{ + **_GlobalRequestorOptions().to_dict(), + "stripe_context": context, + } + ), + client=http_client_mock.get_mock_http_client(), + ) + + http_client_mock.stub_request( + "get", path=self.v1_path, rbody="{}", rcode=200 + ) + + requestor.request("get", self.v1_path, {}, base_address="api") + + http_client_mock.assert_requested( + "get", + stripe_context=context, + ) + def test_sets_default_http_client(self, mocker): assert not stripe.default_http_client @@ -528,9 +632,9 @@ def test_uses_app_info(self, requestor, http_client_mock): ) http_client_mock.stub_request( - "get", path=self.valid_path, rbody="{}", rcode=200 + "get", path=self.v1_path, rbody="{}", rcode=200 ) - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") ua = "Stripe/v1 PythonBindings/%s" % (stripe.VERSION,) ua += " MyAwesomePlugin/1.2.34 (https://myawesomeplugin.info)" @@ -557,7 +661,7 @@ def test_handles_failed_platform_call( self, requestor, mocker, http_client_mock ): http_client_mock.stub_request( - "get", path=self.valid_path, rbody="{}", rcode=200 + "get", path=self.v1_path, rbody="{}", rcode=200 ) def fail(): @@ -565,7 +669,7 @@ def fail(): mocker.patch("platform.platform", side_effect=fail) - requestor.request("get", self.valid_path, {}, {}, base_address="api") + requestor.request("get", self.v1_path, {}, {}, base_address="api") last_call = http_client_mock.get_last_call() last_call.assert_method("get") @@ -577,104 +681,130 @@ def fail(): ) def test_uses_given_idempotency_key(self, requestor, http_client_mock): - meth = "post" + method = "post" http_client_mock.stub_request( - meth, path=self.valid_path, rbody="{}", rcode=200 + method, path=self.v1_path, rbody="{}", rcode=200 ) request_options: RequestOptions = {"idempotency_key": "123abc"} requestor.request( - meth, - self.valid_path, + method, + self.v1_path, {}, options=request_options, base_address="api", ) http_client_mock.assert_requested( - meth, idempotency_key="123abc", post_data="" + method, idempotency_key="123abc", post_data="" ) def test_uuid4_idempotency_key_when_not_given( self, requestor, http_client_mock ): - meth = "post" + method = "post" + http_client_mock.stub_request( + method, path=self.v1_path, rbody="{}", rcode=200 + ) + requestor.request(method, self.v1_path, {}, base_address="api") + + http_client_mock.assert_requested( + method, idempotency_key=AnyUUID4Matcher(), post_data="" + ) + + def test_generates_default_idempotency_key_for_v2_delete( + self, requestor, http_client_mock + ): + method = "delete" + http_client_mock.stub_request( + method, path=self.v2_path, rbody="{}", rcode=200 + ) + requestor.request(method, self.v2_path, {}, base_address="api") + + http_client_mock.assert_requested( + method, idempotency_key=AnyUUID4Matcher() + ) + + def test_skips_generates_default_idempotency_key_for_v1_delete( + self, requestor, http_client_mock + ): + method = "delete" http_client_mock.stub_request( - meth, path=self.valid_path, rbody="{}", rcode=200 + method, path=self.v1_path, rbody="{}", rcode=200 ) - requestor.request(meth, self.valid_path, {}, base_address="api") + requestor.request(method, self.v1_path, {}, base_address="api") http_client_mock.assert_requested( - meth, idempotency_key=AnyUUID4Matcher(), post_data="" + method, idempotency_key=IsNoneMatcher() ) def test_fails_without_api_key(self, requestor): stripe.api_key = None with pytest.raises(stripe.error.AuthenticationError): - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") def test_invalid_request_error_404(self, requestor, http_client_mock): http_client_mock.stub_request( - "get", path=self.valid_path, rbody='{"error": {}}', rcode=404 + "get", path=self.v1_path, rbody='{"error": {}}', rcode=404 ) with pytest.raises(stripe.error.InvalidRequestError): - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") def test_invalid_request_error_400(self, requestor, http_client_mock): http_client_mock.stub_request( - "get", path=self.valid_path, rbody='{"error": {}}', rcode=400 + "get", path=self.v1_path, rbody='{"error": {}}', rcode=400 ) with pytest.raises(stripe.error.InvalidRequestError): - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") def test_idempotency_error(self, requestor, http_client_mock): http_client_mock.stub_request( "get", - path=self.valid_path, + path=self.v1_path, rbody='{"error": {"type": "idempotency_error"}}', rcode=400, ) with pytest.raises(stripe.error.IdempotencyError): - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") def test_authentication_error(self, requestor, http_client_mock): http_client_mock.stub_request( - "get", path=self.valid_path, rbody='{"error": {}}', rcode=401 + "get", path=self.v1_path, rbody='{"error": {}}', rcode=401 ) with pytest.raises(stripe.error.AuthenticationError): - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") def test_permissions_error(self, requestor, http_client_mock): http_client_mock.stub_request( - "get", path=self.valid_path, rbody='{"error": {}}', rcode=403 + "get", path=self.v1_path, rbody='{"error": {}}', rcode=403 ) with pytest.raises(stripe.error.PermissionError): - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") def test_card_error(self, requestor, http_client_mock): http_client_mock.stub_request( "get", - path=self.valid_path, + path=self.v1_path, rbody='{"error": {"code": "invalid_expiry_year"}}', rcode=402, ) with pytest.raises(stripe.error.CardError) as excinfo: - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") assert excinfo.value.code == "invalid_expiry_year" def test_rate_limit_error(self, requestor, http_client_mock): http_client_mock.stub_request( - "get", path=self.valid_path, rbody='{"error": {}}', rcode=429 + "get", path=self.v1_path, rbody='{"error": {}}', rcode=429 ) with pytest.raises(stripe.error.RateLimitError): - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") def test_old_rate_limit_error(self, requestor, http_client_mock): """ @@ -682,29 +812,29 @@ def test_old_rate_limit_error(self, requestor, http_client_mock): """ http_client_mock.stub_request( "get", - path=self.valid_path, + path=self.v1_path, rbody='{"error": {"code":"rate_limit"}}', rcode=400, ) with pytest.raises(stripe.error.RateLimitError): - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") def test_server_error(self, requestor, http_client_mock): http_client_mock.stub_request( - "get", path=self.valid_path, rbody='{"error": {}}', rcode=500 + "get", path=self.v1_path, rbody='{"error": {}}', rcode=500 ) with pytest.raises(stripe.error.APIError): - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") def test_invalid_json(self, requestor, http_client_mock): http_client_mock.stub_request( - "get", path=self.valid_path, rbody="{", rcode=200 + "get", path=self.v1_path, rbody="{", rcode=200 ) with pytest.raises(stripe.error.APIError): - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") def test_invalid_method(self, requestor): with pytest.raises(stripe.error.APIConnectionError): @@ -713,49 +843,49 @@ def test_invalid_method(self, requestor): def test_oauth_invalid_requestor_error(self, requestor, http_client_mock): http_client_mock.stub_request( "get", - path=self.valid_path, + path=self.v1_path, rbody='{"error": "invalid_request"}', rcode=400, ) with pytest.raises(stripe.oauth_error.InvalidRequestError): - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") def test_invalid_client_error(self, requestor, http_client_mock): http_client_mock.stub_request( "get", - path=self.valid_path, + path=self.v1_path, rbody='{"error": "invalid_client"}', rcode=401, ) with pytest.raises(stripe.oauth_error.InvalidClientError): - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") def test_invalid_grant_error(self, requestor, http_client_mock): http_client_mock.stub_request( "get", - path=self.valid_path, + path=self.v1_path, rbody='{"error": "invalid_grant"}', rcode=400, ) with pytest.raises(stripe.oauth_error.InvalidGrantError): - requestor.request("get", self.valid_path, {}, base_address="api") + requestor.request("get", self.v1_path, {}, base_address="api") def test_extract_error_from_stream_request_for_bytes( self, requestor, http_client_mock ): http_client_mock.stub_request( "get", - path=self.valid_path, + path=self.v1_path, rbody=util.io.BytesIO(b'{"error": "invalid_grant"}'), rcode=400, ) with pytest.raises(stripe.oauth_error.InvalidGrantError): requestor.request_stream( - "get", self.valid_path, {}, base_address="api" + "get", self.v1_path, {}, base_address="api" ) def test_extract_error_from_stream_request_for_response( @@ -764,7 +894,7 @@ def test_extract_error_from_stream_request_for_response( # Responses don't have getvalue, they only have a read method. http_client_mock.stub_request( "get", - path=self.valid_path, + path=self.v1_path, rbody=urllib3.response.HTTPResponse( body=util.io.BytesIO(b'{"error": "invalid_grant"}'), preload_content=False, @@ -774,20 +904,20 @@ def test_extract_error_from_stream_request_for_response( with pytest.raises(stripe.oauth_error.InvalidGrantError): requestor.request_stream( - "get", self.valid_path, {}, base_address="api" + "get", self.v1_path, {}, base_address="api" ) def test_raw_request_with_file_param(self, requestor, http_client_mock): test_file = tempfile.NamedTemporaryFile() test_file.write("\u263a".encode("utf-16")) test_file.seek(0) - meth = "post" + method = "post" path = "/v1/files" params = {"file": test_file, "purpose": "dispute_evidence"} supplied_headers = {"Content-Type": "multipart/form-data"} - http_client_mock.stub_request(meth, path=path, rbody="{}", rcode=200) + http_client_mock.stub_request(method, path=path, rbody="{}", rcode=200) requestor.request( - meth, + method, path, params, supplied_headers, diff --git a/tests/test_generated_examples.py b/tests/test_generated_examples.py index 826a79ff9..3676af8ab 100644 --- a/tests/test_generated_examples.py +++ b/tests/test_generated_examples.py @@ -4460,6 +4460,26 @@ async def test_checkout_sessions_post_2_service_async( post_data="success_url=https%3A%2F%2Fexample.com%2Fsuccess&line_items[0][price]=price_xxxxxxxxxxxxx&line_items[0][quantity]=2&mode=payment", ) + def test_core_events_get_service( + self, http_client_mock: HTTPClientMock + ) -> None: + http_client_mock.stub_request( + "get", + "/v2/core/events/ll_123", + ) + client = StripeClient( + "sk_test_123", + http_client=http_client_mock.get_mock_http_client(), + ) + + client.v2.core.events.retrieve("ll_123") + http_client_mock.assert_requested( + "get", + path="/v2/core/events/ll_123", + query_string="", + api_base="https://api.stripe.com", + ) + def test_country_specs_get(self, http_client_mock: HTTPClientMock) -> None: stripe.CountrySpec.list(limit=3) http_client_mock.assert_requested( diff --git a/tests/test_http_client.py b/tests/test_http_client.py index d6cbbbdf5..b671b1e7a 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List from typing_extensions import Type from unittest.mock import call import pytest @@ -96,11 +96,13 @@ def test_new_http_client_async_fallback_no_import_found( class TestRetrySleepTimeDefaultHttpClient(StripeClientTestCase): from contextlib import contextmanager - def assert_sleep_times(self, client, expected): - until = len(expected) - actual = list( - map(lambda i: client._sleep_time_seconds(i + 1), range(until)) - ) + def assert_sleep_times( + self, client: _http_client.HTTPClient, expected: List[float] + ): + # the sleep duration for a request after N retries + actual = [ + client._sleep_time_seconds(i + 1) for i in range(len(expected)) + ] assert expected == actual @contextmanager @@ -128,7 +130,7 @@ def test_maximum_delay(self): client = _http_client.new_default_http_client() client._add_jitter_time = lambda sleep_seconds: sleep_seconds max_delay = _http_client.HTTPClient.MAX_DELAY - expected = [0.5, 1.0, max_delay, max_delay, max_delay] + expected = [0.5, 1.0, 2.0, 4.0, max_delay, max_delay, max_delay] self.assert_sleep_times(client, expected) def test_retry_after_header(self): @@ -1090,7 +1092,7 @@ class TestAPIEncode(StripeClientTestCase): def test_encode_dict(self): body = {"foo": {"dob": {"month": 1}, "name": "bat"}} - values = [t for t in _api_encode(body)] + values = [t for t in _api_encode(body, "V1")] assert ("foo[dob][month]", 1) in values assert ("foo[name]", "bat") in values @@ -1098,11 +1100,19 @@ def test_encode_dict(self): def test_encode_array(self): body = {"foo": [{"dob": {"month": 1}, "name": "bat"}]} - values = [t for t in _api_encode(body)] + values = [t for t in _api_encode(body, "V1")] assert ("foo[0][dob][month]", 1) in values assert ("foo[0][name]", "bat") in values + def test_encode_v2_array(self): + body = {"foo": [{"dob": {"month": 1}, "name": "bat"}]} + + values = [t for t in _api_encode(body, "V2")] + + assert ("foo[dob][month]", 1) in values + assert ("foo[name]", "bat") in values + class TestHTTPXClient(StripeClientTestCase, ClientTestBase): REQUEST_CLIENT: Type[_http_client.HTTPXClient] = _http_client.HTTPXClient diff --git a/tests/test_integration.py b/tests/test_integration.py index a96f9b533..843937541 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -13,6 +13,8 @@ from collections import defaultdict from typing import List, Dict, Tuple, Optional +from stripe._stripe_client import StripeClient + if platform.python_implementation() == "PyPy": pytest.skip("skip integration tests with PyPy", allow_module_level=True) @@ -101,7 +103,6 @@ def setup_stripe(self): stripe._default_proxy = None stripe.enable_telemetry = False stripe.max_network_retries = 3 - stripe.proxy = None yield stripe.api_base = orig_attrs["api_base"] stripe.upload_api_base = orig_attrs["api_base"] @@ -348,7 +349,9 @@ async def async_http_client(self, request, anyio_backend): async def set_global_async_http_client(self, async_http_client): stripe.default_http_client = async_http_client - async def test_async_success(self, set_global_async_http_client): + async def test_async_raw_request_success( + self, set_global_async_http_client + ): class MockServerRequestHandler(MyTestHandler): default_body = '{"id": "cus_123", "object": "customer"}'.encode( "utf-8" @@ -357,11 +360,16 @@ class MockServerRequestHandler(MyTestHandler): self.setup_mock_server(MockServerRequestHandler) - stripe.api_base = "http://localhost:%s" % self.mock_server_port - - cus = await stripe.Customer.create_async( - description="My test customer" + client = StripeClient( + "sk_test_123", + base_addresses={ + "api": "http://localhost:%s" % self.mock_server_port + }, + ) + resp = await client.raw_request_async( + "post", "/v1/customers", description="My test customer" ) + cus = client.deserialize(resp.data, api_mode="V1") reqs = MockServerRequestHandler.get_requests(1) req = reqs[0] @@ -370,14 +378,15 @@ class MockServerRequestHandler(MyTestHandler): assert req.command == "POST" assert isinstance(cus, stripe.Customer) - async def test_async_timeout(self, set_global_async_http_client): + async def test_async_raw_request_timeout( + self, set_global_async_http_client + ): class MockServerRequestHandler(MyTestHandler): def do_request(self, n): time.sleep(0.02) return super().do_request(n) self.setup_mock_server(MockServerRequestHandler) - stripe.api_base = "http://localhost:%s" % self.mock_server_port # If we set HTTPX's generic timeout the test is flaky (sometimes it's a ReadTimeout, sometimes its a ConnectTimeout) # so we set only the read timeout specifically. hc = stripe.default_http_client @@ -391,11 +400,20 @@ def do_request(self, n): expected_message = "A ServerTimeoutError was raised" else: raise ValueError(f"Unknown http client: {hc.name}") - stripe.max_network_retries = 0 exception = None try: - await stripe.Customer.create_async(description="My test customer") + client = StripeClient( + "sk_test_123", + http_client=hc, + base_addresses={ + "api": "http://localhost:%s" % self.mock_server_port + }, + max_network_retries=0, + ) + await client.raw_request_async( + "post", "/v1/customers", description="My test customer" + ) except stripe.APIConnectionError as e: exception = e @@ -403,7 +421,9 @@ def do_request(self, n): assert expected_message in str(exception.user_message) - async def test_async_retries(self, set_global_async_http_client): + async def test_async_raw_request_retries( + self, set_global_async_http_client + ): class MockServerRequestHandler(MyTestHandler): def do_request(self, n): if n == 0: @@ -417,16 +437,26 @@ def do_request(self, n): pass self.setup_mock_server(MockServerRequestHandler) - stripe.api_base = "http://localhost:%s" % self.mock_server_port - await stripe.Customer.create_async(description="My test customer") + client = StripeClient( + "sk_test_123", + base_addresses={ + "api": "http://localhost:%s" % self.mock_server_port + }, + max_network_retries=stripe.max_network_retries, + ) + await client.raw_request_async( + "post", "/v1/customers", description="My test customer" + ) reqs = MockServerRequestHandler.get_requests(2) req = reqs[0] assert req.path == "/v1/customers" - async def test_async_unretryable(self, set_global_async_http_client): + async def test_async_raw_request_unretryable( + self, set_global_async_http_client + ): class MockServerRequestHandler(MyTestHandler): def do_request(self, n): return ( @@ -438,11 +468,18 @@ def do_request(self, n): pass self.setup_mock_server(MockServerRequestHandler) - stripe.api_base = "http://localhost:%s" % self.mock_server_port exception = None try: - await stripe.Customer.create_async(description="My test customer") + client = StripeClient( + "sk_test_123", + base_addresses={ + "api": "http://localhost:%s" % self.mock_server_port + }, + ) + await client.raw_request_async( + "post", "/v1/customers", description="My test customer" + ) except stripe.AuthenticationError as e: exception = e diff --git a/tests/test_raw_request.py b/tests/test_raw_request.py new file mode 100644 index 000000000..459b0d98b --- /dev/null +++ b/tests/test_raw_request.py @@ -0,0 +1,225 @@ +from __future__ import absolute_import, division, print_function + +import datetime + +import stripe + +from tests.test_api_requestor import GMT1 + + +class TestRawRequest(object): + ENCODE_INPUTS = { + "type": "standard", + "int": 123, + "datetime": datetime.datetime(2013, 1, 1, second=1, tzinfo=GMT1()), + } + POST_REL_URL = "/v1/accounts" + GET_REL_URL = "/v1/accounts/acct_123" + POST_REL_URL_V2 = "/v2/billing/meter_event_session" + GET_REL_URL_V2 = "/v2/accounts/acct_123" + + def test_form_request_get( + self, http_client_mock, stripe_mock_stripe_client + ): + http_client_mock.stub_request( + "get", + path=self.GET_REL_URL, + rbody='{"id": "acct_123", "object": "account"}', + rcode=200, + rheaders={}, + ) + + resp = stripe_mock_stripe_client.raw_request("get", self.GET_REL_URL) + http_client_mock.assert_requested("get", path=self.GET_REL_URL) + + deserialized = stripe_mock_stripe_client.deserialize( + resp, api_mode="V1" + ) + assert isinstance(deserialized, stripe.Account) + + def test_form_request_post( + self, http_client_mock, stripe_mock_stripe_client + ): + http_client_mock.stub_request( + "post", + path=self.POST_REL_URL, + rbody='{"id": "acct_123", "object": "account"}', + rcode=200, + rheaders={}, + ) + + expectation = "type=standard&int=123&datetime=1356994801" + + resp = stripe_mock_stripe_client.raw_request( + "post", self.POST_REL_URL, **self.ENCODE_INPUTS + ) + + http_client_mock.assert_requested( + "post", + path=self.POST_REL_URL, + content_type="application/x-www-form-urlencoded", + post_data=expectation, + ) + + deserialized = stripe_mock_stripe_client.deserialize( + resp, api_mode="V1" + ) + assert isinstance(deserialized, stripe.Account) + + def test_preview_request_post( + self, http_client_mock, stripe_mock_stripe_client + ): + http_client_mock.stub_request( + "post", + path=self.POST_REL_URL_V2, + rbody='{"id": "bmes_123", "object": "billing.meter_event_session"}', + rcode=200, + rheaders={}, + ) + + params = dict({}, **self.ENCODE_INPUTS) + expectation = ( + '{"type": "standard", "int": 123, "datetime": 1356994801}' + ) + + resp = stripe_mock_stripe_client.raw_request( + "post", self.POST_REL_URL_V2, **params + ) + + http_client_mock.assert_requested( + "post", + path=self.POST_REL_URL_V2, + content_type="application/json", + post_data=expectation, + is_json=True, + ) + + deserialized = stripe_mock_stripe_client.deserialize( + resp, api_mode="V2" + ) + assert isinstance(deserialized, stripe.v2.billing.MeterEventSession) + + def test_form_request_with_extra_headers( + self, http_client_mock, stripe_mock_stripe_client + ): + http_client_mock.stub_request( + "get", + path=self.GET_REL_URL, + rbody='{"id": "acct_123", "object": "account"}', + rcode=200, + rheaders={}, + ) + + extra_headers = {"foo": "bar", "Stripe-Account": "acct_123"} + params = {"headers": extra_headers} + + stripe_mock_stripe_client.raw_request( + "get", self.GET_REL_URL, **params + ) + + http_client_mock.assert_requested( + "get", + path=self.GET_REL_URL, + extra_headers=extra_headers, + ) + + def test_preview_request_default_api_version( + self, http_client_mock, stripe_mock_stripe_client + ): + http_client_mock.stub_request( + "get", + path=self.GET_REL_URL_V2, + rbody='{"id": "acct_123", "object": "account"}', + rcode=200, + rheaders={}, + ) + params = {} + + stripe_mock_stripe_client.raw_request( + "get", self.GET_REL_URL_V2, **params + ) + + http_client_mock.assert_requested( + "get", + path=self.GET_REL_URL_V2, + ) + + def test_preview_request_overridden_api_version( + self, http_client_mock, stripe_mock_stripe_client + ): + http_client_mock.stub_request( + "post", + path=self.POST_REL_URL_V2, + rbody='{"id": "acct_123", "object": "account"}', + rcode=200, + rheaders={}, + ) + stripe_version_override = "2023-05-15.preview" + params = { + "stripe_version": stripe_version_override, + } + + stripe_mock_stripe_client.raw_request( + "post", self.POST_REL_URL_V2, **params + ) + + http_client_mock.assert_requested( + "post", + path=self.POST_REL_URL_V2, + content_type="application/json", + stripe_version=stripe_version_override, + post_data="{}", + is_json=True, + ) + + # TODO(jar) this test is not applicable yet, but may be some day + # @pytest.mark.anyio + # async def test_form_request_get_async( + # self, http_client_mock, stripe_mock_stripe_client + # ): + # http_client_mock.stub_request( + # "get", + # path=self.GET_REL_URL, + # rbody='{"id": "acct_123", "object": "account"}', + # rcode=200, + # rheaders={}, + # ) + # + # resp = await stripe_mock_stripe_client.raw_request_async( + # "get", self.GET_REL_URL + # ) + # + # http_client_mock.assert_requested("get", path=self.GET_REL_URL) + # + # deserialized = stripe_mock_stripe_client.deserialize(resp) + # assert isinstance(deserialized, stripe.Account) + # + def test_raw_request_usage_reported( + self, http_client_mock, stripe_mock_stripe_client + ): + http_client_mock.stub_request( + "post", + path=self.POST_REL_URL, + rbody='{"id": "acct_123", "object": "account"}', + rcode=200, + rheaders={}, + ) + + expectation = "type=standard&int=123&datetime=1356994801" + + resp = stripe_mock_stripe_client.raw_request( + "post", self.POST_REL_URL, **self.ENCODE_INPUTS + ) + + http_client_mock.assert_requested( + "post", + path=self.POST_REL_URL, + content_type="application/x-www-form-urlencoded", + post_data=expectation, + usage=["raw_request"], + ) + + deserialized = stripe_mock_stripe_client.deserialize( + resp, api_mode="V1" + ) + assert isinstance(deserialized, stripe.Account) diff --git a/tests/test_request_options.py b/tests/test_request_options.py index b57995ba1..27d1fe026 100644 --- a/tests/test_request_options.py +++ b/tests/test_request_options.py @@ -42,6 +42,7 @@ def test_extract_from_dict(self): "api_key": "sk_test_123", "stripe_version": "2020-01-01", "stripe_account": "acct_123", + "stripe_context": "wksp_123", "idempotency_key": "idemp_123", "headers": { "X-Stripe-Header": "Some-Value", @@ -52,6 +53,7 @@ def test_extract_from_dict(self): assert options.get("api_key") == "sk_test_123" assert options.get("stripe_version") == "2020-01-01" assert options.get("stripe_account") == "acct_123" + assert options.get("stripe_context") == "wksp_123" assert options.get("idempotency_key") == "idemp_123" assert options.get("headers") == {"X-Stripe-Header": "Some-Value"} assert remaining == {"foo": "bar"} diff --git a/tests/test_requestor_options.py b/tests/test_requestor_options.py index 2ed3731ad..b818d39a8 100644 --- a/tests/test_requestor_options.py +++ b/tests/test_requestor_options.py @@ -10,6 +10,7 @@ def test_to_dict(self): requestor = RequestorOptions( api_key="sk_test_123", stripe_account="acct_123", + stripe_context="wksp_123", stripe_version="2019-12-03", base_addresses={ "api": "https://api.example.com", @@ -21,6 +22,7 @@ def test_to_dict(self): assert requestor.to_dict() == { "api_key": "sk_test_123", "stripe_account": "acct_123", + "stripe_context": "wksp_123", "stripe_version": "2019-12-03", "base_addresses": { "api": "https://api.example.com", @@ -38,16 +40,22 @@ def test_global_options_get_updated( orig_api_base = stripe.api_base orig_connect_base = stripe.connect_api_base orig_upload_base = stripe.upload_api_base + orig_meter_events_base = stripe.meter_events_api_base orig_max_network_retries = stripe.max_network_retries assert global_options.api_key == orig_api_key assert global_options.base_addresses["api"] == orig_api_base assert global_options.base_addresses["connect"] == orig_connect_base assert global_options.base_addresses["files"] == orig_upload_base + assert ( + global_options.base_addresses["meter_events"] + == orig_meter_events_base + ) assert global_options.stripe_account is None stripe.api_key = "sk_test_555555555" stripe.api_base = "https://api.example.com" stripe.connect_api_base = "https://connect.example.com" stripe.upload_api_base = "https://upload.example.com" + stripe.meter_events_api_base = "https://meter-events.example.com" stripe.max_network_retries = 3 assert global_options.api_key == "sk_test_555555555" assert ( @@ -61,10 +69,15 @@ def test_global_options_get_updated( global_options.base_addresses["files"] == "https://upload.example.com" ) + assert ( + global_options.base_addresses["meter_events"] + == "https://meter-events.example.com" + ) assert global_options.stripe_account is None assert global_options.max_network_retries == 3 stripe.api_key = orig_api_key stripe.api_base = orig_api_base stripe.connect_api_base = orig_connect_base stripe.upload_api_base = orig_upload_base + stripe.meter_events_api_base = orig_meter_events_base stripe.max_network_retries = orig_max_network_retries diff --git a/tests/test_stripe_client.py b/tests/test_stripe_client.py index 6abfe396b..a724fc9bb 100644 --- a/tests/test_stripe_client.py +++ b/tests/test_stripe_client.py @@ -2,7 +2,11 @@ import stripe import pytest +from stripe.v2._event import Event from stripe._http_client import new_default_http_client +from stripe.events._v1_billing_meter_error_report_triggered_event import ( + V1BillingMeterErrorReportTriggeredEvent, +) class TestStripeClient(object): @@ -28,6 +32,32 @@ def test_v1_customers_retrieve( http_client_mock.assert_requested(method, path=path) assert customer.id is not None + def test_v2_events_retrieve(self, http_client_mock): + method = "get" + path = "/v2/core/events/evt_123" + http_client_mock.stub_request( + method, + path=path, + rbody='{"id": "evt_123","object": "event", "type": "v1.billing.meter.error_report_triggered"}', + rcode=200, + rheaders={}, + ) + client = stripe.StripeClient( + api_key="keyinfo_test_123", + http_client=http_client_mock.get_mock_http_client(), + ) + event = client.v2.core.events.retrieve("evt_123") + + http_client_mock.assert_requested( + method, + api_base=stripe.DEFAULT_API_BASE, + path=path, + api_key="keyinfo_test_123", + ) + assert event.id is not None + assert isinstance(event, Event) + assert isinstance(event, V1BillingMeterErrorReportTriggeredEvent) + def test_no_api_key(self): with pytest.raises(stripe.error.AuthenticationError): stripe.StripeClient(None) # type: ignore @@ -61,12 +91,14 @@ def test_client_level_options(self, http_client_mock): api_base = "https://example.com" api_key = "sk_test_456" stripe_account = "acct_123" + stripe_context = "wksp_123" stripe_client = stripe.StripeClient( api_key=api_key, http_client=http_client_mock.get_mock_http_client(), base_addresses={"api": api_base}, stripe_account=stripe_account, + stripe_context=stripe_context, ) stripe_client.customers.retrieve("cus_xxxxxxxxxxxxx") @@ -77,6 +109,7 @@ def test_client_level_options(self, http_client_mock): path=path, api_key=api_key, stripe_account=stripe_account, + stripe_context=stripe_context, stripe_version=stripe.api_version, ) @@ -111,15 +144,18 @@ def test_request_level_options(self, http_client_mock): client_api_base = "https://example.com" client_api_key = "sk_test_456" client_stripe_account = "acct_123" + client_stripe_context = "wksp_123" request_api_key = "sk_test_789" request_stripe_account = "acct_456" + request_stripe_context = "wksp_456" stripe_client = stripe.StripeClient( api_key=client_api_key, http_client=http_client_mock.get_mock_http_client(), base_addresses={"api": client_api_base}, stripe_account=client_stripe_account, + stripe_context=client_stripe_context, ) stripe_client.customers.retrieve( @@ -127,6 +163,7 @@ def test_request_level_options(self, http_client_mock): options={ "api_key": request_api_key, "stripe_account": request_stripe_account, + "stripe_context": request_stripe_context, }, ) @@ -136,6 +173,7 @@ def test_request_level_options(self, http_client_mock): path=path, api_key=request_api_key, stripe_account=request_stripe_account, + stripe_context=request_stripe_context, stripe_version=stripe.api_version, ) @@ -179,6 +217,31 @@ def test_separate_clients_have_separate_options(self, http_client_mock): stripe_version=stripe.api_version, ) + def test_v2_encodes_none_as_null(self, http_client_mock): + http_client_mock.stub_request( + "post", + path="/v2/billing/meter_events", + rbody='{"event_name": "cool", "payload": {}, "identifier": null}', + rcode=200, + rheaders={}, + ) + + client = stripe.StripeClient( + api_key="sk_test_123", + http_client=http_client_mock.get_mock_http_client(), + ) + + client.v2.billing.meter_events.create( + {"event_name": "cool", "payload": {}, "identifier": None} # type: ignore - None is not valid for `identifier` + ) + + http_client_mock.assert_requested( + "post", + content_type="application/json", + post_data='{"event_name": "cool", "payload": {}, "identifier": null}', + is_json=True, + ) + def test_carries_over_requestor_options_to_resource( self, http_client_mock ): @@ -230,7 +293,8 @@ def test_user_options_are_not_mutated(self, http_client_mock): http_client_mock.stub_request( "get", - path="/v1/accounts", + path="/v2/core/events", + query_string="object_id=obj_123", rbody='{"data": [{"id": "x"}], "next_page": "page_2"}', rcode=200, rheaders={}, @@ -238,7 +302,9 @@ def test_user_options_are_not_mutated(self, http_client_mock): my_options: stripe.RequestOptions = {"api_key": "sk_test_xyz"} - client.accounts.list(options=my_options) + client.v2.core.events.list( + {"object_id": "obj_123"}, options=my_options + ) assert my_options == {"api_key": "sk_test_xyz"} diff --git a/tests/test_util.py b/tests/test_util.py index df045a749..93b75c84f 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -135,7 +135,7 @@ def test_convert_to_stripe_object_and_back(self): "livemode": False, } - obj = util.convert_to_stripe_object(resp) + obj = util.convert_to_stripe_object(resp, api_mode="V1") assert isinstance(obj, stripe.Balance) assert isinstance(obj.available, list) assert isinstance(obj.available[0], stripe.stripe_object.StripeObject) @@ -149,4 +149,6 @@ def test_convert_to_stripe_object_and_back(self): def test_sanitize_id(self): sanitized_id = util.sanitize_id("cu %x 123") + if isinstance(sanitized_id, bytes): + sanitized_id = sanitized_id.decode("utf-8", "strict") assert sanitized_id == "cu++%25x+123" diff --git a/tests/test_v2_error.py b/tests/test_v2_error.py new file mode 100644 index 000000000..f2828c377 --- /dev/null +++ b/tests/test_v2_error.py @@ -0,0 +1,141 @@ +from __future__ import absolute_import, division, print_function + +import json + +import pytest + +import stripe +from stripe import error +from tests.http_client_mock import HTTPClientMock + + +class TestV2Error(object): + @pytest.fixture(scope="function") + def stripe_client(self, http_client_mock): + return stripe.StripeClient( + api_key="keyinfo_test_123", + http_client=http_client_mock.get_mock_http_client(), + ) + + def test_raises_v2_error( + self, + stripe_client: stripe.StripeClient, + http_client_mock: HTTPClientMock, + ): + method = "get" + path = "/v2/core/events/evt_123" + + error_response = { + "error": { + "type": "temporary_session_expired", + "code": "session_bad", + "message": "you messed up", + } + } + http_client_mock.stub_request( + method, + path=path, + rbody=json.dumps(error_response), + rcode=400, + rheaders={}, + ) + + try: + stripe_client.v2.core.events.retrieve("evt_123") + except error.TemporarySessionExpiredError as e: + assert e.code == "session_bad" + assert e.error.code == "session_bad" + assert e.error.message == "you messed up" + else: + assert False, "Should have raised a TemporarySessionExpiredError" + + http_client_mock.assert_requested( + method, + path=path, + api_key="keyinfo_test_123", + ) + + @pytest.mark.skip("python doesn't have any errors with invalid params yet") + def test_raises_v2_error_with_field( + self, + stripe_client: stripe.StripeClient, + http_client_mock: HTTPClientMock, + ): + method = "post" + path = "/v2/payment_methods/us_bank_accounts" + + error_response = { + "error": { + "type": "invalid_payment_method", + "code": "invalid_us_bank_account", + "message": "bank account is invalid", + "invalid_param": "routing_number", + } + } + http_client_mock.stub_request( + method, + path=path, + rbody=json.dumps(error_response), + rcode=400, + rheaders={}, + ) + + try: + stripe_client.v2.payment_methods.us_bank_accounts.create( + params={"account_number": "123", "routing_number": "456"} + ) + except error.InvalidPaymentMethodError as e: + assert e.invalid_param == "routing_number" + assert e.error.code == "invalid_us_bank_account" + assert e.error.message == "bank account is invalid" + else: + assert False, "Should have raised a InvalidUsBankAccountError" + + http_client_mock.assert_requested( + method, + path=path, + api_key="keyinfo_test_123", + ) + + def test_falls_back_to_v1_error( + self, + stripe_client: stripe.StripeClient, + http_client_mock: HTTPClientMock, + ): + method = "post" + path = "/v2/billing/meter_events" + + error_response = { + "error": { + "code": "invalid_request", + "message": "your request is invalid", + "param": "invalid_param", + } + } + http_client_mock.stub_request( + method, + path=path, + rbody=json.dumps(error_response), + rcode=400, + rheaders={"request-id": "123"}, + ) + + try: + stripe_client.v2.billing.meter_events.create( + {"event_name": "asdf", "payload": {}} + ) + except error.InvalidRequestError as e: + assert e.param == "invalid_param" + assert repr(e) == ( + "InvalidRequestError(message='your request is invalid', " + "param='invalid_param', code='invalid_request', " + "http_status=400, request_id='123')" + ) + else: + assert False, "Should have raised a InvalidRequestError" + + http_client_mock.assert_requested( + method, + path=path, + api_key="keyinfo_test_123", + ) diff --git a/tests/test_v2_event.py b/tests/test_v2_event.py new file mode 100644 index 000000000..cfc94cf05 --- /dev/null +++ b/tests/test_v2_event.py @@ -0,0 +1,110 @@ +import json +from typing import Callable + +import pytest + +import stripe +from stripe import ThinEvent +from tests.test_webhook import DUMMY_WEBHOOK_SECRET, generate_header + +EventParser = Callable[[str], ThinEvent] + + +class TestV2Event(object): + @pytest.fixture(scope="function") + def v2_payload_no_data(self): + return json.dumps( + { + "id": "evt_234", + "object": "event", + "type": "financial_account.balance.opened", + "created": "2022-02-15T00:27:45.330Z", + "related_object": { + "id": "fa_123", + "type": "financial_account", + "url": "/v2/financial_accounts/fa_123", + "stripe_context": "acct_123", + }, + "reason": { + "id": "foo", + "idempotency_key": "bar", + }, + } + ) + + @pytest.fixture(scope="function") + def v2_payload_with_data(self): + return json.dumps( + { + "id": "evt_234", + "object": "event", + "type": "financial_account.balance.opened", + "created": "2022-02-15T00:27:45.330Z", + "related_object": { + "id": "fa_123", + "type": "financial_account", + "url": "/v2/financial_accounts/fa_123", + "stripe_context": "acct_123", + }, + "data": { + "containing_compartment_id": "compid", + "id": "foo", + "type": "bufo", + }, + } + ) + + @pytest.fixture(scope="function") + def stripe_client(self, http_client_mock): + return stripe.StripeClient( + api_key="keyinfo_test_123", + stripe_context="wksp_123", + http_client=http_client_mock.get_mock_http_client(), + ) + + @pytest.fixture(scope="function") + def parse_thin_event( + self, stripe_client: stripe.StripeClient + ) -> EventParser: + """ + helper to simplify parsing and validating events given a payload + returns a function that has the client pre-bound + """ + + def _parse_thin_event(payload: str): + return stripe_client.parse_thin_event( + payload, generate_header(payload=payload), DUMMY_WEBHOOK_SECRET + ) + + return _parse_thin_event + + def test_parses_thin_event( + self, parse_thin_event: EventParser, v2_payload_no_data: str + ): + event = parse_thin_event(v2_payload_no_data) + + assert isinstance(event, ThinEvent) + assert event.id == "evt_234" + + assert event.related_object + assert event.related_object.id == "fa_123" + + assert event.reason + assert event.reason.id == "foo" + + def test_parses_thin_event_with_data( + self, parse_thin_event: EventParser, v2_payload_with_data: str + ): + event = parse_thin_event(v2_payload_with_data) + + assert isinstance(event, ThinEvent) + assert not hasattr(event, "data") + assert event.reason is None + + def test_validates_signature( + self, stripe_client: stripe.StripeClient, v2_payload_no_data + ): + with pytest.raises(stripe.error.SignatureVerificationError): + stripe_client.parse_thin_event( + v2_payload_no_data, "bad header", DUMMY_WEBHOOK_SECRET + ) diff --git a/tests/test_webhook.py b/tests/test_webhook.py index 53389f725..8c190acb7 100644 --- a/tests/test_webhook.py +++ b/tests/test_webhook.py @@ -135,7 +135,7 @@ def test_timestamp_off_but_no_tolerance(self): class TestStripeClientConstructEvent(object): def test_construct_event(self, stripe_mock_stripe_client): header = generate_header() - event = stripe_mock_stripe_client.construct_event( + event = stripe_mock_stripe_client.parse_snapshot_event( DUMMY_WEBHOOK_PAYLOAD, header, DUMMY_WEBHOOK_SECRET ) assert isinstance(event, stripe.Event) @@ -144,21 +144,21 @@ def test_raise_on_json_error(self, stripe_mock_stripe_client): payload = "this is not valid JSON" header = generate_header(payload=payload) with pytest.raises(ValueError): - stripe_mock_stripe_client.construct_event( + stripe_mock_stripe_client.parse_snapshot_event( payload, header, DUMMY_WEBHOOK_SECRET ) def test_raise_on_invalid_header(self, stripe_mock_stripe_client): header = "bad_header" with pytest.raises(stripe.error.SignatureVerificationError): - stripe_mock_stripe_client.construct_event( + stripe_mock_stripe_client.parse_snapshot_event( DUMMY_WEBHOOK_PAYLOAD, header, DUMMY_WEBHOOK_SECRET ) def test_construct_event_from_bytearray(self, stripe_mock_stripe_client): header = generate_header() payload = bytearray(DUMMY_WEBHOOK_PAYLOAD, "utf-8") - event = stripe_mock_stripe_client.construct_event( + event = stripe_mock_stripe_client.parse_snapshot_event( payload, header, DUMMY_WEBHOOK_SECRET ) assert isinstance(event, stripe.Event) @@ -166,7 +166,7 @@ def test_construct_event_from_bytearray(self, stripe_mock_stripe_client): def test_construct_event_from_bytes(self, stripe_mock_stripe_client): header = generate_header() payload = bytes(DUMMY_WEBHOOK_PAYLOAD, "utf-8") - event = stripe_mock_stripe_client.construct_event( + event = stripe_mock_stripe_client.parse_snapshot_event( payload, header, DUMMY_WEBHOOK_SECRET ) assert isinstance(event, stripe.Event) @@ -181,7 +181,7 @@ def test_construct_event_inherits_requestor(self, http_client_mock): http_client=http_client_mock.get_mock_http_client(), ) header = generate_header() - event = client.construct_event( + event = client.parse_snapshot_event( DUMMY_WEBHOOK_PAYLOAD, header, DUMMY_WEBHOOK_SECRET ) assert event._requestor == client._requestor