# Usage: python dowload_and_map.py
#        python download_and_map.py run_ids.txt
#
#        Edit DAILY_DOWNLOAD_NUMBER and MIN_FILE_SIZE
#
# This program checks RNA_SEQ_INFO_FILE for not yet downloaded, *public* RNA-seq data, make a list of them, download and map using Salmon.  It is very important to prepare
# RNA_SEQ_INFO_FILE (see parse_end_xlm.py).  In fact, only first column of RNA_SEQ_INFO_FILE is required in this file, that is a list of RNA-seq IDs.
#
# Purpose: automate downloading RNA-seq files from ENA or SRA.  This program checks MAPPED_RDATA_DIR and RAW_RDATA_DIR to ensure that we are not re-mapping or re-downloading already mapped or downloaded data.
#
# Note: use download_data2() to download from SRA, and download_and_map_data() to download from ENA (closer to Cambridge so faster).  This script depends on get_TPM_by_salmon.py
#
# 23 DEC 2016, hui, slcu. Updated: 9 Feb 2017
# Last modified 10 APR 2017, hui, slcu
# Last reviewed 31 July 2018
# Last revised 10 Feb 2021

import os, sys, glob, json
import fnmatch
import time
import re
from datetime import datetime

##########################################################################################
from configure import DAILY_MAP_NUMBER, MIN_FASTQ_FILE_SIZE, RNA_SEQ_INFO_FILE, DOWNLOADED_SRA_ID_LOG_FILE, IGNORED_SRA_ID_LOG_FILE, UPDATE_NETWORK_LOG_FILE, MAPPED_RDATA_DIR, RAW_RDATA_DIR, SALMON_MAP_RESULT_DIR

FASTQ_DUMP_PATH = '/home/hui/software/sratoolkit/sratoolkit.2.8.0-ubuntu64/bin/fastq-dump'

##########################################################################################

def glob_files(directory, pattern):
    ''' return all file names (without paths) given directory and pattern '''
    result = []
    for root, dirnames, filenames in os.walk(directory):
        for filename in fnmatch.filter(filenames, pattern):
            result.append(filename)
    return result


def glob_files_include_path(directory, pattern):
    ''' return all file names (with paths) given directory and pattern '''
    result = []

    for root, dirnames, filenames in os.walk(directory):
        for filename in fnmatch.filter(filenames, pattern):
            result.append(os.path.join(root, filename))
    # all check *.txt files where downloaded files are recorded.
    for fname in glob.glob(os.path.join(directory, 'download*.txt')):
        f = open(fname)
        lines = f.readlines()
        f.close()
        for line in lines:
            line = line.strip()
            if fnmatch.fnmatch(os.path.basename(line), pattern):
                result.append(os.path.join(directory, line))
    return result


def get_list(fname):
    ''' Convert a file to a list, each line is an element in the list '''
    if not os.path.exists(fname):
        return []
    
    result = []
    f = open(fname)
    d = {}
    for line in f:
        line = line.strip()
        if line != '':
            lst = line.split()
            s = lst[0].strip() # SRR, ERR, or DRR id
            if (not s in d) and ('SRR' in s or 'ERR' in s or 'DRR' in s):
                d[s] = 1
                result.append(s)
    f.close()
    return result # only return unique elements


def make_download_list(mapped_dir, rna_data_info_dict):
    ''' 
    Make next n sample IDs.  These samples must have not been downloaded yet.  

    all_run_ids - a list of NextGen-Seq IDs to select from
    mapped_dir - contain all mapped samples
    rna_data_info_dict - a dictionary containing all RNA-seq samples from ENA.
    '''

    result = []
    mapped_files = glob_files(mapped_dir, '*_quant.txt')
    mapped_run_ids = get_list(DOWNLOADED_SRA_ID_LOG_FILE)
    small_ids = get_list(IGNORED_SRA_ID_LOG_FILE) # these files are too small
    for run_id in sorted(rna_data_info_dict.keys(), reverse=True): # SRR first, then ERR, then DRR
        include_me = True if rna_data_info_dict[run_id]['library_strategy'].lower() == 'rna-seq' and  rna_data_info_dict[run_id]['library_source'].lower() == 'transcriptomic' else False
        if not (run_id + '_quant.txt') in mapped_files and (not run_id in result) and (not run_id in small_ids) and (not run_id in mapped_run_ids) and include_me: # not mapped yet and is RNA-seq
            result.append(run_id)
    return result


def num_of_digits(s):
    count = 0
    for c in s:
        if c.isdigit():
            count += 1
    return count


def get_file_url(fname):
    ''' for wget '''
    f = open(fname)
    url_list = []
    for line in f:
        line = line.strip()
        if 'ftp://' in line and '.fastq.gz' in line:
            lst  = line.split()
            address = lst[-1].strip()
            if '.fastq.gz' in address and address.startswith('ftp://') and not address in url_list:
                url_list.append(address)
    f.close()
    return url_list


def get_file_size(fname):
    sz = 0
    f = open(fname)
    for line in f:
        line = line.strip()
        if line.startswith('==> SIZE'):
            lst = line.split()
            sz =  int(lst[-1])
    f.close()
    return sz


def get_remote_file_size(link):
    cmd = 'rm -f ../Data/temp/wget_temp_file1.txt'
    os.system(cmd)
    
    cmd = 'wget --spider %s 2> ../Data/temp/wget_temp_file1.txt' % (link)
    os.system(cmd)
    return get_file_size('../Data/temp/wget_temp_file1.txt')


def get_sample_id(fname):
    ''' extra id from file name'''
    index = fname.find('.fastq.gz')
    if index < 0:
        return ''

    s = fname[0:index]
    lst = s.split('_')
    return lst[0]


def download_and_map_data(lst, daily_map_num, dest):
    ''' Download data from ENA; fast (but can be interruptive) '''
    downloaded_files = [] # a list of paths to downloaded files, small files (size less than MIN_FASTQ_FILE_SIZE) won't be included in the list
    map_list = []

    if len(lst) < daily_map_num or daily_map_num < 1:
        return downloaded_files, map_list
    
    count = 0
    for line in lst: # lst - a list of run IDs
        run_id = line
        dir1 = line[0:6]
        dir2 = ''
        n = num_of_digits(line)
        address = ''
        if n == 6: # follow ENA's data path convention
            address = 'ftp://ftp.sra.ebi.ac.uk/vol1/fastq/%s/%s/' % (dir1, run_id)
        elif n == 7:
            dir2 = '00' + run_id[-1]
            address = 'ftp://ftp.sra.ebi.ac.uk/vol1/fastq/%s/%s/%s/' % (dir1, dir2, run_id)
        elif n == 8:
            dir2 = '0' + run_id[-2:]
            address = 'ftp://ftp.sra.ebi.ac.uk/vol1/fastq/%s/%s/%s/' % (dir1, dir2, run_id)
        elif n == 9:
            dir2 = run_id[-3:]
            address = 'ftp://ftp.sra.ebi.ac.uk/vol1/fastq/%s/%s/%s/' % (dir1, dir2, run_id)

        if os.path.exists('../Data/temp/wget_temp_file0.txt'):
            cmd = 'rm -f ../Data/temp/wget_temp_file0.txt'
            os.system(cmd)

        cmd = 'wget --spider -T 20 %s 2> ../Data/temp/wget_temp_file0.txt' % (os.path.join(address, '*.gz'))
        os.system(cmd)

        url_lst = get_file_url('../Data/temp/wget_temp_file0.txt')
        if url_lst == []:
            write_download_log_file(IGNORED_SRA_ID_LOG_FILE, run_id+'\n')            

        time.sleep(1)

        curr_lst = []
        for link in url_lst:
            sz = get_remote_file_size(link)
            if  sz >= MIN_FASTQ_FILE_SIZE:  # remote file must be big enough
                cmd = 'wget %s -P %s' % (link, dest)
                os.system(cmd)
                file_path = os.path.join(dest, os.path.basename(link))
                curr_lst.append(file_path)
                downloaded_files.append(file_path)
            else:
                print('[download_and_map.py] IGNORE [%d MB] %s' % (int(sz/1000000.0), link))
                file_name = os.path.basename(link)
                sample_id = get_sample_id(file_name)
                write_download_log_file(IGNORED_SRA_ID_LOG_FILE, sample_id+'\n')


        print(curr_lst)
        if curr_lst != []:
            salmon_map(curr_lst)
            map_list.append(run_id)            
            count += 1

        # Remove raw files (as they occupy lots of space)
        for f in downloaded_files:
            if os.path.exists(f):
                print('[download_and_map.py] To save space, I am removing %s.' % (f))
                os.remove(f)
                time.sleep(1)

        if count >= daily_map_num:
            return downloaded_files, map_list
                
        time.sleep(3)

    return  downloaded_files, map_list


def download_data2(lst, dest):
    ''' Download data from SRA, slow '''
    if not os.path.exists(FASTQ_DUMP_PATH):
        print('%s not exists.' % (FASTQ_DUMP_PATH))
        sys.exit()
        
    downloaded_files = [] # a list of paths to downloaded files, small files (size less than MIN_FASTQ_FILE_SIZE) won't be downloaded
    for line in lst:
        run_id = line.strip()
        cmd = '%s -N 1000000 --split-files --skip-technical %s' % (FASTQ_DUMP_PATH, run_id)
        print('\n' + cmd)
        os.system(cmd)
        if glob.glob('%s*fastq*' % (run_id)) != []: # files are successfully downloaded
            cmd = 'mv %s*fastq* %s' % (run_id, dest)
            print(cmd)
            os.system(cmd)
            fastq_file_lst = glob.glob( os.path.join(dest, '%s*fastq*' % (run_id)) )
            if len(fastq_file_lst) == 1: # rename file
                cmd = 'mv %s %s' % (fastq_file_lst[0], os.path.join(dest, run_id+'.fastq'))
                os.system(cmd)

            cmd = 'gzip %s' % (  os.path.join(dest, run_id + '*.fastq') )
            print(cmd)
            os.system(cmd)
            for fname in glob.glob( os.path.join(dest, '%s*gz' % (run_id)) ) :
                downloaded_files.append(fname)
        else:
            write_download_log_file(IGNORED_SRA_ID_LOG_FILE, run_id+'\n')

    return  downloaded_files


def salmon_map(lst):
    gz_file = '../Data/temp/gz_files.txt'
    if os.path.exists(gz_file):
        cmd = 'rm -f %s' % (gz_file) # remove old parameter file (if any). gz means gzip.  fastq files are usually zipped.
        os.system(cmd)
    
    f = open('../Data/temp/gz_files.txt', 'w')
    f.write('\n'.join(lst)) # lst contains paths to fastq files
    f.close()

    print('Start mapping %s' % ('  '.join(lst)))
    cmd = 'python get_TPM_by_salmon.py ../Data/temp/gz_files.txt > /dev/null 2>&1' # mapped files will be saved in result
    os.system(cmd)


def write_download_log_file(fname, s):
    if not os.path.exists(fname):
        f = open(fname, 'w')
    else:
        f = open(fname, 'a')
    f.write(s)
    f.close()


def write_network_log_file(s, fname):
    f = open(fname, 'a')
    curr_time = datetime.now().strftime('%Y-%m-%d %H:%M')
    s = '[' + curr_time + ']: ' + s
    if not '\n' in s:
        s += '\n'
    f.write(s)
    f.close()


def last_session_finished(fname):
    ''' Return True iff the last non-empty line of fname starts with DONE. '''
    if not os.path.exists(fname):
        return True
    f = open(fname)
    lines = f.readlines()
    f.close()
    # Check last status
    last_status = ''
    for line in lines:
        line = line.strip()
        if line.upper().startswith('START'):
            last_status = 'START'
        if line.upper().startswith('DONE'):
            last_status = 'DONE'
    return last_status == 'DONE'


def read_ena_data_info_json(fname):
    with open(fname) as json_data:
        json_dict = json.load(json_data)
    return json_dict


def read_run_ids_from_file(fname):
    f = open(fname)
    lst = []
    for line in f:
        line = line.strip()
        lst = line.split()
        if not line.startswith('#') and 'RR' in line:
            lst.append(lst[0])
    f.close()
    return list(set(lst))



    
## main

# For filtering RNA-seq data
if not os.path.exists(RNA_SEQ_INFO_FILE):
    print('[download_and_map.py] Must provide %s. See parse_ena_xml.py about how to make it.' % (RNA_SEQ_INFO_FILE))
    sys.exit()

# If there is no enough disk space for storing the downloaded sequencing data, then stop
available_G = 4 * os.statvfs('/home').f_bavail / (1024*1024) # compute available space (in G).  Each block has 4k bytes, work for Linux/UNIX systems only
if available_G < 2 * DAILY_MAP_NUMBER:
    print('[download_and_map.py] home directory does not have enough space (only %d G available) ' % (available_G))
    write_network_log_file('[download_and_map.py] home directory does not have enough space (only %d G available).' % (available_G), UPDATE_NETWORK_LOG_FILE)
    sys.exit()

if not last_session_finished(DOWNLOADED_SRA_ID_LOG_FILE): # last session not finished
    s = '[download_and_map.py] last downloading and mapping session not finished yet. Check file %s for details.' % (DOWNLOADED_SRA_ID_LOG_FILE)
    write_network_log_file(s, UPDATE_NETWORK_LOG_FILE)
    sys.exit()

rna_data_info_dict = read_ena_data_info_json(RNA_SEQ_INFO_FILE) # rna_data_info_dict contains only RNA-seq IDs.

# Generate DRR/ERR/SRR ids to download
if len(sys.argv) > 1:  # user has provided a list of IDs in a file
    download_list = read_run_ids_from_file(sys.argv[1])
    DAILY_MAP_NUMBER = len(download_list)
else:
    print('[download_and_map.py] Prepare download list ...')
    download_list = make_download_list(MAPPED_RDATA_DIR, rna_data_info_dict)
    print('[download_and_map.py] There are %d run IDs from which you could select %d of them.' % (len(download_list), DAILY_MAP_NUMBER))


# Make a record in log.txt
curr_time = datetime.now().strftime('%Y-%m-%d_%H%M') # append date info to newly created directories
write_download_log_file(DOWNLOADED_SRA_ID_LOG_FILE, 'START at %s\n' % (curr_time))

# Download these RNA-seq IDs and map them using salmon 
print('[download_and_map.py] Start downloading and mapping ...')
downloaded_file_paths, map_list = download_and_map_data(download_list, DAILY_MAP_NUMBER, RAW_RDATA_DIR) # or we can use the function download_data2 to download from SRA (in US).

# Move all files to MAPPED_RDATA_DIR
curr_time = datetime.now().strftime('%Y-%m-%d_%H%M') # append date info to newly created directories
new_dir_name = MAPPED_RDATA_DIR
if not os.path.isdir(new_dir_name):
    os.makedirs(new_dir_name)

# after mapping is finished, move all resulting files to new_dir_name (MAPPED_RDATA_DIR)
if glob.glob('%s/*_quant.txt' % (SALMON_MAP_RESULT_DIR.rstrip('/'))) != []:
    cmd = 'mv %s/*_quant.txt %s' % (SALMON_MAP_RESULT_DIR.rstrip('/'), new_dir_name)
    os.system(cmd)
    print('[download_and_map.py] Done. Check directory %s.' % (os.path.abspath(new_dir_name)))
else:
    print('[download_and_map.py] No quant files to move.')


write_download_log_file(DOWNLOADED_SRA_ID_LOG_FILE, '%s\n' % ('\n'.join(map_list)))
write_download_log_file(DOWNLOADED_SRA_ID_LOG_FILE, 'DONE at %s\n' % (curr_time))