diff --git a/clarifai_model_utils/llm_eval/evaluator/harness_eval/judge_llm.py b/clarifai_model_utils/llm_eval/evaluator/harness_eval/judge_llm.py index 023b419..ea38d64 100644 --- a/clarifai_model_utils/llm_eval/evaluator/harness_eval/judge_llm.py +++ b/clarifai_model_utils/llm_eval/evaluator/harness_eval/judge_llm.py @@ -65,3 +65,20 @@ def process_rag_result(self, doc, results): results[m] = 0 return results + + def process_rag_result_df(self, df): + # Take value of `question + query = df["question"] + prediction = df["prediction"] + results = dict() + for metric, executor in self.rag_metrics.items(): + if metric != "correctness": + results.update({ + metric: executor.evaluate_strings(input=query, prediction=prediction)['score'] + }) + + for m in results: + if results[m] is None: + results[m] = 0 + + return results diff --git a/clarifai_model_utils/llm_eval/evaluator/harness_eval/llm.py b/clarifai_model_utils/llm_eval/evaluator/harness_eval/llm.py index 99c12e8..7a89220 100644 --- a/clarifai_model_utils/llm_eval/evaluator/harness_eval/llm.py +++ b/clarifai_model_utils/llm_eval/evaluator/harness_eval/llm.py @@ -11,7 +11,7 @@ from ...constant import MODEL -def clarifailm_completion(_self, prompt, **kwargs): +def clarifailm_completion(_self, prompt, inference_params=None, **kwargs): """Query Clarifai API for completion. Retry with back-off until they respond @@ -24,7 +24,7 @@ def clarifailm_completion(_self, prompt, **kwargs): # get final output of workflow/model if _self.is_model: response = _self.client.predict_by_bytes( - input_bytes=prompt, input_type="text", + input_bytes=prompt, input_type="text", inference_params=_self.inference_parameters).outputs[-1].data.text.raw else: if _self.is_rag_workflow: diff --git a/clarifai_model_utils/llm_eval/utils.py b/clarifai_model_utils/llm_eval/utils.py index 7c1a531..f9c973f 100644 --- a/clarifai_model_utils/llm_eval/utils.py +++ b/clarifai_model_utils/llm_eval/utils.py @@ -12,6 +12,9 @@ from google.protobuf import struct_pb2 from clarifai.utils.logging import get_logger +from clarifai.client.app import App +from clarifai.client.input import Inputs +from clarifai.client.model import Model logger = get_logger(name='clarifai_llm_eval-' + __file__) @@ -32,21 +35,13 @@ def split_sample_general_template(text, split_word) -> tuple: def get_text_dataset_inputs(auth, user_id: str, app_id: str, dataset_id: str, max_input=100): stub: V2Stub = auth.get_stub() - user_app_id = resources_pb2.UserAppIDSet(user_id=user_id, app_id=app_id) # get number of samples of dataset - get_dataset_resp = stub.GetDataset( - service_pb2.GetDatasetRequest( - user_app_id=user_app_id, - dataset_id=dataset_id, - ), - metadata=auth.metadata) - if get_dataset_resp.status.code != status_code_pb2.SUCCESS: - logger.error(get_dataset_resp.status) - return False, get_dataset_resp.status.description + app = App(app_id=app_id, user_id=user_id) + dataset = app.dataset(dataset_id=dataset_id) max_per_page = 128 - total_samples = get_dataset_resp.dataset.version.metrics['/'].inputs_count.value + total_samples = dataset.version.metrics['/'].inputs_count.value if not max_input: per_page = max_per_page chunks = total_samples // per_page @@ -57,21 +52,12 @@ def get_text_dataset_inputs(auth, user_id: str, app_id: str, dataset_id: str, ma else: per_page = max_per_page + input_obj = Inputs(user_id=user_id, app_id=app_id, pat=auth._pat) + urls = [] for page in range(chunks + 1): - list_input_response = stub.ListDatasetInputs( - service_pb2.ListDatasetInputsRequest( - user_app_id=user_app_id, - dataset_id=dataset_id, - page=page, - per_page=per_page, - ), - metadata=auth.metadata) - if list_input_response.status.code != status_code_pb2.SUCCESS: - logger.error(list_input_response.status) - return False, list_input_response.status.description - - _urls = [item.input.data.text.url for item in list_input_response.dataset_inputs] + all_inputs = list(input_obj.list_inputs(input_type='text', dataset_id=dataset_id, per_page=per_page, page_no=page)) + _urls = [item.input.data.text.url for item in all_inputs.dataset_inputs] if len(_urls) < 1: break urls += _urls @@ -124,29 +110,8 @@ def _post_call(query: str, query_id: str): def post_ext_metrics_eval(auth, model_id, version_id, eval_id, ext_metrics): - metrics = struct_pb2.Struct() - metrics.update(ext_metrics) - metrics = resources_pb2.ExtendedMetrics(user_metrics=metrics) - - stub = auth.get_stub() - user_app_id = resources_pb2.UserAppIDSet(user_id=auth.user_id, app_id=auth.app_id) - post_eval = stub.PostEvaluations( - service_pb2.PostEvaluationsRequest( - user_app_id=user_app_id, - eval_metrics=[ - resources_pb2.EvalMetrics( - id=eval_id, - model=resources_pb2.Model( - id=model_id, - app_id=auth.app_id, - user_id=auth.user_id, - model_version=resources_pb2.ModelVersion(id=version_id), - ), - extended_metrics=metrics if ext_metrics else None) - ], - ), - metadata=auth.metadata, - ) + model = Model(app_id=auth.app_id, user_id=auth.user_id, model_id=model_id, model_version=dict({"id": version_id})) + post_eval = Model.evaluate(dataset_id="", eval_id=eval_id, extended_metrics=ext_metrics) return post_eval