
rm(list=ls(all=TRUE))

library(rjags)
library(random)
library(R2WinBUGS)
library(doParallel)

load("SOFA-data.RData")

# # 30-point Gauss-Legendre quadrature
xk <- c(-0.996893484074649540272,-0.98366812327974720997,-0.960021864968307512217,-0.926200047429274325879,
        -0.882560535792052681543,-0.829565762382768397443,-0.767777432104826194918,-0.697850494793315796932,
        -0.6205261829892428611405,-0.5366241481420198992642,-0.4470337695380891767806,-0.352704725530878113471,
        -0.2546369261678898464398,-0.1538699136085835469638,-0.051471842555317695833,0.051471842555317695833,
        0.153869913608583546964,0.25463692616788984644,0.352704725530878113471,0.4470337695380891767806,
        0.536624148142019899264,0.62052618298924286114,0.697850494793315796932,0.767777432104826194918,
        0.829565762382768397443,0.882560535792052681543,0.926200047429274325879,0.960021864968307512217,
        0.98366812327974720997,0.996893484074649540272)

wk <- c(0.007968192496166605615,0.018466468311090959142,0.0287847078833233693497,0.038799192569627049597,
        0.048402672830594052903,0.057493156217619066482,0.065974229882180495128,0.073755974737705206268,
        0.0807558952294202153547,0.0868997872010829798024,0.0921225222377861287176,0.0963687371746442596395,
        0.099593420586795267063,0.101762389748405504596,0.102852652893558840341,0.1028526528935588403413,
        0.1017623897484055045964,0.099593420586795267063,0.096368737174644259639,0.092122522237786128718,
        0.0868997872010829798024,0.080755895229420215355,0.07375597473770520627,0.065974229882180495127,
        0.057493156217619066482,0.048402672830594052903,0.038799192569627049597,0.02878470788332336935,
        0.0184664683110909591423,0.007968192496166605615)

ysofa <- matrix(NA,nrow(data.patients),max(data.total$day))
M1 <- as.numeric(table(data.total$id))
for(i in 1:nrow(data.patients)){ysofa[i,1:M1[i]] <- log(data.total$sofa1[data.total$id==unique(data.total$id)[i]])}

# remove
ysofa.131 <- ysofa[131,] # 3 SOFA
ysofa.12 <- ysofa[12,] # 6 SOFA
ysofa[12,c(4,5,6)] <- c(NA,NA,NA)
ysofa <- ysofa[-131,]
M1 <- M1[-c(131)]
M1[12] <- 3
data.patients$status[12] <- 0
data.patients$time[12] <- 3

data.JAGS <- list(N=nrow(data.patients)-1, # number of patients
             M=M1, # number of longitudinal measurements
             y=ysofa, # longitudinal data: log(SOFA+1)
             day=1:max(data.total$day), # range of days
             age=data.patients$age[-131], # patient age
             Time=data.patients$time[-131], # time-to-event data
             K=length(xk), # 30-point Gauss-Legendre quadrature
             xk=xk, # nodes (Gauss-Legendre quadrature)
             wk=wk, # weights (Gauss-Legendre quadrature)
             zeros=numeric(nrow(data.patients)-1), C=50000, # zeros trick
             eventD=ifelse(data.patients$status[-131]==1,1,0), # specify the event death in the ICU
             eventA=ifelse(data.patients$status[-131]==2,1,0)) # specify the event alive at hospital discharge

##########################################
#      Cumulative incidence function     #
##########################################
require(survival)
require(reshape2)
require(ggplot2)

data <- data.frame(time=data.patients$time,status=data.patients$status)
CIF <- survfit(Surv(time,status,type="mstate")~1,data=data)

surv.data <- data.frame(CIF$time,CIF$pstate[,1],CIF$lower[,1],CIF$upper[,1],CIF$pstate[,2],CIF$lower[,2],CIF$upper[,2])
surv.data <- rbind(surv.data,rep(0,7))
names(surv.data) <- c("day","prev1","prev1.lo","prev1.up","prev2","prev2.lo","prev2.up")

surv.data2 <- data.frame(melt(surv.data[,c(1,2,5)], id=1))
levels(surv.data2$variable)=c("Death", "Alive discharge")

##########################################
#             SOFA and SOFA*             #
##########################################

sofa <- data.frame(y=as.vector(t(exp(ysofa)-1)),time=rep(1:30,139),
        id=rep(1:139,each=30),status=rep(data.patients$status,each=30))

sofa1 <- data.frame(y=as.vector(t(ysofa)),time=rep(1:30,139),
         id=rep(1:139,each=30),status=rep(data.patients$status,each=30))
##########################################


jointmodel <- function(){
  
  for(i in 1:N){
    # LONGITUDINAL SUBMODEL
    for(j in 1:M[i]){
      mu.y[i,j] <- beta[1]+b[i,1] + (beta[2]+b[i,2])*day[j] + beta[3]*age[i]
      y[i,j] ~ dnorm(mu.y[i,j],tau)
    }
    
    # COMPETING RISKS SUBMODEL
    # Calculation for integration (cumulative hazard)
    for(j in 1:K){
      # hazard function for dead
      hD[i,j] <- nuD*pow(Time[i]/2*(xk[j]+1), nuD-1) *
        exp( lambdaD + gammaD*age[i] + alpha0D*b[i,1] + alpha1D*b[i,2]*(Time[i]/2*(xk[j]+1)) )
      # hazard function for alive
      hA[i,j] <- nuA*pow(Time[i]/2*(xk[j]+1), nuA-1) *
        exp( lambdaA + gammaA*age[i] + alpha0A*b[i,1] + alpha1A*b[i,2]*(Time[i]/2*(xk[j]+1)) )
    }
    
    # Cumulative hazard H[t] = int_0^t h[u] du - Gauss-Legendre quadrature
    cumHazD[i] <- (Time[i]/2) * sum(inprod(wk, hD[i,]))
    cumHazA[i] <- (Time[i]/2) * sum(inprod(wk, hA[i,]))
    log.SurvD[i] <- -cumHazD[i]
    log.SurvA[i] <- -cumHazA[i]
    
    # Definition of the log-likelihood using zeros trick
    phi[i] <- C - ( eventD[i]*log(hD[i,K]) + eventA[i]*log(hA[i,K]) ) - (log.SurvD[i] + log.SurvA[i])
    zeros[i] ~ dpois(phi[i])
    
    # Random effects
    b[i,1] ~ dnorm(0,tau0)
    b[i,2] ~ dnorm(0,tau1)
  }
  
  # PRIORS AND HYPERPRIORS
  for(i in 1:3){beta[i] ~ dnorm(0,0.001)}
  tau <- pow(sig,-2)
  sig ~ dunif(0,20)
  tau0 <- pow(sig0,-2)
  sig0 ~ dunif(0,10)
  tau1 <- pow(sig1,-2)
  sig1 ~ dunif(0,10)  
  
  nuD ~ dgamma(0.1,0.1)
  nuA ~ dgamma(0.1,0.1)
  lambdaD ~ dnorm(0,0.001)
  lambdaA ~ dnorm(0,0.001)
  gammaD ~ dnorm(0,0.001)
  gammaA ~ dnorm(0,0.001)
  alpha0D ~ dnorm(0,0.001)
  alpha0A ~ dnorm(0,0.001)
  alpha1D ~ dnorm(0,0.001)
  alpha1A ~ dnorm(0,0.001)  
}

filename.JAGS <- file.path("JM.appl.txt")
write.model(jointmodel, filename.JAGS)

params.JAGS <- c("beta", "sig", "sig0", "sig1", "nuD", "nuA", "lambdaD", "lambdaA",
                 "gammaD", "gammaA", "alpha0D", "alpha0A", "alpha1D", "alpha1A", "b")

registerDoParallel(3)
seeds <- c(34,73,1500)
jags.inits <- function(i){
  return(list(.RNG.name="lecuyer::RngStream", .RNG.seed=seeds[i]))
}
mcmc.combine <- function(...){
  return(as.mcmc.list(sapply(list(...), mcmc)))
}

ni<-200000; na<-0.2*ni; nt<-max(1, floor(2*ni/1000))

time.jags <- system.time(
  jags.parsamples <- foreach( 
    i=1:getDoParWorkers(), .inorder=F, .packages=c('rjags','random'), .combine="mcmc.combine", .multicombine=T) %dopar% {
      load.module("lecuyer")
      model.jags <- jags.model(data=data.JAGS, file=filename.JAGS, inits=jags.inits(i), n.adapt=na)
      result <- coda.samples(model.jags, variable.names=params.JAGS, n.iter=ni, thin=nt)
      return(result)
    } 
)

#-----------MCMC Convergence----------
plot(jags.parsamples)
gelman.diag(jags.parsamples,multivariate=F)
#-------------------------------------


bss <- do.call(rbind,jags.parsamples)
sims.list <- vector("list", length(params.JAGS))
names(sims.list) <- params.JAGS

# place the sample of the parameters of interest in the list
for(p in seq_along(params.JAGS)){
  iik <- grep(paste("^", params.JAGS[p], sep=""), colnames(bss))
  sims.list[[p]] <- bss[,iik]
}

