### Section 3.1: Binary Z
# Load packages:
rm(list = ls())

library(doParallel)
library(xrnet)
library(glmnet) # Competitor
library(readxl)  # For reading excel sheet for simulations
library(xtune) # Competitor
library(randomForest) # Competitor
library(GRridge) # Competitor
library(fwelnet) # Competitor

data_directory <- "Projects/twoLevelRidge/"
source.files <- list.files(paste0(data_directory, "sourceFiles/"))

# Load source files
sapply(paste0(data_directory, "sourceFiles/", source.files), source)
simGrid <- read_xlsx(paste0(data_directory, "simulationList.xlsx"), sheet = 2)

sim <- as.numeric(Sys.getenv("SLURM_ARRAY_TASK_ID"))
#sim <- 2
simVals <- simGrid[sim, ]
nObs <- simVals$nObs
nCovs <- simVals$nCovs
SNR  <- simVals$SNR
seed <- simVals$seed
rhoX <- simVals$rhoX
rhoZ <- simVals$rhoZ
sigmaY <- simVals$sigmaY
nExt <- 6

alpha0 <- rep(0.1, 6)


B <- 500 # Number of Monte Carlo Replications

nCores <- 14
registerDoParallel(nCores)
set.seed(seed, kind = "L'Ecuyer-CMRG")
results <- foreach(b = 1:B, .combine = 'rbind') %dopar% {
    

    # Generate 
    Z <- matrix(rnorm(nCovs * 6), ncol = 6)
    Z <- apply(Z, 2, function(x) ifelse(abs(x) >= qnorm(0.90), 1, 0))

    
    out <- generateSimulationDataFixedZ(nsub = nObs + 1000, ncovs = nCovs, alpha = alpha0,
                                        Z = Z,
                                        beta_intercept = 0.2, alpha_intercept = 0,
                                        SNR.beta = SNR, sigmaY = sigmaY,
                                        covX = "ar", rhoX = rhoX,
                                        isBinaryX = FALSE)
    
    
    X.train <- out$X[1:nObs, ]
    Z.train <- out$Z
    y.train <- out$y[1:nObs]
    
    X.test <- out$X[-(1:nObs), ]
    y.test <- out$y[-(1:nObs)]
    
    foldId <- sample(10, size = nObs, replace = TRUE)
    
    
    # 1) Fit LASSO-Ridge
    fit_ridge_ridge <- tune_xrnet(
        x = X.train,
        y = y.train,
        external = Z.train,
        intercept = c(TRUE, FALSE),
        family = "gaussian",
        penalty_main = define_penalty(0, num_penalty = 20),
        penalty_external = define_penalty(0, num_penalty = 20),
        standardize = c(TRUE, TRUE),
        foldid = foldId
    )
    
    # Fit ridge regression via glmnet w/o taking into account external information
    fit.glm.ridge <- cv.glmnet(X.train, y.train, family = "gaussian", alpha = 0, nlambda = 20,
                               foldid = foldId,
                               standardize = TRUE)
    
    # Fit augemented ridge
    fit.glm.aug <- cv.glmnet(cbind(X.train, X.train %*% Z.train), 
                             y.train, family = "gaussian", alpha = 0, nlambda = 20,
                             foldid = foldId,
                             standardize = TRUE)
    
    # Competitors
    
    # xtune
    #fit.xtune <- xtune(X.train, y.train, Z.train, method = "ridge", family = "linear", message = FALSE)
    
    # GRridge
    
    
    # fwelnet
    fit.fwelnet <- cv.fwelnet(X.train, y.train, Z.train, family = "gaussian", 
                              alpha = 0, 
                              foldid = foldId,
                              #nlambda = 20,
                              #foldid = foldId,
                              standardize = TRUE)
    
    # random forest
    fit.rf <- randomForest(cbind(X.train, X.train %*% Z.train), y = y.train)
    
    # Perform and store prediction results
    pred.glm           <- predict(fit.glm.ridge, newx = X.test, s = "lambda.min")
    pred.ridge_ridge   <- predict(fit_ridge_ridge, newdata = X.test)
    pred.aug_ridge1     <- predict(fit.glm.aug, newx = cbind(X.test, X.test %*% Z.train), s = "lambda.min")
    pred.aug_ridge2     <- predict(fit.glm.aug, newx = cbind(X.test, matrix(0, nrow = dim(X.test)[1], ncol = dim(Z.train)[2])), s = "lambda.min")
    
    # Calculate predicted mean squared error
    # Calculate predicted mean squared error
    MSE.ridge_ridge   <- dist(rbind(y.test, as.vector(pred.ridge_ridge)))
    MSE.glm           <- dist(rbind(y.test, as.vector(pred.glm)))
    MSE.aug_ridge1     <- dist(rbind(y.test, as.vector(pred.aug_ridge1)))
    MSE.aug_ridge2     <- dist(rbind(y.test, as.vector(pred.aug_ridge2)))
    #MSE.xtune          <- dist(rbind(y.test, as.vector( predict(fit.xtune, newX = X.test))))
    MSE.fwelnet        <- dist(rbind(y.test, as.vector(predict(fit.fwelnet, s = "lambda.min", xnew = X.test))))
    MSE.randomForest   <- dist(rbind(y.test, predict(fit.rf, cbind(X.test, X.test %*% Z.train))))
    
    TSS <- sum((y.test - mean(y.test))^2)
    
    c(#PMSE
        MSE.ridge_ridge, 
        MSE.glm,
        MSE.aug_ridge1,
        MSE.aug_ridge2, 
        MSE.fwelnet,
        MSE.randomForest,
        # R-squared
        1 - (MSE.ridge_ridge^2) / TSS,
        1 - (MSE.glm^2) / TSS,
        1 - (MSE.aug_ridge1^2) / TSS,
        1 - (MSE.aug_ridge2^2) / TSS,
        #1 - (MSE.xtune^2) / TSS,
        1 - (MSE.fwelnet^2) / TSS,
        1 - (MSE.randomForest^2) / TSS)
    
}
stopImplicitCluster()

colnames(results) <- c("PMSE_RR", "PMSE_GLM", "PMSE_AUG1", "PMSE_AUG2", "PMSE_fwelnet", "PMSE_randomForest",
                       "R2_RR", "R2_GLM", "R2_AUG1", "R2_AUG2", "R2_fwelnet", "R2_randomForest")

saveRDS(results, file = paste0(data_directory, "results/linear2_", sim, ".rds"))


