The Task interface is defined as follows:
class Task:
def __init__(self, name: str, concurrency: int = 1, *args, **kwargs):
self.name = name
self.concurrency = concurrency
def get_indices(self) -> List[SampleIndex]:
raise NotImplementedError()
async def start_sample(
self, index: SampleIndex, session: Session
) -> TaskSampleExecutionResult:
raise NotImplementedError()
def calculate_overall(self, results: List[TaskOutput]) -> Dict[str, Any]:
raise NotImplementedError()
def release(self):
pass
To implement your own Task, you just need to inherit from Task and implement the corresponding interfaces. The specific interfaces are described as follows:
name
: Task name, usually specified in the configconcurrency
: The maximum concurrency supported within a workerget_indices
: Returns the indices of all samplesstart_sample
: Logic within a single sample, whereindex
is the index of the sample to be tested, andsession
is a proxy of the Agent.calculate_overall
: Calculates the score after all samples have been tested; the return format is arbitrary and will eventually be saved tooverall.json
.release
: Cleanup tasks that need to be executed after the task_worker process ends. Note that this is after the entire worker process ends, not after a particular sample ends.
The definition of the structures in the program is as follows:
SampleIndex = Union[int, str]
JSONSerializable = Union[None, bool, int, float, str, List[Any], Dict[str, Any]]
class TaskSampleExecutionResult(BaseModel):
status: SampleStatus = SampleStatus.COMPLETED
result: JSONSerializable = None
class TaskOutput(BaseModel):
index: Union[None, SampleIndex] = None
status: SampleStatus = SampleStatus.RUNNING # directly from TaskSampleExecutionResult
result: JSONSerializable = None # directly from TaskSampleExecutionResult
history: Union[None, List[ChatHistoryItem]] = None
class SampleStatus(str, Enum):
RUNNING = "running"
COMPLETED = "completed"
AGENT_CONTEXT_LIMIT = "agent context limit"
AGENT_VALIDATION_FAILED = "agent validation failed"
AGENT_INVALID_ACTION = "agent invalid action"
TASK_LIMIT_REACHED = "task limit reached"
UNKNOWN = "unknown"
TASK_ERROR = "task error"
class ChatHistoryItem(BaseModel):
role: Literal["user", "agent"]
content: str
Note that when returning TaskSampleExecutionResult
in start_sample
, you should carefully examine the completion status of the sample. If it is completed normally, it should be marked as COMPLETED
. The relevant data of the completion status of the sample will be automatically counted by the framework.
The Session
implements the following interfaces:
def inject(self, item: Union[ChatHistoryItem, List[ChatHistoryItem]])
: Inserts one or more historical records.async def action(self, *injection) -> AgentOutput
: Waits for the Agent's response, and for convenience, it also supports inserting one or more historical records at this time.
The definition of AgentOutput
is as follows:
class AgentOutput(BaseModel):
status: AgentOutputStatus = AgentOutputStatus.NORMAL
content: Union[str, None] = None
class AgentOutputStatus(str, Enum):
NORMAL = "normal"
CANCELLED = "cancelled"
AGENT_CONTEXT_LIMIT = "agent context limit"
After obtaining AgentOutput
, you need to handle it carefully and determine whether the AgentOutputStatus
is normal. If it is not normal, corresponding processing is required. If the status is CANCELLED
, it means that the client needs to cancel the test of this sample for some reason. At this time, you can quickly end this sample in any way to ensure that it does not affect subsequent tests.
A simple implementation is as follows:
class VirtualTask(Task):
def __init__(self, *args, **kwargs) -> None:
super().__init__(name="virtual-task", *args, **kwargs)
def get_indices(self) -> List[Any]:
return list(range(10))
async def start_sample(self, index, session: Session):
print("task start sample")
for loop_times in range(3):
await asyncio.sleep(1)
res = await session.action(
{"role": "user", "content": "Loop: %d" % loop_times}
)
print("TASK", res.content)
return TaskSampleExecutionResult(
status=SampleStatus.COMPLETED,
result={"result": "ok"},
)
def calculate_overall(self, results: List[TaskOutput]) -> Dict[str, Any]:
return {"score": 0.4}
The data in the original get_data
can be directly bound to self
in the __init__
method, and the corresponding data can be obtained from self.data
in start_sample
according to index
. On this basis, implement get_indices
. If the full set of test samples is a list, you can directly return list(range(len(self.data)))
.
First, change the original def
to async def
. At the same time, change the original session.action
to await session.action
. Finally, the return value needs to set an additional status
compared to the original. This helps to automatically count the reasons for sample errors, which is beneficial for further experimental analysis.
This change was initially made to facilitate the counting of samples. If you don't want to change the original metrics
, you can also create a new calculate_overall
function and call self.metrics
within the function.
If you originally overrode the predict_all
method, it cannot be used in the new framework.