Skip to content

Commit

Permalink
Merge pull request #39 from weavel-ai/dev
Browse files Browse the repository at this point in the history
fix: fix WebSocketClient initialization, StrEnum bugs
  • Loading branch information
aschung01 authored Oct 9, 2024
2 parents aae11b2 + 570ba09 commit 076c5e6
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 26 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setup(
name="weavel",
version="1.11.0",
version="1.11.1"
packages=find_namespace_packages(),
entry_points={},
description="Weavel, automated prompt engineering and observability for LLM applications",
Expand Down
15 changes: 7 additions & 8 deletions weavel/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
WsLocalGlobalMetricRequest,
WsLocalMetricRequest,
WsLocalTask,
WsServerOptimizeResponse,
WsServerTask,
)

Expand Down Expand Up @@ -90,7 +89,7 @@ def __init__(
flush_interval=flush_interval,
flush_batch_size=flush_batch_size,
)
self.ws_client = WebsocketClient(api_key=api_key, base_url=base_url)
self.ws_client = WebsocketClient(api_key=self.api_key, base_url=self.base_url)
self._generator_var: contextvars.ContextVar[Optional[BaseGenerator]] = (
contextvars.ContextVar("generator")
)
Expand Down Expand Up @@ -946,15 +945,15 @@ def _get_global_metric(self) -> Optional[BaseGlobalMetric]:
def _set_global_metric(self, global_metric: Optional[BaseGlobalMetric]):
self._global_metric_var.set(global_metric)

@websocket_handler(WsLocalTask.GENERATE.value)
@websocket_handler(WsLocalTask.GENERATE)
async def handle_generation_request(self, data: WsLocalGenerateRequest):
logger.debug("Handling generation request...")
generator = self._get_generator()
if not generator:
raise AttributeError("Generate not set")
return await generator(prompt=Prompt(**data["prompt"]), inputs=data["inputs"])

@websocket_handler(WsLocalTask.EVALUATE.value)
@websocket_handler(WsLocalTask.EVALUATE)
async def handle_evaluation_request(
self, data: WsLocalEvaluateRequest
) -> WsLocalEvaluateResponse:
Expand Down Expand Up @@ -988,7 +987,7 @@ async def handle_evaluation_request(
"global_result": global_result.model_dump(),
}

@websocket_handler(WsLocalTask.METRIC.value)
@websocket_handler(WsLocalTask.METRIC)
async def handle_metric_request(self, data: WsLocalMetricRequest):
logger.debug("Handling metric request...")
metric = self._get_metric()
Expand All @@ -997,7 +996,7 @@ async def handle_metric_request(self, data: WsLocalMetricRequest):
res = await metric(dataset_item=data["dataset_item"], pred=data["pred"])
return res.model_dump()

@websocket_handler(WsLocalTask.GLOBAL_METRIC.value)
@websocket_handler(WsLocalTask.GLOBAL_METRIC)
async def handle_global_metric_request(self, data: WsLocalGlobalMetricRequest):
logger.debug("Handling global metric request...")
global_metric = self._get_global_metric()
Expand All @@ -1007,7 +1006,7 @@ async def handle_global_metric_request(self, data: WsLocalGlobalMetricRequest):
res = await global_metric(results=results)
return res.model_dump()

@websocket_handler(WsServerTask.OPTIMIZE.value)
@websocket_handler(WsServerTask.OPTIMIZE)
async def handle_optimization_result(self, data: Dict[str, Any]):
# Extract the correlation_id from the response data
correlation_id = data.get("correlation_id")
Expand Down Expand Up @@ -1101,7 +1100,7 @@ async def optimize(
dataset = await self.acreate_dataset(name=dataset_name)
dataset_created = True
dataset_items = [
WvDatasetItem(inputs=item["inputs"], outputs=item["outputs"])
WvDatasetItem(inputs=item["inputs"], outputs=item.get("outputs", None))
for item in trainset
]
await self.acreate_dataset_items(dataset_name, dataset_items)
Expand Down
16 changes: 7 additions & 9 deletions weavel/clients/websocket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ async def connect_to_gateway(self):
self.endpoint,
extra_headers=headers,
ping_interval=30,
ping_timeout=10,
close_timeout=60,
ping_timeout=60,
close_timeout=60 * 5,
)
logger.info("WebSocket connection established")
self.heartbeat_task = asyncio.create_task(self.heartbeat())
Expand Down Expand Up @@ -197,10 +197,6 @@ async def _process_message(self, message: str):
else:
logger.warning(f"Unknown message type: {message_type}")

# Optionally handle reset events or other logic here
if message_type in self.relevant_message_types():
logger.debug(f"Ignoring reset_event for {message_type}")

except Exception:
logger.exception("Error processing message")

Expand All @@ -210,8 +206,10 @@ def relevant_message_types(self) -> List[str]:
Add all relevant message types that should reset the timeout here.
"""
return [
WsLocalTask.GENERATE.value,
WsLocalTask.EVALUATE.value,
WsLocalTask.GENERATE,
WsLocalTask.EVALUATE,
WsLocalTask.METRIC,
WsLocalTask.GLOBAL_METRIC,
# Add other message types as needed
]

Expand Down Expand Up @@ -370,7 +368,7 @@ async def request(self, type: WsServerTask, data: Dict[str, Any] = {}):

message = {
"correlation_id": correlation_id,
"type": type.value,
"type": type,
"data": data,
}
try:
Expand Down
12 changes: 4 additions & 8 deletions weavel/types/websocket.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
from enum import StrEnum
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Union
from typing_extensions import TypedDict
from openai.types.chat.completion_create_params import ChatCompletionMessageParam
from ape.common.types import DatasetItem, MetricResult, GlobalMetricResult
class WsLocalTask(StrEnum):

class WsLocalTask(str, Enum):
GENERATE = "GENERATE"
EVALUATE = "EVALUATE"
METRIC = "METRIC"
GLOBAL_METRIC = "GLOBAL_METRIC"


class WsServerTask(StrEnum):
class WsServerTask(str, Enum):
OPTIMIZE = "OPTIMIZE"


class WsServerOptimizeResponse(StrEnum):
OPTIMIZATION_COMPLETE = "OPTIMIZATION_COMPLETE"


class BaseWsLocalRequest(TypedDict):
type: WsLocalTask
correlation_id: str
Expand Down

0 comments on commit 076c5e6

Please sign in to comment.