# vim fileencoding=utf-8
# vi:ts=4:et
#
# $Date: 2004/05/14 08:18:22 $
# $Revision: 1.13 $
# =====================================================

"""library for partition"""

# Return all the partitions of the number N.
# partition(3) -> (1,1,1), (2,1), (3,)

import sys
import itertools
import sets
import math
import copy

# young
import mathformat
import error
import util
import hook

try:
    reversed, sorted
except NameError, e:
    from compat import reversed, sorted

__all__     = [
        'partition',
        'number_of_partition',
        'Partition',
        'partition_upper_bound',
        ]

# cache partitions
_cache = {}


def partitions(n):
    # This function is based on the following recipe by David Eppstein.
    # http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/218332

    # base case of recursion: zero is the sum of the empty list
    if n == 0:
        yield ()
        return

    # modify partitions of n-1 to form partitions of n
    for p in partitions(n-1):
        yield p + (1,)
        if p and (len(p) < 2 or p[-2] > p[-1]):
            yield p[:-1] + (p[-1] + 1,)

class Partition(tuple):
    """class for partition
    """
    def size(self):
        # XXX
        # do we need this?
        """Return the sum of parts of the partition
        """
        return sum(self)

    def get_num_of_parts(self):
        """Return the number of parts of the partition.
        """
        # If partition = (3,2,1,1),
        # we say that the partition has 4 parts.
        return len(self)

    def conjugate(self):
        """Return the conjugate of the partition.
        """
        # Stanley EC1 p.39

        # (4,3,1,1,1) -> (5,2,2,1)

        # * * * *
        # * * *
        # *
        # *
        # *
        # --> transpose
        # * * * * *
        # * *
        # * *
        # *

        length = len(self)
        padded_partition = self + (0,)
        buff = []

        for i in range(length, 0, -1):
            diff = padded_partition[i-1] - padded_partition[i]
            if diff > 0:
                buff.extend([i] * diff)

        return Partition(buff)

    def __repr__(self):
        return mathformat.pprint_partition(self)

    __str__ = __repr__

    def plot(self):
        """Plot the Ferrers diagram(Ferres graph).
        """
        # (3,1,1)
        # -->
        # ***
        # *
        # *
        for num in self:
            print "*" * num
        else:
            print

    def dot_notation(self):
        """Display partition with dots and bars.
        For example, partition (2,1,1) is "..|.|.".
        """
        buff = ["." * num for num in self]
        print "|".join(buff)

    def isdistinct(self):
        """Tests whether or not the partition consists of distinct parts.
        """
        return len(self) == len(sets.ImmutableSet(self))

    def up(self):
        """Return upped partitions of the original one.
        """
        # XXX
        # Should I return SeqOfPartition instance
        # or just return a generator?
        tmp = SeqOfPartition()
        for part in itertools.imap(Partition, _up(self)):
            tmp.add(part)
        return tmp

        # XXX
        # generator version: 
        #for p in _up(self):
        #    yield Partition(p)

    def down(self):
        """Return downed partitions of the original one.
        """
        tmp = SeqOfPartition()
        for part in sets.Set(_down(self)):
            tmp.add(Partition(part))
        return tmp

    def get_hook(self):
        """return a Hook object
        """
        # Fulton p.53
        return hook.Hook(self)

class SeqOfPartition(object):
    """class for generating partitions
    """
    __slots__ = (
            'number',
            'seq',
            )

    def __init__(self, number=0):

        self.number = number
        self.seq    = []

        if number:
            self.init_partition()

    def __getitem__(self, index):
        return self.seq[index]

    def size(self):
        """Return the number of partitions
        """
        # Mathematica
        # Length[]

        # NOTE
        # By convention we agree that p_0(0) = p(0) = 1.
        # (Stanley, "EC Vol.1", P.28.)

        if self.number == 0:
            return 1
        else:
            return len(self.seq)
    __len__ = size

    def get_partition(self):
        """Return the partition
        """
        return self.seq

    def __iter__(self):
        return iter(self.seq)

    def add_single_partition(self, pat):
        """Add a new partition to a member.
        """
        # XXX
        # this might not be needed.
        # currently ,the functionality is same as add.
        if pat not in self.seq:
            self.seq.append(pat)
            self.number += 1

    def add(self, aPartition):
        """Add a partition to a sequence of partitions.
        This has no effect if the partition is already present.
        """
        if aPartition not in self.seq:
            self.seq.append(aPartition)
            self.number += 1

    def remove(self, partition):
        """Remove a partition from a sequence of patitions.
        This has no effect if the patition is not a member.
        """
        if partition in self.seq:
            self.seq.remove(partition)
            self.number -= 1

    def up(self):
        tmp = [candidate for partition in self.seq for candidate in partition.up()]
        self.seq = sorted(sets.Set(tmp))

    def down(self):
        tmp = [candidate for part in self.seq for candidate in part.down()]
        self.seq = sorted(sets.Set(tmp))


    def init_partition(self):
        """initialize partition.
        """
        try:
            self.seq = _cache[self.number]
        except KeyError, e:
            _cache[self.number] = self.seq = \
                    map(Partition, partitions(self.number))

    def plot(self):
        """Plot each partition
        """
        for pat in self.seq:
            pat.plot()

    def show_partition(self):
        """show each partition of the number this instance holds.
        """
        print "\n".join(self._show_partition())

    def _show_partition(self):
        return itertools.imap(mathformat.pprint_partition, self.seq)

    def dot_notation(self):
        """Display each partition in a dot-notation.
        """
        for part in self.seq:
            part.dot_notation()

    def distinct(self):
        """Return only distinct partitions.
        """
        # Return SeqOfPartition's instance whose partitions are all distinct.
        other = copy.deepcopy(self)
        other.seq = map(Partition, itertools.ifilter(Partition.isdistinct, self.seq))
        return other

    def __repr__(self):
        return "\n".join(self._show_partition())

    __str__ = __repr__

def _up(shape):
    """Up operation of diagram"""
    # (2,1,1)
    # -->
    # (3,1,1)
    # (2,2,1)
    # (2,1,1,1)

    # (3,2)
    # -->
    # (4,2)
    # (3,3)
    # (3,2,1)

    size_of_prev_line = shape[0] + 1
    for index, size_of_cur_line in enumerate(shape):
        if size_of_cur_line < size_of_prev_line:
            yield shape[:index] + (shape[index] + 1, ) + shape[index+1:]
        size_of_prev_line = size_of_cur_line
    else:
        # In this case, (2,1,1) will turn into (2,1,1,1)
        # trivial case. 
        yield shape + (1, )

def _down(shape):
    """Down operation of diagram"""
    # (2,1,1)
    # -->
    # (1,1,1)
    # (2,1)

    # (3,2)
    # -->
    # (3,1)
    # (2,2)

    shape = tuple(shape)
    # generator is a pair of (2,1,1) and (1,1)

    pair = util.pairwise(shape)

    for index, (x, y) in enumerate(pair):
        if x > y:
            # In this case, (2,1,1) will turn into (1,1,1)
            yield shape[:index] + (x-1, ) + shape[index+1:]
    else:
        last_element = shape[-1]
        if last_element == 1:
            # In this case, (2,1,1) will turn into (2,1)
            yield shape[:-1]
        else:
            # In this case, (3,2) will turn into (3,1)
            yield shape[:-1] + (shape[-1]-1, )

def _compare_max_num(number, part):
    """Tests whether or not the number is equal to or greater than
    any number of the partition."""
    # In short, if the partition is in order.
    # you only need to check if number >= part[0].

    # number = 3
    # part = (2,1)
    # --> OK

    # number = 2
    # part = (2,1,1,1)
    # --> OK

    # number = 3
    # part = (4,2,1)
    # --> NG

    return number >= part[0]


def partition(number):
    """partition(n) -> Return the partitions of n.
    """
    # argument validity check
    if not util.is_non_negative_integer(number):
        raise error.ArgumentError,\
                ("argument must be a non-negative integer.", number)

    return SeqOfPartition(number)

def number_of_partition(number, k=None):
    """number_of_partition(n [,k]) -> Return the number of partitions of n.
    If k is specified, return the number of partitions of n with k parts.
    """
    if util.is_non_negative_integer(k):
        return partition_with_k_parts(number, k)
    else:
        return partition(number).size()

def partition_with_k_parts(n, k):
    """Return the number of partitions of n with exactly k parts.
    """
    # Stanley EC1 p.28

    # argument validity check
    if not util.is_non_negative_integer(n):
        raise error.ArgumentError,\
                ("number must be a non-negative integer.", n)

    if not util.is_non_negative_integer(k):
        raise error.ArgumentError,\
                ("number must be a non-negative integer.", k)

    if n == k:
        return 1
    elif k == 0 or n == 0:
        return 0
    elif k == 1:
        return 1
    elif n < k:
        return 0
    return partition_with_k_parts(n-1, k-1) + partition_with_k_parts(n - k, k)

def partition_upper_bound(n):
    """Ramanujan's upper bound for number of partitions of n"""
    # http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/218332
    # See Also:
    # George E. Andrews:The Theory Of Partitions, Cambridge University Press (1998) pp.70.

    # argument validity check
    if not util.is_non_negative_integer(n):
        raise error.ArgumentError,\
                ("number must be a non-negative integer.", n)

    if n == 0:
        # special case
        return 1

    return int(math.exp(math.pi*math.sqrt(2.0*n/3.0))/(4.0*n*math.sqrt(3.0)))

