# Reproduces Figure 4 (supplementak figure) in the manuscript

rm(list = ls())
library(ggplot2)
library(tidyr)
library(ggpubr)

wd <- "paper_results/" # Directory to use paper results (from previous sim)
#wd <- "code/results/" 

# 4-6: Varying rho x
# 7-9: Varying rho z
# 19-21: Varying y

##############################
out <- numeric()
for(i in 4:6) {
  results <- readRDS(paste0(wd, "linear_supp", i, ".rds"))
  out <- c(out, results[, 7:12])
}

res <- data.frame(y = out, 
                  x = as.factor(rep(1:3, each = 3000)),
                  z = as.factor(rep(1:6, each = 500))) %>%
  mutate(z = factor(z, ordered = TRUE, 
                    levels = c(2, 3, 4, 1, 5, 6)))



p1 <- ggplot(data = res %>% dplyr::filter(z != "4"), aes(x = x, y = y, fill = z)) +
  geom_boxplot(outlier.size = -10) +
  ylab(expression(Test~~R^2)) +
  xlab(expression(rho[x])) +
  scale_x_discrete(label = c("0", "0.25", "0.5")) +
  scale_fill_discrete(name = "",
                      label = c("Standard Ridge",
                                "Augmented Ridge",
                                "Two-Level Ridge",
                                "fwelnet",
                                "Random Forest")) +
  theme(legend.position = "top",
        axis.title.x = element_text(color = "black", size = 12),
        axis.title.y = element_text(color = "black", size = 12),
        axis.text.x = element_text(color = "black", size = 14),
        axis.text.y = element_text(color = "black", size = 14),
        legend.title = element_text(size = 0),
        legend.text = element_text(size = 14),
        legend.spacing.x = unit(0.5, 'cm')) 

##############################
out <- numeric()
for(i in 7:9) {
  results <- readRDS(paste0(wd, "linear_supp", i, ".rds"))
  out <- c(out, results[, 7:12])
}

res <- data.frame(y = out, 
                  x = as.factor(rep(1:3, each = 3000)),
                  z = as.factor(rep(1:6, each = 500))) %>%
  mutate(z = factor(z, ordered = TRUE, 
                    levels = c(2, 3, 4, 1, 5, 6)))


p2 <- ggplot(data = res %>% dplyr::filter(z != "4"), aes(x = x, y = y, fill = z)) +
  geom_boxplot(outlier.size = -10) +
  ylab(expression(Test~~R^2)) +
  xlab(expression(rho[Z])) +
  scale_x_discrete(label = c("0.25", "0.5", "0.8")) +
  scale_fill_discrete(name = "",
                      label = c("Standard Ridge",
                                "Augmented Ridge",
                                "Two-Level Ridge",
                                "fwelnet",
                                "Random Forest")) +
  theme(legend.position = "top",
        axis.title.x = element_text(color = "black", size = 12),
        axis.title.y = element_text(color = "black", size = 12),
        axis.text.x = element_text(color = "black", size = 14),
        axis.text.y = element_text(color = "black", size = 14),
        legend.title = element_text(size = 0),
        legend.text = element_text(size = 14),
        legend.spacing.x = unit(0.5, 'cm')) 

##############################
out <- numeric()
for(i in 19:21) {
  results <- readRDS(paste0(wd, "linear_supp", i, ".rds"))
  out <- c(out, results[, 7:12])
}

res <- data.frame(y = out, 
                  x = as.factor(rep(1:3, each = 3000)),
                  z = as.factor(rep(1:6, each = 500))) %>%
  mutate(z = factor(z, ordered = TRUE, 
                    levels = c(2, 3, 4, 1, 5, 6)))


p3 <- ggplot(data = res %>% dplyr::filter(z != "4"), aes(x = x, y = y, fill = z)) +
  geom_boxplot(outlier.size = -10) +
  ylab(expression(Test~~R^2)) +
  xlab(expression(sigma[y])) +
  scale_x_discrete(label = c("0.1", "0.5", "3")) +
  scale_fill_discrete(name = "",
                      label = c("Standard Ridge",
                                "Augmented Ridge",
                                "Two-Level Ridge",
                                "fwelnet",
                                "Random Forest")) +
  theme(legend.position = "top",
        axis.title.x = element_text(color = "black", size = 12),
        axis.title.y = element_text(color = "black", size = 12),
        axis.text.x = element_text(color = "black", size = 14),
        axis.text.y = element_text(color = "black", size = 14),
        legend.title = element_text(size = 0),
        legend.text = element_text(size = 14),
        legend.spacing.x = unit(0.5, 'cm')) 



##### Put plots together
figure <- ggarrange(p1, p2, p3, 
                    labels = c("A", "B", "C"),
                    ncol = 2, nrow = 2,
                    font.label = 12,
                    common.legend = TRUE, legend = "top")
figure
