"""QTL files output from QTL Parser with two columns added at the end for expression and probeQuality value

Output file format is
pCutoffA, pCutoffB, maxDist, cisDist, minExp, maxExp, probeQCutoff
nTotal, nCommon, nNotCommon, nCisCommon, nTransCommon, nCisNotCommon, nTransNotCommon, aveCisDist, medianCisDist, aveNQTLs

Note:  The number of common QTLs may sometimes differ between sets.  This occurs when their are two QTLs in one data set within maxDist
of one qtl in the second data set.  The actual number of common QTLs should be taken to be the smaller of the two values. since this preserves
one-to-one correspondence between QTLs in the two data sets.  This will never happen when maxDist < 1/2 * (Distance originally used to choose
the best QTL peak) and will very seldom happen when cisDist is set reasonably low because of the requirement that QTLs operate in the same
mode to be considered the same QTLs.

Output consists of 36 diagnostic fields, reproduced as a tab delimited line for easy copying into Excel etc.:
file	pCutoffA	pCutoffB	LRSCutoffA	LRSCutoffB	maxDist	cisDist	minExp	maxExp	probeQCutoff	minP    nTotal	nCommon	nNotCommon	nCisCommon	nTransCommon	nCisNotCommon	nTransNotCommon	aveCisDist	stDevCisDist	medianCisDist	aveNQTLs	posAdd	negAdd	posAddCommon	negAddCommon	1 QTL	2 QTL	3 QTL	4 QTL	5 QTL	6 QTL	7 QTL	8 QTL	9 QTL	10 QTL


experiment parameters:
file	pCutoffA	pCutoffB	LRSCutoffA	LRSCutoffB	maxDist	cisDist	minExp	maxExp	probeQCutoff	minP

total and common QTL measures
nTotal	nCommon	nNotCommon	nCisCommon	nTransCommon	nCisNotCommon	nTransNotCommon

distances from peak marker to transcript location for accuracy measures
aveCisDist	stDevCisDist	medianCisDist

average QTL number and sign of additive effect (in BXDs and B6D2F2s, B6 effect is negative)
aveNQTLs	posAdd	negAdd	posAddCommon	negAddCommon

histogram of number of QTLs
1 QTL	2 QTL	3 QTL	4 QTL	5 QTL	6 QTL	7 QTL	8 QTL	9 QTL	10 QTL

"""

from stats import lmean, lmedian, lsamplestdev

class QTLComparer(object):

    def loadQTLFile(self, fileAsList, pCutoff, LRSCutoff, cisDist, minExp, maxExp, probeQCutoff, minP):
        QTLDictionary = {}
        
        for li in fileAsList:
            probeSet = li[0]
            qtlEntry = QTL(li[1],li[2],li[3],li[4],li[5],li[6],li[7],li[8],li[9],li[10],li[11],li[16],li[17],0,-1)
            if qtlEntry.P <= pCutoff and qtlEntry.LRS >= LRSCutoff and qtlEntry.exp >= minExp and qtlEntry.exp <= maxExp and qtlEntry.probesetQ >= probeQCutoff and qtlEntry.P >= minP:

                """fixes probeSet names between M430v2.0 and M430A/M430B"""

                """
                if probeSet.find("_A") > -1:
                    probeSet = probeSet[0:probeSet.find("_A")]
                if probeSet.find("_B") > -1:
                    probeSet = probeSet[0:probeSet.find("_B")]
                """


                """checks for cisQTLs trans=0 is a cisQTL"""
                
                if qtlEntry.gChr == qtlEntry.qChr and abs(qtlEntry.qMb - qtlEntry.gMb) < cisDist:
                    qtlEntry.qTrans = 0
                else:
                    qtlEntry.qTrans = 1
 

                """makes a new dictionary entry if needed or adds to existing dictionary key"""
                if QTLDictionary.has_key(probeSet):
                    keyValue = QTLDictionary[probeSet]
                    keyValue.append(qtlEntry)
                    QTLDictionary[probeSet] = keyValue
                else:
                    QTLDictionary[probeSet] = [qtlEntry]
                    
        return QTLDictionary



    def Comparer (self, setA, setB, maxDist):
        setA_keyList = setA.keys()
        for keyA in setA_keyList:
            if setB.has_key(keyA): #if there are QTLs for the same probeSet
                qtlA_num = -1
                keyAQTLs = setA[keyA]  # set A QTLs for keyA
                keyBQTLs = setB[keyA]  # set B QTLs for KeyB
                for QTL_A in keyAQTLs:
                    qtlA_num += 1
                    qtlB_num = -1
                    for QTL_B in keyBQTLs:
                        qtlB_num += 1
                        sameAdd = 0
                        sameMode = 0
                        if QTL_A.add > 0 and QTL_B.add > 0:
                            sameAdd = 1
                        if QTL_A.add < 0 and QTL_B.add < 0:
                            sameAdd = 1
                        if QTL_A.qTrans == 0 and QTL_B.qTrans == 0:
                            sameMode = 1
                        if QTL_A.qTrans == 1 and QTL_B.qTrans == 1:
                            sameMode = 1
                        if QTL_A.qChr == QTL_B.qChr and abs(QTL_B.qMb - QTL_A.qMb) <= maxDist and sameAdd == 1 and sameMode == 1:
                            setA[keyA][qtlA_num].qCommon = 1
                            setB[keyA][qtlB_num].qCommon = 1


        """print 'test xxxxxxxxxxxx'
        for keyA in setA:
            print keyA, setA[keyA]
        print 'test2 xxxxxxxxxxx'
        for keyB in setB:
            print keyB, setB[keyB]"""

        return [setA, setB]



    def makeStats(self, set):
        nTotal = 0
        nCommon = 0
        nNotCommon = 0
        nCisCommon = 0
        nTransCommon = 0
        nCisNotCommon = 0
        nTransNotCommon = 0

        posAdd = 0
        negAdd = 0
        posAddCommon = 0
        negAddCommon = 0

        posAddCommonCis = 0
        posAddCommonTrans = 0
        negAddCommonCis = 0
        negAddCommonTrans = 0

        qtlFreqList = []
        
        cisQTLDist = []         # list of distances between best marker and position of transcript
        nQTLs = []              # number of QTLs per gene
        
        keyList = set.keys()
        for key in keyList:
            for qtl in set[key]:
                nTotal += 1
                """separates cis and trans for common QTLs"""
                if qtl.qCommon == 1:
                    nCommon += 1
                    if qtl.qTrans == 1:
                        nTransCommon += 1
                    if qtl.qTrans == 0:
                        nCisCommon += 1
                """separates cis and trans for non-common QTLs"""
                if qtl.qCommon == 0:
                    nNotCommon += 1
                    if qtl.qTrans == 1:
                        nTransNotCommon += 1
                    if qtl.qTrans == 0:
                        nCisNotCommon += 1
                """finds number of additive effects greater than 0 (B6 high) or less than 0 (D2 high)""" 
                if qtl.add > 0:
                    posAdd +=1
                    if qtl.qCommon == 1:
                        posAddCommon += 1
                        if qtl.qTrans == 0:
                            posAddCommonCis += 1
                        else:
                            posAddCommonTrans += 1
                if qtl.add < 0:
                    negAdd +=1
                    if qtl.qCommon == 1:
                        negAddCommon += 1
                        if qtl.qTrans == 0:
                            negAddCommonCis += 1
                        else:
                            negAddCommonTrans += 1
                """makes list of distances between best marker and transcript position"""
                if qtl.qTrans == 0:
                    cisDist = abs(qtl.gMb - qtl.qMb)
                    cisQTLDist.append(cisDist)
            nQTLs.append(float(len(set[key])))

        """make average and median cisQTL distance between peak marker and transcript"""                
        if len(cisQTLDist) <> 0:
            aveCisDist = lmean(cisQTLDist)
            medianCisDist = lmedian(cisQTLDist)
            stdevCisDist = lsamplestdev(cisQTLDist)
        else:
            print 'no cis QTLs'
            aveCisDist = -1
            medianCisDist = -1
            
        """make average QTLs per probe set data and QTL histogram"""    
        if len(nQTLs) <> 0:
            aveNQTLs = lmean(nQTLs)
            histQTLs = histogram(nQTLs)
            for i in xrange(1,11):
                if histQTLs.has_key(i):
                    qtlFreqList.append(str(histQTLs[i]))
                else:
                    qtlFreqList.append("0")

        else:
            print 'no QTLs'
        return [qtlFreqList, nTotal, nCommon, nNotCommon, nCisCommon, nTransCommon, nCisNotCommon, nTransNotCommon, aveCisDist, stdevCisDist, medianCisDist, aveNQTLs, posAdd, negAdd, posAddCommon, negAddCommon, posAddCommonCis, posAddCommonTrans, negAddCommonCis, negAddCommonTrans]

                    
        
    def makeStatsOutput(self, parameters, file, stats, outFile):
            parameters = map(str, parameters)
            qtlHistList = stats[0]
            stats = map(str, stats[1:])
            output_print = file + "\t" + "\t".join(parameters) + "\t" + "\t".join(stats)+ "\t" + "\t".join(qtlHistList)
            print output_print
            output = output_print + "\n"
            file = open(outFile,mode='a') # open to append
            file.write(output)
            file.close()



    def makeQTLOutput(self, set, fileName):
        keyList = set.keys()
        
        for key in keyList:
            for qtl in set[key]:
                qtlLine = [key]             #1
                qtlLine.append(qtl.g)       #2
                qtlLine.append(qtl.gChr)    #3
                qtlLine.append(qtl.gMb)     #4
                qtlLine.append(qtl.gGMb)    #5
                qtlLine.append(qtl.gMarker) #6
                qtlLine.append(qtl.LRS)     #7
                qtlLine.append(qtl.add)     #8
                qtlLine.append(qtl.P)       #9
                qtlLine.append(qtl.qChr)    #10
                qtlLine.append(qtl.qMb)     #11
                qtlLine.append(qtl.qGMb)    #12
                qtlLine.append(qtl.exp)     #13
                qtlLine.append(qtl.probesetQ) #14
                qtlLine.append(qtl.qCommon) #15
                qtlLine.append(qtl.qTrans)  #16
                qtlLine = map(str, qtlLine)
                output = "\t".join(qtlLine) + "\n"
                file = open(fileName, mode='a') # open to append
                file.write(output)
            file.close



    def flowControl(self, fileA, fileB, maxDist, pCutoffA, pCutoffB, LRSCutoffA, LRSCutoffB, cisDistA, cisDistB, minExp, maxExp, probeQCutoff, outStatsFile, minP):
        setA_asList = splitFileIntoLists(fileA)
        setB_asList = splitFileIntoLists(fileB)
        
        setA = self.loadQTLFile(setA_asList, pCutoffA, LRSCutoffA, cisDistA, minExp, maxExp, probeQCutoff, minP)
        setB = self.loadQTLFile(setB_asList, pCutoffB, LRSCutoffB, cisDistB, minExp, maxExp, probeQCutoff, 0)
        
        comparedSets = self.Comparer(setA, setB, maxDist)

        comparedSetA = comparedSets[0]
        comparedSetB = comparedSets[1]

        statsA = self.makeStats(comparedSetA)
        statsB = self.makeStats(comparedSetB)

        parameters = [pCutoffA, pCutoffB, LRSCutoffA, LRSCutoffB, maxDist, cisDistA, cisDistB, minExp, maxExp, probeQCutoff, minP]

        self.makeStatsOutput(parameters, fileA, statsA, outStatsFile)
        self.makeStatsOutput(parameters, fileB, statsB, outStatsFile)

#        self.makeQTLOutput(comparedSetA, "out"+fileA)
#        self.makeQTLOutput(comparedSetB, "out"+fileB)
        


class QTL:    
    def __init__(self, g, gChr, gMb, gGMb, gMarker, LRS, add, P, qChr, qMb, qGMb, exp, probesetQ, qCommon, qTrans):
        self.g = g
        self.gChr = gChr
        self.gMb = float(gMb)
        self.gGMb = float(gGMb)
        self.gMarker = gMarker
        self.LRS = float(LRS)
        self.add = float(add)
        self.P = float(P)
        self.qChr = qChr
        self.qMb = float(qMb)
        self.qGMb = float(qGMb)
        self.exp = float(exp)
        self.probesetQ = float(probesetQ)
        self.qCommon = qCommon
        self.qTrans = qTrans



def splitFileIntoLists(filename):
    rows = []
    a_file = open(filename)
    for line in a_file:
        rows.append(line.split())
    a_file.close()
    return rows

    
def histogram(A, flAsList=False):
    """ Taken from example posted by Samuel Reynolds, 2004/07/08, example posted on Active State Programmers Network
    Returns histogram of values in array A."""
    
    H = {}
    for val in A:
        H[val] = H.get(val,0) + 1
    if flAsList:
    	return H.items()
    return H


if __name__ == "__main__":
    fileA = "IBR_M_0405_P_1000000.top"
    fileB = "BRF2_M_0805_P_1000000.top"

    maxDist = 10           # Mb distance allowable between peaks to be considered the same QTL
    
    cisDistA = 5           # Mb distance allowable between the best marker and gene location to be considered a cisQTL
    cisDistB = 10

    minExp = 0                # minimum expression value
    maxExp = 100             # maximum expression value
    
    pCutoffA = 0.05            # P value cutoff for set A
    pCutoffB = 0.05            # P value cutoff for set B

    minP = 0.0          # P value minimum (P value must be this high or higher to be included -- specifies range)
                        # note:  applies to FileA only. File B minP = 0 because confirmation should only be limited by an upper bound

    LRSCutoffA = 0         # set LRS cut off for set A
    LRSCutoffB = 0         # set LRS cut off for set B

    probeQCutoff = 0      # probeQ is a measure of probe quality.  Not yet used, so set to 0

    outStatsFile = "outStats.txt"  #output file name for statistics data

    myComparer = QTLComparer()
    myComparer.flowControl(fileA, fileB, maxDist, pCutoffA, pCutoffB, LRSCutoffA, LRSCutoffB, cisDistA, cisDistB, minExp, maxExp, probeQCutoff, outStatsFile, minP)

