Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add disable SSL verification #1262

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions mock_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from werkzeug.wrappers import Request, Response

import weaviate
from mock_tests.mock_data import mock_class
from weaviate.connect.base import ConnectionParams, ProtocolParams
from weaviate.proto.v1 import (
batch_pb2,
Expand All @@ -21,8 +22,6 @@
weaviate_pb2_grpc,
)

from mock_tests.mock_data import mock_class

MOCK_IP = "127.0.0.1"
MOCK_PORT = 23536
MOCK_PORT_GRPC = 23537
Expand Down Expand Up @@ -105,18 +104,17 @@ def slow_post(request: Request) -> Response:
yield weaviate_no_auth_mock


# Implement the health check service
class MockHealthServicer(HealthServicer):
def Check(self, request: HealthCheckRequest, context: ServicerContext) -> HealthCheckResponse:
return HealthCheckResponse(status=HealthCheckResponse.SERVING)


@pytest.fixture(scope="function")
def start_grpc_server() -> Generator[grpc.Server, None, None]:
# Create a gRPC server
server: grpc.Server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))

# Implement the health check service
class MockHealthServicer(HealthServicer):
def Check(
self, request: HealthCheckRequest, context: ServicerContext
) -> HealthCheckResponse:
return HealthCheckResponse(status=HealthCheckResponse.SERVING)

# Add the health check service to the server
add_HealthServicer_to_server(MockHealthServicer(), server)

Expand Down
125 changes: 125 additions & 0 deletions mock_tests/test_ssl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import json
import ssl
from concurrent import futures
from typing import Iterable

import grpc
import pytest
import trustme
from grpc_health.v1.health_pb2_grpc import add_HealthServicer_to_server
from pytest_httpserver import HTTPServer
from werkzeug.wrappers import Response

import weaviate
from mock_tests.conftest import MockHealthServicer, MOCK_IP, MOCK_PORT_GRPC

SERVER = "127.0.0.1"
MOCK_PORT_GRPC_SSL = 23538
PORT = 23539


@pytest.fixture(scope="session")
def httpserver_ssl_context():
ca = trustme.CA()
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_cert = ca.issue_cert(SERVER)
server_cert.configure_cert(server_context)

return server_context


@pytest.fixture(scope="session")
def make_httpserver(httpserver_ssl_context) -> Iterable[HTTPServer]:
server = HTTPServer(host=SERVER, port=PORT, ssl_context=httpserver_ssl_context)
server.start()
yield server
server.clear()
if server.is_running():
server.stop()


@pytest.fixture(scope="module")
def start_grpc_server_ssl() -> grpc.Server:
# Create a gRPC server
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))

# Add the health check service to the server
add_HealthServicer_to_server(MockHealthServicer(), server)

# Create server credentials using the SSL context
ca = trustme.CA()
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_cert = ca.issue_cert(SERVER)
server_cert.configure_cert(server_context)
server_credentials = grpc.ssl_server_credentials(
[(server_cert.private_key_pem.bytes(), server_cert.cert_chain_pems[0].bytes())]
)

# Listen on a specific port with SSL
server.add_secure_port(f"[::]:{MOCK_PORT_GRPC_SSL}", server_credentials)
server.start()

yield server

# Teardown - stop the server
server.stop(0)


def test_disable_ssl_verification(
make_httpserver: HTTPServer, start_grpc_server_ssl: grpc.Server, start_grpc_server: grpc.Server
):
make_httpserver.expect_request("/v1/.well-known/ready").respond_with_json({})
make_httpserver.expect_request("/v1/meta").respond_with_json({"version": "1.24"})
make_httpserver.expect_request("/v1/nodes").respond_with_json({"nodes": [{"gitHash": "ABC"}]})
make_httpserver.expect_request("/v1/.well-known/openid-configuration").respond_with_response(
Response(json.dumps({}), status=404)
)

assert make_httpserver.port == PORT
assert make_httpserver.host == SERVER

# test http connection with ssl
with pytest.raises(weaviate.exceptions.WeaviateConnectionError):
weaviate.connect_to_custom(
http_port=PORT,
http_host=SERVER,
grpc_port=MOCK_PORT_GRPC,
http_secure=True,
grpc_host=MOCK_IP,
grpc_secure=False,
)

# test grpc connection with ssl
with pytest.raises(weaviate.exceptions.WeaviateConnectionError):
weaviate.connect_to_custom(
http_port=PORT,
http_host=SERVER,
grpc_port=MOCK_PORT_GRPC_SSL,
http_secure=True,
grpc_host=SERVER,
grpc_secure=True,
)

# test http connection with ssl and verify disabled
weaviate.connect_to_custom(
http_port=PORT,
http_host=SERVER,
grpc_port=MOCK_PORT_GRPC,
http_secure=True,
grpc_host=MOCK_IP,
grpc_secure=False,
additional_config=weaviate.config.AdditionalConfig(disable_ssl_verification=True),
)

# test grpc connection with ssl and verify disabled
weaviate.connect_to_custom(
http_port=PORT,
http_host=SERVER,
grpc_port=MOCK_PORT_GRPC_SSL,
http_secure=True,
grpc_host=SERVER,
grpc_secure=True,
additional_config=weaviate.config.AdditionalConfig(disable_ssl_verification=True),
)

make_httpserver.check_assertions()
1 change: 1 addition & 0 deletions requirements-devel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pytest-xdist==3.6.1
werkzeug==3.0.3
pytest-httpserver==1.0.12
py-spy==0.3.14
trustme>=1.1.0

numpy>=1.24.4,<3.0.0
pandas>=2.0.3,<3.0.0
Expand Down
2 changes: 2 additions & 0 deletions test/collection/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest

from weaviate.config import ConnectionConfig
from weaviate.connect import ConnectionV4, ConnectionParams

Expand All @@ -11,6 +12,7 @@ def connection() -> ConnectionV4:
(10, 60),
None,
True,
False,
None,
ConnectionConfig(),
None,
Expand Down
11 changes: 4 additions & 7 deletions weaviate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,20 @@
from httpx import HTTPError as HttpxError
from requests.exceptions import ConnectionError as RequestsConnectionError

from weaviate import syncify
from weaviate.backup.backup import _BackupAsync
from weaviate.backup.sync import _Backup


from weaviate import syncify
from weaviate.event_loop import _EventLoopSingleton, _EventLoop
from .auth import AuthCredentials
from .backup import Backup
from .batch import Batch
from .classification import Classification

from .client_base import _WeaviateClientBase
from .cluster import Cluster
from .collections.collections.async_ import _CollectionsAsync
from .collections.collections.sync import _Collections
from .collections.batch.client import _BatchClientWrapper
from .collections.cluster import _Cluster, _ClusterAsync
from .collections.collections.async_ import _CollectionsAsync
from .collections.collections.sync import _Collections
from .config import AdditionalConfig, Config
from .connect import Connection
from .connect.base import (
Expand All @@ -40,7 +38,6 @@
)
from .gql import Query
from .schema import Schema
from weaviate.event_loop import _EventLoopSingleton, _EventLoop
from .types import NUMBER
from .util import _get_valid_timeout_config, _type_request_response
from .warnings import _Warnings
Expand Down
4 changes: 1 addition & 3 deletions weaviate/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
import asyncio
from typing import Optional, Tuple, Union, Dict, Any


from weaviate.collections.classes.internal import _GQLEntryReturnType, _RawGQLReturn

from weaviate.integrations import _Integrations

from .auth import AuthCredentials
from .config import AdditionalConfig
from .connect import ConnectionV4
Expand Down Expand Up @@ -83,6 +80,7 @@ def __init__(
proxies=config.proxies,
trust_env=config.trust_env,
loop=self._loop,
disable_ssl_verification=config.disable_ssl_verification,
)

self.integrations = _Integrations(self._connection)
Expand Down
1 change: 1 addition & 0 deletions weaviate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class AdditionalConfig(BaseModel):
proxies: Union[str, Proxies, None] = Field(default=None)
timeout_: Union[Tuple[int, int], Timeout] = Field(default_factory=Timeout, alias="timeout")
trust_env: bool = Field(default=False)
disable_ssl_verification: bool = Field(default=False)

@property
def timeout(self) -> Timeout:
Expand Down
39 changes: 34 additions & 5 deletions weaviate/connect/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import datetime
import os
import socket
import ssl
import time
from abc import ABC, abstractmethod
from typing import Dict, Tuple, TypeVar, Union, cast
Expand All @@ -8,14 +10,13 @@
import grpc # type: ignore
from grpc import ssl_channel_credentials
from grpc.aio import Channel # type: ignore

# from grpclib.client import Channel

from pydantic import BaseModel, field_validator, model_validator

from weaviate.config import Proxies
from weaviate.types import NUMBER

# from grpclib.client import Channel


JSONPayload = Union[dict, list]
TIMEOUT_TYPE_RETURN = Tuple[NUMBER, NUMBER]
Expand Down Expand Up @@ -111,15 +112,43 @@ def _grpc_address(self) -> Tuple[str, int]:
def _grpc_target(self) -> str:
return f"{self.grpc.host}:{self.grpc.port}"

def _grpc_channel(self, proxies: Dict[str, str]) -> Channel:
def _grpc_channel(self, proxies: Dict[str, str], enable_ssl_verification: bool) -> Channel:
if (p := proxies.get("grpc")) is not None:
options: list = [*GRPC_DEFAULT_OPTIONS, ("grpc.http_proxy", p)]
else:
options = GRPC_DEFAULT_OPTIONS
if self.grpc.secure:
if enable_ssl_verification:
creds = ssl_channel_credentials()
else:
import logging

logging.basicConfig(level=logging.DEBUG)

# download certificate from server. This is super hacky, but the grpc library does NOT offer a way to
# disable certificate verification. There are probably a number of edge cases that this does not cover.
context = ssl.create_default_context()
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
targets = self.grpc.host.replace("http://", "", 1).split(":")

with socket.create_connection((targets[0], self.grpc.port)) as sock:
with context.wrap_socket(
sock, server_hostname=self._grpc_target
) as secure_sock:
cert_binary = secure_sock.getpeercert(binary_form=True)
if cert_binary is None:
raise ValueError(
"Failed to retrieve the server certificate to bypass ssl verification."
)

cert = ssl.DER_cert_to_PEM_cert(cert_binary)

creds = ssl_channel_credentials(root_certificates=cert.encode())

return grpc.aio.secure_channel(
target=self._grpc_target,
credentials=ssl_channel_credentials(),
credentials=creds,
options=options,
)
else:
Expand Down
8 changes: 7 additions & 1 deletion weaviate/connect/v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
timeout_config: TimeoutConfig,
proxies: Union[str, Proxies, None],
trust_env: bool,
disable_ssl_verification: bool,
additional_headers: Optional[Dict[str, Any]],
connection_config: ConnectionConfig,
loop: asyncio.AbstractEventLoop, # required for background token refresh
Expand All @@ -115,6 +116,7 @@ def __init__(
self.timeout_config = timeout_config
self.__connection_config = connection_config
self.__trust_env = trust_env
self.__enable_ssl_verification = not disable_ssl_verification
self._weaviate_version = _ServerVersion.from_string("")
self.__connected = False
self.__loop = loop
Expand Down Expand Up @@ -211,6 +213,7 @@ def __make_mounts(self) -> Dict[str, AsyncHTTPTransport]:
proxy=Proxy(url=proxy),
retries=self.__connection_config.session_pool_max_retries,
trust_env=self.__trust_env,
verify=self.__enable_ssl_verification,
)
for key, proxy in self._proxies.items()
if key != "grpc"
Expand All @@ -221,6 +224,7 @@ def __make_async_client(self) -> AsyncClient:
headers=self._headers,
mounts=self.__make_mounts(),
trust_env=self.__trust_env,
verify=self.__enable_ssl_verification,
)

def __make_clients(self) -> None:
Expand All @@ -229,7 +233,9 @@ def __make_clients(self) -> None:
async def _open_connections(
self, auth_client_secret: Optional[AuthCredentials], skip_init_checks: bool
) -> None:
self._grpc_channel = self._connection_params._grpc_channel(proxies=self._proxies)
self._grpc_channel = self._connection_params._grpc_channel(
proxies=self._proxies, enable_ssl_verification=self.__enable_ssl_verification
)
assert self._grpc_channel is not None
self._grpc_stub = weaviate_pb2_grpc.WeaviateStub(self._grpc_channel)

Expand Down
Loading