===================
== Martin Trojer ==
Programming Blog

===================

The M Word

clojure monads
(ns monads
  (:require clojure.set))

(declare ^:dynamic return
         ^:dynamic bind)

(defn lift-inc [v]
  (return (inc v)))

(defn m-add [mv n]
  (bind mv (fn [v] (return (+ v n)))))

(defn m-div [mv n]
  (if (zero? n)
    (return nil)
    (bind mv (fn [v] (return (/ v n))))))

;; ===============================
;; Identity monad

(binding [return (fn [v]
                     (fn [] v))
          bind (fn [mv f]
                   (f (mv)))]
  ((->
    (return 0)
    (bind lift-inc)
    (m-add 10)
    (m-div 2)
    )))
;; -> 11/2

;; ===============================
;; maybe monad

(binding [return (fn [v]
                     (fn [] v))
          bind (fn [mv f]
                   (if-let [v (mv)]
                     (f v)
                     (return nil)))]
  ((->
    (return 0)
    (m-add 10)
    (m-div 0)
    (m-add 2))))
;; -> nil

;; ===============================
;; state monad

(defn setf [v] (fn [_] [v v]))
(def getf (fn [s] [s s]))

(defn set-state [mv]
  (bind mv setf))

(defn get-state [mv]
  (bind mv (fn [_] getf)))

(defn add-from-state [mv]
  (bind mv (fn [v]
               (bind getf
                       (fn [s] (return (+ s v)))))))

(binding [return (fn [v]
                     (fn [s] [v s]))
          bind (fn [mv f]
                   (fn [s]
                     (let [[v sn] (mv s)]
                       ((f v) sn))))]
  [
   ((->
     (return 0)
     (m-add 10)
     set-state
     add-from-state
     )0)

   (let [m-count (bind getf (fn [s]
                              (setf (inc s))))
         m-inc (fn [mv]
                 (bind mv (fn [v]
                              (bind m-count
                                      (fn [_] (return (inc v)))))))]
     ((-> (return 5)
          (m-add 5)
          m-inc
          m-inc
          )0))])
;; -> [[20 10] [12 2]]

;; ===============================
;; Continuation monad

(defn halt [x]
  (fn [c] x))

(defn bounce [x]
  (fn [c]
    (fn [] (c x))))

(defn mark [x]
  (fn [c] c))

(binding [return (fn [v]
                   (fn [c] (c v)))
          bind (fn [mv f]
                 (fn [c]
                   (mv (fn [v]
                         ((f v) c)))))]

  [
   ((-> (return 21)
        (bind lift-inc)
        (bind halt)
        (m-add 10))
    identity)

   (trampoline
    ((-> (return 21)
         (bind lift-inc)
         (bind bounce)
         (m-add 10))
     identity))

   (doall (map ((-> (return 21)
                    (bind mark)
                    (m-add 10)
                    (bind lift-inc)
                    )identity)
               [0 1 2]))])
;; -> [22 32 (11 12 13)]

;; ===============================
;; List monad

(defn lift-id-half-double [x]
  (return x (/ x 2) (* x 2)))

(binding [return (fn [& v]
                     (fn [] (apply list v)))
          bind (fn [mv f]
                   (fn [] (mapcat (comp #(%) f)
                                 (mv))))]
  ((->
     (return 4 8)
     (bind lift-id-half-double)
     )))
;; -> (4 2 8 8 4 16)

;; ===============================
;; Set monad

(binding [return (fn [& v]
                     (fn [] (apply hash-set v)))
          bind (fn [mv f]
                   (fn [] (apply clojure.set/union
                                (map (comp #(%) f) (mv)))))]
  ((->
     (return 4 8)
     (bind lift-id-half-double)
     )))
;; -> #{2 4 8 16}