Spaces:
Running
Running
| # | |
| # ============================================================ | |
| # app.R | Shiny App for Rerandomization with fastrerandomize | |
| # ============================================================ | |
| # 1) The user can upload or simulate a covariate dataset (X). | |
| # 2) They specify rerandomization parameters: n_treated, acceptance prob, etc. | |
| # 3) The app generates a set of accepted randomizations under rerandomization. | |
| # 4) The user can optionally upload or simulate outcomes (Y) and run a randomization test. | |
| # 5) The app displays distribution of the balance measure (e.g., Hotelling's T^2) | |
| # and final p-value/fiducial interval, along with run-time comparisons between | |
| # fastrerandomize and base R methods. | |
| # | |
| # ---------------------------- | |
| # Load required packages | |
| # ---------------------------- | |
| options(error=NULL) | |
| library(shiny) | |
| library(shinydashboard) | |
| library(DT) # For data tables | |
| library(ggplot2) # For basic plotting | |
| library(fastrerandomize) # Our rerandomization package | |
| library(parallel) # For detecting CPU cores | |
| # For production apps, ensure fastrerandomize is installed: | |
| # install.packages("devtools") | |
| # devtools::install_github("cjerzak/fastrerandomize-software/fastrerandomize") | |
| # --------------------------------------------------------- | |
| # HELPER FUNCTIONS (BASE R) | |
| # --------------------------------------------------------- | |
| # 1) Compute Hotelling's T^2 in base R | |
| baseR_hotellingT2 <- function(X, W) { | |
| # For a single assignment W: | |
| # T^2 = (n0 * n1 / (n0 + n1)) * (xbar1 - xbar0)^T * S_inv * (xbar1 - xbar0) | |
| n <- length(W) | |
| n1 <- sum(W) | |
| n0 <- n - n1 | |
| if (n1 == 0 || n0 == 0) return(NA_real_) # invalid scenario | |
| xbar_treat <- colMeans(X[W == 1, , drop = FALSE]) | |
| xbar_control <- colMeans(X[W == 0, , drop = FALSE]) | |
| diff_vec <- (xbar_treat - xbar_control) | |
| # covariance (pooled) – we just use cov(X) | |
| S <- cov(X) | |
| Sinv <- tryCatch(solve(S), error = function(e) NULL) | |
| if (is.null(Sinv)) { | |
| # fallback: diagonal approximation if solve fails | |
| Sinv <- diag(1 / diag(S), ncol(S)) | |
| } | |
| out <- (n0 * n1 / (n0 + n1)) * c(t(diff_vec) %*% Sinv %*% diff_vec) | |
| out | |
| } | |
| # 2) Generate randomizations in base R, filtering by acceptance probability | |
| # using T^2 and keep the best (lowest) fraction. | |
| baseR_generate_randomizations <- function(n_units, n_treated, X, accept_prob, random_type, | |
| max_draws, batch_size) { | |
| # For safety, check if exact enumerations will explode: | |
| if (random_type == "exact") { | |
| n_comb_total <- choose(n_units, n_treated) | |
| if (n_comb_total > 1e6) { | |
| warning( | |
| sprintf("Exact randomization is requested, but that is %s combinations. | |
| This may be infeasible in terms of memory/time. | |
| Consider Monte Carlo instead.", format(n_comb_total, big.mark=",")), | |
| immediate. = TRUE | |
| ) | |
| } | |
| } | |
| if (random_type == "exact") { | |
| # -------------- EXACT RANDOMIZATIONS -------------- | |
| cidx <- combn(n_units, n_treated) | |
| # Build assignment matrix | |
| n_comb <- ncol(cidx) | |
| assignment_mat <- matrix(0, nrow = n_comb, ncol = n_units) | |
| for (i in seq_len(n_comb)) { | |
| assignment_mat[i, cidx[, i]] <- 1 | |
| } | |
| # Compute T^2 for each row | |
| T2vals <- apply(assignment_mat, 1, function(w) baseR_hotellingT2(X, w)) | |
| # Drop any NA (in pathological cases) | |
| keep_idx <- which(!is.na(T2vals)) | |
| assignment_mat <- assignment_mat[keep_idx, , drop = FALSE] | |
| T2vals <- T2vals[keep_idx] | |
| # acceptance threshold | |
| cutoff <- quantile(T2vals, probs = accept_prob) | |
| keep_final <- (T2vals < cutoff) | |
| assignment_mat_accepted <- assignment_mat[keep_final, , drop = FALSE] | |
| T2vals_accepted <- T2vals[keep_final] | |
| } else { | |
| # -------------- MONTE CARLO RANDOMIZATIONS -------------- | |
| # We'll sample max_draws permutations | |
| base_assign <- c(rep(1, n_treated), rep(0, n_units - n_treated)) | |
| # We'll store T^2's in chunks to reduce memory overhead | |
| batch_count <- ceiling(max_draws / batch_size) | |
| all_assign <- list() | |
| all_T2 <- numeric(0) | |
| cur_draw <- 0 | |
| for (b in seq_len(batch_count)) { | |
| ndraws_here <- min(batch_size, max_draws - cur_draw) | |
| cur_draw <- cur_draw + ndraws_here | |
| # sample permutations | |
| perms <- matrix(nrow = ndraws_here, ncol = n_units) | |
| for (j in seq_len(ndraws_here)) { | |
| perms[j, ] <- sample(base_assign) | |
| } | |
| # T^2 for each | |
| T2vals_batch <- apply(perms, 1, function(w) baseR_hotellingT2(X, w)) | |
| # collect | |
| all_assign[[b]] <- perms | |
| all_T2 <- c(all_T2, T2vals_batch) | |
| } | |
| assignment_mat <- do.call(rbind, all_assign) | |
| # remove any NA | |
| keep_idx <- which(!is.na(all_T2)) | |
| assignment_mat <- assignment_mat[keep_idx, , drop = FALSE] | |
| all_T2 <- all_T2[keep_idx] | |
| # acceptance threshold | |
| cutoff <- quantile(all_T2, probs = accept_prob) | |
| keep_final <- (all_T2 < cutoff) | |
| assignment_mat_accepted <- assignment_mat[keep_final, , drop = FALSE] | |
| T2vals_accepted <- all_T2[keep_final] | |
| } | |
| list(randomizations = assignment_mat_accepted, balance = T2vals_accepted) | |
| } | |
| # Helper: compute difference in means quickly | |
| diff_in_means <- function(Y, W) { | |
| mean(Y[W == 1]) - mean(Y[W == 0]) | |
| } | |
| # Helper: for a given tau, relabel outcomes and compute the difference in means for a single permutation | |
| compute_diff_at_tau_for_oneW <- function(Wprime, obsY, obsW, tau) { | |
| # Y0_under_null = obsY - obsW * tau | |
| Y0 <- obsY - obsW * tau | |
| # Y1_under_null = Y0 + tau | |
| # But in practice, for assignment Wprime, the observed outcome is: | |
| # Y'(i) = Y0(i) if Wprime(i) = 0, or Y0(i) + tau if Wprime(i)=1 | |
| Yprime <- Y0 | |
| Yprime[Wprime == 1] <- Y0[Wprime == 1] + tau | |
| diff_in_means(Yprime, Wprime) | |
| } | |
| # 3a) For base R randomization test: difference in means + optional p-value | |
| # *without* fiducial interval | |
| # (We will incorporate the FI logic below.) | |
| baseR_randomization_test <- function(obsW, obsY, allW, findFI = FALSE, alpha = 0.05) { | |
| # Observed diff in means | |
| tau_obs <- diff_in_means(obsY, obsW) | |
| # for each candidate assignment, compute diff in means on obsY | |
| diffs <- apply(allW, 1, function(w) diff_in_means(obsY, w)) | |
| # p-value = fraction whose absolute diff >= observed | |
| pval <- mean(abs(diffs) >= abs(tau_obs)) | |
| # optionally compute a fiducial interval | |
| FI <- NULL | |
| if (findFI) { | |
| FI <- baseR_find_fiducial_interval(obsW, obsY, allW, tau_obs, alpha = alpha) | |
| } | |
| list(p_value = pval, tau_obs = tau_obs, FI = FI) | |
| } | |
| # 3b) The fiducial interval logic for base R, mirroring the approach in fastrerandomize: | |
| # 1) Attempt to find a wide lower and upper bracket via random updates | |
| # 2) Then a grid search in [lowerBound-1, upperBound*2] for which tau are accepted. | |
| baseR_find_fiducial_interval <- function(obsW, obsY, allW, tau_obs, alpha = 0.05, c_initial = 2, | |
| n_search_attempts = 500) { | |
| # random bracket approach | |
| lowerBound_est <- tau_obs - 3*tau_obs | |
| upperBound_est <- tau_obs + 3*tau_obs | |
| z_alpha <- qnorm(1 - alpha) | |
| k <- 2 / (z_alpha * (2 * pi)^(-1/2) * exp(-z_alpha^2 / 2)) | |
| # For each iteration, pick one random assignment from allW | |
| # then see how the implied difference changes, and update the bracket | |
| n_allW <- nrow(allW) | |
| for (step_t in seq_len(n_search_attempts)) { | |
| # pick random assignment | |
| idx <- sample.int(n_allW, 1) | |
| Wprime <- allW[idx, ] | |
| # ~~~~~ update lowerBound ~~~~~ | |
| # Y0 = obsY - obsW * lowerBound_est | |
| # Y'(Wprime) = ... | |
| lowerY0 <- obsY - obsW * lowerBound_est | |
| Yprime_lower <- lowerY0 | |
| Yprime_lower[Wprime == 1] <- lowerY0[Wprime == 1] + lowerBound_est | |
| tau_at_step_lower <- diff_in_means(Yprime_lower, Wprime) | |
| c_step <- c_initial | |
| # difference from obs | |
| delta <- tau_obs - tau_at_step_lower | |
| if (tau_at_step_lower < tau_obs) { | |
| # move lowerBound up | |
| lowerBound_est <- lowerBound_est + k * delta * (alpha/2) / step_t | |
| } else { | |
| # move it down | |
| lowerBound_est <- lowerBound_est - k * (-delta) * (1 - alpha/2) / step_t | |
| } | |
| # ~~~~~ update upperBound ~~~~~ | |
| upperY0 <- obsY - obsW * upperBound_est | |
| Yprime_upper <- upperY0 | |
| Yprime_upper[Wprime == 1] <- upperY0[Wprime == 1] + upperBound_est | |
| tau_at_step_upper <- diff_in_means(Yprime_upper, Wprime) | |
| delta2 <- tau_at_step_upper - tau_obs | |
| if (tau_at_step_upper > tau_obs) { | |
| # move upperBound down | |
| upperBound_est <- upperBound_est - k * delta2 * (alpha/2) / step_t | |
| } else { | |
| # move it up | |
| upperBound_est <- upperBound_est + k * (-delta2) * (1 - alpha/2) / step_t | |
| } | |
| } | |
| # Now we do a grid search from (lowerBound_est - 1) to (upperBound_est * 2) | |
| # in e.g. 100 steps, seeing which tau is "accepted". | |
| # We'll define "accepted" if the min of: | |
| # fraction(tau_obs >= distribution_of(tau_pseudo)) | |
| # fraction(tau_obs <= distribution_of(tau_pseudo)) | |
| # is > alpha, i.e. do not reject | |
| grid_lower <- lowerBound_est - 1 | |
| grid_upper <- upperBound_est * 2 | |
| tau_seq <- seq(grid_lower, grid_upper, length.out = 100) | |
| accepted <- logical(length(tau_seq)) | |
| for (i in seq_along(tau_seq)) { | |
| tau_pseudo <- tau_seq[i] | |
| # for each row in allW, compute the diff in means if the true effect = tau_pseudo | |
| # distribution_of(tau_pseudo) | |
| diffs_pseudo <- apply(allW, 1, function(wp) compute_diff_at_tau_for_oneW(wp, obsY, obsW, tau_pseudo)) | |
| # Then see how often diffs_pseudo >= tau_obs (or <= tau_obs) | |
| frac_ge <- mean(diffs_pseudo >= tau_obs) | |
| frac_le <- mean(diffs_pseudo <= tau_obs) | |
| # min(...) is the typical "two-sided" approach | |
| accepted[i] <- (min(frac_ge, frac_le) > alpha / 2) # or 0.05 if we want 5% test | |
| } | |
| if (!any(accepted)) { | |
| # no values accepted => degenerate? | |
| # We'll return the bracket we found, or NA. | |
| return(c(NA, NA)) | |
| } | |
| c(min(tau_seq[accepted]), max(tau_seq[accepted])) | |
| } | |
| # --------------------------------------------------------- | |
| # UI Section | |
| # --------------------------------------------------------- | |
| ui <- dashboardPage( | |
| # ========== Header ================= | |
| dashboardHeader( | |
| title = span( | |
| style = "font-weight: 600; font-size: 14px;", | |
| a( | |
| href = "https://fastrerandomize.github.io/", | |
| "fastrerandomize.github.io", | |
| target = "_blank", | |
| style = "color: white; text-decoration: underline; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;" | |
| ) | |
| ) | |
| ), | |
| # ========== Sidebar ================ | |
| dashboardSidebar( | |
| sidebarMenu( | |
| menuItem("1. Data & Covariates", tabName = "datatab", icon = icon("database")), | |
| menuItem("2. Generate Randomizations", tabName = "gennet", icon = icon("random")), | |
| menuItem("3. Randomization Test", tabName = "randtest", icon = icon("flask")), | |
| # ---- Here is the minimal "Share" button HTML + JS inlined in Shiny ---- | |
| # We wrap it in tags$div(...) and tags$script(HTML(...)) so it is recognized | |
| # by Shiny. You can adjust the styling or placement as needed. | |
| tags$div( | |
| style = "text-align: left; margin: 1em 0 1em 1em;", | |
| HTML(' | |
| <button id="share-button" | |
| style=" | |
| display: inline-flex; | |
| align-items: center; | |
| justify-content: center; | |
| gap: 8px; | |
| padding: 5px 10px; | |
| font-size: 16px; | |
| font-weight: normal; | |
| color: #000; | |
| background-color: #fff; | |
| border: 1px solid #ddd; | |
| border-radius: 6px; | |
| cursor: pointer; | |
| box-shadow: 0 1.5px 0 #000; | |
| "> | |
| <svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" | |
| stroke-width="2" stroke-linecap="round" stroke-linejoin="round"> | |
| <circle cx="18" cy="5" r="3"></circle> | |
| <circle cx="6" cy="12" r="3"></circle> | |
| <circle cx="18" cy="19" r="3"></circle> | |
| <line x1="8.59" y1="13.51" x2="15.42" y2="17.49"></line> | |
| <line x1="15.41" y1="6.51" x2="8.59" y2="10.49"></line> | |
| </svg> | |
| <strong>Share</strong> | |
| </button> | |
| '), | |
| # Insert the JS as well | |
| tags$script( | |
| HTML(" | |
| (function() { | |
| const shareBtn = document.getElementById('share-button'); | |
| // Reusable helper function to show a small “Copied!” message | |
| function showCopyNotification() { | |
| const notification = document.createElement('div'); | |
| notification.innerText = 'Copied to clipboard'; | |
| notification.style.position = 'fixed'; | |
| notification.style.bottom = '20px'; | |
| notification.style.right = '20px'; | |
| notification.style.backgroundColor = 'rgba(0, 0, 0, 0.8)'; | |
| notification.style.color = '#fff'; | |
| notification.style.padding = '8px 12px'; | |
| notification.style.borderRadius = '4px'; | |
| notification.style.zIndex = '9999'; | |
| document.body.appendChild(notification); | |
| setTimeout(() => { notification.remove(); }, 2000); | |
| } | |
| shareBtn.addEventListener('click', function() { | |
| const currentURL = window.location.href; | |
| const pageTitle = document.title || 'Check this out!'; | |
| // If browser supports Web Share API | |
| if (navigator.share) { | |
| navigator.share({ | |
| title: pageTitle, | |
| text: '', | |
| url: currentURL | |
| }) | |
| .catch((error) => { | |
| console.log('Sharing failed', error); | |
| }); | |
| } else { | |
| // Fallback: Copy URL | |
| if (navigator.clipboard && navigator.clipboard.writeText) { | |
| navigator.clipboard.writeText(currentURL).then(() => { | |
| showCopyNotification(); | |
| }, (err) => { | |
| console.error('Could not copy text: ', err); | |
| }); | |
| } else { | |
| // Double fallback for older browsers | |
| const textArea = document.createElement('textarea'); | |
| textArea.value = currentURL; | |
| document.body.appendChild(textArea); | |
| textArea.select(); | |
| try { | |
| document.execCommand('copy'); | |
| showCopyNotification(); | |
| } catch (err) { | |
| alert('Please copy this link:\\n' + currentURL); | |
| } | |
| document.body.removeChild(textArea); | |
| } | |
| } | |
| }); | |
| })(); | |
| ") | |
| ) | |
| ) | |
| # ---- End: Minimal Share button snippet ---- | |
| ) | |
| ), | |
| # ========== Body =================== | |
| dashboardBody( | |
| # A little CSS to keep the design timeless and clean | |
| tags$head( | |
| tags$style(HTML(" | |
| .smalltext { font-size: 90%; color: #555; } | |
| .shiny-output-error { color: red; } | |
| .shiny-input-container { margin-bottom: 15px; } | |
| ")) | |
| ), | |
| tabItems( | |
| # ------------------------------------------------ | |
| # 1) Data & Covariates Tab | |
| # ------------------------------------------------ | |
| tabItem( | |
| tabName = "datatab", | |
| fluidRow( | |
| box(width = 5, title = "Covariate Data: Upload or Simulate", | |
| status = "primary", solidHeader = TRUE, | |
| radioButtons("data_source", "Data Source:", | |
| choices = c("Upload CSV" = "upload", | |
| "Simulate data" = "simulate"), | |
| selected = "simulate"), | |
| conditionalPanel( | |
| condition = "input.data_source == 'upload'", | |
| fileInput("file_covariates", "Choose CSV File", | |
| accept = c(".csv")), | |
| helpText("Columns = features/covariates, rows = units.") | |
| ), | |
| conditionalPanel( | |
| condition = "input.data_source == 'simulate'", | |
| numericInput("sim_n", "Number of units (rows)", | |
| value = 64, min = 10), | |
| numericInput("sim_p", "Number of covariates (columns)", | |
| value = 32, min = 2), | |
| actionButton("simulate_btn", "Simulate X") | |
| ) | |
| ), | |
| box(width = 7, title = "Preview of Covariates (X)", | |
| status = "info", solidHeader = TRUE, | |
| DTOutput("covariates_table")) | |
| ) | |
| ), | |
| # ------------------------------------------------ | |
| # 2) Generate Randomizations Tab | |
| # ------------------------------------------------ | |
| tabItem( | |
| tabName = "gennet", | |
| fluidRow( | |
| box(width = 4, title = "Rerandomization Parameters", | |
| status = "primary", solidHeader = TRUE, | |
| numericInput("n_treated", "Number Treated (n_treated)", | |
| value = 10, min = 1), | |
| selectInput("random_type", "Randomization Type:", | |
| choices = c("Monte Carlo" = "monte_carlo", | |
| "Exact" = "exact"), | |
| selected = "monte_carlo"), | |
| numericInput("accept_prob", "Acceptance Probability (stringency)", | |
| value = 0.01, min = 0.0001, max = 1), | |
| conditionalPanel( | |
| condition = "input.random_type == 'monte_carlo'", | |
| numericInput("max_draws", "Max Draws (MC)", value = 1e5, min = 1e3), | |
| numericInput("batch_size", "Batch Size (MC)", value = 1e3, min = 1e2) | |
| ), | |
| actionButton("generate_btn", "Generate Randomizations") | |
| ), | |
| box(width = 8, title = "Summary of Accepted Randomizations", | |
| status = "info", solidHeader = TRUE, | |
| # First row of boxes: accepted randomizations and min balance measure | |
| fluidRow( | |
| column(width = 6, valueBoxOutput("n_accepted_box", width = 12)), | |
| column(width = 6, valueBoxOutput("balance_min_box", width = 12)) | |
| ), | |
| # Second row of boxes: fastrerandomize time & base R time | |
| fluidRow( | |
| column(width = 6, valueBoxOutput("fastrerand_time_box", width = 12)), | |
| column(width = 6, valueBoxOutput("baseR_time_box", width = 12)) | |
| ), | |
| br(), | |
| plotOutput("balance_hist", height = "250px"), | |
| # Hardware info note | |
| br(), | |
| uiOutput("hardware_info") | |
| ) | |
| ) | |
| ), | |
| # ------------------------------------------------ | |
| # 3) Randomization Test Tab | |
| # ------------------------------------------------ | |
| tabItem( | |
| tabName = "randtest", | |
| fluidRow( | |
| box(width = 4, title = "Randomization Test Setup", | |
| status = "primary", solidHeader = TRUE, | |
| radioButtons("outcome_source", "Outcome Data (Y):", | |
| choices = c("Simulate Y" = "simulate", | |
| "Upload CSV" = "uploadY"), | |
| selected = "simulate"), | |
| conditionalPanel( | |
| condition = "input.outcome_source == 'simulate'", | |
| numericInput("true_tau", "True Effect (simulate)", 1, step = 0.5), | |
| numericInput("noise_sd", "Noise SD for Y", 0.5, step = 0.1), | |
| actionButton("simulateY_btn", "Simulate Y") | |
| ), | |
| conditionalPanel( | |
| condition = "input.outcome_source == 'uploadY'", | |
| fileInput("file_outcomes", "Choose CSV File with outcome vector Y", | |
| accept = c(".csv")), | |
| helpText("Single column with length = #units.") | |
| ), | |
| br(), | |
| actionButton("run_randtest_btn", "Run Randomization Test"), | |
| checkboxInput("findFI", "Compute Fiducial Interval?", value = FALSE) | |
| ), | |
| box(width = 8, title = "Test Results", status = "info", solidHeader = TRUE, | |
| # First row: p-value and observed effect (fastrerandomize) | |
| fluidRow( | |
| column(width = 6, valueBoxOutput("pvalue_box", width = 12)), | |
| column(width = 6, valueBoxOutput("tauobs_box", width = 12)) | |
| ), | |
| # Second row: fastrerandomize test time & base R test time | |
| fluidRow( | |
| column(width = 6, valueBoxOutput("fastrerand_test_time_box", width = 12)), | |
| column(width = 6, valueBoxOutput("baseR_test_time_box", width = 12)) | |
| ), | |
| # Show fastrerandomize FI | |
| uiOutput("fi_text"), | |
| # Now show Base R results in a separate row | |
| tags$hr(), | |
| fluidRow( | |
| column(width = 6, valueBoxOutput("pvalue_box_baseR", width = 12)), | |
| column(width = 6, valueBoxOutput("tauobs_box_baseR", width = 12)) | |
| ), | |
| fluidRow( | |
| column(width = 12, uiOutput("fi_text_baseR")) | |
| ), | |
| br(), | |
| plotOutput("test_plot", height = "280px") | |
| ) | |
| ) | |
| ) | |
| ) # end tabItems | |
| ) # end dashboardBody | |
| ) # end dashboardPage | |
| # --------------------------------------------------------- | |
| # SERVER | |
| # --------------------------------------------------------- | |
| server <- function(input, output, session) { | |
| # ------------------------------------------------------- | |
| # 1. Covariate Data Handling | |
| # ------------------------------------------------------- | |
| # We store the covariate matrix X in a reactiveVal for convenient reuse | |
| X_data <- reactiveVal(NULL) | |
| # Observe file input or simulation for X | |
| observeEvent(input$file_covariates, { | |
| req(input$file_covariates) | |
| inFile <- input$file_covariates | |
| df <- tryCatch(read.csv(inFile$datapath, header = TRUE), | |
| error = function(e) NULL) | |
| if (!is.null(df)) { | |
| X_data(as.matrix(df)) | |
| } | |
| }) | |
| # If the user clicks "Simulate X" | |
| observeEvent(input$simulate_btn, { | |
| n <- input$sim_n | |
| p <- input$sim_p | |
| # Basic simulation of N(0,1) data | |
| simX <- matrix(rnorm(n * p), nrow = n, ncol = p) | |
| X_data(simX) | |
| }) | |
| # Show X in table | |
| output$covariates_table <- renderDT({ | |
| req(X_data()) | |
| # Round all numeric columns to 3 significant digits | |
| df <- as.data.frame(X_data()) | |
| numeric_cols <- sapply(df, is.numeric) | |
| df[numeric_cols] <- lapply(df[numeric_cols], signif, digits = 3) | |
| datatable(df, options = list(scrollX = TRUE, pageLength = 10)) | |
| }) | |
| # ------------------------------------------------------- | |
| # 2. Generate Rerandomizations | |
| # ------------------------------------------------------- | |
| # We'll keep the accepted randomizations from fastrerandomize in RerandResult | |
| # and from base R in RerandResult_base. | |
| RerandResult <- reactiveVal(NULL) | |
| RerandResult_base <- reactiveVal(NULL) | |
| # We also store their run times | |
| fastrand_time <- reactiveVal(NULL) | |
| baseR_time <- reactiveVal(NULL) | |
| observeEvent(input$generate_btn, { | |
| req(X_data()) | |
| validate( | |
| need(nrow(X_data()) >= input$n_treated, | |
| "Number treated cannot exceed total units.") | |
| ) | |
| withProgress(message = "Computing results...", value = 0, { | |
| # =========== 1) fastrerandomize generation timing =========== | |
| t0_fast <- Sys.time() | |
| out <- tryCatch({ | |
| generate_randomizations( | |
| n_units = nrow(X_data()), | |
| n_treated = input$n_treated, | |
| X = X_data(), | |
| randomization_accept_prob= input$accept_prob, | |
| randomization_type = input$random_type, | |
| max_draws = if (input$random_type == "monte_carlo") input$max_draws else NULL, | |
| batch_size = if (input$random_type == "monte_carlo") input$batch_size else NULL, | |
| verbose = FALSE | |
| ) | |
| }, error = function(e) e) | |
| t1_fast <- Sys.time() | |
| if (inherits(out, "error")) { | |
| showNotification(paste("Error generating randomizations (fastrerandomize):", out$message), type = "error") | |
| RerandResult(NULL) | |
| } else { | |
| RerandResult(out) | |
| } | |
| fastrand_time(difftime(t1_fast, t0_fast, units = "secs")) | |
| # =========== 2) base R generation timing =========== | |
| t0_base <- Sys.time() | |
| out_base <- tryCatch({ | |
| baseR_generate_randomizations( | |
| n_units = nrow(X_data()), | |
| n_treated = input$n_treated, | |
| X = X_data(), | |
| accept_prob= input$accept_prob, | |
| random_type= input$random_type, | |
| max_draws = if (input$random_type == "monte_carlo") input$max_draws else NULL, | |
| batch_size = if (input$random_type == "monte_carlo") input$batch_size else NULL | |
| ) | |
| }, error = function(e) e) | |
| t1_base <- Sys.time() | |
| if (inherits(out_base, "error")) { | |
| showNotification(paste("Error generating randomizations (base R):", out_base$message), type = "error") | |
| RerandResult_base(NULL) | |
| } else { | |
| RerandResult_base(out_base) | |
| } | |
| baseR_time(difftime(t1_base, t0_base, units = "secs")) | |
| }) | |
| }) | |
| # Summaries of accepted randomizations | |
| output$n_accepted_box <- renderValueBox({ | |
| rr <- RerandResult() | |
| if (is.null(rr) || is.null(rr$randomizations)) { | |
| valueBox("0", "Accepted Randomizations", icon = icon("ban"), color = "red") | |
| } else { | |
| nAcc <- nrow(rr$randomizations) | |
| valueBox(nAcc, "Accepted Randomizations", icon = icon("check"), color = "green") | |
| } | |
| }) | |
| output$balance_min_box <- renderValueBox({ | |
| rr <- RerandResult() | |
| if (is.null(rr) || is.null(rr$balance)) { | |
| valueBox("---", "Min Balance Measure", icon = icon("question"), color = "orange") | |
| } else { | |
| minBal <- round(min(rr$balance), 4) | |
| valueBox(minBal, "Min Balance Measure", icon = icon("thumbs-up"), color = "blue") | |
| } | |
| }) | |
| # Timings for generation: fastrerandomize and base R | |
| output$fastrerand_time_box <- renderValueBox({ | |
| tm <- fastrand_time() | |
| if (is.null(tm)) { | |
| valueBox("---", "fastrerandomize generation time (secs)", icon = icon("clock"), color = "teal") | |
| } else { | |
| valueBox(round(as.numeric(tm), 3), "fastrerandomize generation time (secs)", | |
| icon = icon("clock"), color = "teal") | |
| } | |
| }) | |
| output$baseR_time_box <- renderValueBox({ | |
| tm <- baseR_time() | |
| if (is.null(tm)) { | |
| valueBox("---", "base R generation time (secs)", icon = icon("clock"), color = "lime") | |
| } else { | |
| valueBox(round(as.numeric(tm), 3), "base R generation time (secs)", | |
| icon = icon("clock"), color = "lime") | |
| } | |
| }) | |
| # Plot histogram of the balance measure (fastrerandomize result) | |
| output$balance_hist <- renderPlot({ | |
| rr <- RerandResult() | |
| req(rr, rr$balance) | |
| df <- data.frame(balance = rr$balance) | |
| ggplot(df, aes(x = balance)) + | |
| geom_histogram(binwidth = diff(range(df$balance))/30, fill = "darkblue", alpha = 0.7) + | |
| labs(title = "Distribution of Balance Measure", | |
| x = "Balance (e.g. T^2)", | |
| y = "Frequency") + | |
| theme_minimal(base_size = 14) | |
| }) | |
| # Hardware info (CPU cores, GPU note) | |
| output$hardware_info <- renderUI({ | |
| num_cores <- detectCores(logical = TRUE) | |
| HTML(paste( | |
| "<strong>System Hardware Info:</strong><br/>", | |
| "Number of CPU cores detected:", num_cores, "<br/>", | |
| "With additional CPU or GPU, greater speedups can be expected.<br/>", | |
| "Note: Speedups greatest in high-dimensional or large-N settings.<br/>" | |
| )) | |
| }) | |
| # ------------------------------------------------------- | |
| # 3. Randomization Test | |
| # ------------------------------------------------------- | |
| Y_data <- reactiveVal(NULL) | |
| # (A) If user simulates Y | |
| observeEvent(input$simulateY_btn, { | |
| req(RerandResult()) | |
| rr <- RerandResult() | |
| if (is.null(rr$randomizations) || nrow(rr$randomizations) < 1) { | |
| showNotification("No accepted randomizations found. Cannot simulate Y for the 'observed' assignment.", type = "error") | |
| return(NULL) | |
| } | |
| obsW <- rr$randomizations[1, ] | |
| nunits <- length(obsW) | |
| # Basic data generation: Y = X * beta + tau * W + noise | |
| Xval <- X_data() | |
| if (is.null(Xval)) { | |
| showNotification("No covariate data found to help simulate outcomes. Using intercept-only model.", type="warning") | |
| Xval <- matrix(0, nrow = nunits, ncol = 1) | |
| } | |
| # random coefficients | |
| beta <- rnorm(ncol(Xval), 0, 1) | |
| linear_part <- Xval %*% beta | |
| Ysim <- as.numeric(linear_part + obsW * input$true_tau + rnorm(nunits, 0, input$noise_sd)) | |
| Y_data(Ysim) | |
| }) | |
| # (B) If user uploads Y | |
| observeEvent(input$file_outcomes, { | |
| req(input$file_outcomes) | |
| inFile <- input$file_outcomes | |
| dfy <- tryCatch(read.csv(inFile$datapath, header = FALSE), error=function(e) NULL) | |
| if (!is.null(dfy)) { | |
| if (ncol(dfy) > 1) { | |
| showNotification("Please provide a single-column CSV for Y.", type="error") | |
| } else { | |
| Y_data(as.numeric(dfy[[1]])) | |
| } | |
| } | |
| }) | |
| # The randomization test result: | |
| RandTestResult <- reactiveVal(NULL) | |
| RandTestResult_base <- reactiveVal(NULL) | |
| # We'll store their times: | |
| fastrand_test_time <- reactiveVal(NULL) | |
| baseR_test_time <- reactiveVal(NULL) | |
| observeEvent(input$run_randtest_btn, { | |
| withProgress(message = "Computing results...", value = 0, { | |
| req(RerandResult()) | |
| rr <- RerandResult() | |
| req(rr$randomizations) | |
| if (is.null(Y_data())) { | |
| showNotification("No outcome data Y found. Upload or simulate first.", type="error") | |
| return(NULL) | |
| } | |
| obsW <- rr$randomizations[1, ] | |
| obsY <- Y_data() | |
| # =========== 1) fastrerandomize randomization_test timing =========== | |
| t0_testfast <- Sys.time() | |
| outTest <- tryCatch({ | |
| randomization_test( | |
| obsW = obsW, | |
| obsY = obsY, | |
| candidate_randomizations = rr$randomizations, | |
| findFI = input$findFI | |
| ) | |
| }, error=function(e) e) | |
| t1_testfast <- Sys.time() | |
| if (inherits(outTest, "error")) { | |
| showNotification(paste("Error in randomization_test (fastrerandomize):", outTest$message), type="error") | |
| RandTestResult(NULL) | |
| } else { | |
| RandTestResult(outTest) | |
| } | |
| fastrand_test_time(difftime(t1_testfast, t0_testfast, units = "secs")) | |
| # =========== 2) base R randomization test timing =========== | |
| req(RerandResult_base()) | |
| rr_base <- RerandResult_base() | |
| if (is.null(rr_base$randomizations) || nrow(rr_base$randomizations) < 1) { | |
| showNotification("No base R randomizations found. Cannot run base R test.", type = "error") | |
| RandTestResult_base(NULL) | |
| return(NULL) | |
| } | |
| t0_testbase <- Sys.time() | |
| outTestBase <- tryCatch({ | |
| baseR_randomization_test( | |
| obsW = obsW, | |
| obsY = obsY, | |
| allW = rr_base$randomizations, | |
| findFI = input$findFI # if user wants the FI, do so | |
| ) | |
| }, error = function(e) e) | |
| t1_testbase <- Sys.time() | |
| if (inherits(outTestBase, "error")) { | |
| showNotification(paste("Error in randomization_test (base R):", outTestBase$message), type="error") | |
| RandTestResult_base(NULL) | |
| } else { | |
| RandTestResult_base(outTestBase) | |
| } | |
| baseR_test_time(difftime(t1_testbase, t0_testbase, units = "secs")) | |
| }) | |
| }) | |
| # Display p-value and observed tau (from the fastrerandomize test) | |
| output$pvalue_box <- renderValueBox({ | |
| rt <- RandTestResult() | |
| if (is.null(rt)) { | |
| valueBox("---", "p-value (fastrerandomize)", icon = icon("question"), color = "blue") | |
| } else { | |
| valueBox(round(rt$p_value, 4), "p-value (fastrerandomize)", icon = icon("list-check"), color = "purple") | |
| } | |
| }) | |
| output$tauobs_box <- renderValueBox({ | |
| rt <- RandTestResult() | |
| if (is.null(rt)) { | |
| valueBox("---", "Observed Effect (fastrerandomize)", icon = icon("question"), color = "maroon") | |
| } else { | |
| valueBox(round(rt$tau_obs, 4), "Observed Effect (fastrerandomize)", icon = icon("bullseye"), color = "maroon") | |
| } | |
| }) | |
| # Times for randomization test | |
| output$fastrerand_test_time_box <- renderValueBox({ | |
| tm <- fastrand_test_time() | |
| if (is.null(tm)) { | |
| valueBox("---", "fastrerandomize test time (secs)", icon = icon("clock"), color = "teal") | |
| } else { | |
| valueBox(round(as.numeric(tm), 3), "fastrerandomize test time (secs)", | |
| icon = icon("clock"), color = "teal") | |
| } | |
| }) | |
| output$baseR_test_time_box <- renderValueBox({ | |
| tm <- baseR_test_time() | |
| if (is.null(tm)) { | |
| valueBox("---", "base R test time (secs)", icon = icon("clock"), color = "lime") | |
| } else { | |
| valueBox(round(as.numeric(tm), 3), "base R test time (secs)", | |
| icon = icon("clock"), color = "lime") | |
| } | |
| }) | |
| # If we have a fiducial interval from fastrerandomize, display it | |
| output$fi_text <- renderUI({ | |
| rt <- RandTestResult() | |
| if (is.null(rt) || is.null(rt$FI)) { | |
| return(NULL) | |
| } | |
| fi_lower <- round(rt$FI[1], 4) | |
| fi_upper <- round(rt$FI[2], 4) | |
| tagList( | |
| strong("Fiducial Interval (fastrerandomize, 95%):"), | |
| p(sprintf("[%.4f, %.4f]", fi_lower, fi_upper)) | |
| ) | |
| }) | |
| # If we have a fiducial interval from base R, display it | |
| output$fi_text_baseR <- renderUI({ | |
| rt <- RandTestResult_base() | |
| if (is.null(rt) || is.null(rt$FI)) { | |
| return(NULL) | |
| } | |
| fi_lower <- round(rt$FI[1], 4) | |
| fi_upper <- round(rt$FI[2], 4) | |
| tagList( | |
| strong("Fiducial Interval (base R, 95%):"), | |
| p(sprintf("[%.4f, %.4f]", fi_lower, fi_upper)) | |
| ) | |
| }) | |
| # A simple plot for the randomization distribution (for demonstration). | |
| # In this app, we do not store the entire distribution from either method, | |
| # so we simply show the observed effect as a point. | |
| output$test_plot <- renderPlot({ | |
| rt <- RandTestResult() | |
| if (is.null(rt)) { | |
| plot.new() | |
| title("No test results yet.") | |
| return(NULL) | |
| } | |
| # Just display the observed effect from fastrerandomize | |
| obs_val <- rt$tau_obs | |
| ggplot(data.frame(x = obs_val, y = 0), aes(x, y)) + | |
| geom_point(size=4, color="red") + | |
| xlim(c(obs_val - abs(obs_val)*2 - 1, obs_val + abs(obs_val)*2 + 1)) + | |
| labs(title = "Observed Treatment Effect (fastrerandomize)", | |
| x = "Effect Size", y = "") + | |
| theme_minimal(base_size = 14) + | |
| geom_vline(xintercept = 0, linetype="dashed", color="gray40") | |
| }) | |
| } | |
| # --------------------------------------------------------- | |
| # Run the Application | |
| # --------------------------------------------------------- | |
| shinyApp(ui = ui, server = server) | |