# Last modified 13 August 2019

TARGET_TF_FILE     <- "../Data/information/target_tf.txt"
DATA_FILE          <- "../Data/history/expr/TPM.txt" # A TPM table
AGINAME_FILE       <- "../Data/information/AGI-to-gene-names_v2.txt"
CORR_THRESHOLD     <- 0.5
MIN_SIZE           <- 100


# Make sure we have required files
if (! file.exists(TARGET_TF_FILE)) {
   stop(sprintf('[create_edges2_k2.R] Unable to find %s', TARGET_TF_FILE))
}

if (! file.exists(DATA_FILE)) {
   stop(sprintf('[create_edges_k2.R] Unable to find %s', DATA_FILE))
}

if (! file.exists(AGINAME_FILE)) {
   stop(sprintf('[create_edges_k2.R] Unable to find %s', AGINAME_FILE))
}


####### Read data #########################################
X             <- read.table(DATA_FILE, header=TRUE, check.names=FALSE)
gene_id       <- X$gene_id
X$gene_id     <- NULL
row.names(X)  <- gene_id
X             <- as.matrix(X)
rna.sample.id <- colnames(X)

target_tf  <- read.table(TARGET_TF_FILE, sep='\t', header=FALSE)
target_tf  <- as.matrix(target_tf)
targets    <- target_tf[,1]
tfs        <- target_tf[,2]
conditions <- target_tf[,3]

agi        <- read.table(AGINAME_FILE, stringsAsFactors=F) # AGINAME_FILE cannot contain quotes
#######################################################

library(mixtools)
options(max.print=999999999)
output.file <- paste('../Data/history/edges/one_target/edges.txt', 'mixtools', format(Sys.time(), '%b.%d.%Y.%H%M%S'), sep='.')
f <- file(output.file, 'w')

for (i in 1:length(targets)) {
    curr.date <- gsub('-','',Sys.Date())
    id1 <- tfs[i]
    id2 <- targets[i]
    if (id1 %in% gene_id == F || id2 %in% gene_id == F) {
        next
    }

    name1 <- agi$V2[which(agi$V1 == id1)]
    name2 <- agi$V2[which(agi$V1 == id2)]	    

    cond <- conditions[i]
    x <- X[id1,]
    y <- X[id2,]
    x <- log(x+1)
    y <- log(y+1)
    index <- x < 0.01 | y < 0.01
    x <- x[!index]
    y <- y[!index]
    if (length(x) < 3 | sd(x) < 0.1 | sd(y) < 0.1 ) {
        next
    }
    r <- cor(x, y)
    if (abs(r) >= CORR_THRESHOLD) {
        s = sprintf('%s %s\t%s %s\t%4.2f\t%s\t%s\t%s\t%s\t%s\n', id2, name2,id1,name1, r, 'all', '.', cond, '.', curr.date)
        #cat(s, file=result.file, sep='\n', append=T)
        #cat(s, sep='\n')
        #flush.console()
	#write.table(s, file.name, quote=F, sep='', row.names=F, append=T, col.names=F)
        next
    }

    k <- 2
    N <- length(x)
    tryCatch( em.out <- regmixEM(y, x, maxit=100, epsilon=1e-03, k=k), error=function(e) NULL )
    if (length(em.out) == 0)  { # if there is an error when running regmixEM, we skip.
        next
    }

    pos_r_max   <- -2
    pos_r_N     <- 0
    pos_r_index <- c()
    pos_r_loglik <- -100000000

    neg_r_max   <- 2
    neg_r_N     <- 0
    neg_r_index <- c()
    neg_r_loglik <- -100000000

    for (j in seq(1,k,1)) {

        index <- which(max.col(em.out$posterior) == j)
        size <- length(index)
        r <- cor(em.out$x[index,2], em.out$y[index])

        if (!is.na(r) && r >= CORR_THRESHOLD && size >= MIN_SIZE && r > pos_r_max && size > pos_r_N) {
            pos_r_max <- r
            pos_r_N   <- size
            pos_r_index <- index
            pos_r_loglik <- em.out$loglik
        }
        if (!is.na(r) && r <= -CORR_THRESHOLD && size >= MIN_SIZE && r < neg_r_max && size > neg_r_N) {
            neg_r_max <- r
            neg_r_N   <- size
            neg_r_index <- index
            neg_r_loglik <- em.out$loglik
        }
    }

    if (pos_r_max > 0) {
        sub.cond <- paste(rna.sample.id[pos_r_index], collapse=' ')
	num.sub.cond <- length(rna.sample.id[pos_r_index])
        s = sprintf('%s %s\t%s %s\t%4.2f\t%s\t%d\t%s\t%4.2f\t%s\t%4.2f\t%s\n', id2, name2, id1, name1, pos_r_max, 'mix', num.sub.cond, cond, pos_r_loglik, curr.date, pos_r_max, 'mixtool')
        #cat(s, file=result.file, sep='\n', append=T)
        #cat(s, sep='\n')
	#write.table(s, file.name, quote=F, sep='', row.names=F, append=T, col.names=F)
        cat(s, file=f, sep='')	
    }
    
    if (neg_r_max < 0) {
        sub.cond <- paste(rna.sample.id[neg_r_index], collapse=' ')
	num.sub.cond <- length(rna.sample.id[neg_r_index])	
        s = sprintf('%s %s\t%s %s\t%4.2f\t%s\t%d\t%s\t%4.2f\t%s\t%4.2f\t%s\n', id2, name2, id1, name1, neg_r_max, 'mix', num.sub.cond, cond, neg_r_loglik, curr.date, neg_r_max, 'mixtool')
        #cat(s, file=result.file, sep='\n', append=T)
        #cat(s, sep='\n')
	#write.table(s, file.name, quote=F, sep='', row.names=F, append=T, col.names=F)
        cat(s, file=f, sep='')
    }
}

close(f)