Skip to content

Commit

Permalink
feat(LAB-3244): on LLM dynamic projects export annotations at convers…
Browse files Browse the repository at this point in the history
…ation level
  • Loading branch information
FannyGaudin authored and baptiste-olivier committed Nov 22, 2024
1 parent c2bb207 commit 3e777ee
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 22 deletions.
46 changes: 30 additions & 16 deletions src/kili/llm/services/export/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def export(
chat_items = label["chatItems"]
annotations = label["annotations"]
rounds = self._build_rounds(chat_items, annotations, json_interface)
total_rounds = len(rounds)
for step, round in enumerate(rounds):
raw_data = _format_raw_data(
round["context"]
Expand All @@ -59,25 +60,30 @@ def export(
label["id"],
obfuscated_models,
)
formatted_response = _format_json_response(
json_interface["jobs"],
round["annotations"],
round["completion"],
obfuscated_models,
)
label_data = {
"author": label["author"]["email"],
"created_at": label["createdAt"],
"label_type": label["labelType"],
"label": {},
}
if formatted_response["turn"]:
label_data["label"]["turn"] = formatted_response["turn"]
if step == total_rounds - 1 and formatted_response["conversation"]:
label_data["label"]["conversation"] = formatted_response["conversation"]

result[f"{step}"] = {
"external_id": asset["externalId"],
"metadata": asset["jsonMetadata"],
"models": _format_models_object(
asset["assetProjectModels"], obfuscated_models
),
"labels": [
{
"author": label["author"]["email"],
"created_at": label["createdAt"],
"label_type": label["labelType"],
"label": _format_json_response(
json_interface["jobs"],
round["annotations"],
round["completion"],
obfuscated_models,
),
}
],
"labels": [label_data],
"raw_data": raw_data,
"status": asset["status"],
}
Expand Down Expand Up @@ -173,6 +179,9 @@ def _build_rounds(self, chat_items, annotations, json_interface):
has_children = True
parent_target = chat_item["id"]

current_round["annotations"] += [
annotation for annotation in annotations if annotation["chatItemId"] is None
]
rounds.append(current_round)
new_context = (
current_round["context"]
Expand All @@ -191,6 +200,9 @@ def _build_rounds(self, chat_items, annotations, json_interface):

raise ValueError(f"Role {node['role']} not supported")
if current_round["prompt"] is not None:
current_round["annotations"] += [
annotation for annotation in annotations if annotation["chatItemId"] is None
]
rounds.append(current_round)
return rounds

Expand Down Expand Up @@ -225,8 +237,8 @@ def _format_comparison_annotation(annotation, completions, job, obfuscated_model

def _format_json_response(
jobs_config: Dict, annotations: List[Dict], completions: List[Dict], obfuscated_models: Dict
) -> Dict[str, Union[str, List[str]]]:
result = {}
) -> Dict[str, Dict[str, Union[str, List[str]]]]:
result = {"turn": {}, "conversation": {}}
for annotation in annotations:
formatted_response = None
job = jobs_config[annotation["job"]]
Expand All @@ -243,8 +255,10 @@ def _format_json_response(
logging.warning(
f"Annotation with job {annotation['job']} with mlTask {job['mlTask']} not supported. Ignored in the export."
)
elif "level" in job and job["level"] == "conversation":
result["conversation"][annotation["job"]] = formatted_response
else:
result[annotation["job"]] = formatted_response
result["turn"][annotation["job"]] = formatted_response

return result

Expand Down
4 changes: 3 additions & 1 deletion src/kili/llm/services/export/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _process_llm_v1(

def _format_json_response(
jobs_config: Dict, json_response: Dict
) -> Dict[str, Union[str, List[str]]]:
) -> Dict[str, Dict[str, Union[str, List[str]]]]:
result = {}
for job_name, job_value in json_response.items():
job_config = jobs_config[job_name]
Expand All @@ -83,6 +83,8 @@ def _format_json_response(
result[job_name] = job_value["text"]
else:
logging.warning(f"Job {job_name} with mlTask {job_config['mlTask']} not supported")
if len(result) != 0:
return {"conversation": result}
return result


Expand Down
99 changes: 96 additions & 3 deletions tests/unit/llm/services/export/test_dynamic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import pytest

from kili.llm.presentation.client.llm import LlmClientMethods
Expand Down Expand Up @@ -278,7 +280,9 @@
"author": "[email protected]",
"created_at": "2024-08-06T12:30:42.122Z",
"label_type": "DEFAULT",
"label": {"COMPARISON_JOB": "A_3", "CLASSIFICATION_JOB": ["BOTH_ARE_GOOD"]},
"label": {
"turn": {"COMPARISON_JOB": "A_3", "CLASSIFICATION_JOB": ["BOTH_ARE_GOOD"]},
},
}
],
},
Expand Down Expand Up @@ -353,7 +357,9 @@
"author": "[email protected]",
"created_at": "2024-08-06T12:30:42.122Z",
"label_type": "DEFAULT",
"label": {"COMPARISON_JOB": "B_1"},
"label": {
"turn": {"COMPARISON_JOB": "B_1"},
},
}
],
},
Expand Down Expand Up @@ -442,7 +448,9 @@
"author": "[email protected]",
"created_at": "2024-08-06T12:30:42.122Z",
"label_type": "DEFAULT",
"label": {"COMPARISON_JOB": "A_2"},
"label": {
"turn": {"COMPARISON_JOB": "A_2"},
},
}
],
},
Expand Down Expand Up @@ -616,3 +624,88 @@ def test_export_dynamic_empty_json_interface(mocker):
kili_llm.export(
project_id="project_id",
)


def test_export_dynamic_with_conversation_level(mocker):
updated_mock_json_interface = copy.deepcopy(mock_json_interface)

updated_mock_json_interface["jobs"].update(
{
"CLASSIFICATION_JOB_0": {
"content": {
"categories": {
"GOOD": {"children": [], "name": "Good", "id": "category7"},
"BAD": {"children": [], "name": "Bad", "id": "category8"},
},
"input": "radio",
},
"level": "conversation",
"instruction": "Overall quality",
"mlTask": "CLASSIFICATION",
"required": 1,
"isChild": False,
"isNew": False,
},
"TRANSCRIPTION_JOB": {
"content": {"input": "textField"},
"level": "conversation",
"instruction": "Write something about the overall quality",
"mlTask": "TRANSCRIPTION",
"required": 1,
"isChild": False,
"isNew": False,
},
}
)

updated_mock_fetch_assets = copy.deepcopy(mock_fetch_assets)
updated_mock_fetch_assets[0]["labels"][0]["annotations"].extend(
[
{
"id": "20241025134207822-9",
"job": "CLASSIFICATION_JOB_0",
"path": [],
"labelId": "clzief6q2003e7tc91jm46uii",
"chatItemId": None,
"annotationValue": {
"categories": ["GOOD"],
},
"__typename": "ClassificationAnnotation",
},
{
"id": "20241025134209366-10",
"job": "TRANSCRIPTION_JOB",
"path": [],
"labelId": "clzief6q2003e7tc91jm46uii",
"chatItemId": None,
"annotationValue": {
"text": "something",
},
"__typename": "TranscriptionAnnotation",
},
]
)

updated_expected_export = copy.deepcopy(expected_export)
updated_expected_export[0]["2"]["labels"][0]["label"]["conversation"] = {
"CLASSIFICATION_JOB_0": ["GOOD"],
"TRANSCRIPTION_JOB": "something",
}
get_project_return_val = {
"jsonInterface": updated_mock_json_interface,
"inputType": "LLM_INSTR_FOLLOWING",
"title": "Test project",
"id": "project_id",
"dataConnections": None,
}
kili_api_gateway = mocker.MagicMock()
kili_api_gateway.count_assets.return_value = 3
kili_api_gateway.get_project.return_value = get_project_return_val
kili_api_gateway.list_assets.return_value = updated_mock_fetch_assets

kili_llm = LlmClientMethods(kili_api_gateway)

result = kili_llm.export(
project_id="project_id",
)
assert result == updated_expected_export
4 changes: 2 additions & 2 deletions tests/unit/llm/services/export/test_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"author": "[email protected]",
"created_at": "2024-08-05T13:03:00.051Z",
"label_type": "DEFAULT",
"label": {"CLASSIFICATION_JOB": ["A_BETTER_THAN_B"]},
"label": {"conversation": {"CLASSIFICATION_JOB": ["A_BETTER_THAN_B"]}},
}
],
},
Expand Down Expand Up @@ -117,7 +117,7 @@
"author": "[email protected]",
"created_at": "2024-08-05T13:03:03.061Z",
"label_type": "DEFAULT",
"label": {"CLASSIFICATION_JOB": ["B_BETTER_THAN_A"]},
"label": {"conversation": {"CLASSIFICATION_JOB": ["B_BETTER_THAN_A"]}},
}
],
},
Expand Down

0 comments on commit 3e777ee

Please sign in to comment.