# Usage: python slice_binding_to_JSON.py parameter_for_net.txt
import sys, os, operator, itertools
import numpy as np
import json

JSON_DIR = '../Data/history/bind/json2' # contains json for all genes

GLB_PARAM_SYMBOL    = '%%'
DATA_SYMBOL         = '@'

def read_matrix_data(fname):
    ''' 
    fname - a file, first line is head, first column is row name.
    '''
    
    lineno = 0
    colid = []
    rowid = []
    d =  {}  # {gene1:{cond1:val1, cond2:val2, ...}, gene2: {...}, ...}
    d2 = {} # {cond1:{gene1:val1, gene2:val2, ...}, cond2: {...}, ...}
    d3 = {} # {gene1: [], gene2: [], ...}
    d4 = {} # {cond1:[], cond2:[], ...}

    f = open(fname)
    lines = f.readlines()
    f.close()

    head_line = lines[0].strip()
    lst = head_line.split()
    colid = lst[1:]

    for c in colid:
        d2[c] = {}
        d4[c] = []
    
    for line in lines[1:]:
        line = line.strip()
        lst = line.split()
        g = lst[0]
        rowid.append(g)
        d[g] = {}
        levels = lst[1:]
        if len(levels) != len(colid):
            print('Incomplete columns at row %s' % (g))
            sys.exit()
            
        d3[g] = []
        for i in range(len(colid)):
            c = colid[i]
            d[g][c]  = float(levels[i])
            d2[c][g] = float(levels[i])
            d3[g].append(float(levels[i]))
            d4[c].append(float(levels[i]))
        lineno += 1

    d_return = {}
    d_return['xy'] = d  # first gene, then condition
    d_return['yx'] = d2 # first condition, then gene
    d_return['xx'] = d3 # each item is an array of gene expression levels, i.e., each item is a row
    d_return['yy'] = d4 # each item is an array of gene expression levels, i.e., each item is a column
    d_return['nrow'] = lineno - 1
    d_return['ncol'] = len(colid)
    d_return['rowid'] = rowid
    d_return['colid'] = colid    

    d4_sorted = {}
    for k in d4:
        d4_sorted[k] = sorted(d4[k], reverse=True)
    d_return['yy_sorted'] = d4_sorted

    return d_return

# read paramters


def get_key_value(s):
    lst = s.split('=')
    k, v = lst[0], lst[1]
    return (k.strip(), v.strip())


def get_value(s, delimit):
    lst = s.split(delimit)
    return lst[1].strip()

def read_info_data(fname):
    ''' Read chip-seq data information '''

    if not os.path.exists(fname):
        print('%s not exists.' % (fname) )
        sys.exit()
        
    d = {'ID_LIST':[]}
    f = open(fname)
    lines = f.readlines()
    f.close()
    for line in lines:
        line = line.strip()
        if line == '' or line.startswith('#') or line.startswith('%'):
            continue
        if line.startswith(DATA_SYMBOL):
            s = line[line.rfind(DATA_SYMBOL[-1])+1:]
            s = s.strip()
            if s in d:
                print('ID %s duplicate' % (s))
                sys.exit()
            d[s] = {'PROTEIN_ID':'', 'PROTEN_NAME':'', 'DATA_NAME':'', 'DATA_FORMAT':'', 'DESCRIPTION':'', 'LOCATION':'', 'NOTE':''}
            d['ID_LIST'].append(s)
        if line.startswith('DESCRIPTION:'):
            d[s]['DESCRIPTION'] = get_value(line, ':')
        elif line.startswith('PROTEN_NAME:'):
            d[s]['PROTEN_NAME'] = get_value(line, ':')
        elif line.startswith('PROTEIN_ID:'):
            d[s]['PROTEIN_ID'] = get_value(line, ':')
        elif line.startswith('DATA_NAME:'):
            d[s]['DATA_NAME'] = get_value(line, ':')                        
        elif line.startswith('DATA_FORMAT:'):
            d[s]['DATA_FORMAT'] = get_value(line, ':')
        elif line.startswith('LOCATION:'):
            d[s]['LOCATION'] = get_value(line, ':')
        elif line.startswith('NOTE:'):
            d[s]['NOTE'] = get_value(line, ':')

    return d


def make_global_param_dict(fname):
    f = open(fname)
    d = {}
    for line in f:
        line = line.strip()
        if line.startswith(GLB_PARAM_SYMBOL):
            s = line[line.rfind(GLB_PARAM_SYMBOL[-1])+1:]
            lst = s.split('\t')  # separate items by TAB
            for x in lst:
                if x != '':
                    k, v = get_key_value(x)
                    d[k] = v
    f.close()
    return d


def make_json_file(bind_dict, bind_info_dict, dir_name, glb_param_dict):
    if not os.path.isdir(dir_name): # create the directory if not exist
        os.makedirs(dir_name)

    d = bind_dict['xy']
    col_name_lst = bind_dict['colid']
    row_name_lst = bind_dict['rowid']
    for g in row_name_lst:
        #print(g)
        d2 = d[g]
        d3 = {}
        for k in sorted(d2.keys()):
            data_type = bind_info_dict[k]['DATA_FORMAT'].upper()
            if data_type == 'NARROWPEAK':
                data_type = 'NP' # short name for narrowPeak
            value = d2[k]
            d3[k] = {'v':value, 't':data_type}
        filename = os.path.join(dir_name, g + '.json')
        with open(filename, 'w') as f:
            json.dump(d3, f)
    

### main
param_file = sys.argv[1] # a single prameter file
glb_param_dict = make_global_param_dict(param_file)
#print('Read binding matrix ...')
binding_dict = read_matrix_data(glb_param_dict['BINDING_MATRIX'])
bind_info_dict = read_info_data(glb_param_dict['BINDING_INFO'])
#print('Make json files ...')
make_json_file(binding_dict, bind_info_dict, JSON_DIR, glb_param_dict)