"""
 huffman.py 

 generate huffman codes from probabilities

 history
  Feb 2014   original
  Dec 2019   added entropy function; upgraded to python3

 tested with python 3.6

 Jim Mahoney | cs.marlboro.college | MIT License
"""
from heapq import heappush, heappop
from numpy import log2

def get_probabilities(symbols):
    """ Return a dict of {symbol:probability} given a collection of symbols
        >>> p = get_probabilities(['a', 'b', 'a', 'c'])
        >>> [(x, p[x]) for x in sorted(p.keys())]
        [('a', 0.5), ('b', 0.25), ('c', 0.25)]
    """
    counts = {}  # {symbol:count}
    for s in symbols:
        counts[s] = 1 + counts.get(s, 0)
    n = float(len(symbols))
    return {s : counts[s]/n for s in counts.keys()}

def entropy(probabilities):
    """ return independent probability model info entropy """
    return sum(map(lambda p: p * log2(1/p), probabilities.values()))

class PriorityQueue:
    """ A min priority queue is a data structure which can
          * add values (push), 
          * find (peek) the smallest, and
          * remove (pop) the smallest. 
        >>> pq = PriorityQueue([5, 3, 10, 7])
        >>> (pq.pop(), pq.pop())
        (3, 5)
        >>> pq.push(9)
        >>> (pq.pop(), pq.pop())
        (7, 9)
    """
    def __init__(self, values=[], sortkey=lambda x:x):
        self.sortkey = sortkey
        self.data = []
        for value in values:
            self.push(value)
    def peek(self):
        return self.data[0][1]
    def push(self, value):
        heappush(self.data, (self.sortkey(value), value))
    def pop(self):
        (key, item) = heappop(self.data)
        return item
    def __len__(self):
        return len(self.data)
    def values(self):
        return [keyvalue[1] for keyvalue in self.data]

class BinaryTree:
    """ A node in a binary tree """
    # or equivalently the root of a binary subtree
    def __init__(self, name='', data=0, 
                 parent=None, left=None, right=None):
        self.name = name if name != '' else str(data)
        self.data = data
        self.parent = parent
        self.set_children(left, right)
    def set_children(self, left, right):
        self.left = left
        self.right = right
        for child in (left, right):
            if child != None:
                child.parent = self
    def __lt__(self, other):
        return self.data < other.data
    def graphviz(self, labels=False):
        """ return graphviz text description of tree """
        # for a directed graph, use 'digraph {' and '->' instead of '--'
        result = 'graph {\n'
        if labels:
            result += self._graphviz_labels()
        result += self._graphviz_subtree(use_ids=labels)
        return result + '}\n'
    def _graphviz_labels(self):
        """ one line for each node setting a label with name and data """
        result = '  {} [label="{} ({:0.2})"];\n'.format(
            id(self), self.name, self.data)
        for child in (self.left, self.right):
            if child:
                result += child._graphviz_labels()
        return result
    def _graphviz_subtree(self, use_ids):
        """ recursively return the description of this node and those below """
        result = ''
        for child in (self.left, self.right):
            if child:
                result += '  {} -- {};\n'.format(
                    self.name if not use_ids else id(self),
                    child.name if not use_ids else id(child))
                result += child._graphviz_subtree(use_ids)
        return result

class Huffman(PriorityQueue):
    """ A class which creates a Huffman code from
        a dict of symbol names and probabilities.
        >>> h = Huffman({'00': 0.6, '01': 0.2, '10': 0.1, '11': 0.1})
        >>> [(s, h.huffman_code[s]) for s in sorted(h.huffman_code.keys())]
        [('00', '1'), ('01', '00'), ('10', '010'), ('11', '011')]
        >>> h.mean_code_length()
        1.6
    """
    # 
    def __init__(self, symbol_probabilities):
        self.probabilities = symbol_probabilities
        self.symbols = sorted(self.probabilities.keys())
        self.huffman_tree = None   # root of huffman binary tree
        self.huffman_code = {}     # {symbol:code} dictionary
        PriorityQueue.__init__(self, sortkey = lambda x: x.data)
        for (symbol, probability) in self.probabilities.items():
            self.push(BinaryTree(name=symbol, data=probability))
        self.leaves = self.values() # save a copy of list of terminal nodes
        self._build_huffman_tree()
        self._build_huffman_code()
    def _build_huffman_tree(self):
        """ Build the huffman tree and store it in self.huffman_tree """
        # The idea is to repeatedly create a new node in the tree
        # whose probability is the sum of the two smallest which 
        # haven't yet been combined. Here this is accomplished with
        # two data structures: a PriorityQueue to keep track of which
        # probabilities still need to be looked at, and which is 
        # the smallest, and a BinaryTree collection.
        while len(self) > 1:
            item1 = self.pop()
            item2 = self.pop()
            data = item1.data + item2.data # probability
            self.huffman_tree = BinaryTree(name='*', data=data, 
                                           left=item1, right=item2)
            self.push(self.huffman_tree)
    def _build_huffman_code(self, node=None, code=''):
        """ Build a dictionary of {symbol:codeword} in self.huffman_code """
        if node == None:
            node = self.huffman_tree
        if node.name == '*':  # intermediate node ?
            self._build_huffman_code(node.left, code + '0')
            self._build_huffman_code(node.right, code + '1')
        else:                    # terminal node, i.e. an original symbol
            self.huffman_code[node.name] = code
    def mean_code_length(self):
        return sum([self.probabilities[sym] * len(self.huffman_code[sym]) 
                    for sym in self.symbols])

def print_demo_graphviz():
    """ print output suitable for graphviz (dot)
        >>> print_demo_graphviz()
        graph {
          0 -- 1;
          1 -- 3;
          0 -- 2;
        }
        <BLANKLINE>
    """
    # To generate demo_graph.png :
    # $ python huffman.py demo_graphviz | dot -Tpng > demo_graph.png
    nodes = [BinaryTree(name=i) for i in range(4)]
    nodes[0].set_children(nodes[1], nodes[2])
    nodes[1].set_children(nodes[3], None)
    print(nodes[0].graphviz())

def print_huffman_graphviz():
    """ print output suitable for graphviz (dot) for huffman tree """
    # To generate huffman_graph.png :
    # $ python huffman.py huffman_graphviz | dot -Tpng > huffman_graph.png
    h = Huffman({'00': 0.6, '01': 0.2, '10': 0.1, '11': 0.1})
    print(h.huffman_tree.graphviz(labels=True))
    
def main():
    import sys
    if len(sys.argv) > 1 and sys.argv[1] == 'demo_graphviz':
        print_demo_graphviz()
    if len(sys.argv) > 1 and sys.argv[1] == 'huffman_graphviz':
        print_huffman_graphviz()
    
if __name__ == '__main__':
    import doctest
    doctest.testmod()
    main()