001  (ns org.clojars.punit-naik.clj-ml.utils.linear-algebra
002    (:require [org.clojars.punit-naik.clj-ml.utils.calculus :as cu]
003              [org.clojars.punit-naik.clj-ml.utils.generic :as gu]
004              [clojure.set :refer [intersection]]))
005  
006  (defn eval-fn
007    "Evaluates a function ax^n+bx^n-1+...+z represented by a collection of it's coefficients [a b ... z]
008     at the value `x`"
009    [eq x]
010    (reduce + (map-indexed (fn [idx coeff]
011                             (* coeff (Math/pow x (- (dec (count eq)) idx)))) eq)))
012  
013  (defn factors
014    "Finds all the factors of a number"
015    ([num]
016     (loop [n (range 1 (inc (Math/abs num)))
017            result #{}]
018       (if (empty? n)
019         result
020         (recur (rest n)
021                (cond-> result
022                  (zero? (mod (Math/abs num) (first n))) (conj (first n)))))))
023    ([num decimal?]
024     (if decimal?
025       (let [[up down] (gu/rationalise (Math/abs num))]
026         (cond-> (factors up)
027           (not= down 1) (intersection (factors down))))
028       (factors num))))
029  
030  (defn isa-solution?
031    "Given the a to z terms of the quadratic equation ax^n+....+z=0 as a collection
032     And the root, this function checks if the same root is a solution for the equation or not
033     And returns the new reduced equation for finding the remaining roots using Synthetic Division"
034    [coefficients root]
035    (let [result (reduce (fn [{:keys [sum] :as acc} v]
036                           (let [s (+ (if (zero? sum) v (* root sum)) (if (zero? sum) sum v))]
037                             (-> acc
038                                 (assoc :sum s)
039                                 (update :coeffs conj s))))
040                         {:sum 0 :coeffs []} coefficients)]
041      (when (and (zero? (Math/round (* (:sum result) 1.0)))
042                 (seq (:coeffs result)))
043        (:coeffs (update result :coeffs butlast)))))
044  
045  (defn find-all-possible-solutions
046    "Given the a to z terms of the equation ax ^n+....+z=0 as a collection
047     This function finds all the possible roots of this equation"
048    [coefficients]
049    (when (> (count coefficients) 3)
050      (let [first-coefficient (first coefficients)
051            last-coefficient (last coefficients)
052            first-coeff-factors (sort (factors first-coefficient true))
053            last-coeff-factors (sort (factors last-coefficient true))]
054        (->> (map (fn [i]
055                    (map (fn [j]
056                           (if (> (Math/abs last-coefficient)
057                                  (Math/abs first-coefficient))
058                             [(/ j i) (/ (* j -1) i)]
059                             [(/ i j) (/ (* i -1) j)]))
060                         last-coeff-factors))
061                  first-coeff-factors)
062             flatten distinct sort))))
063  
064  (defn newtons-method
065    "Uses Newton's method to find the root of an equation ax^n+bx^n-1+...+z
066     Represented as a collection of it's coefficients [a b ... z]
067     It selects a root for precision upto the number set by the arg `precision`
068     x1 = x0 - ( f(x0) / f'(x0) )"
069    [eq eq-deriv precision x-0]
070    (loop [testing-root x-0
071           error 1]
072      (let [f-x (eval-fn eq testing-root)]
073        (if (or (zero? f-x)
074                (<= error (gu/error-decimal precision)))
075        testing-root
076        (let [f-dash-x (eval-fn eq-deriv testing-root)
077              new-root (if (zero? f-dash-x)
078                         ((if (= (/ testing-root (Math/abs testing-root)) -1) + -)
079                          testing-root (gu/error-decimal precision))
080                         (- testing-root (/ f-x f-dash-x)))]
081          (recur new-root (Math/abs (- new-root testing-root))))))))
082  
083  (defmulti solve-equation
084    "Given the a to z terms of the equation ax^n+....+z=0
085     This returns all the roots for the equation"
086    (fn [coefficients]
087      (condp = (count coefficients)
088        2 :linear
089        3 :quadratic
090        :default)))
091  
092  (defmethod solve-equation :linear
093    [[a b]]
094    [(gu/approximate-decimal (/ (* -1 b) a))])
095  
096  (defmethod solve-equation :quadratic
097    [[a b c]]
098    (let [negative-b (* -1 b)
099          square-root-of-b-suared-minus-4-ac (Math/sqrt (- (Math/pow b 2) (* 4 a c)))
100          twice-a (* 2 a)]
101      [(gu/approximate-decimal (/ (+ negative-b square-root-of-b-suared-minus-4-ac) twice-a))
102       (gu/approximate-decimal (/ (- negative-b square-root-of-b-suared-minus-4-ac) twice-a))]))
103  
104  (defn solve-equation-synthetic-division
105    [coefficients]
106    (loop [coeffs coefficients
107           all-possible (reverse (find-all-possible-solutions coeffs))
108           solutions []]
109      (if (empty? all-possible)
110        (map #(gu/approximate-decimal (* %1 1.0))
111             (cond-> solutions
112               (= (count coeffs) 3) (into (solve-equation coeffs))))
113        (let [testing-root (first all-possible)
114              next-eq (isa-solution? coeffs testing-root)]
115          (recur (or next-eq coeffs)
116                 (if next-eq
117                   (reverse (find-all-possible-solutions next-eq))
118                   (rest all-possible))
119                 (cond-> solutions
120                   next-eq (conj testing-root)))))))
121  
122  (defn solve-equation-newtons-method
123    [coefficients]
124    (let [precision 5]
125      (->> (find-all-possible-solutions coefficients)
126           (map #(gu/approximate-decimal
127                  (newtons-method coefficients
128                                  (cu/derivative coefficients)
129                                  precision %)
130                  precision))
131           distinct)))
132  
133  (defmethod solve-equation :default
134    [coefficients]
135    (let [sesd (solve-equation-synthetic-division coefficients)]
136      (if (seq sesd)
137        sesd
138        (solve-equation-newtons-method coefficients))))