# This file is copyrighted by the author, and may not be redistributed without prior permission.
# Contact: biopython AT maubp DOT freeserve PERIOD co DO uk
# Peter, 25 June 2006
class distanceMatrix :
    """Store a symmetrix distance matrix with zero diagonal entries, and read and writes
    to Phylip distance matrix format

    Needs proper splicing support.
    Could try using Numeric/Numpy internally for more speed/less memory"""
    def __init__(self, names=None, distances=None, n=None) :
        """Create a distance matrix.

        If names is ommitted, names will be automatically assigned
        If distances are ommitted, zero will be assummed.
        If both are missing, then an n by n or empty matrix is returned.

        The distance can be a normal array, in which case we check for
        symmetry and zero diagonals.

        Also, a lower triangular form may be used, e.g. for a 3x3 case:
        [[], [x], [y, z]] or including the diagonal, [[0], [x, 0], [y, z, 0]] 
        rather than the fully symmetric version [[0, x, y], [x, 0, z], [y, z, 0]]
        """
        self._names=[]
        self._data=[[]]

        if names and distances :
            assert len(names) == len(distances)
            
        if names :
            assert n is None
            self._names = names
            n = len(names)
        elif distances :
            assert n is None
            self._names = ["Name %i" % (i+1) for i in range(0,len(distance))]
            n = len(distance)
        elif n > 0 :
            self._names = ["Name %i" % (i+1) for i in range(0,n)]
        else :
            self._names=[]
            n=0


        self._data=[]
        for i in range(0, n) :
            self._data.append([0.0 for j in range(0,i)])

        if distances :
            for i in range(0, len(distances)) :
                if len(distances[i]) < i :
                    print "Warning: Missing data from %s row" % (self._names[i])
                for j in range(0, min(i,len(distances[i]))) :
                    assert i > j, "Expected %i > %i" % (i,j)
                    self.set_distance(i, j, distances[i][j])
                for j in range(i, min(n, len(distances[i]))) :
                    assert i <= j
                    if i < j :
                        #print "Checking (%i,%i)==(%i,%i)" % (i,j,j,i)
                        assert distances[i][j]==distances[j][i], "Non symmetric distance"
                    else :
                        #print "Checking (%i,%i)==zero" % (i,j)
                        assert 0.0 == float(distances[i][j]), "Non zero diagonal"
        #self._names=["a","b","c"]
        #self._data=[[],[0.5],[0.1,0.2]]

    def add(self, name, distances=None, allowDuplicateNames=True) :
        """Add an additional row/column to distance matrix.

        If the matrix currently N by N, then then 'distances' should be a list of
        N distances (or N+1 including a final zero distance for this entry to itself)"""
        if name in self._names :
            if allowDuplicateNames :
                #print "WARNING entry %i is a repeated name: %s" % (len(self)+1,name)
                pass
            else :
                raise ValueError("Repeated name: " + name)

        if not distances : distances = [0] * len(self)
        
        if len(distances) == len(self)+1 :
            assert float(distances[-1])==0.0
            distances = distances[:-1]
        assert len(self) == len(distances), "Have %i entries already, got %i new distances " % (len(self), len(distances))
        
        self._names.append(name)
        self._data.append(map(float, distances))

    def name(self, index) :
        return self._names[index]

    def names(self) :
        """Returns a COPY of the the list of names"""
        return self._names[:]

    def __len__(self) :
        """Returns the number of rows/cols/names in the distance matrix"""
        assert len(self._names) == len(self._data)
        return len(self._names)

    def shape(self) :
        """The array-like property shape returns the tuple (n,n) when n is the number of rows/cols/names"""
        l = len(self)
        return (l,l)

    def array(self) :
        """Returns a COPY of the distance matrix as a list of lists of floats

        e.g. To convert to a Numeric array of floats, use Numeric.asarray(s.array(),'f')"""
        return [self.row(i) for i in range(0,len(self))]
    
    def get_distance(self, i, j) :
        """Read distances from the matrix

        Note object.get_distance(i,j) == object[i,j] however get_distance is
        slightly faster."""
        assert 0 <= i < len(self._names)
        assert 0 <= j < len(self._names)
        if i > j :
            return self._data[i][j]
        elif i < j :
            return self._data[j][i]
        else :
            return 0.0

    def set_distance(self, i, j, dist) :
        """Set distances in the matrix

        Note that in addition to object.set_distance(i,j, dist) you can also
        use object[i,j]=dist however set_distance is slightly faster."""
        assert 0 <= i < len(self._names)
        assert 0 <= j < len(self._names)
        if i > j :
            self._data[i][j] = float(dist)
        elif i < j :
            self._data[j][i] = float(dist)
        else :
            assert dist==0

    def row(self, index) :
        """Returns a COPY of that row/column of the symmetric distance matrix"""
        return [self.get_distance(i,index) for i in range(0, len(self._names))]

    def col(self, index) :
        """Returns a COPY of that row/column of the symmetric distance matrix"""
        return self.row(index)

    def __getitem__(self, index):
        """Read a distance from the matrix; NOT SPLICE AWARE

        Recommended syntax is object[i,j] which returns a single distance.
        
        Using object[i] returns a COPY of that row/column of the symmetric
        distance matrix, thus object[i][j] == object[i,j] but involves
        creating a temporay list.

        Note that object.get_distance(i,j) is slightly faster than object[i,j]"""
        try :
            n=len(index)
        except TypeError:
            #Happens for integers
            n=1
            
        if n==1 :
            return self.row(index)
        elif n==2 :
            return self.get_distance(index[0],index[1])
        else :
            raise IndexError("Invalid number of dimensions")
            

    def __setitem__(self, index, value):
        """Set a distance in the matrix; NOT SPLICE AWARE

        Use the syntax object[i,j]=dist but note that object.set_distance(i,j,dist)
        is slightly faster."""
        try :
            n=len(index)
        except TypeError:
            #Happens for integers
            n=1

        if n==1 :
            raise IndexError("Not supported")
        elif n==2 :
            return self.set_distance(index[0],index[1], value)
        else :
            raise IndexError("Invalid number of dimensions")

        
    def __str__(self) :
        lines = []
        name_length = max(map(len, self._names))
        for (i,name) in enumerate(self._names) :
            name = name + " "*(name_length-len(name))
            lines.append("%s [%s]" % (name, ", ".join(["%0.3f" % d for d in self.row(i)])))
        return "\n".join(lines)

    #def __repr__(self) :
    #    return str(self.__class__) + "\n" + str(self)

    def __eq__(self, other) :
        """Are two distance matrices equal?

        The must have the same names, in the same order.  Plus, of course, the same distances."""
        #First, the easy quick test
        if self._names <> other._names : return False
        #Next, the hard test
        return self._data == other._data

    def __ne__(self, other) :
        """Are two distance matrices non-equal?"""
        return not (self == other)
            

    def asphylip(self, lowerTriangular=False, truncateNames=True, ambiguousWarning=False, sep="\t") :
        """Returns a string representing the distance matrix in PHYLIP format.

        If lowerTriangular=True, then only a lower triangular matrix is output - the diagonal
        and all upper triangular entries are ommitted.  This roughly halves the resulting filesize.

        If lowerTriangular=False (the default) a full symmetrix matrix is generated.

        If truncateNames=True (the default) then names are truncated and space filled to exactly 
        ten characters only, in strict agreement with the PHYLIP definition.

        If truncateNames=False, then names are space filled to the length of the longest
        name (or 10 if longer), as done by Clustalw.

        Finally, 'sep' is the field separator and defaults to a tab.  This is not explicitly
        specified in the file format definition.
        """
        #TODO - Option to specify the number of decimal places used for the distances?
        warning=False
        output=[sep+str(len(self))]
        output_names = []
        
        if truncateNames :
            name_length = 10
        else :
            name_length = max(10, max(map(len, self._names)))
        for i in range(0, len(self)) :
            #Truncate name:
            name = self._names[i][0:name_length]
            #Space pad:
            name = name + " "*(name_length - len(name)) + sep
            if ambiguousWarning :
                if not warning :
                    warning = name in output_names
                    output_names.append(name)

            if lowerTriangular :
                data = self._data[i][0:i]
            else :
                data = self.row(i)
            output.append(name + sep.join(["%f" % d for d in data]))
        if ambiguousWarning and warning : print "Warning - resulting file has repeated names"
        return "\n".join(output)

def read_phylip_distance_matrix(handle, allowDuplicateNames=True) :
    """Will read a Phylip format file

    NOTE - Will only look at the LOWER triangular entries generating a symmetrix matrix.
    It is ASSUMED that the matrix in the file is symmetric as it should be.  If a full
    matrix file is supplied, we do check for a zero diagonal."""

    #handle = open(filename,"rU")
    
    line = handle.readline().rstrip("\n").rstrip("\r").strip()
    n = int(line)

    s = distanceMatrix()

    format=None
    
    i=0
    for line in handle.readlines() :
        line.rstrip("\n").rstrip("\r")
        parts = line.split()
        if parts :
            i=i+1
            #Include the diagonal entry (if present) so we can check it is zero.
            #The add function will convert the strings to floats.
            s.add(parts[0], parts[1:(i+1)], allowDuplicateNames)
    assert i==n, "Number of line,s %i, did not match header %i" % (i,n)
    handle.close()
    return s

def read_phylip_distance_matrix_from_file(filename, allowDuplicateNames=True) :
    """Will read a Phylip format file

    NOTE - Will only look at the LOWER triangular entries generating a symmetrix matrix.
    It is assumed that the matrix in the file is symmetrix as it should be.  If a full
    matrix file is supplied, we do check for a zero diagonal."""
    return read_phylip_distance_matrix(open(filename,"r"))


if __name__ == "__main__" :
    print "Running self tests..."
    from StringIO import StringIO

    def test_write_read(s, truncateNames=True) :
        assert s == read_phylip_distance_matrix(StringIO(s.asphylip(lowerTriangular=True,  truncateNames=truncateNames, ambiguousWarning=False))), \
               "Write/read in lower triangular form failed"
        assert s == read_phylip_distance_matrix(StringIO(s.asphylip(lowerTriangular=False, truncateNames=truncateNames, ambiguousWarning=False))), \
               "Write/read in full symmetric form failed"


    
    #############################################
    # Build matrix using lower triangular input #
    #############################################
    s = distanceMatrix()
    s.add("a")
    s.add("b",[0.1])
    s.add("c",[0.2,0.5])
    test_write_read(s)
    assert str(s) == "a [0.000, 0.100, 0.200]\nb [0.100, 0.000, 0.500]\nc [0.200, 0.500, 0.000]"

    print
    print "Simple 3x3 example:"
    print s


    ###########################################################
    # Build matrix using lower triangular input with diagonal #
    ###########################################################
    s2 = distanceMatrix()
    s2.add("a",[0])
    s2.add("b",[0.1,0])
    s2.add("c",[0.2,0.5,0])
    assert s == s2

    ####################################
    # Test construction with distances #
    ####################################

    #Set up a few 3x3 distance matrix inputs...
    distance_expressions = [
        [[0, 0.1, 0.2],[0.1, 0, 0.5], [0.2, 0.5, 0]], # full symmetric
        [[0,],[0.1, 0], [0.2, 0.5, 0]],               # lower triangular with zero diagonal
        [[],[0.1], [0.2, 0.5]]                        # lower triangular
    ]

    try :
        import Numeric
        d = Numeric.array([[0, 0.1, 0.2],[0.1, 0, 0.5], [0.2, 0.5, 0]])
        distance_expressions.append(d)
    except ImportError:
        print "Not testing loading from a Numeric array"
        pass

    for d in distance_expressions :
        assert s == distanceMatrix(["a","b","c"], d)


    ###############################################
    # Test readling fully symmetric PHYLIP format #
    ###############################################

    #Uses mixture of tabs and spaces to be as evil as possible!
    test_file_data = """     8
    V_Harveyi_PATH             0.000  0.524  0.697  0.710  0.643  0.700  0.671  0.734 
    B_subtilis_YXEM            0.524  0.000  0.672  0.663  0.671  0.692  0.639  0.729 
    B_subtilis_GlnH_homo_YCKK  0.697  0.672  0.000  0.516  0.622  0.710  0.664  0.749 
    YA80_HAEIN                 0.710  0.663  0.516  0.000  0.639  0.739  0.654  0.769   
    FLIY_ECOLI                 0.643  0.671  0.622  0.639  0.000  0.699  0.647  0.716  
    E_coli_GlnH                0.700  0.692  0.710  0.739  0.699  0.000  0.702  0.724  
    Deinococcus_radiodurans    0.671  0.639  0.664  0.654  0.647  0.702  0.000  0.663 
    HISJ_E_COLI                0.734  0.729  0.749  0.769  0.716  0.724  0.663  0.000"""
    from StringIO import StringIO
    s = read_phylip_distance_matrix(StringIO(test_file_data))

    print
    print "Small 8x8 example with non-truncated names:"
    print s

    test_write_read(s, truncateNames=False)

    #Now compare results of writing and reading a truncated matrix...
    s_truncated = read_phylip_distance_matrix(StringIO(s.asphylip(lowerTriangular=False,  truncateNames=True,  ambiguousWarning=False)))

    print
    print "Small 8x8 example with truncated names:"
    print s_truncated

    test_write_read(s_truncated, truncateNames=True)
    test_write_read(s_truncated, truncateNames=False)

    #This uses two spaces as the separator, has no white space before the size in the header
    #AND has some blank lines at the end.
    test_truncated_file_data = """8   
    V_Harveyi_  0.000000  0.524000  0.697000  0.710000  0.643000  0.700000  0.671000  0.734000
    B_subtilis  0.524000  0.000000  0.672000  0.663000  0.671000  0.692000  0.639000  0.729000
    B_subtilis  0.697000  0.672000  0.000000  0.516000  0.622000  0.710000  0.664000  0.749000
    YA80_HAEIN  0.710000  0.663000  0.516000  0.000000  0.639000  0.739000  0.654000  0.769000
    FLIY_ECOLI  0.643000  0.671000  0.622000  0.639000  0.000000  0.699000  0.647000  0.716000
    E_coli_Gln  0.700000  0.692000  0.710000  0.739000  0.699000  0.000000  0.702000  0.724000
    Deinococcu  0.671000  0.639000  0.664000  0.654000  0.647000  0.702000  0.000000  0.663000
    HISJ_E_COL  0.734000  0.729000  0.749000  0.769000  0.716000  0.724000  0.663000  0.000000


    """
    assert s_truncated == read_phylip_distance_matrix(StringIO(test_truncated_file_data))

    #######################
    # Test element access #
    #######################
    def element_check(s) :
        for i in range(0,len(s)) :
            assert s[i][i] == s[i,i] == 0
            assert s[i] == s.row(i) == s.col(i)
            for j in range(0,len(s)) :
                assert s[i][j] == s[i,j] == s.get_distance(i,j) == s.get_distance(j,i)

    element_check(s)
    s[0,1]=0.3333
    element_check(s)
    s.set_distance(0,1,0.66666)
    element_check(s)


    #Note this behavour:
    s[0,1]=0.3333
    s[0][1]=0.6666 #For now, this has no effect!!!
    assert s[0,1]==0.3333

    print
    print "Tests passed"

