#
# Ankerst and Neumair 2022
#

# > str(pbcg)
# $ site       : Factor w/ 11 levels
# $ pca7       : num  0 1 ...
# $ age        : num  59 64 ...
# $ lpsa2      : num  2.58 1.59 ...
# $ aa         : num  0 1 ...
# $ priorbiopsy: int  0 0 ...
# $ dre        : int  0 1 ...
# $ famhist1   : int  1 0 ...


library(tidyverse)
library(ggrepel)
library(ggmap)
library(naniar)

# Fig 1
nodes <- data.frame(lon = c(8.55, -81.621277, -78.898621, -122.4550, -66.664513, -73.9561118, -79.366318, -92.467369, 9.993682, 9.26249, -98.491142),
                    lat = c(47.36667, 41.503201, 35.994034, 37.7579, 18.200178, 40.7640249, 43.722825, 44.022705, 53.551086, 45.50551, 29.424349),
                    name = c("Zurich","Cleveland Clinic","Durham VA","UCSF","Puerto Rico VA","MSKCC","Sunnybrook","Mayo Clinic","Hamburg","San Raffaele","UT Health"))

sbbox <- make_bbox(lon = nodes$lon, lat = nodes$lat, f = .1)
sq_map <- get_map(location = sbbox, maptype = "terrain", source = "google")

ggmap(sq_map) + 
  geom_point(data = nodes, aes(x = lon, y = lat), size = 3,
             shape = 21, fill = 'white',
             color = 'black', stroke = 0.5) +
  geom_label_repel(data = nodes, aes(x = lon, y = lat, label = name),
                   size = 4, color = "black", fontface = "bold",
                   point.size = 3)+
  theme_nothing()
#ggsave("Fig1.pdf", dpi = 600, width = 10, height = 3.5) 


# Fig 2
gg_miss_fct(pbcg1, fct=site)+
  scale_fill_gradient(limits=c(0,100), low='snow2', high='firebrick3', name="Percent missing")+
  theme(text = element_text(size=18, face = "bold"),
        axis.text.x = element_text(angle=0,hjust=0.5),
        axis.title.y = element_blank(),
        legend.text = element_text(size=12),
        legend.position = "bottom",
        panel.border = element_rect(fill = NA, color = "black", size = 1))+
  xlab("Cohort number")
#ggsave("Fig2.pdf", dpi = 600, width = 7, height = 5) 

# Fig 3
# data set of individual cohorts for selected features
pbcgall <- pbcg_short <- 
  pbcg1[,c("site","pca","lpsa2","dre","age","aa","priorbiopsy","famhist1")] 
pbcgall$site="Overall"

# combining both data sets (individual cohorts and all cohorts pooled) and 
# discretizing continuous covariates age and psa
data=rbind(pbcg_short,pbcgall)
data$pca=factor(data$pca)
data$site=factor(data$site)

data$lpsa2=factor(ifelse(data$lpsa2>log(6,2),1,0))
data$dre=factor(data$dre)
data$age=factor(ifelse(data$age>65,1,0))
data$aa=factor(data$aa)
data$priorbiopsy=factor(data$priorbiopsy,levels=c("1","0"))
data$famhist1=factor(data$famhist1)

# get for every cohort and every risk factor the respective univariate odds ratio, 
# proportion with this risk factor and whether the risk factor is significant or not
covariates=names(data)[-c(1:2)]

odds_list=as.data.frame(covariates)
perc_list=as.data.frame(covariates)
sig_list=as.data.frame(covariates)
for(i in levels(data$site)){
  odds=list()
  perc=list()
  sig=list()
  for(j in covariates){
    # neglect unpredictive cohorts
    if(is.na(var(na.omit(subset(data,site==i,j))))|var(na.omit(subset(data,site==i,j)))==0){
      odds[[j]]=NA
      sig[[j]]=NA
      perc[[j]]=NA}else{
        odds[[j]]=exp(summary(glm(data[which(data$site==i),"pca"]~data[which(data$site==i),j],family=binomial(link="logit")))$coefficients[2,1])
        sig[[j]]=ifelse(summary(glm(data[which(data$site==i),"pca"]~data[which(data$site==i),j],family=binomial(link="logit")))$coefficients[2,4]<0.05,"bold","plain")
        perc[[j]]=sum((data[which(data$site==i),j])==1,na.rm=TRUE)/sum(!is.na(data[which(data$site==i),j]))
      }
  } 
  odds_list[,i]=unlist(odds)
  perc_list[,i]=unlist(perc)
  sig_list[,i]=unlist(sig)
}

# combine results in a data set
perc=gather(perc_list,site,percentage,-1)
odds=gather(odds_list,site,odds,-1)
sig=gather(sig_list,site,significance,-1)
odds=odds$odds
significance=as.factor(sig$significance)
data=cbind(perc,odds,significance)
# different color for individual cohorts and group "overall"
data$sitecode=factor(ifelse(data$site=="Overall",2,1))
# names and order of covariates
data$covariates <- factor(data$covariates)
levels(data$covariates)=c("African ancestry = Yes",
                          "Age > 65",
                          "DRE = Abnormal",
                          "Family History first degree = Yes",
                          "PSA > 6",
                          "Prior negative biopsy = No")
data$covariates=factor(data$covariates,levels=c("PSA > 6",
                                                "DRE = Abnormal",
                                                "Age > 65",
                                                "African ancestry = Yes",
                                                "Prior negative biopsy = No",
                                                "Family History first degree = Yes"))

# remove unused cohorts
data=na.omit(data)

vari <- c("PSA > 6",
          "DRE = Abnormal",
          "Family History first degree = Yes",
          "Age > 65",
          "African ancestry = Yes",
          "Prior negative biopsy = No")

data %>% 
  filter(covariates %in% vari) %>%
  left_join(data.frame(name=c(as.character(levels_order_prevalence),"Overall"),number=c(11:1,"Overall")),by=c("site"="name")) %>%
  ggplot(aes(x=percentage*100,y=odds,color=sitecode))+
  geom_point(size=1.5)+
  facet_wrap(~covariates,dir="v",ncol=2)+
  geom_text_repel(aes(label=number,fontface=significance),max.overlaps = 30)+
  ylab("Odds ratio for clinically significant prostate cancer")+
  xlab("Proportion with risk factor in %")+
  scale_color_manual(name="region",breaks=c(1,2),labels=c("cohorts","overall"),values=c("dodgerblue3","firebrick3"))+
  guides(color="none")+
  theme_bw(base_size=12)+
  theme(text=element_text(size=18),
        strip.background =element_blank(),
        strip.text = element_text(face="bold", hjust = 0))+
  scale_x_continuous(limits = c(0,62),breaks=seq(0,60,20))+
  scale_y_continuous(trans="log",limits = c(0.3,4),breaks=c(0.3,0.5,0.7,1,2,3,4))
#ggsave("Fig3.pdf", dpi = 600, width = 10, height = 10)

# Fig 4 to 6 are screenshots

# Fig 7
# define patients
low <- data.frame("age"=60,"lpsa2"=log(1,2),"aa"=0,"priorbiopsy"=1,"dre"=0,"famhist1"=0,"famhist2"=0,"famhist_bca"=0,"prosvol_l2"=5.459,"ari_use"=1,"hispanic"=0,"priorpsa"=1)
high <- data.frame("age"=75,"lpsa2"=log(4,2),"aa"=1,"priorbiopsy"=0,"dre"=1,"famhist1"=1,"famhist2"=1,"famhist_bca"=1,"prosvol_l2"=5.459,"ari_use"=0,"hispanic"=1,"priorpsa"=0)

# take age and psa as mandatory
# generate combination matrix for the other variables
combi <- as.matrix(expand.grid(lapply(numeric(length(variables[-c(1:2)])), function(x) c(0, 1))))
# and combine it for the mandatory variables
combi <- cbind(matrix(1,nrow = nrow(combi),ncol = 2),combi)
# names of the variables
colnames(combi) <- variables

pred_ac <- matrix(NA,nrow = nrow(combi), ncol = 2)
colnames(pred_ac) <- c("low","high")

for (i in 1:nrow(combi)){
  
  # select variable combination 
  use_variables <- colnames(combi)[combi[i,]==1]
  
  # build main effects model
  mod <- glm(as.formula(paste("pca7 ~ ",paste(use_variables, collapse= "+"))),
             data = na.omit(subset(pbcg,select =c("site","pca7",use_variables))),
             family = binomial(link='logit'))
  
  pred_ac[i,1] <- predict(mod,newdata=low,type="response")
  pred_ac[i,2] <- predict(mod,newdata=high,type="response")
  
}

pred_ac <- cbind(pred_ac,"nvar"=rowSums(combi))

pred_ac %>% 
  as.data.frame() %>% 
  gather("Variable","Prediction",-nvar) %>%
  mutate(Variable = plyr::revalue(Variable, c("low"="Low-risk patient", "high"="High-risk patient"))) %>%
  ggplot(aes(y=Prediction*100, x=nvar, fill=Variable))+
  geom_jitter(alpha=0.7, size=2,shape=21,color="black")+
  scale_x_continuous(breaks = seq(2,12,1), name="Number of available predictors")+
  scale_y_continuous(breaks = c(seq(2,10,2),seq(30,80,10)), name="Risk prediction (%)")+
  scale_fill_manual(values=c("firebrick3","dodgerblue3"))+
  facet_wrap(~Variable, scales = "free_y", nrow = 2)+
  theme_bw(base_size=12)+
  theme(text = element_text(size=18, face = "bold"),
        legend.position = "none",
        strip.background =element_blank(),
        strip.text = element_text(size=18, face="bold"))
#ggsave("Fig7.pdf", width = 6, height = 6)
