# Section 3.3: Logistic Regression (Binary Outcome)

# Load packages:
rm(list = ls())

library(doParallel)
library(xrnet)
library(glmnet) # Competitor
library(readxl)  # For reading excel sheet for simulations
library(pROC)
library(randomForest)
library(fwelnet)

# Load source files
source.files <- list.files("code/sourceFiles/")
sapply(paste0("code/sourceFiles/", source.files), source)
simGrid <- read_xlsx("code/simulationList.xlsx", sheet = 3)

# This is where you define what simulation you want to run:

# 1-3: Varying p
# 4-6: Varying rho x
# 7-9: Varying rho z
# 10-12: Varying SNR
# 13-15: Varying q
# 16-18: Varying n
# 19-21: Varying y
# 22- : Varying other SNRs

# Run from here to run the simulations for a single scenario
sim <- 2
simVals <- simGrid[sim, ]
nObs <- simVals$nObs
nCovs <- simVals$nCovs
nExt <- simVals$nExt
SNR  <- simVals$SNR
seed <- simVals$seed
rhoX <- simVals$rhoX
rhoZ <- simVals$rhoZ
sigmaY <- simVals$sigmaY



B <- 500 # Number of Monte Carlo Replications

nCores <- 12
registerDoParallel(nCores)
set.seed(seed, kind = "L'Ecuyer-CMRG")
results <- foreach(b = 1:B, .combine = 'rbind') %dopar% {
  
  res <- numeric()
  
  out <- generateBinaryData(nsub = nObs + 1000, ncovs = nCovs, alpha = alpha0,
                            beta_intercept = 0.5, alpha_intercept = 0,
                            SNR.beta = SNR, sigmaY = sigmaY,
                            covZ = "ar", covX = "ar", rhoZ = rhoZ, rhoX = rhoX,
                            isBinaryZ = FALSE, isBinaryX = FALSE)
  
  
  X.train <- out$X[1:nObs, ]
  Z.train <- out$Z
  y.train <- out$y[1:nObs]
  
  X.test <- out$X[-(1:nObs), ]
  y.test <- out$y[-(1:nObs)]
  
  foldId <- sample(10, size = nObs, replace = TRUE)
  
  # 1) Fit LASSO-Ridge
  fit_ridge_ridge <- tune_xrnet(
    x = X.train,
    y = y.train,
    external = Z.train,
    intercept = c(TRUE, FALSE),
    family = "binomial",
    penalty_main = define_penalty(0, num_penalty = 20),
    penalty_external = define_penalty(0, num_penalty = 20),
    standardize = c(TRUE, TRUE),
    foldid = foldId
  )
  
  # Fit ridge regression via glmnet w/o taking into account external information
  fit.glm.ridge <- cv.glmnet(X.train, y.train, family = "binomial", alpha = 0, nlambda = 20,
                             foldid = foldId,
                             standardize = TRUE)
  
  # Fit augemented ridge
  fit.glm.aug <- cv.glmnet(cbind(X.train, X.train %*% Z.train), 
                           y.train, family = "binomial", alpha = 0, nlambda = 20,
                           foldid = foldId,
                           standardize = TRUE)
  
  fit.fwelnet <- cv.fwelnet(X.train, y.train, Z.train,
                            alpha = 0, 
                            family = "binomial",
                            #nlambda = 20,
                            foldid = foldId,
                            standardize = TRUE)
  
  # random forest
  fit.rf <- randomForest(x = cbind(X.train, X.train %*% Z.train), 
                         y = factor(y.train))
  
  
  # Perform and store prediction results
  pred.glm           <- predict(fit.glm.ridge, newx = X.test, type = "response", s = "lambda.min")
  pred.ridge_ridge   <- predict(fit_ridge_ridge, newdata = X.test, type = "response")
  pred.aug_ridge1     <- predict(fit.glm.aug, newx = cbind(X.test, X.test %*% Z.train), type = "response", s = "lambda.min")
  pred.aug_ridge2     <- predict(fit.glm.aug, newx = cbind(X.test, matrix(0, nrow = dim(X.test)[1], ncol = dim(Z.train)[2])),  type = "response", s = "lambda.min")
  
  #MSE.xtune          <- dist(rbind(y.test, as.vector( predict(fit.xtune, newX = X.test))))
  MSE.fwelnet        <- roc(y.test, as.vector(predict(fit.fwelnet, s = "lambda.min", xnew = X.test)))$auc
  MSE.randomForest   <- roc(factor(y.test), predict(fit.rf, newdata = cbind(X.test, X.test %*% Z.train), type = "prob")[, 2])$auc
  
  
  
  # Calculate predicted mean squared error
  MSE.ridge_ridge   <- roc(y.test, drop(pred.ridge_ridge))$auc
  MSE.glm           <- roc(y.test, drop(pred.glm))$auc
  MSE.aug_ridge1    <- roc(y.test, drop(pred.aug_ridge1))$auc
  MSE.aug_ridge2    <- roc(y.test, drop(pred.aug_ridge2))$auc
  
  
  res[1]  <- MSE.ridge_ridge
  res[2]  <- MSE.glm
  res[3]  <- MSE.aug_ridge1
  res[4]  <- MSE.aug_ridge2
  res[5]  <- MSE.fwelnet
  res[6]  <- MSE.randomForest
  res
}
stopImplicitCluster()

colnames(results) <- c("PAUB_RR", "PMSE_GLM", "PMSE_AUG1", "PMSE_AUG2", "PMSE_fwelnet", "PMSE_randomForest")

saveRDS(results, file = paste0("code/results/binary_supp", sim, ".rds"))
