# Usage: python make_graphviz_file3C.py AT1G19850
#
#        Make plot: python make_graphviz_file3C.py AT1G65480 | sfdp -Goverlap=false -Tpdf -o result.flower.pdf result.flower.gv
#                   unflatten -f -l 3 result.flower.gv | dot -Tsvg -o result.flower.svg 
#
# Purpose: Generate result.txt for Graphviz software dot. The single
# parameter AT1G19850 is a TF.
# The query neighbours of the TF's neighbours are also shown.
#
# Created 10 July 2017, hui, slcu

import random
import numpy as np
import sys
from geneid2name import make_gene_name_AGI_map_dict, get_gene_name

PERCENT = 1
NUM_TARGETS_CUTOFF = 5

def get_tf_tissue(fname):
    d = {}
    f = open(fname)
    lines = f.readlines()
    f.close()
    head = lines[0].strip()
    head_lst = head.split('\t')
    head_lst = head_lst[1:] # remove TF
    for line in lines[1:]:
        line = line.strip()
        lst = line.split('\t')
        lst2 = lst[1:]
        lst3 = [int(x) for x in lst2]
        lst4 = np.array(lst3)
        median_val = np.median(lst4)
        tissue = []
        for i in range(len(lst2)):
            if int(lst2[i]) >= max(median_val, 1) or int(lst2[i]) >= NUM_TARGETS_CUTOFF:
                tissue.append(head_lst[i])
        tf = (lst[0].split())[0]
        d[tf] = tissue # tf is assigned with a list of tissues, the tissue with node degree greater than median are selected.
    return d


def get_tissue_from_fname(fname):
    tissue_lst = [
        'seedling',
        'meristem',
        'flower',
        'aerial',
        'shoot',
        'seed',
        'leaf',
        'root',
        'stem']
    for x in tissue_lst:
        if x in fname:
            return x
    return 'unknown'


def get_edge(fname):
    ''' Return d = {'flower':{'tf':[target1,target2, ...]}, 'seed':{}} '''
    d = {}
    d2 = {} # the actual correlation coefficient, absolute value
    f = open(fname)
    for line in f:
        line = line.strip()
        if not line.startswith('#'):
            lst = line.split('\t')
            target = (lst[0].split('_'))[0]
            tf     = (lst[1].split('_'))[0]
            if not tf in d[tissue]:
                d[tissue][tf] = [target]
            else:
                d[tissue][tf].append(target)

            strength = abs(float(lst[2]))
            if not tf in d2[tissue]:
                d2[tissue][tf] = {target:strength}
            else:
                d2[tissue][tf][target] = strength
                
        else:
            tissue = get_tissue_from_fname(line)
            d[tissue] = {}
            d2[tissue] = {}            
    f.close()
    return d, d2


def in_same_tissue(source, target, node_dict):
    return node_dict[source] == node_dict[target]


def make_label(a, b):
    if b == '.':
        return a
    else:
        lst = b.split(';')
        return a + '_' + lst[0]


def has_predecessor(tf, d):
    for k in d:
        if tf in d[k] and k != tf:
            return True
    return False

def get_num_successors(tf, d):
    if not tf in d:
        return 0
    return len(d[tf])

def get_shape(tf, d):
    ''' d = {'tf':[target1, target2]} '''
    p = has_predecessor(tf, d)
    s = get_num_successors(tf, d)
    if s > 0 and p:
        return 'doublecircle'
    if s > 0:
        return 'ellipse' # regulator
    if p:
        return 'egg' # regulatee


def get_color(tf, edge_dict, tissue):
    #colours = ['darkolivegreen1', 'darkolivegreen2', 'darkolivegreen3', 'darkolivegreen4', 'gold', 'gold1', 'gold2', 'gold3', 'gold4', 'darkgoldenrod', 'darkgoldenrod4']
    #colours = ['snow', 'snow1', 'snow2', 'snow3', 'snow4', 'gold', 'gold1', 'gold2', 'gold3', 'gold4']
    colours = ['springgreen', 'springgreen1', 'springgreen2', 'springgreen3', 'springgreen4', 'gold', 'gold1', 'gold2', 'gold3', 'gold4']
    d = {}
    total = 0
    for k in edge_dict:
        n = get_num_successors(tf, edge_dict[k])
        d[k] = n
        total += n
        #print('%s %d' % (k, n))
    if total == 0:
        return 'azure'
    return colours[min(int(10 * 1.0 * d[tissue] / total), len(colours)-1)]


def make_more_string(n, tissue, edge_dict, colour_dict, agi2name_dict, query_tf):
    result = ''
    d = edge_dict[tissue]
    if n in d:
        for target in d[n]:
            ll = make_label(target, get_gene_name(target, agi2name_dict))
            shape = get_shape(target, d)
            color = get_color(target, edge_dict, tissue)
            node_target = target + '_' + tissue + '.2'
            if target != query_tf:
                result += '     \"%s\" [label=\"%s\", fillcolor=%s, color=%s, shape=%s, style=filled];\n' % (node_target, ll, color, colour_dict[tissue], shape)

                node_n = n + '_' + tissue
                if random.uniform(0, 1) <= PERCENT:
                    result += '     \"%s\" -> \"%s\" [color=%s];\n' % (node_n, node_target, 'gold')

    for tf in d:
        if tf != query_tf:
            if n in d[tf]: # n is successor
                ll = make_label(tf, get_gene_name(tf, agi2name_dict))
                shape = get_shape(tf, d)
                color = get_color(tf, edge_dict, tissue)
                node_tf = tf + '_' + tissue + '.2'
                result += '     \"%s\" [label=\"%s\", fillcolor=%s, color=%s, shape=%s, style=filled];\n' % (node_tf, ll, color, colour_dict[tissue], shape)

                node_n = n + '_' + tissue
                if random.uniform(0, 1) <= PERCENT:
                    result += '     \"%s\" -> \"%s\" [color=%s];\n' % (node_tf, node_n, 'red')

    return result


def write_graphviz_file(fname, edge_dict, colour_dict, agi2name_dict, query_tf):

    f = open(fname, 'w')
    
    graph_dict = {}
    more = {}
    for k in edge_dict:
        graph_dict[k] = {'head':'', 'nodes':[], 'edges':[]}

    for k in edge_dict: # k is tissue
        neighbours = []
        node_added_dict = {}
        tissue_node = '%s_node' % (k)
        graph_dict[k]['head'] = ''
        d = edge_dict[k]
        tf_lst = d.keys()
        for tf in tf_lst:
            node_tf = tf + '_' + k            
            if tf == query_tf:
                ll = make_label(tf, get_gene_name(tf, agi2name_dict))
                shape = get_shape(tf, d)
                color = get_color(tf, edge_dict, k)
                if not tf in node_added_dict:
                    graph_dict[k]['nodes'].append('     \"%s\" [label=\"%s\", fillcolor=%s, color=%s, shape=%s, style=filled];\n' % (node_tf, 'Query gene: '+ll, 'DeepSkyBlue', colour_dict[k], shape))
                    node_added_dict[tf] = 'YES'
                for target in d[tf]:
                    ll = make_label(target, get_gene_name(target, agi2name_dict))
                    node_target = target + '_' + k
                    shape = get_shape(target, d)
                    color = get_color(target, edge_dict, k)
                    if random.uniform(0, 1) <= PERCENT:
                        if not target in node_added_dict:
                            neighbours.append(target)
                            graph_dict[k]['nodes'].append('     \"%s\" [label=\"%s\", fillcolor=%s, color=%s, shape=%s, style=filled];\n' % (node_target, ll, color, colour_dict[k], shape))
                            node_added_dict[target] = 'YES'
                            graph_dict[k]['edges'].append('     \"%s\" -> \"%s\" [color=%s];\n' % (node_tf, node_target, 'gold'))
            else: # check if tf is a target of another tf
                for target in d[tf]:
                    if target == query_tf:
                        ll = make_label(tf, get_gene_name(tf, agi2name_dict))
                        node_tf = tf + '_' + k
                        shape = get_shape(tf, d)
                        color = get_color(tf, edge_dict, k)
                        node_target = target + '_' + k
                        if random.uniform(0, 1) <= PERCENT:
                            if not tf in node_added_dict:
                                neighbours.append(tf)                                
                                graph_dict[k]['nodes'].append('     \"%s\" [label=\"%s\", fillcolor=%s, color=%s, shape=%s, style=filled];\n' % (node_tf, ll, color, colour_dict[k], shape))
                                node_added_dict[tf] = 'YES'
                            graph_dict[k]['edges'].append('     \"%s\" -> \"%s\" [color=%s];\n' % (node_tf, node_target, 'red'))
        neighbours = list(set(neighbours))
        more[k] = ''
        for n in neighbours:
            more[k] += make_more_string(n, k, edge_dict, colour_dict, agi2name_dict, query_tf)
            

    for k in graph_dict:
        if graph_dict[k]['nodes'] != []:
            f = open(fname + '.' +  k + '.gv', 'w')
            s0 = 'digraph G {\n    graph[splines=true, ranksep=3, fontname=Arial];\n    node[fontname=Arial];\n'
            s0 += graph_dict[k]['head']
            for x in graph_dict[k]['nodes']:
                s0 += x
            for x in graph_dict[k]['edges']:
                s0 += x
            if k in more:
                s0 += more[k]
            s0 += '}\n'
            f.write(s0)        
            f.close()

# main

GENE_ID_TO_GENE_NAME    = '../Data/information/AGI-to-gene-names_v2.txt'
agi2name_dict = make_gene_name_AGI_map_dict(GENE_ID_TO_GENE_NAME)

edge_file = 'result.skeleton.txt'  # prepared by test_network4.py
node_file = 'result.out.txt'       # prepared by test_network4.py

tissue_colour_dict = {
        'seedling':'greenyellow',
        'meristem':'skyblue4',
        'flower':'lightpink',
        'aerial':'cyan',
        'shoot':'forestgreen',
        'seed':'black',
        'leaf':'green',
        'root':'gold',
        'stem':'orange4'}



if len(sys.argv) < 2:
    sys.exit()
else:
    query_tf = sys.argv[1]
    
#tf_tissue_dict = get_tf_tissue(node_file )
edge_dict, edge_dict_r = get_edge(edge_file)
write_graphviz_file('result', edge_dict, tissue_colour_dict, agi2name_dict, query_tf)