# balance function after 2019

balance <- function (rawdata, treat, matched, estimand="ATT")
                     #factor = TRUE) 
{
  # rawdata: the full covariate dataset 
  # treat: the vector of treatment assignments for the full dataset
  # matched: vector of weights to apply to the full dataset to create the 
  # restructured data:  
  # --for matching without replacement these will all be 0's and 1's
  # --for one-to-one matching with replacement these will all be non-negative 
  #   integers
  # --for IPTW or more complicated matching methods these could be any 
  #   non-negative numbers
  # estimand: can either be ATT, ATC, or ATE
  require("Hmisc")
  if(missing(rawdata)) stop("rawdata is required")
  if(missing(matched)) stop("argument matched is required")
  if(missing(treat)) stop("treatment vector (treat) is required")
  cat("Balance diagnostics assume that the estimand is the",estimand,"\n")
  #
  #raw.dat <- data.frame(rawdata, treat = treat)
  covnames <- colnames(rawdata)
  if (is.null(covnames)){
    cat("No covariate names provided.  Generic names will be generated.")
    covnames = paste("v",c(1:ncol(rawdata)),sep="")
  }
  K <- length(covnames) 
  diff.means <- matrix(NA, K, 5)
  var.t <- numeric(K)
  var.c <- numeric(K)
  std.denom <- numeric(K)
  binary <- rep(1,K)
  #
  # First we calculate balance on the RAW DATA
  # Columns are (1) treat mean, (2) control mean, (3) diff in means, (4) abs std diff,
  # (5) ratio of sds
  for (i in 1:K) {
    # separate means by group
    diff.means[i, 1] <- mean(rawdata[treat==1, i]) 
    diff.means[i, 2] <- mean(rawdata[treat==0, i])
    # separate variances by group == only used as input to calculations below
    var.t[i] <- var(rawdata[(treat == 1), i])
    var.c[i] <- var(rawdata[(treat == 0), i])
    # denominator in standardized difference calculations
    if(estimand=="ATE"){std.denom[i] <- sqrt((var.t[i]+var.c[i])/2)}
    else{
      std.denom[i] <- ifelse(estimand=="ATT",sqrt(var.t[i]),sqrt(var.c[i]))
    }
    # difference in means
    diff.means[i, 3] <- diff.means[i, 1] - diff.means[i, 2]
    # standardized difference in means (sign intact)
    diff.means[i, 4] <- abs(diff.means[i, 3]/std.denom[i])
    if(length(unique(rawdata[,covnames[i]]))>2){
    binary[i] = 0
    }
  }
    #ifelse(estimand="ATT",sqrt(var.c[i]/var.t[i]),sqrt(var.t[i]/var.c[i]))
#  dimnames(diff.means) <- list(covnames[-(K + 1)], c("treat", "control", "unstd.diff", 
#                                                     "abs.std.diff", "ratio"))
#  diff.means[is.na(diff.means)] = "--"  #maybe only worry about in print function
  dimnames(diff.means) <- list(covnames, c("treat", "control", "unstd.diff", 
                                                     "abs.std.diff", "ratio"))
  # Now we calculate balance on the restructured data
  diff.means.matched = matrix(NA, K, 5)
  #
  for (i in 1:K) {
    wts0 <- matched[treat==0]
    # separate means by group
    diff.means.matched[i, 1] <- mean(rawdata[treat == 1, i])
    diff.means.matched[i, 2] <- weighted.mean(rawdata[treat==0, i],w=wts0)
    # separate variances by group == only used as input to calculations below
    # these overwrite the variance above
    var.t[i] <- var(rawdata[treat == 1, i])
    var.c[i] <- wtd.var(rawdata[treat==0,i],weights=wts0)
    # difference in means
    diff.means.matched[i, 3] <- diff.means.matched[i, 1] - diff.means.matched[i, 2]
    # absolute standardized difference in means (denominator is stolen from
    # calculations on raw data above)
    diff.means.matched[i, 4] <- abs(diff.means.matched[i, 3])/std.denom[i]
    if(length(unique(rawdata[,covnames[i]]))>2){
      # just for binary
      # ratio of sds (treat over control:  should we change to comparison over inferential)
    diff.means.matched[i, 5] <- sqrt(var.c[i]/var.t[i])
    }
  }
  #dimnames(diff.means.matched) <- list(covnames[-(K + 1)], c("treat", "control", "unstd.diff", 
  #                                                            "abs.std.diff", "ratio"))
  dimnames(diff.means.matched) <- list(covnames, c("treat", "control", "unstd.diff", 
                                                             "abs.std.diff", "ratio"))
  #
  out <- list(diff.means.raw = diff.means, diff.means.matched = diff.means.matched, 
              covnames = covnames, binary = binary)
  class(out) <- "balance"
  return(out)
}


print.balance <- function(x, ..., combined=FALSE, digits= 2)
{
 if(combined==FALSE){
  cat("Balance Statistics for Unmatched Data\n")
  cat("--\n")
  print(round(x$diff.means.raw, digits=digits))
  cat("--\n")
  cat("\n")
  cat("Balance Statistics for Matched Data\n")
  cat("--\n")
  print(round(x$diff.means.matched, digits=digits), na.print="--")
  cat("--\n")
  cat("\n")
  }
 else{
   cat("Balance Statistics\n")
   cat("--\n") 
   print(round(cbind(x$diff.means.raw,x$diff.matched.raw)[,c(4,9,5,10)], 
               digits=digits), na.print="--")
 }
}

### NEXT NEED TO FIGURE OUT HOW TO REVERSE THE ORDER OF THE COVARIATES

plot.balance <- function(x, longcovnames=NULL, which.covs="mixed",
                v.axis=TRUE, cex.main=1, cex.vars=0.8, cex.pts=0.8,
                mar=c(4, 3, 5.1, 2), plot=TRUE, x.max = NULL,...)
{
# if which.covs = mixed then it plots all as std diffs
# if which.covs = binary it only plots binary and as abs unstd diffs
# if which.covs = cont it only plots non-binary and as abs std diffs
#  
  #  
  covnames <- x$covnames
  if(!is.null(x.max)){x.range = c(0,x.max)}
 # if(which.covs=="binary") {
#    cat("condition satisfied \n")
#  }
 
# if plotting all, then use the standardized diff for all
  if(which.covs == "mixed"){
  pts <-  x$diff.means.raw[,4]                    # before matched.dat
  pts2 <- x$diff.means.matched[,4]                  # after matched
  K <- length(pts)
  idx <- 1:K
  main="Absolute Standardized Difference in Means"
}
  #if plotting just binary use the unstandardized difference
  # for the plot make it the absolute value of
  if(which.covs == "binary"){
    pts <-  abs(x$diff.means.raw[x$binary==TRUE,3])      # before matched.dat
    pts2 <- abs(x$diff.means.matched[x$binary==TRUE,3])  # after matched
    K <- length(pts)
    idx <- 1:K
    main="Absolute Difference in Means"
    covnames = covnames[x$binary==TRUE]
      }
  #if plotting just continuous use the standardized difference
  if(which.covs == "cont"){
    pts <-  x$diff.means.raw[x$binary==FALSE,4]      # before matched
    pts2 <- x$diff.means.matched[x$binary==FALSE,4]  # after matched
    K <- length(pts)
    idx <- 1:K
    main="Absolute Standardized Difference in Means"
    covnames = covnames[x$binary==FALSE]
    }
 cat(pts,"\n")
  # tune the graphic console
  #par (mar=mar, mgp=mgp, oma=oma, tcl=tcl)

  par(mar = mar)
  if (is.null(longcovnames)) {
      longcovnames <- covnames
      maxchar <- max(sapply(longcovnames, nchar))
  }
  else {
      maxchar <- max(sapply(longcovnames, nchar))
  }
  min.mar <- par("mar")
  mar[2] <- min(min.mar[2], trunc(mar[2] + maxchar/10)) + mar[2] + 0.1
  par(mar = mar)

  ## now reverse the order of everything so the plot proceeds from
  ## to top to bottom with respect to original ordering of variables
  pts = rev(pts)
  pts2 = rev(pts2)
  longcovnames = rev(longcovnames)
  #
  if(plot){
     # plot the estimates
    if(is.null(x.max)){
      plot(c(pts,pts2), c(idx,idx),
           bty="n", xlab="", ylab="",
           xaxt="n", yaxt="n", type="n",
           main=main, cex.main=cex.main)
    }
    if(!is.null(x.max)){
      plot(c(pts,pts2), c(idx,idx),
         bty="n", xlab="", ylab="",
         xaxt="n", yaxt="n", type="n",
         xlim=x.range,
         main=main, cex.main=cex.main)
      }
     abline(v=0, lty=2)
     points(pts, idx, cex=cex.pts)          # before matched
     points(pts2, idx, pch=19, cex=cex.pts) # after matched
     if (v.axis){
         axis(3)
     }
     if (is.null(longcovnames)){
         axis(2, at=1:K, labels=covnames[1:K],
             las=2, hadj=1, lty=0, cex.axis=cex.vars)
     }
     else{
         axis(2, at=1:K, labels=longcovnames[1:K],
             las=2, hadj=1, lty=0, cex.axis=cex.vars)
     }
  }
  else{
    plot(c(pts,pts2), c(idx,idx),
      bty="n", xlab="", ylab="",
      xaxt="n", yaxt="n", #xaxs="i",
      #yaxs="i",
      type="n", axes=FALSE,
      #ylim=c(max(idx)+.25, min(idx)-.25),
      #xlim=x.range,
      main="", cex.main=cex.main,...)
  }
  return(list("raw"=pts, "matched"=pts2))
}