# Usage: python draw_subnetwork.py edges.txt
# Purpose: draw a sub-network given a list of genes from thermomorphogenesis paper.
#          The paper Molecular and genetic control of plant thermomorphogenesis. https://doi.org/10.1038/nplants.2015.190
#
# Created on 5 December 2019 by Hui Lan (lanhui@zjnu.edu.cn)


import os, sys
import networkx as nx
import pylab as plt
import glob
import math
from networkx.algorithms.distance_measures import diameter, eccentricity


def build_network_from_file(edge_fname, gene_lst, gene_dict):
    G = nx.DiGraph()    

    for g in gene_lst:
        G.add_node(gene_dict[g])
        
    f = open(edge_fname)
    for line in f:
        line = line.strip()
        lst = line.split('\t')
        if len(lst) == 10:

            g1 = lst[0].split()[0] # target gene ID
            g2 = lst[1].split()[0] # source gene ID        

                
            strength = float(lst[8])
            method_or_tissue = lst[9]            


            g1_label = lst[0].split()[1].split(';')[0] if lst[0].split()[1] != '.' else g1
            g1_name = lst[0].split()[1] if lst[0].split()[1] != '.' else ''
                

            g2_label = lst[1].split()[1].split(';')[0] if lst[1].split()[1] != '.' else g2
            g2_name = lst[1].split()[1] if lst[1].split()[1] != '.' else ''

            if g1 in gene_lst and g2 in gene_lst:
                G.add_node(gene_dict[g1], full_name=g1_name, label=g1_label) # if g1 is also a TF, then istf='0' will overwrite it in the following for loop
                G.add_node(gene_dict[g2], full_name=g2_name, label=g2_label) # tf_category contains default TF category code.  It can be modified later given user's input
                G.add_edge(gene_dict[g2], gene_dict[g1], weight=strength, strength=strength, method=method_or_tissue) # g2 is source, and g1 is target

    f.close()

    return G



def build_thermomorphogenesis_network(gene_lst, gene_dict):
    ''' Edges from thermo.png in my asus laptop. '''
    G = nx.DiGraph()

    for g in gene_lst:
        G.add_node(gene_dict[g])

    pairs = [('PIF4', 'PAR1'), ('PIF4', 'PRE1'), ('PIF4', 'YUC8'), ('PIF4', 'TAA1'), ('PIF4', 'IAA4'), ('PIF4', 'SAUR21'),
             ('HY5', 'PAR1'), ('HY5', 'PRE1'), ('HY5', 'YUC8'), ('HY5', 'TAA1'), ('HY5', 'IAA4'), ('HY5', 'SAUR21'),
             ('FCA', 'PAR1'), ('FCA', 'PRE1'), ('FCA', 'YUC8'), ('FCA', 'TAA1'), ('FCA', 'IAA4'), ('FCA', 'SAUR21'),
             ('BZR1', 'PAR1'), ('BZR1', 'PRE1'), ('BZR1', 'YUC8'), ('BZR1', 'TAA1'), ('BZR1', 'IAA4'), ('BZR1', 'SAUR21'),
             ('ARF6', 'PAR1'), ('ARF6', 'PRE1'), ('ARF6', 'YUC8'), ('ARF6', 'TAA1'), ('ARF6', 'IAA4'), ('ARF6', 'SAUR21'),
             ('PAR1', 'IBH1'), ('PRE1', 'IBH1'),
             ('PAR1', 'PIF4'), ('PRE1', 'PIF4'),
             ('IBH1', 'HBI1'),
             ('IAA4', 'ARF6'),
             ('ARF6', 'SAUR21')
    ] # see paper Molecular and genetic control of plant thermomorphogenesis. https://doi.org/10.1038/nplants.2015.190

    for (g2, g1) in pairs:  # g2 is source, g1 is target
        G.add_edge(g2, g1, weight=2) # g2 is source, and g1 is target


    return G


def compute_total_edge_weight(edges, G):
    total = 0
    for e in edges:
        u = e[0]
        v = e[1]
        total += G[u][v]['weight']
    return total


def draw_graph(G, fname):
    pos=nx.circular_layout(G)
    tau = 2.5
    elarge=[(u,v) for (u,v,d) in G.edges(data=True) if d['weight'] >tau]
    esmall=[(u,v) for (u,v,d) in G.edges(data=True) if d['weight'] <=tau]
    labels = {}
    for (n,d) in G.nodes(data=True):
        if 'label' in d:
            labels[n] = d['label']
        else:
            labels[n] = n
    nx.draw_networkx_nodes(G,pos,alpha=0.1)
    nx.draw_networkx_edges(G,pos,edgelist=elarge,width=1,alpha=0.2)
    nx.draw_networkx_edges(G,pos,edgelist=esmall,width=1,alpha=0.1,edge_color='k',style='dashed')
    nx.draw_networkx_labels(G,pos,font_size=8,font_color='k',font_family='sans-serif')
    plt.axis('off')
    plt.savefig(fname) 
    plt.close()
    #plt.show() # display    


def better_date(s):
    ''' Add a dash between year and month, and a dash between month and day.'''
    if len(s) == 8:
        return '-'.join([s[:4], s[4:6], s[6:]])
    else:
        return s


def draw_graph2(G, fname, date):

    pos = nx.circular_layout(G)
    all_edges = []
    all_widths = []

    for (u, v, d) in G.edges(data=True):
        all_edges.append((u, v))
        all_widths.append(math.sqrt(d['weight']))

    nx.draw_networkx_nodes(G,pos,alpha=0.05)
    nx.draw_networkx_edges(G,pos,edgelist=all_edges,width=all_widths,alpha=0.2,edge_color='k',style='dashed')
    nx.draw_networkx_labels(G,pos,font_size=11,font_color='b',font_family='sans-serif')
    plt.axis('off')
    plt.title(better_date(date))
    plt.savefig(fname) 
    plt.close()
    #plt.show() # display    


def draw_graph3(G, fname):

    pos = nx.circular_layout(G)
    all_edges = []

    for (u, v, d) in G.edges(data=True):
        all_edges.append((u, v))

    nx.draw_networkx_nodes(G,pos,alpha=0.05)
    nx.draw_networkx_edges(G,pos,edgelist=all_edges,alpha=0.2,edge_color='k',style='dashed')
    nx.draw_networkx_labels(G,pos,font_size=11,font_color='b',font_family='sans-serif')
    plt.axis('off')
    plt.savefig(fname) 
    plt.close() # it is important to close the plot before creating another one.  Otherwise, plots will overlap.
    #plt.show() # display    
    

## main

thermomorphogenesis_genes = [
    'AT4G28720',
    'AT2G25930',
    'AT2G40080',
    'AT3G46640',
    'AT5G11260',
    'AT2G43010',
    'AT3G59060',
    'AT4G10180',
    'AT2G32950',
    'AT3G13550',
    'AT4G05420',
    'AT4G21100',
    'AT2G46340',
    'AT4G11110',
    'AT3G15354',
    'AT1G53090',
    'AT1G02340',
    'AT4G08920',
    'AT4G39950',
    'AT2G22330',
    'AT2G42870',
    'AT5G39860',
    'AT1G70560',
    'AT3G62980',
    'AT4G03190',
    'AT3G26810',
    'AT1G12820',
    'AT4G24390',
    'AT5G49980',
    'AT5G01830',
    'AT5G18010',
    'AT5G18020',
    'AT5G18050',
    'AT5G18060',
    'AT5G18080',
    'AT1G29440',
    'AT1G29510',
    'AT4G18710',
    'AT1G75080',
    'AT1G30330',
    'AT1G19850',
    'AT3G33520',
    'AT4G16280',
    'AT2G43060',
    'AT2G18300',
    'AT4G16780',
    'AT1G01060',
    'AT1G22770',
    'AT4G25420',
    'AT1G15550',
    'AT1G78440',
    'AT5G43700',
    'AT4G32280',
    'AT2G38120',
    'AT1G15580',
]


thermomorphogenesis_genes_small = [
    'AT2G43010',	#PIF4
    'AT5G11260',	#HY5
    'AT2G42870',	#PAR1
    'AT5G39860',	#PRE1
    'AT5G43700',        #IAA4
    'AT4G16280',	#FCA
    'AT2G43060',	#IBH1
    'AT2G18300',	#HBI1
    'AT4G28720',	#YUC8
    'AT1G70560',	#TAA1
    'AT1G30330',	#ARF6
    'AT1G19850',        #ARF5
    'AT5G01830',	#SAUR21    
    'AT1G75080'   	#BZR1
#    'AT2G25930',	#ELF3
#    'AT2G40080',	#ELF4
#    'AT3G46640',	#LUX
]

gene_dict = {
    'AT2G43010':'PIF4',
    'AT5G11260':'HY5',
    'AT2G42870':'PAR1',
    'AT5G39860':'PRE1',
    'AT5G43700':'IAA4',
    'AT4G16280':'FCA',
    'AT2G43060':'IBH1',
    'AT2G18300':'HBI1',
    'AT4G28720':'YUC8',
    'AT1G70560':'TAA1',
    'AT1G30330':'ARF6',
    'AT1G19850':'ARF5',
    'AT5G01830':'SAUR21',
    'AT1G75080':'BZR1',
    'AT2G25930':'ELF3',
    'AT2G40080':'ELF4',
    'AT3G46640':'LUX'
}

print('Make sub graphs ...')

G0 = build_thermomorphogenesis_network(thermomorphogenesis_genes_small, gene_dict)
print('Number of edges in the paper thermomorphogenesis is %d' % (len(G0.edges())))
draw_graph2(G0, '../Data/temp/graph-%s.pdf' % ('20160101'), '')

graph_lst = []
graph_names = []
for fname in sorted(glob.glob('../Analysis/edges.txt.2020*')):
    if fname == '../Analysis/edges.txt.20190801' or '.gz' in fname:
        continue
    print(fname)
    graph_names.append(fname)
    G = build_network_from_file(fname, thermomorphogenesis_genes_small, gene_dict)
    graph_lst.append(G)


for i in range(len(graph_lst)):
    G = graph_lst[i]
    print('Graph from %s' % (graph_names[i]))    
    e = G.edges(data=True)
    n = len(e)
    print('Number of edges is %d' % (n))
    print('------------------------------------------------------------------------')    
    draw_graph2(G, '../Data/temp/graph-%s.pdf' % (graph_names[i].split('.')[-1]), graph_names[i].split('.')[-1])


print('Compute network differences ...')

print('In G0 but not in G ...')
Gdiff1 = nx.difference(G0, G)
draw_graph3(Gdiff1, '../Data/temp/graph-in-G0-not-in-G.pdf')
print(Gdiff1.edges())
print(len(Gdiff1.edges()))

print('In G but not in G0 ...')
Gdiff2 = nx.difference(G, G0)
draw_graph3(Gdiff2, '../Data/temp/graph-in-G-not-in-G0.pdf')
print(Gdiff2.edges())
print(len(Gdiff2.edges()))

print('Compute network intersection ...')
print('In both ...')
Gcommon = nx.intersection(G0, G)
draw_graph3(Gcommon, '../Data/temp/graph-in-G0-and-in-G.pdf')
print(Gcommon.edges())
print(len(Gcommon.edges()))

print('Compute edit distance ...')
ged = nx.algorithms.similarity.graph_edit_distance(G0, G) # sometimes take very long time to finish
print('Edit distance is %4.0f' % (ged))