# Usage: python create_edges.py parameter_for_net.txt > edges.txt.20170227_1618
#
# 01 DEC 2016, hui

import sys, os, operator, itertools
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stat
from datetime import datetime

import rpy2.robjects as r
from rpy2.robjects.packages import importr
from rpy2.robjects import FloatVector

import warnings
from geneid2name import make_gene_name_AGI_map_dict, get_gene_name
from param4net import make_global_param_dict

####### Utility files #############
GENE_ID_TO_GENE_NAME = '../Data/information/AGI-to-gene-names_v2.txt'

####################################
GLB_PARAM_SYMBOL    = '%%'
DATA_SYMBOL         = '@'
TOP_N_TF            = 50
MIN_NUM_CONDITION   = 20
SIGNAL_INPUT_RATIO_TAU = 1.5
REMOVE_HORIZONTAL_STRIP_TAU = 0.05
REMOVE_VERTICAL_STRIP_TAU   = 0.05
MIN_NUMBER_OF_POINTS_FOR_MIXTURE_OF_GAUSSIAN = 30
####################################

def get_two_components(y, x):
    K = 2
    epsilon = 1e-4
    lam = 0.1
    iterations = 25
    random_restarts = 2

    # Remove NaNs or Infs
    warn_msg = ''
    sz = len(x)
    if sz < MIN_NUMBER_OF_POINTS_FOR_MIXTURE_OF_GAUSSIAN: # too few points, ignore.
        return None, None, 'IGNORE'
    # print('DEBUG')
    # print(y)
    # print(x)
    # print(type(y))
    # print(type(x))    
    index = np.isfinite(x) & np.isfinite(y)
    if sum(index) < sz:
        warn_msg = np.array_str(x) + ',' +  np.array_str(y)
        if sum(index) < MIN_NUMBER_OF_POINTS_FOR_MIXTURE_OF_GAUSSIAN:
            return None, None, 'IGNORE'
            
    x = x[index]
    y = y[index]
    
    # Train the model
    model = LinearRegressionsMixture(np.expand_dims(x, axis=1), np.expand_dims(y, axis=1), K=K)
    model.train(epsilon=epsilon, lam=lam, iterations=iterations, random_restarts=random_restarts, verbose=False)
    idx1 = (model.gamma[:,0] >  model.gamma[:,1]) # model.gamma is a vector of posterior probabilities
    idx2 = (model.gamma[:,1] >  model.gamma[:,0]) 
    return idx1, idx2, warn_msg


def get_three_components(y, x, cond_lst):
    K = 3
    epsilon = 1e-4
    lam = 0.1
    iterations = 50
    random_restarts = 5

    # Remove NaNs or Infs
    warn_msg = ''
    sz = len(x)
    if sz < MIN_NUMBER_OF_POINTS_FOR_MIXTURE_OF_GAUSSIAN: # too few points, ignore.
        return None, None, None, None, None, None, None, None, None, 'IGNORE'

    index = np.isfinite(x) & np.isfinite(y)
    if sum(index) < sz:
        warn_msg = 'HAS_NAN_OR_INIFNITE'
        if sum(index) < MIN_NUMBER_OF_POINTS_FOR_MIXTURE_OF_GAUSSIAN:
            return None, None, None, None, None, None, None, None, None, 'IGNORE'
            
    xx = np.array(x[index])
    yy = np.array(y[index])
    cond_lst2 = np.array(cond_lst)
    cond_lst2 = cond_lst2[index]

    
    # Train the model
    model = LinearRegressionsMixture(np.expand_dims(xx, axis=1), np.expand_dims(yy, axis=1), K=K)
    model.train(epsilon=epsilon, lam=lam, iterations=iterations, random_restarts=random_restarts, verbose=False)
    idx1 = np.array(model.gamma[:,0] >  model.gamma[:,1]) & np.array(model.gamma[:,0] >  model.gamma[:,2]) # model.gamma is a vector of posterior probabilities
    idx2 = np.array(model.gamma[:,1] >  model.gamma[:,0]) & np.array(model.gamma[:,1] >  model.gamma[:,2])
    idx3 = np.array(model.gamma[:,2] >  model.gamma[:,0]) & np.array(model.gamma[:,2] >  model.gamma[:,1])
    return xx[idx1], yy[idx1], xx[idx2], yy[idx2], xx[idx3], yy[idx3], list(cond_lst2[idx1]), list(cond_lst2[idx2]), list(cond_lst2[idx3]), warn_msg


def get_three_components_and_evaluate(y, x, cond_lst):
    K = 3
    epsilon = 1e-4
    lam = 0.1
    iterations = 50
    random_restarts = 5

    # Remove NaNs or Infs
    warn_msg = ''
    sz = len(x)
    if sz < MIN_NUMBER_OF_POINTS_FOR_MIXTURE_OF_GAUSSIAN: # too few points, ignore.
        return None, None, None, None, None, None, None, None, None, 'IGNORE'

    index = np.isfinite(x) & np.isfinite(y)
    if sum(index) < sz:
        warn_msg = 'HAS_NAN_OR_INIFNITE'
        if sum(index) < MIN_NUMBER_OF_POINTS_FOR_MIXTURE_OF_GAUSSIAN:
            return None, None, None, None, None, None, None, None, None, 'IGNORE'
            
    xx = np.array(x[index])
    yy = np.array(y[index])
    cond_lst2 = np.array(cond_lst)
    cond_lst2 = cond_lst2[index]

    # Train the model
    model = LinearRegressionsMixture(np.expand_dims(xx, axis=1), np.expand_dims(yy, axis=1), K=K)
    model.train(epsilon=epsilon, lam=lam, iterations=iterations, random_restarts=random_restarts, verbose=False)
    idx1 = np.array(model.gamma[:,0] >  model.gamma[:,1]) & np.array(model.gamma[:,0] >  model.gamma[:,2]) # model.gamma is a vector of posterior probabilities
    idx2 = np.array(model.gamma[:,1] >  model.gamma[:,0]) & np.array(model.gamma[:,1] >  model.gamma[:,2])
    idx3 = np.array(model.gamma[:,2] >  model.gamma[:,0]) & np.array(model.gamma[:,2] >  model.gamma[:,1])
    rmse_avg, rmse_std = model.cross_validate(k_fold=10, verbose=False, silent=True)
    warn_msg = 'rmse_avg=%4.2f,rmse_sd=%4.2f' % (rmse_avg, rmse_std)
    return xx[idx1], yy[idx1], xx[idx2], yy[idx2], xx[idx3], yy[idx3], list(cond_lst2[idx1]), list(cond_lst2[idx2]), list(cond_lst2[idx3]), warn_msg


def get_three_components_mixtools(y, x, cond_lst):

    # Remove NaNs or Infs
    warn_msg = ''
    sz = len(x)
    if sz < MIN_NUMBER_OF_POINTS_FOR_MIXTURE_OF_GAUSSIAN: # too few points, ignore.
        return None, None, None, None, None, None, None, None, None, 'IGNORE'

    index = np.isfinite(x) & np.isfinite(y)
    if sum(index) < sz:
        warn_msg = 'HAS_NAN_OR_INIFNITE'
        if sum(index) < MIN_NUMBER_OF_POINTS_FOR_MIXTURE_OF_GAUSSIAN:
            return None, None, None, None, None, None, None, None, None, 'IGNORE'
            
    xx = np.array(x[index])
    yy = np.array(y[index])
    cond_lst2 = np.array(cond_lst)
    cond_lst2 = cond_lst2[index]

    # Train the model
    mixtools =  importr('mixtools')
    try:
        result = mixtools.regmixEM(FloatVector(yy), FloatVector(xx), epsilon = 1e-04, k=3, maxit=100)
    except:
        return None, None, None, None, None, None, None, None, None, 'IGNORE'
    posterior = result[result.names.index('posterior')]
    posterior = np.array(posterior)
    l = np.argmax(posterior, axis=1) # class information
    idx1 = l == 0
    idx2 = l == 1
    idx3 = l == 2    
    warn_msg = 'loglik=%4.2f' % (np.array(result[result.names.index('loglik')])[0])

    return xx[idx1], yy[idx1], xx[idx2], yy[idx2], xx[idx3], yy[idx3], list(cond_lst2[idx1]), list(cond_lst2[idx2]), list(cond_lst2[idx3]), warn_msg


def read_matrix_data(fname):
    ''' 
    fname - a file, first line is head, first column is row name.
    '''
    
    lineno = 0
    colid = []
    rowid = []
    d =  {}  # {gene1:{cond1:val1, cond2:val2, ...}, gene2: {...}, ...}
    d2 = {} # {cond1:{gene1:val1, gene2:val2, ...}, cond2: {...}, ...}
    d3 = {} # {gene1: [], gene2: [], ...}
    d4 = {} # {cond1:[], cond2:[], ...}

    f = open(fname)
    lines = f.readlines()
    f.close()

    head_line = lines[0].strip()
    lst = head_line.split()
    colid = lst[1:]

    for c in colid:
        d2[c] = {}
        d4[c] = []
    
    for line in lines[1:]:
        line = line.strip()
        lst = line.split()
        g = lst[0]
        rowid.append(g)
        d[g] = {}
        levels = lst[1:]
        if len(levels) != len(colid):
            print('Incomplete columns at row %s' % (g))
            sys.exit()
            
        d3[g] = []
        for i in range(len(colid)):
            c = colid[i]
            d[g][c]  = float(levels[i])
            d2[c][g] = float(levels[i])
            d3[g].append(float(levels[i]))
            d4[c].append(float(levels[i]))
        lineno += 1

    d_return = {}
    d_return['xy'] = d  # first gene, then condition
    d_return['yx'] = d2 # first condition, then gene
    d_return['xx'] = d3 # each item is an array of gene expression levels, i.e., each item is a row
    d_return['yy'] = d4 # each item is an array of gene expression levels, i.e., each item is a column
    d_return['nrow'] = lineno - 1
    d_return['ncol'] = len(colid)
    d_return['rowid'] = rowid
    d_return['colid'] = colid    

    d4_sorted = {}
    for k in d4:
        d4_sorted[k] = sorted(d4[k], reverse=True)
    d_return['yy_sorted'] = d4_sorted

    return d_return


def get_value(s, delimit):
    lst = s.split(delimit)
    return lst[1].strip()

def read_info_data(fname):
    ''' Read chip-seq data information '''

    if not os.path.exists(fname):
        print('%s not exists.' % (fname) )
        sys.exit()
        
    d = {'ID_LIST':[]}
    f = open(fname)
    lines = f.readlines()
    f.close()
    for line in lines:
        line = line.strip()
        if line == '' or line.startswith('#') or line.startswith('%'):
            continue
        if line.startswith(DATA_SYMBOL):
            s = line[line.rfind(DATA_SYMBOL[-1])+1:]
            s = s.strip()
            if s in d:
                print('ID %s duplicate' % (s))
                sys.exit()
            d[s] = {'PROTEIN_ID':'', 'PROTEN_NAME':'', 'DATA_NAME':'', 'DATA_FORMAT':'', 'DESCRIPTION':'', 'LOCATION':'', 'NOTE':''}
            d['ID_LIST'].append(s)
        if line.startswith('DESCRIPTION:'):
            d[s]['DESCRIPTION'] = get_value(line, ':')
        elif line.startswith('PROTEN_NAME:'):
            d[s]['PROTEN_NAME'] = get_value(line, ':')
        elif line.startswith('PROTEIN_ID:'):
            d[s]['PROTEIN_ID'] = get_value(line, ':')
        elif line.startswith('DATA_NAME:'):
            d[s]['DATA_NAME'] = get_value(line, ':')                        
        elif line.startswith('DATA_FORMAT:'):
            d[s]['DATA_FORMAT'] = get_value(line, ':')
        elif line.startswith('LOCATION:'):
            d[s]['LOCATION'] = get_value(line, ':')
        elif line.startswith('NOTE:'):
            d[s]['NOTE'] = get_value(line, ':')

    return d


def get_gene_list(fname):
    result = []
    f = open(fname)
    for line in f:
        line = line.strip()
        lst = line.split()
        result.append(lst[0])
    f.close()
    return result


def get_related_condition(s, info_dict):
    lst = s.split(';')
    result = [] # a list of sample IDs
    result = info_dict['ID_LIST'] # TBD
    return result


def update_global_param_dict(glb_param_dict, info_dict):
    if glb_param_dict['RESEARCH_KEYWORDS'] == '':
        glb_param_dict['USER_CONDITION_LIST'] = info_dict['ID_LIST']
    glb_param_dict['USER_CONDITION_LIST'] = get_related_condition(glb_param_dict['RESEARCH_KEYWORDS'], info_dict)


def get_threshold(lst):
    x = np.array(lst)
    x = x[x > 0]
    return np.median(x)


def get_threshold2(lst, glb_param_dict):
    x = np.array(lst)
    x = x[x > 0]
    max_num  = int(glb_param_dict['MAX_NUM_TARGETS'])
    percent = float(glb_param_dict['OVERFLOW_TARGETS_PERCENTAGE'])
    n = len(x)
    if n < max_num:
        return x[-1]
    else: # include some overflowing targets, but not all
        overflow = n - max_num
        keep = int(overflow * percent)
        index = keep + max_num
        return x[index]

def get_tf(g, bind_dict, info_dict, input_dict, glb_param_dict):
    tf_dict = {}
    d = bind_dict['xy']
    input_d = input_dict['xy']
    input_cond = input_dict['colid'][0] # use the first column as input (improve)
    if g in d.keys():
        for c in bind_dict['colid']:
            bind_val = d[g][c]
            if info_dict[c]['DATA_FORMAT'].upper() == 'BW':
                input_val = input_d[g][input_cond]
                if g == 'AT1G65480': # FT, target is FT
                    #print('DEBUG target:%s protein=%s bv=%g, iv=%g, ratio=%g' % (g, info_dict[c]['PROTEIN_ID'], bind_val, input_val, bind_val/input_val))
                    pass
                if input_val > 0 and input_val < 10000 and (bind_val / input_val) > SIGNAL_INPUT_RATIO_TAU: # input_val should also be not too large
                    g2 = info_dict[c]['PROTEIN_ID']
                    if g2 != '':
                        if not g2 in tf_dict:
                            tf_dict[g2] = [c]
                        else:
                            tf_dict[g2].append(c)
            elif info_dict[c]['DATA_FORMAT'].upper() == 'NARROWPEAK':
                #tau = bind_dict['yy_sorted'][c][TOP_N_TF]
                tau = get_threshold2(bind_dict['yy_sorted'][c], glb_param_dict)
                #print('DEBUG target=%s %s %g >= %g' % (g, info_dict[c]['PROTEIN_ID'], bind_val, tau))
                if bind_val >= tau: # change later
                    g2 = info_dict[c]['PROTEIN_ID']
                    if g2 != '':
                        if not g2 in tf_dict:
                            tf_dict[g2] = [c]
                        else:
                            tf_dict[g2].append(c)

    return tf_dict



def get_gene_expression(gene_id, cond_lst, expr_dict, takelog=False):

    num_cond = len(cond_lst)
    elst = [None]*num_cond
    d = expr_dict['xy']
    for i in range(num_cond):
        c = cond_lst[i]
        x = d[gene_id][c]
        if takelog == True:
            elst[i] = np.log(x+1)
        else:
            elst[i] = x
    return np.array( elst )


def float_equal(x, y):
    return np.abs(x-y) < 0.001


def get_gene_expression2(gene_id1, gene_id2, cond_lst, expr_dict, takelog=False):
    ''' get gene expression for two genes.  Conditions in which two genes have zero TPM values are ignored. '''
    num_cond = len(cond_lst)
    elst1 = [None]*num_cond
    elst2 = [None]*num_cond
    clst  = [None]*num_cond
    d = expr_dict['xy']
    j = 0
    for i in range(num_cond):
        c = cond_lst[i]
        x = expr_dict['xy'][gene_id1][c]
        y = expr_dict['xy'][gene_id2][c]
        #print('DEBUG %s %s %g %g c=%s' % (gene_id1, gene_id2, x, y, c))
        #print('DEBUG at2g07745 at R0000SRR1802166XX %g' % (expr_dict['xy']['AT2G07754']['R0000SRR1802166XX']))
        if not float_equal(x,0.0) or not float_equal(y,0.0): # at least one is not zero
            if takelog == True: # increase gene expression uniformly by 1 for taking logarithm
                elst1[j] = np.log(x+1)
                elst2[j] = np.log(y+1)
            else:
                elst1[j] = x
                elst2[j] = y
            clst[j] = c
            j += 1
    return ( np.array(elst1[0:j]), np.array(elst2[0:j]), clst[0:j] )



def get_gene_expression3(gene_id1, gene_id2, cond_lst, expr_dict, takelog=False):
    ''' 
    get gene expression for two genes.  Conditions in which two genes have zero TPM values are ignored. 
    in addition, vertical strip and horizontal strip are removed.
    '''
    num_cond = len(cond_lst)
    elst1 = [None]*num_cond
    elst2 = [None]*num_cond
    mark_cond = [True]*num_cond # indicate if a condition should be included
    clst  = []
    d = expr_dict['xy']

    for i in range(num_cond):
        c = cond_lst[i]
        x = expr_dict['xy'][gene_id1][c]
        y = expr_dict['xy'][gene_id2][c]
        if not float_equal(x,0.0) or not float_equal(y,0.0): # at least one is not zero
            if takelog == True: # increase gene expression uniformly by 1 for taking logarithm
                elst1[i] = np.log(x+1)
                elst2[i] = np.log(y+1)
            else:
                elst1[i] = x
                elst2[i] = y
            if elst1[i] < REMOVE_VERTICAL_STRIP_TAU or elst2[i] < REMOVE_HORIZONTAL_STRIP_TAU: # don't include this condition if its values are in the strip
                mark_cond[i] = False
            else:
                clst.append(c)
        else:
            mark_cond[i] = False

    a = np.array(elst1)
    a = a[mark_cond]
    b = np.array(elst2)
    b = b[mark_cond]
    return (a.astype(np.float64), b.astype(np.float64), clst)


#################### select condition stuff ###############

def get_yhat(slope, intercept, x):
    yhat = intercept + slope * x 
    return yhat

def select_points_theil(x, y, max_diff):
    theil_result = stat.mstats.theilslopes(x, y)
    slope = theil_result[0]
    intercept = theil_result[1] 
    yhat =  get_yhat(slope, intercept, x)
    d = y - yhat
    d_abs = np.abs(d)
    index = d_abs < max_diff
    return (x[index], y[index], index)

def common_elements(list1, list2):
     return sorted(list(set(list1).intersection(list2)))

def two_points(p1, p2):
    '''Return slope and intercept '''
    x1, y1 = p1
    x2, y2 = p2
    m = (y2 - y1) / (x2 - x1)
    b = y1 - m * x1
    return (m, b)

def select_points_diagonal(x, y, max_diff, direction):
    ''' positive direction '''

    n = len(x)
    if n < 3:
        return (x, y)

    if direction.lower() == 'pos':
        index_x = np.argsort(x)
    elif direction.lower() == 'neg':
        index_x = np.argsort(-1*x)
    else:
        print('%s must be pos or neg' % (direction))
        sys.exit()
        
    index_y = np.argsort(y)

    # get lower (or upper) end point
    idx = None
    for i in range(2,n+1): # get common index
        s1 = index_x[0:i]
        s2 = index_y[0:i]
        s  = common_elements(s1, s2)
        if s != []:
            idx = s[0]
            break
    p1 = (x[idx], y[idx])

    # get upper (or lower) end point
    index_x = list(reversed(index_x)) # reverse list
    index_y = list(reversed(index_y)) 

    idx = None
    for i in range(2,n+1): # get common index
        s1 = index_x[0:i]
        s2 = index_y[0:i]
        s  = common_elements(s1, s2)
        if s != []:
            idx = s[0]
            break    
    p2 = (x[idx], y[idx])
    
    slope, intercept = two_points(p1, p2)

    yhat =  get_yhat(slope, intercept, x)
    d = y - yhat
    d_abs = np.abs(d)
    try:
        index = d_abs < max_diff
    except RuntimeWarning:
        print(d_abs)
        print(max_diff)
        sys.exit()
        
    return (x[index], y[index], index)


def correlation_is_significant(r, p, glb_param_dict):
    return (not np.isnan(r)) and np.abs(r) > float(glb_param_dict['TWO_WAY_CORRELATION_CUTOFF']) and p < float(glb_param_dict['TWO_WAY_CORRELATION_PVALUE_CUTOFF'])


def subset_list(lst, bool_index):
    if len(lst) != len(bool_index):
        print('subset_list: list size not equal (%d %s)', len(lst), len(bool_index))
        sys.exit()

    n = len(lst)
    result = []
    for i in range(n):
        if bool_index[i] == True:
            result.append(lst[i])
    return result


def print_dict(d, glb_param_dict):
    agi2name_dict = glb_param_dict['name_conversion_dict']
    curr_date = datetime.now().strftime('%Y%m%d') # add date to the end of each line, for future reference or filtering
    if d['type'] == 'two-way':
        gene_id = d['target']
        head = '%s %s\t' % (gene_id, get_gene_name(gene_id, agi2name_dict))
        gene_id = d['TF']
        head += '%s %s\t' % (gene_id, get_gene_name(gene_id, agi2name_dict))
        d2 = d['significant']
        if 'all' in d2:
            s = '%4.2f\t%s\t%s\t%s\t%s' % (d2['all']['score'], 'all', '.', ' '.join(d2['all']['chip_id']), '.')
            s += '\t' + curr_date
            print(head + s)
            sys.stdout.flush() # flush to stdout, so we can see the results immedialtely.            
        if 'user' in d2:
            s = '%4.2f\t%s\t%s\t%s\t%s' % (d2['user']['score'], 'user', ' '.join(d2['user']['signal_set']), ' '.join(d2['user']['chip_id']), '.')
            print(head + s)
        if 'pos' in d2:
            s = '%4.2f\t%s\t%s\t%s\t%s' % (d2['pos']['score'], 'pos', ' '.join(d2['pos']['signal_set']), ' '.join(d2['pos']['chip_id']), '.')
            print(head + s)
        if 'neg' in d2:
            s = '%4.2f\t%s\t%s\t%s\t%s' % (d2['neg']['score'], 'neg', ' '.join(d2['neg']['signal_set']), ' '.join(d2['neg']['chip_id']), '.')
            print(head + s)
        if 'mix' in d2:
            n = len(d2['mix']['score'])
            for i in range(n):
                s = '%4.2f\t%s\t%s\t%s\t%s' % (d2['mix']['score'][i], 'mix', ' '.join(d2['mix']['signal_set'][i]), ' '.join(d2['mix']['chip_id']), d2['mix']['message'][i])
                s += '\t' + curr_date
                print(head + s)
                sys.stdout.flush() # flush to stdout, so we can see the results immedialtely.
                

def two_way(target, tf_dict, expr_dict, expr_info_dict, glb_param_dict):
    '''

    Check if target has relationship with each of TFs.
    
    tf_dict: a dictionary of TFs, {tf_name:ChIP_ID_LIST)

    Return a list of dictionaries. Each dictionary has the following format:  

        'type'  :'two-way'
        'target':''
        'TF'    :''
        'significant': {
            'all':          {'signal_set':[], score=2, chip_id:''}
            'pos_direction':{'signal_set':[], score:.0, chip_id:''}
            'neg_direction':{'signal_set':[], score:.1, chip_id:''}
            'user_defined': {'signal_set':[], score:.3, chip_id:''}
            'mix':          {'signal_set':[], score:[.7,-.5], chip_id:''}
        }

    '''

    result_dict_lst = [] # a list of dictionaries, one for each TF, each dict contains info for a Target-TF pair
    
    target_gene_id = target
    all_cond_lst = expr_dict['colid'] # Use all RNA-seq samples.   TBD, can be glb_param_dict['USER_CONDITION_LIST']
    logrithmize = glb_param_dict['LOGRITHMIZE'].upper() == 'YES' # take logarithmic of TPM values
    #target_elst = get_gene_expression(target_gene_id, all_cond_lst, expr_dict, takelog=logrithmize) # a list of gene expression levels

    if not target_gene_id in expr_dict['rowid']: # target gene not in expression table, cannot do anything
        return  result_dict_lst
    
    for tf_gene_id in sorted(tf_dict.keys()):  # y is a TF gene id

        chip_id = tf_dict[tf_gene_id] # a list of chip experiment IDs, e.g., C00000000000

        if not tf_gene_id in expr_dict['rowid']: # tf gene not in expression table, cannot do anything
            continue

        # get gene expression profiles for target and TF.  If in a RNA-seq sample, both target and TF is 0, then this sample is ignored.
        target_elst, tf_elst, clist = get_gene_expression3(target_gene_id, tf_gene_id, all_cond_lst, expr_dict, takelog=logrithmize)

        r, p = stat.pearsonr(target_elst, tf_elst)

        d = {}
        d['target'] = target_gene_id
        d['TF'] = tf_gene_id
        d['type'] = 'two-way'
        d['significant'] = {}

        all_good = False
        if correlation_is_significant(r, p, glb_param_dict):
            d['significant']['all'] =  {}
            d['significant']['all']['signal_set'] =  clist # a list of sample IDs, returned by get_gene_expression3
            d['significant']['all']['score'] =  r
            d['significant']['all']['chip_id']  = chip_id            
            all_good = True
            
        user_cond_lst =  glb_param_dict['USER_CONDITION_LIST']
        if glb_param_dict['RESEARCH_KEYWORDS'] != '' and user_cond_lst != []:
             target_elst_user, tf_elst_user, clist_user = get_gene_expression3(target_gene_id, tf_gene_id, user_cond_lst, expr_dict, takelog=logrithmize)
             
             r, p = stat.pearsonr(target_elst_user, tf_elst_user)
             if correlation_is_significant(r, p, glb_param_dict):
                d['significant']['user'] =  {}
                d['significant']['user']['signal_set'] = user_cond_lst
                d['significant']['user']['score'] = r
                d['significant']['user']['chip_id']  = chip_id

        # obsolete
        max_diff = glb_param_dict['SELECT_POINTS_DIAGONAL_MAX_DIFF']
        if glb_param_dict['LOOK_FOR_POS_CORRELATION'] == 'YES':
            aa, bb, index_pos = select_points_diagonal(target_elst, tf_elst, max_diff, 'pos')
            r_pos, p_pos = stat.pearsonr(aa, bb)
            if correlation_is_significant(r_pos, p_pos, glb_param_dict) and sum(index_pos) >= MIN_NUM_CONDITION:
                d['significant']['pos'] =  {}
                d['significant']['pos']['signal_set'] = subset_list(all_cond_lst, index_pos)
                d['significant']['pos']['score'] = r_pos
                d['significant']['pos']['chip_id']  = chip_id

        # obsolete
        if glb_param_dict['LOOK_FOR_NEG_CORRELATION'] == 'YES':
            aa, bb, index_neg = select_points_diagonal(target_elst, tf_elst, max_diff, 'neg')
            r_neg, p_neg = stat.pearsonr(aa, bb)
            if correlation_is_significant(r_neg, p_neg, glb_param_dict) and sum(index_neg) >= MIN_NUM_CONDITION:
                d['significant']['neg'] =  {}
                d['significant']['neg']['signal_set'] = subset_list(all_cond_lst, index_neg)
                d['significant']['neg']['score'] = r_neg
                d['significant']['neg']['chip_id']  = chip_id

        K = int(glb_param_dict['NUMBER_OF_COMPONENTS'])
        if glb_param_dict['MIXTURE_OF_REGRESSION'] == 'YES' and not all_good: # look hard only when using all RNA-seq data does not produce good results
            if K == 2: # for now consider two components
                #print('DEBUG len1=%d, len=%d' % (len(target_elst), len(tf_elst)))
                #print('DEBUG %s, %s, %s' % (target_gene_id, tf_gene_id, ' '.join(clist)))                   
                index1, index2, msg = get_two_components(target_elst, tf_elst) # get two Gaussian Mixture Model components
                if msg != 'IGNORE':
                    aa = target_elst[index1]
                    bb = tf_elst[index1]
                    r_mix1, p_mix1 = stat.pearsonr(aa, bb)
                    aa = target_elst[index2]
                    bb = tf_elst[index2]
                    r_mix2, p_mix2 = stat.pearsonr(aa, bb)
                    #print('DEBUG %s %s r_mix1:%g r_mix2:%g' % (target_gene_id, tf_gene_id, r_mix1, r_mix2))
                    flag1 = correlation_is_significant(r_mix1, p_mix1, glb_param_dict)
                    flag2 = correlation_is_significant(r_mix2, p_mix2, glb_param_dict)
                    if flag1 or flag2:
                        d['significant']['mix'] =  {}
                        d['significant']['mix']['signal_set'] = []
                        d['significant']['mix']['score'] = []
                        d['significant']['mix']['chip_id']  = chip_id
                        if flag1:
                            d['significant']['mix']['signal_set'].append(subset_list(clist, index1))
                            d['significant']['mix']['score'].append(r_mix1)
                        if flag2:
                            d['significant']['mix']['signal_set'].append(subset_list(clist, index2))
                            d['significant']['mix']['score'].append(r_mix2)

            if K == 3: # three components
                aa1, bb1, aa2, bb2, aa3, bb3, cond1, cond2, cond3, msg = get_three_components_mixtools(target_elst, tf_elst, clist) # get two Gaussian Mixture Model components
                if msg != 'IGNORE':
                    r_mix1, p_mix1 = stat.pearsonr(aa1, bb1)
                    r_mix2, p_mix2 = stat.pearsonr(aa2, bb2)
                    r_mix3, p_mix3 = stat.pearsonr(aa3, bb3)
                    #print('DEBUG %s, %s' % (target_gene_id, tf_gene_id))
                    #print('DEBUG rmix1=%g, pmix1=%g' % (r_mix1, p_mix1))
                    #print('DEBUG rmix2=%g, pmix2=%g' % (r_mix2, p_mix2))
                    #print('DEBUG rmix3=%g, pmix3=%g' % (r_mix3, p_mix3))
                    #print('DEBUG %d %d %d' %(len(aa1), len(aa2), len(aa3)))
                    min_num_points = int(glb_param_dict['CORRELATION_BASED_ON_AT_LEAST_N_POINTS'])
                    flag1 = correlation_is_significant(r_mix1, p_mix1, glb_param_dict) and len(aa1) > min_num_points
                    flag2 = correlation_is_significant(r_mix2, p_mix2, glb_param_dict) and len(aa2) > min_num_points
                    flag3 = correlation_is_significant(r_mix3, p_mix3, glb_param_dict) and len(aa3) > min_num_points                   
                    if flag1 or flag2 or flag3:
                        d['significant']['mix'] =  {}
                        d['significant']['mix']['signal_set'] = []
                        d['significant']['mix']['score'] = []
                        d['significant']['mix']['chip_id']  = chip_id
                        d['significant']['mix']['message'] = []
                        if flag1:
                            d['significant']['mix']['signal_set'].append(cond1)
                            d['significant']['mix']['score'].append(r_mix1)
                            d['significant']['mix']['message'].append(msg)
                        if flag2:
                            d['significant']['mix']['signal_set'].append(cond2)
                            d['significant']['mix']['score'].append(r_mix2)
                            d['significant']['mix']['message'].append(msg)                            
                        if flag3:
                            d['significant']['mix']['signal_set'].append(cond3)
                            d['significant']['mix']['score'].append(r_mix3)
                            d['significant']['mix']['message'].append(msg)                            
                            
        if len(d['significant']) > 0: # significant edges exist
            print_dict(d, glb_param_dict)
            #result_dict_lst.append(d)

    return result_dict_lst


def three_way(target, tf_lst, expr_dict, expr_info_dict, glb_param_dict):
    ''' TBD '''
    return []


def establish_edges(expr_dict, expr_info_dict, bind_dict, bind_info_dict, input_dict, glb_param_dict):
    high_gene_lst = glb_param_dict['HIGH_PRIORITY_GENE'].split()    
    gene_lst = get_gene_list(glb_param_dict['GENE_LIST'])
    final_gene_lst = list(set(high_gene_lst)) # unique genes
    for x in gene_lst:
        if not x in high_gene_lst:
            final_gene_lst.append(x)
    
    update_global_param_dict(glb_param_dict, expr_info_dict)
    result_d = {'two_way_edges':{}, 'three_way_edges':{}}
    for g in final_gene_lst:
        tf_dict = get_tf(g, bind_dict, bind_info_dict, input_dict, glb_param_dict)
        if len(tf_dict) > 0:
            key = g
            if glb_param_dict['TWO_WAY'] == 'YES':
                two_dict = two_way(g, tf_dict, expr_dict, expr_info_dict, glb_param_dict)
                result_d['two_way_edges'][key] = two_dict
            if glb_param_dict['THREE_WAY'] == 'YES':
                three_dict = three_way(g, tf_dict, expr_dict, expr_info_dict, glb_param_dict) 
                result_d['three_way_edges'][key] = three_dict

    return result_d


def dumpclean(obj):
    '''
    show dictionary content, recursively
    '''
    if type(obj) == dict:
        for k, v in obj.items():
            if hasattr(v, '__iter__'):
                print(k)
                dumpclean(v)
            else:
                print('%s : %s' % (k, v))
    elif type(obj) == list:
        for v in obj:
            if hasattr(v, '__iter__'):
                dumpclean(v)
            else:
                print(v)
        else:
            print(obj)


# obsolete
def print_dict_list(dict_lst, agi2name_dict):
    for d in dict_lst:
        #dumpclean(d)
        if d['type'] == 'two-way':
            gene_id = d['target']
            head = '%s %s\t' % (gene_id, get_gene_name(gene_id, agi2name_dict))
            gene_id = d['TF']
            head += '%s %s\t' % (gene_id, get_gene_name(gene_id, agi2name_dict))
            d2 = d['significant']
            if 'all' in d2:
                s = '%4.2f\t%s\t%s\t%s' % (d2['all']['score'], 'all', '.', ' '.join(d2['all']['chip_id']))
                print(head + s)
            if 'user' in d2:
                s = '%4.2f\t%s\t%s\t%s' % (d2['user']['score'], 'user', ' '.join(d2['user']['signal_set']), ' '.join(d2['user']['chip_id']))
                print(head + s)
            if 'pos' in d2:
                s = '%4.2f\t%s\t%s\t%s' % (d2['pos']['score'], 'pos', ' '.join(d2['pos']['signal_set']), ' '.join(d2['pos']['chip_id']))
                print(head + s)
            if 'neg' in d2:
                s = '%4.2f\t%s\t%s\t%s' % (d2['neg']['score'], 'neg', ' '.join(d2['neg']['signal_set']), ' '.join(d2['neg']['chip_id']))
                print(head + s)
            if 'mix' in d2:
                n = len(d2['mix']['score'])
                for i in range(n):
                    s = '%4.2f\t%s\t%s\t%s' % (d2['mix']['score'][i], 'mix', ' '.join(d2['mix']['signal_set'][i]), ' '.join(d2['mix']['chip_id']))
                    print(head + s)
                    
def print_result(d, agi2name_dict):
    for k in d:
        print(k) # two-way or three-way
        d2 = d[k] 
        for k2 in d2: # k2 is a gene
            dlst = d2[k2]
            print_dict_list(dlst, agi2name_dict)



########## main ##################################################
r.r['options'](warn=-1) # supress warning message from rpy2
warnings.filterwarnings("ignore")
param_file = sys.argv[1] # a single prameter file
glb_param_dict = make_global_param_dict(param_file)
agi2name_dict = make_gene_name_AGI_map_dict(GENE_ID_TO_GENE_NAME)
glb_param_dict['name_conversion_dict'] = agi2name_dict
#print('Read expression data')
expr_dict = read_matrix_data(glb_param_dict['EXPRESSION_MATRIX'])
#print('DEBUG at2g07754 at R0000SRR1802166XX %g' % (expr_dict['xy']['AT2G07754']['R0000SRR1802166XX']))
expr_info_dict = read_info_data(glb_param_dict['EXPRESSION_INFO'])
#print('Read binding data')
bind_dict = read_matrix_data(glb_param_dict['BINDING_MATRIX'])
bind_info_dict = read_info_data(glb_param_dict['BINDING_INFO'])
input_dict = read_matrix_data(glb_param_dict['INPUT_MATRIX']) # newly added, for comparing with bw files
#print('Establish edges')
edge_d = establish_edges(expr_dict, expr_info_dict, bind_dict, bind_info_dict, input_dict, glb_param_dict)
#print_result(edge_d, agi2name_dict)