;;;; Functional Combinators

(define (compose f g)
  (assert (= (get-arity f) 1))
  (let ((n (get-arity g)))
    (define (the-composition . args)
      (assert (= (length args) n))
      (f (apply g args)))
    (restrict-arity the-composition n)))

#|
((compose (lambda (x)
	    (list 'foo x))
	  (lambda (x)
	    (list 'bar x)))
 'z)
;Value: (foo (bar z))
|#

(define (parallel-combine combine t1 t2)
  (let ((n1 (get-arity t1))
	(n2 (get-arity t2)))
    (define (the-combination . args)
      (assert (= (length args) n1))
      (combine (apply t1 args)
	       (apply t2 args)))
    (assert (= n1 n2))
    (restrict-arity the-combination n1)))

#|
((parallel-combine list
		   (lambda (x y z)
		     (list 'foo x y z))
		   (lambda (u v w)
		     (list 'bar u v w)))
 'a 'b 'c)
;Value: ((foo a b c) (bar a b c))
|#

(define (spread-combine combine t1 t2)
  (let ((n1 (get-arity t1))
	(n2 (get-arity t2)))
    (define (the-combination . args)
      (assert (= (length args) (+ n1 n2)))
      (combine
       (apply t1 (list-head args n1))
       (apply t2 (list-tail args n2))))
    (restrict-arity the-combination (+ n1 n2))))

#|
((spread-combine list
		 (lambda (x y z)
		   (list 'foo x y z))
		 (lambda (u v w)
		   (list 'bar u v w)))
 'a 'b 'c 'd 'e 'f)
;Value: ((foo a b c) (bar d e f))
|#

(define (discard-argument i)
  (lambda (f)
    (let ((m (+ (get-arity f) 1)))
      (define (the-combination . args)
	(assert (= (length args) m))
	(apply f
	       (append
		(list-head args i)
		(list-tail args (+ i 1)))))
      (assert (< i m))
      (restrict-arity the-combination m))))

#|
(((discard-argument 2)
  (lambda (x y z)
    (list 'foo x y z)))
 'a 'b 'c 'd)
;Value: (foo a b d)
|#

(define (curry-argument i)
  (lambda (f)
    (let ((m (- (get-arity f) 1)))
      (define (the-combination . args)
	(assert (= (length args) m))
	(lambda (x)
	  (apply f 
		 (list-insert args i x))))
      (restrict-arity the-combination m))))

#|
((((curry-argument 2)
   (lambda (x y z w)
     (list 'foo x y z w)))
  'a 'b 'c)
 'd)
;Value: (foo a b d c)
|#


(define (permute-arguments . permutation)
  (lambda (f)
    (let ((n (get-arity f)))
      (define (the-combination . args)
	(assert (= n (length args)))
	(apply f (permute permutation args)))
      (assert (= n (length permutation)))
      (restrict-arity the-combination n))))

#|
(((permute-arguments 1 2 0 3)
  (lambda (x y z w)
    (list 'foo x y z w)))
 'a 'b 'c 'd)
;Value: (foo b c a d)
|#

(define (curry-arguments . target-indices)
  (lambda (f)
    (let ((n (get-arity f))
	  (m (length target-indices)))
      (define (the-combination . oargs)
	(assert (= (length oargs) (- n m)))
	(lambda nargs
	  (assert (= (length nargs) m))
	  (let lp ((is target-indices)
		   (args oargs)
		   (nargs nargs))
	    (if (null? is)
		(apply f args)
		(lp (cdr is)
		    (list-insert args
				 (car is)
				 (car nargs))
		    (cdr nargs))))))
      (restrict-arity the-combination
		      (- n m)))))

#|
((((curry-arguments 0 2)
   (lambda (x y z w)
     (list 'foo x y z w)))
  'a 'b)
 'c 'd)
;Value: (foo c a d b)
|#

(define (fan-out-argument . target-indices)
  (lambda (f)
    (let ((n (get-arity f))
	  (m (length target-indices)))
      (define (the-combination . oargs)
	(assert (= (length oargs) (- n m)))
	(lambda (x)
	  (let lp ((is target-indices)
		   (args oargs))
	    (if (null? is)
		(apply f args)
		(lp (cdr is)
		    (list-insert args
				 (car is)
				 x))))))
      (restrict-arity the-combination
		      (- n m)))))

#|
((((fan-out-argument 0 2)
   (lambda (x y z w)
     (list 'foo x y z w)))
  'a 'b)
 'c)
;Value: (foo c a c b)
|#

(define arity-table (make-eq-hash-table))

(define (restrict-arity proc nargs)
  (hash-table/put! arity-table proc nargs)
  proc)

(define (get-arity proc)
  (let ((a (hash-table/get arity-table proc #f)))
    (or a
	(let ((a (procedure-arity proc)))
	  (assert (= (car a) (cdr a)))
	  (car a)))))

(define (list-insert lst index value)
  (let lp ((lst lst) (index index))
    (if (= index 0)
	(cons value lst)
	(cons (car lst)
	      (lp (cdr lst) (- index 1))))))


;;; Given a permutation (represented as
;;; a list of numbers), and a list to be
;;; permuted, construct the list so
;;; permuted.

(define (permute permutation lst)
  (map (lambda (p)
	 (list-ref lst p))
       permutation))

