From bbe8789f3cfeceeed95082ef39a8e3fccecd0032 Mon Sep 17 00:00:00 2001 From: Bala FA Date: Tue, 12 Dec 2023 04:20:33 +0530 Subject: [PATCH] fix typing in helpers.py (#1373) Signed-off-by: Bala.FA --- minio/helpers.py | 75 +++++++++++++++++++++--------------------------- 1 file changed, 33 insertions(+), 42 deletions(-) diff --git a/minio/helpers.py b/minio/helpers.py index cde5bb463..290950cc7 100644 --- a/minio/helpers.py +++ b/minio/helpers.py @@ -26,11 +26,12 @@ import platform import re import urllib.parse -from abc import ABCMeta +from datetime import datetime from queue import Queue from threading import BoundedSemaphore, Thread -from typing import BinaryIO +from typing import BinaryIO, Mapping +from typing_extensions import Protocol from urllib3._collections import HTTPHeaderDict from . import __title__, __version__ @@ -108,30 +109,21 @@ def queryencode( def headers_to_strings( - headers: dict[str, str], + headers: Mapping[str, str | list[str] | tuple[str]], titled_key: bool = False, ) -> str: """Convert HTTP headers to multi-line string.""" - def _get_key(key: str) -> str: - return key.title() if titled_key else key - - def _get_value(value: str) -> str: - return re.sub( - r"Credential=([^/]+)", - "Credential=*REDACTED*", - re.sub( - r"Signature=([0-9a-f]+)", - "Signature=*REDACTED*", - value if isinstance(value, str) else str(value), - ), - ) if titled_key else value - - return "\n".join( - [ - f"{_get_key(key)}: {_get_value(value)}" - for key, value in headers.items() - ] - ) + values = [] + for key, value in headers.items(): + key = key.title() if titled_key else key + for item in value if isinstance(value, (list, tuple)) else [value]: + item = re.sub( + r"Credential=([^/]+)", + "Credential=*REDACTED*", + re.sub(r"Signature=([0-9a-f]+)", "Signature=*REDACTED*", item), + ) if titled_key else item + values.append(f"{key}: {item}") + return "\n".join(values) def _validate_sizes(object_size: int, part_size: int): @@ -186,22 +178,21 @@ def get_part_info(object_size: int, part_size: int) -> tuple[int, int]: return part_size, part_count -class Progress: - """Progress base class for put object API.""" - __metaclass__ = ABCMeta +class ProgressType(Protocol): + """typing stub for Put/Get object progress.""" - def set_meta(self, total_length: int, object_name: str): - """Set object information to progress.""" + def set_meta(self, object_name: str, total_length: int): + """Set process meta information.""" - def update(self, size: int): - """Update current progress size.""" + def update(self, length: int): + """Set current progress length.""" def read_part_data( stream: BinaryIO, size: int, part_data: bytes = b"", - progress: Progress | None = None, + progress: ProgressType | None = None, ) -> bytes: """Read part data of given size from stream.""" size -= len(part_data) @@ -290,13 +281,13 @@ def is_valid_policy_type(policy: str | bytes): return True -def check_ssec(sse: SseCustomerKey): +def check_ssec(sse: SseCustomerKey | None): """Check sse is SseCustomerKey type or not.""" if sse and not isinstance(sse, SseCustomerKey): raise ValueError("SseCustomerKey type is required") -def check_sse(sse: Sse): +def check_sse(sse: Sse | None): """Check sse is Sse type or not.""" if sse and not isinstance(sse, Sse): raise ValueError("Sse type is required") @@ -345,7 +336,7 @@ def url_replace( def _metadata_to_headers( - metadata: dict[str, str | list | tuple], + metadata: dict[str, str | list[str] | tuple[str]], ) -> dict[str, list[str]]: """Convert user metadata to headers.""" def normalize_key(key: str) -> str: @@ -364,7 +355,7 @@ def to_string(value) -> str: ) from exc return value - def normalize_value(values: str | list | tuple) -> list[str]: + def normalize_value(values: str | list[str] | tuple[str]) -> list[str]: if not isinstance(values, (list, tuple)): values = [values] return [to_string(value) for value in values] @@ -376,8 +367,8 @@ def normalize_value(values: str | list | tuple) -> list[str]: def normalize_headers( - headers: dict[str, str | list | tuple], -) -> dict[str, str | list | tuple]: + headers: dict[str, str | list[str] | tuple[str]] | None, +) -> dict[str, str | list[str] | tuple[str]]: """Normalize headers by prefixing 'X-Amz-Meta-' for user metadata.""" headers = {str(key): value for key, value in (headers or {}).items()} @@ -407,12 +398,12 @@ def guess_user_metadata(key: str) -> bool: def genheaders( - headers: dict[str, str | list | tuple], + headers: dict[str, str | list[str] | tuple[str]] | None, sse: Sse | None, tags: dict[str, str] | None, retention, legal_hold: bool, -) -> dict[str, str | list | tuple]: +) -> dict[str, str | list[str] | tuple[str]]: """Generate headers for given parameters.""" headers = normalize_headers(headers) headers.update(sse.headers() if sse else {}) @@ -648,7 +639,7 @@ def _build_aws_url( netloc = s3_prefix if "s3-accelerate" in s3_prefix: - if "." in bucket_name: + if "." in (bucket_name or ""): raise ValueError( f"bucket name '{bucket_name}' with '.' is not allowed " f"for accelerate endpoint" @@ -755,7 +746,7 @@ def __init__( version_id: str | None, etag: str | None, http_headers: HTTPHeaderDict, - last_modified: str | None = None, + last_modified: datetime | None = None, location: str | None = None, ): self._bucket_name = bucket_name @@ -792,7 +783,7 @@ def http_headers(self) -> HTTPHeaderDict: return self._http_headers @property - def last_modified(self) -> str | None: + def last_modified(self) -> datetime | None: """Get last-modified time.""" return self._last_modified