Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added visualization support for Collection queries. #3

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,29 +1,50 @@
# ChromaViz

A package for visualising vector embedding collections as part of the [Chroma](https://trychroma.com) vector database.
A package for visualising vector embedding collections as part of the [Chroma](https://trychroma.com) vector database.

Uses [Flask](https://flask.palletsprojects.com/en/2.3.x/), [Vite](https://vitejs.dev), and [react-three-fiber](https://github.com/pmndrs/react-three-fiber) to host a live 3D view of the data in a web browser, should perform well up to 10k+ documents. Dimensional reduction is performed using PCA for colors down to 50 dimensions, followed by tSNE down to 3.
Uses [Flask](https://flask.palletsprojects.com/en/2.3.x/), [Vite](https://vitejs.dev),
and [react-three-fiber](https://github.com/pmndrs/react-three-fiber) to host a live 3D view of the data in a web
browser, should perform well up to 10k+ documents. Dimensional reduction is performed using PCA for colors down to 50
dimensions, followed by tSNE down to 3.

## How to Use

`pip install chromaviz` or `pip install git+https://github.com/mtybadger/chromaviz/`.
After installing from pip, simply call `visualize_collection` with a valid ChromaDB collection, and chromaviz will do the rest.
```
After installing from pip, simply call `visualize_collection` with a valid ChromaDB collection, and chromaviz will do
the rest.

```python
from chromaviz import visualize_collection
visualize_collection(chromadb.Collection)
```
It also works with Langchain+Chroma, as in:

Visualization of query results:

```python
from chromaviz import visualize_collection
import chromadb
client = chromadb.HttpClient()
collection = client.get_collection("my_collection")
visualize_collection(collection,query="My question goes here",n_result=100)
```

It also works with Langchain+Chroma, as in:

```python
from langchain.vectorstores import Chroma
vectordb = Chroma.from_documents(data, embeddings, ids)

from chromaviz import visualize_collection
visualize_collection(vectordb._collection)
```

## Screenshots

![Screenshot of ChromaViz on a biological dataset](/images/1.png)
![Screenshot of ChromaViz close up](/images/2.png)

## To-Do

- [ ] More dimensional reduction options and flexibility
- [ ] Refactor extremely shoddy React code
- [ ] Improve UX
70 changes: 47 additions & 23 deletions chromaviz/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,58 +11,61 @@
import webbrowser

import importlib.resources

app = Flask(__name__)
CORS(app)

from flask import cli
from flask import Response
from flask import request

cli.show_server_banner = lambda *_: None

data = [[]]



@app.route("/")
def hello_world():
with importlib.resources.open_text("chromaviz", "index.html") as file:
contents = file.read()
return contents
contents = file.read()
return contents


@app.route('/assets/<path:filename>')
def serve_assets(filename):

mime = 'text/html'

if(".js" in filename):
mime = 'text/javascript'
if('.css' in filename):
if (".js" in filename):
mime = 'text/javascript'
if ('.css' in filename):
mime = 'text/css'
# Logic to serve the assets
# Here, you can use the `filename` parameter to determine which asset to serve
# You can use the `url_for` function to generate the URL for the asset dynamically
with importlib.resources.open_text("chromaviz", filename) as file:
contents = file.read()
return Response(contents, mimetype=mime)
contents = file.read()
return Response(contents, mimetype=mime)


@app.route("/data")
def data_api():
global data
df = pd.DataFrame.from_dict(data=data["embeddings"])
print(df)
print('Size of the dataframe: {}'.format(df.shape))

pca_50 = PCA(n_components=50)
pca_result_50 = pca_50.fit_transform(df)

print('Cumulative explained variation for 50 principal components: {}'.format(np.sum(pca_50.explained_variance_ratio_)))
print('Cumulative explained variation for 50 principal components: {}'.format(
np.sum(pca_50.explained_variance_ratio_)))

time_start = time.time()

tsne = TSNE(n_components=3, verbose=0, perplexity=40, n_iter=300)
tsne_pca_results = tsne.fit_transform(pca_result_50)

print('t-SNE done! Time elapsed: {} seconds'.format(time.time()-time_start))
print('t-SNE done! Time elapsed: {} seconds'.format(time.time() - time_start))
tsne_pca_results = tsne_pca_results / 3

pca_3 = PCA(n_components=3)
Expand All @@ -72,22 +75,43 @@ def data_api():
groups = np.argmax(pca_result_50, axis=1)

points = []
for position, document, metadata, id, group in zip(tsne_pca_results.tolist(), data["documents"], data["metadatas"], data["ids"], groups.tolist()):
for position, document, metadata, id, group in zip(tsne_pca_results.tolist(), data["documents"], data["metadatas"],
data["ids"], groups.tolist()):
point = {
'position': position,
'document': document,
'metadata': metadata,
'id': id,
'group': group
'position': position,
'document': document,
'metadata': metadata,
'id': id,
'group': group
}
points.append(point)
return json.dumps({'points': points})


client = chromadb.Client()

def visualize_collection(col: chromadb.api.models.Collection.Collection):

def visualize_collection(col: chromadb.api.models.Collection.Collection, query: str = None,
n_results: int = 50) -> None:
global data
data = col.get(include=["documents", "metadatas", "embeddings"])
webbrowser.open('http://127.0.0.1:5000')
if query is not None:
if n_results is None:
n_results = 50
if n_results < 50:
print("Warning: n_results is less than 50. This may lead to unexpected results.")
n_results = 50
result = col.query(query_texts=[query], n_results=n_results, include=["documents", "metadatas", "embeddings"])
if len(result["ids"]) < 50:
raise Exception("Query returned less than 50 results. This may lead to unexpected results.")
data = {
"ids": [id for id in result["ids"][0]] if result["ids"] else None,
"embeddings": [embed for embed in result["embeddings"][0]] if result["embeddings"] else None,
"documents": [doc for doc in result["documents"][0]] if result["documents"] else None,
"metadatas": [meta for meta in result["metadatas"][0]] if result["metadatas"] else None,
"distances": [dist for dist in result["distances"][0]] if result["distances"] else None
}
else:
data = col.get(include=["documents", "metadatas", "embeddings"])
webbrowser.open('http://127.0.0.1:5000')
app.run(port=5000, debug=False)
return
return
Loading