diff --git a/.github/workflows/push-server.yml b/.github/workflows/push-server.yml index 61ec5152dc..92b22a69c7 100644 --- a/.github/workflows/push-server.yml +++ b/.github/workflows/push-server.yml @@ -450,3 +450,147 @@ jobs: push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} + + word-id-cronjob: + name: Push Word ID cronjob + runs-on: ${{ matrix.runner }} + strategy: + matrix: + runner: [blacksmith-8vcpu-ubuntu-2204] + platform: [linux/amd64] + exclude: + - runner: blacksmith-8vcpu-ubuntu-2204 + platform: linux/arm64 + - runner: blacksmith-8vcpu-ubuntu-2204-arm + platform: linux/amd64 + steps: + - name: Checkout the repo + uses: actions/checkout@v4 + + - name: Setup buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Docker meta + id: meta + uses: docker/metadata-action@v5 + with: + # list of Docker images to use as base name for tags + images: | + trieve/word-id-cronjob + tags: | + type=raw,latest + type=sha + + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + platforms: ${{ matrix.platform }} + cache-from: type=registry,ref=trieve/buildcache:word-id-cronjob-${{matrix.runner}} + cache-to: type=registry,ref=trieve/buildcache:word-id-cronjob-${{matrix.runner}},mode=max + context: server/ + file: ./server/Dockerfile.word-id-cronjob + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + + word-worker: + name: Push Word Worker + runs-on: ${{ matrix.runner }} + strategy: + matrix: + runner: [blacksmith-8vcpu-ubuntu-2204] + platform: [linux/amd64] + exclude: + - runner: blacksmith-8vcpu-ubuntu-2204 + platform: linux/arm64 + - runner: blacksmith-8vcpu-ubuntu-2204-arm + platform: linux/amd64 + steps: + - name: Checkout the repo + uses: actions/checkout@v4 + + - name: Setup buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Docker meta + id: meta + uses: docker/metadata-action@v5 + with: + # list of Docker images to use as base name for tags + images: | + trieve/word-worker + tags: | + type=raw,latest + type=sha + + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + platforms: ${{ matrix.platform }} + cache-from: type=registry,ref=trieve/buildcache:word-worker-${{matrix.runner}} + cache-to: type=registry,ref=trieve/buildcache:word-worker-${{matrix.runner}},mode=max + context: server/ + file: ./server/Dockerfile.word-worker + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + + bktree-worker: + name: Push BK-Tree Worker + runs-on: ${{ matrix.runner }} + strategy: + matrix: + runner: [blacksmith-8vcpu-ubuntu-2204] + platform: [linux/amd64] + exclude: + - runner: blacksmith-8vcpu-ubuntu-2204 + platform: linux/arm64 + - runner: blacksmith-8vcpu-ubuntu-2204-arm + platform: linux/amd64 + steps: + - name: Checkout the repo + uses: actions/checkout@v4 + + - name: Setup buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Docker meta + id: meta + uses: docker/metadata-action@v5 + with: + # list of Docker images to use as base name for tags + images: | + trieve/bktree-worker + tags: | + type=raw,latest + type=sha + + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + platforms: ${{ matrix.platform }} + cache-from: type=registry,ref=trieve/buildcache:bktree-worker-${{matrix.runner}} + cache-to: type=registry,ref=trieve/buildcache:bktree-worker-${{matrix.runner}},mode=max + context: server/ + file: ./server/Dockerfile.bktree-worker + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} diff --git a/frontends/search/src/components/GroupPage.tsx b/frontends/search/src/components/GroupPage.tsx index 089f976515..e254b8b447 100644 --- a/frontends/search/src/components/GroupPage.tsx +++ b/frontends/search/src/components/GroupPage.tsx @@ -263,6 +263,18 @@ export const GroupPage = (props: GroupPageProps) => { slim_chunks: search.debounced.slimChunks, page_size: search.debounced.pageSize, get_total_pages: search.debounced.getTotalPages, + typo_options: { + correct_typos: search.debounced.correctTypos, + one_typo_word_range: { + min: search.debounced.oneTypoWordRangeMin, + max: search.debounced.oneTypoWordRangeMax, + }, + two_typo_word_range: { + min: search.debounced.twoTypoWordRangeMin, + max: search.debounced.twoTypoWordRangeMax, + }, + disable_on_words: search.debounced.disableOnWords, + }, highlight_options: { highlight_results: search.debounced.highlightResults, highlight_strategy: search.debounced.highlightStrategy, diff --git a/frontends/search/src/components/ResultsPage.tsx b/frontends/search/src/components/ResultsPage.tsx index 5b4f3b166d..438ae3d098 100644 --- a/frontends/search/src/components/ResultsPage.tsx +++ b/frontends/search/src/components/ResultsPage.tsx @@ -282,6 +282,18 @@ const ResultsPage = (props: ResultsPageProps) => { slim_chunks: props.search.debounced.slimChunks ?? false, page_size: props.search.debounced.pageSize ?? 10, get_total_pages: props.search.debounced.getTotalPages ?? false, + typo_options: { + correct_typos: props.search.debounced.correctTypos, + one_typo_word_range: { + min: props.search.debounced.oneTypoWordRangeMin, + max: props.search.debounced.oneTypoWordRangeMax, + }, + two_typo_word_range: { + min: props.search.debounced.twoTypoWordRangeMin, + max: props.search.debounced.twoTypoWordRangeMax, + }, + disable_on_word: props.search.debounced.disableOnWords, + }, highlight_options: { highlight_results: props.search.debounced.highlightResults ?? true, highlight_strategy: diff --git a/frontends/search/src/components/SearchForm.tsx b/frontends/search/src/components/SearchForm.tsx index 25be35cf10..989c60952d 100644 --- a/frontends/search/src/components/SearchForm.tsx +++ b/frontends/search/src/components/SearchForm.tsx @@ -1051,6 +1051,13 @@ const SearchForm = (props: { pageSize: 10, getTotalPages: false, highlightStrategy: "exactmatch", + correctTypos: false, + oneTypoWordRangeMin: 5, + oneTypoWordRangeMax: 8, + twoTypoWordRangeMin: 8, + twoTypoWordRangeMax: null, + disableOnWords: [], + typoTolerance: false, highlightResults: true, highlightDelimiters: ["?", ".", "!"], highlightMaxLength: 8, @@ -1195,7 +1202,7 @@ const SearchForm = (props: { />
- +
+
+ + { + setTempSearchValues((prev) => { + return { + ...prev, + correctTypos: e.target.checked, + }; + }); + }} + /> +
+
+ + { + setTempSearchValues((prev) => { + return { + ...prev, + oneTypoWordRangeMin: parseInt( + e.currentTarget.value, + ), + }; + }); + }} + /> +
+
+ + { + setTempSearchValues((prev) => { + return { + ...prev, + oneTypoWordRangeMax: + e.currentTarget.value === "" + ? null + : parseInt(e.currentTarget.value), + }; + }); + }} + /> +
+
+ + { + setTempSearchValues((prev) => { + return { + ...prev, + twoTypoWordRangeMin: parseInt( + e.currentTarget.value, + ), + }; + }); + }} + /> +
+
+ + { + setTempSearchValues((prev) => { + return { + ...prev, + oneTypoWordRangeMax: + e.currentTarget.value === "" + ? null + : parseInt(e.currentTarget.value), + }; + }); + }} + /> +
+
+ + { + if (e.currentTarget.value === " ") { + setTempSearchValues((prev) => { + return { + ...prev, + disableOnWords: [" "], + }; + }); + } + + setTempSearchValues((prev) => { + return { + ...prev, + disableOnWords: + e.currentTarget.value.split(","), + }; + }); + }} + /> +
{ sort_by: JSON.stringify(state.sort_by), pageSize: state.pageSize.toString(), getTotalPages: state.getTotalPages.toString(), + correctTypos: state.correctTypos.toString(), + oneTypoWordRangeMin: state.oneTypoWordRangeMin.toString(), + oneTypoWordRangeMax: state.oneTypoWordRangeMax?.toString() ?? "8", + twoTypoWordRangeMin: state.twoTypoWordRangeMin.toString(), + twoTypoWordRangeMax: state.twoTypoWordRangeMax?.toString() ?? "", + disableOnWords: state.disableOnWords.join(","), highlightStrategy: state.highlightStrategy, highlightResults: state.highlightResults.toString(), highlightThreshold: state.highlightThreshold.toString(), @@ -121,6 +139,13 @@ const fromStateToParams = (state: SearchOptions): Params => { }; }; +const parseIntOrNull = (str: string | undefined) => { + if (!str || str === "") { + return null; + } + return parseInt(str); +}; + const fromParamsToState = ( params: Partial, ): Omit => { @@ -136,6 +161,12 @@ const fromParamsToState = ( initalState.sort_by, pageSize: parseInt(params.pageSize ?? "10"), getTotalPages: (params.getTotalPages ?? "false") === "true", + correctTypos: (params.correctTypos ?? "false") === "true", + oneTypoWordRangeMin: parseInt(params.oneTypoWordRangeMin ?? "5"), + oneTypoWordRangeMax: parseIntOrNull(params.oneTypoWordRangeMax), + twoTypoWordRangeMin: parseInt(params.oneTypoWordRangeMin ?? "8"), + twoTypoWordRangeMax: parseIntOrNull(params.twoTypoWordRangeMax), + disableOnWords: params.disableOnWords?.split(",") ?? [], highlightResults: (params.highlightResults ?? "true") === "true", highlightStrategy: isHighlightStrategy(params.highlightStrategy) ? params.highlightStrategy diff --git a/helm/local-values.yaml b/helm/local-values.yaml index 0afdea0ede..ff825a1973 100644 --- a/helm/local-values.yaml +++ b/helm/local-values.yaml @@ -22,6 +22,10 @@ containers: tag: latest sync_qdrant: tag: latest + bktree_worker: + tag: latest + word_worker: + tag: latest search: tag: latest chat: diff --git a/helm/templates/bktree-worker-deployment.yaml b/helm/templates/bktree-worker-deployment.yaml new file mode 100644 index 0000000000..a4c5e8c53c --- /dev/null +++ b/helm/templates/bktree-worker-deployment.yaml @@ -0,0 +1,114 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: bktree-worker + labels: + app.kubernetes.io/name: bktree-worker + app.kubernetes.io/instance: {{ .Release.Name }} +spec: + selector: + matchLabels: + app.kubernetes.io/name: bktree-worker + app.kubernetes.io/instance: {{ .Release.Name }} + template: + metadata: + labels: + app.kubernetes.io/name: bktree-worker + app.kubernetes.io/instance: {{ .Release.Name }} + spec: + containers: + - name: bktree-worker + image: {{ printf "%s%s:%s" (ternary "trieve/" "localhost:5001/" (ne .Values.environment "local")) "bktree-worker" .Values.containers.bktree_worker.tag }} + env: + - name: ADMIN_API_KEY + value: {{ .Values.config.trieve.adminApiKey }} + - name: BASE_SERVER_URL + value: {{ .Values.config.trieve.baseServerUrl }} + - name: REDIS_URL + value: {{ ( ternary "redis://:redis@trieve-redis-master.default.svc.cluster.local:6379" .Values.config.redis.uri (eq .Values.config.redis.useSubchart true)) }} + - name: QDRANT_URL + value: {{ ( ternary "http://trieve-qdrant.default.svc.cluster.local:6334" .Values.config.qdrant.qdrantUrl (eq .Values.config.qdrant.useSubchart true)) }} + - name: QDRANT_API_KEY + value: {{ .Values.config.qdrant.apiKey }} + - name: QUANTIZE_VECTORS + value: {{ .Values.config.qdrant.quantizeVectors | quote }} + - name: REPLICATION_FACTOR + value: {{ .Values.config.qdrant.replicationFactor | quote }} + - name: DATABASE_URL + value: {{ ( ternary "postgres://postgres:password@trieve-postgresql.default.svc.cluster.local:5432/trieve" .Values.postgres.dbURI (eq .Values.postgres.useSubchart true)) }} + - name: SMTP_RELAY + value: {{ .Values.config.smtp.relay }} + - name: SMTP_USERNAME + value: {{ .Values.config.smtp.username }} + - name: SMTP_PASSWORD + value: {{ .Values.config.smtp.password }} + - name: SMTP_EMAIL_ADDRESS + value: {{ .Values.config.smtp.emailAddress }} + - name: OPENAI_API_KEY + value: {{ .Values.config.openai.apiKey }} + - name: LLM_API_KEY + value: {{ .Values.config.llm.apiKey }} + - name: SECRET_KEY + value: {{ .Values.config.trieve.secretKey | quote }} + - name: SALT + value: {{ .Values.config.trieve.salt }} + - name: S3_ENDPOINT + value: {{ .Values.config.s3.endpoint }} + - name: S3_ACCESS_KEY + value: {{ .Values.config.s3.accessKey }} + - name: S3_SECRET_KEY + value: {{ .Values.config.s3.secretKey }} + - name: S3_BUCKET + value: {{ .Values.config.s3.bucket }} + - name: COOKIE_SECURE + value: {{ .Values.config.trieve.cookieSecure | quote }} + - name: QDRANT_COLLECTION + value: {{ .Values.config.qdrant.collection }} + - name: TIKA_URL + value: http://tika.default.svc.cluster.local:9998 + - name: OPENAI_BASE_URL + value: {{ .Values.config.openai.baseUrl }} + - name: STRIPE_SECRET + value: {{ .Values.config.stripe.secret }} + - name: STRIPE_WEBHOOK_SECRET + value: {{ .Values.config.stripe.webhookSecret }} + - name: ADMIN_DASHBOARD_URL + value: {{ .Values.config.trieve.adminDashboardUrl }} + - name: OIDC_CLIENT_SECRET + value: {{ .Values.config.oidc.clientSecret }} + - name: OIDC_CLIENT_ID + value: {{ .Values.config.oidc.clientId }} + - name: OIDC_AUTH_REDIRECT_URL + value: {{ .Values.config.oidc.authRedirectUrl }} + - name: OIDC_ISSUER_URL + value: {{ .Values.config.oidc.issuerUrl }} + - name: GPU_SERVER_ORIGIN + value: {{ .Values.config.trieve.gpuServerOrigin }} + - name: SPARSE_SERVER_QUERY_ORIGIN + value: {{ .Values.config.trieve.sparseServerQueryOrigin }} + - name: SPARSE_SERVER_DOC_ORIGIN + value: {{ .Values.config.trieve.sparseServerDocOrigin }} + - name: SPARSE_SERVER_ORIGIN + value: {{ .Values.config.trieve.sparseServerOrigin }} + - name: EMBEDDING_SERVER_ORIGIN + value: {{ .Values.config.trieve.embeddingServerOrigin }} + - name: EMBEDDING_SERVER_ORIGIN_BGEM3 + value: {{ .Values.config.trieve.embeddingServerOriginBGEM3 }} + - name: RERANKER_SERVER_ORIGIN + value: {{ .Values.config.trieve.rerankerServerOrigin }} + - name: UNLIMITED + value: {{ .Values.config.trieve.unlimited | quote }} + - name: REDIS_CONNECTIONS + value: "2" + - name: AWS_REGION + value: {{ .Values.config.s3.region }} + - name: CLICKHOUSE_URL + value: {{ .Values.config.analytics.clickhouseUrl | quote }} + - name: CLICKHOUSE_DB + value: {{ .Values.config.analytics.clickhouseDB | quote }} + - name: CLICKHOUSE_USER + value: {{ .Values.config.analytics.clickhouseUser | quote }} + - name: CLICKHOUSE_PASSWORD + value: {{ .Values.config.analytics.clickhousePassword | quote }} + - name: USE_ANALYTICS + value: {{ .Values.config.analytics.enabled | quote }} diff --git a/helm/templates/wordworker-deployment.yaml b/helm/templates/wordworker-deployment.yaml new file mode 100644 index 0000000000..563be23c9f --- /dev/null +++ b/helm/templates/wordworker-deployment.yaml @@ -0,0 +1,146 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: word-worker + labels: + app.kubernetes.io/name: word-worker + app.kubernetes.io/instance: {{ .Release.Name }} +spec: + selector: + matchLabels: + app.kubernetes.io/name: word-worker + app.kubernetes.io/instance: {{ .Release.Name }} + template: + metadata: + labels: + app.kubernetes.io/name: word-worker + app.kubernetes.io/instance: {{ .Release.Name }} + spec: + serviceAccountName: cloud-postgres-service-account + containers: + - name: word-worker + image: {{ printf "%s%s:%s" (ternary "trieve/" "localhost:5001/" (ne .Values.environment "local")) "word-worker" .Values.containers.word_worker.tag }} + env: + - name: ADMIN_API_KEY + value: {{ .Values.config.trieve.adminApiKey }} + - name: BASE_SERVER_URL + value: {{ .Values.config.trieve.baseServerUrl }} + - name: REDIS_URL + value: {{ ( ternary "redis://:redis@trieve-redis-master.default.svc.cluster.local:6379" .Values.config.redis.uri (eq .Values.config.redis.useSubchart true)) }} + - name: QDRANT_URL + value: {{ ( ternary "http://trieve-qdrant.default.svc.cluster.local:6334" .Values.config.qdrant.qdrantUrl (eq .Values.config.qdrant.useSubchart true)) }} + - name: QDRANT_API_KEY + value: {{ .Values.config.qdrant.apiKey }} + - name: QUANTIZE_VECTORS + value: {{ .Values.config.qdrant.quantizeVectors | quote }} + - name: REPLICATION_FACTOR + value: {{ .Values.config.qdrant.replicationFactor | quote }} + - name: DATABASE_URL + value: {{ ( ternary "postgres://postgres:password@trieve-postgresql.default.svc.cluster.local:5432/trieve" .Values.postgres.dbURI (eq .Values.postgres.useSubchart true)) }} + - name: SMTP_RELAY + value: {{ .Values.config.smtp.relay }} + - name: SMTP_USERNAME + value: {{ .Values.config.smtp.username }} + - name: SMTP_PASSWORD + value: {{ .Values.config.smtp.password }} + - name: SMTP_EMAIL_ADDRESS + value: {{ .Values.config.smtp.emailAddress }} + - name: OPENAI_API_KEY + value: {{ .Values.config.openai.apiKey }} + - name: LLM_API_KEY + value: {{ .Values.config.llm.apiKey }} + - name: SECRET_KEY + value: {{ .Values.config.trieve.secretKey | quote }} + - name: SALT + value: {{ .Values.config.trieve.salt }} + - name: S3_ENDPOINT + value: {{ .Values.config.s3.endpoint }} + - name: S3_ACCESS_KEY + value: {{ .Values.config.s3.accessKey }} + - name: S3_SECRET_KEY + value: {{ .Values.config.s3.secretKey }} + - name: S3_BUCKET + value: {{ .Values.config.s3.bucket }} + - name: COOKIE_SECURE + value: {{ .Values.config.trieve.cookieSecure | quote }} + - name: QDRANT_COLLECTION + value: {{ .Values.config.qdrant.collection }} + - name: TIKA_URL + value: http://tika.default.svc.cluster.local:9998 + - name: OPENAI_BASE_URL + value: {{ .Values.config.openai.baseUrl }} + - name: STRIPE_SECRET + value: {{ .Values.config.stripe.secret }} + - name: STRIPE_WEBHOOK_SECRET + value: {{ .Values.config.stripe.webhookSecret }} + - name: ADMIN_DASHBOARD_URL + value: {{ .Values.config.trieve.adminDashboardUrl }} + - name: OIDC_CLIENT_SECRET + value: {{ .Values.config.oidc.clientSecret }} + - name: OIDC_CLIENT_ID + value: {{ .Values.config.oidc.clientId }} + - name: OIDC_AUTH_REDIRECT_URL + value: {{ .Values.config.oidc.authRedirectUrl }} + - name: OIDC_ISSUER_URL + value: {{ .Values.config.oidc.issuerUrl }} + - name: GPU_SERVER_ORIGIN + value: {{ .Values.config.trieve.gpuServerOrigin }} + - name: SPARSE_SERVER_QUERY_ORIGIN + value: {{ .Values.config.trieve.sparseServerQueryOrigin }} + - name: SPARSE_SERVER_DOC_ORIGIN + value: {{ .Values.config.trieve.sparseServerDocOrigin }} + - name: SPARSE_SERVER_ORIGIN + value: {{ .Values.config.trieve.sparseServerOrigin }} + - name: EMBEDDING_SERVER_ORIGIN + value: {{ .Values.config.trieve.embeddingServerOrigin }} + - name: EMBEDDING_SERVER_ORIGIN_BGEM3 + value: {{ .Values.config.trieve.embeddingServerOriginBGEM3 }} + - name: RERANKER_SERVER_ORIGIN + value: {{ .Values.config.trieve.rerankerServerOrigin }} + - name: UNLIMITED + value: {{ .Values.config.trieve.unlimited | quote }} + - name: REDIS_CONNECTIONS + value: "2" + - name: AWS_REGION + value: {{ .Values.config.s3.region }} + - name: CLICKHOUSE_URL + value: {{ .Values.config.analytics.clickhouseUrl | quote }} + - name: CLICKHOUSE_DB + value: {{ .Values.config.analytics.clickhouseDB | quote }} + - name: CLICKHOUSE_USER + value: {{ .Values.config.analytics.clickhouseUser | quote }} + - name: CLICKHOUSE_PASSWORD + value: {{ .Values.config.analytics.clickhousePassword | quote }} + - name: USE_ANALYTICS + value: {{ .Values.config.analytics.enabled | quote }} + {{- if eq $.Values.environment "gcloud" }} + - name: cloud-sql-proxy + # It is recommended to use the latest version of the Cloud SQL Auth Proxy + # Make sure to update on a regular schedule! + image: gcr.io/cloud-sql-connectors/cloud-sql-proxy:2.8.0 + args: + - "--structured-logs" + - "--auto-iam-authn" + # Replace DB_PORT with the port the proxy should listen on + - "--port=5432" + - "studious-lore-405302:us-west1:trieve-cloud" + securityContext: + # The default Cloud SQL Auth Proxy image runs as the + # "nonroot" user and group (uid: 65532) by default. + runAsNonRoot: true + # You should use resource requests/limits as a best practice to prevent + # pods from consuming too many resources and affecting the execution of + # other pods. You should adjust the following values based on what your + # application needs. For details, see + # https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/ + resources: + requests: + # The proxy's memory use scales linearly with the number of active + # connections. Fewer open connections will use less memory. Adjust + # this value based on your application's requirements. + memory: "2Gi" + # The proxy's CPU use scales linearly with the amount of IO between + # the database and the application. Adjust this value based on your + # application's requirements. + cpu: "1" + {{- end }} diff --git a/helm/values.yaml.tpl b/helm/values.yaml.tpl index c4e08e1cce..233c1764e7 100644 --- a/helm/values.yaml.tpl +++ b/helm/values.yaml.tpl @@ -23,6 +23,10 @@ containers: tag: latest sync_qdrant: tag: latest + bktree_worker: + tag: latest + word_worker: + tag: latest search: tag: latest chat: diff --git a/scripts/reset-bktree.sh b/scripts/reset-bktree.sh new file mode 100644 index 0000000000..96fd743900 --- /dev/null +++ b/scripts/reset-bktree.sh @@ -0,0 +1,61 @@ +#!/bin/bash + +# Default connection details +REDIS_HOST="localhost" +REDIS_PORT="6379" +REDIS_PASSWORD="" +CLICKHOUSE_HOST="localhost" +CLICKHOUSE_PORT="8123" +CLICKHOUSE_USER="default" +CLICKHOUSE_PASSWORD="password" +CLICKHOUSE_DB="default" + +# Function to print usage +usage() { + echo "Usage: $0 -d [-rh ] [-rp ] [-rw ] [-ch ] [-cp ] [-cu ] [-cw ] [-cd ]" + exit 1 +} + +# Parse command line arguments +while getopts "d:rh:rp:rw:ch:cp:cu:cw:cd:" opt; do + case $opt in + d) DATASET_ID="$OPTARG" ;; + rh) REDIS_HOST="$OPTARG" ;; + rp) REDIS_PORT="$OPTARG" ;; + rw) REDIS_PASSWORD="$OPTARG" ;; + ch) CLICKHOUSE_HOST="$OPTARG" ;; + cp) CLICKHOUSE_PORT="$OPTARG" ;; + cu) CLICKHOUSE_USER="$OPTARG" ;; + cw) CLICKHOUSE_PASSWORD="$OPTARG" ;; + cd) CLICKHOUSE_DB="$OPTARG" ;; + *) usage ;; + esac +done + +# Check if dataset_id is provided +if [ -z "$DATASET_ID" ]; then + echo "Error: dataset_id is required" + usage +fi + +# Construct Redis CLI command +REDIS_CMD="redis-cli -h $REDIS_HOST -p $REDIS_PORT" +if [ -n "$REDIS_PASSWORD" ]; then + REDIS_CMD="$REDIS_CMD -a $REDIS_PASSWORD" +fi + +# Delete key from Redis +echo "Deleting key *$DATASET_ID from Redis..." +$REDIS_CMD DEL "*$DATASET_ID" + +# Delete row from ClickHouse +echo "Deleting row with dataset_id=$DATASET_ID from ClickHouse..." +clickhouse-client \ + --host "$CLICKHOUSE_HOST" \ + --port "$CLICKHOUSE_PORT" \ + --user "$CLICKHOUSE_USER" \ + --password "$CLICKHOUSE_PASSWORD" \ + --database "$CLICKHOUSE_DB" \ + --query "ALTER TABLE dataset_words_last_processed DELETE WHERE dataset_id = '$DATASET_ID'" + +echo "Cleanup completed for dataset_id: $DATASET_ID" \ No newline at end of file diff --git a/server/Cargo.lock b/server/Cargo.lock index a2fb633736..1a61a1c36f 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -722,6 +722,15 @@ dependencies = [ "redis 0.25.4", ] +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -743,6 +752,15 @@ dependencies = [ "crunchy", ] +[[package]] +name = "bktree" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bb1e744816f6a3b9e962186091867f3e5959d4dac995777ec254631cb00b21c" +dependencies = [ + "num", +] + [[package]] name = "blake2b_simd" version = "1.0.2" @@ -1149,6 +1167,19 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + [[package]] name = "crossbeam-channel" version = "0.5.13" @@ -1177,6 +1208,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.20" @@ -1737,9 +1777,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.30" +version = "1.0.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +checksum = "7f211bbe8e69bbd0cfdea405084f128ae8b4aaa6b0b522fc8f2b009084797920" dependencies = [ "crc32fast", "miniz_oxide", @@ -2895,6 +2935,30 @@ dependencies = [ "winapi", ] +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint-dig" version = "0.8.4" @@ -2947,6 +3011,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -5678,11 +5753,14 @@ dependencies = [ "async-stripe", "base64 0.22.1", "bb8-redis", + "bincode", + "bktree", "blake3", "cfg-if", "chm", "chrono", "clickhouse 0.12.0", + "crossbeam", "crossbeam-channel", "dateparser", "derive_more", @@ -5690,12 +5768,14 @@ dependencies = [ "diesel-async", "diesel_migrations", "dotenvy", + "flate2", "futures", "futures-util", "glob", "itertools 0.13.0", "lazy_static", "lettre", + "levenshtein_automata", "log", "murmur3", "ndarray", @@ -5708,6 +5788,7 @@ dependencies = [ "prometheus", "qdrant-client", "rand 0.8.5", + "rayon", "redis 0.25.4", "regex", "regex-split", @@ -5723,6 +5804,7 @@ dependencies = [ "signal-hook", "simple-server-timing-header", "simsearch", + "strsim 0.11.1", "tantivy", "time", "tokio", diff --git a/server/Cargo.toml b/server/Cargo.toml index ed093a13b6..6bab8d5790 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -17,10 +17,22 @@ panic = "abort" name = "trieve-server" path = "src/main.rs" +[[bin]] +name = "word-id-cronjob" +path = "src/bin/word-id-cronjob.rs" + [[bin]] name = "ingestion-worker" path = "src/bin/ingestion-worker.rs" +[[bin]] +name = "bktree-worker" +path = "src/bin/bktree-worker.rs" + +[[bin]] +name = "word-worker" +path = "src/bin/word-worker.rs" + [[bin]] name = "file-worker" path = "src/bin/file-worker.rs" @@ -106,7 +118,7 @@ reqwest = { version = "0.12.2", features = ["json"] } rand = "0.8.5" dotenvy = "0.15.7" simsearch = "0.2.4" -lazy_static = { version = "1.4.0" } +lazy_static = "1.4.0" actix-files = "0.6.2" utoipa = { version = "4.2", features = [ "actix_extras", @@ -148,6 +160,14 @@ prometheus = "0.13.4" chm = "0.1.16" murmur3 = "0.5.2" tantivy = "0.22.0" +strsim = "0.11.1" +levenshtein_automata = "0.2.1" +bktree = "1.0.1" +flate2 = "1.0.31" +bincode = "1.3" +rayon = "1.10.0" +crossbeam = "0.8.4" + [build-dependencies] dotenvy = "0.15.7" diff --git a/server/Dockerfile.bktree-worker b/server/Dockerfile.bktree-worker new file mode 100644 index 0000000000..007be040f5 --- /dev/null +++ b/server/Dockerfile.bktree-worker @@ -0,0 +1,28 @@ +FROM rust:1.80-slim-bookworm AS chef +# We only pay the installation cost once, +# it will be cached from the second build onwards +RUN apt-get update -y && apt-get -y install pkg-config libssl-dev libpq-dev g++ curl +RUN cargo install cargo-chef +WORKDIR app + +FROM chef AS planner +COPY . . +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder +COPY --from=planner /app/recipe.json recipe.json +# Build dependencies - this is the caching Docker layer! +RUN cargo chef cook --release --recipe-path recipe.json --bin "bktree-worker" +# Build application +COPY . . +RUN cargo build --release --features "runtime-env" --bin "bktree-worker" + +FROM debian:bookworm-slim as runtime +RUN apt-get update -y && apt-get -y install pkg-config libssl-dev libpq-dev ca-certificates +WORKDIR /app +COPY ./migrations/ /app/migrations +COPY --from=builder /app/target/release/bktree-worker /app/bktree-worker + + +EXPOSE 8090 +ENTRYPOINT ["/app/bktree-worker"] diff --git a/server/Dockerfile.word-id-cronjob b/server/Dockerfile.word-id-cronjob new file mode 100644 index 0000000000..6cc82d7137 --- /dev/null +++ b/server/Dockerfile.word-id-cronjob @@ -0,0 +1,28 @@ +FROM rust:1.80-slim-bookworm AS chef +# We only pay the installation cost once, +# it will be cached from the second build onwards +RUN apt-get update -y && apt-get -y install pkg-config libssl-dev libpq-dev g++ curl +RUN cargo install cargo-chef +WORKDIR app + +FROM chef AS planner +COPY . . +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder +COPY --from=planner /app/recipe.json recipe.json +# Build dependencies - this is the caching Docker layer! +RUN cargo chef cook --release --recipe-path recipe.json --bin "word-id-cronjob" +# Build application +COPY . . +RUN cargo build --release --features "runtime-env" --bin "word-id-cronjob" + +FROM debian:bookworm-slim as runtime +RUN apt-get update -y && apt-get -y install pkg-config libssl-dev libpq-dev ca-certificates +WORKDIR /app +COPY ./migrations/ /app/migrations +COPY --from=builder /app/target/release/word-id-cronjob /app/word-id-cronjob + + +EXPOSE 8090 +ENTRYPOINT ["/app/word-id-cronjob"] diff --git a/server/Dockerfile.word-worker b/server/Dockerfile.word-worker new file mode 100644 index 0000000000..454f1d1676 --- /dev/null +++ b/server/Dockerfile.word-worker @@ -0,0 +1,28 @@ +FROM rust:1.80-slim-bookworm AS chef +# We only pay the installation cost once, +# it will be cached from the second build onwards +RUN apt-get update -y && apt-get -y install pkg-config libssl-dev libpq-dev g++ curl +RUN cargo install cargo-chef +WORKDIR app + +FROM chef AS planner +COPY . . +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder +COPY --from=planner /app/recipe.json recipe.json +# Build dependencies - this is the caching Docker layer! +RUN cargo chef cook --release --recipe-path recipe.json --bin "word-worker" +# Build application +COPY . . +RUN cargo build --release --features "runtime-env" --bin "word-worker" + +FROM debian:bookworm-slim as runtime +RUN apt-get update -y && apt-get -y install pkg-config libssl-dev libpq-dev ca-certificates +WORKDIR /app +COPY ./migrations/ /app/migrations +COPY --from=builder /app/target/release/word-worker /app/word-worker + + +EXPOSE 8090 +ENTRYPOINT ["/app/word-worker"] diff --git a/server/ch_migrations/1723258343_store_words_in_clickhouse/down.sql b/server/ch_migrations/1723258343_store_words_in_clickhouse/down.sql new file mode 100644 index 0000000000..8eb04f3b52 --- /dev/null +++ b/server/ch_migrations/1723258343_store_words_in_clickhouse/down.sql @@ -0,0 +1,2 @@ + +DROP TABLE IF EXISTS words_datasets; diff --git a/server/ch_migrations/1723258343_store_words_in_clickhouse/up.sql b/server/ch_migrations/1723258343_store_words_in_clickhouse/up.sql new file mode 100644 index 0000000000..28ac871785 --- /dev/null +++ b/server/ch_migrations/1723258343_store_words_in_clickhouse/up.sql @@ -0,0 +1,12 @@ +CREATE TABLE IF NOT EXISTS words_datasets ( + id UUID NOT NULL, + dataset_id UUID NOT NULL, + word String NOT NULL, + count Int32 NOT NULL, + created_at DateTime DEFAULT now() NOT NULL, + INDEX idx_created_at created_at TYPE minmax GRANULARITY 8192, + INDEX idx_id id TYPE minmax GRANULARITY 8192 +) ENGINE = SummingMergeTree(created_at) +ORDER BY (dataset_id, word) +PARTITION BY dataset_id; + diff --git a/server/ch_migrations/1723490007_create_last_processed_table/down.sql b/server/ch_migrations/1723490007_create_last_processed_table/down.sql new file mode 100644 index 0000000000..484495f3ac --- /dev/null +++ b/server/ch_migrations/1723490007_create_last_processed_table/down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS dataset_words_last_processed; diff --git a/server/ch_migrations/1723490007_create_last_processed_table/up.sql b/server/ch_migrations/1723490007_create_last_processed_table/up.sql new file mode 100644 index 0000000000..e2b23894dd --- /dev/null +++ b/server/ch_migrations/1723490007_create_last_processed_table/up.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS dataset_words_last_processed ( + last_processed DateTime DEFAULT now() NOT NULL, + dataset_id UUID NOT NULL, +) ENGINE = ReplacingMergeTree(last_processed) +ORDER BY (dataset_id) +PARTITION BY dataset_id; diff --git a/server/migrations/2024-08-02-200541_word_dataset_table/down.sql b/server/migrations/2024-08-02-200541_word_dataset_table/down.sql new file mode 100644 index 0000000000..a8ab58383e --- /dev/null +++ b/server/migrations/2024-08-02-200541_word_dataset_table/down.sql @@ -0,0 +1,7 @@ +-- This file should undo anything in `up.sql` + +DROP TABLE IF EXISTS "words_datasets"; +DROP TABLE IF EXISTS "words_in_datasets"; +DROP TABLE IF EXISTS "dataset_words_last_processed"; + + diff --git a/server/migrations/2024-08-02-200541_word_dataset_table/up.sql b/server/migrations/2024-08-02-200541_word_dataset_table/up.sql new file mode 100644 index 0000000000..ab43d0ae6f --- /dev/null +++ b/server/migrations/2024-08-02-200541_word_dataset_table/up.sql @@ -0,0 +1,25 @@ +-- Your SQL goes here + +CREATE TABLE IF NOT EXISTS "words_in_datasets" ( + id UUID PRIMARY KEY, + word TEXT NOT NULL, + UNIQUE(word) +); + +CREATE TABLE IF NOT EXISTS "words_datasets" ( + id UUID PRIMARY KEY, + dataset_id UUID NOT NULL, + word_id UUID NOT NULL, + count INT NOT NULL, + UNIQUE(dataset_id, word_id), + FOREIGN KEY (dataset_id) REFERENCES "datasets"(id) ON DELETE CASCADE, + FOREIGN KEY (word_id) REFERENCES "words_in_datasets"(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS "dataset_words_last_processed" ( + id UUID PRIMARY KEY, + last_processed TIMESTAMP NULL, + dataset_id UUID NOT NULL, + FOREIGN KEY (dataset_id) REFERENCES "datasets"(id) ON DELETE CASCADE, + UNIQUE(dataset_id) +); diff --git a/server/migrations/2024-08-09-013645_update_dataset_updated_at/down.sql b/server/migrations/2024-08-09-013645_update_dataset_updated_at/down.sql new file mode 100644 index 0000000000..2b44495d86 --- /dev/null +++ b/server/migrations/2024-08-09-013645_update_dataset_updated_at/down.sql @@ -0,0 +1,4 @@ +-- This file should undo anything in `up.sql` +-- Finally, let's drop the trigger and function +DROP TRIGGER IF EXISTS trigger_update_dataset_timestamp ON chunk_metadata; +DROP FUNCTION IF EXISTS update_dataset_timestamp(); \ No newline at end of file diff --git a/server/migrations/2024-08-09-013645_update_dataset_updated_at/up.sql b/server/migrations/2024-08-09-013645_update_dataset_updated_at/up.sql new file mode 100644 index 0000000000..c856109716 --- /dev/null +++ b/server/migrations/2024-08-09-013645_update_dataset_updated_at/up.sql @@ -0,0 +1,39 @@ +CREATE OR REPLACE FUNCTION update_chunk_metadata_counts() +RETURNS TRIGGER AS $$ +DECLARE + d_id UUID; + new_count INT; +BEGIN + SELECT dataset_id INTO d_id FROM modified WHERE dataset_id IS NOT NULL LIMIT 1; + IF d_id IS NULL THEN + RETURN NULL; + END IF; + SELECT COUNT(modified.id) INTO new_count FROM modified; + + IF TG_OP = 'INSERT' THEN + -- Update dataset_usage_counts + INSERT INTO dataset_usage_counts (dataset_id, chunk_count) + VALUES (d_id, new_count) + ON CONFLICT (dataset_id) DO UPDATE + SET chunk_count = dataset_usage_counts.chunk_count + new_count; + + -- Update dataset + UPDATE datasets + SET updated_at = CURRENT_TIMESTAMP + WHERE id = d_id; + + ELSIF TG_OP = 'DELETE' THEN + -- Update dataset_usage_counts + UPDATE dataset_usage_counts + SET chunk_count = dataset_usage_counts.chunk_count - new_count + WHERE dataset_id = d_id; + + -- Update dataset + UPDATE datasets + SET updated_at = CURRENT_TIMESTAMP + WHERE id = d_id; + END IF; + + RETURN NULL; +END; +$$ LANGUAGE plpgsql; diff --git a/server/migrations/2024-08-09-024547_add_created_at_to_word/down.sql b/server/migrations/2024-08-09-024547_add_created_at_to_word/down.sql new file mode 100644 index 0000000000..330b2abf25 --- /dev/null +++ b/server/migrations/2024-08-09-024547_add_created_at_to_word/down.sql @@ -0,0 +1,2 @@ +-- This file should undo anything in `up.sql` +ALTER TABLE words_datasets DROP COLUMN IF EXISTS created_at; \ No newline at end of file diff --git a/server/migrations/2024-08-09-024547_add_created_at_to_word/up.sql b/server/migrations/2024-08-09-024547_add_created_at_to_word/up.sql new file mode 100644 index 0000000000..1ca1af8298 --- /dev/null +++ b/server/migrations/2024-08-09-024547_add_created_at_to_word/up.sql @@ -0,0 +1,3 @@ +-- Your SQL goes here +ALTER TABLE words_datasets +ADD COLUMN IF NOT EXISTS created_at TIMESTAMP NOT NULL DEFAULT NOW(); \ No newline at end of file diff --git a/server/migrations/2024-08-10-032512_delete_tables/down.sql b/server/migrations/2024-08-10-032512_delete_tables/down.sql new file mode 100644 index 0000000000..8526f01622 --- /dev/null +++ b/server/migrations/2024-08-10-032512_delete_tables/down.sql @@ -0,0 +1,16 @@ +-- This file should undo anything in `up.sql` +CREATE TABLE IF NOT EXISTS "words_in_datasets" ( + id UUID PRIMARY KEY, + word TEXT NOT NULL, + UNIQUE(word) +); + +CREATE TABLE IF NOT EXISTS "words_datasets" ( + id UUID PRIMARY KEY, + dataset_id UUID NOT NULL, + word_id UUID NOT NULL, + count INT NOT NULL, + UNIQUE(dataset_id, word_id), + FOREIGN KEY (dataset_id) REFERENCES "datasets"(id) ON DELETE CASCADE, + FOREIGN KEY (word_id) REFERENCES "words_in_datasets"(id) ON DELETE CASCADE +); diff --git a/server/migrations/2024-08-10-032512_delete_tables/up.sql b/server/migrations/2024-08-10-032512_delete_tables/up.sql new file mode 100644 index 0000000000..591cd969c6 --- /dev/null +++ b/server/migrations/2024-08-10-032512_delete_tables/up.sql @@ -0,0 +1,5 @@ +-- Your SQL goes here +DROP TABLE IF EXISTS "words_datasets"; +DROP TABLE IF EXISTS "words_in_datasets"; + + diff --git a/server/migrations/2024-08-12-191216_delete_last_processed_table/down.sql b/server/migrations/2024-08-12-191216_delete_last_processed_table/down.sql new file mode 100644 index 0000000000..a8589faf50 --- /dev/null +++ b/server/migrations/2024-08-12-191216_delete_last_processed_table/down.sql @@ -0,0 +1,8 @@ +-- This file should undo anything in `up.sql` +CREATE TABLE IF NOT EXISTS "dataset_words_last_processed" ( + id UUID PRIMARY KEY, + last_processed TIMESTAMP NULL, + dataset_id UUID NOT NULL, + FOREIGN KEY (dataset_id) REFERENCES "datasets"(id) ON DELETE CASCADE, + UNIQUE(dataset_id) +); diff --git a/server/migrations/2024-08-12-191216_delete_last_processed_table/up.sql b/server/migrations/2024-08-12-191216_delete_last_processed_table/up.sql new file mode 100644 index 0000000000..6eab645e73 --- /dev/null +++ b/server/migrations/2024-08-12-191216_delete_last_processed_table/up.sql @@ -0,0 +1,2 @@ +-- Your SQL goes +DROP TABLE IF EXISTS "dataset_words_last_processed"; \ No newline at end of file diff --git a/server/src/bin/bktree-worker.rs b/server/src/bin/bktree-worker.rs new file mode 100644 index 0000000000..122fb6def7 --- /dev/null +++ b/server/src/bin/bktree-worker.rs @@ -0,0 +1,387 @@ +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; + +use chm::tools::migrations::SetupArgs; +use rand::Rng; +use sentry::{Hub, SentryFutureExt}; +use signal_hook::consts::SIGTERM; +use tracing_subscriber::{prelude::*, EnvFilter, Layer}; +use trieve_server::{ + data::models::RedisPool, + errors::ServiceError, + get_env, + operators::{ + chunk_operator::get_last_processed_from_clickhouse, + dataset_operator::{scroll_words_from_dataset, update_dataset_last_processed_query}, + words_operator::{BkTree, CreateBkTreeMessage}, + }, +}; + +#[allow(clippy::print_stdout)] +fn main() { + dotenvy::dotenv().ok(); + let sentry_url = std::env::var("SENTRY_URL"); + let _guard = if let Ok(sentry_url) = sentry_url { + let guard = sentry::init(( + sentry_url, + sentry::ClientOptions { + release: sentry::release_name!(), + traces_sample_rate: 1.0, + ..Default::default() + }, + )); + + tracing_subscriber::Registry::default() + .with(sentry::integrations::tracing::layer()) + .with( + tracing_subscriber::fmt::layer().with_filter( + EnvFilter::from_default_env() + .add_directive(tracing_subscriber::filter::LevelFilter::INFO.into()), + ), + ) + .init(); + + log::info!("Sentry monitoring enabled"); + Some(guard) + } else { + tracing_subscriber::Registry::default() + .with( + tracing_subscriber::fmt::layer().with_filter( + EnvFilter::from_default_env() + .add_directive(tracing_subscriber::filter::LevelFilter::INFO.into()), + ), + ) + .init(); + + None + }; + + let should_terminate = Arc::new(AtomicBool::new(false)); + signal_hook::flag::register(SIGTERM, Arc::clone(&should_terminate)) + .expect("Failed to register shutdown hook"); + + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("Failed to create tokio runtime") + .block_on( + async move { + let redis_url = get_env!("REDIS_URL", "REDIS_URL is not set"); + let redis_connections: u32 = std::env::var("REDIS_CONNECTIONS") + .unwrap_or("2".to_string()) + .parse() + .unwrap_or(2); + + let redis_manager = bb8_redis::RedisConnectionManager::new(redis_url) + .expect("Failed to connect to redis"); + + let redis_pool = bb8_redis::bb8::Pool::builder() + .max_size(redis_connections) + .connection_timeout(std::time::Duration::from_secs(2)) + .build(redis_manager) + .await + .expect("Failed to create redis pool"); + + let web_redis_pool = actix_web::web::Data::new(redis_pool); + + let args = SetupArgs { + url: Some(get_env!("CLICKHOUSE_URL", "CLICKHOUSE_URL is not set").to_string()), + user: Some( + get_env!("CLICKHOUSE_USER", "CLICKHOUSE_USER is not set").to_string(), + ), + password: Some( + get_env!("CLICKHOUSE_PASSWORD", "CLICKHOUSE_PASSWORD is not set") + .to_string(), + ), + database: Some( + get_env!("CLICKHOUSE_DB", "CLICKHOUSE_DB is not set").to_string(), + ), + }; + + let clickhouse_client = clickhouse::Client::default() + .with_url(args.url.as_ref().unwrap()) + .with_user(args.user.as_ref().unwrap()) + .with_password(args.password.as_ref().unwrap()) + .with_database(args.database.as_ref().unwrap()) + .with_option("async_insert", "1") + .with_option("wait_for_async_insert", "0"); + + let should_terminate = Arc::new(AtomicBool::new(false)); + signal_hook::flag::register(SIGTERM, Arc::clone(&should_terminate)) + .expect("Failed to register shutdown hook"); + + bktree_worker(should_terminate, web_redis_pool, clickhouse_client).await + } + .bind_hub(Hub::new_from_top(Hub::current())), + ); +} + +#[allow(clippy::print_stdout)] +async fn bktree_worker( + should_terminate: Arc, + redis_pool: actix_web::web::Data, + clickhouse_client: clickhouse::Client, +) { + log::info!("Starting bk tree service thread"); + + let mut redis_conn_sleep = std::time::Duration::from_secs(1); + + #[allow(unused_assignments)] + let mut opt_redis_connection = None; + + loop { + let borrowed_redis_connection = match redis_pool.get().await { + Ok(redis_connection) => Some(redis_connection), + Err(err) => { + log::error!("Failed to get redis connection outside of loop: {:?}", err); + None + } + }; + + if borrowed_redis_connection.is_some() { + opt_redis_connection = borrowed_redis_connection; + break; + } + + tokio::time::sleep(redis_conn_sleep).await; + redis_conn_sleep = std::cmp::min(redis_conn_sleep * 2, std::time::Duration::from_secs(300)); + } + + let mut redis_connection = + opt_redis_connection.expect("Failed to get redis connection outside of loop"); + + let mut broken_pipe_sleep = std::time::Duration::from_secs(10); + + loop { + if should_terminate.load(Ordering::Relaxed) { + log::info!("Shutting down"); + break; + } + + let payload_result: Result, redis::RedisError> = redis::cmd("SPOP") + .arg("bktree_creation") + .query_async(&mut *redis_connection) + .await; + + let serialized_message = match payload_result { + Ok(payload) => { + broken_pipe_sleep = std::time::Duration::from_secs(10); + + if payload.is_empty() { + continue; + } + let _: Result = redis::cmd("SADD") + .arg("bktree_processing") + .query_async(&mut *redis_connection) + .await; + + payload + .first() + .expect("Payload must have a first element") + .clone() + } + Err(err) => { + log::error!("Unable to process {:?}", err); + + if err.is_io_error() { + tokio::time::sleep(broken_pipe_sleep).await; + broken_pipe_sleep = + std::cmp::min(broken_pipe_sleep * 2, std::time::Duration::from_secs(300)); + } + + continue; + } + }; + + let create_tree_msg: CreateBkTreeMessage = match serde_json::from_str(&serialized_message) { + Ok(message) => message, + Err(err) => { + log::error!( + "Failed to deserialize message, was not a CreateBkTreeMessage: {:?}", + err + ); + continue; + } + }; + + let mut id_offset = uuid::Uuid::nil(); + log::info!("Processing dataset {}", create_tree_msg.dataset_id); + + let mut bk_tree = if let Ok(Some(bktree)) = + BkTree::from_redis(create_tree_msg.dataset_id, redis_pool.clone()).await + { + bktree + } else { + BkTree::new() + }; + + let mut failed = false; + + let last_processed = + get_last_processed_from_clickhouse(&clickhouse_client, create_tree_msg.dataset_id) + .await; + + let last_processed = match last_processed { + Ok(last_processed) => last_processed.map(|lp| lp.last_processed), + Err(err) => { + let _ = readd_error_to_queue(create_tree_msg.clone(), &err, redis_pool.clone()) + .await + .map_err(|e| { + eprintln!("Failed to readd error to queue: {:?}", e); + }); + continue; + } + }; + + while let Ok(Some(word_and_counts)) = scroll_words_from_dataset( + create_tree_msg.dataset_id, + id_offset, + last_processed, + 5000, + &clickhouse_client, + ) + .await + .map_err(|err| { + let err = err.clone(); + let redis_pool = redis_pool.clone(); + let create_tree_msg = create_tree_msg.clone(); + tokio::spawn(async move { + let _ = readd_error_to_queue(create_tree_msg.clone(), &err, redis_pool.clone()) + .await + .map_err(|e| { + eprintln!("Failed to readd error to queue: {:?}", e); + }); + }); + failed = true; + }) { + dbg!(id_offset); + if let Some(last_word) = word_and_counts.last() { + id_offset = last_word.id; + } + + let word_and_counts = word_and_counts + .into_iter() + .map(|words| (words.word, words.count)) + .collect::>(); + + bk_tree.insert_all(word_and_counts); + } + + if failed { + continue; + } + + match bk_tree + .save(create_tree_msg.dataset_id, redis_pool.clone()) + .await + { + Ok(()) => { + let _ = redis::cmd("LREM") + .arg("bktree_processing") + .arg(1) + .arg(serialized_message.clone()) + .query_async::(&mut *redis_connection) + .await; + + log::info!( + "Succesfully created bk-tree for {}", + create_tree_msg.dataset_id + ); + } + Err(err) => { + let _ = readd_error_to_queue( + create_tree_msg.clone(), + &ServiceError::InternalServerError(format!( + "Failed to serialize tree: {:?}", + err + )), + redis_pool.clone(), + ) + .await; + } + } + + match update_dataset_last_processed_query(create_tree_msg.dataset_id, &clickhouse_client) + .await + { + Ok(_) => {} + Err(err) => { + log::error!("Failed to update last processed {:?}", err); + } + } + let sleep_duration = rand::thread_rng().gen_range(1..=10); + tokio::time::sleep(std::time::Duration::from_secs(sleep_duration)).await; + } +} + +pub async fn readd_error_to_queue( + message: CreateBkTreeMessage, + error: &ServiceError, + redis_pool: actix_web::web::Data, +) -> Result<(), ServiceError> { + let mut message = message; + + let old_payload_message = serde_json::to_string(&message).map_err(|_| { + ServiceError::InternalServerError("Failed to reserialize input for retry".to_string()) + })?; + + let mut redis_conn = redis_pool + .get() + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + + let _ = redis::cmd("SREM") + .arg("bktree_processing") + .arg(1) + .arg(old_payload_message.clone()) + .query_async::(&mut *redis_conn) + .await; + + message.attempt_number += 1; + + if message.attempt_number == 3 { + log::error!("Failed to construct bktree 3 times {:?}", error); + let mut redis_conn = redis_pool + .get() + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + + redis::cmd("SADD") + .arg("bktree_dead_letters") + .arg(old_payload_message) + .query_async(&mut *redis_conn) + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + + return Err(ServiceError::InternalServerError(format!( + "Failed to construct bktree {:?}", + error + ))); + } else { + let new_payload_message = serde_json::to_string(&message).map_err(|_| { + ServiceError::InternalServerError("Failed to reserialize input for retry".to_string()) + })?; + + let mut redis_conn = redis_pool + .get() + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + + log::error!( + "Failed to insert data, re-adding {:?} retry: {:?}", + error, + message.attempt_number + ); + + redis::cmd("SADD") + .arg("bktree_creation") + .arg(&new_payload_message) + .query_async(&mut *redis_conn) + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))? + } + + Ok(()) +} diff --git a/server/src/bin/file-worker.rs b/server/src/bin/file-worker.rs index 39da0a253b..20274e633c 100644 --- a/server/src/bin/file-worker.rs +++ b/server/src/bin/file-worker.rs @@ -232,6 +232,19 @@ async fn file_worker( Ok(Some(file_id)) => { log::info!("Uploaded file: {:?}", file_id); + event_queue + .send(ClickHouseEvent::WorkerEvent( + models::WorkerEvent::from_details( + file_worker_message.dataset_id, + models::EventType::FileUploaded { + file_id, + file_name: file_worker_message.upload_file_data.file_name.clone(), + }, + ) + .into(), + )) + .await; + let _ = redis::cmd("LREM") .arg("file_processing") .arg(1) diff --git a/server/src/bin/word-id-cronjob.rs b/server/src/bin/word-id-cronjob.rs new file mode 100644 index 0000000000..6b1f09861f --- /dev/null +++ b/server/src/bin/word-id-cronjob.rs @@ -0,0 +1,191 @@ +use chm::tools::migrations::SetupArgs; +use diesel_async::pooled_connection::{AsyncDieselConnectionManager, ManagerConfig}; +use futures::future::join_all; +use itertools::Itertools; +use tracing_subscriber::{prelude::*, EnvFilter, Layer}; +use trieve_server::{ + errors::ServiceError, + establish_connection, get_env, + operators::{ + chunk_operator::{ + get_last_processed_from_clickhouse, scroll_chunk_ids_for_dictionary_query, + }, + dataset_operator::get_all_dataset_ids, + words_operator::ProcessWordsFromDatasetMessage, + }, +}; + +#[allow(clippy::print_stdout)] +#[tokio::main] +async fn main() -> Result<(), ServiceError> { + dotenvy::dotenv().ok(); + log::info!("Starting id worker service thread"); + let sentry_url = std::env::var("SENTRY_URL"); + let _guard = if let Ok(sentry_url) = sentry_url { + let guard = sentry::init(( + sentry_url, + sentry::ClientOptions { + release: sentry::release_name!(), + traces_sample_rate: 1.0, + ..Default::default() + }, + )); + + tracing_subscriber::Registry::default() + .with(sentry::integrations::tracing::layer()) + .with( + tracing_subscriber::fmt::layer().with_filter( + EnvFilter::from_default_env() + .add_directive(tracing_subscriber::filter::LevelFilter::INFO.into()), + ), + ) + .init(); + + log::info!("Sentry monitoring enabled"); + Some(guard) + } else { + tracing_subscriber::Registry::default() + .with( + tracing_subscriber::fmt::layer().with_filter( + EnvFilter::from_default_env() + .add_directive(tracing_subscriber::filter::LevelFilter::INFO.into()), + ), + ) + .init(); + + None + }; + + let redis_url = get_env!("REDIS_URL", "REDIS_URL is not set"); + let redis_connections: u32 = std::env::var("REDIS_CONNECTIONS") + .unwrap_or("2".to_string()) + .parse() + .unwrap_or(2); + + let redis_manager = + bb8_redis::RedisConnectionManager::new(redis_url).expect("Failed to connect to redis"); + + let redis_pool = bb8_redis::bb8::Pool::builder() + .max_size(redis_connections) + .connection_timeout(std::time::Duration::from_secs(2)) + .build(redis_manager) + .await + .expect("Failed to create redis pool"); + + let database_url = get_env!("DATABASE_URL", "DATABASE_URL is not set"); + + let mut config = ManagerConfig::default(); + config.custom_setup = Box::new(establish_connection); + + let mgr = AsyncDieselConnectionManager::::new_with_config( + database_url, + config, + ); + + let pool = diesel_async::pooled_connection::deadpool::Pool::builder(mgr) + .max_size(3) + .build() + .expect("Failed to create diesel_async pool"); + + let pool = actix_web::web::Data::new(pool.clone()); + + let args = SetupArgs { + url: Some(get_env!("CLICKHOUSE_URL", "CLICKHOUSE_URL is not set").to_string()), + user: Some(get_env!("CLICKHOUSE_USER", "CLICKHOUSE_USER is not set").to_string()), + password: Some( + get_env!("CLICKHOUSE_PASSWORD", "CLICKHOUSE_PASSWORD is not set").to_string(), + ), + database: Some(get_env!("CLICKHOUSE_DB", "CLICKHOUSE_DB is not set").to_string()), + }; + + let clickhouse_client = clickhouse::Client::default() + .with_url(args.url.as_ref().unwrap()) + .with_user(args.user.as_ref().unwrap()) + .with_password(args.password.as_ref().unwrap()) + .with_database(args.database.as_ref().unwrap()) + .with_option("async_insert", "1") + .with_option("wait_for_async_insert", "0"); + + let dataset_ids = get_all_dataset_ids(pool.clone()).await?; + let dataset_ids_and_processed = dataset_ids + .into_iter() + .map(|dataset_id| { + let clickhouse_client = clickhouse_client.clone(); + async move { + ( + dataset_id, + get_last_processed_from_clickhouse(&clickhouse_client, dataset_id).await, + ) + } + }) + .collect_vec(); + + let dataset_ids_and_processed = join_all(dataset_ids_and_processed).await; + + for (dataset_id, last_processed) in dataset_ids_and_processed { + let mut chunk_id_offset = uuid::Uuid::nil(); + + let last_processed = last_processed.map_err(|_| { + ServiceError::InternalServerError( + "Failed to get last processed from clickhouse".to_string(), + ) + })?; + while let Some(chunk_id_dataset_id_list) = scroll_chunk_ids_for_dictionary_query( + pool.clone(), + dataset_id, + last_processed.clone(), + 10000, + chunk_id_offset, + ) + .await? + { + if let Some((chunk_id, _)) = chunk_id_dataset_id_list.last() { + chunk_id_offset = *chunk_id + } + let redis_futures = + chunk_id_dataset_id_list + .chunks(500) + .map(|chunk_id_dataset_id_list| { + let pool = redis_pool.clone(); + async move { + let mut redis_conn = pool + .get() + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + let process_words_msg = ProcessWordsFromDatasetMessage { + chunks_to_process: chunk_id_dataset_id_list.to_vec(), + attempt_number: 0, + }; + + match serde_json::to_string(&process_words_msg).map_err(|_| { + ServiceError::InternalServerError( + "Failed to serialize message".to_string(), + ) + }) { + Ok(serialized_msg) => redis::cmd("LPUSH") + .arg("create_dictionary") + .arg(serialized_msg) + .query_async::( + &mut *redis_conn, + ) + .await + .map_err(|_| { + ServiceError::InternalServerError( + "Failed to send message to redis".to_string(), + ) + }), + Err(err) => Err(err), + } + } + }); + + let _ = join_all(redis_futures) + .await + .into_iter() + .collect::, ServiceError>>()?; + log::info!("Scrolled {} chunks", chunk_id_dataset_id_list.len()); + } + } + + Ok(()) +} diff --git a/server/src/bin/word-worker.rs b/server/src/bin/word-worker.rs new file mode 100644 index 0000000000..157ad0dc5a --- /dev/null +++ b/server/src/bin/word-worker.rs @@ -0,0 +1,417 @@ +#![allow(clippy::print_stdout)] +use actix_web::web; +use chm::tools::migrations::SetupArgs; +use diesel_async::pooled_connection::{AsyncDieselConnectionManager, ManagerConfig}; +use futures::future::join_all; +use itertools::Itertools; +use sentry::{Hub, SentryFutureExt}; +use signal_hook::consts::SIGTERM; +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; +use tracing_subscriber::{prelude::*, EnvFilter, Layer}; +use trieve_server::{ + data::models, + errors::ServiceError, + establish_connection, get_env, + operators::{ + chunk_operator::get_chunk_html_from_ids_query, + dataset_operator::add_words_to_dataset, + parse_operator::convert_html_to_text, + words_operator::{CreateBkTreeMessage, ProcessWordsFromDatasetMessage}, + }, +}; + +#[allow(clippy::print_stdout)] +fn main() -> Result<(), ServiceError> { + dotenvy::dotenv().ok(); + let sentry_url = std::env::var("SENTRY_URL"); + let _guard = if let Ok(sentry_url) = sentry_url { + let guard = sentry::init(( + sentry_url, + sentry::ClientOptions { + release: sentry::release_name!(), + traces_sample_rate: 1.0, + ..Default::default() + }, + )); + + tracing_subscriber::Registry::default() + .with(sentry::integrations::tracing::layer()) + .with( + tracing_subscriber::fmt::layer().with_filter( + EnvFilter::from_default_env() + .add_directive(tracing_subscriber::filter::LevelFilter::INFO.into()), + ), + ) + .init(); + + log::info!("Sentry monitoring enabled"); + Some(guard) + } else { + tracing_subscriber::Registry::default() + .with( + tracing_subscriber::fmt::layer().with_filter( + EnvFilter::from_default_env() + .add_directive(tracing_subscriber::filter::LevelFilter::INFO.into()), + ), + ) + .init(); + + None + }; + + let database_url = get_env!("DATABASE_URL", "DATABASE_URL is not set"); + + let mut config = ManagerConfig::default(); + config.custom_setup = Box::new(establish_connection); + + let mgr = AsyncDieselConnectionManager::::new_with_config( + database_url, + config, + ); + + let pool = diesel_async::pooled_connection::deadpool::Pool::builder(mgr) + .max_size(3) + .build() + .expect("Failed to create diesel_async pool"); + + let web_pool = actix_web::web::Data::new(pool.clone()); + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("Failed to create tokio runtime") + .block_on( + async move { + let redis_url = get_env!("REDIS_URL", "REDIS_URL is not set"); + + let redis_manager = bb8_redis::RedisConnectionManager::new(redis_url) + .expect("Failed to connect to redis"); + + let redis_pool = bb8_redis::bb8::Pool::builder() + .connection_timeout(std::time::Duration::from_secs(2)) + .build(redis_manager) + .await + .expect("Failed to create redis pool"); + + let web_redis_pool = actix_web::web::Data::new(redis_pool); + + let args = SetupArgs { + url: Some(get_env!("CLICKHOUSE_URL", "CLICKHOUSE_URL is not set").to_string()), + user: Some( + get_env!("CLICKHOUSE_USER", "CLICKHOUSE_USER is not set").to_string(), + ), + password: Some( + get_env!("CLICKHOUSE_PASSWORD", "CLICKHOUSE_PASSWORD is not set") + .to_string(), + ), + database: Some( + get_env!("CLICKHOUSE_DB", "CLICKHOUSE_DB is not set").to_string(), + ), + }; + + let clickhouse_client = clickhouse::Client::default() + .with_url(args.url.as_ref().unwrap()) + .with_user(args.user.as_ref().unwrap()) + .with_password(args.password.as_ref().unwrap()) + .with_database(args.database.as_ref().unwrap()) + .with_option("async_insert", "1") + .with_option("wait_for_async_insert", "0"); + + let should_terminate = Arc::new(AtomicBool::new(false)); + signal_hook::flag::register(SIGTERM, Arc::clone(&should_terminate)) + .expect("Failed to register shutdown hook"); + word_worker( + should_terminate, + web_redis_pool, + web_pool, + clickhouse_client, + ) + .await + } + .bind_hub(Hub::new_from_top(Hub::current())), + ); + + Ok(()) +} + +async fn word_worker( + should_terminate: Arc, + redis_pool: actix_web::web::Data, + web_pool: actix_web::web::Data, + clickhouse_client: clickhouse::Client, +) { + log::info!("Starting word worker service thread"); + let mut redis_conn_sleep = std::time::Duration::from_secs(1); + + #[allow(unused_assignments)] + let mut opt_redis_connection = None; + + loop { + let borrowed_redis_connection = match redis_pool.get().await { + Ok(redis_connection) => Some(redis_connection), + Err(err) => { + log::error!("Failed to get redis connection outside of loop: {:?}", err); + None + } + }; + + if borrowed_redis_connection.is_some() { + opt_redis_connection = borrowed_redis_connection; + break; + } + + tokio::time::sleep(redis_conn_sleep).await; + redis_conn_sleep = std::cmp::min(redis_conn_sleep * 2, std::time::Duration::from_secs(300)); + } + + let mut redis_connection = + opt_redis_connection.expect("Failed to get redis connection outside of loop"); + + let mut broken_pipe_sleep = std::time::Duration::from_secs(10); + + loop { + if should_terminate.load(Ordering::Relaxed) { + log::info!("Shutting down"); + break; + } + + let payload_result: Result, redis::RedisError> = redis::cmd("brpoplpush") + .arg("create_dictionary") + .arg("process_dictionary") + .arg(1.0) + .query_async(&mut *redis_connection) + .await; + + let serialized_msg = match payload_result { + Ok(payload) => { + broken_pipe_sleep = std::time::Duration::from_secs(10); + + if payload.is_empty() { + continue; + } + + payload + .first() + .expect("Payload must have a first element") + .clone() + } + Err(err) => { + log::error!("Unable to process {:?}", err); + + if err.is_io_error() { + tokio::time::sleep(broken_pipe_sleep).await; + broken_pipe_sleep = + std::cmp::min(broken_pipe_sleep * 2, std::time::Duration::from_secs(300)); + } + + continue; + } + }; + + let msg: ProcessWordsFromDatasetMessage = match serde_json::from_str(&serialized_msg) { + Ok(message) => message, + Err(err) => { + log::error!( + "Failed to deserialize message, was not an IngestionMessage: {:?}", + err + ); + continue; + } + }; + + match process_chunks( + msg.clone(), + web_pool.clone(), + redis_pool.clone(), + clickhouse_client.clone(), + ) + .await + { + Ok(()) => { + log::info!("Processing {} chunks", msg.chunks_to_process.len()); + } + Err(err) => { + log::error!("Failed to process dataset: {:?}", err); + let _ = readd_error_to_queue(msg.clone(), err, redis_pool.clone()).await; + } + } + } +} + +async fn process_chunks( + message: ProcessWordsFromDatasetMessage, + pool: web::Data, + redis_pool: web::Data, + clickhouse_client: clickhouse::Client, +) -> Result<(), ServiceError> { + let mut word_count_map: HashMap<(uuid::Uuid, String), i32> = HashMap::new(); + if let Some(chunks) = get_chunk_html_from_ids_query( + message + .chunks_to_process + .clone() + .into_iter() + .map(|x| x.0) + .collect(), + pool.clone(), + ) + .await? + { + let chunks = chunks + .into_iter() + // add dataset_id back to chunks + .zip(message.chunks_to_process.clone().into_iter().map(|x| x.1)) + .collect_vec(); + + for ((_, chunk), dataset_id) in &chunks { + let content = convert_html_to_text(chunk); + for word in content + .split([' ', '\n', '\t', '\r', ',', '.', ';', ':', '!', '?'].as_ref()) + .filter(|word| !word.is_empty()) + { + let word = word + .replace(|c: char| !c.is_alphabetic(), "") + .to_lowercase() + .chars() + .take(50) + .join(""); + if let Some(count) = word_count_map.get_mut(&(*dataset_id, word.clone())) { + *count += 1; + } else { + word_count_map.insert((*dataset_id, word), 1); + } + } + } + } + + let (dataset_id_word, counts): (Vec<_>, Vec<_>) = word_count_map + .into_iter() + .sorted_by_key(|((_, word), _)| word.clone()) + .unzip(); + + let words_and_counts = dataset_id_word + .into_iter() + .zip(counts.into_iter()) + .dedup_by(|((_, word1), _), ((_, word2), _)| word1 == word2) + .collect_vec(); + + let word_dataset_relation_futs = words_and_counts + .chunks(5000) + .map(|ids_counts| { + let words = ids_counts.iter().map(|((_, w), _)| w.clone()).collect_vec(); + let dataset_ids = ids_counts + .iter() + .map(|((d, _), _)| d.to_owned()) + .collect_vec(); + let counts = ids_counts + .iter() + .map(|((_, _), c)| c.to_owned()) + .collect_vec(); + add_words_to_dataset(words, counts, dataset_ids, &clickhouse_client) + }) + .collect_vec(); + + join_all(word_dataset_relation_futs) + .await + .into_iter() + .collect::, ServiceError>>()?; + + let serialized_payload = serde_json::to_string(&message).map_err(|_| { + ServiceError::InternalServerError("Failed to reserialize input".to_string()) + })?; + + let mut redis_conn = redis_pool + .get() + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + + let _ = redis::cmd("LREM") + .arg("process_dictionary") + .arg(1) + .arg(serialized_payload) + .query_async::(&mut *redis_conn) + .await; + + let create_tree_msgs = words_and_counts + .iter() + .map(|((dataset_id, _), _)| *dataset_id) + .unique() + .map(|id| { + let msg = CreateBkTreeMessage { + dataset_id: id, + attempt_number: 0, + }; + + serde_json::to_string(&msg).map_err(|_| { + ServiceError::InternalServerError("Failed to serialize message".to_string()) + }) + }) + .collect::, ServiceError>>()?; + + redis::cmd("SADD") + .arg("bktree_creation") + .arg(create_tree_msgs) + .query_async::(&mut *redis_conn) + .await + .map_err(|_| { + ServiceError::InternalServerError("Failed to send message to redis".to_string()) + })?; + + Ok(()) +} + +#[tracing::instrument(skip(redis_pool))] +pub async fn readd_error_to_queue( + mut message: ProcessWordsFromDatasetMessage, + error: ServiceError, + redis_pool: actix_web::web::Data, +) -> Result<(), ServiceError> { + let old_payload_message = serde_json::to_string(&message).map_err(|_| { + ServiceError::InternalServerError("Failed to reserialize input for retry".to_string()) + })?; + + let mut redis_conn = redis_pool + .get() + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + + let _ = redis::cmd("lrem") + .arg("process_dictionary") + .arg(1) + .arg(old_payload_message.clone()) + .query_async::(&mut *redis_conn) + .await; + + message.attempt_number += 1; + + if message.attempt_number == 3 { + log::error!("Failed to process dataset 3 times: {:?}", error); + redis::cmd("lpush") + .arg("dictionary_dead_letters") + .arg(old_payload_message) + .query_async(&mut *redis_conn) + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + return Err(ServiceError::InternalServerError(format!( + "Failed to create new qdrant point: {:?}", + error + ))); + } + + let new_payload_message = serde_json::to_string(&message).map_err(|_| { + ServiceError::InternalServerError("Failed to reserialize input for retry".to_string()) + })?; + + redis::cmd("lpush") + .arg("create_dictionary") + .arg(&new_payload_message) + .query_async(&mut *redis_conn) + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + + Ok(()) +} diff --git a/server/src/data/models.rs b/server/src/data/models.rs index 91c706bfdc..e2e5da570f 100644 --- a/server/src/data/models.rs +++ b/server/src/data/models.rs @@ -4,7 +4,8 @@ use super::schema::*; use crate::errors::ServiceError; use crate::get_env; use crate::handlers::chunk_handler::{ - AutocompleteReqPayload, ChunkFilter, FullTextBoost, SearchChunksReqPayload, SemanticBoost, + AutocompleteReqPayload, ChunkFilter, FullTextBoost, ParsedQuery, SearchChunksReqPayload, + SemanticBoost, }; use crate::handlers::file_handler::UploadFileReqPayload; use crate::handlers::group_handler::{SearchOverGroupsReqPayload, SearchWithinGroupReqPayload}; @@ -327,7 +328,7 @@ impl FromSql for GeoInfo { fn from_sql(bytes: PgValue) -> deserialize::Result { let bytes = bytes.as_bytes(); - if bytes[0] != 1 { + if bytes.get(0) != Some(&1) { return Err("Unsupported JSONB encoding version".into()); } serde_json::from_slice(&bytes[1..]).map_err(Into::into) @@ -1871,9 +1872,9 @@ pub struct Dataset { pub created_at: chrono::NaiveDateTime, pub updated_at: chrono::NaiveDateTime, pub organization_id: uuid::Uuid, + pub server_configuration: serde_json::Value, pub tracking_id: Option, pub deleted: i32, - pub server_configuration: serde_json::Value, } impl Dataset { @@ -5019,6 +5020,28 @@ pub struct HighlightOptions { pub highlight_window: Option, } +#[derive(Serialize, Deserialize, Debug, Clone, ToSchema, Default)] +/// Typo Options lets you specify different methods to correct typos in the query. If not specified, typos will not be corrected. +pub struct TypoOptions { + /// Set correct_typos to true to correct typos in the query. If not specified, this defaults to false. + pub correct_typos: Option, + /// The range of which the query will be corrected if it has one typo. If not specified, this defaults to 5-8. + pub one_typo_word_range: Option, + /// The range of which the query will be corrected if it has two typos. If not specified, this defaults to 8-inf. + pub two_typo_word_range: Option, + /// Words that should not be corrected. If not specified, this defaults to an empty list. + pub disable_on_word: Option>, +} + +#[derive(Serialize, Deserialize, Debug, Clone, ToSchema, Default)] +/// The TypoRange struct is used to specify the range of which the query will be corrected if it has a typo. +pub struct TypoRange { + /// The minimum number of characters that the query will be corrected if it has a typo. If not specified, this defaults to 5. + pub min: u32, + /// The maximum number of characters that the query will be corrected if it has a typo. If not specified, this defaults to 8. + pub max: Option, +} + #[derive(Debug, Serialize, Deserialize, ToSchema, Clone, Default)] /// LLM options to use for the completion. If not specified, this defaults to the dataset's LLM options. pub struct LLMOptions { @@ -5174,6 +5197,7 @@ impl<'de> Deserialize<'de> for SearchChunksReqPayload { use_quote_negated_terms: Option, remove_stop_words: Option, user_id: Option, + typo_options: Option, #[serde(flatten)] other: std::collections::HashMap, } @@ -5204,6 +5228,7 @@ impl<'de> Deserialize<'de> for SearchChunksReqPayload { use_quote_negated_terms: helper.use_quote_negated_terms, remove_stop_words: helper.remove_stop_words, user_id: helper.user_id, + typo_options: helper.typo_options, }) } } @@ -5228,6 +5253,7 @@ impl<'de> Deserialize<'de> for AutocompleteReqPayload { use_quote_negated_terms: Option, remove_stop_words: Option, user_id: Option, + typo_options: Option, #[serde(flatten)] other: std::collections::HashMap, } @@ -5257,6 +5283,7 @@ impl<'de> Deserialize<'de> for AutocompleteReqPayload { use_quote_negated_terms: helper.use_quote_negated_terms, remove_stop_words: helper.remove_stop_words, user_id: helper.user_id, + typo_options: helper.typo_options, }) } } @@ -5284,6 +5311,7 @@ impl<'de> Deserialize<'de> for SearchWithinGroupReqPayload { use_quote_negated_terms: Option, remove_stop_words: Option, user_id: Option, + typo_options: Option, #[serde(flatten)] other: std::collections::HashMap, } @@ -5316,6 +5344,7 @@ impl<'de> Deserialize<'de> for SearchWithinGroupReqPayload { use_quote_negated_terms: helper.use_quote_negated_terms, remove_stop_words: helper.remove_stop_words, user_id: helper.user_id, + typo_options: helper.typo_options, }) } } @@ -5340,6 +5369,7 @@ impl<'de> Deserialize<'de> for SearchOverGroupsReqPayload { use_quote_negated_terms: Option, remove_stop_words: Option, user_id: Option, + typo_options: Option, #[serde(flatten)] other: std::collections::HashMap, } @@ -5365,6 +5395,7 @@ impl<'de> Deserialize<'de> for SearchOverGroupsReqPayload { score_threshold: helper.score_threshold, slim_chunks: helper.slim_chunks, use_quote_negated_terms: helper.use_quote_negated_terms, + typo_options: helper.typo_options, remove_stop_words: helper.remove_stop_words, user_id: helper.user_id, }) @@ -5527,6 +5558,15 @@ pub struct MultiQuery { pub weight: f32, } +impl From<(ParsedQuery, f32)> for MultiQuery { + fn from((query, weight): (ParsedQuery, f32)) -> Self { + Self { + query: query.query, + weight, + } + } +} + #[derive(Debug, Serialize, Deserialize, ToSchema, Clone, PartialEq)] #[serde(untagged)] /// Query is the search query. This can be any string. The query will be used to create an embedding vector and/or SPLADE vector which will be used to find the result set. You can either provide one query, or multiple with weights. Multi-query only works with Semantic Search and is not compatible with cross encoder re-ranking or highlights. @@ -5545,3 +5585,27 @@ impl QueryTypes { } } } + +#[derive(Debug, Serialize, Deserialize, Row, Clone, ToSchema)] +pub struct WordDataset { + #[serde(with = "clickhouse::serde::uuid")] + pub id: uuid::Uuid, + #[serde(with = "clickhouse::serde::uuid")] + pub dataset_id: uuid::Uuid, + pub word: String, + pub count: i32, + #[serde(with = "clickhouse::serde::time::datetime")] + pub created_at: OffsetDateTime, +} + +impl WordDataset { + pub fn from_details(word: String, dataset_id: uuid::Uuid, count: i32) -> Self { + Self { + id: uuid::Uuid::new_v4(), + word, + dataset_id, + count, + created_at: OffsetDateTime::now_utc(), + } + } +} diff --git a/server/src/data/schema.rs b/server/src/data/schema.rs index 3921eba22a..96b9ecd825 100644 --- a/server/src/data/schema.rs +++ b/server/src/data/schema.rs @@ -90,9 +90,9 @@ diesel::table! { created_at -> Timestamp, updated_at -> Timestamp, organization_id -> Uuid, + server_configuration -> Jsonb, tracking_id -> Nullable, deleted -> Int4, - server_configuration -> Jsonb, } } diff --git a/server/src/handlers/auth_handler.rs b/server/src/handlers/auth_handler.rs index daf69c157c..a8c0d04f81 100644 --- a/server/src/handlers/auth_handler.rs +++ b/server/src/handlers/auth_handler.rs @@ -178,7 +178,11 @@ pub async fn create_account( .organization, ), None => { - let org_name = email.split('@').collect::>()[0] + let org_name = email + .split('@') + .collect::>() + .get(0) + .unwrap_or(&"") .to_string() .replace(' ', "-"); ( diff --git a/server/src/handlers/chunk_handler.rs b/server/src/handlers/chunk_handler.rs index 1830e3b9db..b489ea1ba2 100644 --- a/server/src/handlers/chunk_handler.rs +++ b/server/src/handlers/chunk_handler.rs @@ -5,7 +5,7 @@ use crate::data::models::{ HighlightOptions, IngestSpecificChunkMetadata, Pool, QueryTypes, RagQueryEventClickhouse, RecommendType, RecommendationEventClickhouse, RecommendationStrategy, RedisPool, ScoreChunk, ScoreChunkDTO, SearchMethod, SearchQueryEventClickhouse, SlimChunkMetadataWithScore, - SortByField, SortOptions, UnifiedId, UpdateSpecificChunkMetadata, + SortByField, SortOptions, TypoOptions, UnifiedId, UpdateSpecificChunkMetadata, }; use crate::errors::ServiceError; use crate::get_env; @@ -963,6 +963,8 @@ pub struct SearchChunksReqPayload { pub remove_stop_words: Option, /// User ID is the id of the user who is making the request. This is used to track user interactions with the search results. pub user_id: Option, + /// Typo options lets you specify different methods to handle typos in the search query. If not specified, this defaults to no typo handling. + pub typo_options: Option, } impl Default for SearchChunksReqPayload { @@ -982,6 +984,7 @@ impl Default for SearchChunksReqPayload { use_quote_negated_terms: None, remove_stop_words: None, user_id: None, + typo_options: None, } } } @@ -990,6 +993,7 @@ impl Default for SearchChunksReqPayload { #[schema(title = "V1")] pub struct SearchChunkQueryResponseBody { pub score_chunks: Vec, + pub corrected_query: Option, pub total_chunk_pages: i64, } @@ -998,6 +1002,7 @@ pub struct SearchChunkQueryResponseBody { pub struct SearchResponseBody { pub id: uuid::Uuid, pub chunks: Vec, + pub corrected_query: Option, pub total_pages: i64, } @@ -1019,6 +1024,7 @@ impl SearchChunkQueryResponseBody { .into_iter() .map(|chunk| chunk.into()) .collect(), + corrected_query: self.corrected_query, total_pages: self.total_chunk_pages, } } @@ -1083,7 +1089,7 @@ pub fn parse_query( let re = Regex::new(r#""(?:[^"\\]|\\.)*""#).expect("Regex pattern is always valid"); let quote_words: Vec = re .captures_iter(&query) - .map(|capture| capture[0].to_string()) + .filter_map(|capture| capture.get(0).map(|capture| capture.as_str().to_string())) .filter(|word| !word.is_empty()) .collect::>(); @@ -1140,18 +1146,21 @@ pub fn parse_query( ("ApiKey" = ["readonly"]), ) )] -#[tracing::instrument(skip(pool, event_queue))] +#[tracing::instrument(skip(pool, event_queue, redis_pool))] pub async fn search_chunks( - mut data: web::Json, + data: web::Json, _user: LoggedUser, pool: web::Data, event_queue: web::Data, + redis_pool: web::Data, api_version: APIVersion, dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan, ) -> Result { let dataset_config = DatasetConfiguration::from_json(dataset_org_plan_sub.dataset.server_configuration.clone()); + let mut data = data.into_inner(); + let parsed_query = match data.query.clone() { QueryTypes::Single(query) => ParsedQueryTypes::Single(parse_query( query.clone(), @@ -1186,6 +1195,7 @@ pub async fn search_chunks( data.clone(), parsed_query.to_parsed_query()?, pool, + redis_pool, dataset_org_plan_sub.dataset.clone(), &dataset_config, &mut timer, @@ -1197,6 +1207,7 @@ pub async fn search_chunks( data.clone(), parsed_query, pool, + redis_pool, dataset_org_plan_sub.dataset.clone(), &dataset_config, &mut timer, @@ -1332,6 +1343,7 @@ pub struct AutocompleteReqPayload { pub remove_stop_words: Option, /// User ID is the id of the user who is making the request. This is used to track user interactions with the search results. pub user_id: Option, + pub typo_options: Option, } impl From for SearchChunksReqPayload { @@ -1351,6 +1363,7 @@ impl From for SearchChunksReqPayload { use_quote_negated_terms: autocomplete_data.use_quote_negated_terms, remove_stop_words: autocomplete_data.remove_stop_words, user_id: autocomplete_data.user_id, + typo_options: autocomplete_data.typo_options, } } } @@ -1376,12 +1389,13 @@ impl From for SearchChunksReqPayload { ("ApiKey" = ["readonly"]), ) )] -#[tracing::instrument(skip(pool, event_queue))] +#[tracing::instrument(skip(pool, event_queue, redis_pool))] pub async fn autocomplete( data: web::Json, _user: LoggedUser, pool: web::Data, event_queue: web::Data, + redis_pool: web::Data, api_version: APIVersion, dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan, ) -> Result { @@ -1404,6 +1418,7 @@ pub async fn autocomplete( data.clone(), parsed_query, pool, + redis_pool, dataset_org_plan_sub.dataset.clone(), &dataset_config, &mut timer, @@ -1645,6 +1660,7 @@ impl From for SearchChunksReqPayload { use_quote_negated_terms: count_data.use_quote_negated_terms, remove_stop_words: None, user_id: None, + typo_options: None, } } } @@ -2513,11 +2529,19 @@ pub async fn generate_off_chunks( )) })?; - let chat_content = assistant_completion.choices[0].message.content.clone(); + let chat_content = match assistant_completion.choices.get(0) { + Some(choice) => choice.message.content.clone(), + None => { + return Err(ServiceError::InternalServerError( + "Failed to get response completion".into(), + ) + .into()) + } + }; - let completion_content = match &chat_content { + let completion_content = match chat_content.clone() { ChatMessageContent::Text(text) => text.clone(), - _ => "".to_string(), + _ => "Failed to get response completion".to_string(), }; let clickhouse_rag_event = RagQueryEventClickhouse { @@ -2579,7 +2603,10 @@ pub async fn generate_off_chunks( let completion_stream = stream.map(move |response| -> Result { if let Ok(response) = response { - let chat_content = response.choices[0].delta.content.clone(); + let chat_content = match response.choices.get(0) { + Some(choice) => choice.delta.content.clone(), + None => Some("failed to get response completion".to_string()), + }; if let Some(message) = chat_content.clone() { s.send(message).unwrap(); } diff --git a/server/src/handlers/group_handler.rs b/server/src/handlers/group_handler.rs index a498ea54c8..1082cfab00 100644 --- a/server/src/handlers/group_handler.rs +++ b/server/src/handlers/group_handler.rs @@ -10,7 +10,7 @@ use crate::{ ChunkMetadataStringTagSet, DatasetAndOrgWithSubAndPlan, DatasetConfiguration, HighlightOptions, Pool, QueryTypes, RecommendType, RecommendationEventClickhouse, RecommendationStrategy, RedisPool, ScoreChunk, ScoreChunkDTO, SearchMethod, - SearchQueryEventClickhouse, SortOptions, UnifiedId, + SearchQueryEventClickhouse, SortOptions, TypoOptions, UnifiedId, }, errors::ServiceError, middleware::api_version::APIVersion, @@ -277,7 +277,10 @@ pub async fn create_chunk_group( .collect::>(); if created_groups.len() == 1 { - Ok(HttpResponse::Ok().json(created_groups[0].clone())) + match created_groups.get(0) { + Some(group) => Ok(HttpResponse::Ok().json(group.clone())), + None => Ok(HttpResponse::Ok().json(serde_json::json!({}))), + } } else { Ok(HttpResponse::Ok().json(created_groups)) } @@ -1298,6 +1301,7 @@ pub async fn get_recommended_groups( let group_qdrant_query_result = SearchOverGroupsQueryResult { search_results: recommended_groups_from_qdrant.clone(), + corrected_query: None, total_chunk_pages: (recommended_groups_from_qdrant.len() as f64 / 10.0).ceil() as i64, }; @@ -1409,6 +1413,7 @@ pub struct SearchWithinGroupReqPayload { pub remove_stop_words: Option, /// The user_id is the id of the user who is making the request. This is used to track user interactions with the search results. pub user_id: Option, + pub typo_options: Option, } impl From for SearchChunksReqPayload { @@ -1428,6 +1433,7 @@ impl From for SearchChunksReqPayload { use_quote_negated_terms: search_within_group_data.use_quote_negated_terms, remove_stop_words: search_within_group_data.remove_stop_words, user_id: search_within_group_data.user_id, + typo_options: search_within_group_data.typo_options, } } } @@ -1437,6 +1443,7 @@ impl From for SearchChunksReqPayload { pub struct SearchWithinGroupResults { pub bookmarks: Vec, pub group: ChunkGroupAndFileId, + pub corrected_query: Option, pub total_pages: i64, } @@ -1445,6 +1452,7 @@ pub struct SearchWithinGroupResults { pub struct SearchWithinGroupResponseBody { pub id: uuid::Uuid, pub chunks: Vec, + pub corrected_query: Option, pub total_pages: i64, } @@ -1466,6 +1474,7 @@ impl SearchWithinGroupResults { .into_iter() .map(|chunk| chunk.into()) .collect(), + corrected_query: self.corrected_query, total_pages: self.total_pages, } } @@ -1492,11 +1501,12 @@ impl SearchWithinGroupResults { ("ApiKey" = ["readonly"]), ) )] -#[tracing::instrument(skip(pool, event_queue))] +#[tracing::instrument(skip(pool, event_queue, redis_pool))] pub async fn search_within_group( data: web::Json, pool: web::Data, event_queue: web::Data, + redis_pool: web::Data, api_version: APIVersion, _required_user: LoggedUser, dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan, @@ -1504,6 +1514,8 @@ pub async fn search_within_group( let dataset_config = DatasetConfiguration::from_json(dataset_org_plan_sub.dataset.server_configuration.clone()); + let data = data.into_inner(); + //search over the links as well let group_id = data.group_id; let dataset_id = dataset_org_plan_sub.dataset.id; @@ -1551,8 +1563,10 @@ pub async fn search_within_group( parsed_query.to_parsed_query()?, group, search_pool, + redis_pool, dataset_org_plan_sub.dataset.clone(), &dataset_config, + &mut timer, ) .await? } @@ -1562,8 +1576,10 @@ pub async fn search_within_group( parsed_query, group, search_pool, + redis_pool, dataset_org_plan_sub.dataset.clone(), &dataset_config, + &mut timer, ) .await? } @@ -1642,6 +1658,7 @@ pub struct SearchOverGroupsReqPayload { pub remove_stop_words: Option, /// The user_id is the id of the user who is making the request. This is used to track user interactions with the search results. pub user_id: Option, + pub typo_options: Option, } /// Search Over Groups @@ -1665,11 +1682,12 @@ pub struct SearchOverGroupsReqPayload { ("ApiKey" = ["readonly"]), ) )] -#[tracing::instrument(skip(pool, event_queue))] +#[tracing::instrument(skip(pool, event_queue, redis_pool))] pub async fn search_over_groups( data: web::Json, pool: web::Data, event_queue: web::Data, + redis_pool: web::Data, api_version: APIVersion, _required_user: LoggedUser, dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan, @@ -1713,6 +1731,7 @@ pub async fn search_over_groups( data.clone(), parsed_query, pool, + redis_pool, dataset_org_plan_sub.dataset.clone(), &dataset_config, &mut timer, @@ -1724,6 +1743,7 @@ pub async fn search_over_groups( data.clone(), parsed_query.to_parsed_query()?, pool, + redis_pool, dataset_org_plan_sub.dataset.clone(), &dataset_config, &mut timer, @@ -1741,6 +1761,7 @@ pub async fn search_over_groups( data.clone(), parsed_query, pool, + redis_pool, dataset_org_plan_sub.dataset.clone(), &dataset_config, &mut timer, diff --git a/server/src/handlers/message_handler.rs b/server/src/handlers/message_handler.rs index a9b4733bfb..527cb248c5 100644 --- a/server/src/handlers/message_handler.rs +++ b/server/src/handlers/message_handler.rs @@ -5,7 +5,7 @@ use super::{ use crate::{ data::models::{ self, ChunkMetadata, DatasetAndOrgWithSubAndPlan, DatasetConfiguration, HighlightOptions, - LLMOptions, Pool, SearchMethod, SuggestType, + LLMOptions, Pool, RedisPool, SearchMethod, SuggestType, }, errors::ServiceError, get_env, @@ -131,6 +131,7 @@ pub async fn create_message( dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan, event_queue: web::Data, pool: web::Data, + redis_pool: web::Data, ) -> Result { let message_count_pool = pool.clone(); let message_count_org_id = dataset_org_plan_sub.organization.organization.id; @@ -225,6 +226,7 @@ pub async fn create_message( dataset_org_plan_sub.dataset, stream_response_pool, event_queue, + redis_pool, dataset_config, create_message_data, ) @@ -389,13 +391,14 @@ impl From for CreateMessageReqPayload { ("ApiKey" = ["readonly"]), ) )] -#[tracing::instrument(skip(pool, event_queue))] +#[tracing::instrument(skip(pool, event_queue, redis_pool))] pub async fn edit_message( data: web::Json, user: AdminOnly, dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan, pool: web::Data, event_queue: web::Data, + redis_pool: web::Data, ) -> Result { let topic_id: uuid::Uuid = data.topic_id; let message_sort_order = data.message_sort_order; @@ -428,6 +431,7 @@ pub async fn edit_message( dataset_org_plan_sub, event_queue, third_pool, + redis_pool, ) .await } @@ -453,13 +457,14 @@ pub async fn edit_message( ("ApiKey" = ["readonly"]), ) )] -#[tracing::instrument(skip(pool, event_queue))] +#[tracing::instrument(skip(pool, event_queue, redis_pool))] pub async fn regenerate_message_patch( data: web::Json, user: AdminOnly, dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan, pool: web::Data, event_queue: web::Data, + redis_pool: web::Data, ) -> Result { let topic_id = data.topic_id; let dataset_config = @@ -487,6 +492,7 @@ pub async fn regenerate_message_patch( dataset_org_plan_sub.dataset, create_message_pool, event_queue, + redis_pool.clone(), dataset_config, data.into_inner().into(), ) @@ -559,6 +565,7 @@ pub async fn regenerate_message_patch( dataset_org_plan_sub.dataset, create_message_pool, event_queue, + redis_pool.clone(), dataset_config, data.into_inner().into(), ) @@ -587,15 +594,24 @@ pub async fn regenerate_message_patch( ) )] #[deprecated] -#[tracing::instrument(skip(pool, event_queue))] +#[tracing::instrument(skip(pool, event_queue, redis_pool))] pub async fn regenerate_message( data: web::Json, user: AdminOnly, dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan, pool: web::Data, event_queue: web::Data, + redis_pool: web::Data, ) -> Result { - regenerate_message_patch(data, user, dataset_org_plan_sub, pool, event_queue).await + regenerate_message_patch( + data, + user, + dataset_org_plan_sub, + pool, + event_queue, + redis_pool, + ) + .await } #[derive(Deserialize, Serialize, Debug, ToSchema)] @@ -643,6 +659,7 @@ pub async fn get_suggested_queries( data: web::Json, dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan, pool: web::Data, + redis_pool: web::Data, _required_user: LoggedUser, ) -> Result { let dataset_id = dataset_org_plan_sub.dataset.id; @@ -669,7 +686,6 @@ pub async fn get_suggested_queries( ) .into() }; - let search_type = data.search_type.clone().unwrap_or(SearchMethod::Hybrid); let filters = data.filters.clone(); @@ -692,6 +708,7 @@ pub async fn get_suggested_queries( search_req_payload, parsed_query, pool, + redis_pool, dataset_org_plan_sub.dataset.clone(), &dataset_config, &mut Timer::new(), @@ -702,6 +719,7 @@ pub async fn get_suggested_queries( search_req_payload, ParsedQueryTypes::Single(parsed_query), pool, + redis_pool, dataset_org_plan_sub.dataset.clone(), &dataset_config, &mut Timer::new(), @@ -865,6 +883,7 @@ pub async fn get_suggested_queries( Some(cleaned_query) } }) + .map(|query| query.to_string().trim().trim_matches('\n').to_string()) .collect(); while queries.len() < 3 { diff --git a/server/src/lib.rs b/server/src/lib.rs index 6015155d90..f264b713fc 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -9,7 +9,7 @@ use crate::{ handlers::{auth_handler::build_oidc_client, metrics_handler::Metrics}, operators::{ clickhouse_operator::EventQueue, qdrant_operator::create_new_qdrant_collection_query, - user_operator::create_default_user, + user_operator::create_default_user, words_operator::BKTreeCache, }, }; use actix_cors::Cors; @@ -411,6 +411,8 @@ impl Modify for SecurityAddon { data::models::SortOptions, data::models::LLMOptions, data::models::HighlightOptions, + data::models::TypoOptions, + data::models::TypoRange, data::models::SortByField, data::models::SortBySearchType, data::models::ReRankOptions, @@ -629,6 +631,8 @@ pub fn main() -> std::io::Result<()> { (clickhouse::Client::default(), EventQueue::default()) }; + BKTreeCache::enforce_cache_ttl(); + let metrics = Metrics::new().map_err(|e| { std::io::Error::new(std::io::ErrorKind::Other, format!("Failed to create metrics {:?}", e)) diff --git a/server/src/operators/chunk_operator.rs b/server/src/operators/chunk_operator.rs index a4baaf78f0..a9fa4e765e 100644 --- a/server/src/operators/chunk_operator.rs +++ b/server/src/operators/chunk_operator.rs @@ -18,6 +18,7 @@ use crate::{ }; use actix_web::web; use chrono::NaiveDateTime; +use clickhouse::Row; use dateparser::DateTimeUtc; use diesel::dsl::{not, sql}; use diesel::prelude::*; @@ -29,6 +30,7 @@ use itertools::Itertools; use serde::{Deserialize, Serialize}; use simsearch::{SearchOptions, SimSearch}; use std::collections::{HashMap, HashSet}; +use time::OffsetDateTime; use utoipa::ToSchema; use super::group_operator::create_groups_query; @@ -2427,3 +2429,103 @@ pub async fn get_pg_point_ids_from_qdrant_point_ids( Ok(chunk_ids) } + +#[tracing::instrument(skip(pool))] +pub async fn get_chunk_html_from_ids_query( + chunk_ids: Vec, + pool: web::Data, +) -> Result>, ServiceError> { + use crate::data::schema::chunk_metadata::dsl as chunk_metadata_columns; + let mut conn = pool.get().await.unwrap(); + + let chunk_htmls = chunk_metadata_columns::chunk_metadata + .select(( + chunk_metadata_columns::id, + chunk_metadata_columns::chunk_html.assume_not_null(), + )) + .filter(chunk_metadata_columns::id.eq_any(chunk_ids)) + .load::<(uuid::Uuid, String)>(&mut conn) + .await + .map_err(|_| ServiceError::NotFound("Failed to get chunk_htmls".to_string()))?; + + if chunk_htmls.is_empty() { + return Ok(None); + } + Ok(Some(chunk_htmls)) +} + +pub async fn scroll_chunk_ids_for_dictionary_query( + pool: web::Data, + dataset_id: uuid::Uuid, + last_processed: Option, + limit: i64, + offset: uuid::Uuid, +) -> Result>, ServiceError> { + use crate::data::schema::chunk_metadata::dsl as chunk_metadata_columns; + + let mut conn = pool + .get() + .await + .map_err(|_| ServiceError::BadRequest("Could not get database connection".to_string()))?; + + let mut chunk_ids = chunk_metadata_columns::chunk_metadata + .select(( + chunk_metadata_columns::id, + chunk_metadata_columns::dataset_id, + )) + .filter(chunk_metadata_columns::dataset_id.eq(dataset_id)) + .filter(chunk_metadata_columns::id.gt(offset)) + .into_boxed(); + + if let Some(last_processed) = last_processed { + let last_processed = + NaiveDateTime::from_timestamp(last_processed.last_processed.unix_timestamp(), 0); + + chunk_ids = chunk_ids.filter(chunk_metadata_columns::created_at.gt(last_processed)); + } + + let chunk_ids = chunk_ids + .order_by(chunk_metadata_columns::id) + .limit(limit) + .load::<(uuid::Uuid, uuid::Uuid)>(&mut conn) + .await + .map_err(|_| { + log::error!("Failed to scroll dataset ids for dictionary"); + ServiceError::InternalServerError( + "Failed to scroll dataset ids for dictionary".to_string(), + ) + })?; + + if chunk_ids.is_empty() { + return Ok(None); + } + Ok(Some(chunk_ids)) +} + +#[derive(Debug, Clone, Serialize, Deserialize, Row)] +pub struct DatasetLastProcessed { + #[serde(with = "clickhouse::serde::uuid")] + pub dataset_id: uuid::Uuid, + #[serde(with = "clickhouse::serde::time::datetime")] + pub last_processed: OffsetDateTime, +} + +pub async fn get_last_processed_from_clickhouse( + clickhouse_client: &clickhouse::Client, + dataset_id: uuid::Uuid, +) -> Result, ServiceError> { + let query = format!( + "SELECT dataset_id, min(last_processed) as last_processed FROM dataset_words_last_processed WHERE dataset_id = '{}' GROUP BY dataset_id LIMIT 1", + dataset_id + ); + + let last_processed = clickhouse_client + .query(&query) + .fetch_optional::() + .await + .map_err(|_| { + ServiceError::InternalServerError("Failed to get last processed".to_string()) + })?; + + Ok(last_processed) +} diff --git a/server/src/operators/dataset_operator.rs b/server/src/operators/dataset_operator.rs index 3658833b2b..2f044c65a3 100644 --- a/server/src/operators/dataset_operator.rs +++ b/server/src/operators/dataset_operator.rs @@ -1,6 +1,7 @@ use crate::data::models::{ DatasetAndOrgWithSubAndPlan, DatasetAndUsage, DatasetConfiguration, DatasetUsageCount, Organization, OrganizationWithSubAndPlan, RedisPool, StripePlan, StripeSubscription, UnifiedId, + WordDataset, }; use crate::handlers::dataset_handler::{GetDatasetsPagination, TagsWithCount}; use crate::operators::clickhouse_operator::ClickHouseEvent; @@ -12,11 +13,14 @@ use crate::{ errors::ServiceError, }; use actix_web::web; +use clickhouse::Row; use diesel::dsl::count; use diesel::prelude::*; use diesel::result::{DatabaseErrorKind, Error as DBError}; use diesel_async::RunQueryDsl; +use itertools::Itertools; use serde::{Deserialize, Serialize}; +use time::{format_description, OffsetDateTime}; use super::clickhouse_operator::EventQueue; @@ -109,6 +113,24 @@ pub async fn get_deleted_dataset_by_unifiedid_query( Ok(dataset) } +#[tracing::instrument(skip(pool))] +pub async fn get_all_dataset_ids(pool: web::Data) -> Result, ServiceError> { + use crate::data::schema::datasets::dsl as datasets_columns; + let mut conn = pool + .get() + .await + .map_err(|_| ServiceError::BadRequest("Could not get database connection".to_string()))?; + + let datasets = datasets_columns::datasets + .select(datasets_columns::id) + .filter(datasets_columns::deleted.eq(0)) + .load::(&mut conn) + .await + .map_err(|_| ServiceError::NotFound("Could not find dataset".to_string()))?; + + Ok(datasets) +} + #[tracing::instrument(skip(pool))] pub async fn get_dataset_and_organization_from_dataset_id_query( id: UnifiedId, @@ -668,3 +690,166 @@ pub async fn get_tags_in_dataset_query( Ok((items, total_count)) } + +pub async fn scroll_dataset_ids_query( + offset: uuid::Uuid, + limit: i64, + pool: web::Data, +) -> Result>, ServiceError> { + use crate::data::schema::datasets::dsl as datasets_columns; + + let mut conn = pool + .get() + .await + .map_err(|_| ServiceError::BadRequest("Could not get database connection".to_string()))?; + + let datasets = datasets_columns::datasets + .select(datasets_columns::id) + .filter(datasets_columns::id.gt(offset)) + .order_by(datasets_columns::id) + .limit(limit) + .load::(&mut conn) + .await + .map_err(|_| ServiceError::NotFound("Failed to get datasets".to_string()))?; + + if datasets.is_empty() { + return Ok(None); + } + Ok(Some(datasets)) +} + +pub async fn add_words_to_dataset( + words: Vec, + counts: Vec, + dataset_ids: Vec, + clickhouse_client: &clickhouse::Client, +) -> Result<(), ServiceError> { + let rows = words + .into_iter() + .zip(counts) + .zip(dataset_ids) + .map(|((w, count), dataset_id)| WordDataset::from_details(w, dataset_id, count)) + .collect_vec(); + + let mut words_inserter = clickhouse_client + .insert("default.words_datasets") + .map_err(|e| { + log::error!("Error inserting words_datasets: {:?}", e); + sentry::capture_message("Error inserting words_datasets", sentry::Level::Error); + ServiceError::InternalServerError(format!("Error inserting words_datasets: {:?}", e)) + })?; + + for row in rows { + words_inserter.write(&row).await.map_err(|e| { + log::error!("Error inserting words_datasets: {:?}", e); + sentry::capture_message("Error inserting words_datasets", sentry::Level::Error); + ServiceError::InternalServerError(format!("Error inserting words_datasets: {:?}", e)) + })?; + } + + words_inserter.end().await.map_err(|e| { + log::error!("Error inserting words_datasets: {:?}", e); + sentry::capture_message("Error inserting words_datasets", sentry::Level::Error); + ServiceError::InternalServerError(format!("Error inserting words_datasets: {:?}", e)) + })?; + + Ok(()) +} + +#[derive(Serialize, Deserialize, Clone, Debug, Row)] +pub struct WordDatasetCount { + #[serde(with = "clickhouse::serde::uuid")] + pub id: uuid::Uuid, + pub word: String, + pub count: i32, +} + +#[tracing::instrument(skip(clickhouse_client))] +pub async fn scroll_words_from_dataset( + dataset_id: uuid::Uuid, + offset: uuid::Uuid, + last_processed: Option, + limit: i64, + clickhouse_client: &clickhouse::Client, +) -> Result>, ServiceError> { + let mut query = format!( + " + SELECT + id, + word, + count, + FROM words_datasets + WHERE dataset_id = '{}' AND id > '{}' + ", + dataset_id, offset, + ); + + if let Some(last_processed) = last_processed { + query = format!( + "{} AND created_at >= '{}'", + query, + last_processed + .format( + &format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]",) + .unwrap() + ) + .map_err(|e| { + log::error!("Error formatting last processed time: {:?}", e); + sentry::capture_message( + "Error formatting last processed time", + sentry::Level::Error, + ); + ServiceError::InternalServerError(format!( + "Error formatting last processed time: {:?}", + e + )) + })? + ); + } + + query = format!("{} ORDER BY id LIMIT {}", query, limit); + + let words = clickhouse_client + .query(&query) + .fetch_all::() + .await + .map_err(|e| { + log::error!("Error fetching words from dataset: {:?}", e); + sentry::capture_message("Error fetching words from dataset", sentry::Level::Error); + ServiceError::InternalServerError(format!("Error fetching words from dataset: {:?}", e)) + })?; + + if words.is_empty() { + Ok(None) + } else { + Ok(Some(words)) + } +} + +pub async fn update_dataset_last_processed_query( + dataset_id: uuid::Uuid, + clickhouse_client: &clickhouse::Client, +) -> Result<(), ServiceError> { + let query = format!( + " + INSERT INTO dataset_words_last_processed (dataset_id, last_processed) + VALUES ('{}', now()) + ", + dataset_id + ); + + clickhouse_client + .query(&query) + .execute() + .await + .map_err(|e| { + log::error!("Error updating last processed time: {:?}", e); + sentry::capture_message("Error updating last processed time", sentry::Level::Error); + ServiceError::InternalServerError(format!( + "Error updating last processed time: {:?}", + e + )) + })?; + + Ok(()) +} diff --git a/server/src/operators/invitation_operator.rs b/server/src/operators/invitation_operator.rs index 34ea272ef0..100ce613e1 100644 --- a/server/src/operators/invitation_operator.rs +++ b/server/src/operators/invitation_operator.rs @@ -60,7 +60,7 @@ pub async fn send_invitation( humans@trieve.ai", org_name, inv_url, - inv_url.split('?').collect::>()[0], + inv_url.split('?').collect::>().get(0).unwrap_or(&""), invitation.email ); diff --git a/server/src/operators/message_operator.rs b/server/src/operators/message_operator.rs index fd93ef733d..cb0ee6e5a2 100644 --- a/server/src/operators/message_operator.rs +++ b/server/src/operators/message_operator.rs @@ -1,6 +1,6 @@ use crate::data::models::{ self, ChunkMetadataStringTagSet, ChunkMetadataTypes, Dataset, DatasetConfiguration, QueryTypes, - RagQueryEventClickhouse, SearchMethod, + RagQueryEventClickhouse, RedisPool, SearchMethod, }; use crate::diesel::prelude::*; use crate::get_env; @@ -222,13 +222,15 @@ pub async fn delete_message_query( Ok(()) } -#[tracing::instrument(skip(pool, event_queue))] +#[allow(clippy::too_many_arguments)] +#[tracing::instrument(skip(pool, redis_pool, event_queue))] pub async fn stream_response( messages: Vec, topic_id: uuid::Uuid, dataset: Dataset, pool: web::Data, event_queue: web::Data, + redis_pool: web::Data, dataset_config: DatasetConfiguration, create_message_req_payload: CreateMessageReqPayload, ) -> Result { @@ -374,6 +376,7 @@ pub async fn stream_response( search_chunk_data.clone(), parsed_query, pool.clone(), + redis_pool, dataset.clone(), &dataset_config, &mut search_timer, @@ -557,8 +560,12 @@ pub async fn stream_response( )) })?; - let completion_content = match &assistant_completion.choices[0].message.content { - ChatMessageContent::Text(text) => text.clone(), + let completion_content = match &assistant_completion + .choices + .get(0) + .map(|chat_completion_choice| chat_completion_choice.message.content.clone()) + { + Some(ChatMessageContent::Text(text)) => text.clone(), _ => "".to_string(), }; @@ -670,7 +677,11 @@ pub async fn stream_response( let chunk_stream = stream::iter(vec![Ok(Bytes::from(chunk_metadatas_stringified1))]); let completion_stream = stream.map(move |response| -> Result { if let Ok(response) = response { - let chat_content = response.choices[0].delta.content.clone(); + let chat_content = response + .choices + .get(0) + .map(|chat_completion_content| chat_completion_content.delta.content.clone()) + .unwrap_or(None); if let Some(message) = chat_content.clone() { s.send(message).unwrap(); } diff --git a/server/src/operators/mod.rs b/server/src/operators/mod.rs index 9496604d8e..03876b1ab0 100644 --- a/server/src/operators/mod.rs +++ b/server/src/operators/mod.rs @@ -16,3 +16,4 @@ pub mod search_operator; pub mod stripe_operator; pub mod topic_operator; pub mod user_operator; +pub mod words_operator; diff --git a/server/src/operators/model_operator.rs b/server/src/operators/model_operator.rs index b00648e478..ef9c3faf8d 100644 --- a/server/src/operators/model_operator.rs +++ b/server/src/operators/model_operator.rs @@ -430,7 +430,7 @@ pub async fn get_dense_vectors( "query" => EmbeddingInput::String( format!( "{}{}", - dataset_config.EMBEDDING_QUERY_PREFIX, &clipped_messages[0] + dataset_config.EMBEDDING_QUERY_PREFIX, &clipped_messages.get(0).unwrap_or(&"".to_string()) ) .to_string(), ), diff --git a/server/src/operators/search_operator.rs b/server/src/operators/search_operator.rs index 01a16386a7..dec03ac53c 100644 --- a/server/src/operators/search_operator.rs +++ b/server/src/operators/search_operator.rs @@ -12,10 +12,11 @@ use super::model_operator::{ use super::qdrant_operator::{ count_qdrant_query, search_over_groups_query, GroupSearchResults, QdrantSearchQuery, VectorType, }; +use super::words_operator::correct_query; use crate::data::models::{ convert_to_date_time, ChunkGroup, ChunkGroupAndFileId, ChunkMetadata, ChunkMetadataTypes, ConditionType, ContentChunkMetadata, Dataset, DatasetConfiguration, GeoInfoWithBias, - HasIDCondition, QdrantSortBy, QueryTypes, ReRankOptions, ScoreChunk, ScoreChunkDTO, + HasIDCondition, QdrantSortBy, QueryTypes, ReRankOptions, RedisPool, ScoreChunk, ScoreChunkDTO, SearchMethod, SlimChunkMetadata, SortByField, SortBySearchType, SortOrder, UnifiedId, }; use crate::handlers::chunk_handler::{ @@ -985,6 +986,7 @@ pub async fn get_group_tag_set_filter_condition( #[derive(Serialize, Deserialize, Clone, Debug)] pub struct SearchOverGroupsQueryResult { pub search_results: Vec, + pub corrected_query: Option, pub total_chunk_pages: i64, } @@ -1042,6 +1044,7 @@ pub async fn retrieve_group_qdrant_points_query( Ok(SearchOverGroupsQueryResult { search_results: point_ids, + corrected_query: None, total_chunk_pages: pages, }) } @@ -1104,6 +1107,7 @@ impl From for SearchOverGroupsResults { #[schema(title = "V1")] pub struct DeprecatedSearchOverGroupsResponseBody { pub group_chunks: Vec, + pub corrected_query: Option, pub total_chunk_pages: i64, } @@ -1116,6 +1120,7 @@ impl DeprecatedSearchOverGroupsResponseBody { .into_iter() .map(|chunk| chunk.into()) .collect(), + corrected_query: self.corrected_query, total_pages: self.total_chunk_pages, } } @@ -1126,6 +1131,7 @@ impl DeprecatedSearchOverGroupsResponseBody { pub struct SearchOverGroupsResponseBody { pub id: uuid::Uuid, pub results: Vec, + pub corrected_query: Option, pub total_pages: i64, } @@ -1288,6 +1294,7 @@ pub async fn retrieve_chunks_for_groups( Ok(DeprecatedSearchOverGroupsResponseBody { group_chunks, + corrected_query: None, total_chunk_pages: search_over_groups_query_result.total_chunk_pages, }) } @@ -1529,6 +1536,7 @@ pub async fn retrieve_chunks_from_point_ids( Ok(SearchChunkQueryResponseBody { score_chunks, + corrected_query: None, total_chunk_pages: search_chunk_query_results.total_chunk_pages, }) } @@ -1780,11 +1788,12 @@ async fn get_qdrant_vector( } } -#[tracing::instrument(skip(timer, pool))] +#[tracing::instrument(skip(timer, pool, redis_pool))] pub async fn search_chunks_query( - data: SearchChunksReqPayload, + mut data: SearchChunksReqPayload, parsed_query: ParsedQueryTypes, pool: web::Data, + redis_pool: web::Data, dataset: Dataset, config: &DatasetConfiguration, timer: &mut Timer, @@ -1801,12 +1810,36 @@ pub async fn search_chunks_query( }; sentry::configure_scope(|scope| scope.set_span(Some(transaction.clone()))); - timer.add("start to create dense embedding vector"); + let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; + + if let Some(options) = &data.typo_options { + timer.add("start correcting query"); + match parsed_query { + ParsedQueryTypes::Single(ref mut query) => { + corrected_query = + correct_query(query.query.clone(), dataset.id, redis_pool, options).await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); + data.query = QueryTypes::Single(query.query.clone()); + } + ParsedQueryTypes::Multi(ref mut queries) => { + for (query, _) in queries { + corrected_query = + correct_query(query.query.clone(), dataset.id, redis_pool.clone(), options) + .await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); + } + } + } + timer.add("corrected query"); + } - timer.add("computed dense embedding"); + timer.add("start to create dense embedding vector"); let vector = get_qdrant_vector(data.clone().search_type, parsed_query.clone(), config).await?; + timer.add("computed dense embedding"); + let (sort_by, rerank_by) = match data.sort_options.as_ref().map(|d| d.sort_by.clone()) { Some(Some(sort_by)) => match sort_by { QdrantSortBy::Field(field) => (Some(field.clone()), None), @@ -1893,15 +1926,18 @@ pub async fn search_chunks_query( timer.add("reranking"); transaction.finish(); + result_chunks.corrected_query = corrected_query; + Ok(result_chunks) } #[allow(clippy::too_many_arguments)] -#[tracing::instrument(skip(timer, pool))] +#[tracing::instrument(skip(timer, pool, redis_pool))] pub async fn search_hybrid_chunks( - data: SearchChunksReqPayload, + mut data: SearchChunksReqPayload, parsed_query: ParsedQuery, pool: web::Data, + redis_pool: web::Data, dataset: Dataset, config: &DatasetConfiguration, timer: &mut Timer, @@ -1918,6 +1954,21 @@ pub async fn search_hybrid_chunks( }; sentry::configure_scope(|scope| scope.set_span(Some(transaction.clone()))); + let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; + + if let Some(options) = &data.typo_options { + timer.add("start correcting query"); + corrected_query = + correct_query(parsed_query.query.clone(), dataset.id, redis_pool, options).await?; + parsed_query.query = corrected_query + .clone() + .unwrap_or(parsed_query.query.clone()); + data.query = QueryTypes::Single(parsed_query.query.clone()); + + timer.add("corrected query"); + } + let dataset_config = DatasetConfiguration::from_json(dataset.server_configuration.clone()); let dense_vector_future = get_dense_vector( @@ -2033,6 +2084,7 @@ pub async fn search_hybrid_chunks( SearchChunkQueryResponseBody { score_chunks: reranked_chunks, + corrected_query, total_chunk_pages: result_chunks.total_chunk_pages, } }; @@ -2066,17 +2118,43 @@ pub async fn search_hybrid_chunks( } #[allow(clippy::too_many_arguments)] -#[tracing::instrument(skip(pool))] +#[tracing::instrument(skip(pool, timer, redis_pool))] pub async fn search_groups_query( - data: SearchWithinGroupReqPayload, + mut data: SearchWithinGroupReqPayload, parsed_query: ParsedQueryTypes, group: ChunkGroupAndFileId, pool: web::Data, + redis_pool: web::Data, dataset: Dataset, config: &DatasetConfiguration, + timer: &mut Timer, ) -> Result { let vector = get_qdrant_vector(data.clone().search_type, parsed_query.clone(), config).await?; + let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; + + if let Some(options) = &data.typo_options { + timer.add("start correcting query"); + match parsed_query { + ParsedQueryTypes::Single(ref mut query) => { + corrected_query = + correct_query(query.query.clone(), dataset.id, redis_pool, options).await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); + data.query = QueryTypes::Single(query.query.clone()); + } + ParsedQueryTypes::Multi(ref mut queries) => { + for (query, _) in queries { + corrected_query = + correct_query(query.query.clone(), dataset.id, redis_pool.clone(), options) + .await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); + } + } + } + timer.add("corrected query"); + } + let (sort_by, rerank_by) = match data.sort_options.as_ref().map(|d| d.sort_by.clone()) { Some(Some(sort_by)) => match sort_by { QdrantSortBy::Field(field) => (Some(field.clone()), None), @@ -2161,24 +2239,42 @@ pub async fn search_groups_query( Ok(SearchWithinGroupResults { bookmarks: result_chunks.score_chunks, group, + corrected_query, total_pages: result_chunks.total_chunk_pages, }) } #[allow(clippy::too_many_arguments)] -#[tracing::instrument(skip(pool))] +#[tracing::instrument(skip(pool, timer, redis_pool))] pub async fn search_hybrid_groups( - data: SearchWithinGroupReqPayload, + mut data: SearchWithinGroupReqPayload, parsed_query: ParsedQuery, group: ChunkGroupAndFileId, pool: web::Data, + redis_pool: web::Data, dataset: Dataset, config: &DatasetConfiguration, + timer: &mut Timer, ) -> Result { let dataset_config = DatasetConfiguration::from_json(dataset.server_configuration.clone()); + let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; + + if let Some(options) = &data.typo_options { + timer.add("start correcting query"); + corrected_query = + correct_query(parsed_query.query.clone(), dataset.id, redis_pool, options).await?; + parsed_query.query = corrected_query + .clone() + .unwrap_or(parsed_query.query.clone()); + data.query = QueryTypes::Single(parsed_query.query.clone()); + + timer.add("corrected query"); + } + let dense_vector_future = get_dense_vector( - data.query.clone().to_single_query()?, + parsed_query.query.clone(), None, "query", dataset_config.clone(), @@ -2328,6 +2424,7 @@ pub async fn search_hybrid_groups( SearchChunkQueryResponseBody { score_chunks: reranked_chunks, + corrected_query: None, total_chunk_pages: result_chunks.total_chunk_pages, } }; @@ -2335,21 +2432,47 @@ pub async fn search_hybrid_groups( Ok(SearchWithinGroupResults { bookmarks: reranked_chunks.score_chunks, group, + corrected_query, total_pages: result_chunks.total_chunk_pages, }) } -#[tracing::instrument(skip(timer, pool))] +#[tracing::instrument(skip(timer, pool, redis_pool))] pub async fn semantic_search_over_groups( - data: SearchOverGroupsReqPayload, + mut data: SearchOverGroupsReqPayload, parsed_query: ParsedQueryTypes, pool: web::Data, + redis_pool: web::Data, dataset: Dataset, config: &DatasetConfiguration, timer: &mut Timer, ) -> Result { let dataset_config = DatasetConfiguration::from_json(dataset.server_configuration.clone()); + let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; + + if let Some(options) = &data.typo_options { + timer.add("start correcting query"); + match parsed_query { + ParsedQueryTypes::Single(ref mut query) => { + corrected_query = + correct_query(query.query.clone(), dataset.id, redis_pool, options).await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); + data.query = QueryTypes::Single(query.query.clone()); + } + ParsedQueryTypes::Multi(ref mut queries) => { + for (query, _) in queries { + corrected_query = + correct_query(query.query.clone(), dataset.id, redis_pool.clone(), options) + .await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); + } + } + } + timer.add("corrected query"); + } + timer.add("start to create dense embedding vector"); let embedding_vector = get_qdrant_vector( @@ -2400,15 +2523,17 @@ pub async fn semantic_search_over_groups( timer.add("fetched from postgres"); //TODO: rerank for groups + result_chunks.corrected_query = corrected_query; Ok(result_chunks) } -#[tracing::instrument(skip(timer, pool))] +#[tracing::instrument(skip(timer, pool, redis_pool))] pub async fn full_text_search_over_groups( - data: SearchOverGroupsReqPayload, + mut data: SearchOverGroupsReqPayload, parsed_query: ParsedQueryTypes, pool: web::Data, + redis_pool: web::Data, dataset: Dataset, config: &DatasetConfiguration, timer: &mut Timer, @@ -2424,6 +2549,30 @@ pub async fn full_text_search_over_groups( timer.add("computed sparse vector"); + let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; + + if let Some(options) = &data.typo_options { + timer.add("start correcting query"); + match parsed_query { + ParsedQueryTypes::Single(ref mut query) => { + corrected_query = + correct_query(query.query.clone(), dataset.id, redis_pool, options).await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); + data.query = QueryTypes::Single(query.query.clone()); + } + ParsedQueryTypes::Multi(ref mut queries) => { + for (query, _) in queries { + corrected_query = + correct_query(query.query.clone(), dataset.id, redis_pool.clone(), options) + .await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); + } + } + } + timer.add("corrected query"); + } + let search_over_groups_qdrant_result = retrieve_group_qdrant_points_query( embedding_vector, data.page.unwrap_or(1), @@ -2463,6 +2612,7 @@ pub async fn full_text_search_over_groups( timer.add("fetched from postgres"); //TODO: rerank for groups + result_groups_with_chunk_hits.corrected_query = corrected_query; Ok(result_groups_with_chunk_hits) } @@ -2527,17 +2677,33 @@ async fn cross_encoder_for_groups( Ok(group_results) } -#[tracing::instrument(skip(timer, pool))] +#[tracing::instrument(skip(timer, pool, redis_pool))] pub async fn hybrid_search_over_groups( - data: SearchOverGroupsReqPayload, + mut data: SearchOverGroupsReqPayload, parsed_query: ParsedQuery, pool: web::Data, + redis_pool: web::Data, dataset: Dataset, config: &DatasetConfiguration, timer: &mut Timer, ) -> Result { let dataset_config = DatasetConfiguration::from_json(dataset.server_configuration.clone()); + let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; + + if let Some(options) = &data.typo_options { + timer.add("start correcting query"); + corrected_query = + correct_query(parsed_query.query.clone(), dataset.id, redis_pool, options).await?; + parsed_query.query = corrected_query + .clone() + .unwrap_or(parsed_query.query.clone()); + data.query = QueryTypes::Single(parsed_query.query.clone()); + + timer.add("corrected query"); + } + timer.add("start to create dense embedding vector and sparse vector"); let dense_embedding_vectors_future = get_dense_vector( @@ -2602,6 +2768,7 @@ pub async fn hybrid_search_over_groups( let combined_search_chunk_query_results = SearchOverGroupsQueryResult { search_results: combined_results, + corrected_query: None, total_chunk_pages: semantic_results.total_chunk_pages, }; @@ -2652,7 +2819,9 @@ pub async fn hybrid_search_over_groups( timer.add("reranking"); if let Some(score_threshold) = data.score_threshold { - reranked_chunks.retain(|chunk| chunk.metadata[0].score >= score_threshold.into()); + reranked_chunks.retain(|chunk| { + chunk.metadata.get(0).map(|m| m.score).unwrap_or(0.0) >= score_threshold.into() + }); reranked_chunks.iter_mut().for_each(|chunk| { chunk .metadata @@ -2662,6 +2831,7 @@ pub async fn hybrid_search_over_groups( let result_chunks = DeprecatedSearchOverGroupsResponseBody { group_chunks: reranked_chunks, + corrected_query, total_chunk_pages: combined_search_chunk_query_results.total_chunk_pages, }; @@ -2670,16 +2840,33 @@ pub async fn hybrid_search_over_groups( Ok(result_chunks) } -#[tracing::instrument(skip(timer, pool))] +#[tracing::instrument(skip(timer, pool, redis_pool))] pub async fn autocomplete_chunks_query( - data: AutocompleteReqPayload, + mut data: AutocompleteReqPayload, parsed_query: ParsedQuery, pool: web::Data, + redis_pool: web::Data, dataset: Dataset, config: &DatasetConfiguration, timer: &mut Timer, ) -> Result { let parent_span = sentry::configure_scope(|scope| scope.get_span()); + + let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; + + if let Some(options) = &data.typo_options { + timer.add("start correcting query"); + corrected_query = + correct_query(parsed_query.query.clone(), dataset.id, redis_pool, options).await?; + parsed_query.query = corrected_query + .clone() + .unwrap_or(parsed_query.query.clone()); + data.query.clone_from(&parsed_query.query); + + timer.add("corrected query"); + } + let transaction: sentry::TransactionOrSpan = match &parent_span { Some(parent) => parent .start_child("semantic search", "Search Semantic Chunks") @@ -2729,10 +2916,11 @@ pub async fn autocomplete_chunks_query( .await?, ]; - qdrant_query[0] - .filter - .must - .push(Condition::matches_text("content", data.query.clone())); + if let Some(q) = qdrant_query.get_mut(0) { + q.filter + .must + .push(Condition::matches_text("content", data.query.clone())); + } if data.extend_results.unwrap_or(false) { qdrant_query.push( @@ -2814,6 +3002,8 @@ pub async fn autocomplete_chunks_query( timer.add("reranking"); transaction.finish(); + result_chunks.corrected_query = corrected_query; + Ok(result_chunks) } diff --git a/server/src/operators/words_operator.rs b/server/src/operators/words_operator.rs new file mode 100644 index 0000000000..4c424deeab --- /dev/null +++ b/server/src/operators/words_operator.rs @@ -0,0 +1,588 @@ +use std::{ + collections::{HashMap, HashSet}, + io::Write, + sync::{Arc, Mutex}, + time::{Duration, Instant}, +}; + +use crate::{ + data::models::{RedisPool, TypoOptions, TypoRange}, + errors::ServiceError, +}; +use actix_web::web; +use flate2::{ + write::{GzDecoder, GzEncoder}, + Compression, +}; +use itertools::Itertools; +use lazy_static::lazy_static; +use rayon::prelude::*; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::collections::VecDeque; +use tokio::sync::RwLock; + +#[derive(Clone, Debug, Eq, PartialEq)] +struct Node { + word: String, + count: i32, + children: Vec<(isize, Node)>, +} + +/// A BK-tree datastructure +/// +#[derive(Clone, Debug)] +pub struct BkTree { + root: Option>, +} + +#[derive(Serialize, Deserialize)] +struct FlatNode { + parent_index: Option, + distance: Option, + word: String, + count: i32, +} + +impl Serialize for BkTree { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut queue = VecDeque::new(); + let mut flat_tree = Vec::new(); + + if let Some(root) = &self.root { + queue.push_back((None, None, root.as_ref())); + } + + while let Some((parent_index, distance, node)) = queue.pop_front() { + let current_index = flat_tree.len(); + flat_tree.push(FlatNode { + parent_index, + distance, + word: node.word.clone(), + count: node.count, + }); + + for (child_distance, child) in &node.children { + queue.push_back((Some(current_index), Some(*child_distance), child)); + } + } + + let binary_data = bincode::serialize(&flat_tree).map_err(serde::ser::Error::custom)?; + serializer.serialize_bytes(&binary_data) + } +} + +impl<'de> Deserialize<'de> for BkTree { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let binary_data: Vec = Vec::deserialize(deserializer)?; + let flat_tree: Vec = + bincode::deserialize(&binary_data).map_err(serde::de::Error::custom)?; + + if flat_tree.is_empty() { + return Ok(BkTree { root: None }); + } + + let mut nodes: Vec = flat_tree + .iter() + .map(|flat_node| Node { + word: flat_node.word.clone(), + count: flat_node.count, + children: Vec::new(), + }) + .collect(); + + // Reconstruct the tree structure + for i in (1..nodes.len()).rev() { + let parent_index = flat_tree[i].parent_index.unwrap(); + let distance = flat_tree[i].distance.unwrap(); + let child = nodes.remove(i); + nodes[parent_index].children.push((distance, child)); + } + + Ok(BkTree { + root: Some(Box::new(nodes.remove(0))), + }) + } +} + +impl Default for BkTree { + fn default() -> Self { + Self::new() + } +} + +pub fn levenshtein_distance>(a: &S, b: &S) -> isize { + let a = a.as_ref().to_lowercase(); + let b = b.as_ref().to_lowercase(); + + if a == b { + return 0; + } + + let a_len = a.chars().count(); + let b_len = b.chars().count(); + + if a_len == 0 { + return b_len as isize; + } + + if b_len == 0 { + return a_len as isize; + } + + let mut res = 0; + let mut cache: Vec = (1..).take(a_len).collect(); + let mut a_dist; + let mut b_dist; + + for (ib, cb) in b.chars().enumerate() { + res = ib; + a_dist = ib; + for (ia, ca) in a.chars().enumerate() { + b_dist = if ca == cb { a_dist } else { a_dist + 1 }; + a_dist = cache[ia]; + + res = if a_dist > res { + if b_dist > res { + res + 1 + } else { + b_dist + } + } else if b_dist > a_dist { + a_dist + 1 + } else { + b_dist + }; + + cache[ia] = res; + } + } + + res as isize +} + +impl BkTree { + /// Create a new BK-tree + pub fn new() -> Self { + Self { root: None } + } + + /// Insert every element from a given iterator in the BK-tree + pub fn insert_all>(&mut self, iter: I) { + for i in iter { + self.insert(i); + } + } + + /// Insert a new element in the BK-tree + pub fn insert(&mut self, val: (String, i32)) { + match self.root { + None => { + self.root = Some(Box::new(Node { + word: val.0, + count: val.1, + children: Vec::new(), + })) + } + Some(ref mut root_node) => { + let mut u = &mut **root_node; + loop { + let k = levenshtein_distance(&u.word, &val.0); + if k == 0 { + u.count = val.1; + return; + } + + if val.1 == 1 { + return; + } + + let v = u.children.iter().position(|(dist, _)| *dist == k); + match v { + None => { + u.children.push(( + k, + Node { + word: val.0, + count: val.1, + children: Vec::new(), + }, + )); + return; + } + Some(pos) => { + let (_, ref mut vnode) = u.children[pos]; + u = vnode; + } + } + } + } + } + } + + /// Find the closest elements to a given value present in the BK-tree + /// + /// Returns pairs of element references and distances + pub fn find(&self, val: String, max_dist: isize) -> Vec<((&String, &i32), isize)> { + match self.root { + None => Vec::new(), + Some(ref root) => { + let found = Arc::new(Mutex::new(Vec::new())); + let mut candidates: Vec<&Node> = vec![root]; + + while !candidates.is_empty() { + let next_candidates: Vec<&Node> = if candidates.len() > 1000 { + candidates + .par_iter() + .flat_map(|&n| { + let distance = levenshtein_distance(&n.word, &val); + let mut local_candidates = Vec::new(); + + if distance <= max_dist { + found.lock().unwrap().push(((&n.word, &n.count), distance)); + } + + for (arc, node) in &n.children { + if (*arc - distance).abs() <= max_dist { + local_candidates.push(node); + } + } + + local_candidates + }) + .collect() + } else { + candidates + .iter() + .flat_map(|&n| { + let distance = levenshtein_distance(&n.word, &val); + if distance <= max_dist { + found.lock().unwrap().push(((&n.word, &n.count), distance)); + } + n.children + .iter() + .filter(|(arc, _)| (*arc - distance).abs() <= max_dist) + .map(|(_, node)| node) + .collect::>() + }) + .collect() + }; + + candidates = next_candidates; + } + + let mut result = Arc::try_unwrap(found).unwrap().into_inner().unwrap(); + result.sort_by_key(|&(_, dist)| dist); + result + } + } + } + + /// Create an iterator over references of BK-tree elements, in no particular order + pub fn iter(&self) -> Iter { + let mut queue = Vec::new(); + if let Some(ref root) = self.root { + queue.push(&**root); + } + Iter { queue } + } + + pub async fn from_redis( + dataset_id: uuid::Uuid, + redis_pool: web::Data, + ) -> Result, ServiceError> { + let mut redis_conn = redis_pool.get().await.map_err(|_| { + ServiceError::InternalServerError("Failed to get redis connection".to_string()) + })?; + + let compressed_bk_tree: Option> = redis::cmd("GET") + .arg(format!("bk_tree_{}", dataset_id)) + .query_async(&mut *redis_conn) + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + + if let Some(compressed_bk_tree) = compressed_bk_tree { + let buf = Vec::new(); + let mut decoder = GzDecoder::new(buf); + decoder.write_all(&compressed_bk_tree).map_err(|err| { + ServiceError::InternalServerError(format!("Failed to decompress bk tree {}", err)) + })?; + + let serialized_bk_tree = decoder.finish().map_err(|err| { + ServiceError::InternalServerError(format!( + "Failed to finish decompressing bk tree {}", + err + )) + })?; + + let tree = bincode::deserialize(&serialized_bk_tree).map_err(|err| { + ServiceError::InternalServerError(format!("Failed to deserialize bk tree {}", err)) + })?; + + Ok(Some(tree)) + } else { + Ok(None) + } + } + + pub async fn save( + &self, + dataset_id: uuid::Uuid, + redis_pool: web::Data, + ) -> Result<(), ServiceError> { + if self.root.is_none() { + return Ok(()); + } + let mut redis_conn = redis_pool.get().await.map_err(|_| { + ServiceError::InternalServerError("Failed to get redis connection".to_string()) + })?; + + let uncompressed_bk_tree = bincode::serialize(self).map_err(|_| { + ServiceError::InternalServerError("Failed to serialize bk tree".to_string()) + })?; + + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(&uncompressed_bk_tree).map_err(|_| { + ServiceError::InternalServerError("Failed to compress bk tree".to_string()) + })?; + + let serialized_bk_tree = encoder.finish().map_err(|_| { + ServiceError::InternalServerError("Failed to finish compressing bk tree".to_string()) + })?; + + redis::cmd("SET") + .arg(format!("bk_tree_{}", dataset_id)) + .arg(serialized_bk_tree) + .query_async(&mut *redis_conn) + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + + Ok(()) + } +} + +/// Iterator over BK-tree elements +pub struct IntoIter { + queue: Vec, +} + +impl Iterator for IntoIter { + type Item = (String, i32); + fn next(&mut self) -> Option { + self.queue.pop().map(|node| { + self.queue.extend(node.children.into_iter().map(|(_, n)| n)); + (node.word, node.count) + }) + } +} + +/// Iterator over BK-tree elements, by reference +pub struct Iter<'a> { + queue: Vec<&'a Node>, +} + +impl<'a> Iterator for Iter<'a> { + type Item = (&'a String, &'a i32); + fn next(&mut self) -> Option { + self.queue.pop().map(|node| { + self.queue.extend(node.children.iter().map(|(_, n)| n)); + (&node.word, &node.count) + }) + } +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct ProcessWordsFromDatasetMessage { + pub chunks_to_process: Vec<(uuid::Uuid, uuid::Uuid)>, // chunk_id, dataset_id + pub attempt_number: usize, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct CreateBkTreeMessage { + pub dataset_id: uuid::Uuid, + pub attempt_number: usize, +} + +struct BKTreeCacheEntry { + bktree: BkTree, + expiration: Instant, +} + +pub struct BKTreeCache { + cache: RwLock>, +} + +lazy_static! { + static ref BKTREE_CACHE: BKTreeCache = BKTreeCache::new(); +} + +impl BKTreeCache { + fn new() -> Self { + Self { + cache: RwLock::new(HashMap::new()), + } + } + + fn insert_with_ttl(&self, id: uuid::Uuid, bktree: BkTree, ttl: Duration) { + if let Ok(mut cache) = self.cache.try_write() { + let entry = BKTreeCacheEntry { + bktree, + expiration: Instant::now() + ttl, + }; + cache.insert(id, entry); + }; + } + + fn get_if_valid(&self, id: &uuid::Uuid) -> Option { + match self.cache.try_read() { + Ok(cache) => cache.get(id).and_then(|entry| { + if Instant::now() < entry.expiration { + Some(entry.bktree.clone()) + } else { + None + } + }), + _ => None, + } + } + + fn remove_expired(&self) { + if let Ok(mut cache) = self.cache.try_write() { + cache.retain(|_, entry| Instant::now() < entry.expiration); + } + } + + pub fn enforce_cache_ttl() { + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(60)); // Run every 60 seconds + + loop { + interval.tick().await; + BKTREE_CACHE.remove_expired(); + } + }); + } +} + +fn correct_query_helper(tree: &BkTree, query: String, options: &TypoOptions) -> Option { + let query_split_by_whitespace = query + .split_whitespace() + .map(|s| s.to_string()) + .collect_vec(); + let mut query_split_to_correction: HashMap = HashMap::new(); + let excluded_words = options + .clone() + .disable_on_word + .unwrap_or_default() + .into_iter() + .map(|s| s.to_lowercase()) + .collect::>(); + + for split in &query_split_by_whitespace { + if excluded_words.contains(&split.to_lowercase()) { + continue; + } + + let exact_match = tree.find(split.to_string(), 0); + + if !exact_match.is_empty() { + continue; + } + + let mut corrections = vec![]; + + let num_chars = split.chars().collect_vec().len(); + + let single_typo_range = options.clone().one_typo_word_range.unwrap_or(TypoRange { + min: 5, + max: Some(8), + }); + + if num_chars >= (single_typo_range.min as usize) + && num_chars <= (single_typo_range.max.unwrap_or(u32::MAX) as usize) + { + corrections.extend_from_slice(&tree.find(split.to_string(), 1)); + } + + let two_typo_range = options + .clone() + .two_typo_word_range + .unwrap_or(TypoRange { min: 8, max: None }); + + if num_chars >= (two_typo_range.min as usize) + && num_chars <= (two_typo_range.max.unwrap_or(u32::MAX) as usize) + { + corrections.extend_from_slice(&tree.find(split.to_string(), 2)); + } + + corrections.sort_by(|((_, freq_a), _), ((_, freq_b), _)| (**freq_b).cmp(*freq_a)); + + if let Some(((correction, _), _)) = corrections.get(0) { + query_split_to_correction.insert(split.to_string(), correction.to_string()); + } + } + + let mut corrected_query = query.clone(); + + if !query_split_to_correction.is_empty() { + for (og_string, correction) in query_split_to_correction { + corrected_query = corrected_query.replacen(&og_string, &correction, 1); + } + Some(corrected_query) + } else { + None + } +} + +#[tracing::instrument(skip(redis_pool))] +pub async fn correct_query( + query: String, + dataset_id: uuid::Uuid, + redis_pool: web::Data, + options: &TypoOptions, +) -> Result, ServiceError> { + if matches!(options.correct_typos, None | Some(false)) { + return Ok(None); + } + + match BKTREE_CACHE.get_if_valid(&dataset_id) { + Some(tree) => Ok(correct_query_helper(&tree, query, options)), + None => { + let dataset_id = dataset_id; + let redis_pool = redis_pool.clone(); + log::info!("Pulling new BK tree from Redis"); + tokio::spawn(async move { + match BkTree::from_redis(dataset_id, redis_pool).await { + // TTL of 1 day + Ok(Some(bktree)) => { + BKTREE_CACHE.insert_with_ttl( + dataset_id, + bktree, + Duration::from_secs(60 * 60 * 24), + ); + log::info!( + "Inserted new BK tree into cache for dataset_id: {:?}", + dataset_id + ); + } + Ok(None) => { + log::info!("No BK tree found in Redis for dataset_id: {:?}", dataset_id); + } + Err(e) => { + log::info!( + "Failed to insert new BK tree into cache {:?} for dataset_id: {:?}", + e, + dataset_id + ); + } + }; + }); + Ok(None) + } + } +}