(ns monads
(:require clojure.set))
(declare ^:dynamic return
^:dynamic bind)
(defn lift-inc [v]
(return (inc v)))
(defn m-add [mv n]
(bind mv (fn [v] (return (+ v n)))))
(defn m-div [mv n]
(if (zero? n)
(return nil)
(bind mv (fn [v] (return (/ v n))))))
;; ===============================
;; Identity monad
(binding [return (fn [v]
(fn [] v))
bind (fn [mv f]
(f (mv)))]
((->
(return 0)
(bind lift-inc)
(m-add 10)
(m-div 2)
)))
;; -> 11/2
;; ===============================
;; maybe monad
(binding [return (fn [v]
(fn [] v))
bind (fn [mv f]
(if-let [v (mv)]
(f v)
(return nil)))]
((->
(return 0)
(m-add 10)
(m-div 0)
(m-add 2))))
;; -> nil
;; ===============================
;; state monad
(defn setf [v] (fn [_] [v v]))
(def getf (fn [s] [s s]))
(defn set-state [mv]
(bind mv setf))
(defn get-state [mv]
(bind mv (fn [_] getf)))
(defn add-from-state [mv]
(bind mv (fn [v]
(bind getf
(fn [s] (return (+ s v)))))))
(binding [return (fn [v]
(fn [s] [v s]))
bind (fn [mv f]
(fn [s]
(let [[v sn] (mv s)]
((f v) sn))))]
[
((->
(return 0)
(m-add 10)
set-state
add-from-state
)0)
(let [m-count (bind getf (fn [s]
(setf (inc s))))
m-inc (fn [mv]
(bind mv (fn [v]
(bind m-count
(fn [_] (return (inc v)))))))]
((-> (return 5)
(m-add 5)
m-inc
m-inc
)0))])
;; -> [[20 10] [12 2]]
;; ===============================
;; Continuation monad
(defn halt [x]
(fn [c] x))
(defn bounce [x]
(fn [c]
(fn [] (c x))))
(defn mark [x]
(fn [c] c))
(binding [return (fn [v]
(fn [c] (c v)))
bind (fn [mv f]
(fn [c]
(mv (fn [v]
((f v) c)))))]
[
((-> (return 21)
(bind lift-inc)
(bind halt)
(m-add 10))
identity)
(trampoline
((-> (return 21)
(bind lift-inc)
(bind bounce)
(m-add 10))
identity))
(doall (map ((-> (return 21)
(bind mark)
(m-add 10)
(bind lift-inc)
)identity)
[0 1 2]))])
;; -> [22 32 (11 12 13)]
;; ===============================
;; List monad
(defn lift-id-half-double [x]
(return x (/ x 2) (* x 2)))
(binding [return (fn [& v]
(fn [] (apply list v)))
bind (fn [mv f]
(fn [] (mapcat (comp #(%) f)
(mv))))]
((->
(return 4 8)
(bind lift-id-half-double)
)))
;; -> (4 2 8 8 4 16)
;; ===============================
;; Set monad
(binding [return (fn [& v]
(fn [] (apply hash-set v)))
bind (fn [mv f]
(fn [] (apply clojure.set/union
(map (comp #(%) f) (mv)))))]
((->
(return 4 8)
(bind lift-id-half-double)
)))
;; -> #{2 4 8 16}