Skip to content

Commit

Permalink
Integrate llm auto classification
Browse files Browse the repository at this point in the history
  • Loading branch information
sudan45 authored and AdityaKhatri committed Dec 4, 2024
1 parent 1664543 commit 968a83b
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 6 deletions.
8 changes: 7 additions & 1 deletion apps/assisted_tagging/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
from admin_auto_filters.filters import AutocompleteFilterFactory
from django.contrib import admin

from assisted_tagging.models import AssistedTaggingModelPredictionTag, AssistedTaggingPrediction, DraftEntry, LLMAssistedTaggingPredication
from assisted_tagging.models import (
AssistedTaggingModelPredictionTag,
AssistedTaggingPrediction,
DraftEntry,
LLMAssistedTaggingPredication
)
from deep.admin import VersionAdmin

admin.site.register(LLMAssistedTaggingPredication)


@admin.register(DraftEntry)
class DraftEntryAdmin(VersionAdmin):
search_fields = ['lead']
Expand Down
1 change: 1 addition & 0 deletions apps/assisted_tagging/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def batch_load_fn(self, keys):
_map[assisted_tagging.draft_entry_id].append(assisted_tagging)
return Promise.resolve([_map.get(key, []) for key in keys])


class LLMDraftEntryPredicationsLoader(DataLoaderWithContext):
def batch_load_fn(self, keys):
llm_assisted_tagging_qs = LLMAssistedTaggingPredication.objects.filter(draft_entry_id__in=keys)
Expand Down
4 changes: 2 additions & 2 deletions apps/assisted_tagging/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from deep.deepl import DeeplServiceEndpoint
from deepl_integration.handlers import (
LlmAssistedTaggingDraftEntryHandler,
AutoAssistedTaggingDraftEntryHandler,
LLMAutoAssistedTaggingDraftEntryHandler,
BaseHandler as DeepHandler
)

Expand Down Expand Up @@ -102,7 +102,7 @@ def trigger_request_for_draft_entry_task(draft_entry_id):
@redis_lock('trigger_request_for_auto_draft_entry_task_{0}', 60 * 60 * 0.5)
def trigger_request_for_auto_draft_entry_task(lead_id):
lead = Lead.objects.get(id=lead_id)
return AutoAssistedTaggingDraftEntryHandler.auto_trigger_request_to_extractor(lead)
return LLMAutoAssistedTaggingDraftEntryHandler.auto_trigger_request_to_extractor(lead)


@shared_task
Expand Down
147 changes: 147 additions & 0 deletions apps/deepl_integration/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,3 +1249,150 @@ def save_data(cls, draft_entry, data):
draft_entry.save_geo_data()
draft_entry.save()
return draft_entry


class LLMAutoAssistedTaggingDraftEntryHandler(BaseHandler):
model = Lead
callback_url_name = 'auto-llm-assisted_tagging_draft_entry_prediction_callback'

@classmethod
def auto_trigger_request_to_extractor(cls, lead):
lead_preview = LeadPreview.objects.get(lead=lead)
payload = {
"documents": [
{
"client_id": cls.get_client_id(lead),
"text_extraction_id": str(lead_preview.text_extraction_id),
}
],
'project_id': lead.project_id,
'af_id': lead.project.analysis_framework_id,
"callback_url": cls.get_callback_url()
}
response_content = None
try:
response = requests.post(
url=DeeplServiceEndpoint.LLM_ENTRY_EXTRACTION_CLASSIFICATION,
headers=cls.REQUEST_HEADERS,
json=payload
)
response_content = response.content
if response.status_code == 202:
lead.auto_entry_extraction_status = Lead.AutoExtractionStatus.PENDING
lead.save(update_fields=('auto_entry_extraction_status',))
return True

except Exception:
logger.error('Entry Extraction send failed, Exception occurred!!', exc_info=True)
lead.auto_entry_extraction_status = Lead.AutoExtractionStatus.FAILED
lead.save(update_fields=('auto_entry_extraction_status',))
logger.error(
'Entry Extraction send failed!!',
extra={
'data': {
'payload': payload,
'response': response_content,
},
},
)

# --- Callback logics
@staticmethod
def _get_or_create_models_version(models_data):
def get_versions_map():
return {
(model_version.model.model_id, model_version.version): model_version
for model_version in AssistedTaggingModelVersion.objects.filter(
reduce(
lambda acc, item: acc | item,
[
models.Q(
model__model_id=model_data['name'],
version=model_data['version'],
)
for model_data in models_data
],
)
).select_related('model').all()
}

existing_model_versions = get_versions_map()
new_model_versions = [
model_data
for model_data in models_data
if (model_data['name'], model_data['version']) not in existing_model_versions
]

if new_model_versions:
AssistedTaggingModelVersion.objects.bulk_create([
AssistedTaggingModelVersion(
model=AssistedTaggingModel.objects.get_or_create(
model_id=model_data['name'],
defaults=dict(
name=model_data['name'],
),
)[0],
version=model_data['version'],
)
for model_data in models_data
])
existing_model_versions = get_versions_map()
return existing_model_versions

@classmethod
def _process_model_preds(cls, model_version, draft_entry, model_prediction):
prediction_status = model_prediction['prediction_status']
if not prediction_status: # If False no tags are provided
return

tags = model_prediction.get('classification', {}) # NLP TagId

common_attrs = dict(
model_version=model_version,
draft_entry_id=draft_entry.id,
)
LLMAssistedTaggingPredication.objects.create(
**common_attrs,
model_tags=tags
)
# draft_entry.prediction_status = DraftEntry.PredictionStatus.DONE
# draft_entry.save(update_fields='prediction_status')

@classmethod
@transaction.atomic
def save_data(cls, lead, data_url):
# NOTE: Schema defined here
# - https://docs.google.com/document/d/1NmjOO5sOrhJU6b4QXJBrGAVk57_NW87mLJ9wzeY_NZI/edit#heading=h.t3u7vdbps5pt
data = RequestHelper(url=data_url, ignore_error=True).json()
draft_entry_qs = DraftEntry.objects.filter(lead=lead, type=DraftEntry.Type.AUTO)
if draft_entry_qs.exists():
raise serializers.ValidationError('Draft entries already exit')
for model_preds in data['blocks']:
if not model_preds['relevant']:
continue
models_version_map = cls._get_or_create_models_version([
data['classification_model_info']
])
draft = DraftEntry.objects.create(
page=model_preds['page'],
text_order=model_preds['textOrder'],
project=lead.project,
lead=lead,
excerpt=model_preds['text'],
prediction_status=DraftEntry.PredictionStatus.STARTED,
type=DraftEntry.Type.AUTO
)
if model_preds['geolocations']:
geo_areas_qs = GeoAreaGqlFilterSet(
data={'titles': [geo['entity'] for geo in model_preds['geolocations']]},
queryset=GeoArea.get_for_project(lead.project)
).qs.distinct('title')
draft.related_geoareas.set(geo_areas_qs)

model_version = models_version_map[
(data['classification_model_info']['name'], data['classification_model_info']['version'])
]
cls._process_model_preds(model_version, draft, model_preds)
lead.auto_entry_extraction_status = Lead.AutoExtractionStatus.SUCCESS
lead.save(update_fields=('auto_entry_extraction_status',))
return lead
17 changes: 16 additions & 1 deletion apps/deepl_integration/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
AnalysisAutomaticSummaryHandler,
AnalyticalStatementNGramHandler,
AnalyticalStatementGeoHandler,
AutoAssistedTaggingDraftEntryHandler
AutoAssistedTaggingDraftEntryHandler,
LLMAutoAssistedTaggingDraftEntryHandler,
)

from deduplication.tasks.indexing import index_lead_and_calculate_duplicates
Expand Down Expand Up @@ -302,6 +303,20 @@ def create(self, validated_data):
)


class AutoLLMAssistedTaggingDraftEntryCallbackSerializer(BaseCallbackSerializer):
entry_extraction_classification_path = serializers.URLField(required=True)
text_extraction_id = serializers.CharField(required=True)
status = serializers.IntegerField()
nlp_handler = LLMAutoAssistedTaggingDraftEntryHandler

def create(self, validated_data):
obj = validated_data['object']
return self.nlp_handler.save_data(
obj,
validated_data['entry_extraction_classification_path'],
)


class EntriesCollectionBaseCallbackSerializer(DeeplServerBaseCallbackSerializer):
model: Type[DeeplTrackBaseModel]
presigned_s3_url = serializers.URLField()
Expand Down
7 changes: 6 additions & 1 deletion apps/deepl_integration/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@

from .serializers import (
AssistedTaggingDraftEntryPredictionCallbackSerializer,
AutoLLMAssistedTaggingDraftEntryCallbackSerializer,
LlmAssistedTaggingDraftEntryPredictionCallbackSerializer,
LeadExtractCallbackSerializer,
UnifiedConnectorLeadExtractCallbackSerializer,
AnalysisTopicModelCallbackSerializer,
AnalysisAutomaticSummaryCallbackSerializer,
AnalyticalStatementNGramCallbackSerializer,
AnalyticalStatementGeoCallbackSerializer,
AutoAssistedTaggingDraftEntryCallbackSerializer
AutoAssistedTaggingDraftEntryCallbackSerializer,
)


Expand All @@ -43,6 +44,10 @@ class AutoTaggingDraftEntryPredictionCallbackView(BaseCallbackView):
serializer = AutoAssistedTaggingDraftEntryCallbackSerializer


class AutoLLMTaggingDraftEntryPredictionCallbackView(BaseCallbackView):
serializer = AutoLLMAssistedTaggingDraftEntryCallbackSerializer


class LeadExtractCallbackView(BaseCallbackView):
serializer = LeadExtractCallbackSerializer

Expand Down
1 change: 1 addition & 0 deletions deep/deepl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ class DeeplServiceEndpoint():
ASSISTED_TAGGING_ENTRY_PREDICT_ENDPOINT = f'{DEEPL_SERVER_DOMAIN}/api/v1/entry-classification/'
ENTRY_EXTRACTION_CLASSIFICATION = f'{DEEPL_SERVER_DOMAIN}/api/v1/entry-extraction-classification/'
LLM_ASSISTED_TAGGING_ENTRY_PREDICT_ENDPOINT = f'{DEEPL_SERVER_DOMAIN}/api/v1/entry-classification-llm/'
LLM_ENTRY_EXTRACTION_CLASSIFICATION = f'{DEEPL_SERVER_DOMAIN}/api/v1/entry-extraction-classification-llm/'
7 changes: 7 additions & 0 deletions deep/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
AssistedTaggingDraftEntryPredictionCallbackView,
LlmAssistedTaggingDraftEntryPredictionCallbackView,
AutoTaggingDraftEntryPredictionCallbackView,
AutoLLMTaggingDraftEntryPredictionCallbackView,
LeadExtractCallbackView,
UnifiedConnectorLeadExtractCallbackView,
AnalysisTopicModelCallbackView,
Expand Down Expand Up @@ -587,6 +588,12 @@ def get_api_path(path):
name='llm-assisted_tagging_draft_entry_prediction_callback',
),

re_path(
get_api_path(r'callback/auto-llm-assisted-tagging-draft-entry-prediction/$'),
AutoLLMTaggingDraftEntryPredictionCallbackView.as_view(),
name='auto-llm-assisted_tagging_draft_entry_prediction_callback',
),

re_path(
get_api_path(r'callback/analysis-topic-model/$'),
AnalysisTopicModelCallbackView.as_view(),
Expand Down
2 changes: 1 addition & 1 deletion schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -4614,9 +4614,9 @@ type JwtTokenType {

type LLMAssistedTaggingPredictionType {
id: ID!
modelTags: GenericScalar
modelVersion: ID!
draftEntry: ID!
modelTags: GenericScalar
}

enum LeadAutoEntryExtractionTypeEnum {
Expand Down

0 comments on commit 968a83b

Please sign in to comment.