# Last modified on 9 Aug 2019 by Hui Lan

DATA.FILE      <- '../Data/history/expr/TPM.txt'
TARGET.TF.FILE <- '../Data/information/target_tf.txt'
AGINAME.FILE   <- '../Data/information/AGI-to-gene-names_v2.txt'
r.tau        <- 0.60
min.cluster  <- 3  # min number of clusters


# Make sure we have required files
if (! file.exists(DATA.FILE)) {
   stop(sprintf('[correlation_per_group.R] Unable to find %s', DATA.FILE))
}

if (! file.exists(TARGET.TF.FILE)) {
   stop(sprintf('[correlation_per_group.R] Unable to find %s', TARGET.TF.FILE))
}

if (! file.exists(AGINAME.FILE)) {
   stop(sprintf('[correlation_per_group.R] Unable to find %s', AGINAME.FILE))
}


cat(sprintf('Read %s\n', DATA.FILE))
X             <- read.table(DATA.FILE, header=TRUE, check.names=FALSE)
all.id        <- X$gene_id
X$gene_id     <- NULL   # remove column gene_id
row.names(X)  <- all.id # add row names
all.genes     <- rownames(X)

min.sample   <- max(50, ceiling(sqrt(dim(X)[2]))) # at least this many samples needed for computing a correlation coefficient
max.cluster  <- min(55, max(min.cluster + 1, ceiling(dim(X)[2]^0.50))) # max number of clusters, depending on total number of samples


# Filter genes
rowsum.tau <- dim(X)[2]       # the gene's TPM value is at least 1 on average
sd.val     <- apply(X, 1, sd)
lambda <- 0.3
#sd.tau  <- lambda * summary(sd.val)[3] + (1-lambda) * summary(sd.val)[5] # genes whose gene expression varies least are to be filtered
sd.tau <- 1
index.row <- rowSums(X) > rowsum.tau & sd.val > sd.tau & !is.na(sd.val)

X  <- log(X[index.row, ] + 1.0)

# Normalize each row such that its mean is 0 and standard deviation is 1
normalize <- function(X) {
    d <- dim(X)
    num_row <- d[1]
    num_col <- d[2]
    
    s <- apply(X, 1, sd)
    S <- matrix(rep(s, num_col), nrow=num_row)
    m <- apply(X, 1, mean)
    M <- matrix(rep(m, num_col), nrow=num_row)
    X <- (X - M)/S
}

X2 <- normalize(X)

cat(sprintf('Read %s\n', AGINAME.FILE))
agi        <- read.table(AGINAME.FILE, stringsAsFactors=F) # AGINAME_FILE cannot contain quotes

cat(sprintf('Read %s\n', TARGET.TF.FILE))
target.tf <- read.table(TARGET.TF.FILE, header=FALSE, check.names=FALSE, sep='\t')
total.pair <- dim(target.tf)[1]

cat(sprintf('min.cluster=%d, max.cluster=%d, min.sample=%d, r.tau=%4.2f\n', min.cluster, max.cluster, min.sample, r.tau))
cat('Hclust ...\n')
clusters <- hclust(dist(t(X2)), method = 'average')
cat('Go through pairs..\n')
output.file <- paste('../Data/history/edges/one_target/edges.txt', 'group', format(Sys.time(), '%b.%d.%Y.%H%M%S'), sep='.')
f <- file(output.file, 'w')

for (i in 1:total.pair) {
    
    gene.tf <- as.vector(target.tf[i,2])
    gene.target <- as.vector(target.tf[i,1])
    all.in <- gene.tf %in% all.genes & gene.target %in% all.genes
    if (!all.in) {
        next
    }
    if (!gene.tf %in% rownames(X) || !gene.target %in% rownames(X)) { # make sure both gene.tf and gene.target are in X
        next
    }

    # if too few rnaseq samples, or correlation on all rnaseq samples is good, don't look for group correlation
    x <- as.vector(t(X[gene.tf, ]))
    y <- as.vector(t(X[gene.target, ]))
    index <- x < 0.01 | y < 0.01 # don't include data that is too small
    x.1 <- x[!index]
    y.1 <- y[!index]
    if (length(x.1) < min.sample) {
        next
    } else if (cor(x.1, y.1) >= r.tau) {
        next
    }

    
    name1 <- agi$V2[which(agi$V1 == gene.tf)]
    name2 <- agi$V2[which(agi$V1 == gene.target)]	    

    # initial values
    max.r <- 0.0
    max.n <- 0
    max.samples <- c()

    # cut tree into different number of clusters
    for (cn in seq(min.cluster, max.cluster, 2)) { # cn is number of clusters
        cut <- cutree(clusters, cn)
        sample.names <- names(cut)
        for (c in unique(cut)) { # each cluster
            sample.index <- (cut == c)
            x <- as.vector(t(X[gene.tf, sample.index]))
            y <- as.vector(t(X[gene.target, sample.index]))
            n <- length(x)
            if (n > min.sample & sd(x) > 0.1 & sd(y) > 0.1) { # both x and y should vary
                r <- cor(x, y)
            } else {
                r <- 0.0
            }

            if (n > min.sample & abs(r) > r.tau & n > max.n) {
                max.r <- r
                max.n <- n
                max.samples <- sample.names[sample.index]
            }
        }
    }

    # save results
    if (max.n > 0) {
        curr.date <- gsub('-','',Sys.Date())
        loglik <- '-991.0'
        sub.cond <- paste(max.samples, collapse=' ')
	num.sub.cond <- length(max.samples)
        cond <- as.vector(target.tf[i,3])
        result <- sprintf('%s %s\t%s %s\t%4.2f\t%s\t%s\t%s\t%s\t%s\t%4.2f\t%s\n', gene.target, name2, gene.tf, name1, max.r, 'mix', num.sub.cond, cond, loglik, curr.date, max.r, 'hclust.group')
        cat(result, file=f, sep='')
    }
}

close(f)