Skip to content

Commit

Permalink
Merge pull request #12 from Clarifai/refactoring1
Browse files Browse the repository at this point in the history
[DEVX-384] Evaluations of already made predictions without GT
  • Loading branch information
isaac-chung authored May 17, 2024
2 parents 7450241 + d767d6e commit aae0cef
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 49 deletions.
17 changes: 17 additions & 0 deletions clarifai_model_utils/llm_eval/evaluator/harness_eval/judge_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,20 @@ def process_rag_result(self, doc, results):
results[m] = 0

return results

def process_rag_result_df(self, df):
# Take value of `question
query = df["question"]
prediction = df["prediction"]
results = dict()
for metric, executor in self.rag_metrics.items():
if metric != "correctness":
results.update({
metric: executor.evaluate_strings(input=query, prediction=prediction)['score']
})

for m in results:
if results[m] is None:
results[m] = 0

return results
4 changes: 2 additions & 2 deletions clarifai_model_utils/llm_eval/evaluator/harness_eval/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ...constant import MODEL


def clarifailm_completion(_self, prompt, **kwargs):
def clarifailm_completion(_self, prompt, inference_params=None, **kwargs):
"""Query Clarifai API for completion.
Retry with back-off until they respond
Expand All @@ -24,7 +24,7 @@ def clarifailm_completion(_self, prompt, **kwargs):
# get final output of workflow/model
if _self.is_model:
response = _self.client.predict_by_bytes(
input_bytes=prompt, input_type="text",
input_bytes=prompt, input_type="text",
inference_params=_self.inference_parameters).outputs[-1].data.text.raw
else:
if _self.is_rag_workflow:
Expand Down
59 changes: 12 additions & 47 deletions clarifai_model_utils/llm_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from google.protobuf import struct_pb2

from clarifai.utils.logging import get_logger
from clarifai.client.app import App
from clarifai.client.input import Inputs
from clarifai.client.model import Model

logger = get_logger(name='clarifai_llm_eval-' + __file__)

Expand All @@ -32,21 +35,13 @@ def split_sample_general_template(text, split_word) -> tuple:

def get_text_dataset_inputs(auth, user_id: str, app_id: str, dataset_id: str, max_input=100):
stub: V2Stub = auth.get_stub()
user_app_id = resources_pb2.UserAppIDSet(user_id=user_id, app_id=app_id)

# get number of samples of dataset
get_dataset_resp = stub.GetDataset(
service_pb2.GetDatasetRequest(
user_app_id=user_app_id,
dataset_id=dataset_id,
),
metadata=auth.metadata)
if get_dataset_resp.status.code != status_code_pb2.SUCCESS:
logger.error(get_dataset_resp.status)
return False, get_dataset_resp.status.description
app = App(app_id=app_id, user_id=user_id)
dataset = app.dataset(dataset_id=dataset_id)

max_per_page = 128
total_samples = get_dataset_resp.dataset.version.metrics['/'].inputs_count.value
total_samples = dataset.version.metrics['/'].inputs_count.value
if not max_input:
per_page = max_per_page
chunks = total_samples // per_page
Expand All @@ -57,21 +52,12 @@ def get_text_dataset_inputs(auth, user_id: str, app_id: str, dataset_id: str, ma
else:
per_page = max_per_page

input_obj = Inputs(user_id=user_id, app_id=app_id, pat=auth._pat)

urls = []
for page in range(chunks + 1):
list_input_response = stub.ListDatasetInputs(
service_pb2.ListDatasetInputsRequest(
user_app_id=user_app_id,
dataset_id=dataset_id,
page=page,
per_page=per_page,
),
metadata=auth.metadata)
if list_input_response.status.code != status_code_pb2.SUCCESS:
logger.error(list_input_response.status)
return False, list_input_response.status.description

_urls = [item.input.data.text.url for item in list_input_response.dataset_inputs]
all_inputs = list(input_obj.list_inputs(input_type='text', dataset_id=dataset_id, per_page=per_page, page_no=page))
_urls = [item.input.data.text.url for item in all_inputs.dataset_inputs]
if len(_urls) < 1:
break
urls += _urls
Expand Down Expand Up @@ -124,29 +110,8 @@ def _post_call(query: str, query_id: str):


def post_ext_metrics_eval(auth, model_id, version_id, eval_id, ext_metrics):
metrics = struct_pb2.Struct()
metrics.update(ext_metrics)
metrics = resources_pb2.ExtendedMetrics(user_metrics=metrics)

stub = auth.get_stub()
user_app_id = resources_pb2.UserAppIDSet(user_id=auth.user_id, app_id=auth.app_id)
post_eval = stub.PostEvaluations(
service_pb2.PostEvaluationsRequest(
user_app_id=user_app_id,
eval_metrics=[
resources_pb2.EvalMetrics(
id=eval_id,
model=resources_pb2.Model(
id=model_id,
app_id=auth.app_id,
user_id=auth.user_id,
model_version=resources_pb2.ModelVersion(id=version_id),
),
extended_metrics=metrics if ext_metrics else None)
],
),
metadata=auth.metadata,
)
model = Model(app_id=auth.app_id, user_id=auth.user_id, model_id=model_id, model_version=dict({"id": version_id}))
post_eval = Model.evaluate(dataset_id="", eval_id=eval_id, extended_metrics=ext_metrics)
return post_eval


Expand Down

0 comments on commit aae0cef

Please sign in to comment.