# Usage: python create_edges0.py parameter_for_net.txt
#
# Make it faster by spawning subprocesses.
#
# Quickly create edges using all samples in TPM.txt.  TF and targets
# are from target_tf.txt.  Results will be written to
# EDGE_POOL_DIR/edges.txt.simple.correlation.all.conditions.date
# target_tf.txt is produced by make_target_tf.py.
# 
#
# 26 JAN 2017, hui, slcu
# Last modified 5 APR 2017, hui, slcu

import sys, os, operator, itertools
from datetime import datetime
from geneid2name import make_gene_name_AGI_map_dict, get_gene_name
from param4net import make_global_param_dict
from configure import EDGE_POOL_DIR

TARGET_FILE   = '../Data/temp/all_targets.txt'
TF_FILE       = '../Data/temp/all_tfs.txt'
RESULT_FILE   = '../Data/temp/corr_all.txt'
R_SCRIPT_FILE = '../Data/temp/compute_simple_correlation.r'

HISTORY_DIR       = EDGE_POOL_DIR   # edges.txt.* files are here


def get_value(s, delimit):
    lst = s.split(delimit)
    return lst[1].strip()


def get_gene_list(fname):
    result = []
    f = open(fname)
    for line in f:
        line = line.strip()
        lst = line.split()
        result.append(lst[0])
    f.close()
    return result


def make_tf_dict(fname):
    d = {}
    f = open(fname)
    for line in f:
        line = line.strip()
        lst = line.split('\t')
        target = lst[0]
        tf     = lst[1]
        cond   = lst[2].split()
        if not target in d:
            d[target] = {tf:cond}
        else:
            d[target][tf] = cond
    f.close()
    return d


def get_targets_and_tfs(fname):
    f = open(fname)
    target_lst = []
    tf_lst = []
    for line in f:
        line = line.strip()
        lst = line.split('\t')
        target = lst[0]
        tf     = lst[1]
        target_lst.append(target)
        tf_lst.append(tf)
    f.close()
    return sorted(list(set(target_lst))), sorted(list(set(tf_lst)))


def write_lst_to_file(lst, fname):
    f = open(fname, 'w')
    for x in lst:
        f.write(x + '\n')
    f.close()


def make_r_script(fname, result_file, data_file, target_file, tf_file, r_tau=0.75):

    head = 'OUTPUT_FILE   <- \'%s\'\n  DATA_FILE     <- \'%s\'\n    TARGET_FILE   <- \'%s\'\n   TF_FILE       <- \'%s\'\n tau           <- %0.2f\n' % (result_file, data_file, target_file, tf_file, r_tau)

    body = '''
       targets       <- read.table(TARGET_FILE, header=FALSE)
       tfs           <- read.table(TF_FILE, header=FALSE)
       X             <- read.table(DATA_FILE, header=TRUE, check.names=FALSE)
       targets       <- as.vector(targets$V1)
       tfs           <- as.vector(tfs$V1)
       all_genes     <- rownames(X)
       
       X           <- as.matrix(X)
       sd.1        <- apply(X, 1, sd) # sd of each row
       s0          <- apply(X, 1, function(c) sum(c==0)) # number of zeros in each row
       sd.tau      <- (quantile(sd.1,na.rm=TRUE)[1] +  quantile(sd.1,na.rm=TRUE)[2]) / 2.0 # min SD
       good        <- sd.1 > max(sd.tau, 0.05)
       tf_good     <- which( good & (all_genes %in% tfs) == T )
       target_good <- which( good & (all_genes %in% targets) == T )
       
       # compute correlation coefficient
       X <- log(X + 1)
       X[X<0.01] <- NA
       if (length(tf_good) < 2) {
           c <- cor(t(X[target_good,]), t(X[c(tf_good, tf_good),]), use='pairwise.complete.obs')
       } else {
           c <- cor(t(X[target_good,]), t(X[tf_good,]), use='pairwise.complete.obs')
       }
       index <- !is.na(c) & abs(c) >= tau &  abs(c) <= 0.99
       row_names <- rownames(c)
       col_names <- colnames(c)
       result <- data.frame(row = row_names[row(c)[index]], col = col_names[col(c)[index]], r = c[index])
       
       # write results
       write.table(result, OUTPUT_FILE, col.names=F, row.names=F, sep='\\t', quote=F)
    '''

    f = open(fname, 'w')
    content = head + body
    lst = [x.strip() for x in content.split('\n')]
    f.write('\n'.join(lst))
    f.close()


def edit_headline(fname):
    ''' Remove gene_id from first line.  For easier R matrix reading. '''
    new_fname = fname + '.copy'
    f = open(fname)
    lines = f.readlines()
    f.close()
    f = open(new_fname, 'w')
    head = lines[0].strip()
    head_lst = head.split('\t')[1:]
    num_rnaseq = len(head.split('\t')) - 1
    f.write('\t'.join(head_lst) + '\n')
    for line in lines[1:]:
        f.write(line)
    f.close()
    return new_fname, num_rnaseq


def establish_edges(corr_fname, target_tf_fname, result_fname, agi2name_dict, num_rnaseq, glb_param_dict):
    big_tf_dict = make_tf_dict(target_tf_fname)
    f = open(corr_fname)
    lines = f.readlines()
    f.close()

    result = ''
    for line in lines:
        line = line.strip()
        lst = line.split('\t')
        target = lst[0]
        tf     = lst[1]
        score  = '%4.2f' % (float(lst[2]))
        if target in big_tf_dict and tf in big_tf_dict[target]:
            target_str = target + ' ' + get_gene_name(target, agi2name_dict)
            tf_str     = tf     + ' ' + get_gene_name(tf,     agi2name_dict)
            score_str  = score
            cond_str   = ' '.join(big_tf_dict[target][tf])
            curr_date =  datetime.now().strftime('%Y%m%d')
            method_or_tissue = 'all' if glb_param_dict['EXPRESSION_MATRIX_DESCRIPTION'].strip() == '' else glb_param_dict['EXPRESSION_MATRIX_DESCRIPTION']
            s = '\t'.join([target_str, tf_str, score_str, 'all', str(num_rnaseq), cond_str, '.', curr_date, score_str, method_or_tissue])
            result += s + '\n'

    f = open(result_fname, 'w')
    f.write(result)
    f.close()


def target_tf_file_compare_same(fname1, fname2):
    if not os.path.exists(fname1):
        return False
    if not os.path.exists(fname2):
        return False
    f1 = open(fname1)
    s1 = f1.read()
    f1.close()
    f2 = open(fname2)
    s2 = f2.read()
    f2.close()
    return s1 == s2
    

## main
param_file = sys.argv[1] # a single prameter file
glb_param_dict = make_global_param_dict(param_file)
agi2name_dict = make_gene_name_AGI_map_dict(glb_param_dict['GENE_ID_AND_GENE_NAME'])

target_tf_fname = '../Data/information/target_tf.txt'
if not os.path.exists(target_tf_fname):
    print('create_edges0: file %s does not exist.  Produce this file use make_target_tf.py.' % (target_tf_fname))
    sys.exit()
    
all_targets, all_tfs = get_targets_and_tfs(target_tf_fname)
write_lst_to_file(all_targets, TARGET_FILE)
write_lst_to_file(all_tfs, TF_FILE)

data_file, num_rnaseq = edit_headline(glb_param_dict['EXPRESSION_MATRIX'])

make_r_script(R_SCRIPT_FILE, RESULT_FILE, data_file, TARGET_FILE, TF_FILE, 0.60)

cmd = 'Rscript %s' % (R_SCRIPT_FILE)
os.system(cmd) 

if not os.path.isdir(HISTORY_DIR):
    os.makedirs(HISTORY_DIR)

curr_time = datetime.now().strftime('%Y%m%d_%H%M%S')
result_fname = os.path.join(HISTORY_DIR, 'edges.txt.simple.correlation.all.conditions.' + curr_time)
establish_edges(RESULT_FILE, target_tf_fname, result_fname, agi2name_dict, num_rnaseq, glb_param_dict)  # change

cmd = 'rm -f %s %s %s %s %s' % (data_file, TARGET_FILE, TF_FILE, R_SCRIPT_FILE, RESULT_FILE)
os.system(cmd)
print('Done. Check %s.' % (result_fname))