# Section 3.2: Continuous Z

# Load packages:
rm(list = ls())
library(doParallel)
library(xrnet)
library(glmnet) # Competitor
library(readxl)  # For reading excel sheet for simulations
library(xtune)
library(fwelnet)
library(randomForest)

# Load source files
source.files <- list.files("code/sourceFiles/")
sapply(paste0("code/sourceFiles/", source.files), source)
simGrid <- read_xlsx("code/simulationList.xlsx", sheet = 1)

# This is where you define what simulation you want to run:

# 1-3: Varying p
# 4-6: Varying rho x
# 7-9: Varying rho z
# 10-12: Varying SNR
# 13-15: Varying q
# 16-18: Varying n
# 19-21: Varying y
# 22- : Varying other SNRs


# Run from here to run the simulations for a single scenario
sim <- 2
simVals <- simGrid[sim, ]
nObs <- simVals$nObs
nCovs <- simVals$nCovs
nExt <- simVals$nExt
SNR  <- simVals$SNR
seed <- simVals$seed
rhoX <- simVals$rhoX
rhoZ <- simVals$rhoZ
sigmaY <- simVals$sigmaY

alpha0 <- 0.01 * c(rep(1, 50), rep(0, 25), rep(3, 50), rep(1, 25), rep(0, nExt - 150))

B <- 500 # Number of Monte Carlo Replications

nCores <- 12
registerDoParallel(nCores)
set.seed(seed, kind = "L'Ecuyer-CMRG")
results <- foreach(b = 1:B, .combine = 'rbind') %dopar% {
    cat(b)
    res <- numeric()
    # Simulate data (nObs training + 1000 test)
    out <- generateSimulationData(nsub = nObs + 1000, ncovs = nCovs, alpha = alpha0,
                                  beta_intercept = 0.2, alpha_intercept = 0,
                                  SNR.beta = SNR, sigmaY = sigmaY,
                                  covZ = "ar", covX = "ar", rhoZ = rhoZ, rhoX = rhoX,
                                  isBinaryZ = FALSE, isBinaryX = FALSE)
    
    
    X.train <- out$X[1:nObs, ]
    Z.train <- out$Z
    y.train <- out$y[1:nObs]
    
    X.test <- out$X[-(1:nObs), ]
    y.test <- out$y[-(1:nObs)]
    
    foldId <- sample(10, size = nObs, replace = TRUE)
    
    # 1) Fit LASSO-Ridge
    fit_ridge_ridge <- tune_xrnet(
        x = X.train,
        y = y.train,
        external = Z.train,
        intercept = c(TRUE, FALSE),
        family = "gaussian",
        penalty_main = define_penalty(0, num_penalty = 20),
        penalty_external = define_penalty(0, num_penalty = 20),
        standardize = c(TRUE, TRUE),
        foldid = foldId
    )
    
    # Fit ridge regression via glmnet w/o taking into account external information
    fit.glm.ridge <- cv.glmnet(X.train, y.train, family = "gaussian", alpha = 0, nlambda = 20,
                               foldid = foldId,
                               standardize = TRUE)
    
    # Fit augemented ridge
    fit.glm.aug <- cv.glmnet(cbind(X.train, X.train %*% Z.train), 
                             y.train, family = "gaussian", alpha = 0, nlambda = 20,
                             foldid = foldId,
                             standardize = TRUE)
    
    # Competitors
    # fwelnet
    fit.fwelnet <- cv.fwelnet(X.train, y.train, Z.train, family = "gaussian", 
                              alpha = 0, 
                              foldid = foldId,
                              standardize = TRUE)
    
    # random forest
    fit.rf <- randomForest(cbind(X.train, X.train %*% Z.train), y = y.train)
    
    # Perform and store prediction results
    pred.glm           <- predict(fit.glm.ridge, newx = X.test, s = "lambda.min")
    pred.ridge_ridge   <- predict(fit_ridge_ridge, newdata = X.test)
    pred.aug_ridge1     <- predict(fit.glm.aug, newx = cbind(X.test, X.test %*% Z.train), s = "lambda.min")
    pred.aug_ridge2     <- predict(fit.glm.aug, newx = cbind(X.test, matrix(0, nrow = dim(X.test)[1], ncol = dim(Z.train)[2])), s = "lambda.min")
    
    # Calculate predicted mean squared error
    MSE.ridge_ridge   <- dist(rbind(y.test, as.vector(pred.ridge_ridge)))
    MSE.glm           <- dist(rbind(y.test, as.vector(pred.glm)))
    MSE.aug_ridge1     <- dist(rbind(y.test, as.vector(pred.aug_ridge1)))
    MSE.aug_ridge2     <- dist(rbind(y.test, as.vector(pred.aug_ridge2)))
    MSE.xtune          <- dist(rbind(y.test, as.vector(predict(fit.xtune, newX = X.test))))
    MSE.fwelnet        <- dist(rbind(y.test, as.vector(predict(fit.fwelnet, s = "lambda.min", xnew = X.test))))
    MSE.randomForest   <- dist(rbind(y.test, predict(fit.rf, cbind(X.test, X.test %*% Z.train))))
    
    
    
    res[1]  <- MSE.ridge_ridge
    res[2]  <- MSE.glm
    res[3]  <- MSE.aug_ridge1
    res[4]  <- MSE.aug_ridge2
    res[5]  <- MSE.fwelnet
    res[6]  <- MSE.randomForest
    
    TSS <- sum((y.test - mean(y.test))^2)
    res[7] <- 1 - (MSE.ridge_ridge^2) / TSS
    res[8] <- 1 - (MSE.glm^2) / TSS
    res[9] <- 1 - (MSE.aug_ridge1^2) / TSS
    res[10] <- 1 - (MSE.aug_ridge2^2) / TSS
    res[11] <- 1 - (MSE.fwelnet^2) / TSS
    res[12] <- 1 - (MSE.randomForest^2) / TSS
    
    res
}
stopImplicitCluster()

colnames(results) <- c("PMSE_RR", "PMSE_GLM", "PMSE_AUG1", "PMSE_AUG2", "PMSE_fwelnet", "PMSE_randomForest",
                       "R2_RR", "R2_GLM", "R2_AUG1", "R2_AUG2", "R2_fwelnet", "R2_randomForest")

saveRDS(results, file = paste0("code/results/linear_supp", sim, ".rds"))
