;;;======================================================================
;;; NLP code for use with Natural Language Understanding, 2nd ed.
;;; Copyright (C) 1994 James F. Allen
;;;
;;; This program is free software; you can redistribute it and/or modify
;;; it under the terms of the GNU General Public License as published by
;;; the Free Software Foundation; either version 2, or (at your option)
;;; any later version.
;;;
;;; This program is distributed in the hope that it will be useful,
;;; but WITHOUT ANY WARRANTY; without even the implied warranty of
;;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
;;; GNU General Public License for more details.
;;;
;;; You should have received a copy of the GNU General Public License
;;; along with this program; if not, write to the Free Software
;;; Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
;;;======================================================================


;;  This file contains the code for computing and maintaining the probability tables
;;   needed for the stochastic parser.

(let (
      ;;   These are the four tables neede to drive the parser

      (*CFruleprobs*)       ;;  context free rule probabilities indexed by rule name
      (*CSruleprobs*)       ;;  context sensitive rule probabilities indexed by rule and lex category
      (*bigrams*)           ;;  bigram probabilities indexed by lexical categories
      (*lexGenProbs*)       ;;  lexical generation probabilities

      ;;   These are temporary tables used in the corpus analysis

      (*lexCatCount*)       ;;  temporary count for each lexical category
      (*bigramtotals*)      ;;  temporary count of bigrams
      (*CSruletotals*)      ;;  temporary count of CSrules (indexed by rule and lex cat)
      (*CFruletotals*)      ;;  temporary count of CFrules (indexed by rules)
      (*lexGentotals*)      ;;  temporary count of lexical entry use
      )

  ;*****************************
  ;;INITIALIZATION

  (defun init-tables nil
    (let ((gsize (length (get-grammar)))
          (lsize (length (defined-words))))
      (cond ((eq gsize 0) (Format t "Warning: No grammar has been specified"))
            ((eq lsize 0) (Format t "Warning: No lexicon has been specified"))
            (t
             (setq *CSruleProbs* (make-hash-table :size gsize))
             (setq *CFruleProbs* (make-hash-table :size gsize))
             (setq *bigrams* (make-hash-table :size 20))
             (setq *LexGenProbs* (make-hash-table :size lsize))
             
             (setq *lexcatcount* (make-hash-table :size 20))
             (setq *LexGentotals* (make-hash-table :size 20))
             (setq *CSruletotals* (make-hash-table :size gsize))
             (setq *CFruletotals* (make-hash-table :size gsize))
             (setq *bigramtotals* (make-hash-table :size 20))))))
   
  (defun CSruletotals ()
    *CSruletotals*)
  
  (defun CFruletotals ()
    *CFruletotals*)

  (defun updateCatCount (cat)
    (updateCount cat *lexcatcount* 1))
  
  (defun getCatCount (cat)
    (gethash cat *lexCatCount* 1))

 (defun updateCSruleCount (rule word)
   (update2count rule word *CSruletotals* 1))

 (defun updateCFruleCount (rule)
   (updateCount rule *CFruletotals* 1))

 (defun updateBigramCount (cat1 cat2)
   (update2count cat1 cat2 *bigramtotals* 1))

 (defun getBigramCounts ()
   *bigramtotals*)

 (defun setBigramProb (cat1 cat2 val)
   (update2count cat1 cat2 *bigrams* val))

 (defun getBigramProb (cat1 cat2)
   (let ((val (get2hash cat1 cat2 *bigrams*)))
     (if val val .0001)))

 ;;  Collecting the lexical statistics

 (defun updateLexCount (id)
   (updatecount id *LexGentotals* 1))

 (defun getLexGenTotal (id)
   (gethash id *LexGenTotals* .5))

 (defun setLexGenProb (id val)
   (setf (gethash id *LexGenProbs*) val))
 
 (defun getLexGenProb (id)
   (let ((val (gethash id *LexGenProbs*)))
    (if val val .0001)))

 (defun setCFruleprob (rule-id val)
   (setf (gethash rule-id *CFruleProbs*) val))
 
 (defun getCFruleProb (id)
   (gethash id *CFruleProbs* .0001))

 (defun setCSruleProb (rule-id word val)
   (update2count rule-id word *CSruleProbs* val))

 (defun getCSruleProb (rule-id word)
   (let ((val (get2hash rule-id word *CSruleProbs*)))
     (if val val .0001)))    ;; return default of .0001 if not found

 (defun prob-tables-defined nil
   (and (hash-table-p *CFruleProbs*) 
        (hash-table-p *CSruleProbs*) 
        (hash-table-p *LexGenProbs*)
        (hash-table-p *bigrams*)))

;;   These functions are used to save the probabilities

 (defun dump-stats nil
   (save2hashtable 'SetBigramProbs *bigrams*)
   (save2hashtable 'SetCSruleprobs *CSruleprobs*)
   (savehashtable 'SetCFruleprobs *CFruleprobs*)
   (savehashtable 'SetlexGenProbs *LexGenProbs*))

 (defun setBigrams (table)
   (setq *bigrams* table))

 (defun setCSruleprobs (table)
   (setq *CSruleProbs* table))

 (defun setCFruleProbs (table)
   (setq *CFruleProbs* table))
 
 (defun setlexGenProbs (table)
   (setq *LexGenProbs* table))

)   ;; end scope of statistical variables

;; =========================================================================================
;;  Aux functions to manipulate hash tables

;;  update a simple hash table
(defun updateCount (val table amt)
  (setf (gethash val table 0) (+ (gethash val table 0) amt)))

;;  Maintaining nested hash tables

(defun get2hash (val1 val2 table)
  (let ((table1 (gethash val1 table)))
    (if table1 (gethash val2 table1) nil)))

;;  update a nested hashtable
(defun update2count (val1 val2 table amt)
   (let ((ruletable (gethash val1 table)))
     (when (null ruletable) 
       (setq ruletable (make-hash-table :size 20))
       (setf (gethash val1 table) ruletable))
     (updateCount val2 ruletable amt)))

;; =========================================================================================

;;  RECORDING THE CORPUS
;;  When a sentence has been parsed, this function writes out the desired parse
;;   in a form that can be read in later.


;; WRITE-TREE prints out the parse tree rooted at the specified entry
;;  The format of each constituent is
;;    (category rule-id list-of-feature-values subconstituent1 ... subconstituentn)

(defun write-tree (entry-name)
  (dump-tree 0 (get-entry-by-name entry-name) nil))

(defun dump-tree (prefix entry bindings)
  (let* ((constit (entry-constit entry))
         (subconstitnames (getsubconstitnames 1 constit))
         (subentries (mapcar #'get-entry-by-name
                              subconstitnames))
         (subconstits (mapcar #'entry-constit subentries))
         (bndgs (merge-lists 
                 (cons bindings (mapcar #'constit-match (entry-rhs entry) subconstits)))))
  (Format t "~%")
  (print-blanks prefix)
  (Format t "(~S ~S ~S" (constit-cat constit) (entry-rule-id entry) (constit-feats constit))
  (mapcar #'(lambda (n e) (dump-tree (1+ prefix) e bndgs))
             subconstitnames (mapcar #'(lambda (e) (subst-in e bndgs)) subentries))
  (Format t ")")))

;; =========================================================================================
;;  DATA ANALYSIS

;;  INTERPRET-TREE traverses a tree and collects counts on rules and lexical
;;  items used

(defun interpret-tree (tree)
  (init-lastcat)
  (updateCatCount 'START)
  (traverse-tree tree nil)
  (updateBigramCount (lastcat) 'END))

(defun traverse-tree (tree path-to-root)
  (let ((cat (first tree))
        (id (second tree))
        (constit (third tree))
        (subconstits (cdddr tree)))
    (cond ((or subconstits (not (member cat (getLexicalCats))))
           (updateCFruleCount id) 
           ;;   interpret left-most constituent down to a lexical entry
           (if (car subconstits)
             (traverse-tree (car subconstits) (cons id path-to-root)))
           ;;   interpret remaining constituents (but don't need path to root)
           (if (cdr subconstits)
             (mapcar #'(lambda (c)
                       (traverse-tree c nil))
                   (cdr subconstits))))
          (t   ;;  must be a lexical entry
           (update-lexical-counts cat (get-fvalue constit 'lex) id path-to-root)
           ))))

(defun update-lexical-counts (cat word word-id path-to-root)
  (UpdateCatCount cat)
  (UpdateLexCount word-id)
  (mapcar #'(lambda (rule)
              (updateCSruleCount rule word))
          path-to-root)
  (updateBigramCount (lastcat) cat)
  (updateLastCat cat))

(defun find-lex-entry-by-id (id)
  (find-if #'(lambda (x) (eq (lex-entry-id (cadr x)) id))
            (get-lexicon)))

(defun find-rule-by-id (id)
  (find-if #'(lambda (x) (eq (rule-id x) id)) (get-grammar)))

;;  Functions to maintain the last cat to count bigrams

(let ((*lastcat*))
  (defun init-lastcat ()
    (setq *lastcat* 'START))
  (defun lastcat ()
    *lastcat*)
  (defun updateLastCat (cat)
    (setq *lastcat* cat)))


(defun docorpus (c)
  (init-tables)
  (mapcar #'(lambda (x) (interpret-tree x)) c)
  (computeLexprobs)
  (computeruleprobs)
)

;******************************************************************
;*
;*  computing the probabilities
;*
;*******************

;;  This computes the context independent lexical probabilities and bigrams
;;   LexGenProbs are computed for each lexical id using #times-id-seen/#-times-category-seen
;;   Bigrams are computed for two cats c1 and c2 by #-times-c1-follows-c2/#-times-c1-seen

(defun computeLexprobs nil
  ;;  normalize the counts in the lexicon to produce context independent estimates (prob entry given its word)
  ;;   and Prob of entry given its Cat
  (maphash #'(lambda (word lexentries)
              (mapcar #'(lambda (entry)
                          (let ((id (lex-entry-id entry)))
                            (setLexGenProb id
                                       (/ (getLexGenTotal id) 
                                          (getCatCount (constit-cat (lex-entry-constit entry)))))))
                      lexentries))
         (get-lexicon))

  ;;  compute bigram probabilities from bigram totals
  (maphash #'(lambda (cat1 ht)
               (maphash #'(lambda (cat2 val)
                            (setBigramProb cat1 cat2 
                                           (/ val (getCatCount cat1))))
                        ht))
           (getBigramCounts))
  
)
  
;;  This computes the CF and CS rule probabilities from the totals found
;;   by analyzing the corpus
;;   CFruleProbs are computed by #-times-rule-seen/#-times-mother-cat-seen
;;     (e.g., (count NP -> ART N used) / (count # of NPs)
;;   CSruleProbs are computed relative to current word 
;;            (count #-times-rule-seen and word is w)/#-of-times-mother-cat-seen & word is w)

(defun computeruleprobs nil
  ;; Computing the CF probabilities
  ;;   First compute the number of rule instances for each mother constituent
  ;;   and then uses this to compute the CF prob for each rule
  (let* ((MotherCounts (make-hash-table :size 10)))
    (maphash #'(lambda (id count)
                 (updateCount (get-cat-from-id id) MotherCounts count))
             (CFruleTotals))
    (maphash #'(lambda (id count)
                 (setCFruleProb id (/ count (gethash (get-cat-from-id id) MotherCounts))))
             (CFruleTotals)))
  ;;  Compute the CS rule probs
  ;;  much as the CF case except for hasing on two items: the cat and the word
  (let* ((MotherWordCounts (make-hash-table :size 10)))
    (maphash #'(lambda (id table)
                 (let ((motherCat (get-cat-from-id id)))
                   (maphash #'(lambda (word count)
                                (update2Count motherCat word MotherWordCounts count))
                            table)))
             (CSruleTotals))
    (maphash #'(lambda (id table)
                 (let ((motherCat (get-cat-from-id id)))
                  (maphash #'(lambda (word count)
                               (setCSruleProb id word (/ count
                                                         (get2hash motherCat word MotherWordCounts))))
                           table)))
             (CSruleTotals))))
    
  
;; find the category of the mother constituent in the rule
;;   note special cases for the two gap insertion rules                 

(defun get-cat-from-id (id)
  (cond ((eq id 'NP-GAP-INTRO) 'NP)
        ((eq id 'GAP-INTRO) 'PP)
        (t (constit-cat (rule-lhs (find-rule-by-id id))))))


;************************************************************
;*
;*    COMPUTING CONTEXT DEPENDENT LEXICAL PROBABILITIES FOR A SENTENCE
;*
;************************************************************

;; This applies the FORWARD algorithm to the possible lexical
;;   constituents, modifying their probabilities (which were originally
;;   set to the CF probability. Unlike the algorithm in the book, which uses
;;   an N by T array, where N is the number of lexical categories and T the 
;;   number of words, this uses a single dimensional array of size T, where 
;;   each element is list of possible lexical entries at that position.
;;   This allows for shorter lists, and also allows a word to have two distinct
;;   interpretations of the same category. The answers are stored in the prob field
;;   of each lexical entry.

(defun GetCSlexicalEntries (s)
 (if (prob-tables-defined)
  (let* ((L (length s))
         (prob (make-array L)))
    ;; Check that prob tables are set
    
    ;; initialization
    (setf (aref prob 0) (lookupword (car s) 0))
    (mapcar #'(lambda (entry)
                 (setf (entry-prob entry)
                       (* (getBigramProb 'START (constit-cat (entry-constit entry)))
                          (getLexGenProb (entry-rule-id entry)))))
             (aref prob 0))

    ;;  sweeping forward
    (do ((tt 1 (+ tt 1))) ((>= tt L)) 
      (let ((word (nth tt s))
            (prevEntries (aref prob (- tt 1))))
        (setf (aref prob tt) (lookupword word tt))
        ;;  now modify probabilities based on the bigram stats
        ;;  For each new entry at position tt, add up the probs from bigrams with
        ;;   entries at position (tt -1)
        (mapcar #'(lambda (entry)
                    (let ((sum 0)
                          (cat (constit-cat (entry-constit entry))))
                      (mapcar #'(lambda (prev)
                                  (setq sum (+ sum
                                               (* (entry-prob prev)
                                                  (getBigramProb (constit-cat (entry-constit prev))
                                                                 cat)))))
                              prevEntries)
                      (setf (entry-prob entry) (* sum (getLexGenProb (entry-rule-id entry))))))
                (aref prob tt))))

    ;;  Normalize the probabilities   
    (dotimes (tt L)
      (let ((sum 0))
        (mapcar #'(lambda (e)
                    (setq sum (+ sum (entry-prob e))))
                (aref prob tt))
        (if (> sum 0)
          (mapcar #'(lambda (e)
                 (setf (entry-prob e) (/ (entry-prob e) sum)))
                  (aref prob tt)))))

    ;;  return the entries
    (let ((ans))
      (dotimes (tt L)
        (setq ans (append ans (aref prob tt))))
      ans)
      
)
  (Format t "~%~%Error: probability tables are not defined~%~%")))

;; =========================================================================================
;;
;;  I/O ROUTINES

(defun saveHashTable (fn table)
  (Format t "~%~%(~s ~%    (BuildHashTable '(" fn)
    (if table
      (maphash #'(lambda (id val)
                 (Format t "~%         (~S ~s)" id val))
               table))
    (Format t "~%)))"))

(defun save2HashTable (fn table)
  (Format t "~%~%(~s ~%    (Build2HashTable '(" fn)
    (if table
      (maphash #'(lambda (id val)
                   (Format t "~%         (~S (" id)
                   (maphash #'(lambda (id2 val2)
                                (Format t "(~S ~S)" id2 val2))
                            val)
                   (Format t "))"))
               table))
    (Format t "~%)))"))

 (defun BuildHashTable (vals)
   (let ((temp (make-hash-table :size (length vals))))
     (mapcar #'(lambda (x)
                 (setf (gethash (car x) temp) (cadr x)))
             vals)
     temp))
 
(defun Build2HashTable (vals)
   (let ((temp (make-hash-table :size (length vals))))
     (mapcar #'(lambda (row)
                 (mapcar #'(lambda (col)
                             (update2Count (car row) (car col) temp (cadr col)))
                         (cadr row)))
             vals)
     temp))

