# Load packages:
rm(list = ls())

library(doParallel)
library(xrnet)
library(glmnet) # Competitor
library(readxl)  # For reading excel sheet for simulations
library(pROC)
library(randomForest)
library(fwelnet)

data_directory <- "Projects/twoLevelRidge/"
source.files <- list.files(paste0(data_directory, "sourceFiles/"))

# Load source files
sapply(paste0(data_directory, "sourceFiles/", source.files), source)
simGrid <- read_xlsx(paste0(data_directory, "simulationList.xlsx"), sheet = 3)


sim <- as.numeric(Sys.getenv("SLURM_ARRAY_TASK_ID"))
#sim <- 1
simVals <- simGrid[sim, ]
nObs <- simVals$nObs
nCovs <- simVals$nCovs
nExt <- simVals$nExt
SNR  <- simVals$SNR
seed <- simVals$seed
rhoX <- simVals$rhoX
rhoZ <- simVals$rhoZ
sigmaY <- simVals$sigmaY

alpha0 <- 0.1 * c(rep(1, 50), rep(0, 25), rep(3, 50), rep(1, 25), rep(0, nExt - 150))


B <- 500 # Number of Monte Carlo Replications

nCores <- 12
registerDoParallel(nCores)
set.seed(seed, kind = "L'Ecuyer-CMRG")
results <- foreach(b = 1:B, .combine = 'rbind') %dopar% {
    
    res <- numeric()

    out <- generateBinaryData(nsub = nObs + 1000, ncovs = nCovs, alpha = alpha0,
                                  beta_intercept = 0.5, alpha_intercept = 0,
                                  SNR.beta = SNR, sigmaY = sigmaY,
                                  covZ = "ar", covX = "ar", rhoZ = rhoZ, rhoX = rhoX,
                                  isBinaryZ = FALSE, isBinaryX = FALSE)
    
    
    X.train <- out$X[1:nObs, ]
    Z.train <- out$Z
    y.train <- out$y[1:nObs]
    
    X.test <- out$X[-(1:nObs), ]
    y.test <- out$y[-(1:nObs)]
    
    foldId <- sample(10, size = nObs, replace = TRUE)
    
    # 1) Fit LASSO-Ridge
    fit_ridge_ridge <- tune_xrnet(
        x = X.train,
        y = y.train,
        external = Z.train,
        intercept = c(TRUE, FALSE),
        family = "binomial",
        penalty_main = define_penalty(0, num_penalty = 20),
        penalty_external = define_penalty(0, num_penalty = 20),
        standardize = c(TRUE, TRUE),
        foldid = foldId
    )
    
    # Fit ridge regression via glmnet w/o taking into account external information
    fit.glm.ridge <- cv.glmnet(X.train, y.train, family = "binomial", alpha = 0, nlambda = 20,
                               foldid = foldId,
                               standardize = TRUE)
    
    # Fit augemented ridge
    fit.glm.aug <- cv.glmnet(cbind(X.train, X.train %*% Z.train), 
                             y.train, family = "binomial", alpha = 0, nlambda = 20,
                             foldid = foldId,
                             standardize = TRUE)
    
    fit.fwelnet <- cv.fwelnet(X.train, y.train, Z.train,
                              alpha = 0, 
                              family = "binomial",
                              #nlambda = 20,
                              foldid = foldId,
                              standardize = TRUE)
    
    # random forest
    fit.rf <- randomForest(x = cbind(X.train, X.train %*% Z.train), 
                           y = factor(y.train))
    
    
    # Perform and store prediction results
    pred.glm           <- predict(fit.glm.ridge, newx = X.test, type = "response", s = "lambda.min")
    pred.ridge_ridge   <- predict(fit_ridge_ridge, newdata = X.test, type = "response")
    pred.aug_ridge1     <- predict(fit.glm.aug, newx = cbind(X.test, X.test %*% Z.train), type = "response", s = "lambda.min")
    pred.aug_ridge2     <- predict(fit.glm.aug, newx = cbind(X.test, matrix(0, nrow = dim(X.test)[1], ncol = dim(Z.train)[2])),  type = "response", s = "lambda.min")
    
    #MSE.xtune          <- dist(rbind(y.test, as.vector( predict(fit.xtune, newX = X.test))))
    MSE.fwelnet        <- roc(y.test, as.vector(predict(fit.fwelnet, s = "lambda.min", xnew = X.test)))$auc
    MSE.randomForest   <- roc(factor(y.test), predict(fit.rf, newdata = cbind(X.test, X.test %*% Z.train), type = "prob")[, 2])$auc
    
    
    
    # Calculate predicted mean squared error
    MSE.ridge_ridge   <- roc(y.test, drop(pred.ridge_ridge))$auc
    MSE.glm           <- roc(y.test, drop(pred.glm))$auc
    MSE.aug_ridge1    <- roc(y.test, drop(pred.aug_ridge1))$auc
    MSE.aug_ridge2    <- roc(y.test, drop(pred.aug_ridge2))$auc
    
    
    res[1]  <- MSE.ridge_ridge
    res[2]  <- MSE.glm
    res[3]  <- MSE.aug_ridge1
    res[4]  <- MSE.aug_ridge2
    res[5]  <- MSE.fwelnet
    res[6]  <- MSE.randomForest
    res
}
stopImplicitCluster()

colnames(results) <- c("PAUB_RR", "PMSE_GLM", "PMSE_AUG1", "PMSE_AUG2", "PMSE_fwelnet", "PMSE_randomForest")

saveRDS(results, file = paste0(data_directory, "results/binary_supp", sim, ".rds"))
