-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add clustering based on Francis' code
- Loading branch information
Showing
10 changed files
with
708 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
conda create -n kfcluster python=3.11 -y | ||
|
||
SO | ||
conda activate kfcluster | ||
|
||
then based on Francis' https://github.com/nla/gen_ai/blob/main/src/cluster.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# kf13sep24 | ||
# uses files generated by https://hinton.nla.gov.au:4321/admin/generateFilesForClustering | ||
# hacked from Francis' https://github.com/nla/gen_ai/blob/main/src/cluster.py | ||
# run like this: python3 cluster.py data > clusterRunLog (defaults - only 3 clusters!? - 5358 items out of 6009 items) (24 sec) | ||
# python3 cluster.py data -m 10 -e 50 > clusterRunLog2 (found 37 clusters, 1498 items) | ||
# python3 cluster.py data -m 5 -e 50 > clusterRunLog3 (found only 3 clusters!) (23 sec) | ||
# python3 cluster.py data -m 8 -e 100 > clusterRunLog4 (found 47 clusters - 1554 items) going to use this one! | ||
# python3 cluster.py data -m 20 -e 100 > clusterRunLog5 (found 20 clusters - 1545 items) | ||
# python3 cluster.py data -m 15 -e 100 > clusterRunLog6 (found 24 clusters - 1426 items) | ||
|
||
import argparse | ||
import json | ||
import os | ||
from datetime import datetime | ||
|
||
import hdbscan | ||
import numpy as np | ||
from sklearn.cluster import KMeans | ||
|
||
# from domain.metadata import (METADATA_FILE, TITLE, URL, FILE_PATH, CLUSTER_ID, CLUSTER_SIZE, CLUSTER_URL, | ||
# SUMMARY_EMBEDDING) | ||
# from src.index_records import clean_metadata | ||
|
||
HDBSCAN = "hdbscan" | ||
KMEANS = "kmeans" | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Cluster embedding (vector) data") | ||
parser.add_argument("data_root", type=str, help="Root of file tree to load metadata JSON files from") | ||
parser.add_argument('-a', "--algorithm", type=str, default=HDBSCAN, | ||
help="Clustering algorithm (hdbscan, kmeans) - default is hdbscan") | ||
parser.add_argument('-m', "--min_cluster_size", type=int, default="5", | ||
help="Minimum cluster size (required for hdbscan) - default is 5") | ||
parser.add_argument('-e', "--expected_num_clusters", type=int, default="10", | ||
help="Expected number of clusters (required for kmeans) - default is 10") | ||
parser.add_argument('-u', "--update_records", action='store_true', | ||
help="Update metadata records with cluster ID and cluster size for each record in a cluster") | ||
|
||
|
||
args = parser.parse_args() | ||
process_file_tree(args.data_root, args.min_cluster_size, args.algorithm, args.expected_num_clusters, | ||
args.update_records) | ||
|
||
|
||
def process_file_tree(data_root, min_cluster_size, algorithm, expected_num_clusters, update_records): | ||
formatted_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | ||
print(f"Starting clustering from: {data_root} at {formatted_datetime}") | ||
|
||
embeddings_array, metadata_list = get_embeddings_and_metadata(data_root) | ||
clusters = get_clusters(embeddings_array, algorithm, min_cluster_size, expected_num_clusters) | ||
|
||
if clusters: | ||
print(f"Number of clusters found: {clusters.labels_.max() + 1}\n") | ||
sorted_clusters = sort_clusters(clusters, embeddings_array, metadata_list, update_records) | ||
total_clustered_items = 0 | ||
|
||
for cluster_size, cluster_id, metadata in sorted_clusters: | ||
print(f"Cluster {cluster_id} - Size: {cluster_size}, URL: {metadata['id']}") | ||
total_clustered_items += cluster_size | ||
|
||
print(f"\nTotal clustered items: {total_clustered_items}") | ||
else: | ||
print("No clusters generated.") | ||
|
||
formatted_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | ||
print(f"Finished clustering at: {formatted_datetime}") | ||
|
||
|
||
def get_embeddings_and_metadata(data_root): | ||
embeddings = [] | ||
records_metadata = [] | ||
|
||
for root, dirs, files in os.walk(data_root): | ||
for file in files: | ||
if file.endswith(".json"): | ||
file_path = os.path.join(root, file) | ||
|
||
try: | ||
metadata = get_metadata(file_path) | ||
|
||
if metadata and metadata["clipEmbedding"]: | ||
# print( "read metadata " + file_path + " id " + metadata["id"]) | ||
embeddings.append(metadata["clipEmbedding"]) | ||
metadata["file_path"] = file_path | ||
records_metadata.append(metadata) | ||
except Exception as e: | ||
print(f"Error reading {file_path}: {e}") | ||
|
||
return np.array(embeddings), records_metadata | ||
|
||
|
||
def get_clusters(embeddings_array, algorithm, min_cluster_size, expected_num_clusters, min_samples=3): | ||
print(f"Total items to cluster: {embeddings_array.shape[0]}") | ||
print(f"Clustering algorithm: {algorithm}") | ||
|
||
if algorithm == HDBSCAN: | ||
print(f"Minimum cluster size: {min_cluster_size}") | ||
print(f"Minimum samples: {min_samples}") | ||
clusters = hdbscan.HDBSCAN(min_samples=min_samples, min_cluster_size=min_cluster_size) | ||
clusters.fit(embeddings_array) | ||
elif algorithm == KMEANS: | ||
print(f"Expected number of clusters: {expected_num_clusters}") | ||
clusters = KMeans(n_clusters=expected_num_clusters, init="k-means++", random_state=42) | ||
clusters.fit(embeddings_array) | ||
else: | ||
print(f"Error: Unsupported clustering algorithm: {algorithm}") | ||
return None | ||
|
||
return clusters | ||
|
||
|
||
def sort_clusters(clusters, embeddings_array, metadata_list, update_records): | ||
cluster_representatives = {} | ||
cluster_sizes = {} | ||
cluster_info = [] | ||
|
||
# Iterate through all unique cluster labels (excluding noise) | ||
for cluster_id in set(clusters.labels_): | ||
if cluster_id != -1: | ||
cluster_indices = np.where(clusters.labels_ == cluster_id)[0] | ||
cluster_size = len(cluster_indices) | ||
cluster_embeddings = embeddings_array[cluster_indices] | ||
cluster_mean = cluster_embeddings.mean(axis=0) | ||
|
||
# Find the closest point to the mean embedding | ||
closest_point_index = cluster_indices[np.argmin(np.linalg.norm(cluster_embeddings - cluster_mean, axis=1))] | ||
representative_metadata = metadata_list[closest_point_index] | ||
cluster_representatives[cluster_id] = representative_metadata["id"] | ||
cluster_sizes[cluster_id] = cluster_size | ||
|
||
# Store tuple of cluster information | ||
cluster_info.append((cluster_size, cluster_id, representative_metadata)) | ||
|
||
if update_records: | ||
update_metadata_records(clusters, cluster_sizes, metadata_list, cluster_representatives) | ||
else: | ||
writeClusterInfo(clusters, cluster_sizes, metadata_list, cluster_representatives) | ||
|
||
# Sort by cluster size in descending order | ||
sorted_clusters = sorted(cluster_info, key=lambda x: x[0], reverse=True) | ||
|
||
return sorted_clusters | ||
|
||
|
||
def writeClusterInfo(clusters, cluster_sizes, metadata_list, cluster_representatives): | ||
print("\nCluster info details") | ||
print("sourceId TAB clusterNumber TAB clusterSize TAB clusterCentroidId") | ||
|
||
for i, metadata in enumerate(metadata_list): | ||
cluster_id = clusters.labels_[i] | ||
if cluster_id != -1: | ||
print(f"{metadata['id']}\t{cluster_id}\t{cluster_sizes[cluster_id]}\t{cluster_representatives[cluster_id]}") | ||
print("End of cluster info details") | ||
|
||
def update_metadata_records(clusters, cluster_sizes, metadata_list, cluster_representatives): | ||
print("Updating metadata records with cluster details\n") | ||
|
||
for i, metadata in enumerate(metadata_list): | ||
cluster_id = clusters.labels_[i] | ||
if cluster_id != -1: | ||
metadata["cluster_id"] = int(cluster_id) | ||
metadata["cluster_size"] = cluster_sizes[cluster_id] | ||
metadata["cluster_url"] = cluster_representatives[cluster_id] | ||
|
||
# Ensure 'file_path' is not written back | ||
file_path = metadata.pop("file_path", None) | ||
|
||
if file_path: | ||
print("Writing back " + file_path) | ||
with open(file_path, 'w') as f: | ||
json.dump(metadata, f, indent=4) | ||
|
||
|
||
def get_metadata(file_path): | ||
with open(file_path, 'r') as f: | ||
metadata = json.load(f) | ||
# metadata = clean_metadata(metadata) | ||
|
||
return metadata | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# | ||
# Setup Python package and module structure. This script should be | ||
# placed in the top level of the repository and used by running: | ||
# | ||
# pip3 install setuptools | ||
# pip3 install -e . | ||
# | ||
|
||
from setuptools import setup, find_packages | ||
|
||
setup( | ||
name='Generative AI', | ||
version='0.1.0', | ||
packages=find_packages(), | ||
install_requires=[ | ||
'openai', | ||
'tenacity', | ||
'numpy', | ||
'hdbscan', | ||
'scikit-learn', | ||
'requests', | ||
'bs4', | ||
'cbor2', | ||
'Pillow', | ||
'termcolor' | ||
], | ||
entry_points={ | ||
'console_scripts': [ | ||
# Define command line scripts here if needed | ||
], | ||
}, | ||
author='Francis Crimmins', | ||
author_email='[email protected]', | ||
description='Tools for generative AI', | ||
#long_description=open('README.md').read(), | ||
#long_description_content_type='text/markdown', | ||
#url='https://github.com/nla/gen_ai', | ||
classifiers=[ | ||
'Programming Language :: Python :: 3', | ||
'License :: OSI Approved :: MIT License', | ||
'Operating System :: OS Independent', | ||
], | ||
python_requires='>=3.6', | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
<!DOCTYPE html> | ||
<html> | ||
|
||
<head> | ||
<meta charset="UTF-8"> | ||
<meta http-equiv="Content-Security-Policy" content="img-src 'self' https://nla.gov.au; script-src-attr 'self' ; script-src 'self' https://fonts.googleapis.com https://apis.google.com; style-src 'self' https://fonts.googleapis.com/ 'unsafe-inline' "> | ||
|
||
<!-- meta http-equiv="Content-Security-Policy" content="default-src *; script-src-attr 'self' ; script-src 'self' https://fonts.googleapis.com https://apis.google.com; style-src 'self' https://fonts.googleapis.com/ 'unsafe-inline'" --> | ||
|
||
<link rel="icon" type="image/png" sizes="32x32" href="/static/images/favicon-32x32.png"> | ||
<link rel="icon" type="image/png" sizes="16x16" href="/static/images/favicon-16x16.png"> | ||
<link rel="icon" type="image/png" href="/static/images/favicon.ico"> | ||
|
||
<link rel="stylesheet" href="/static/style/css.css" type="text/css"> | ||
|
||
<script src="/static/javascript/js.js"></script> | ||
<link href="https://fonts.googleapis.com/css2?family=Roboto&display=swap" rel="stylesheet"> | ||
<link href="https://fonts.googleapis.com/css2?family=Roboto+Condensed&display=swap" rel="stylesheet"> | ||
|
||
<title>NLA Image Search Evaluation</title> | ||
|
||
</head> | ||
<body> | ||
|
||
<H3>NLA Image Search Evaluation</H3> | ||
<HR/> | ||
<div class="content" style="padding-top:0.5em"> | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
<%function setSelect(t, value) { | ||
return (t == value) ? " selected " : "" ; | ||
} | ||
%> | ||
|
||
<form action="demoSearch" method="get" id="searchForm"> | ||
<input type="hidden" name="searchOrSim" value="search"/> | ||
<table style="margin-bottom:1em"> | ||
<tr valign='top'> | ||
<td nowrap> | ||
<input type='text' size='60' id='stxt' name="stxt" style='font-size:105%' value='<%=stxt%>' class='veryLightMauve'/> | ||
<button id='searchButton' type="submit" class='searchButton'><span id='searchText' style='font-size:120%;padding:15px'>Search</span></button> | ||
</td> | ||
<td nowrap style='padding-left:1em;font-variant:small-caps'> | ||
|
||
Method: | ||
<select name="method" id="method"> | ||
<option value="0" <%=setSelect(0, method)%>>CLIP Image similarity</option> | ||
<option value="1" <%=setSelect(1, method)%>>NLA metadata keyword</option> | ||
<option value="2" <%=setSelect(2, method)%>>OpenAI description keyword</option> | ||
<option value="3" <%=setSelect(3, method)%>>Phi-3.5 description keyword</option> | ||
<option value="4" <%=setSelect(4, method)%>>80% CLIP 20% NLA metadata</option> | ||
<option value="5" <%=setSelect(5, method)%>>50% CLIP 50% NLA metadata</option> | ||
<option value="6" <%=setSelect(6, method)%>>80% CLIP 20% OpenAI description</option> | ||
<option value="7" <%=setSelect(7, method)%>>50% CLIP 50% OpenAI description</option> | ||
<option value="8" <%=setSelect(8, method)%>>80% CLIP 20% Phi-3.5 description</option> | ||
<option value="9" <%=setSelect(9, method)%>>50% CLIP 50% Phi-3.5 description</option> | ||
<option value="10" <%=setSelect(10, method)%>>50% CLIP 30% OpenAI description 20% NLA metadata</option> | ||
<option value="11" <%=setSelect(11, method)%>>50% CLIP 30% Phi-3.5 description 20% NLA metadata</option> | ||
<option value="12" <%=setSelect(12, method)%>>Phi-3 description keyword</option> | ||
</select> | ||
</td> | ||
<td nowrap style='padding-left:1em;font-variant:small-caps'> | ||
Layout: | ||
<select name="layout" id="layout"> | ||
<option <%=setSelect("List", layout)%>>List</option> | ||
<option <%=setSelect("Grid", layout)%>>Grid</option> | ||
</select> | ||
</td> | ||
</tr> | ||
|
||
</table> | ||
</form> | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
</div> <!-- end of content div --> | ||
<div style='font-size:50%'>v26Aug-17:02</div> | ||
<div style='font-size:50%'>v12Sep-noon</div> | ||
</body> | ||
</HTML> |
Oops, something went wrong.