diff --git a/Clustering.py b/Clustering.py index 079b528..318d9d3 100644 --- a/Clustering.py +++ b/Clustering.py @@ -97,14 +97,24 @@ def classify_a_point(point, groups, k): index = freq.index(max(freq)) return index -def cluster(data, no_of_clusters, k): +def cluster(data, no_of_clusters, k, max_iterations): groups=initialize(data, no_of_clusters) groups=[[element] for element in groups] - for i in range(data.shape[0]): - group_no = classify_a_point(data[i,:], groups,k) - groups[group_no].append(data[i,:]) + for n in range(max_iterations): + for i in range(data.shape[0]): + group_no = classify_a_point(data[i,:], groups,k) + groups[group_no].append(data[i,:]) return groups +def plot_clusters(groups): + colors = cm.rainbow(np.linspace(0, 1, len(groups))) + for group in range(len(groups)): + plt.scatter(*zip(*groups[group]), color=colors) + plt.show() + +plot_clusters(cluster(data, 4, 5,5)) + +