# Usage: python make_graphviz_file3B.py AT1G19850
#
#        Make plot: python make_graphviz_file3B.py AT1G65480 | dot -Tpdf -o result.pdf result.gv
#                   python make_graphviz_file3B.py AT1G65480 | neato -Goverlap=false -Tpdf -o result.pdf result.gv
#
#                   The plot is saved in result.pdf, and each little grey box contains a tissue name.
#                   Change 'pdf' to 'svg' to get a vector image.  Tissue name is in yellow box.  Double circle represents both a regulator and a regulatee.
#                   Egg represents a regulatee.  Oval represent a regulator.  Yellow arrow regulating.  Red arrow being regulated.
#
# Input file is specified in variable edge_file (result.skeleton.txt).  This file is generated by test_network4.py. 
# The tissue name is contained in the lines starting with '##', e.g., '##TF skeleton size in shoot: 15735.' contains 'shoot'.
# Edit the variable tissue_colour_dict and tissue_lst in function get_tissue_from_fname() to match with the tissue names.
#
#
# Purpose: Generate result.gv for Graphviz software dot. The single
# parameter AT1G19850 is a TF.  result.gv contains all edges from/to the TF
# in each tissue.  A tissue is a subgraph.  We can
# convert result.gv to a figure using 'dot -Tpdf -o result.pdf
# result.gv'.
#
# Created 6 July 2017, hui, slcu
# Last modified 11 July 2017, hui, slcu

import random
import numpy as np
import sys
from geneid2name import make_gene_name_AGI_map_dict, get_gene_name

NUM_TARGETS_CUTOFF = 5


def get_tissue_from_fname(fname):
    tissue_lst = [
        'seedling',
        'meristem',
        'flower',
        'aerial',
        'shoot',
        'seed',
        'leaf',
        'root',
        'stem']
    for x in tissue_lst:
        if x in fname:
            return x
    return 'unknown'


def get_edge(fname):
    ''' Return d = {'flower':{'tf':[target1,target2, ...]}, 'seed':{}} '''
    d = {}
    d2 = {} # the actual correlation coefficient, absolute value
    f = open(fname)
    for line in f:
        line = line.strip()
        if not line.startswith('#'):
            lst = line.split('\t')
            target = (lst[0].split('_'))[0]
            tf     = (lst[1].split('_'))[0]
            if not tf in d[tissue]:
                d[tissue][tf] = [target]
            else:
                d[tissue][tf].append(target)

            strength = abs(float(lst[2]))
            if not tf in d2[tissue]:
                d2[tissue][tf] = {target:strength}
            else:
                d2[tissue][tf][target] = strength
                
        else:
            tissue = get_tissue_from_fname(line)
            d[tissue] = {}
            d2[tissue] = {}            
    f.close()
    return d, d2


def in_same_tissue(source, target, node_dict):
    return node_dict[source] == node_dict[target]

def make_label(a, b):
    if b == '.':
        return a
    else:
        lst = b.split(';')
        return a + ' ' + lst[0]


def has_predecessor(tf, d):
    for k in d:
        if tf in d[k] and k != tf:
            return True
    return False

def get_num_successors(tf, d):
    if not tf in d:
        return 0
    return len(d[tf])

def get_shape(tf, d):
    ''' d = {'tf':[target1, target2]} '''
    p = has_predecessor(tf, d)
    s = get_num_successors(tf, d)
    if s > 0 and p: # tf is both a regulator and a regulatee
        return 'doublecircle'
    if s > 0 and not p:  # a regulator
        return 'oval' # regulator
    if p and s == 0:  # a regulatee
        return 'egg' # regulatee
    return 'point'

def get_color(tf, edge_dict, tissue):
    #colours = ['darkolivegreen1', 'darkolivegreen2', 'darkolivegreen3', 'darkolivegreen4', 'gold', 'gold1', 'gold2', 'gold3', 'gold4', 'darkgoldenrod', 'darkgoldenrod4']
    #colours = ['snow', 'snow1', 'snow2', 'snow3', 'snow4', 'gold', 'gold1', 'gold2', 'gold3', 'gold4']
    colours = ['springgreen', 'springgreen1', 'springgreen2', 'springgreen3', 'springgreen4', 'gold', 'gold1', 'gold2', 'gold3', 'gold4'] # darker colours means more important for that tissue
    d = {}
    total = 0
    for k in edge_dict:
        n = get_num_successors(tf, edge_dict[k])
        d[k] = n
        total += n
        #print('%s %d' % (k, n))
    if total == 0: # no successor
        return 'azure'
    return colours[min(int(10 * 1.0 * d[tissue] / total), len(colours)-1)]

def write_graphviz_file(fname, edge_dict, colour_dict, agi2name_dict, query_tf):

    f = open(fname, 'w')
    
    graph_dict = {} # record for each tissue the graph
    last_node = {} # record the last node added in each subgraph
    for k in edge_dict:
        graph_dict[k] = {'head':'', 'nodes':[], 'edges':[]}

    for k in edge_dict: # k is tissue
        node_added_dict = {} # make sure we don't add the same node twice
        edge_added_dict = {} # make sure an edge is not added twice
        tissue_node = '%s_node' % (k)
        graph_dict[k]['head'] = ''
        d = edge_dict[k] # d = {'tf1':[target1, target2, ...]}
        tf_lst = d.keys()
        for tf in tf_lst:
            node_tf = tf + '_' + k
            if tf == query_tf:
                ll = make_label(tf, get_gene_name(tf, agi2name_dict))
                shape = get_shape(tf, d)
                color = get_color(tf, edge_dict, k) # shape's boundary colour
                if not tf in node_added_dict:
                    graph_dict[k]['nodes'].append('     \"%s\" [label=\"%s\", fillcolor=%s, color=%s, shape=%s, style=filled];\n' % (node_tf, ll, color, colour_dict[k], shape))
                    node_added_dict[tf] = 'YES'
                for target in d[tf]:
                    ll = make_label(target, get_gene_name(target, agi2name_dict))
                    node_target = target + '_' + k
                    shape = get_shape(target, d)
                    color = get_color(target, edge_dict, k)
                    if not target in node_added_dict:
                        graph_dict[k]['nodes'].append('     \"%s\" [label=\"%s\", fillcolor=%s, color=%s, shape=%s, style=filled];\n' % (node_target, ll, color, colour_dict[k], shape))
                        node_added_dict[target] = 'YES'
                        last_node[k] = node_target 
                       
                    edge_key = tf + target
                    if not edge_key in edge_added_dict:
                        graph_dict[k]['edges'].append('     \"%s\" -> \"%s\" [color=%s];\n' % (node_tf, node_target, 'gold')) # out-going edge
                        edge_added_dict[edge_key] = 'YES'
                        
            else: # check if tf is a target of another tf
                for target in d[tf]:
                    if target == query_tf:
                        ll = make_label(tf, get_gene_name(tf, agi2name_dict))
                        node_tf = tf + '_' + k
                        shape = get_shape(tf, d)
                        color = get_color(tf, edge_dict, k)
                        node_target = target + '_' + k
                        if not tf in node_added_dict:
                            graph_dict[k]['nodes'].append('     \"%s\" [label=\"%s\", fillcolor=%s, color=%s, shape=%s, style=filled];\n' % (node_tf, ll, color, colour_dict[k], shape))
                            node_added_dict[tf] = 'YES'
                            last_node[k] = node_target
                        edge_key = tf + target
                        if not edge_key in edge_added_dict:
                            graph_dict[k]['edges'].append('     \"%s\" -> \"%s\" [color=%s];\n' % (node_tf, node_target, 'red'))

        if graph_dict[k]['nodes'] != []:
            node_label = k + '_label_node'                
            graph_dict[k]['nodes'].append('     \"%s\" [label=\"%s\", shape=box, color=yellow, style=filled, height=0.8, width=1.6];\n' % (node_label, k.upper()))

    # write graphviz file
    s0 = 'digraph G {\n    graph[splines=true, ranksep=2, fontname=Arial];\n    node[fontname=Arial];\n'
    s0 += '    {rank=sink; '  # move label node to bottom
    for k in last_node:
        if graph_dict[k]['nodes'] != []:
            node_label = k + '_label_node'        
            s0 += '%s;' % (node_label)
    s0 += '}\n'  
    for k in graph_dict:
        s0 += graph_dict[k]['head']
        node_label = k + '_label_node'
        for x in graph_dict[k]['nodes']:
            s0 += x
        for x in graph_dict[k]['edges']:
            s0 += x
        if k in last_node:
            s0 += '     \"%s\" -> \"%s\" [arrowhead=none, style=invis];\n' % (last_node[k], node_label)

    s0 += '}\n'
    f.write(s0)
    f.close()


# main

GENE_ID_TO_GENE_NAME    = '/home/hui/network/v03/Data/information/AGI-to-gene-names_v2.txt'
agi2name_dict = make_gene_name_AGI_map_dict(GENE_ID_TO_GENE_NAME)

edge_file = 'result.skeleton.txt'  # prepared by test_network4.py

tissue_colour_dict = {
        'seedling':'greenyellow',
        'meristem':'skyblue4',
        'flower':'lightpink',
        'aerial':'cyan',
        'shoot':'forestgreen',
        'seed':'black',
        'leaf':'green',
        'root':'gold',
        'stem':'orange4'}

if len(sys.argv) < 2:
    print('Need to specifiy a gene ID, e.g., AT1G19850.')    
    sys.exit()
else:
    query_tf = sys.argv[1]
    
edge_dict, edge_dict_r = get_edge(edge_file)
write_graphviz_file('result.gv', edge_dict, tissue_colour_dict, agi2name_dict, query_tf)