# Usage: python buildRmatrix.py paramter_for_buildRmatrix.txt
#        Edit the variable TPM_TABLE for a different output file name.
#        Watch out NA values in TPM.txt, these genes don't have any gene expression information.
#
# Purpose: make a TPM table, where each row is a gene, and each column is an experiment.  The column name is RNA-seq experiment ID.
#
# 23 Dec 2016, hui, slcu
# Last modified 5 Apr 2017, hui, slcu
# Last modified 25 Oct 2019, hui, zjnu [Comments; add a variable WARN_NA to turn on/off print NA warnings.]

import os, sys, glob

#TPM_TABLE = '../Data/history/expr/TPM.txt'
TPM_TABLE = '../Data/history/expr/TPM.txt'
WARN_NA   = False

####################################
GLB_PARAM_SYMBOL    = '%%'
LCL_PARAM_SYMBOL    = '%'
DATA_SYMBOL         = '@'
####################################

def common_part(s):
    ''' s is expected to have this form: AT1G01020.1, remove .1 '''
    s = s.strip()
    index = s.find('.')
    if index < 0: # not found, -1
        return s
    return s[0:index]


def make_expression_dict(fname, myid):
    '''
    fname -- salmon file 
    myid -- RNA-seq experiment ID

    The retured value is a dictionary which looks like

    {
      'ID': RNA-seq experiment ID
      'isoform': 
         {
           'AT1G12345': [],
           'AT2G12345': [],
           ...
         }
    }

    Each gene ID (e.g., AT1G12345) has a number of isoforms which gives different expression levels.
    '''
    
    ID_COL  = 0 # Salmon's quant.sf file, first column is gene ID
    TPM_COL = 3 # Salmon's quant.sf file, fourth column is TPM

    if not os.path.exists(fname):
        print('ERROR [buildRmatrix.py]: file %s not exists.' % (fname))
        sys.exit()
        
    d = {'ID':myid, 'isoform':{}}

    f = open(fname)
    lines = f.readlines()
    f.close()    
    for line in lines[1:]: # ignore head line, Name Length EffectiveLength TPM NumReads
        line = line.strip()
        lst = line.split()
        gene_id = lst[ID_COL]
        tpm = float(lst[TPM_COL])
        common = common_part(gene_id) # gene id without .1, .2, etc.
        if not common in d['isoform']:
            d['isoform'][common] = [tpm]
        else:
            d['isoform'][common].append(tpm)

    return d


def get_max_expressed_isoform(g, d):
    if not g in d['isoform']:
        return -9
    lst = d['isoform'][g]
    return max(lst)
    

def save_TPM_table(gene_lst, dict_lst, fname):
    '''
    gene_lst: a list of genes
    dict_lst: a list of dictionaries.  Each dictionary contains gene expression inforamtion.  What is the detailed data structure of each dictionary?
    fname: where the gene expression level matrix will be saved.
    '''
    
    dir_name = os.path.dirname(fname)
    if not os.path.isdir(dir_name):
        os.makedirs(dir_name)
        
    if len(dict_lst) == 0:
        print('buildRmatrix.py: dict_lst is empty. Nothing to build.')
        sys.exit()

    f = open(fname, 'w')
    head = 'gene_id'
    #print('Merge %d tables.' % (len(dict_lst)))
    for d in dict_lst:
        head += '\t' + d['ID'] # d['ID'] is the RNA-seq samples's SRA id
    f.write('%s\n' % (head))
    total_count = 0 # number of total gene expression levels
    bad_count = 0   # number of NA gene expression levels.  We wish this number to be far smaller than total_count.

    missed_genes = {}
    for g in gene_lst:
        s = g
        for d in dict_lst:
            v = get_max_expressed_isoform(g, d)
            total_count += 1
            if v != -9:
                s += '\t' + '%4.2f' % (v)
            else:
                if WARN_NA:
                    print('WARNING [buildRmatrix.py]: %s not in %s.' % (g, d['ID']))
                s += '\t' + 'NA'
                bad_count += 1
                missed_genes[g] = 1
        f.write('%s\n' % (s))
    f.close()

    if 1.0 * bad_count / total_count > 0.0: 
        print('WARNING [buildRmatrix.py]: %s contains NA values!\n%d out of %d gene expression levels (%4.1f percent) are NAs.\n%d gene IDs are in your gene list but not in the results output by Salmon.' % (fname, bad_count, total_count, 100.0* bad_count/total_count, len(missed_genes)))


def get_dict_list(d):
    ''' A list of dictionaries, each element for one RNA-seq data '''
    dlst = []
    for myid in d['ID_LIST']:
        if myid in d:
            fname = d[myid]['LOCATION']
            d2 = make_expression_dict(fname, myid)
            dlst.append(d2)
    return dlst


def get_gene_list(fname):
    f = open(fname)
    lst = []
    for line in f:
        line = line.strip()
        if line != '':
            l = line.split()[0]
            lst.append(l)
    f.close()
    return lst


def get_key_value(s):
    lst = s.split('=')
    k, v = lst[0], lst[1]
    return (k, v)


def get_value(s, delimit):
    index = s.find(delimit)
    if index < 0:
        sys.exit()
    return s[index+1:].strip()


def make_data_dict(fname):
    '''
    fname - parameter_for_buildRmatrix.txt

    Return a dictionary which looks like
    
    {
      'ID_LIST': [],
      'SRR1':
        {
          'LOCATION': path to the salmon quant file, e.g., /home/lanhui/brain/Data/R/Mapped/public/SRR953400_quant.txt
        }
    }

    '''
    d = {'ID_LIST':[]} # ID_LIST is a list of RNA-seq experiment IDs
    f = open(fname)
    lines = f.readlines()
    f.close()
    for line in lines:
        line = line.strip()
        if line == '' or line.startswith('#'):
            continue
        if line.startswith(DATA_SYMBOL):
            s = line[line.rfind(DATA_SYMBOL[-1])+1:]
            s = s.strip()
            if s in d:
                print('Warning [buildRmatrix.py]: ID %s is duplicated.' % (s))
                sys.exit()
            d[s] = {'DATA_NAME':'', 'DATA_FORMAT':'', 'DESCRIPTION':'', 'LOCATION':'', 'NOTE':''}
            d['ID_LIST'].append(s)
        if line.startswith('DESCRIPTION:'):
            d[s]['DESCRIPTION'] = get_value(line, ':')
        elif line.startswith('DATA_FORMAT:'):
            d[s]['DATA_NAME'] = get_value(line, ':')            
        elif line.startswith('DATA_FORMAT:'):
            d[s]['DATA_FORMAT'] = get_value(line, ':')
        elif line.startswith('LOCATION:'):
            d[s]['LOCATION'] = get_value(line, ':')
        elif line.startswith('NOTE:'):
            d[s]['NOTE'] = get_value(line, ':')
        elif line.startswith(LCL_PARAM_SYMBOL) and not line.startswith(GLB_PARAM_SYMBOL):
            make_local_parameter(d[s]['PARAM'], line)

    return d


def make_global_param_dict(fname):
    f = open(fname)
    d = {'GENE_LIST':''} # change
    for line in f:
        line = line.strip()
        if line.startswith(GLB_PARAM_SYMBOL):
            s = line[line.rfind(GLB_PARAM_SYMBOL[-1])+1:]
            lst = s.split('\t')  # separate items by TAB
            for x in lst:
                if x != '':
                    k, v = get_key_value(x)
                    d[k] = v
    f.close()
    return d

## main
param_file = sys.argv[1]
global_param_dict = make_global_param_dict(param_file)
data_dict = make_data_dict(param_file)
TPM_TABLE = os.path.abspath(TPM_TABLE)
save_TPM_table(get_gene_list(global_param_dict['GENE_LIST']), get_dict_list(data_dict), TPM_TABLE)
#print('Done.  Check %s.' % (TPM_TABLE))