rm(list = ls())
library(ggplot2)
library(tidyr)
library(ggpubr)

# Varying p: 1-3
# Varying SNR: 10-12
# Varying q: 13-15
# Varying n: 16-18

##############################
out <- numeric()
for(i in 1:3) {
  results <- readRDS(paste0("Desktop/twoLevelRidge/results/linear2_", i, ".rds"))
  out <- c(out, results[, 5:8])
}

res <- data.frame(y = out, 
                  x = as.factor(rep(1:3, each = 2000)),
                  z = as.factor(rep(1:4, each = 500))) %>%
  mutate(z = factor(z, ordered = TRUE, 
                    levels = c(2, 3, 4, 1)))


p1 <- ggplot(data = res %>% dplyr::filter(z != "4"), aes(x = x, y = y, fill = z)) +
  geom_boxplot(outlier.size = -10) +
  ylab(expression(Test~~R^2)) +
  xlab(expression(p)) +
  scale_x_discrete(label = c("400", "1000", "2000")) +
  scale_fill_discrete(name = "",
                      label = c("Standard Ridge",
                                "Augmented Ridge",
                                "Two-Level Ridge")) +
  theme(legend.position = "right",
        axis.title.x = element_text(color = "black", size = 12),
        axis.title.y = element_text(color = "black", size = 12),
        axis.text.x = element_text(color = "black", size = 14),
        axis.text.y = element_text(color = "black", size = 14),
        legend.title = element_text(size = 0),
        legend.text = element_text(size = 14),
        legend.spacing.x = unit(0.5, 'cm')) 

##############################
out <- numeric()
for(i in 4:6) {
  results <- readRDS(paste0("Desktop/twoLevelRidge/results/linear2_", i, ".rds"))
  out <- c(out, results[, 5:8])
}

res <- data.frame(y = out, 
                  x = as.factor(rep(1:3, each = 2000)),
                  z = as.factor(rep(1:4, each = 500))) %>%
  mutate(z = factor(z, ordered = TRUE, 
                    levels = c(2, 3, 4, 1)))


p2 <- ggplot(data = res %>% dplyr::filter(z != "4"), aes(x = x, y = y, fill = z)) +
  geom_boxplot(outlier.size = -10) +
  ylab(expression(Test~~R^2)) +
  xlab(expression(n)) +
  scale_x_discrete(label = c("400", "800", "1000")) +
  scale_fill_discrete(name = "",
                      label = c("Standard Ridge",
                                "Augmented Ridge",
                                "Two-Level Ridge")) +
  theme(legend.position = "right",
        axis.title.x = element_text(color = "black", size = 12),
        axis.title.y = element_text(color = "black", size = 12),
        axis.text.x = element_text(color = "black", size = 14),
        axis.text.y = element_text(color = "black", size = 14),
        legend.title = element_text(size = 0),
        legend.text = element_text(size = 14),
        legend.spacing.x = unit(0.5, 'cm')) 

##############################
out <- numeric()
for(i in c(10, 8, 7)) {
  results <- readRDS(paste0("Desktop/twoLevelRidge/results/linear2_", i, ".rds"))
  out <- c(out, results[, 5:8])
}

res <- data.frame(y = out, 
                  x = as.factor(rep(1:3, each = 2000)),
                  z = as.factor(rep(1:4, each = 500))) %>%
  mutate(z = factor(z, ordered = TRUE, 
                    levels = c(2, 3, 4, 1)))


p3 <- ggplot(data = res %>% dplyr::filter(z != "4"), aes(x = x, y = y, fill = z)) +
  geom_boxplot(outlier.size = -10) +
  ylab(expression(Test~~R^2)) +
  xlab(expression(SNR)) +
  scale_x_discrete(label = c("0.001", "0.5", "2")) +
  scale_fill_discrete(name = "",
                      label = c("Standard Ridge",
                                "Augmented Ridge",
                                "Two-Level Ridge")) +
  theme(legend.position = "right",
        axis.title.x = element_text(color = "black", size = 12),
        axis.title.y = element_text(color = "black", size = 12),
        axis.text.x = element_text(color = "black", size = 14),
        axis.text.y = element_text(color = "black", size = 14),
        legend.title = element_text(size = 0),
        legend.text = element_text(size = 14),
        legend.spacing.x = unit(0.5, 'cm')) 


##### Put plots together
figure <- ggarrange(p1, p2, p3, 
                    labels = c("A", "B", "C"),
                    ncol = 2, nrow = 2,
                    font.label = 12,
                    common.legend = TRUE, legend = "top")
figure

