# Usage: python create_edges0.py parameter_for_net.txt
#
# Purpose: tissue specific 
#
# Quickly create edges using all samples in TPM.txt (with the same
# tissue). tfs and targets are from target_tf.txt.  Results will be
# written to
# ../Data/history/edges/many_targets/edges.txt.simple.correlation.tissue.date
# target_tf.txt is produced by make_target_tf.py.
# 
#
# 26 JAN 2017, hui, slcu
# Last modified 13 June 2017, hui, slcu
# Last modified  8 Aug  2019, hui, zjnu

import sys, os, operator, itertools, glob
from datetime import datetime
from configure import UPDATE_NETWORK_LOG_FILE
from geneid2name import make_gene_name_AGI_map_dict, get_gene_name
from param4net import make_global_param_dict

TARGET_FILE       = '../Data/temp/all_targets.txt'
TF_FILE           = '../Data/temp/all_tfs.txt'
R_SCRIPT_FILE     = 'correlation_per_tissue.R'
TISSUE_INFO_FILE  = '../Data/information/experiment.and.tissue.txt' # make sure this file is the same as TISSUE.FILE in R_SCRIPT_FILE
HISTORY_DIR       = '../Data/history/edges/many_targets'   # edges.txt.* files are here


def get_value(s, delimit):
    lst = s.split(delimit)
    return lst[1].strip()


def get_gene_list(fname):
    result = []
    f = open(fname)
    for line in f:
        line = line.strip()
        lst = line.split()
        result.append(lst[0])
    f.close()
    return result


def make_tf_dict(fname):
    d = {}
    f = open(fname)
    for line in f:
        line = line.strip()
        lst = line.split('\t')
        target = lst[0]
        tf     = lst[1]
        cond   = lst[2].split()
        if not target in d:
            d[target] = {tf:cond}
        else:
            d[target][tf] = cond
    f.close()
    return d


def get_targets_and_tfs(fname):
    f = open(fname)
    target_lst = []
    tf_lst = []
    for line in f:
        line = line.strip()
        lst = line.split('\t')
        target = lst[0]
        tf     = lst[1]
        target_lst.append(target)
        tf_lst.append(tf)
    f.close()
    return sorted(list(set(target_lst))), sorted(list(set(tf_lst)))


def write_lst_to_file(lst, fname):
    f = open(fname, 'w')
    for x in lst:
        f.write(x + '\n')
    f.close()


def establish_edges(corr_fname, target_tf_fname, result_fname, agi2name_dict, tissue_dict, loglikhood_dict):
    big_tf_dict = make_tf_dict(target_tf_fname)
    f = open(corr_fname)
    lines = f.readlines()
    f.close()

    result = ''
    for line in lines:
        line = line.strip()
        lst = line.split('\t')
        target = lst[0]
        tf     = lst[1]
        score  = '%4.2f' % (float(lst[2]))
        tissue = lst[3]
        num_rnaseq_id = lst[4]
        loglike = '-9999.0'
        if tissue in loglikhood_dict:
            loglik = loglikhood_dict[tissue]
        if target in big_tf_dict and tf in big_tf_dict[target]:
            target_str = target + ' ' + get_gene_name(target, agi2name_dict)
            tf_str     = tf     + ' ' + get_gene_name(tf,     agi2name_dict)
            score_str  = score
            cond_str   = ' '.join(big_tf_dict[target][tf])
            curr_date =  datetime.now().strftime('%Y%m%d')
            rnaseq_subset = '.'
            if tissue in tissue_dict:
                rnaseq_subset = ' '.join(list(set(tissue_dict[tissue])))
            s = '\t'.join([target_str, tf_str, score_str, 'all', num_rnaseq_id, cond_str, loglik, curr_date, score_str, tissue])
            result += s + '\n'

    f = open(result_fname, 'w')
    f.write(result)
    f.close()


def get_tissue_from_filename(s, d):
    for k in d:
        if k in s:
            return k, d[k]
    return 'unknown', '-9999.0'


def make_tissue_dict(fname):
    f = open(fname)
    lines = f.readlines()
    f.close()
    d = {}
    for line in lines[1:]:
        line = line.strip()
        if line != '':
            lst = line.split('\t')
            k = lst[0] # run.id
            v = lst[4]
            d[k] = v
            k2 = v.split('.')[0]  # broad tissue category, ignore subcategories, for example, flower.anther, only keep flower.
            if not k2 in d:
                d[k2] = [k]
            else:
                d[k2].append(k)

    return d


def target_tf_file_compare_same(fname1, fname2):
    if not os.path.exists(fname1):
        return False
    if not os.path.exists(fname2):
        return False
    f1 = open(fname1)
    s1 = f1.read()
    f1.close()
    f2 = open(fname2)
    s2 = f2.read()
    f2.close()
    return s1 == s2


########## main ##################################################
param_file = sys.argv[1] # a single prameter file
glb_param_dict = make_global_param_dict(param_file)
agi2name_dict = make_gene_name_AGI_map_dict(glb_param_dict['GENE_ID_AND_GENE_NAME'])

target_tf_fname = '../Data/information/target_tf.txt'
if not os.path.exists(target_tf_fname):
    write_log_file('[create_edges0B.py] Critical file %s does not exists.' % (target_tf_fname), UPDATE_NETWORK_LOG_FILE)
    sys.exit()

all_targets, all_tfs = get_targets_and_tfs(target_tf_fname)
write_lst_to_file(all_targets, TARGET_FILE)
write_lst_to_file(all_tfs, TF_FILE)

if os.path.exists(R_SCRIPT_FILE):
    cmd = 'Rscript %s' % (R_SCRIPT_FILE)
    os.system(cmd)
else:
    sys.exit()

loglikhood_dict = {
    'seedling':'-999.0',
    'meristem':'-998.0',
    'root':'-997.0',
    'leaf':'-996.0',
    'flower':'-995.0',
    'shoot':'-994.0',
    'seed':'-993.0',
    'stem':'-992.0',
    'aerial':'-990.0'
}

if not os.path.isdir(HISTORY_DIR):
    os.makedirs(HISTORY_DIR)


file_lst = glob.glob('../Data/temp/edges.txt.simple.correlation.tissue.*.txt')
curr_time = datetime.now().strftime('%Y%m%d_%H%M%S')

if os.path.exists(TISSUE_INFO_FILE):
    tissue_dict = make_tissue_dict(TISSUE_INFO_FILE) # assign each rnaseq a tissue, and each tissue a list of rnaseq
else:
    sys.exit()
    
for fname in file_lst:
    tissue, loglik_placeholder = get_tissue_from_filename(fname, loglikhood_dict)
    if tissue != 'unknown':
        print(fname)
        result_fname = os.path.join(HISTORY_DIR, 'edges.txt.simple.correlation.%s.%s' % (tissue, curr_time))
        RESULT_FILE = fname
        establish_edges(RESULT_FILE, target_tf_fname, result_fname, agi2name_dict, tissue_dict, loglikhood_dict)  # change

cmd = 'rm -f %s %s' % (TARGET_FILE, TF_FILE)
os.system(cmd)
cmd = 'rm -f ../Data/temp/edges.txt.simple.correlation.tissue.*.txt'
os.system(cmd)
#print('Done.')