#!/usr/bin/python
# Copyright 2007 by Peter Cock.  All rights reserved.
# This code is part of the Biopython distribution and governed by its
# license.  Please see the LICENSE file that should have been included
# as part of this package.

"""Code for dealing with sequence alignments.

In particular, this file defines an Alignment class (superceeding the existing
class in defined in the Bio.Align.Generic module)"""

from Bio.SeqRecord import SeqRecord
from Bio import Alphabet
from Bio.Seq import Seq
from LocatableSeq import LocatableSeq

#TODO - Make this a subclass of Bio.Align.Generic.Alignment, or vice-versa?
#Note we are explicitly subclassing the list object here.
class Alignment(list):
    """Multiple Sequence Alignment class.

    Can be treated as both a list of SeqRecords, and as an array or characters"""
    def __init__(self, records, alphabet=Alphabet.single_letter_alphabet):
        """Initialize a new Alignment object.

        Arguments:
        records - A list or iterator/generator of SeqRecord objects,
                  all of the same sequence length.
                  
        alphabet - Optional alphabet; all the records must have this alphabet
                   of a subclass of it.
        """
        self.alphabet = alphabet
        #Get the error checking done by our append method
        for record in records :
            self.append(record)
        self.set_ali_is_ref(True)
        self.set_nd(' ')

    def set_ali_is_ref(self, ali_is_ref):
        self._ali_is_ref = ali_is_ref
        for record in self:
            record.seq.internal_ref = not self._ali_is_ref
    def set_nd(self, nd):
        self._nd = nd
        for record in self:
            record.seq.nd = self._nd


    def __len__(self) :
        """Returns the number of sequence records in the alignment.

        Use the get_alignment_length() method to get the sequence length
        """
        #Note - My only reason to explicitly define this is to set the docstring
        return list.__len__(self)

    def __str__(self) :
        """Returns a multi-line string representing the alignment

        This is output is intended to be readable, but large alignments are
        shown truncated.
        """
        cols = self.get_alignment_length()
        rows = len(self)
        answer = "%s alignment with %i rows and %i columns\n" \
                 % (str(self.alphabet), rows, cols)
        show_rows = min(20, rows)
        if cols <= 60 :
            answer += "\n".join(["%s %s" % (rec.seq.tostring(), rec.id) \
                              for rec in self[:show_rows]])
        else :
            answer += "\n".join(["%s..%s %s" % (rec.seq.tostring()[:54], \
                                                rec.seq.tostring()[-3:], rec.id) \
                              for rec in self[:show_rows]])
        if show_rows < rows :
            answer += "\n..."
        return answer

    def __repr__(self) :
        """Returns a representation of the object for debugging"""
        return "<Bio.Align.Alignment instance (%i records of length %i, %s) at %x>" \
               % (len(self), self.get_alignment_length(), repr(self.alphabet), id(self))
        #This version is useful for doing eval(repr(alignment)), but is can be VERY long
        #return "Bio.Align.Alignment(%s, %s)" \
        #       % (list.__repr__(self), repr(self.alphabet))

    def __iter__(self) :
        """Iterate over alignment rows as SeqRecord objects.

        e.g.

        for record in align :
            print record.id
            print record.seq"""
        #Note - My only reason to explicitly define this is to set the docstring
        return list.__iter__(self)

    def __getitem__(self, index):
        """Retrieve a sub-alignment, row, column or element.

        Depending on the indices, you can get a SeqRecord objects
        (representing a single row), strings (for a single columns or
        single characters) or another alignment (representing some or
        part of the alignment).

        align[r,c] gives a single character as a string
        align[r] gives a row as a SeqRecord
        align[r,:] or align[r,c1:c2] gives all or part of a row as a Seq
        align[:,c] or align[r1:r2,c] gives all or part of a column as a Seq
        align[:] and align[:,:] give a copy of the alignment

        Anything else gives a sub alignment, e.g.
        align[0:2] or align[0:2,:] uses only row 0 and 1
        align[:,1:3] uses only columns 1 and 2
        align[0:2,1:3] uses only rows 0 & 1 and only cols 1 & 2
        """
        #First we deal with one-dimensional indices.
        if isinstance(index, int) :
            #e.g. result = align[x]
            #Return a SeqRecord
            record = list.__getitem__(self, index)
            assert isinstance(record, SeqRecord)
            return record
        elif isinstance(index, slice) :
            #e.g. result = align[x:y]
            #or   result = align[x:y:z]
            #or   result = align[:]  (? see __getslice__  ?)
            #Return a sub alignment containing only
            #some/all records (with full length sequences)
            print "!!!"
            answer = self.__class__(list.__getitem__(self, index), \
                                    self.alphabet)
            assert isinstance(answer, Alignment)
            return answer

        n = len(index)
        if n <> 2 :
            raise TypeError, "Only expect one or two dimensional indices"
        row_index = index[0]
        col_index = index[1]
        if not (isinstance(row_index, int) or isinstance(row_index, slice)) \
        or not (isinstance(col_index, int) or isinstance(col_index, slice)) :
            raise TypeError, "Expect integers or slices"

        if isinstance(row_index, int) and isinstance(col_index, int) :
            #e.g. result = align[x,y]
            #Return a single character string
            return list.__getitem__(self,row_index).seq[col_index]
        elif isinstance(col_index, int) :
            #e.g. result = align[:,col]
            #or   result = align[x:y,col]
            #or   result = align[x:y:z,col]
            #Return a (partial) column as a Seq object
            from Bio.Seq import Seq
            return Seq("".join([rec.seq[col_index] \
                                for rec in list.__getitem__(self,row_index)]), \
                       self.alphabet)
        elif isinstance(row_index, int) :
            #e.g. result = align[row,:]
            #or   result = align[row,x:y]
            #or   result = align[row,x:y:z]
            #Return a (partial) row as a Seq object
            #TODO - Get the SeqRecord to do the splicing, instead of the Seq
            #which may give another SeqRecord...
            record = list.__getitem__(self,row_index)
            assert isinstance(record, SeqRecord)
            return record.seq[col_index]

        assert isinstance(row_index, slice)
        assert isinstance(col_index, slice)
        #e.g. result = align[:,:]
        #or   result = aling[:,5:10]
        #or   result = aling[3:4,5:10]
        #Return a sub alignment containing
        #some/all records and some/all columns

        #answer = Alignment([], self.alphabet)
        answer = self.__class__([], self.alphabet)
        for rec in list.__getitem__(self, row_index) :
            #TODO - Assuming SeqRecord splicing gives another SeqRecord,
            #       let the SeqRecord do the column-splice and take care
            #       of preserving the id/name/description and any annotations.
            assert isinstance(rec, SeqRecord)
            #answer.append(rec[col_index])
            answer.append(SeqRecord(rec.seq[col_index], \
                                    id=rec.id,
                                    name=rec.name,
                                    description=rec.description))
        assert isinstance(answer, Alignment)
        return answer

    def __getslice__(self, i, j) :
        #Seem need this method otherwise align[:] gives a list on Python 2.4
        answer = self.__class__(list.__getslice__(self,i,j), self.alphabet)
        assert isinstance(answer, Alignment)
        return answer

    def __setslice__(self, i, j, value) :
        #TODO - Support this?
        #Would need error checks like the append method
        raise NotImplementedError

    def __setitem__(self, index, value) :
        #TODO - Support this? With double indexing?
        #Would need error checks like the append method
        raise NotImplementedError

    #TODO - Support double indexed __delitem__ ?
    #Wait until SeqRecord support it?

    #TODO - Support dropping columns via double indexed __delitem__ ?
    #Wait until SeqRecord supports deleting slices?
    #How does a numpy array do this?

    #TODO - Add docstring for __imul__
    #(it seems to work fine - but I can't think of a reason to use it!)

    def __insert__(self, index, record) :
        #TODO - Support this?
        #Would need error checks like the append method
        raise NotImplementedError    

    def __add__(self, other) :
        """Add two alignments whose sequences are the same length.

        other - Another Alignment or a list of SeqRecord objects.

        Think of this as adding two lists of SeqRecord objects,
        (i.e. adding more rows).  The sequences must all have the same
        lengths, and compatible alphabets.

        Usage:
        alignment_one = Alignment(...)
        alignment_two = Alignment(...)
        my_alignment = alignment_one + alignment_two

        Returns a new alignment"""
        try :
            other_alphabet = other.alphabet
        except AttributeError :
            #Must be a list of SeqRecords, we'll use use the current alphabet
            other_alphabet = self.alphabet
        if isinstance(other_alphabet, self.alphabet.__class__) :
            joint_alphabet = self.alphabet
        elif isinstance(self.alphabet, other_alphabet.__class__) :
            joint_alphabet = other_alphabet
        else :
            raise ValueError, "The alignments have incompatible alphabets"
        del other_alphabet
        
        answer = self.__class__([], joint_alphabet)
        answer.extend(self)
        answer.extend(other)
        assert isinstance(answer, Alignment)
        return answer

    def __radd__(self, other) :
        """Add more SeqRecords whose sequences are the same length.

        other - Another Alignment, or a list/interator/generator
                of SeqRecord objects.

        Usage:
        my_alignment = Alignment(...)
        extra_rows = Alignment(...)
        my_alignment += extra_rows
        """
        #TODO - Explicitly check the other.alphabet?
        #Right now we just check the aphabet of other's records which
        #is more flexible.
        #
        #TODO - Check everything is compatible before changing self?
        #Otherwise any update that fails can leave self partly modified.
        #Note we would have to turn any generators into lists, but they
        #would end up in memory anyway as part of this alignment.
        self.extend(other)
    
    def get_alignment_length(self) :
        """Return the maximum length of the alignment.

        This function will go through and find this length
        by finding the maximum length of sequences in the alignment.
        """
        if len(self) == 0 :
            # no records, is length zero or undefined?
            return 0
        if isinstance(self[0].seq, LocatableSeq):
            length = None
            for record in self:
                l = len(record.seq) + record.seq.start
                if not length:
                    length = l
                elif length < l:
                    length = l
        elif isinstance(self[0].seq, Seq):
            #This is perhaps being paranoid, but double check the lengths agree:
            length = len(self[0].seq)
            for record in self:
                assert length == len(record.seq), \
                    "Alignment corrupted - has sequences of different lengths"
        return length
        
    def append(self, record) :
        """Add another SeqRecord (row) to the alignment."""
        if not isinstance(record, SeqRecord) :
            raise TypeError, \
                  "You must supply alignment rows as SeqRecord objects"
        if not isinstance(record.seq.alphabet, self.alphabet.__class__) :
            raise ValueError, \
                  "Additional records (rows) must have a compatible alphabet"
        if not isinstance(record.seq, LocatableSeq):
            if len(self) > 0 and len(record.seq) <> self.get_alignment_length() :
                raise ValueError, \
                      "Additional records (rows) must have the same sequence length"
        #Get the list class to deal with this
        list.append(self, record)
        if isinstance(record.seq, LocatableSeq):
            #we should set the external_length in all the seqs
            l = self.get_alignment_length()
            for record in self:
                record.seq.set_external_seq_length(l)


    def extend(self, records) :
        """Add additional SeqRecord objects (rows) to the alignment."""
        #Get the error checking done by our append method
        for record in records :
            self.append(record)
    
    def get_column(self,col):
        """Retrieve a single column as a string [LEGACY METHOD].

        Returns a string for the requested column.

        This method is included only for backwards compatibility with the
        Bio.Align.Generic.Alignment class, and may be deprecated in a
        future release of Biopython."""
        return self[:,col].tostring()

    def get_all_seqs(self) :
        """Retrieve all the rows as a list of SeqRecord objects [LEGACY METHOD].

        This method is included only for backwards compatibility with the
        Bio.Align.Generic.Alignment class, and may be deprecated in a
        future release of Biopython."""
        return list(self)

    def get_seq_by_num(self, number):
        """Retrieve a single sequence (row) as a Seq object [LEGACY METHOD]

        Returns a Seq object for the requested row.

        This method is included only for backwards compatibility with the
        Bio.Align.Generic.Alignment class, and may be deprecated in a
        future release of Biopython."""
        return self[number].seq
    
    def add_sequence(self, descriptor, sequence):
        """Add another sequence (row) to the alignment (as a string) [LEGACY METHOD].

        This method is included only to provide limited backwards compatibility
        with the Bio.Align.Generic.Alignment class, and may be deprecated in a
        future release of Biopython."""
        if not isinstance(sequence, str) :
            raise TypeError, \
                  "This method expects the sequence as a plain string"
        self.append(SeqRecord(Seq(sequence, self.alphabet),
                              id = descriptor, description = descriptor))

import unittest

class LocatableSeqTests(unittest.TestCase):
    def test_Alignment(self):
        from LocatableSeq import LocatableSeq
        alpha = Alphabet.generic_nucleotide
        #0 1 2 3 4 5 6 7 8 9
        #  T A G T C G T
        #T T A G
        #      G T C G T T T
        seq1 = 'TAGTCGT'
        seq2 = 'TTAG'
        seq3 = 'GTCGTTT'
        s1 = LocatableSeq(seq1, alpha, 1)
        s2 = LocatableSeq(seq2, alpha, 0)
        s3 = LocatableSeq(seq3, alpha, 3)
        sr1 = SeqRecord(s1, id = 'seq1')
        sr2 = SeqRecord(s2, id = 'seq2')
        sr3 = SeqRecord(s3, id = 'seq3')
        ali = Alignment([sr1, sr2, sr3], alpha)
        sr = ali[0]
        self.failUnless(sr.seq.start == 1)
        sr = ali[1]
        self.failUnless(sr.seq.start == 0)
        srs = ali[0:2]
        self.failUnless(srs[1].seq.start == 0)
        srs = ali[0:1] 
        self.failUnless(srs[0].seq.start == 1)
        c = ali[0,2]
        self.failUnless(c == 'A')
        c = ali[0,0]
        self.failUnless(c == ' ')
        seq = ali[:,1]
        self.failUnless(seq.tostring() == 'TT ')
        seq = ali[1,2:7]
        self.failUnless(seq.tostring() == 'AG   ')



if __name__ == "__main__" :

    print "Mini self test..."
    from Bio.Seq import Seq

    def pretty(align) :
        answer = "Alignment, %i by %i\n" \
                 % (len(align), align.get_alignment_length())
        answer += "\n".join([rec.seq.tostring() for rec in align])
        return answer    

    raw_data = ["ACGATCAGCTAGCT", "CCGATCAGCTAGCT", "ACGATGAGCTAGCT"]
    rec_names = ["Alpha", "Beta", "Gamma"]
    seq_recs = [SeqRecord(Seq(data, Alphabet.generic_dna), id=name) \
                for (data, name) in zip(raw_data, rec_names)]

    for a in [Alignment(seq_recs, Alphabet.generic_nucleotide),
              Alignment(seq_recs, Alphabet.generic_dna),
              Alignment(seq_recs)]:
        print
        print "Using a.alphabet = %s" % repr(a.alphabet)
        print a

        #Iterating over the rows...
        for rec in a :
            assert isinstance(rec, SeqRecord)
        for r,rec in enumerate(a) :
            assert isinstance(rec, SeqRecord)
            assert raw_data[r] == rec.seq.tostring()
        print "Alignment iteraction as SeqRecord OK"

        #Check getting rows back
        for r in range(-len(raw_data), len(raw_data)) :
            rec = a[r]
            assert isinstance(rec, SeqRecord)
            assert raw_data[r] == rec.seq.tostring()
            assert a[r].seq == a.get_seq_by_num(r)
        print "Alignment row access as SeqRecord OK"

        #Check getting rows back as Seq
        for r in range(-len(raw_data), len(raw_data)) :
            row = a[r,:]
            assert isinstance(row, Seq)
            assert raw_data[r] == row.tostring()
            assert a[r,:].tostring() == a.get_seq_by_num(r).tostring()
            assert a[r,:].tostring() == a[r].seq.tostring()

        print "Alignment row access as Seq OK"

        #Check getting elements back:
        for r in range(-len(raw_data), len(raw_data)) :
            for c in range(-a.get_alignment_length(), a.get_alignment_length()) :
                element = a[r,c]
                assert isinstance(element, str)
                assert element == raw_data[r][c]
        print "Alignment element access as string OK"

        #Check getting columns back:
        for c in range(-a.get_alignment_length(), a.get_alignment_length()) :
            if c >= 0 :
                assert a[:,c].tostring() == a.get_column(c)
            col = a[:,c]
            assert isinstance(col, Seq)
            assert isinstance(col.alphabet, a.alphabet.__class__)
            assert len(col) == 3
            assert col.tostring() == "".join([row[c] for row in raw_data])

            col = a[1:2,c]
            assert isinstance(col, Seq)
            assert isinstance(col.alphabet, a.alphabet.__class__)
            assert len(col) == 1
            assert col.tostring() == raw_data[1][c]

            col = a[0:2,c]
            assert isinstance(col, Seq)
            assert isinstance(col.alphabet, a.alphabet.__class__)
            assert len(col) == 2
            assert col.tostring() == raw_data[0][c] + raw_data[1][c]

            col = a[-1:,c]
            assert isinstance(col, Seq)
            assert isinstance(col.alphabet, a.alphabet.__class__)
            assert len(col) == 1
            assert col.tostring() == raw_data[-1][c]
     
        print "Alignment column access as Seq OK"

        #Check getting alignments back:
        sub = a[:] #does this copy annotation?
        assert isinstance(sub, Alignment)
        assert pretty(sub) == pretty(a)

        sub = a[:,:] #does this copy annotation?
        assert isinstance(sub, Alignment)
        assert pretty(sub) == pretty(a)
     
        sub = a[0:len(raw_data),:]
        assert isinstance(sub, Alignment)
        assert pretty(sub) == pretty(a)

        sub = a[1:2]
        assert isinstance(sub, Alignment)
        assert len(sub.get_all_seqs()) == 1
        assert sub[0].seq.tostring() == raw_data[1]
        assert sub[0,:].tostring() == raw_data[1]

        sub = a[1:3]
        assert isinstance(sub, Alignment)
        assert len(sub.get_all_seqs()) == 2
        assert sub[0].seq.tostring() == raw_data[1]
        assert sub[0,:].tostring() == raw_data[1]
        assert sub[1].seq.tostring() == raw_data[2]
        assert sub[1,:].tostring() == raw_data[2]

        sub = a[0:2]
        assert isinstance(sub, Alignment)
        assert len(sub.get_all_seqs()) == 2
        assert sub[0].seq.tostring() == raw_data[0]
        assert sub[0,:].tostring() == raw_data[0]
        assert sub[1].seq.tostring() == raw_data[1]
        assert sub[1,:].tostring() == raw_data[1]

        #print pretty(sub)
        print "Alignment access as sub alignment OK"
        
    print
    try :
        a = Alignment(seq_recs, Alphabet.generic_rna)
        assert False, "Should have failed!"
    except ValueError, e :
        print "As expected, creating alignment with incompatible alphabet failed:"
        print str(e)

    print

    b = Alignment([SeqRecord(Seq("ACGATGAGCCAGCT", Alphabet.generic_dna), id="Delta"),
                   SeqRecord(Seq("ACCATGAGCCAGCT", Alphabet.generic_dna), id="Epsilon"),
                   SeqRecord(Seq("ACCATGAGCAAGCT", Alphabet.generic_dna), id="Gamma"),
                   SeqRecord(Seq("ACGATGAGCAAGCT", Alphabet.generic_dna), id="Zeta")],
                  Alphabet.generic_dna)

    c = Alignment([SeqRecord(Seq("DEKLSLKDSLKDLA", Alphabet.generic_protein), id="P1"),
                   SeqRecord(Seq("DEKISLKDSVKDLA", Alphabet.generic_protein), id="P2"),
                   SeqRecord(Seq("DEKISLKDSVKELA", Alphabet.generic_protein), id="P2"),
                   SeqRecord(Seq("DDKLSLKDSVKELA", Alphabet.generic_protein), id="P4")],
                  Alphabet.generic_protein)
    
    print "a:", a
    print "b:", b
    print "c:", b
    assert len(a) == 3
    assert len(b) == 4
    assert len(c) == 4

    print
    print "Checking __add__ next:"
    print "a+b:", a+b
    assert len(a) + len(b) == len(a+b)
    print "b+a:", b+a
    assert len(a) + len(b) == len(b+a)
    assert len(a) == 3
    assert len(b) == 4
    try :
        print "b+c:", b+c
        assert False, "Should have failed!"
    except ValueError, e :
        #Good!
        print "b+c: Fails, %s" % str(e)

    print "c:", c
    print "c[0:2]:", c[0:2]
    print "c[3:4]:", c[3:4]
    print "c[0:2]+c[3:4]:", c[0:2]+c[3:4]

    print
    print "Checking __radd__ next:"
    print "doing a+=b ..."
    a+=b
    assert len(a) == 7
    assert len(b) == 4
    assert [rec.id for rec in a] == ['Alpha', 'Beta', 'Gamma', 'Delta', 'Epsilon', 'Gamma', 'Zeta']
    print "a:", a
    print "doing b+=a ..."
    b+=a
    print "b:", b
    assert len(a) == 7
    assert len(b) == 11


    print "Done"
    unittest.main()
