# vim fileencoding=utf-8
# vi:ts=4:et
#
# $Date: 2004/05/14 08:20:31 $
# $Revision: 1.11 $
# =====================================================

"""library for Young tableaux"""

import sys
import sets
import operator
import itertools
import copy
import bisect

# young
import partition
import combination
import hook
import mathformat
import error
import util

# for CVS head users
try:
    reversed, sorted
except NameError:
    from compat import reversed, sorted, tee
    itertools.tee = tee

__all__ = [
    'young',
    'Young',
    'Skew',
    'Word',
]

# cache tableaux
# {key : value} = {partition : tableaux}
_cache = {}

class Young(list):
    """Class for Young
    """
    def __init__(self, content=[]):
        """Test whether a given sequence is a sequence of sequence.
        """
        if content:
            # XXX
            # make sure that if content is not empty, it's a sequence of a sequence.
            if not (isinstance(content[0], (list, tuple)) and isinstance(content, (list, tuple))):
                raise ValueError,\
                        "tableau must be a sequence of a sequence\n" +\
                        repr(content) + "\n" +\
                        "    was given"
        list.__init__(self, content)

    def __repr__(self):
        return mathformat.pprint_tableaux(self)

    __str__ = __repr__

    def get_partition(self):
        """Return the partitions of a tableaux.
        """
        return partition.Partition(itertools.imap(len, self))

    def get_hook(self):
        """return a Hook object
        """
        # Fulton p.53
        return hook.Hook(self.get_partition())

    def bump(self, element):
        """row-insert an item in a tableau"""
        if isinstance(element, (int, long)):
            self.single_bump(element, 0)

        # XXX
        # (Young, tuple, list)):
        # ->
        # (tuple, list)):
        # is OK?
        elif isinstance(element, (Young, tuple, list)):
            return self*element

    def single_bump(self, item, row=0):
        # bump a single element in a tableau
        try:
            self[row]
        except IndexError, e:
            # We need a new row to insert a bumped box.
            # [1,3] <- [2]
            # =>
            # [1,2] 
            # [3]
            self.append([item])
            return

        n = bisect.bisect_right(self[row], item)

        if n < len(self[row]):
            # [1,2,3] <- [2]
            # =>
            # [1,2,2]
            # [3]
            bumped, self[row][n] = self[row][n], item
            self.single_bump(bumped, row+1)
        else:
            # [1,2,2] <- [3]
            # =>
            # [1,2,2,3]
            self[row].append(item)

    def __mul__(self, other):
        # This method is used to row-bump a tableau in another tableau.
        
        for box in util.flatten_tableau(other, True):
            self.single_bump(box, 0)

    def toskew(self, other):
        """Return a skew tableau of l/m"""
        # Fulton p.12

        assert isinstance(other, Young)

        # Test if a given tableau is contained in the current tableau.
        for result in itertools.imap(operator.ge, self.get_partition(), other.get_partition()):
            if result < 0:
                # ge's return value is negative, positive, or zero.
                raise ValueError, "\n" + str(other) + "is not containd in\n" + str(self)
        return Skew(copy.deepcopy(self), other)

    def toword(self):
        """Return a Word"""
        # Fulton P.17

        return Word(util.flatten_tableau(self, True))

    # foo /  bar     # obj.__div__(foo, bar)
    # foo /= bar     # obj.__idiv__(foo, bar)
    __idiv__ = __div__ = toskew

class Skew(list):
    """Class for skew diagram
    """
    # NOTE
    # This implementation is in its experimental stage.

    def __init__(self, large, small):
        """Smalarge tableau must be containd in large tableau.
        """
        list.__init__(self, large)
        for i, n in enumerate(small.get_partition()):
            self[i][:n] = [None] * n

    def __repr__(self):
        return mathformat.pprint_skew_tableaux(self)

    __str__ = __repr__

    def jeu_de_taquin(self, tab=None):
        """slide a box until there's no inside corner."""
        # jeu de taquin/sliding/rectification
        # See Fulton P. 12- for more details.

        if not tab:
            yng = copy.deepcopy(self)
        else:
            yng = tab

        # TODO
        # How jeu de taquin is implemented should be described here 
        # in more details before the algorithm slips my mind.

        for i, row in enumerate(yng[::-1]):
            # find the lowest row which has None in its row.
            if None not in row:
                continue

            # XXX
            # debug
            #print i
            for j, col in enumerate(row[::-1]):
                # find the rightmost index of None in a given row.
                if col != None:
                    continue

                # r : index of row in Skew tableau
                # c : index of column in Skew tableau

                r = len(yng) -i -1
                c = len(row) -j -1

                # XXX
                # debug
                #print r, c

                # In the Ascii Art below,
                # @ denotes a None to be slided and * denotes a box
                # Note : tab[r][c] == None == @

                if i == 0:
                    # There's no row below `row`.
                    if j == 0:
                        # * * *
                        # * @
                        if c == 0:
                            # * * *
                            # @
                            # =>
                            # * * *

                            # last row is not needed.
                            yng = yng[:r]
                        else:
                            # * * *
                            # * @
                            # =>
                            # * * *
                            # *

                            # remove last box in a row.
                            yng[r] = yng[r][:-1]
                    else:
                        # * *
                        # @ *
                        # =>
                        # * *
                        # * @

                        # swap None with the right box
                        yng[r][c], yng[r][c+1] = yng[r][c+1], yng[r][c]
                else:
                    # i != 0
                    if j == 0:
                        try:
                            # * @
                            # * *
                            # =>
                            # * *
                            # * @

                            # swap None with the box below
                            yng[r][c], yng[r+1][c] = yng[r+1][c], yng[r][c]
                        except IndexError, e:
                            # * * @
                            # *
                            # =>
                            # * *
                            # *

                            # remove last box in a row.
                            yng[r] = yng[r][:-1]
                    else:
                        # i != 0 and j != 0
                        try:
                            below, right = yng[r+1][c], yng[r][c+1]
                        except IndexError, e:
                            # * @ *
                            # *
                            # =>
                            # * * @
                            # *

                            # swap None with the right box
                            yng[r][c], yng[r][c+1] = yng[r][c+1], yng[r][c]
                        else:
                            if below <= right:
                                # @ y
                                # x
                                # => if x <= y
                                # x y
                                # @

                                # swap None with the box below
                                yng[r][c], yng[r+1][c] = yng[r+1][c], yng[r][c]

                            else:
                                # @ y
                                # x
                                # => if x > y
                                # y @
                                # x

                                # swap None with the right box
                                yng[r][c], yng[r][c+1] = yng[r][c+1], yng[r][c]

                return self.jeu_de_taquin(yng)
        else:
            return Young(yng)

    # alias
    rectify = taquin = jeu_de_taquin


def get_candidate():
    """
    this function is intended to replace get_empty_space()
    """
    return

class YoungGenerator(list):
    """
    Based on partitions of N,
    generate all Young diagrams
    """

    __slots__ = (
            'shape',       # partition
            'size',        # size
            )

    def __init__(self, shape, size=0):
        self.shape= shape[:]  # shape := partition

        if not size:
            # (3,2,1,1) -> 3 + 2 + 1 + 1 = 7
            self.size = sum(shape)
        else:
            self.size = size

    def tableaux_initializer(self):
        """Based on a shape, create the container of the tableaux
        """
        # (3,1,1) -> [[None, None, None],[None], [None]]
        return [[None]*num for num in self.shape]

    def get_empty_space(self, tableaux):
        """Return empty space where a new number can be put in.
        """
        prev_pos = self.size

        for i, row in enumerate(tableaux):
            for j, pos in enumerate(row):
                if pos:
                    # ++
                    #
                    # go to the next col
                    continue
                else:
                    if j >= prev_pos:
                        # tableaux is as follows:
                        # +-
                        # +-
                        pass
                    else:
                        # tableaux is as follows:
                        # ++
                        # +-

                        prev_pos = j
                        yield (i,j)
                    # go to the next row
                    # go down the row
                    break

    def __repr__(self):
        buff = itertools.imap(mathformat.pprint_tableaux, self)

        return "\n".join(buff)

    __str__ = __repr__

    def set_tableaux(self):
        """Set tableaux
        """
        #self = [Young(x) for x in self.generate()]
        #self = map(Young, self.generate())
        for diagram in self.generate(self.tableaux_initializer(), 1):
            self.append(Young(diagram))

    #def generate(self, tableaux = [], index = 1):
    def generate(self, tableaux, index = 1):
        """Generate tableaux
        """
        if not tableaux:
            tableaux = self.tableaux_initializer()

        if index > self.size:
            #yield tableaux + []
            yield copy.deepcopy(tableaux)
            return

        for col, row in self.get_empty_space(tableaux):
            tab = tableaux[:]

            # set the number
            tab[col][row] = index

            for diagram in self.generate(tab, index +1):
                yield diagram + []
            tab[col][row] = None


class YoungTableaux(dict):
    """This class holds a sequence of Tableaux
    """
    __slots__ = (
            'total',
            'square',
            'seq_of_partition',
            )

    def __init__(self, seq_of_partition):
        self.seq_of_partition = seq_of_partition
        self.total  = self.calc_total()
        self.square = self.calc_square()

    def refresh(self):
        """refresh values of self.total and self.square
        """
        self.total  = self.calc_total()
        self.square = self.calc_square()

    def __iter__(self):
        # NOTE
        # __iter__ just returns dictionary's values without sorting,
        # so the order is not always the same.
        return itertools.chain(*self.values())

    def set_tableaux(self):
        """Set the tableaux
        """
        if self:
            self.clear()

        for part in self.seq_of_partition:
            try:
                self[part] = _cache[part]
            except KeyError, e:
                yd = YoungGenerator(part)
                yd.set_tableaux()
                _cache[part] = yd
                self[part]   = yd

    def calc_total(self):
        # XXX
        #return sum(itertools.imap(len, self.itervalues()))
        return sum(map(len, self.itervalues()))

    def size(self):
        """Return the total number of the diagram
        """

        # TODO
        # we need a better name for this method.

        if self.total:
            return self.total

        self.total = self.calc_total()
        return self.total

    __len__ = size

    def get_size_by_hook(self):
        """how many diagrams we have?
        """
        self.total = sum([hook.Hook(part).hook_length_formula() for part in self.seq_of_partition])
        return self.total

    def up(self):
        """Up operation of diagram"""
        self.seq_of_partition.up()

        # Since we changed the original seq_of_partition,
        # we need to update the data about tableaux.
        self.set_tableaux()

        self.refresh()

    def down(self):
        """Down operation of diagram"""

        self.seq_of_partition.down()

        # Since we have changed the original seq_of_partition,
        # we need to update the data about tableaux
        self.set_tableaux()

        self.refresh()

    def calc_square(self):
        def square_help(part):
            return len(self[part]) ** 2
        return sum(map(square_help, self.iterkeys()))

    def get_square(self):
        if not self.square:
            self.square = self.calc_square()

        return self.square


    def report(self):
        """Display the statistical info about the tableaux
        """

        # sample
        """\
        partition             number  square
        ----------------------------------------
        (4)                        1       1
        (3,1)                      3       9
        (2,2)                      2       4
        (2,1,1)                    3       9
        (1,1,1,1)                  1       1
        ----------------------------------------
        total                     10      24
        """
        print '%-20s%8s%8s'%('partition', 'number', 'square')
        print "-" * 40

        # initialize
        self.square = self.total = 0
        total = squares = 0

        for part in sorted(self.iterkeys(), reverse=True):
            num       = len(self[part])
            total    += num
            square    = num * num
            squares  += square
            print '%-20s%8d%8d'%(mathformat.pprint_partition(part).ljust(20), num, square)

        self.square = squares
        self.total  = total
        print "-" * 40
        print '%-20s%8d%8d'%('total', self.total, self.square)

    def get_partition(self):
        """Return all partitions of the tableaux
        """
        return self.seq_of_partition

    def display(self):
        """Display all the tableaux
        """
        print self

    def __repr__(self):
        return "\n".join(self._formatter())

    __str__ = __repr__

    def _formatter(self):
        """Formatter for tableaux.
        return as a generator.
        """
        for part in sorted(self.iterkeys(), reverse=True):
            yield `part`
            for tableaux in self[part]:
                yield mathformat.pprint_tableaux(tableaux)


class Word(list):
    """
    This class represents Word
    """

    def __repr__(self):
        # example of words
        # (5 6)(4 4 6)(2 3 5 5)(1 2 2 3)

        return 'Word[' + '(' + \
               ')('.join(itertools.imap(lambda x:' '.join(itertools.imap(str, x)), word_breaker(self))) + \
               ')' + ']'

    __str__ = __repr__

    def __mul__(self, other):
        # Fulton P.19-20
        # Proposition 1 & its Corollary

        if isinstance(other, Word):
            yng = self.toyoung()
            yng * other.toyoung()

            return self.__init__(yng.toword())

        yng = self.toyoung()
        yng.bump(other)

        return self.__init__(yng.toword())

    def toyoung(self):
        """Return a Young instance.
        """
        return Young(list(reversed(list(word_breaker(self)))))

def word_breaker(seq):
    """
    break Word into groups in ascending order.
    """
    # Fulton P. 17
    # Example:
    # 5 6 4 4 6 2 3 5 5 1 2 2 3
    # ==>
    # 5 6 | 4 4 6 | 2 3 5 5 | 1 2 2 3
    #
    # list(5,6,4,4,6,2,3,5,5,1,2,2,3) 
    # == [[5, 6], [4, 4, 6], [2, 3, 5, 5], [1, 2, 2, 3]]
    buf = []
    prev = 0

    if not seq:
        yield []
        return

    for num in seq:
        if prev <= num:
            # ascending order
            buf.append(num)
        else:
            # split by this element and create a new group
            yield buf
            buf = [num]
        prev = num
    if buf:
        yield buf

def func_chain(func, iterable):
    """chain-like function"""
    # similar to itertools.chain
    for element in iterable:
        for result in func(element):
            yield result

def ispartition(iterable):
    """Tests whether or not iterable is a valid partition
    """
    for x, y in itertools.izip(iterable, sorted(iterable, reverse=True)):
        if not util.is_non_negative_integer(x):
            raise error.PartitionFormatError, \
                    ("each element must be integer", iterable)
        if x != y:
            raise error.PartitionFormatError,\
                    ("partition is not in descending order", iterable)
    return True

def young_of_sequence(shape):
    """Return all Young diagrams of the given shape.
    """
    # partition based

    # check if the shape is a valid partition.
    #ispartition(shape)

    if not isinstance(shape, partition.Partition):
        shape = partition.Partition(shape)
    sop = partition.SeqOfPartition()
    sop.add(shape)
    yt = YoungTableaux(sop)
    yt.set_tableaux()
    return yt

def young_of_number(number):
    """Return all Young diagrams of partition(n)
    """
    # number based

    shapes = partition.partition(number)
    yt = YoungTableaux(shapes)
    yt.set_tableaux()
    return yt

def young(arg):
    """young(arg) -> Return all the standard Young tableaux.

    If arg is a number,
    return all the tableaux of partition(the size of a partition is the number)

    If arg is a partition,
    return all the tableaux of the partition.
    """
    # XXX
    # there're two ways to call young().
    # [1]
    # young(3)       # arg = (number,)
    # [2]
    # young((2,1,1)) # arg = (partition,)

    if util.is_non_negative_integer(arg):
        return young_of_number(arg)
    elif isinstance(arg, (tuple, list)):
        return young_of_sequence(arg)
    else:
        raise TypeError, "argument must be integer or partition"

