Skip to content

Commit

Permalink
add clustering based on Francis' code
Browse files Browse the repository at this point in the history
  • Loading branch information
kentfitch committed Sep 13, 2024
1 parent 013e2bf commit e653d44
Show file tree
Hide file tree
Showing 10 changed files with 708 additions and 8 deletions.
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,16 @@ data/*.tsv
web/data/search.tsv
web/data/similarity.tsv
web/data/junk/*
python/data/*

python/clusterRunLog
python/clusterRunLog2
python/clusterRunLog3
python/clusterRunLog4
python/clusterRunLog5
python/clusterRunLog6
python/Generative_AI.egg-info/*
python/Generative_AI.egg-info/dependency_links.txt
python/Generative_AI.egg-info/requires.txt
python/Generative_AI.egg-info/SOURCES.txt
python/Generative_AI.egg-info/top_level.txt
6 changes: 6 additions & 0 deletions python/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
conda create -n kfcluster python=3.11 -y

SO
conda activate kfcluster

then based on Francis' https://github.com/nla/gen_ai/blob/main/src/cluster.py
186 changes: 186 additions & 0 deletions python/cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#!/usr/bin/env python3

# kf13sep24
# uses files generated by https://hinton.nla.gov.au:4321/admin/generateFilesForClustering
# hacked from Francis' https://github.com/nla/gen_ai/blob/main/src/cluster.py
# run like this: python3 cluster.py data > clusterRunLog (defaults - only 3 clusters!? - 5358 items out of 6009 items) (24 sec)
# python3 cluster.py data -m 10 -e 50 > clusterRunLog2 (found 37 clusters, 1498 items)
# python3 cluster.py data -m 5 -e 50 > clusterRunLog3 (found only 3 clusters!) (23 sec)
# python3 cluster.py data -m 8 -e 100 > clusterRunLog4 (found 47 clusters - 1554 items) going to use this one!
# python3 cluster.py data -m 20 -e 100 > clusterRunLog5 (found 20 clusters - 1545 items)
# python3 cluster.py data -m 15 -e 100 > clusterRunLog6 (found 24 clusters - 1426 items)

import argparse
import json
import os
from datetime import datetime

import hdbscan
import numpy as np
from sklearn.cluster import KMeans

# from domain.metadata import (METADATA_FILE, TITLE, URL, FILE_PATH, CLUSTER_ID, CLUSTER_SIZE, CLUSTER_URL,
# SUMMARY_EMBEDDING)
# from src.index_records import clean_metadata

HDBSCAN = "hdbscan"
KMEANS = "kmeans"


def main():
parser = argparse.ArgumentParser(description="Cluster embedding (vector) data")
parser.add_argument("data_root", type=str, help="Root of file tree to load metadata JSON files from")
parser.add_argument('-a', "--algorithm", type=str, default=HDBSCAN,
help="Clustering algorithm (hdbscan, kmeans) - default is hdbscan")
parser.add_argument('-m', "--min_cluster_size", type=int, default="5",
help="Minimum cluster size (required for hdbscan) - default is 5")
parser.add_argument('-e', "--expected_num_clusters", type=int, default="10",
help="Expected number of clusters (required for kmeans) - default is 10")
parser.add_argument('-u', "--update_records", action='store_true',
help="Update metadata records with cluster ID and cluster size for each record in a cluster")


args = parser.parse_args()
process_file_tree(args.data_root, args.min_cluster_size, args.algorithm, args.expected_num_clusters,
args.update_records)


def process_file_tree(data_root, min_cluster_size, algorithm, expected_num_clusters, update_records):
formatted_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"Starting clustering from: {data_root} at {formatted_datetime}")

embeddings_array, metadata_list = get_embeddings_and_metadata(data_root)
clusters = get_clusters(embeddings_array, algorithm, min_cluster_size, expected_num_clusters)

if clusters:
print(f"Number of clusters found: {clusters.labels_.max() + 1}\n")
sorted_clusters = sort_clusters(clusters, embeddings_array, metadata_list, update_records)
total_clustered_items = 0

for cluster_size, cluster_id, metadata in sorted_clusters:
print(f"Cluster {cluster_id} - Size: {cluster_size}, URL: {metadata['id']}")
total_clustered_items += cluster_size

print(f"\nTotal clustered items: {total_clustered_items}")
else:
print("No clusters generated.")

formatted_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"Finished clustering at: {formatted_datetime}")


def get_embeddings_and_metadata(data_root):
embeddings = []
records_metadata = []

for root, dirs, files in os.walk(data_root):
for file in files:
if file.endswith(".json"):
file_path = os.path.join(root, file)

try:
metadata = get_metadata(file_path)

if metadata and metadata["clipEmbedding"]:
# print( "read metadata " + file_path + " id " + metadata["id"])
embeddings.append(metadata["clipEmbedding"])
metadata["file_path"] = file_path
records_metadata.append(metadata)
except Exception as e:
print(f"Error reading {file_path}: {e}")

return np.array(embeddings), records_metadata


def get_clusters(embeddings_array, algorithm, min_cluster_size, expected_num_clusters, min_samples=3):
print(f"Total items to cluster: {embeddings_array.shape[0]}")
print(f"Clustering algorithm: {algorithm}")

if algorithm == HDBSCAN:
print(f"Minimum cluster size: {min_cluster_size}")
print(f"Minimum samples: {min_samples}")
clusters = hdbscan.HDBSCAN(min_samples=min_samples, min_cluster_size=min_cluster_size)
clusters.fit(embeddings_array)
elif algorithm == KMEANS:
print(f"Expected number of clusters: {expected_num_clusters}")
clusters = KMeans(n_clusters=expected_num_clusters, init="k-means++", random_state=42)
clusters.fit(embeddings_array)
else:
print(f"Error: Unsupported clustering algorithm: {algorithm}")
return None

return clusters


def sort_clusters(clusters, embeddings_array, metadata_list, update_records):
cluster_representatives = {}
cluster_sizes = {}
cluster_info = []

# Iterate through all unique cluster labels (excluding noise)
for cluster_id in set(clusters.labels_):
if cluster_id != -1:
cluster_indices = np.where(clusters.labels_ == cluster_id)[0]
cluster_size = len(cluster_indices)
cluster_embeddings = embeddings_array[cluster_indices]
cluster_mean = cluster_embeddings.mean(axis=0)

# Find the closest point to the mean embedding
closest_point_index = cluster_indices[np.argmin(np.linalg.norm(cluster_embeddings - cluster_mean, axis=1))]
representative_metadata = metadata_list[closest_point_index]
cluster_representatives[cluster_id] = representative_metadata["id"]
cluster_sizes[cluster_id] = cluster_size

# Store tuple of cluster information
cluster_info.append((cluster_size, cluster_id, representative_metadata))

if update_records:
update_metadata_records(clusters, cluster_sizes, metadata_list, cluster_representatives)
else:
writeClusterInfo(clusters, cluster_sizes, metadata_list, cluster_representatives)

# Sort by cluster size in descending order
sorted_clusters = sorted(cluster_info, key=lambda x: x[0], reverse=True)

return sorted_clusters


def writeClusterInfo(clusters, cluster_sizes, metadata_list, cluster_representatives):
print("\nCluster info details")
print("sourceId TAB clusterNumber TAB clusterSize TAB clusterCentroidId")

for i, metadata in enumerate(metadata_list):
cluster_id = clusters.labels_[i]
if cluster_id != -1:
print(f"{metadata['id']}\t{cluster_id}\t{cluster_sizes[cluster_id]}\t{cluster_representatives[cluster_id]}")
print("End of cluster info details")

def update_metadata_records(clusters, cluster_sizes, metadata_list, cluster_representatives):
print("Updating metadata records with cluster details\n")

for i, metadata in enumerate(metadata_list):
cluster_id = clusters.labels_[i]
if cluster_id != -1:
metadata["cluster_id"] = int(cluster_id)
metadata["cluster_size"] = cluster_sizes[cluster_id]
metadata["cluster_url"] = cluster_representatives[cluster_id]

# Ensure 'file_path' is not written back
file_path = metadata.pop("file_path", None)

if file_path:
print("Writing back " + file_path)
with open(file_path, 'w') as f:
json.dump(metadata, f, indent=4)


def get_metadata(file_path):
with open(file_path, 'r') as f:
metadata = json.load(f)
# metadata = clean_metadata(metadata)

return metadata


if __name__ == '__main__':
main()
44 changes: 44 additions & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#
# Setup Python package and module structure. This script should be
# placed in the top level of the repository and used by running:
#
# pip3 install setuptools
# pip3 install -e .
#

from setuptools import setup, find_packages

setup(
name='Generative AI',
version='0.1.0',
packages=find_packages(),
install_requires=[
'openai',
'tenacity',
'numpy',
'hdbscan',
'scikit-learn',
'requests',
'bs4',
'cbor2',
'Pillow',
'termcolor'
],
entry_points={
'console_scripts': [
# Define command line scripts here if needed
],
},
author='Francis Crimmins',
author_email='[email protected]',
description='Tools for generative AI',
#long_description=open('README.md').read(),
#long_description_content_type='text/markdown',
#url='https://github.com/nla/gen_ai',
classifiers=[
'Programming Language :: Python :: 3',
'License :: OSI Approved :: MIT License',
'Operating System :: OS Independent',
],
python_requires='>=3.6',
)
31 changes: 31 additions & 0 deletions web/common/demoHeader.ejs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<!DOCTYPE html>
<html>

<head>
<meta charset="UTF-8">
<meta http-equiv="Content-Security-Policy" content="img-src 'self' https://nla.gov.au; script-src-attr 'self' ; script-src 'self' https://fonts.googleapis.com https://apis.google.com; style-src 'self' https://fonts.googleapis.com/ 'unsafe-inline' ">

<!-- meta http-equiv="Content-Security-Policy" content="default-src *; script-src-attr 'self' ; script-src 'self' https://fonts.googleapis.com https://apis.google.com; style-src 'self' https://fonts.googleapis.com/ 'unsafe-inline'" -->

<link rel="icon" type="image/png" sizes="32x32" href="/static/images/favicon-32x32.png">
<link rel="icon" type="image/png" sizes="16x16" href="/static/images/favicon-16x16.png">
<link rel="icon" type="image/png" href="/static/images/favicon.ico">

<link rel="stylesheet" href="/static/style/css.css" type="text/css">

<script src="/static/javascript/js.js"></script>
<link href="https://fonts.googleapis.com/css2?family=Roboto&display=swap" rel="stylesheet">
<link href="https://fonts.googleapis.com/css2?family=Roboto+Condensed&display=swap" rel="stylesheet">

<title>NLA Image Search Evaluation</title>

</head>
<body>

<H3>NLA Image Search Evaluation</H3>
<HR/>
<div class="content" style="padding-top:0.5em">




45 changes: 45 additions & 0 deletions web/common/demoSearchForm.ejs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
<%function setSelect(t, value) {
return (t == value) ? " selected " : "" ;
}
%>

<form action="demoSearch" method="get" id="searchForm">
<input type="hidden" name="searchOrSim" value="search"/>
<table style="margin-bottom:1em">
<tr valign='top'>
<td nowrap>
<input type='text' size='60' id='stxt' name="stxt" style='font-size:105%' value='<%=stxt%>' class='veryLightMauve'/>
<button id='searchButton' type="submit" class='searchButton'><span id='searchText' style='font-size:120%;padding:15px'>Search</span></button>
</td>
<td nowrap style='padding-left:1em;font-variant:small-caps'>

Method:
<select name="method" id="method">
<option value="0" <%=setSelect(0, method)%>>CLIP Image similarity</option>
<option value="1" <%=setSelect(1, method)%>>NLA metadata keyword</option>
<option value="2" <%=setSelect(2, method)%>>OpenAI description keyword</option>
<option value="3" <%=setSelect(3, method)%>>Phi-3.5 description keyword</option>
<option value="4" <%=setSelect(4, method)%>>80% CLIP 20% NLA metadata</option>
<option value="5" <%=setSelect(5, method)%>>50% CLIP 50% NLA metadata</option>
<option value="6" <%=setSelect(6, method)%>>80% CLIP 20% OpenAI description</option>
<option value="7" <%=setSelect(7, method)%>>50% CLIP 50% OpenAI description</option>
<option value="8" <%=setSelect(8, method)%>>80% CLIP 20% Phi-3.5 description</option>
<option value="9" <%=setSelect(9, method)%>>50% CLIP 50% Phi-3.5 description</option>
<option value="10" <%=setSelect(10, method)%>>50% CLIP 30% OpenAI description 20% NLA metadata</option>
<option value="11" <%=setSelect(11, method)%>>50% CLIP 30% Phi-3.5 description 20% NLA metadata</option>
<option value="12" <%=setSelect(12, method)%>>Phi-3 description keyword</option>
</select>
</td>
<td nowrap style='padding-left:1em;font-variant:small-caps'>
Layout:
<select name="layout" id="layout">
<option <%=setSelect("List", layout)%>>List</option>
<option <%=setSelect("Grid", layout)%>>Grid</option>
</select>
</td>
</tr>

</table>
</form>

2 changes: 1 addition & 1 deletion web/common/footer.ejs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
</div> <!-- end of content div -->
<div style='font-size:50%'>v26Aug-17:02</div>
<div style='font-size:50%'>v12Sep-noon</div>
</body>
</HTML>
Loading

0 comments on commit e653d44

Please sign in to comment.