diff --git a/api/.env.example b/api/.env.example index 4b479e6175a48..b91daab851a31 100644 --- a/api/.env.example +++ b/api/.env.example @@ -215,4 +215,5 @@ WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 # App configuration -APP_MAX_EXECUTION_TIME=1200 \ No newline at end of file +APP_MAX_EXECUTION_TIME=1200 + diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 306b7384cfab4..29eac070a08fc 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -29,13 +29,13 @@ ) # Import auth controllers -from .auth import activate, data_source_oauth, login, oauth +from .auth import activate, data_source_bearer_auth, data_source_oauth, login, oauth # Import billing controllers from .billing import billing # Import datasets controllers -from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing +from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing, website # Import explore controllers from .explore import ( diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py new file mode 100644 index 0000000000000..81678f61fcbf0 --- /dev/null +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -0,0 +1,67 @@ +from flask_login import current_user +from flask_restful import Resource, reqparse +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.auth.error import ApiKeyAuthFailedError +from libs.login import login_required +from services.auth.api_key_auth_service import ApiKeyAuthService + +from ..setup import setup_required +from ..wraps import account_initialization_required + + +class ApiKeyAuthDataSource(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + # The role of the current user in the table must be admin or owner + if not current_user.is_admin_or_owner: + raise Forbidden() + data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id) + if data_source_api_key_bindings: + return { + 'settings': [data_source_api_key_binding.to_dict() for data_source_api_key_binding in + data_source_api_key_bindings]} + return {'settings': []} + + +class ApiKeyAuthDataSourceBinding(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + # The role of the current user in the table must be admin or owner + if not current_user.is_admin_or_owner: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument('category', type=str, required=True, nullable=False, location='json') + parser.add_argument('provider', type=str, required=True, nullable=False, location='json') + parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() + ApiKeyAuthService.validate_api_key_auth_args(args) + try: + ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args) + except Exception as e: + raise ApiKeyAuthFailedError(str(e)) + return {'result': 'success'}, 200 + + +class ApiKeyAuthDataSourceBindingDelete(Resource): + @setup_required + @login_required + @account_initialization_required + def delete(self, binding_id): + # The role of the current user in the table must be admin or owner + if not current_user.is_admin_or_owner: + raise Forbidden() + + ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) + + return {'result': 'success'}, 200 + + +api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source') +api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding') +api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/') diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py new file mode 100644 index 0000000000000..c55ff8707d224 --- /dev/null +++ b/api/controllers/console/auth/error.py @@ -0,0 +1,7 @@ +from libs.exception import BaseHTTPException + + +class ApiKeyAuthFailedError(BaseHTTPException): + error_code = 'auth_failed' + description = "{message}" + code = 500 diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 8b210cc756bc0..0ca0f0a85653d 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -16,7 +16,7 @@ from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from libs.login import login_required from models.dataset import Document -from models.source import DataSourceBinding +from models.source import DataSourceOauthBinding from services.dataset_service import DatasetService, DocumentService from tasks.document_indexing_sync_task import document_indexing_sync_task @@ -29,9 +29,9 @@ class DataSourceApi(Resource): @marshal_with(integrate_list_fields) def get(self): # get workspace data source integrates - data_source_integrates = db.session.query(DataSourceBinding).filter( - DataSourceBinding.tenant_id == current_user.current_tenant_id, - DataSourceBinding.disabled == False + data_source_integrates = db.session.query(DataSourceOauthBinding).filter( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.disabled == False ).all() base_url = request.url_root.rstrip('/') @@ -71,7 +71,7 @@ def get(self): def patch(self, binding_id, action): binding_id = str(binding_id) action = str(action) - data_source_binding = DataSourceBinding.query.filter_by( + data_source_binding = DataSourceOauthBinding.query.filter_by( id=binding_id ).first() if data_source_binding is None: @@ -124,7 +124,7 @@ def get(self): data_source_info = json.loads(document.data_source_info) exist_page_ids.append(data_source_info['notion_page_id']) # get all authorized pages - data_source_bindings = DataSourceBinding.query.filter_by( + data_source_bindings = DataSourceOauthBinding.query.filter_by( tenant_id=current_user.current_tenant_id, provider='notion', disabled=False @@ -163,12 +163,12 @@ class DataSourceNotionApi(Resource): def get(self, workspace_id, page_id, page_type): workspace_id = str(workspace_id) page_id = str(page_id) - data_source_binding = DataSourceBinding.query.filter( + data_source_binding = DataSourceOauthBinding.query.filter( db.and_( - DataSourceBinding.tenant_id == current_user.current_tenant_id, - DataSourceBinding.provider == 'notion', - DataSourceBinding.disabled == False, - DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' ) ).first() if not data_source_binding: diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 81ef0b8925e8a..cb14abe9231d1 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -315,6 +315,22 @@ def post(self): document_model=args['doc_form'] ) extract_settings.append(extract_setting) + elif args['info_list']['data_source_type'] == 'website_crawl': + website_info_list = args['info_list']['website_info_list'] + for url in website_info_list['urls']: + extract_setting = ExtractSetting( + datasource_type="website_crawl", + website_info={ + "provider": website_info_list['provider'], + "job_id": website_info_list['job_id'], + "url": url, + "tenant_id": current_user.current_tenant_id, + "mode": 'crawl', + "only_main_content": website_info_list['only_main_content'] + }, + document_model=args['doc_form'] + ) + extract_settings.append(extract_setting) else: raise ValueError('Data source type not support') indexing_runner = IndexingRunner() @@ -519,6 +535,7 @@ def get(self, vector_type): raise ValueError(f"Unsupported vector db type {vector_type}.") + class DatasetErrorDocs(Resource): @setup_required @login_required diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index cdb8a46277b11..976b7df629243 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -465,6 +465,20 @@ def get(self, dataset_id, batch): document_model=document.doc_form ) extract_settings.append(extract_setting) + elif document.data_source_type == 'website_crawl': + extract_setting = ExtractSetting( + datasource_type="website_crawl", + website_info={ + "provider": data_source_info['provider'], + "job_id": data_source_info['job_id'], + "url": data_source_info['url'], + "tenant_id": current_user.current_tenant_id, + "mode": data_source_info['mode'], + "only_main_content": data_source_info['only_main_content'] + }, + document_model=document.doc_form + ) + extract_settings.append(extract_setting) else: raise ValueError('Data source type not support') @@ -952,6 +966,33 @@ def post(self, dataset_id, document_id): return document +class WebsiteDocumentSyncApi(DocumentResource): + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id): + """sync website document.""" + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound('Dataset not found.') + document_id = str(document_id) + document = DocumentService.get_document(dataset.id, document_id) + if not document: + raise NotFound('Document not found.') + if document.tenant_id != current_user.current_tenant_id: + raise Forbidden('No permission.') + if document.data_source_type != 'website_crawl': + raise ValueError('Document is not a website document.') + # 403 if document is archived + if DocumentService.check_archived(document): + raise ArchivedDocumentImmutableError() + # sync document + DocumentService.sync_website_document(dataset_id, document) + + return {'result': 'success'}, 200 + + api.add_resource(GetProcessRuleApi, '/datasets/process-rule') api.add_resource(DatasetDocumentListApi, '/datasets//documents') @@ -980,3 +1021,5 @@ def post(self, dataset_id, document_id): api.add_resource(DocumentRetryApi, '/datasets//retry') api.add_resource(DocumentRenameApi, '/datasets//documents//rename') + +api.add_resource(WebsiteDocumentSyncApi, '/datasets//documents//website-sync') diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index e77693b6c9495..71476764aaa15 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -73,6 +73,12 @@ class InvalidMetadataError(BaseHTTPException): code = 400 +class WebsiteCrawlError(BaseHTTPException): + error_code = 'crawl_failed' + description = "{message}" + code = 500 + + class DatasetInUseError(BaseHTTPException): error_code = 'dataset_in_use' description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py new file mode 100644 index 0000000000000..bbd91256f1c29 --- /dev/null +++ b/api/controllers/console/datasets/website.py @@ -0,0 +1,49 @@ +from flask_restful import Resource, reqparse + +from controllers.console import api +from controllers.console.datasets.error import WebsiteCrawlError +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from libs.login import login_required +from services.website_service import WebsiteService + + +class WebsiteCrawlApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('provider', type=str, choices=['firecrawl'], + required=True, nullable=True, location='json') + parser.add_argument('url', type=str, required=True, nullable=True, location='json') + parser.add_argument('options', type=dict, required=True, nullable=True, location='json') + args = parser.parse_args() + WebsiteService.document_create_args_validate(args) + # crawl url + try: + result = WebsiteService.crawl_url(args) + except Exception as e: + raise WebsiteCrawlError(str(e)) + return result, 200 + + +class WebsiteCrawlStatusApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, job_id: str): + parser = reqparse.RequestParser() + parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args') + args = parser.parse_args() + # get crawl status + try: + result = WebsiteService.get_crawl_status(job_id, args['provider']) + except Exception as e: + raise WebsiteCrawlError(str(e)) + return result, 200 + + +api.add_resource(WebsiteCrawlApi, '/website/crawl') +api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/') diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index f2fb3771431ae..af4bed13efbe3 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -339,7 +339,7 @@ def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSettin def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \ -> list[Document]: # load file - if dataset_document.data_source_type not in ["upload_file", "notion_import"]: + if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]: return [] data_source_info = dataset_document.data_source_info_dict @@ -375,6 +375,23 @@ def _extract(self, index_processor: BaseIndexProcessor, dataset_document: Datase document_model=dataset_document.doc_form ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) + elif dataset_document.data_source_type == 'website_crawl': + if (not data_source_info or 'provider' not in data_source_info + or 'url' not in data_source_info or 'job_id' not in data_source_info): + raise ValueError("no website import info found") + extract_setting = ExtractSetting( + datasource_type="website_crawl", + website_info={ + "provider": data_source_info['provider'], + "job_id": data_source_info['job_id'], + "tenant_id": dataset_document.tenant_id, + "url": data_source_info['url'], + "mode": data_source_info['mode'], + "only_main_content": data_source_info['only_main_content'] + }, + document_model=dataset_document.doc_form + ) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index 64b7341e9ff7c..bb802d407157b 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -124,7 +124,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode default=float(credentials.get('presence_penalty', 0)), min=-2, max=2 - ) + ), ], pricing=PriceConfig( input=Decimal(cred_with_endpoint.get('input_price', 0)), diff --git a/api/core/rag/extractor/entity/datasource_type.py b/api/core/rag/extractor/entity/datasource_type.py index 2c79e7b97b18d..19ad300d110fe 100644 --- a/api/core/rag/extractor/entity/datasource_type.py +++ b/api/core/rag/extractor/entity/datasource_type.py @@ -4,3 +4,4 @@ class DatasourceType(Enum): FILE = "upload_file" NOTION = "notion_import" + WEBSITE = "website_crawl" diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 33d9786691932..e474cf376f063 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import BaseModel, ConfigDict from models.dataset import Document @@ -19,14 +21,33 @@ def __init__(self, **data) -> None: super().__init__(**data) +class WebsiteInfo(BaseModel): + """ + website import info. + """ + provider: str + job_id: str + url: str + mode: str + tenant_id: str + only_main_content: bool = False + + class Config: + arbitrary_types_allowed = True + + def __init__(self, **data) -> None: + super().__init__(**data) + + class ExtractSetting(BaseModel): """ Model class for provider response. """ datasource_type: str - upload_file: UploadFile = None - notion_info: NotionInfo = None - document_model: str = None + upload_file: Optional[UploadFile] + notion_info: Optional[NotionInfo] + website_info: Optional[WebsiteInfo] + document_model: Optional[str] model_config = ConfigDict(arbitrary_types_allowed=True) def __init__(self, **data) -> None: diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 09d192d4101db..909bfdc137ff7 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -11,6 +11,7 @@ from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.excel_extractor import ExcelExtractor +from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor from core.rag.extractor.html_extractor import HtmlExtractor from core.rag.extractor.markdown_extractor import MarkdownExtractor from core.rag.extractor.notion_extractor import NotionExtractor @@ -154,5 +155,17 @@ def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, tenant_id=extract_setting.notion_info.tenant_id, ) return extractor.extract() + elif extract_setting.datasource_type == DatasourceType.WEBSITE.value: + if extract_setting.website_info.provider == 'firecrawl': + extractor = FirecrawlWebExtractor( + url=extract_setting.website_info.url, + job_id=extract_setting.website_info.job_id, + tenant_id=extract_setting.website_info.tenant_id, + mode=extract_setting.website_info.mode, + only_main_content=extract_setting.website_info.only_main_content + ) + return extractor.extract() + else: + raise ValueError(f"Unsupported website provider: {extract_setting.website_info.provider}") else: raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}") diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py new file mode 100644 index 0000000000000..af6b568936690 --- /dev/null +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -0,0 +1,132 @@ +import json +import time + +import requests + +from extensions.ext_storage import storage + + +class FirecrawlApp: + def __init__(self, api_key=None, base_url=None): + self.api_key = api_key + self.base_url = base_url or 'https://api.firecrawl.dev' + if self.api_key is None and self.base_url == 'https://api.firecrawl.dev': + raise ValueError('No API key provided') + + def scrape_url(self, url, params=None) -> dict: + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}' + } + json_data = {'url': url} + if params: + json_data.update(params) + response = requests.post( + f'{self.base_url}/v0/scrape', + headers=headers, + json=json_data + ) + if response.status_code == 200: + response = response.json() + if response['success'] == True: + data = response['data'] + return { + 'title': data.get('metadata').get('title'), + 'description': data.get('metadata').get('description'), + 'source_url': data.get('metadata').get('sourceURL'), + 'markdown': data.get('markdown') + } + else: + raise Exception(f'Failed to scrape URL. Error: {response["error"]}') + + elif response.status_code in [402, 409, 500]: + error_message = response.json().get('error', 'Unknown error occurred') + raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}') + else: + raise Exception(f'Failed to scrape URL. Status code: {response.status_code}') + + def crawl_url(self, url, params=None) -> str: + start_time = time.time() + headers = self._prepare_headers() + json_data = {'url': url} + if params: + json_data.update(params) + response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers) + if response.status_code == 200: + job_id = response.json().get('jobId') + return job_id + else: + self._handle_error(response, 'start crawl job') + + def check_crawl_status(self, job_id) -> dict: + headers = self._prepare_headers() + response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers) + if response.status_code == 200: + crawl_status_response = response.json() + if crawl_status_response.get('status') == 'completed': + total = crawl_status_response.get('total', 0) + if total == 0: + raise Exception('Failed to check crawl status. Error: No page found') + data = crawl_status_response.get('data', []) + url_data_list = [] + for item in data: + if isinstance(item, dict) and 'metadata' in item and 'markdown' in item: + url_data = { + 'title': item.get('metadata').get('title'), + 'description': item.get('metadata').get('description'), + 'source_url': item.get('metadata').get('sourceURL'), + 'markdown': item.get('markdown') + } + url_data_list.append(url_data) + if url_data_list: + file_key = 'website_files/' + job_id + '.txt' + if storage.exists(file_key): + storage.delete(file_key) + storage.save(file_key, json.dumps(url_data_list).encode('utf-8')) + return { + 'status': 'completed', + 'total': crawl_status_response.get('total'), + 'current': crawl_status_response.get('current'), + 'data': url_data_list + } + + else: + return { + 'status': crawl_status_response.get('status'), + 'total': crawl_status_response.get('total'), + 'current': crawl_status_response.get('current'), + 'data': [] + } + + else: + self._handle_error(response, 'check crawl status') + + def _prepare_headers(self): + return { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}' + } + + def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5): + for attempt in range(retries): + response = requests.post(url, headers=headers, json=data) + if response.status_code == 502: + time.sleep(backoff_factor * (2 ** attempt)) + else: + return response + return response + + def _get_request(self, url, headers, retries=3, backoff_factor=0.5): + for attempt in range(retries): + response = requests.get(url, headers=headers) + if response.status_code == 502: + time.sleep(backoff_factor * (2 ** attempt)) + else: + return response + return response + + def _handle_error(self, response, action): + error_message = response.json().get('error', 'Unknown error occurred') + raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}') + + diff --git a/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py new file mode 100644 index 0000000000000..8e2f107e5eb79 --- /dev/null +++ b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py @@ -0,0 +1,60 @@ +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document +from services.website_service import WebsiteService + + +class FirecrawlWebExtractor(BaseExtractor): + """ + Crawl and scrape websites and return content in clean llm-ready markdown. + + + Args: + url: The URL to scrape. + api_key: The API key for Firecrawl. + base_url: The base URL for the Firecrawl API. Defaults to 'https://api.firecrawl.dev'. + mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'. + """ + + def __init__( + self, + url: str, + job_id: str, + tenant_id: str, + mode: str = 'crawl', + only_main_content: bool = False + ): + """Initialize with url, api_key, base_url and mode.""" + self._url = url + self.job_id = job_id + self.tenant_id = tenant_id + self.mode = mode + self.only_main_content = only_main_content + + def extract(self) -> list[Document]: + """Extract content from the URL.""" + documents = [] + if self.mode == 'crawl': + crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id) + if crawl_data is None: + return [] + document = Document(page_content=crawl_data.get('markdown', ''), + metadata={ + 'source_url': crawl_data.get('source_url'), + 'description': crawl_data.get('description'), + 'title': crawl_data.get('title') + } + ) + documents.append(document) + elif self.mode == 'scrape': + scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id, + self.only_main_content) + + document = Document(page_content=scrape_data.get('markdown', ''), + metadata={ + 'source_url': scrape_data.get('source_url'), + 'description': scrape_data.get('description'), + 'title': scrape_data.get('title') + } + ) + documents.append(document) + return documents diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 1885ad3aca893..4ec0b4fc3861c 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -9,7 +9,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import Document as DocumentModel -from models.source import DataSourceBinding +from models.source import DataSourceOauthBinding logger = logging.getLogger(__name__) @@ -345,12 +345,12 @@ def get_notion_last_edited_time(self) -> str: @classmethod def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: - data_source_binding = DataSourceBinding.query.filter( + data_source_binding = DataSourceOauthBinding.query.filter( db.and_( - DataSourceBinding.tenant_id == tenant_id, - DataSourceBinding.provider == 'notion', - DataSourceBinding.disabled == False, - DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"' + DataSourceOauthBinding.tenant_id == tenant_id, + DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"' ) ).first() diff --git a/api/libs/bearer_data_source.py b/api/libs/bearer_data_source.py new file mode 100644 index 0000000000000..04de1fb6daefb --- /dev/null +++ b/api/libs/bearer_data_source.py @@ -0,0 +1,64 @@ +# [REVIEW] Implement if Needed? Do we need a new type of data source +from abc import abstractmethod + +import requests +from api.models.source import DataSourceBearerBinding +from flask_login import current_user + +from extensions.ext_database import db + + +class BearerDataSource: + def __init__(self, api_key: str, api_base_url: str): + self.api_key = api_key + self.api_base_url = api_base_url + + @abstractmethod + def validate_bearer_data_source(self): + """ + Validate the data source + """ + + +class FireCrawlDataSource(BearerDataSource): + def validate_bearer_data_source(self): + TEST_CRAWL_SITE_URL = "https://www.google.com" + FIRECRAWL_API_VERSION = "v0" + + test_api_endpoint = self.api_base_url.rstrip('/') + f"/{FIRECRAWL_API_VERSION}/scrape" + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + data = { + "url": TEST_CRAWL_SITE_URL, + } + + response = requests.get(test_api_endpoint, headers=headers, json=data) + + return response.json().get("status") == "success" + + def save_credentials(self): + # save data source binding + data_source_binding = DataSourceBearerBinding.query.filter( + db.and_( + DataSourceBearerBinding.tenant_id == current_user.current_tenant_id, + DataSourceBearerBinding.provider == 'firecrawl', + DataSourceBearerBinding.endpoint_url == self.api_base_url, + DataSourceBearerBinding.bearer_key == self.api_key + ) + ).first() + if data_source_binding: + data_source_binding.disabled = False + db.session.commit() + else: + new_data_source_binding = DataSourceBearerBinding( + tenant_id=current_user.current_tenant_id, + provider='firecrawl', + endpoint_url=self.api_base_url, + bearer_key=self.api_key + ) + db.session.add(new_data_source_binding) + db.session.commit() diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index a865ee85ab5c6..3f2889adbefa3 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -4,7 +4,7 @@ from flask_login import current_user from extensions.ext_database import db -from models.source import DataSourceBinding +from models.source import DataSourceOauthBinding class OAuthDataSource: @@ -63,11 +63,11 @@ def get_access_token(self, code: str): 'total': len(pages) } # save data source binding - data_source_binding = DataSourceBinding.query.filter( + data_source_binding = DataSourceOauthBinding.query.filter( db.and_( - DataSourceBinding.tenant_id == current_user.current_tenant_id, - DataSourceBinding.provider == 'notion', - DataSourceBinding.access_token == access_token + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.access_token == access_token ) ).first() if data_source_binding: @@ -75,7 +75,7 @@ def get_access_token(self, code: str): data_source_binding.disabled = False db.session.commit() else: - new_data_source_binding = DataSourceBinding( + new_data_source_binding = DataSourceOauthBinding( tenant_id=current_user.current_tenant_id, access_token=access_token, source_info=source_info, @@ -98,11 +98,11 @@ def save_internal_access_token(self, access_token: str): 'total': len(pages) } # save data source binding - data_source_binding = DataSourceBinding.query.filter( + data_source_binding = DataSourceOauthBinding.query.filter( db.and_( - DataSourceBinding.tenant_id == current_user.current_tenant_id, - DataSourceBinding.provider == 'notion', - DataSourceBinding.access_token == access_token + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.access_token == access_token ) ).first() if data_source_binding: @@ -110,7 +110,7 @@ def save_internal_access_token(self, access_token: str): data_source_binding.disabled = False db.session.commit() else: - new_data_source_binding = DataSourceBinding( + new_data_source_binding = DataSourceOauthBinding( tenant_id=current_user.current_tenant_id, access_token=access_token, source_info=source_info, @@ -121,12 +121,12 @@ def save_internal_access_token(self, access_token: str): def sync_data_source(self, binding_id: str): # save data source binding - data_source_binding = DataSourceBinding.query.filter( + data_source_binding = DataSourceOauthBinding.query.filter( db.and_( - DataSourceBinding.tenant_id == current_user.current_tenant_id, - DataSourceBinding.provider == 'notion', - DataSourceBinding.id == binding_id, - DataSourceBinding.disabled == False + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.id == binding_id, + DataSourceOauthBinding.disabled == False ) ).first() if data_source_binding: diff --git a/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py b/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py new file mode 100644 index 0000000000000..f63bad93457d3 --- /dev/null +++ b/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py @@ -0,0 +1,67 @@ +"""add-api-key-auth-binding + +Revision ID: 7b45942e39bb +Revises: 47cc7df8c4f3 +Create Date: 2024-05-14 07:31:29.702766 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '7b45942e39bb' +down_revision = '4e99a8df00ff' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('data_source_api_key_auth_bindings', + sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('category', sa.String(length=255), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('credentials', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('disabled', sa.Boolean(), server_default=sa.text('false'), nullable=True), + sa.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey') + ) + with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op: + batch_op.create_index('data_source_api_key_auth_binding_provider_idx', ['provider'], unique=False) + batch_op.create_index('data_source_api_key_auth_binding_tenant_id_idx', ['tenant_id'], unique=False) + + with op.batch_alter_table('data_source_bindings', schema=None) as batch_op: + batch_op.drop_index('source_binding_tenant_id_idx') + batch_op.drop_index('source_info_idx') + + op.rename_table('data_source_bindings', 'data_source_oauth_bindings') + + with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op: + batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False) + batch_op.create_index('source_info_idx', ['source_info'], unique=False, postgresql_using='gin') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op: + batch_op.drop_index('source_info_idx', postgresql_using='gin') + batch_op.drop_index('source_binding_tenant_id_idx') + + op.rename_table('data_source_oauth_bindings', 'data_source_bindings') + + with op.batch_alter_table('data_source_bindings', schema=None) as batch_op: + batch_op.create_index('source_info_idx', ['source_info'], unique=False) + batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False) + + with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op: + batch_op.drop_index('data_source_api_key_auth_binding_tenant_id_idx') + batch_op.drop_index('data_source_api_key_auth_binding_provider_idx') + + op.drop_table('data_source_api_key_auth_bindings') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 7f98bbde15349..9f8b15be1a45e 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -270,7 +270,7 @@ class Document(db.Model): 255), nullable=False, server_default=db.text("'text_model'::character varying")) doc_language = db.Column(db.String(255), nullable=True) - DATA_SOURCES = ['upload_file', 'notion_import'] + DATA_SOURCES = ['upload_file', 'notion_import', 'website_crawl'] @property def display_status(self): @@ -322,7 +322,7 @@ def data_source_detail_dict(self): 'created_at': file_detail.created_at.timestamp() } } - elif self.data_source_type == 'notion_import': + elif self.data_source_type == 'notion_import' or self.data_source_type == 'website_crawl': return json.loads(self.data_source_info) return {} diff --git a/api/models/source.py b/api/models/source.py index 97ba23a5bddbf..265e68f014c6c 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,11 +1,13 @@ +import json + from sqlalchemy.dialects.postgresql import JSONB from extensions.ext_database import db from models import StringUUID -class DataSourceBinding(db.Model): - __tablename__ = 'data_source_bindings' +class DataSourceOauthBinding(db.Model): + __tablename__ = 'data_source_oauth_bindings' __table_args__ = ( db.PrimaryKeyConstraint('id', name='source_binding_pkey'), db.Index('source_binding_tenant_id_idx', 'tenant_id'), @@ -20,3 +22,33 @@ class DataSourceBinding(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) + + +class DataSourceApiKeyAuthBinding(db.Model): + __tablename__ = 'data_source_api_key_auth_bindings' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey'), + db.Index('data_source_api_key_auth_binding_tenant_id_idx', 'tenant_id'), + db.Index('data_source_api_key_auth_binding_provider_idx', 'provider'), + ) + + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) + category = db.Column(db.String(255), nullable=False) + provider = db.Column(db.String(255), nullable=False) + credentials = db.Column(db.Text, nullable=True) # JSON + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) + + def to_dict(self): + return { + 'id': self.id, + 'tenant_id': self.tenant_id, + 'category': self.category, + 'provider': self.provider, + 'credentials': json.loads(self.credentials), + 'created_at': self.created_at.timestamp(), + 'updated_at': self.updated_at.timestamp(), + 'disabled': self.disabled + } diff --git a/api/pyproject.toml b/api/pyproject.toml index fb6e6bf8c3630..9f2786d40641b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -78,6 +78,9 @@ CODE_MAX_STRING_LENGTH = "80000" CODE_EXECUTION_ENDPOINT="http://127.0.0.1:8194" CODE_EXECUTION_API_KEY="dify-sandbox" +FIRECRAWL_API_KEY = "fc-" + + [tool.poetry] name = "dify-api" diff --git a/api/services/auth/__init__.py b/api/services/auth/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/services/auth/api_key_auth_base.py b/api/services/auth/api_key_auth_base.py new file mode 100644 index 0000000000000..dd74a8f1b539a --- /dev/null +++ b/api/services/auth/api_key_auth_base.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod + + +class ApiKeyAuthBase(ABC): + def __init__(self, credentials: dict): + self.credentials = credentials + + @abstractmethod + def validate_credentials(self): + raise NotImplementedError diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py new file mode 100644 index 0000000000000..ccd0023c44d84 --- /dev/null +++ b/api/services/auth/api_key_auth_factory.py @@ -0,0 +1,14 @@ + +from services.auth.firecrawl import FirecrawlAuth + + +class ApiKeyAuthFactory: + + def __init__(self, provider: str, credentials: dict): + if provider == 'firecrawl': + self.auth = FirecrawlAuth(credentials) + else: + raise ValueError('Invalid provider') + + def validate_credentials(self): + return self.auth.validate_credentials() diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py new file mode 100644 index 0000000000000..43d0fbf98f2df --- /dev/null +++ b/api/services/auth/api_key_auth_service.py @@ -0,0 +1,70 @@ +import json + +from core.helper import encrypter +from extensions.ext_database import db +from models.source import DataSourceApiKeyAuthBinding +from services.auth.api_key_auth_factory import ApiKeyAuthFactory + + +class ApiKeyAuthService: + + @staticmethod + def get_provider_auth_list(tenant_id: str) -> list: + data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, + DataSourceApiKeyAuthBinding.disabled.is_(False) + ).all() + return data_source_api_key_bindings + + @staticmethod + def create_provider_auth(tenant_id: str, args: dict): + auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials() + if auth_result: + # Encrypt the api key + api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key']) + args['credentials']['config']['api_key'] = api_key + + data_source_api_key_binding = DataSourceApiKeyAuthBinding() + data_source_api_key_binding.tenant_id = tenant_id + data_source_api_key_binding.category = args['category'] + data_source_api_key_binding.provider = args['provider'] + data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False) + db.session.add(data_source_api_key_binding) + db.session.commit() + + @staticmethod + def get_auth_credentials(tenant_id: str, category: str, provider: str): + data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, + DataSourceApiKeyAuthBinding.category == category, + DataSourceApiKeyAuthBinding.provider == provider, + DataSourceApiKeyAuthBinding.disabled.is_(False) + ).first() + if not data_source_api_key_bindings: + return None + credentials = json.loads(data_source_api_key_bindings.credentials) + return credentials + + @staticmethod + def delete_provider_auth(tenant_id: str, binding_id: str): + data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, + DataSourceApiKeyAuthBinding.id == binding_id + ).first() + if data_source_api_key_binding: + db.session.delete(data_source_api_key_binding) + db.session.commit() + + @classmethod + def validate_api_key_auth_args(cls, args): + if 'category' not in args or not args['category']: + raise ValueError('category is required') + if 'provider' not in args or not args['provider']: + raise ValueError('provider is required') + if 'credentials' not in args or not args['credentials']: + raise ValueError('credentials is required') + if not isinstance(args['credentials'], dict): + raise ValueError('credentials must be a dictionary') + if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']: + raise ValueError('auth_type is required') + diff --git a/api/services/auth/firecrawl.py b/api/services/auth/firecrawl.py new file mode 100644 index 0000000000000..69e3fb43c79da --- /dev/null +++ b/api/services/auth/firecrawl.py @@ -0,0 +1,56 @@ +import json + +import requests + +from services.auth.api_key_auth_base import ApiKeyAuthBase + + +class FirecrawlAuth(ApiKeyAuthBase): + def __init__(self, credentials: dict): + super().__init__(credentials) + auth_type = credentials.get('auth_type') + if auth_type != 'bearer': + raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer') + self.api_key = credentials.get('config').get('api_key', None) + self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev') + + if not self.api_key: + raise ValueError('No API key provided') + + def validate_credentials(self): + headers = self._prepare_headers() + options = { + 'url': 'https://example.com', + 'crawlerOptions': { + 'excludes': [], + 'includes': [], + 'limit': 1 + }, + 'pageOptions': { + 'onlyMainContent': True + } + } + response = self._post_request(f'{self.base_url}/v0/crawl', options, headers) + if response.status_code == 200: + return True + else: + self._handle_error(response) + + def _prepare_headers(self): + return { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}' + } + + def _post_request(self, url, data, headers): + return requests.post(url, headers=headers, json=data) + + def _handle_error(self, response): + if response.status_code in [402, 409, 500]: + error_message = response.json().get('error', 'Unknown error occurred') + raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') + else: + if response.text: + error_message = json.loads(response.text).get('error', 'Unknown error occurred') + raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') + raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}') diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 97a6f5d103f32..b3cf15811b6d9 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -31,7 +31,7 @@ DocumentSegment, ) from models.model import UploadFile -from models.source import DataSourceBinding +from models.source import DataSourceOauthBinding from services.errors.account import NoPermissionError from services.errors.dataset import DatasetInUseError, DatasetNameDuplicateError from services.errors.document import DocumentIndexingError @@ -48,6 +48,7 @@ from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task from tasks.recover_document_indexing_task import recover_document_indexing_task from tasks.retry_document_indexing_task import retry_document_indexing_task +from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task class DatasetService: @@ -508,18 +509,40 @@ def recover_document(document): @staticmethod def retry_document(dataset_id: str, documents: list[Document]): for document in documents: + # add retry flag + retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id) + cache_result = redis_client.get(retry_indexing_cache_key) + if cache_result is not None: + raise ValueError("Document is being retried, please try again later") # retry document indexing document.indexing_status = 'waiting' db.session.add(document) db.session.commit() - # add retry flag - retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id) + redis_client.setex(retry_indexing_cache_key, 600, 1) # trigger async task document_ids = [document.id for document in documents] retry_document_indexing_task.delay(dataset_id, document_ids) @staticmethod + def sync_website_document(dataset_id: str, document: Document): + # add sync flag + sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id) + cache_result = redis_client.get(sync_indexing_cache_key) + if cache_result is not None: + raise ValueError("Document is being synced, please try again later") + # sync document indexing + document.indexing_status = 'waiting' + data_source_info = document.data_source_info_dict + data_source_info['mode'] = 'scrape' + document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) + db.session.add(document) + db.session.commit() + + redis_client.setex(sync_indexing_cache_key, 600, 1) + + sync_website_document_indexing_task.delay(dataset_id, document.id) + @staticmethod def get_documents_position(dataset_id): document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() if document: @@ -545,6 +568,9 @@ def save_document_with_dataset_id(dataset: Dataset, document_data: dict, notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] for notion_info in notion_info_list: count = count + len(notion_info['pages']) + elif document_data["data_source"]["type"] == "website_crawl": + website_info = document_data["data_source"]['info_list']['website_info_list'] + count = len(website_info['urls']) batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -683,12 +709,12 @@ def save_document_with_dataset_id(dataset: Dataset, document_data: dict, exist_document[data_source_info['notion_page_id']] = document.id for notion_info in notion_info_list: workspace_id = notion_info['workspace_id'] - data_source_binding = DataSourceBinding.query.filter( + data_source_binding = DataSourceOauthBinding.query.filter( db.and_( - DataSourceBinding.tenant_id == current_user.current_tenant_id, - DataSourceBinding.provider == 'notion', - DataSourceBinding.disabled == False, - DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' ) ).first() if not data_source_binding: @@ -717,6 +743,28 @@ def save_document_with_dataset_id(dataset: Dataset, document_data: dict, # delete not selected documents if len(exist_document) > 0: clean_notion_document_task.delay(list(exist_document.values()), dataset.id) + elif document_data["data_source"]["type"] == "website_crawl": + website_info = document_data["data_source"]['info_list']['website_info_list'] + urls = website_info['urls'] + for url in urls: + data_source_info = { + 'url': url, + 'provider': website_info['provider'], + 'job_id': website_info['job_id'], + 'only_main_content': website_info.get('only_main_content', False), + 'mode': 'crawl', + } + document = DocumentService.build_document(dataset, dataset_process_rule.id, + document_data["data_source"]["type"], + document_data["doc_form"], + document_data["doc_language"], + data_source_info, created_from, position, + account, url, batch) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 db.session.commit() # trigger async task @@ -818,12 +866,12 @@ def update_document_with_dataset_id(dataset: Dataset, document_data: dict, notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] for notion_info in notion_info_list: workspace_id = notion_info['workspace_id'] - data_source_binding = DataSourceBinding.query.filter( + data_source_binding = DataSourceOauthBinding.query.filter( db.and_( - DataSourceBinding.tenant_id == current_user.current_tenant_id, - DataSourceBinding.provider == 'notion', - DataSourceBinding.disabled == False, - DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' ) ).first() if not data_source_binding: @@ -835,6 +883,17 @@ def update_document_with_dataset_id(dataset: Dataset, document_data: dict, "notion_page_icon": page['page_icon'], "type": page['type'] } + elif document_data["data_source"]["type"] == "website_crawl": + website_info = document_data["data_source"]['info_list']['website_info_list'] + urls = website_info['urls'] + for url in urls: + data_source_info = { + 'url': url, + 'provider': website_info['provider'], + 'job_id': website_info['job_id'], + 'only_main_content': website_info.get('only_main_content', False), + 'mode': 'crawl', + } document.data_source_type = document_data["data_source"]["type"] document.data_source_info = json.dumps(data_source_info) document.name = file_name @@ -873,6 +932,9 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] for notion_info in notion_info_list: count = count + len(notion_info['pages']) + elif document_data["data_source"]["type"] == "website_crawl": + website_info = document_data["data_source"]['info_list']['website_info_list'] + count = len(website_info['urls']) batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -973,6 +1035,10 @@ def data_source_args_validate(cls, args: dict): if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ 'notion_info_list']: raise ValueError("Notion source info is required") + if args['data_source']['type'] == 'website_crawl': + if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ + 'website_info_list']: + raise ValueError("Website source info is required") @classmethod def process_rule_args_validate(cls, args: dict): diff --git a/api/services/website_service.py b/api/services/website_service.py new file mode 100644 index 0000000000000..c166b01237b6c --- /dev/null +++ b/api/services/website_service.py @@ -0,0 +1,171 @@ +import datetime +import json + +from flask_login import current_user + +from core.helper import encrypter +from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp +from extensions.ext_redis import redis_client +from extensions.ext_storage import storage +from services.auth.api_key_auth_service import ApiKeyAuthService + + +class WebsiteService: + + @classmethod + def document_create_args_validate(cls, args: dict): + if 'url' not in args or not args['url']: + raise ValueError('url is required') + if 'options' not in args or not args['options']: + raise ValueError('options is required') + if 'limit' not in args['options'] or not args['options']['limit']: + raise ValueError('limit is required') + + @classmethod + def crawl_url(cls, args: dict) -> dict: + provider = args.get('provider') + url = args.get('url') + options = args.get('options') + credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, + 'website', + provider) + if provider == 'firecrawl': + # decrypt api_key + api_key = encrypter.decrypt_token( + tenant_id=current_user.current_tenant_id, + token=credentials.get('config').get('api_key') + ) + firecrawl_app = FirecrawlApp(api_key=api_key, + base_url=credentials.get('config').get('base_url', None)) + crawl_sub_pages = options.get('crawl_sub_pages', False) + only_main_content = options.get('only_main_content', False) + if not crawl_sub_pages: + params = { + 'crawlerOptions': { + "includes": [], + "excludes": [], + "generateImgAltText": True, + "limit": 1, + 'returnOnlyUrls': False, + 'pageOptions': { + 'onlyMainContent': only_main_content, + "includeHtml": False + } + } + } + else: + includes = options.get('includes').split(',') if options.get('includes') else [] + excludes = options.get('excludes').split(',') if options.get('excludes') else [] + params = { + 'crawlerOptions': { + "includes": includes if includes else [], + "excludes": excludes if excludes else [], + "generateImgAltText": True, + "limit": options.get('limit', 1), + 'returnOnlyUrls': False, + 'pageOptions': { + 'onlyMainContent': only_main_content, + "includeHtml": False + } + } + } + if options.get('max_depth'): + params['crawlerOptions']['maxDepth'] = options.get('max_depth') + job_id = firecrawl_app.crawl_url(url, params) + website_crawl_time_cache_key = f'website_crawl_{job_id}' + time = str(datetime.datetime.now().timestamp()) + redis_client.setex(website_crawl_time_cache_key, 3600, time) + return { + 'status': 'active', + 'job_id': job_id + } + else: + raise ValueError('Invalid provider') + + @classmethod + def get_crawl_status(cls, job_id: str, provider: str) -> dict: + credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, + 'website', + provider) + if provider == 'firecrawl': + # decrypt api_key + api_key = encrypter.decrypt_token( + tenant_id=current_user.current_tenant_id, + token=credentials.get('config').get('api_key') + ) + firecrawl_app = FirecrawlApp(api_key=api_key, + base_url=credentials.get('config').get('base_url', None)) + result = firecrawl_app.check_crawl_status(job_id) + crawl_status_data = { + 'status': result.get('status', 'active'), + 'job_id': job_id, + 'total': result.get('total', 0), + 'current': result.get('current', 0), + 'data': result.get('data', []) + } + if crawl_status_data['status'] == 'completed': + website_crawl_time_cache_key = f'website_crawl_{job_id}' + start_time = redis_client.get(website_crawl_time_cache_key) + if start_time: + end_time = datetime.datetime.now().timestamp() + time_consuming = abs(end_time - float(start_time)) + crawl_status_data['time_consuming'] = f"{time_consuming:.2f}" + redis_client.delete(website_crawl_time_cache_key) + else: + raise ValueError('Invalid provider') + return crawl_status_data + + @classmethod + def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None: + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, + 'website', + provider) + if provider == 'firecrawl': + file_key = 'website_files/' + job_id + '.txt' + if storage.exists(file_key): + data = storage.load_once(file_key) + if data: + data = json.loads(data.decode('utf-8')) + else: + # decrypt api_key + api_key = encrypter.decrypt_token( + tenant_id=tenant_id, + token=credentials.get('config').get('api_key') + ) + firecrawl_app = FirecrawlApp(api_key=api_key, + base_url=credentials.get('config').get('base_url', None)) + result = firecrawl_app.check_crawl_status(job_id) + if result.get('status') != 'completed': + raise ValueError('Crawl job is not completed') + data = result.get('data') + if data: + for item in data: + if item.get('source_url') == url: + return item + return None + else: + raise ValueError('Invalid provider') + + @classmethod + def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None: + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, + 'website', + provider) + if provider == 'firecrawl': + # decrypt api_key + api_key = encrypter.decrypt_token( + tenant_id=tenant_id, + token=credentials.get('config').get('api_key') + ) + firecrawl_app = FirecrawlApp(api_key=api_key, + base_url=credentials.get('config').get('base_url', None)) + params = { + 'pageOptions': { + 'onlyMainContent': only_main_content, + "includeHtml": False + } + } + result = firecrawl_app.scrape_url(url, params) + return result + else: + raise ValueError('Invalid provider') diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index c35c18799a603..4cced36ecdd85 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -11,7 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -from models.source import DataSourceBinding +from models.source import DataSourceOauthBinding @shared_task(queue='dataset') @@ -43,12 +43,12 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): page_id = data_source_info['notion_page_id'] page_type = data_source_info['type'] page_edited_time = data_source_info['last_edited_time'] - data_source_binding = DataSourceBinding.query.filter( + data_source_binding = DataSourceOauthBinding.query.filter( db.and_( - DataSourceBinding.tenant_id == document.tenant_id, - DataSourceBinding.provider == 'notion', - DataSourceBinding.disabled == False, - DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.tenant_id == document.tenant_id, + DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' ) ).first() if not data_source_binding: diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py new file mode 100644 index 0000000000000..320da8718a12c --- /dev/null +++ b/api/tasks/sync_website_document_indexing_task.py @@ -0,0 +1,90 @@ +import datetime +import logging +import time + +import click +from celery import shared_task + +from core.indexing_runner import IndexingRunner +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, Document, DocumentSegment +from services.feature_service import FeatureService + + +@shared_task(queue='dataset') +def sync_website_document_indexing_task(dataset_id: str, document_id: str): + """ + Async process document + :param dataset_id: + :param document_id: + + Usage: sunc_website_document_indexing_task.delay(dataset_id, document_id) + """ + start_at = time.perf_counter() + + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + + sync_indexing_cache_key = 'document_{}_is_sync'.format(document_id) + # check document limit + features = FeatureService.get_features(dataset.tenant_id) + try: + if features.billing.enabled: + vector_space = features.vector_space + if 0 < vector_space.limit <= vector_space.size: + raise ValueError("Your total number of documents plus the number of uploads have over the limit of " + "your subscription.") + except Exception as e: + document = db.session.query(Document).filter( + Document.id == document_id, + Document.dataset_id == dataset_id + ).first() + if document: + document.indexing_status = 'error' + document.error = str(e) + document.stopped_at = datetime.datetime.utcnow() + db.session.add(document) + db.session.commit() + redis_client.delete(sync_indexing_cache_key) + return + + logging.info(click.style('Start sync website document: {}'.format(document_id), fg='green')) + document = db.session.query(Document).filter( + Document.id == document_id, + Document.dataset_id == dataset_id + ).first() + try: + if document: + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids) + + for segment in segments: + db.session.delete(segment) + db.session.commit() + + document.indexing_status = 'parsing' + document.processing_started_at = datetime.datetime.utcnow() + db.session.add(document) + db.session.commit() + + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(sync_indexing_cache_key) + except Exception as ex: + document.indexing_status = 'error' + document.error = str(ex) + document.stopped_at = datetime.datetime.utcnow() + db.session.add(document) + db.session.commit() + logging.info(click.style(str(ex), fg='yellow')) + redis_client.delete(sync_indexing_cache_key) + pass + end_at = time.perf_counter() + logging.info(click.style('Sync document: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/__init__.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py new file mode 100644 index 0000000000000..a8bba11e16db1 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -0,0 +1,33 @@ +import os +from unittest import mock + +from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp +from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor +from core.rag.models.document import Document +from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response + + +def test_firecrawl_web_extractor_crawl_mode(mocker): + url = "https://firecrawl.dev" + api_key = os.getenv('FIRECRAWL_API_KEY') or 'fc-' + base_url = 'https://api.firecrawl.dev' + firecrawl_app = FirecrawlApp(api_key=api_key, + base_url=base_url) + params = { + 'crawlerOptions': { + "includes": [], + "excludes": [], + "generateImgAltText": True, + "maxDepth": 1, + "limit": 1, + 'returnOnlyUrls': False, + + } + } + mocked_firecrawl = { + "jobId": "test", + } + mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl)) + job_id = firecrawl_app.crawl_url(url, params) + print(job_id) + assert isinstance(job_id, str) diff --git a/api/tests/unit_tests/oss/__init__.py b/api/tests/unit_tests/oss/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/tests/unit_tests/oss/local/__init__.py b/api/tests/unit_tests/oss/local/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d