001 (ns org.clojars.punit-naik.clj-ml.k-means
002 (:require [org.clojars.punit-naik.clj-ml.utils.generic :as generic-utils]
003 [org.clojars.punit-naik.clj-ml.utils.geometry :as geometry-utils]
004 [org.clojars.punit-naik.clj-ml.utils.matrix :as matrix-utils]))
005
006 (defn generate-initial-clusters
007 [dimensions no-of-clusters]
008 (->> (make-array Double/TYPE no-of-clusters dimensions)
009 (map-indexed
010 (fn [index cluster]
011 [index (map identity cluster)]))
012 (into {})))
013
014 (defn assign-closest-cluster
015 "Assigns a point to the closest cluster
016 Returns the `key`/index of the cluster"
017 [point clusters-map]
018 (->> clusters-map
019 (map
020 (fn [[index cluster]]
021 [index (geometry-utils/distance point cluster)]))
022 (into {})
023 (apply min-key val)
024 key))
025
026 (defn assign-cluster
027 "Assigns a cluster to all data points"
028 [data-points clusters]
029 (reduce
030 (fn [m point]
031 (update m (assign-closest-cluster point clusters) conj point))
032 (->> clusters
033 keys
034 (map #(vector % []))
035 (into {}))
036 data-points))
037
038 (defn sum-squared-distance
039 [data-points cluster]
040 (->> data-points
041 (map #(Math/pow (geometry-utils/distance cluster %) 2))
042 (reduce +)))
043
044 (defn update-clusters
045 ([data-points no-of-clusters]
046 (update-clusters
047 data-points
048 (generate-initial-clusters
049 (-> data-points first count)
050 no-of-clusters)
051 0.001))
052 ([data-points clusters error-rate]
053 (loop [clusterz clusters
054 cluster-assignments (assign-cluster data-points clusters)
055 within-error-rate? false]
056 (if within-error-rate?
057 (reduce
058 (fn [result [cluster-key cluster]]
059 (conj result
060 (let [assigned-data-points (get cluster-assignments cluster-key)]
061 {:cluster cluster
062 :assigned-data-points assigned-data-points
063 :sum-squared-distance (sum-squared-distance assigned-data-points cluster)})))
064 []
065 clusterz)
066 (let [old-new-clusterz (map (fn [[cluster-key assigned-points]]
067 [(get clusterz cluster-key)
068 (if (seq assigned-points)
069 (->> assigned-points
070 matrix-utils/transpose
071 (map generic-utils/mean-coll))
072 (make-array Double/TYPE (-> data-points first count)))])
073 cluster-assignments)
074 clusterz-error-rates (map
075 (fn [[old-cluster new-cluster]]
076 (<=
077 (Math/abs
078 (-
079 (generic-utils/mean-coll old-cluster)
080 (generic-utils/mean-coll new-cluster)))
081 error-rate))
082 old-new-clusterz)
083 new-clusterz (->> old-new-clusterz
084 (map
085 (fn [[old-cluster-key _] [_ new-cluster]]
086 [old-cluster-key new-cluster])
087 clusterz)
088 (into {}))]
089 (recur new-clusterz
090 (assign-cluster data-points new-clusterz)
091 (every? true? clusterz-error-rates)))))))
092
093 (defn rapidly-changing?
094 [wcss]
095 (let [{ssd-mean-last :sum-squared-distance-mean} (last wcss)
096 {ssd-mean-second-last :sum-squared-distance-mean} (-> wcss butlast last)]
097 (> (Math/abs (- ssd-mean-last ssd-mean-second-last)) 0.1)))
098
099 (defn elbow-method-data
100 "Generates data for using elbow method"
101 [data-points]
102 (loop [wcss []
103 c 1]
104 (if (and (seq wcss)
105 (or (->> wcss
106 last
107 :clusters
108 (filter #(not (seq (:assigned-data-points %))))
109 not-empty)
110 (and (> (count wcss) 1)
111 (not (rapidly-changing? wcss)))))
112 (butlast wcss)
113 (let [clusters (update-clusters data-points c)]
114 (recur (conj wcss
115 (let [sum-squared-distance-mean (->> clusters
116 (map :sum-squared-distance)
117 generic-utils/mean-coll)]
118 {:sum-squared-distance-mean sum-squared-distance-mean
119 :distance-from-origin (geometry-utils/distance [(count clusters) sum-squared-distance-mean] [0 0])
120 :clusters (map #(select-keys % [:cluster :assigned-data-points]) clusters)}))
121 (inc c))))))
122
123 (defn elbow
124 "Finds out the elbow point from the clusters,wcss points"
125 [wcss]
126 (->> wcss
127 (sort-by :distance-from-origin)
128 first
129 :clusters))
130
131 (defn data-with-assigned-clusters
132 [clusters]
133 (mapcat
134 (fn [{:keys [cluster assigned-data-points]}]
135 (map
136 #(assoc {} :data-point % :cluster cluster)
137 assigned-data-points))
138 clusters))