# Usage
# -----
# python get_TPM_by_salmon.py file-containing-paths-to-fastq-gz-files
# An example file-containing-paths-to-fastq-gz-files, gz_files.txt
# Edit the first few capitalised variables in this file.
#
# Purpose
# ------
# 
# Build index, get TPM values for all fastq iles.
#
# 30 NOV 2016, SLUC, hui
# Last reviewed 31 July 2018
# Last modified by Hui 10 Sep 2019

import sys, os, glob, shutil
from configure import SALMON, SALMON_INDEX, TRANSCRIPTOME, SALMON_MAP_RESULT_DIR, KMER

#TRANSCRIPTOME   = '/home/hui/tair10/AtRTD2_19April2016.fa'
#TRANSCRIPTOME   = '../Data/information/ath_genes_index_v2.fa'

def build_salmon_index(transcriptome_file, salmon_index_dir, k):
    if not os.path.exists(SALMON_INDEX):
        os.makedirs(SALMON_INDEX)
        cmd = '%s index -t %s -i %s --type quasi -k %d' % (SALMON, transcriptome_file, salmon_index_dir, k)
        os.system(cmd)


def assert_file_exist(s):
    if not os.path.exists(s):
        print('File %s not exists.' % (s))
        sys.exit()
    

def salmon_fatal_error(fname):
    ''' Return True iff the file fname contains i wont proceed. '''
    if not os.path.exists(fname):
        return False
    f = open(fname)
    lines = f.readlines()
    f.close()
    for line in lines:
        line = line.strip()
        if 'I won\'t proceed' in line:
            return True
    return False


def get_TPM(src_dir, file_id, salmon_index, result_dir):
    lst = sorted( glob.glob(os.path.join(src_dir, file_id + '*.fastq*')) )  
    lst2 = sorted( glob.glob(os.path.join(src_dir, file_id + '*_*.fastq.gz')) ) # _1.fastq and _2.fastq
    num_file = len(lst)
    num_file2 = len(lst2)

    if not os.path.isdir(result_dir):
        os.makedirs(result_dir)

    dest_dir = os.path.join(result_dir, file_id + '_transcript_quant')
    if not os.path.isdir(dest_dir):
        os.makedirs(dest_dir)

    if num_file == 1 and num_file2 < 2: # a single fastq.gz file
        file_path = lst[0]
        print(file_path)
        assert_file_exist(file_path)
        cmd = '%s quant -i %s -l A -r %s -o %s' % (SALMON, salmon_index, file_path, dest_dir)
        os.system(cmd)
    elif num_file2 >= 2:
        file_path1 = lst2[0]
        file_path2 = lst2[1]
        print(file_path1)
        print(file_path2)
        assert_file_exist(file_path1)
        assert_file_exist(file_path2)        
        cmd = '%s quant -i %s -l A -1 %s -2 %s -o %s' % (SALMON, salmon_index, file_path1, file_path2, dest_dir)
        print(cmd)
        os.system(cmd)
    elif num_file2 < 2:
        print('Warning: skip %s as it has less than two _*.fastq.gz files' % (file_id))
        return        
    else:
        print('Warning: skip %s as it has more than two fastq.gz files' % (file_id))
        return

    output_file_name = os.path.join(result_dir, file_id + '_quant.txt')
    if os.path.exists( os.path.join(dest_dir, 'quant.sf') ):
        if not salmon_fatal_error('%s/%s_transcript_quant/logs/salmon_quant.log' % (SALMON_MAP_RESULT_DIR.rstrip('/'), file_id)):
            cmd = 'cp %s %s' % (os.path.join(dest_dir, 'quant.sf'), output_file_name)
            os.system(cmd)
        shutil.rmtree(dest_dir)


def get_id(s):

    index = s.find('_1.fastq')
    if index > 0:
        return s[:index]
    
    index = s.find('_2.fastq')
    if index > 0:
        return s[:index]
        
    index = s.find('_3.fastq')
    if index > 0:
        return s[:index]

    index = s.find('.fastq')
    if index > 0:
        return s[:index]

    return 'NA'


def get_src_dir_and_file_id(fname):
    ''' Return a dictionary where key is SRR/ERR/DRR id, and value is tuple (path, a number) '''
    result = {}
    f = open(fname)
    for line in f:
        line = line.strip()
        index = line.rfind('/')
        if index == -1:
            path = './'
        else:
            path = line[:index]
        id = get_id(line[index+1:])
        if not id in result:
            result[id] = (path, 1)
        else:
            t = result[id]
            result[id] = (path, t[1] + 1)
    f.close()
    return result

        
### build salmon index
build_salmon_index(TRANSCRIPTOME, SALMON_INDEX, KMER)
fname = sys.argv[1]  # a file return by find ../Data/R/Raw -name "*.gz"
src_id = get_src_dir_and_file_id(fname)
for k in src_id:
    src_dir = src_id[k][0]
    file_id = k
    get_TPM(src_dir, file_id, SALMON_INDEX, SALMON_MAP_RESULT_DIR)