diff --git a/rl/llm/engines.py b/rl/llm/engines.py index 5593f88..3b9108b 100644 --- a/rl/llm/engines.py +++ b/rl/llm/engines.py @@ -47,7 +47,7 @@ class InferenceOutput(BaseModel): prompt: InferenceInput text: str - logprobs: list[dict[str, float]] | None = None + logprobs: list[dict[int, float]] | None = None metadata: dict[str, Any] = Field(default_factory=dict)