Skip to content

Commit

Permalink
Feat/firecrawl data source (#5232)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas <[email protected]>
Co-authored-by: chenhe <[email protected]>
Co-authored-by: takatost <[email protected]>
  • Loading branch information
4 people authored Jun 14, 2024
1 parent 918ebe1 commit ba5f8af
Show file tree
Hide file tree
Showing 36 changed files with 1,174 additions and 64 deletions.
3 changes: 2 additions & 1 deletion api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,5 @@ WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5

# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_MAX_EXECUTION_TIME=1200

4 changes: 2 additions & 2 deletions api/controllers/console/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
)

# Import auth controllers
from .auth import activate, data_source_oauth, login, oauth
from .auth import activate, data_source_bearer_auth, data_source_oauth, login, oauth

# Import billing controllers
from .billing import billing

# Import datasets controllers
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing, website

# Import explore controllers
from .explore import (
Expand Down
67 changes: 67 additions & 0 deletions api/controllers/console/auth/data_source_bearer_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from flask_login import current_user
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden

from controllers.console import api
from controllers.console.auth.error import ApiKeyAuthFailedError
from libs.login import login_required
from services.auth.api_key_auth_service import ApiKeyAuthService

from ..setup import setup_required
from ..wraps import account_initialization_required


class ApiKeyAuthDataSource(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
# The role of the current user in the table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
if data_source_api_key_bindings:
return {
'settings': [data_source_api_key_binding.to_dict() for data_source_api_key_binding in
data_source_api_key_bindings]}
return {'settings': []}


class ApiKeyAuthDataSourceBinding(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
try:
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
except Exception as e:
raise ApiKeyAuthFailedError(str(e))
return {'result': 'success'}, 200


class ApiKeyAuthDataSourceBindingDelete(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, binding_id):
# The role of the current user in the table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()

ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)

return {'result': 'success'}, 200


api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source')
api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding')
api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>')
7 changes: 7 additions & 0 deletions api/controllers/console/auth/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from libs.exception import BaseHTTPException


class ApiKeyAuthFailedError(BaseHTTPException):
error_code = 'auth_failed'
description = "{message}"
code = 500
22 changes: 11 additions & 11 deletions api/controllers/console/datasets/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.login import login_required
from models.dataset import Document
from models.source import DataSourceBinding
from models.source import DataSourceOauthBinding
from services.dataset_service import DatasetService, DocumentService
from tasks.document_indexing_sync_task import document_indexing_sync_task

Expand All @@ -29,9 +29,9 @@ class DataSourceApi(Resource):
@marshal_with(integrate_list_fields)
def get(self):
# get workspace data source integrates
data_source_integrates = db.session.query(DataSourceBinding).filter(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.disabled == False
data_source_integrates = db.session.query(DataSourceOauthBinding).filter(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False
).all()

base_url = request.url_root.rstrip('/')
Expand Down Expand Up @@ -71,7 +71,7 @@ def get(self):
def patch(self, binding_id, action):
binding_id = str(binding_id)
action = str(action)
data_source_binding = DataSourceBinding.query.filter_by(
data_source_binding = DataSourceOauthBinding.query.filter_by(
id=binding_id
).first()
if data_source_binding is None:
Expand Down Expand Up @@ -124,7 +124,7 @@ def get(self):
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info['notion_page_id'])
# get all authorized pages
data_source_bindings = DataSourceBinding.query.filter_by(
data_source_bindings = DataSourceOauthBinding.query.filter_by(
tenant_id=current_user.current_tenant_id,
provider='notion',
disabled=False
Expand Down Expand Up @@ -163,12 +163,12 @@ class DataSourceNotionApi(Resource):
def get(self, workspace_id, page_id, page_type):
workspace_id = str(workspace_id)
page_id = str(page_id)
data_source_binding = DataSourceBinding.query.filter(
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == current_user.current_tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
Expand Down
17 changes: 17 additions & 0 deletions api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,22 @@ def post(self):
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
elif args['info_list']['data_source_type'] == 'website_crawl':
website_info_list = args['info_list']['website_info_list']
for url in website_info_list['urls']:
extract_setting = ExtractSetting(
datasource_type="website_crawl",
website_info={
"provider": website_info_list['provider'],
"job_id": website_info_list['job_id'],
"url": url,
"tenant_id": current_user.current_tenant_id,
"mode": 'crawl',
"only_main_content": website_info_list['only_main_content']
},
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
else:
raise ValueError('Data source type not support')
indexing_runner = IndexingRunner()
Expand Down Expand Up @@ -519,6 +535,7 @@ def get(self, vector_type):
raise ValueError(f"Unsupported vector db type {vector_type}.")



class DatasetErrorDocs(Resource):
@setup_required
@login_required
Expand Down
43 changes: 43 additions & 0 deletions api/controllers/console/datasets/datasets_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,20 @@ def get(self, dataset_id, batch):
document_model=document.doc_form
)
extract_settings.append(extract_setting)
elif document.data_source_type == 'website_crawl':
extract_setting = ExtractSetting(
datasource_type="website_crawl",
website_info={
"provider": data_source_info['provider'],
"job_id": data_source_info['job_id'],
"url": data_source_info['url'],
"tenant_id": current_user.current_tenant_id,
"mode": data_source_info['mode'],
"only_main_content": data_source_info['only_main_content']
},
document_model=document.doc_form
)
extract_settings.append(extract_setting)

else:
raise ValueError('Data source type not support')
Expand Down Expand Up @@ -952,6 +966,33 @@ def post(self, dataset_id, document_id):
return document


class WebsiteDocumentSyncApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
"""sync website document."""
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound('Document not found.')
if document.tenant_id != current_user.current_tenant_id:
raise Forbidden('No permission.')
if document.data_source_type != 'website_crawl':
raise ValueError('Document is not a website document.')
# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
# sync document
DocumentService.sync_website_document(dataset_id, document)

return {'result': 'success'}, 200


api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
api.add_resource(DatasetDocumentListApi,
'/datasets/<uuid:dataset_id>/documents')
Expand Down Expand Up @@ -980,3 +1021,5 @@ def post(self, dataset_id, document_id):
api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry')
api.add_resource(DocumentRenameApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename')

api.add_resource(WebsiteDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync')
6 changes: 6 additions & 0 deletions api/controllers/console/datasets/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ class InvalidMetadataError(BaseHTTPException):
code = 400


class WebsiteCrawlError(BaseHTTPException):
error_code = 'crawl_failed'
description = "{message}"
code = 500


class DatasetInUseError(BaseHTTPException):
error_code = 'dataset_in_use'
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
Expand Down
49 changes: 49 additions & 0 deletions api/controllers/console/datasets/website.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from flask_restful import Resource, reqparse

from controllers.console import api
from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.login import login_required
from services.website_service import WebsiteService


class WebsiteCrawlApi(Resource):

@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, choices=['firecrawl'],
required=True, nullable=True, location='json')
parser.add_argument('url', type=str, required=True, nullable=True, location='json')
parser.add_argument('options', type=dict, required=True, nullable=True, location='json')
args = parser.parse_args()
WebsiteService.document_create_args_validate(args)
# crawl url
try:
result = WebsiteService.crawl_url(args)
except Exception as e:
raise WebsiteCrawlError(str(e))
return result, 200


class WebsiteCrawlStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, job_id: str):
parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args')
args = parser.parse_args()
# get crawl status
try:
result = WebsiteService.get_crawl_status(job_id, args['provider'])
except Exception as e:
raise WebsiteCrawlError(str(e))
return result, 200


api.add_resource(WebsiteCrawlApi, '/website/crawl')
api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>')
19 changes: 18 additions & 1 deletion api/core/indexing_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSettin
def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
-> list[Document]:
# load file
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]:
return []

data_source_info = dataset_document.data_source_info_dict
Expand Down Expand Up @@ -375,6 +375,23 @@ def _extract(self, index_processor: BaseIndexProcessor, dataset_document: Datase
document_model=dataset_document.doc_form
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
elif dataset_document.data_source_type == 'website_crawl':
if (not data_source_info or 'provider' not in data_source_info
or 'url' not in data_source_info or 'job_id' not in data_source_info):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type="website_crawl",
website_info={
"provider": data_source_info['provider'],
"job_id": data_source_info['job_id'],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info['url'],
"mode": data_source_info['mode'],
"only_main_content": data_source_info['only_main_content']
},
document_model=dataset_document.doc_form
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
# update document status to splitting
self._update_document_index_status(
document_id=dataset_document.id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
default=float(credentials.get('presence_penalty', 0)),
min=-2,
max=2
)
),
],
pricing=PriceConfig(
input=Decimal(cred_with_endpoint.get('input_price', 0)),
Expand Down
1 change: 1 addition & 0 deletions api/core/rag/extractor/entity/datasource_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
class DatasourceType(Enum):
FILE = "upload_file"
NOTION = "notion_import"
WEBSITE = "website_crawl"
Loading

0 comments on commit ba5f8af

Please sign in to comment.