# Make tissue specific networks

import os, sys
from geneid2name import make_gene_name_AGI_map_dict

def get_tfs(fname_lst):
    d = {}    
    for fname in fname_lst:
        f = open(fname)
        for line in f:
            line = line.strip()
            lst = line.split('\t')
            tf = lst[1].split()[0]
            if not tf in d:
                d[tf] = 1
            else:
                d[tf] += 1
        f.close()
    return d

def get_tissue_from_fname(fname):
    tissue_lst = [
        'seedling',
        'meristem',
        'flower',
        'aerial',
        'shoot',
        'seed',
        'leaf',
        'root',
        'stem']
    for x in tissue_lst:
        if x in fname:
            return x
    return 'unknown'

def get_edges_consisting_of_tfs(fname_lst, tf_dict):
    d = {}
    for fname in fname_lst:
        kt = get_tissue_from_fname(fname)
        d[kt] = {}
        f = open(fname)
        for line in f:
            line = line.strip()
            lst = line.split('\t')
            target = lst[0].split()[0].strip()
            tf     = lst[1].split()[0].strip()
            k = target + '_' + tf
            score = float(lst[2])
            if tf in tf_dict and target in tf_dict:
                if not k in d[kt]:
                    d[kt][k] = [(lst[0], lst[1], score)]
                else:
                    d[kt][k].append((lst[0], lst[1], score))
        f.close()
    return d

def get_degree(fname_lst, tf_dict):
    d_out = {}
    d_in  = {}
    d_all = {}
    for fname in fname_lst:
        kt = get_tissue_from_fname(fname)
        d_out[kt] = {}
        d_in[kt]  = {}
        d_all[kt] = {}
        f = open(fname)
        for line in f:
            line = line.strip()
            lst = line.split('\t')
            target = lst[0].split()[0].strip()
            tf     = lst[1].split()[0].strip()
            if True or tf in tf_dict and target in tf_dict:
                if not tf in d_out[kt]:
                    d_out[kt][tf] = 1
                else:
                    d_out[kt][tf] += 1
                    
                if not target in d_in[kt]:
                    d_in[kt][target] = 1
                else:
                    d_in[kt][target] += 1
                    
                if not target in d_all[kt]:
                    d_all[kt][target] = 1
                else:
                    d_all[kt][target] += 1
                    
                if not tf in d_all[kt]:
                    d_all[kt][tf] = 1
                else:
                    d_all[kt][tf] += 1
                
        f.close()
    return d_all, d_out, d_in
    

def simplify(s):
    result = ''
    lst = s.split('\t')
    a = (lst[0].split()[1]).split(';')[0]
    if a == '.':
        a = lst[0].split()[0]
    else:
        a = lst[0].split()[0] + '_' + (lst[0].split()[1]).split(';')[0]
    b = (lst[1].split()[1]).split(';')[0]
    if b == '.':
        b = lst[1].split()[0]
    else:
        b = lst[1].split()[0] + '_' + (lst[1].split()[1]).split(';')[0]
    return '%s\t%s\t%s' % (a, b, lst[2])
    
# main
GENE_ID_TO_GENE_NAME    = '../Data/information/AGI-to-gene-names_v2.txt'
agi2name_dict = make_gene_name_AGI_map_dict(GENE_ID_TO_GENE_NAME)

edge_file_lst = [
    '/home/hui/network/v03/Data/history/edges/many_targets/edges.txt.simple.correlation.seedling.txt.20170629_203729',
    '/home/hui/network/v03/Data/history/edges/many_targets/edges.txt.simple.correlation.meristem.txt.20170629_203729',
    '/home/hui/network/v03/Data/history/edges/many_targets/edges.txt.simple.correlation.flower.txt.20170629_203729',
    '/home/hui/network/v03/Data/history/edges/many_targets/edges.txt.simple.correlation.aerial.txt.20170629_203729',
    '/home/hui/network/v03/Data/history/edges/many_targets/edges.txt.simple.correlation.shoot.txt.20170629_203729',
    '/home/hui/network/v03/Data/history/edges/many_targets/edges.txt.simple.correlation.seed.txt.20170629_203729',
    '/home/hui/network/v03/Data/history/edges/many_targets/edges.txt.simple.correlation.leaf.txt.20170629_203729',
    '/home/hui/network/v03/Data/history/edges/many_targets/edges.txt.simple.correlation.root.txt.20170629_203729',
    '/home/hui/network/v03/Data/history/edges/many_targets/edges.txt.simple.correlation.stem.txt.20170629_203729'
]

tf_dict = get_tfs(edge_file_lst)

f = open('result.skeleton.txt', 'w')
print('Total number of TFs: %d' % (len(tf_dict)))
d0 = get_edges_consisting_of_tfs(edge_file_lst, tf_dict)
for kt in d0:  # kt is tissue
    f.write('##TF skeleton size in %s: %d.\n' % (kt, len(d0[kt])))
    d = d0[kt]
    for k in d:
        lst =  d[k]
        for x in lst: # {'shoot':{'target_tf':[], }, 'flower':{} }
            max_score = -9
            s = ''
            if abs(x[2]) > max_score:
                s =  '%s\t%s\t%4.2f' % (x[0], x[1], x[2])
                max_score = x[2]
        f.write(simplify(s) + '\n')
f.close()

# for each TF, get its out-degree and in-degree in each tissue
dd_all, dd_out, dd_in = get_degree(edge_file_lst, tf_dict)
f = open('result.out.txt', 'w')
head_lst = ['TF']
for k in dd_out:
    head_lst.append(k)
f.write('%s\n' %('\t'.join(head_lst)))
for tf in tf_dict:
    s = tf
    name = '.'
    if tf in agi2name_dict and agi2name_dict[tf] != tf:
        name = agi2name_dict[tf]
    s += ' ' + name    
    for k in dd_out:
        if tf in dd_out[k]:
            s += '\t%d' % (dd_out[k][tf])
        else:
            s += '\t0'
    f.write(s + '\n')
f.close()

f = open('result.in.txt', 'w')
head_lst = ['TF']
for k in dd_in:
    head_lst.append(k)
f.write('%s\n' %('\t'.join(head_lst)))
for tf in tf_dict:
    s = tf
    name = '.'
    if tf in agi2name_dict and agi2name_dict[tf] != tf:
        name = agi2name_dict[tf]
    s += ' ' + name
    for k in dd_in:
        if tf in dd_in[k]:
            s += '\t%d' % (dd_in[k][tf])
        else:
            s += '\t0'
    f.write(s + '\n')
f.close()

f = open('result.all.txt', 'w')
head_lst = ['TF']
for k in dd_all:
    head_lst.append(k)
f.write('%s\n' %('\t'.join(head_lst)))
for tf in tf_dict:
    s = tf
    name = '.'
    if tf in agi2name_dict and agi2name_dict[tf] != tf:
        name = agi2name_dict[tf]
    s += ' ' + name    
    for k in dd_all:
        if tf in dd_all[k]:
            s += '\t%d' % (dd_all[k][tf])
        else:
            s += '\t0'
    f.write(s + '\n')
f.close()