rm(list = ls())
library(glmnet)
library(xrnet)
library(doParallel)
library(dplyr)
library(xtune)
library(randomForest)
library(fwelnet)
library(ggplot2)
library(pROC)

# Make sure working directory is ".../JDS2107-007/"
getwd()

load(paste0("RDA/Data/METABRIC/metabric_cleaned.RData"))

tmonth <- c(5, 7.5, 10) # X-year survival. Main text uses 5-year survival 

sim = 1 # Look at 5-year OS. sim = 2: 7.5 year survival. sim = 3: 10-year survival

clinical.train <- clinical.train %>%
  mutate(surv_in_month = ftime / 365) %>%
  filter(surv_in_month > tmonth[sim] | (surv_in_month <= tmonth[sim] & fstatus == 1)) %>%
  filter(
    er_neg == 0 & her2_pos == 0
  )


clinical.test <- clinical.test %>%
  mutate(surv_in_month = ftime / 365) %>%
  filter(surv_in_month > tmonth[sim] | (surv_in_month <= tmonth[sim] & fstatus == 1)) %>%
  filter(
    er_neg == 0 & her2_pos == 0
  )


#- Merge and standardize external matrix
external_mat = metagene_ext
colnames(external_mat) <- c("CIN", "MES", "LYM", "FGD3-SUSD3") 
external_mat = apply(external_mat, 2, function(x) x / sum(x != 0))

#- Reorder design matrix
X.train = X.train[clinical.train$METABRIC_ID, ]
X.test = X.test[clinical.test$METABRIC_ID, ]
y.train = clinical.train$fstatus
y.test = clinical.test$fstatus

set.seed(2021)

# Create folds
fold1 <- sample(1:10, length(which(y.train == 1)), replace = TRUE)
fold0 <- sample(1:10, length(which(y.train == 0)), replace = TRUE)

foldId <- numeric(dim(X.train)[1])
foldId[which(y.train == 1)] <- fold1
foldId[which(y.train == 0)] <- fold0
#foldId <- sample(1:10, size = dim(X.train)[1], replace = TRUE)



#----- Run LASSO-Ridge
writeLines("Running Lasso-Ridge")

nCores = 10
registerDoParallel(nCores)
fit_ridge_ridge <- tune_xrnet(x = X.train,
                              y = y.train,
                              external = external_mat,
                              family = "binomial",
                              loss = "deviance",
                              # Main penalty
                              penalty_main = define_penalty(
                                penalty_type = 0,
                                num_penalty = 20),
                              # Second-order penalty
                              penalty_ext = define_penalty(
                                penalty_type = 0,
                                num_penalty = 20),
                              intercept = c(TRUE, FALSE),
                              standardize = c(TRUE, FALSE),
                              nfolds = 10,
                              foldid = foldId,
                              parallel = TRUE)
stopImplicitCluster()

rr_pred <- predict(fit_ridge_ridge, newdata = X.test,
                   type = "response")
roc(y.test, rr_pred)


#----- Run GLMNET
registerDoParallel(nCores)
fit_glmnet_ridge <- cv.glmnet(X.train, y.train, nlambda = 20, nfolds = 10, 
                              foldid = foldId,
                              alpha = 0,
                              family = "binomial",
                              type.measure = "deviance",
                              parallel = TRUE)
stopImplicitCluster()


glm_pred <- predict(fit_glmnet_ridge, newx = X.test,
                    type = "response", s= "lambda.min")
roc(y.test, drop(glm_pred))


registerDoParallel(nCores)
fit_glmnet_ridge2 <- cv.glmnet(cbind(X.train, X.train %*% external_mat), y.train,
                               foldid = foldId,
                               alpha = 0,
                               family = "binomial",
                               type.measure = "deviance",
                               parallel = TRUE)

stopImplicitCluster()

aug_pred <- predict(fit_glmnet_ridge2, newx = cbind(X.test, X.test %*% metagene_ext),
                    type = "response", s = "lambda.min")
roc(y.test, drop(aug_pred))

#----- Run competing methods

# xtune
fit_xtune <- xtune(X.train,
                   y.train,
                   external_mat, family = "binary",
                   method = "ridge")

# feature-weighted elastic net
fit_fwelnet <- cv.fwelnet(X.train, y.train, 
                         external_mat,
                        nfolds = 10, 
                         foldid = foldId,
                         alpha = 0,
                         family = "binomial",
                         type.measure = "deviance")

fwelnet_pred <- predict(fit_fwelnet, xnew = X.test, type = "response",
                        s = "lambda.min")

roc(y.test, drop(fwelnet_pred))

# random forest
fit_rf <- randomForest(x = cbind(X.train, X.train %*% external_mat),
                       y = factor(y.train))

rf_pred <- predict(fit_rf, newdata = cbind(X.test, X.test %*% external_mat), 
                   type = "prob")

roc(y.test, rf_pred[, 2])



save(y.test, rr_pred, glm_pred, aug_pred, fwelnet_pred, X.test, fit_xtune,
  paste0("RDA/METABRIC/results/metabric_complete_", sim, ".RData"))
