library(metap)
library(reshape)


#MCRTs (Algorithm 1), p-value combiners and sythetic data generators.  


data_swd <- function(data,lag,sample_split=TRUE){
  
  #data: a stepped-wedge trial data set
  #lag: the time lag of the treatment effect
  #sample_split: if the sample_splitting in MCRTs (Algorithm 1) is used
  #return a list of treated versus control data sets for all the permutation tests to run
  
  num_steps <- max(data$t)
  data_list <-list()
  if (sample_split == FALSE){
    for (j in 1:(num_steps-lag-1)){
      treated_ <- subset(data,s==j & t == (j+lag))
      control_ <- subset(data,s>(j+lag) & t == (j+lag))
      data_list[[j]] <- list(treated_,control_)
    }
  }else if(sample_split == TRUE){
    
    num_groups <- min(lag+1,num_steps-lag-1)
    
    if (num_groups<=0){
      stop("Error: Please enter a time lag < number of time steps - 1")
    }
    
    group_list = list()
    for (t in 1:num_groups){
      step <- t
      group <- step 
      while ((step + lag + 1) <= num_steps){
        step <- step + lag +1 
        group <- c(group,step)
      }
      group_list[[length(group_list)+1]] <- group
    }
    for (group in group_list){
      group_size <- length(group)
      for (j in 1:(group_size-1)){
        treated_ <- subset(data,s==group[j] & t == (group[j]+lag))
        control_ <- subset(data,s %in% group[(j+1):group_size] & t == (group[j]+lag))
        data_list[[length(data_list)+1]] <- list(treated_,control_)
      }
    }  
  } else{
    stop("Error: R stops due to an incorrect setup")
  }
  return(data_list)
}








crt_swd <- function(data_list,type,beta=0,exact=FALSE,b=1000,n_beta=1000,const=3,print=FALSE){
  
  #data_list: a list of data sets for the permutation tests from the function "data_swd"
  #type: choose to return p-values or confidence intervals
  #beta: the constant treatment effect parameter in the null hypothesis
  #exact: if to run the tests exactly or through Monte-Carlo
  #b: number of permutations (assignments) in the tests if we run Monte-Carlo
  #const: [-const, const] is the range that we search for beta's for inverting the tests to confidence intervals
  #n_beta: number of beta's (grids) for the search on [-const, const]
  #If return confidence intervals, return the p-values on all the grids for the function "global_ci" to invert the tests.

  d <- length(data_list)
  
  if (print==TRUE){
    cat("\n")
    cat("Number of tests:", d)
    cat("\n")
  }

  if (type=='p_values'){
    p_list <-NULL
    for (j in 1:d){
      data_i <- data_list[[j]]
      if (print==TRUE){
        cat("Permutation between steps", sort(unique(data_i[[1]]$s)),"and", sort(unique(data_i[[2]]$s)))
        cat("\n")
      }
      p_value <- crt_p(data_i, beta, exact, b)
      p_list <- c(p_list,p_value)
    }
    

  }else if (type=='ci'){
    p_matrix_1 <- matrix(0, d, n_beta)
    p_matrix_2 <- matrix(0, d, n_beta)
    for (j in 1:d){
      data_i <- data_list[[j]]
      if (print==TRUE){
        cat("Permutation between steps", sort(unique(data_i[[1]]$s)),"and", sort(unique(data_i[[2]]$s)))
        cat("\n")
      }
      out <- crt_ci(data_i, const, n_beta, exact, b)
      p_matrix_1[j,] <- out[[1]]
      p_matrix_2[j,] <- out[[2]]
    }
    
    p_list <- list(p_matrix_1,p_matrix_2)
  
  } else {
    stop("Error: R stops due to an incorrect setup")
  }

 return(p_list)  
}








global_test <- function(data_list,p_list,method,alpha = 0.1){
  
  #data_list: a list of data sets from the function "data_swd" for computing the weights in the Z-score combiner
  #p_list: the p-values of all the tests from the function "crt_swd"
  #method: choose how to combine the p-values (Bonferroni, Fisher or Z-score)
  #alpha: significance level
  #return if =the intersection of null hypotheses is rejected
  
  
  if (length(p_list)==1){
    
    reject <- (p_list <= alpha)
    
  }else{
    
    if (method=='bonferroni'){
      p_value <- min(p_list)
      reject<- p_value <= alpha/length(p_list)
    } else if(method=='fisher'){
      p_value<- sumlog(p_list,log.p = FALSE)['p'][[1]]
      reject<- p_value <= alpha
    } else if(method=='zscore'){
      w <- optimal_weight(data_list)
      T_value <- sum(w * qnorm(p_list,mean=0,sd=1))
      p_value <- pnorm(T_value, 0, 1)
      reject<- p_value <= alpha
    } else{
      stop("Error: R stops due to an incorrect setup")
    }
    
  }
  return(reject)
}








global_ci <- function(data_list,p_list,method,alpha=0.1,const=3){
  
  #a list of data sets from the function "data_swd" for computing the weights in the Z-score combiner
  #p_list: the p-values on all the beta grids from the function "crt_swd"
  #const: input the range [-const, const] that we previously searched for beta's in the function "crt_swd"
  #method: choose how to combine the p-values (Bonferroni, Fisher or Z-score)
  #alpha: significance level
  #return the confidence interval of the lagged effect
  
  p_matrix_1 <- p_list[[1]]
  p_matrix_2 <- p_list[[2]]
  significance <- 0
  n_beta <- NCOL(p_matrix_1)
  alpha_0 <- alpha/2+0.05

  while (significance <(1-alpha) & alpha_0>= 0.00){
    list1 <- NULL
    list1_p <- NULL
    list2 <- NULL
    list2_p <- NULL
    
    
    if (NROW(p_matrix_1) >1){
      if (method=='zscore'){
        w<- optimal_weight(data_list)
        for (beta in 1:n_beta){
          add <- 2*const*beta/n_beta -const
          
          T_1 <- sum(w * qnorm(p_matrix_1[,beta],mean=0,sd=1))
          p <- pnorm(T_1, 0, 1)
          if (p >=alpha_0){
            list1 <- c(list1,add)
            list1_p <- c(list1_p, p)
          }
          
          T_2 <- sum(w * qnorm(p_matrix_2[,beta],mean=0,sd=1))
          p <- pnorm(T_2, 0, 1)
          if (p >=alpha_0){
            list2 <- c(list2,add)
            list2_p <- c(list2_p,p)
          }
        }
      } else if(method=='fisher'){
        
        for (beta in 1:n_beta){
          add <- 2*const*beta/n_beta -const
          p <- sumlog(p_matrix_1[,beta],log.p = FALSE)['p'][[1]] 
          if (p >=alpha_0){
            list1 <- c(list1,add)
            list1_p <- c(list1_p,p)
          }
          
          p <- sumlog(p_matrix_2[,beta],log.p = FALSE)['p'][[1]] 
          if (p >=alpha_0){
            list2 <- c(list2,add)
            list2_p <- c(list2_p,p)
          }
        }
      } else {
        stop("Error: R stops due to an incorrect setup")
      }
    }else if(NROW(p_matrix_1) ==1){
      for (beta in 1:n_beta){
        add <- 2*const*beta/n_beta -const
        p <- p_matrix_1[1,beta]
        if (p >=alpha_0){
          list1 <- c(list1,add)
          list1_p <- c(list1_p,p)
        }
        p <- p_matrix_2[1,beta]
        if (p >=alpha_0){
          list2 <- c(list2,add)
          list2_p <- c(list2_p,p)
        }
      } 
    }else {
      stop("Error: R stops due to an incorrect setup")
    }
    
    u <- max(list1)
    l <- min(list2)
    
    ci_levels <- c(list2_p[list2==l],1-list1_p[list1==u])
    ci<-c(l,u)
    
    
    
    significance <- ci_levels[2]-ci_levels[1]
    alpha_0 <- alpha_0 - 0.0025
  }
  
  
  #cat("Significance Level:", significance)
  
  return(ci)
}





crt_p <- function(data_i, beta=0,exact=FALSE,b=1000){
  
  #data_i: the dataset for i-th permutation test
  #beta: the constant treatment effect parameter in the null hypothesis
  #exact: if run the test exactly
  #b: number of permutations to consider
  
  mu_treated <- data_i[[1]]$y
  mu_treated <- mu_treated[!is.na(mu_treated)]
  mu_control <- data_i[[2]]$y
  mu_control <- mu_control[!is.na(mu_control)]
  
  num1 <- length(mu_treated)
  num2 <- length(mu_control)
  num <- num1 + num2
  
  rr <- perm_matrix(num1,num2, exact, b)
  x <- mu_treated - beta
  y <- mu_control
  obs_t <- mean(x)-mean(y)
  obs_t <- obs_t
  xy <- c(x,y)
  counter_t <- rr %*% xy
  p_value <- mean(counter_t > obs_t)
  p_value <- min(max(p_value,0.0000001),0.9999999)
  return(p_value)
}


crt_ci <- function(data_i,const,n_beta,exact,b){
  
  #data_i: the treated and control group for the i-th permutation tests
  #const: input the range [-const, const] that we previously searched for beta's in the function "crt_swd"
  #exact: if the permutation is implemented exactly
  #n_beta: number of beta's (grids) for the search on [-const, const]
  #b: number of permutations (assignments) in the tests if we run Monte-Carlo
  
  mu_treated <- data_i[[1]]$y
  mu_treated <-mu_treated[!is.na(mu_treated)]
  mu_control <- data_i[[2]]$y
  mu_control <-mu_control[!is.na(mu_control)]
  
  vec1 <- NULL
  vec2 <- NULL
  num1 <- length(mu_treated)
  num2 <- length(mu_control)

  rr <- perm_matrix(num1,num2, exact, b)
  
  for (beta in 1:n_beta){
    add <- 2*const*beta/n_beta -const
    
    y_1 <- mu_treated - add
    y_2 <- mu_control
    y_12 <- c(y_1,y_2)
   
    obs_t <- mean(y_1)-mean(y_2)
    counter_t <- rr %*% y_12
    p_value_1 <- mean(counter_t < obs_t)
    p_value_2 <- mean(counter_t > obs_t)
    vec1 <- c(vec1, min(max(p_value_1,0.0000001),0.9999999)) 
    vec2 <- c(vec2, min(max(p_value_2,0.0000001),0.9999999))
  }
  return(list(vec1,vec2))
}








perm_matrix <- function(num1,num2, exact,b){
  
  #Construct a matrix of permutations for computing the difference-in-means statistics
  
  num <- num1 + num2
  if (exact == TRUE){
    group <- combn(1:num, num1)
    NC <- NCOL(group)
    cat("\n")
    rr <- matrix(-1/num2, nrow=NC, ncol=num)
    for (j in 1:NC){
      rr[j,group[,j]] <- 1/num
    }
  }else if (exact == FALSE){
    var <- c(rep(1/num1, num1), rep(-1/num2, num2))
    rr <- matrix(0, nrow=b, ncol=num)
    for (j in 1:b){
      rr[j,] <- sample(var, size = num, replace=FALSE)
    } 

  }else{
    stop("Error: R stops due to an incorrect setup")
  }
  return(rr)
}








optimal_weight <- function(data_list){
  
  #Compute the weights for the weighted Z-score combiner
  
  
  r_list = NULL
  m_total <- 0
  for (data in data_list){
    dat_1 <- data[[1]]$y
    dat_1 <- dat_1[!is.na(dat_1)]
    dat_0 <- data[[2]]$y
    dat_0 <- dat_0[!is.na(dat_0)]

    m_1 <- NROW(dat_1)  
    m_0 <- NROW(dat_0)  
    m_ <- m_0 + m_1
    
    m_total <- m_total + m_
  }

  for (data in data_list){
    dat_1 <- data[[1]]$y
    dat_1 <- dat_1[!is.na(dat_1)]
    dat_0 <- data[[2]]$y
    dat_0 <- dat_0[!is.na(dat_0)]
    v_1 <- var(dat_1)
    v_0 <- var(dat_0)
    m_1 <- NROW(dat_1)  
    m_0 <- NROW(dat_0)  
    r_ <-  1/(v_1*m_total/m_0 + v_0*m_total/m_1)
    r_list <- c(r_list,r_)
  }
  w <- sqrt(r_list/sum(r_list))
  return(w)
}




synthetic <- function(num_individuals,num_steps,effect_vector,seed,weight=0,index=1,normalize=TRUE,include_x=TRUE){
  
  #Generate synthetic data for Simulation 1&2
  
  data_matrix <- matrix(0,nrow = num_individuals,ncol = num_steps + 1)
  x <- rnorm(num_individuals, 0, 0.25)
  random_effect<-rnorm(num_individuals, 0, 0.25)
  
  subsample <- round(num_individuals/num_steps)
  design <- c(rep(subsample, num_steps-1),num_individuals-subsample*(num_steps-1))
  s <- NULL
  for (t in 1:num_steps){
    s <- c(s,rep(t,design[t]))
  }
  s <- sample(s, size = num_individuals, replace=FALSE)

  
  for (t in 0:num_steps){
    
    if (include_x==TRUE){
      factor_t <- 0.5*(x  + t)  
    }else{
      factor_t<- 0.5*t
    }
    if (index==1){
      add <- 1*factor_t*factor_t 
    }else if (index==2){
      add <- 2*exp(factor_t/2)
    }else if (index==3){
      add <- 5*tanh(factor_t)
    }else{
      stop("Error: R stops due to an incorrect setup")
    }
    
    noise <- rnorm(num_individuals, 0, 0.1)
    
    data_matrix[,t+1] <- (1- weight)*factor_t + weight*add+ random_effect+noise 
    
    for (i in 1:num_individuals){
      gap <- t-s[i]
      if (gap>=0){
        data_matrix[i,t+1] <- data_matrix[i,t+1] + effect_vector[gap+1]
      }
    }
  }
  
  
  x_data <- data.frame(x)
  x_data$id <- seq.int(num_individuals)
  
  s_data <- data.frame(s)
  s_data$id <- seq.int(num_individuals)
  
  mydata <- data.frame(data_matrix)
  colnames(mydata) <- c(0:num_steps)
  mydata$id <- seq.int(num_individuals)
  
  
  
  mdata <- melt(mydata, id=c("id"))
  colnames(mdata)[2] <- "t"
  colnames(mdata)[3] <- "y"
  mdata <- mdata[order(mdata$id),]
  
  total <- merge(s_data,x_data,by="id")
  final <- merge(mdata,total,by="id")
  
  final$t <- as.numeric(final$t) -1 
  lag <- final$t - final$s
  num <- NROW(final)
  z_matrix <-  matrix(0,nrow = num,ncol = num_steps)
  for (j in 1:num){
    if (lag[j]>=0){
      z_matrix[j,lag[j]+1] <- 1
    }
  }
  z_data <- data.frame(z_matrix)
  for(i in 0:(NCOL(z_data)-1)){                   
    colnames(z_data)[i+1] <- paste0("z_", i)
  }
  data_xyz <- cbind(final,z_data) 
  
  if (normalize ==TRUE){
    for (i in 1:num_individuals){
      rows <- data_xyz$id
      counter <- subset(data_xyz,id ==i & t==0)
      data_xyz[rows==i,]$y <- data_xyz[rows==i,]$y - counter$y
    }  
  }
  
  return(data_xyz)
}















