# Usage: python create_edges4.py parameter_for_net.txt
# Purpose:
# This script will generate MAX_PROCESS work_on_AGI#.R scripts, each for a target gene. So treat each target separately.
# The results will be saved as edges.txt.AT3G12580.20170308, where AT3G12580 is target gene id, and 20170308 is date.
# The edges.txt files will be merged together later.
# Hopeful it will be faster.
# Make it faster by handling each target separately.
# Make memory footprint smaller by spliting TPM.txt into small json files, and converting binding.txt to target_tf.txt first.
# So we don't need to load the big matrices, TPM.txt and binding.txt.
#
#  7 Mar 2017, slcu, hui

import sys, os, operator, itertools
from datetime import datetime
import time
import json
import subprocess
from geneid2name import make_gene_name_AGI_map_dict
from param4net import make_global_param_dict

EDGE_FILE = '../Data/history/edges/edges.txt'
EDGE_DIR = '../Data/history/edges/one_target'  # a directory storing all edge files, one for each target gene
GENE_ID_TO_GENE_NAME = '../Data/information/AGI-to-gene-names_v2.txt' # for gene names
TIME_INTERVAL = 10 # wait this long in seconds between before launching a R Rscript
MAX_PROCESS = 5 # CHANGE
K = 2              # CHANGE number of components to use

####################################
DATA_SYMBOL         = '@'
####################################

def get_gene_list(fname):
    result = []
    f = open(fname)
    for line in f:
        line = line.strip()
        lst = line.split()
        result.append(lst[0])
    f.close()
    return result

def get_ordered_gene_list(fname):
    gene_list = get_gene_list(fname)
    d = {}
    f = open(EDGE_FILE)
    lines = f.readlines()
    f.close()
    for line in lines:
        line = line.strip()
        lst = line.split('\t')
        target = lst[0].split()[0]
        tf     = lst[1].split()[0]
        if not target in d:
            d[target] = 1
        else:
            d[target] += 1
            
    result_gene_lst = []
    for t in sorted(d.items(), key=operator.itemgetter(1)): # targets with fewer edges will be on the top
        g = t[0]
        if g in gene_list:
            result_gene_lst.append(g)
    return result_gene_lst


def make_tf_dict(fname):
    d = {}
    f = open(fname)
    for line in f:
        line = line.strip()
        lst = line.split('\t')
        target = lst[0]
        tf     = lst[1]
        cond   = lst[2]
        if not target in d:
            d[target] = {tf:cond}
        else:
            d[target][tf] = cond
    f.close()
    return d

def make_r_script(fname, target, tf_dict, abs_jsonTPM_dir, num_component):
    head =  'k.lst <- c(%d)\n' % (num_component) 
    head += 'target <- \'%s\'\n' % (target)
    head += 'id2 <- target\n'
    tfs = ''
    conds = ''
    for k in tf_dict.keys():
        tfs += '\'%s\',' % (k)
        conds += '\'%s\',' % (tf_dict[k])
    head += 'tfs <- c(' + tfs.rstrip(',') + ')\n'
    head += 'conditions <- c(' + conds.rstrip(',') + ')\n'
    head += 'jsonTPM.dir <- \'%s\'\n' % (abs_jsonTPM_dir)
    head += 'AGINAME_FILE   <- \'%s\'\n' % (os.path.abspath(GENE_ID_TO_GENE_NAME))
    head += 'output.file <- paste(\'%s/edges.txt\', id2, gsub(\'-\',\'\',Sys.Date()), \'k%d\', sep=\'.\')\n' % (EDGE_DIR, num_component)
    body = '''
	####### Read data #########################################
	CORR_THRESHOLD <- 0.7
	MIN_SIZE       <- 100
	agi        <- read.table(AGINAME_FILE, sep='\\t', header=FALSE, row.names=1, stringsAsFactors=F) # AGINAME_FILE cannot contain quotes
	#######################################################
	library(mixtools)
	library(rjson)
	name2 <- agi[id2,1]
	result <- ''
	for (i in 1:length(tfs)) {
	    curr.date <- gsub('-','',Sys.Date())
	    id1 <- tfs[i]
	    name1 <- agi[id1,1]
	    cond <- conditions[i]

	    file.x <- paste(jsonTPM.dir, paste(id1, '.json', sep=''), sep='/')
	    if (!file.exists(file.x)) { next }
	    x <- as.data.frame(fromJSON(file = file.x))
	    x <- log(x+1)
	    rcond.x <- names(x)
	    x <- as.vector(t(x)) # convert it to a vector

	    file.y <- paste(jsonTPM.dir, paste(id2, '.json', sep=''), sep='/')
	    if (!file.exists(file.y)) { break }
	    y <- as.data.frame(fromJSON(file = file.y))
	    y <- log(y+1)
	    rcond.y <- names(y)
	    y <- as.vector(t(y)) # convert it to a vector

	    rna.sample.id <- rcond.x
	    if (all(rcond.x == rcond.y) == FALSE | id1 == id2) { # if the IDs in two json files do not match, or target is the same as tf, then ignore 
	       next
	    }

	    index <- x < 0.01 | y < 0.01 # don't include data that is too small
	    x <- x[!index]
	    y <- y[!index]
	    r <- cor(x, y)
	    if (abs(r) >= CORR_THRESHOLD) {
	        s = sprintf('%s %s\\t%s %s\\t%4.2f\\t%s\\t%s\\t%s\\t%s\\t%s\\n', id2, name2, id1, name1, r, 'all', '.', cond, '.', curr.date)
	        result <- paste(result, s, sep='')
	        next  # a good correlation is found using all experiments, so not necessary to look further
	    }
	
	    rna.sample.id <- rna.sample.id[!index] # important to make the index work

	    pos_r_max   <- -2
	    pos_r_N     <- 0
	    pos_r_index <- c()
	    pos_r_loglik <- -100000000
	
	    neg_r_max   <- 2
	    neg_r_N     <- 0
	    neg_r_index <- c()
	    neg_r_loglik <- -100000000

	    for (k in k.lst) {
	        em.out <- regmixEM(y, x, maxit=150, epsilon=1e-04, k=k)
	        for (j in seq(1,k,1)) {
	            index <- which(max.col(em.out$posterior) == j)
	            size <- length(index)
	            r <- cor(em.out$x[index,2], em.out$y[index])
	            if (!is.na(r) && r >= CORR_THRESHOLD && size >= MIN_SIZE && r > pos_r_max && size > pos_r_N) {
	                pos_r_max <- r
	                pos_r_N   <- size
	                pos_r_index <- index
	                pos_r_loglik <- em.out$loglik
	            }
	            if (!is.na(r) && r <= -CORR_THRESHOLD && size >= MIN_SIZE && r < neg_r_max && size > neg_r_N) {
	                neg_r_max <- r
	                neg_r_N   <- size
	                neg_r_index <- index
	                neg_r_loglik <- em.out$loglik
	            }
	        }
	    }
	
	    if (pos_r_max > 0) { # has a good positive correlation
	        sub.cond <- paste(rna.sample.id[pos_r_index], collapse=' ')
	        s = sprintf('%s %s\\t%s %s\\t%4.2f\\t%s\\t%s\\t%s\\t%4.2f\\t%s\\n', id2, name2, id1, name1, pos_r_max, 'mix', sub.cond, cond, pos_r_loglik, curr.date)
	        result <- paste(result, s, sep='')
	    }
	    if (neg_r_max < 0) { # has a good negative correlation
	        sub.cond <- paste(rna.sample.id[neg_r_index], collapse=' ')
	        s = sprintf('%s %s\\t%s %s\\t%4.2f\\t%s\\t%s\\t%s\\t%4.2f\\t%s\\n', id2, name2, id1, name1, neg_r_max, 'mix', sub.cond, cond, neg_r_loglik, curr.date)
	        result <- paste(result, s, sep='')
	    }
	}
	cat(result, file=output.file, sep='')
    '''
    f = open(fname, 'w')
    content = head + body
    f.write('\n'.join([line.lstrip('\t').rstrip() for line in content.split('\n')]))
    f.close()
    return fname

def wait_a_moment(n, prefix):
    ''' if there are more than n work_on...R scripts running, wait... '''
    time.sleep(TIME_INTERVAL)
    ps = subprocess.Popen('ps aux | grep %s' % (prefix), shell=True, stdout=subprocess.PIPE)  # CHANGE
    process_lst = ps.communicate()[0].split('\n')
    while (len(process_lst) > n):
        #print('number of running processes %d' % (len(process_lst)))
        time.sleep(TIME_INTERVAL)        
        ps = subprocess.Popen('ps aux | grep %s' % (prefix), shell=True, stdout=subprocess.PIPE)
        process_lst = ps.communicate()[0].split('\n')
    
def establish_edges(jsonTPM_dir, d, glb_param_dict, rprefix):
    ''' d - binding dictionary {target:{tf1:c1, tf2:c2}, ...  }, c1 c2 are strings of conditions  '''

    gene_lst = get_ordered_gene_list(glb_param_dict['GENE_LIST']) # targets with fewer edges will get higher priority.  For example, those targets never having an edge will be treated first
    high_gene_lst = glb_param_dict['HIGH_PRIORITY_GENE'].split() # high priority genes CHANGE
    
    if not os.path.isdir(EDGE_DIR):
        os.makedirs(EDGE_DIR)

    # make a list of targets, putting high-priority target in the beginning
    final_gene_lst = high_gene_lst
    for x in gene_lst:
        if not x in high_gene_lst:
            final_gene_lst.append(x)
            
    # process each target
    for target in final_gene_lst: # high priority genes are processed first
        if target in d: # target g is in binding dictionary d
            tf_dict = d[target] # in the form of {tf1:c1, tf2:c2}
            if len(tf_dict) > 0: # it has TFs, usually it is the case
                r_file = '../Data/temp/%s_%s_K%d.R' % (rprefix, target, K)
                fname = make_r_script(r_file, target, tf_dict, jsonTPM_dir, K)
                cmd = 'Rscript %s &' % (r_file) # run the Rscript in background
                os.system(cmd) # UNCOMMENT ME
                wait_a_moment(MAX_PROCESS, rprefix) # make sure there are not too many R process running in the same time.  If too many, wait.  MAX_PROCESS sets the limit.
                
## main
param_file = sys.argv[1] # a single prameter file for building network, parameter_for_net.txt
glb_param_dict = make_global_param_dict(param_file)
agi2name_dict = make_gene_name_AGI_map_dict(GENE_ID_TO_GENE_NAME) # for gene names

#print('Make jsonTPM ...')  # CHANGE
cmd = 'python TPM2JSON.py %s' % (param_file) # make jsonTPM directory. The TPM values are not log-transformed.
os.system(cmd)
curr_time = datetime.now().strftime('%Y%m%d_%H%M%S')
JSON_DIR = '../Data/history/expr/jsonTPM_%s' % (curr_time) # for each TPM.txt, there should be a unique jsonTPM directory.
cmd = 'mv ../Data/history/expr/jsonTPM %s' % (JSON_DIR)
os.system(cmd)

#print('Make target tf using binding.txt')
target_tf_fname = '../Data/information/target_tf.txt.' + curr_time
cmd = 'python make_target_tf.py %s > %s' % (param_file, target_tf_fname)  # make target_tf.txt CHANGE better to make a temperory copy for this program
os.system(cmd)

#JSON_DIR = '../Data/history/expr/jsonTPM_20170310_1153'
#target_tf_fname = '../Data/information/target_tf.txt.20170310_1153'
#print('Establish edges')
big_tf_dict = make_tf_dict(target_tf_fname)
rscript_prefix = 'WORK%s' % (datetime.now().strftime('%Y%m%d_%H%M'))
establish_edges(os.path.abspath(JSON_DIR), big_tf_dict, glb_param_dict, rscript_prefix)