cjerzak commited on
Commit
3caf059
·
verified ·
1 Parent(s): 5ceabec

Update app.R

Browse files
Files changed (1) hide show
  1. app.R +300 -197
app.R CHANGED
@@ -1,3 +1,4 @@
 
1
  # ============================================================
2
  # app.R | Shiny App for Rerandomization with fastrerandomize
3
  # ============================================================
@@ -5,8 +6,10 @@
5
  # 2) They specify rerandomization parameters: n_treated, acceptance prob, etc.
6
  # 3) The app generates a set of accepted randomizations under rerandomization.
7
  # 4) The user can optionally upload or simulate outcomes (Y) and run a randomization test.
8
- # 5) The app displays distribution of the balance measure (e.g., Hotelling's T^2) and final p-value/fiducial interval.
9
-
 
 
10
  # ----------------------------
11
  # Load required packages
12
  # ----------------------------
@@ -20,6 +23,125 @@ library(fastrerandomize) # Our rerandomization package
20
  # install.packages("devtools")
21
  # devtools::install_github("cjerzak/fastrerandomize-software/fastrerandomize")
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # ---------------------------------------------------------
24
  # UI Section
25
  # ---------------------------------------------------------
@@ -80,8 +202,8 @@ ui <- dashboardPage(
80
 
81
  conditionalPanel(
82
  condition = "input.data_source == 'simulate'",
83
- numericInput("sim_n", "Number of units (rows)", value = 20, min = 2),
84
- numericInput("sim_p", "Number of covariates (columns)", value = 3, min = 1),
85
  actionButton("simulate_btn", "Simulate X")
86
  )
87
  ),
@@ -89,17 +211,6 @@ ui <- dashboardPage(
89
  box(width = 7, title = "Preview of Covariates (X)",
90
  status = "info", solidHeader = TRUE,
91
  DTOutput("covariates_table"))
92
- ),
93
-
94
- # Performance info for data steps
95
- fluidRow(
96
- box(width = 12, title = "Performance Info for Data Steps", status = "warning", solidHeader = TRUE,
97
- p("Time to upload X (CSV):"),
98
- textOutput("time_data_upload"),
99
- br(),
100
- p("Time to simulate X:"),
101
- textOutput("time_data_sim")
102
- )
103
  )
104
  ),
105
 
@@ -122,7 +233,7 @@ ui <- dashboardPage(
122
  value = 0.05, min = 0.0001, max = 1),
123
  conditionalPanel(
124
  condition = "input.random_type == 'monte_carlo'",
125
- numericInput("max_draws", "Max Draws (MC)", value = 1e4, min = 1e3),
126
  numericInput("batch_size", "Batch Size (MC)", value = 1e3, min = 1e2)
127
  ),
128
  actionButton("generate_btn", "Generate Randomizations")
@@ -131,19 +242,14 @@ ui <- dashboardPage(
131
  box(width = 8, title = "Summary of Accepted Randomizations",
132
  status = "info", solidHeader = TRUE,
133
  fluidRow(
134
- valueBoxOutput("n_accepted_box", width = 6),
135
- valueBoxOutput("balance_min_box", width = 6)
 
 
136
  ),
137
  br(),
138
  plotOutput("balance_hist", height = "250px")
139
  )
140
- ),
141
-
142
- # Performance info for randomization generation
143
- fluidRow(
144
- box(width = 12, title = "Performance Info for Generation", status = "warning", solidHeader = TRUE,
145
- textOutput("time_generate")
146
- )
147
  )
148
  ),
149
 
@@ -182,27 +288,15 @@ ui <- dashboardPage(
182
 
183
  box(width = 8, title = "Test Results", status = "info", solidHeader = TRUE,
184
  fluidRow(
185
- valueBoxOutput("pvalue_box", width = 6),
186
- valueBoxOutput("tauobs_box", width = 6)
 
 
187
  ),
188
  uiOutput("fi_text"),
189
  br(),
190
  plotOutput("test_plot", height = "280px")
191
  )
192
- ),
193
-
194
- # Performance info for randomization test
195
- fluidRow(
196
- box(width = 12, title = "Performance Info for Randomization Test", status = "warning", solidHeader = TRUE,
197
- p("Time to upload Y (CSV):"),
198
- textOutput("time_data_uploadY"),
199
- br(),
200
- p("Time to simulate Y:"),
201
- textOutput("time_data_simY"),
202
- br(),
203
- p("Time to run randomization test:"),
204
- textOutput("time_randtest")
205
- )
206
  )
207
  )
208
 
@@ -215,31 +309,18 @@ ui <- dashboardPage(
215
  # ---------------------------------------------------------
216
  server <- function(input, output, session) {
217
 
218
- # -- ReactiveVals to store performance times (seconds)
219
- time_data_upload <- reactiveVal(NA_real_)
220
- time_data_sim <- reactiveVal(NA_real_)
221
- time_generate <- reactiveVal(NA_real_)
222
- time_data_uploadY <- reactiveVal(NA_real_)
223
- time_data_simY <- reactiveVal(NA_real_)
224
- time_randtest <- reactiveVal(NA_real_)
225
-
226
  # -------------------------------------------------------
227
  # 1. Covariate Data Handling
228
  # -------------------------------------------------------
229
  # We store the covariate matrix X in a reactiveVal for convenient reuse
230
  X_data <- reactiveVal(NULL)
231
 
232
- # Observe file input (upload) for X
233
  observeEvent(input$file_covariates, {
234
  req(input$file_covariates)
235
  inFile <- input$file_covariates
236
-
237
- start_time <- Sys.time()
238
  df <- tryCatch(read.csv(inFile$datapath, header = TRUE),
239
  error = function(e) NULL)
240
- end_time <- Sys.time()
241
- time_data_upload(as.numeric(difftime(end_time, start_time, units = "secs")))
242
-
243
  if (!is.null(df)) {
244
  X_data(as.matrix(df))
245
  }
@@ -249,13 +330,8 @@ server <- function(input, output, session) {
249
  observeEvent(input$simulate_btn, {
250
  n <- input$sim_n
251
  p <- input$sim_p
252
-
253
- start_time <- Sys.time()
254
  # Basic simulation of N(0,1) data
255
  simX <- matrix(rnorm(n * p), nrow = n, ncol = p)
256
- end_time <- Sys.time()
257
- time_data_sim(as.numeric(difftime(end_time, start_time, units = "secs")))
258
-
259
  X_data(simX)
260
  })
261
 
@@ -266,30 +342,17 @@ server <- function(input, output, session) {
266
  options = list(scrollX = TRUE, pageLength = 5))
267
  })
268
 
269
- # --- Performance outputs for Data & Covariates
270
- output$time_data_upload <- renderText({
271
- t <- time_data_upload()
272
- if (is.na(t)) {
273
- "Not run yet."
274
- } else {
275
- paste0(round(t, 3), " seconds")
276
- }
277
- })
278
-
279
- output$time_data_sim <- renderText({
280
- t <- time_data_sim()
281
- if (is.na(t)) {
282
- "Not run yet."
283
- } else {
284
- paste0(round(t, 3), " seconds")
285
- }
286
- })
287
-
288
  # -------------------------------------------------------
289
  # 2. Generate Rerandomizations
290
  # -------------------------------------------------------
291
- # We'll keep the accepted randomizations in a reactiveVal
 
292
  RerandResult <- reactiveVal(NULL)
 
 
 
 
 
293
 
294
  observeEvent(input$generate_btn, {
295
  req(X_data())
@@ -298,37 +361,53 @@ server <- function(input, output, session) {
298
  "Number treated cannot exceed total units.")
299
  )
300
 
301
- # withProgress to show progress bar in the UI
302
- withProgress(message = "Computing randomizations...", value = 0, {
303
-
304
- # Measure time
305
- start_time <- Sys.time()
306
-
307
- # We call generate_randomizations() from fastrerandomize
308
- nunits <- nrow(X_data())
309
- out <- tryCatch({
310
- generate_randomizations(
311
- n_units = nunits,
312
- n_treated = input$n_treated,
313
- X = X_data(),
314
- randomization_accept_prob= input$accept_prob,
315
- randomization_type = input$random_type,
316
- max_draws = if (input$random_type == "monte_carlo") input$max_draws else NULL,
317
- batch_size = if (input$random_type == "monte_carlo") input$batch_size else NULL,
318
- verbose = FALSE
319
- )
320
- }, error = function(e) e)
321
-
322
- # End time
323
- end_time <- Sys.time()
324
- time_generate(as.numeric(difftime(end_time, start_time, units = "secs")))
325
-
326
- if (inherits(out, "error")) {
327
- showNotification(paste("Error generating randomizations:", out$message), type = "error")
328
- return(NULL)
329
- }
330
  RerandResult(out)
331
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  })
333
 
334
  # Summaries of accepted randomizations
@@ -352,7 +431,28 @@ server <- function(input, output, session) {
352
  }
353
  })
354
 
355
- # Plot histogram of the balance measure
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  output$balance_hist <- renderPlot({
357
  rr <- RerandResult()
358
  req(rr, rr$balance)
@@ -365,16 +465,6 @@ server <- function(input, output, session) {
365
  theme_minimal(base_size = 14)
366
  })
367
 
368
- # --- Performance output for randomization generation
369
- output$time_generate <- renderText({
370
- t <- time_generate()
371
- if (is.na(t)) {
372
- "Not run yet."
373
- } else {
374
- paste0(round(t, 3), " seconds")
375
- }
376
- })
377
-
378
  # -------------------------------------------------------
379
  # 3. Randomization Test
380
  # -------------------------------------------------------
@@ -385,11 +475,14 @@ server <- function(input, output, session) {
385
  req(RerandResult())
386
  rr <- RerandResult()
387
 
388
- nunits <- ncol(rr$randomizations) # #units is #cols in randomizations
389
-
390
- start_time <- Sys.time()
391
  # We'll just use the first accepted randomization as the "observed" assignment
 
 
 
 
 
392
  obsW <- rr$randomizations[1, ]
 
393
 
394
  # Basic data generation: Y = X * beta + tau * W + noise
395
  Xval <- X_data()
@@ -401,9 +494,6 @@ server <- function(input, output, session) {
401
  beta <- rnorm(ncol(Xval), 0, 1)
402
  linear_part <- Xval %*% beta
403
  Ysim <- as.numeric(linear_part + obsW * input$true_tau + rnorm(nunits, 0, input$noise_sd))
404
- end_time <- Sys.time()
405
-
406
- time_data_simY(as.numeric(difftime(end_time, start_time, units = "secs")))
407
 
408
  Y_data(Ysim)
409
  })
@@ -412,12 +502,7 @@ server <- function(input, output, session) {
412
  observeEvent(input$file_outcomes, {
413
  req(input$file_outcomes)
414
  inFile <- input$file_outcomes
415
-
416
- start_time <- Sys.time()
417
  dfy <- tryCatch(read.csv(inFile$datapath, header = FALSE), error=function(e) NULL)
418
- end_time <- Sys.time()
419
- time_data_uploadY(as.numeric(difftime(end_time, start_time, units = "secs")))
420
-
421
  if (!is.null(dfy)) {
422
  if (ncol(dfy) > 1) {
423
  showNotification("Please provide a single-column CSV for Y.", type="error")
@@ -429,6 +514,11 @@ server <- function(input, output, session) {
429
 
430
  # The randomization test result:
431
  RandTestResult <- reactiveVal(NULL)
 
 
 
 
 
432
 
433
  observeEvent(input$run_randtest_btn, {
434
  req(RerandResult())
@@ -439,44 +529,59 @@ server <- function(input, output, session) {
439
  return(NULL)
440
  }
441
 
442
- withProgress(message = "Computing randomization test...", value = 0, {
443
-
444
- start_time <- Sys.time()
445
-
446
- obsW <- rr$randomizations[1, ]
447
- obsY <- Y_data()
448
- cands <- rr$randomizations
449
-
450
- # Check that Y has same length as a single W
451
- if (length(obsY) != length(obsW)) {
452
- showNotification("Dimension mismatch: Y must match number of units in the randomization.",
453
- type = "error")
454
- return(NULL)
455
- }
456
-
457
- # Call the randomization_test function
458
- outTest <- tryCatch({
459
- randomization_test(
460
- obsW = obsW,
461
- obsY = obsY,
462
- candidate_randomizations = cands,
463
- findFI = input$findFI
464
- )
465
- }, error=function(e) e)
466
-
467
- end_time <- Sys.time()
468
- time_randtest(as.numeric(difftime(end_time, start_time, units = "secs")))
469
-
470
- if (inherits(outTest, "error")) {
471
- showNotification(paste("Error in randomization_test:", outTest$message), type="error")
472
- return(NULL)
473
- }
474
-
475
  RandTestResult(outTest)
476
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  })
478
 
479
- # Display p-value and observed tau
480
  output$pvalue_box <- renderValueBox({
481
  rt <- RandTestResult()
482
  if (is.null(rt)) {
@@ -495,6 +600,27 @@ server <- function(input, output, session) {
495
  }
496
  })
497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  # If we have a fiducial interval, display it
499
  output$fi_text <- renderUI({
500
  rt <- RandTestResult()
@@ -510,13 +636,18 @@ server <- function(input, output, session) {
510
  )
511
  })
512
 
513
- # A simple plot for the randomization distribution
514
- # (no distribution stored by default, so just show the observed effect)
 
515
  output$test_plot <- renderPlot({
516
  rt <- RandTestResult()
517
  if (is.null(rt)) {
 
 
 
518
  return(NULL)
519
  }
 
520
  obs_val <- rt$tau_obs
521
 
522
  ggplot(data.frame(x = obs_val, y = 0), aes(x, y)) +
@@ -527,34 +658,6 @@ server <- function(input, output, session) {
527
  theme_minimal(base_size = 14) +
528
  geom_vline(xintercept = 0, linetype="dashed", color="gray40")
529
  })
530
-
531
- # --- Performance outputs for outcomes and randomization test
532
- output$time_data_uploadY <- renderText({
533
- t <- time_data_uploadY()
534
- if (is.na(t)) {
535
- "Not run yet."
536
- } else {
537
- paste0(round(t, 3), " seconds")
538
- }
539
- })
540
-
541
- output$time_data_simY <- renderText({
542
- t <- time_data_simY()
543
- if (is.na(t)) {
544
- "Not run yet."
545
- } else {
546
- paste0(round(t, 3), " seconds")
547
- }
548
- })
549
-
550
- output$time_randtest <- renderText({
551
- t <- time_randtest()
552
- if (is.na(t)) {
553
- "Not run yet."
554
- } else {
555
- paste0(round(t, 3), " seconds")
556
- }
557
- })
558
  }
559
 
560
  # ---------------------------------------------------------
 
1
+ #
2
  # ============================================================
3
  # app.R | Shiny App for Rerandomization with fastrerandomize
4
  # ============================================================
 
6
  # 2) They specify rerandomization parameters: n_treated, acceptance prob, etc.
7
  # 3) The app generates a set of accepted randomizations under rerandomization.
8
  # 4) The user can optionally upload or simulate outcomes (Y) and run a randomization test.
9
+ # 5) The app displays distribution of the balance measure (e.g., Hotelling's T^2)
10
+ # and final p-value/fiducial interval, along with run-time comparisons between
11
+ # fastrerandomize and base R methods.
12
+ #
13
  # ----------------------------
14
  # Load required packages
15
  # ----------------------------
 
23
  # install.packages("devtools")
24
  # devtools::install_github("cjerzak/fastrerandomize-software/fastrerandomize")
25
 
26
+ # ---------------------------------------------------------
27
+ # HELPER FUNCTIONS (BASE R)
28
+ # ---------------------------------------------------------
29
+ # 1) Compute Hotelling's T^2 in base R
30
+ baseR_hotellingT2 <- function(X, W) {
31
+ # For a single assignment W:
32
+ # T^2 = (n0 * n1 / (n0 + n1)) * (xbar1 - xbar0)^T * S_inv * (xbar1 - xbar0)
33
+ n <- length(W)
34
+ n1 <- sum(W)
35
+ n0 <- n - n1
36
+ if (n1 == 0 || n0 == 0) return(NA_real_) # invalid scenario
37
+ xbar_treat <- colMeans(X[W == 1, , drop = FALSE])
38
+ xbar_control <- colMeans(X[W == 0, , drop = FALSE])
39
+ diff_vec <- (xbar_treat - xbar_control)
40
+
41
+ # covariance (pooled) – we just use cov(X)
42
+ S <- cov(X)
43
+ Sinv <- tryCatch(solve(S), error = function(e) NULL)
44
+ if (is.null(Sinv)) {
45
+ # fallback: diagonal approximation if solve fails
46
+ Sinv <- diag(1 / diag(S), ncol(S))
47
+ }
48
+
49
+ out <- (n0 * n1 / (n0 + n1)) * c(t(diff_vec) %*% Sinv %*% diff_vec)
50
+ out
51
+ }
52
+
53
+ # 2) Generate randomizations in base R, filtering by acceptance probability
54
+ # using T^2 and keep the best (lowest) fraction.
55
+ baseR_generate_randomizations <- function(n_units, n_treated, X, accept_prob, random_type,
56
+ max_draws, batch_size) {
57
+
58
+ if (random_type == "exact") {
59
+ # -------------- EXACT RANDOMIZATIONS --------------
60
+ # All combinations
61
+ cidx <- combn(n_units, n_treated)
62
+ # Build assignment matrix
63
+ n_comb <- ncol(cidx)
64
+ assignment_mat <- matrix(0, nrow = n_comb, ncol = n_units)
65
+ for (i in seq_len(n_comb)) {
66
+ assignment_mat[i, cidx[, i]] <- 1
67
+ }
68
+ # Compute T^2 for each row
69
+ T2vals <- apply(assignment_mat, 1, function(w) baseR_hotellingT2(X, w))
70
+ # Drop any NA (in pathological cases)
71
+ keep_idx <- which(!is.na(T2vals))
72
+ assignment_mat <- assignment_mat[keep_idx, , drop = FALSE]
73
+ T2vals <- T2vals[keep_idx]
74
+
75
+ # acceptance threshold
76
+ cutoff <- quantile(T2vals, probs = accept_prob)
77
+ keep_final <- (T2vals < cutoff)
78
+ assignment_mat_accepted <- assignment_mat[keep_final, , drop = FALSE]
79
+ T2vals_accepted <- T2vals[keep_final]
80
+
81
+ } else {
82
+ # -------------- MONTE CARLO RANDOMIZATIONS --------------
83
+ # We'll sample max_draws permutations
84
+ # Start with a base assignment vector
85
+ base_assign <- c(rep(1, n_treated), rep(0, n_units - n_treated))
86
+
87
+ # shuffle in R
88
+ # We'll store T2's in chunks to reduce memory overhead
89
+ batch_count <- ceiling(max_draws / batch_size)
90
+ all_assign <- list()
91
+ all_T2 <- numeric(0)
92
+
93
+ cur_draw <- 0
94
+ for (b in seq_len(batch_count)) {
95
+ # how many draws in this batch
96
+ ndraws_here <- min(batch_size, max_draws - cur_draw)
97
+ cur_draw <- cur_draw + ndraws_here
98
+
99
+ # sample permutations
100
+ perms <- matrix(nrow = ndraws_here, ncol = n_units)
101
+ for (j in seq_len(ndraws_here)) {
102
+ perms[j, ] <- sample(base_assign)
103
+ }
104
+ # T^2 for each
105
+ T2vals_batch <- apply(perms, 1, function(w) baseR_hotellingT2(X, w))
106
+
107
+ # collect
108
+ all_assign[[b]] <- perms
109
+ all_T2 <- c(all_T2, T2vals_batch)
110
+ }
111
+ assignment_mat <- do.call(rbind, all_assign)
112
+
113
+ # remove any NA
114
+ keep_idx <- which(!is.na(all_T2))
115
+ assignment_mat <- assignment_mat[keep_idx, , drop = FALSE]
116
+ all_T2 <- all_T2[keep_idx]
117
+
118
+ # acceptance threshold
119
+ cutoff <- quantile(all_T2, probs = accept_prob)
120
+ keep_final <- (all_T2 < cutoff)
121
+ assignment_mat_accepted <- assignment_mat[keep_final, , drop = FALSE]
122
+ T2vals_accepted <- all_T2[keep_final]
123
+ }
124
+
125
+ list(randomizations = assignment_mat_accepted, balance = T2vals_accepted)
126
+ }
127
+
128
+ # 3) Base R randomization test: difference in means
129
+ baseR_randomization_test <- function(obsW, obsY, allW) {
130
+ # obs diff in means
131
+ n1 <- sum(obsW)
132
+ n0 <- length(obsW) - n1
133
+ obs_diff <- mean(obsY[obsW == 1]) - mean(obsY[obsW == 0])
134
+
135
+ # for each candidate assignment, compute diff in means on obsY
136
+ diffs <- apply(allW, 1, function(w) {
137
+ mean(obsY[w == 1]) - mean(obsY[w == 0])
138
+ })
139
+
140
+ # p-value = fraction whose absolute diff >= observed
141
+ pval <- mean(abs(diffs) >= abs(obs_diff))
142
+ list(p_value = pval, tau_obs = obs_diff)
143
+ }
144
+
145
  # ---------------------------------------------------------
146
  # UI Section
147
  # ---------------------------------------------------------
 
202
 
203
  conditionalPanel(
204
  condition = "input.data_source == 'simulate'",
205
+ numericInput("sim_n", "Number of units (rows)", value = 50, min = 2),
206
+ numericInput("sim_p", "Number of covariates (columns)", value = 10, min = 1),
207
  actionButton("simulate_btn", "Simulate X")
208
  )
209
  ),
 
211
  box(width = 7, title = "Preview of Covariates (X)",
212
  status = "info", solidHeader = TRUE,
213
  DTOutput("covariates_table"))
 
 
 
 
 
 
 
 
 
 
 
214
  )
215
  ),
216
 
 
233
  value = 0.05, min = 0.0001, max = 1),
234
  conditionalPanel(
235
  condition = "input.random_type == 'monte_carlo'",
236
+ numericInput("max_draws", "Max Draws (MC)", value = 1e5, min = 1e3),
237
  numericInput("batch_size", "Batch Size (MC)", value = 1e3, min = 1e2)
238
  ),
239
  actionButton("generate_btn", "Generate Randomizations")
 
242
  box(width = 8, title = "Summary of Accepted Randomizations",
243
  status = "info", solidHeader = TRUE,
244
  fluidRow(
245
+ valueBoxOutput("n_accepted_box", width = 4),
246
+ valueBoxOutput("balance_min_box", width = 4),
247
+ valueBoxOutput("fastrerand_time_box", width = 4),
248
+ valueBoxOutput("baseR_time_box", width = 4)
249
  ),
250
  br(),
251
  plotOutput("balance_hist", height = "250px")
252
  )
 
 
 
 
 
 
 
253
  )
254
  ),
255
 
 
288
 
289
  box(width = 8, title = "Test Results", status = "info", solidHeader = TRUE,
290
  fluidRow(
291
+ valueBoxOutput("pvalue_box", width = 4),
292
+ valueBoxOutput("tauobs_box", width = 4),
293
+ valueBoxOutput("fastrerand_test_time_box", width = 4),
294
+ valueBoxOutput("baseR_test_time_box", width = 4)
295
  ),
296
  uiOutput("fi_text"),
297
  br(),
298
  plotOutput("test_plot", height = "280px")
299
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  )
301
  )
302
 
 
309
  # ---------------------------------------------------------
310
  server <- function(input, output, session) {
311
 
 
 
 
 
 
 
 
 
312
  # -------------------------------------------------------
313
  # 1. Covariate Data Handling
314
  # -------------------------------------------------------
315
  # We store the covariate matrix X in a reactiveVal for convenient reuse
316
  X_data <- reactiveVal(NULL)
317
 
318
+ # Observe file input or simulation for X
319
  observeEvent(input$file_covariates, {
320
  req(input$file_covariates)
321
  inFile <- input$file_covariates
 
 
322
  df <- tryCatch(read.csv(inFile$datapath, header = TRUE),
323
  error = function(e) NULL)
 
 
 
324
  if (!is.null(df)) {
325
  X_data(as.matrix(df))
326
  }
 
330
  observeEvent(input$simulate_btn, {
331
  n <- input$sim_n
332
  p <- input$sim_p
 
 
333
  # Basic simulation of N(0,1) data
334
  simX <- matrix(rnorm(n * p), nrow = n, ncol = p)
 
 
 
335
  X_data(simX)
336
  })
337
 
 
342
  options = list(scrollX = TRUE, pageLength = 5))
343
  })
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  # -------------------------------------------------------
346
  # 2. Generate Rerandomizations
347
  # -------------------------------------------------------
348
+ # We'll keep the accepted randomizations from fastrerandomize in RerandResult
349
+ # and from base R in RerandResult_base.
350
  RerandResult <- reactiveVal(NULL)
351
+ RerandResult_base <- reactiveVal(NULL)
352
+
353
+ # We also store their run times
354
+ fastrand_time <- reactiveVal(NULL)
355
+ baseR_time <- reactiveVal(NULL)
356
 
357
  observeEvent(input$generate_btn, {
358
  req(X_data())
 
361
  "Number treated cannot exceed total units.")
362
  )
363
 
364
+ # =========== 1) fastrerandomize generation timing ===========
365
+ t0_fast <- Sys.time()
366
+ out <- tryCatch({
367
+ generate_randomizations(
368
+ n_units = nrow(X_data()),
369
+ n_treated = input$n_treated,
370
+ X = X_data(),
371
+ randomization_accept_prob= input$accept_prob,
372
+ randomization_type = input$random_type,
373
+ max_draws = if (input$random_type == "monte_carlo") input$max_draws else NULL,
374
+ batch_size = if (input$random_type == "monte_carlo") input$batch_size else NULL,
375
+ verbose = FALSE
376
+ )
377
+ }, error = function(e) e)
378
+ t1_fast <- Sys.time()
379
+
380
+ if (inherits(out, "error")) {
381
+ showNotification(paste("Error generating randomizations (fastrerandomize):", out$message), type = "error")
382
+ RerandResult(NULL)
383
+ } else {
 
 
 
 
 
 
 
 
 
384
  RerandResult(out)
385
+ }
386
+ fastrand_time(difftime(t1_fast, t0_fast, units = "secs"))
387
+
388
+ # =========== 2) base R generation timing ===========
389
+ t0_base <- Sys.time()
390
+ out_base <- tryCatch({
391
+ baseR_generate_randomizations(
392
+ n_units = nrow(X_data()),
393
+ n_treated = input$n_treated,
394
+ X = X_data(),
395
+ accept_prob= input$accept_prob,
396
+ random_type= input$random_type,
397
+ max_draws = if (input$random_type == "monte_carlo") input$max_draws else NULL,
398
+ batch_size = if (input$random_type == "monte_carlo") input$batch_size else NULL
399
+ )
400
+ }, error = function(e) e)
401
+ t1_base <- Sys.time()
402
+
403
+ if (inherits(out_base, "error")) {
404
+ showNotification(paste("Error generating randomizations (base R):", out_base$message), type = "error")
405
+ RerandResult_base(NULL)
406
+ } else {
407
+ RerandResult_base(out_base)
408
+ }
409
+ baseR_time(difftime(t1_base, t0_base, units = "secs"))
410
+
411
  })
412
 
413
  # Summaries of accepted randomizations
 
431
  }
432
  })
433
 
434
+ # Timings for generation: fastrerandomize and base R
435
+ output$fastrerand_time_box <- renderValueBox({
436
+ tm <- fastrand_time()
437
+ if (is.null(tm)) {
438
+ valueBox("---", "fastrerandomize generation time (secs)", icon = icon("clock"), color = "teal")
439
+ } else {
440
+ valueBox(round(as.numeric(tm), 3), "fastrerandomize generation time (secs)",
441
+ icon = icon("clock"), color = "teal")
442
+ }
443
+ })
444
+
445
+ output$baseR_time_box <- renderValueBox({
446
+ tm <- baseR_time()
447
+ if (is.null(tm)) {
448
+ valueBox("---", "base R generation time (secs)", icon = icon("clock"), color = "lime")
449
+ } else {
450
+ valueBox(round(as.numeric(tm), 3), "base R generation time (secs)",
451
+ icon = icon("clock"), color = "lime")
452
+ }
453
+ })
454
+
455
+ # Plot histogram of the balance measure (fastrerandomize result)
456
  output$balance_hist <- renderPlot({
457
  rr <- RerandResult()
458
  req(rr, rr$balance)
 
465
  theme_minimal(base_size = 14)
466
  })
467
 
 
 
 
 
 
 
 
 
 
 
468
  # -------------------------------------------------------
469
  # 3. Randomization Test
470
  # -------------------------------------------------------
 
475
  req(RerandResult())
476
  rr <- RerandResult()
477
 
 
 
 
478
  # We'll just use the first accepted randomization as the "observed" assignment
479
+ if (is.null(rr$randomizations) || nrow(rr$randomizations) < 1) {
480
+ showNotification("No accepted randomizations found. Cannot simulate Y for the 'observed' assignment.", type = "error")
481
+ return(NULL)
482
+ }
483
+
484
  obsW <- rr$randomizations[1, ]
485
+ nunits <- length(obsW)
486
 
487
  # Basic data generation: Y = X * beta + tau * W + noise
488
  Xval <- X_data()
 
494
  beta <- rnorm(ncol(Xval), 0, 1)
495
  linear_part <- Xval %*% beta
496
  Ysim <- as.numeric(linear_part + obsW * input$true_tau + rnorm(nunits, 0, input$noise_sd))
 
 
 
497
 
498
  Y_data(Ysim)
499
  })
 
502
  observeEvent(input$file_outcomes, {
503
  req(input$file_outcomes)
504
  inFile <- input$file_outcomes
 
 
505
  dfy <- tryCatch(read.csv(inFile$datapath, header = FALSE), error=function(e) NULL)
 
 
 
506
  if (!is.null(dfy)) {
507
  if (ncol(dfy) > 1) {
508
  showNotification("Please provide a single-column CSV for Y.", type="error")
 
514
 
515
  # The randomization test result:
516
  RandTestResult <- reactiveVal(NULL)
517
+ RandTestResult_base <- reactiveVal(NULL)
518
+
519
+ # We'll store their times:
520
+ fastrand_test_time <- reactiveVal(NULL)
521
+ baseR_test_time <- reactiveVal(NULL)
522
 
523
  observeEvent(input$run_randtest_btn, {
524
  req(RerandResult())
 
529
  return(NULL)
530
  }
531
 
532
+ obsW <- rr$randomizations[1, ]
533
+ obsY <- Y_data()
534
+
535
+ # =========== 1) fastrerandomize randomization_test timing ===========
536
+ t0_testfast <- Sys.time()
537
+ outTest <- tryCatch({
538
+ randomization_test(
539
+ obsW = obsW,
540
+ obsY = obsY,
541
+ candidate_randomizations = rr$randomizations,
542
+ findFI = input$findFI
543
+ )
544
+ }, error=function(e) e)
545
+ t1_testfast <- Sys.time()
546
+
547
+ if (inherits(outTest, "error")) {
548
+ showNotification(paste("Error in randomization_test (fastrerandomize):", outTest$message), type="error")
549
+ RandTestResult(NULL)
550
+ } else {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  RandTestResult(outTest)
552
+ }
553
+ fastrand_test_time(difftime(t1_testfast, t0_testfast, units = "secs"))
554
+
555
+ # =========== 2) base R randomization test timing ===========
556
+ # We must also have the base R set of randomizations
557
+ req(RerandResult_base())
558
+ rr_base <- RerandResult_base()
559
+ if (is.null(rr_base$randomizations) || nrow(rr_base$randomizations) < 1) {
560
+ showNotification("No base R randomizations found. Cannot run base R test.", type = "error")
561
+ RandTestResult_base(NULL)
562
+ return(NULL)
563
+ }
564
+
565
+ t0_testbase <- Sys.time()
566
+ outTestBase <- tryCatch({
567
+ baseR_randomization_test(
568
+ obsW = obsW,
569
+ obsY = obsY,
570
+ allW = rr_base$randomizations
571
+ )
572
+ }, error = function(e) e)
573
+ t1_testbase <- Sys.time()
574
+
575
+ if (inherits(outTestBase, "error")) {
576
+ showNotification(paste("Error in randomization_test (base R):", outTestBase$message), type="error")
577
+ RandTestResult_base(NULL)
578
+ } else {
579
+ RandTestResult_base(outTestBase)
580
+ }
581
+ baseR_test_time(difftime(t1_testbase, t0_testbase, units = "secs"))
582
  })
583
 
584
+ # Display p-value and observed tau (from the fastrerandomize test)
585
  output$pvalue_box <- renderValueBox({
586
  rt <- RandTestResult()
587
  if (is.null(rt)) {
 
600
  }
601
  })
602
 
603
+ # Times for randomization test
604
+ output$fastrerand_test_time_box <- renderValueBox({
605
+ tm <- fastrand_test_time()
606
+ if (is.null(tm)) {
607
+ valueBox("---", "fastrerandomize test time (secs)", icon = icon("clock"), color = "teal")
608
+ } else {
609
+ valueBox(round(as.numeric(tm), 3), "fastrerandomize test time (secs)",
610
+ icon = icon("clock"), color = "teal")
611
+ }
612
+ })
613
+
614
+ output$baseR_test_time_box <- renderValueBox({
615
+ tm <- baseR_test_time()
616
+ if (is.null(tm)) {
617
+ valueBox("---", "base R test time (secs)", icon = icon("clock"), color = "lime")
618
+ } else {
619
+ valueBox(round(as.numeric(tm), 3), "base R test time (secs)",
620
+ icon = icon("clock"), color = "lime")
621
+ }
622
+ })
623
+
624
  # If we have a fiducial interval, display it
625
  output$fi_text <- renderUI({
626
  rt <- RandTestResult()
 
636
  )
637
  })
638
 
639
+ # A simple plot for the randomization distribution (for demonstration).
640
+ # In the minimal example, we do not store the entire distribution in 'randomization_test',
641
+ # so we simply show the observed effect with a placeholder.
642
  output$test_plot <- renderPlot({
643
  rt <- RandTestResult()
644
  if (is.null(rt)) {
645
+ # no test run yet
646
+ plot.new()
647
+ title("No test results yet.")
648
  return(NULL)
649
  }
650
+ # Just display the observed effect
651
  obs_val <- rt$tau_obs
652
 
653
  ggplot(data.frame(x = obs_val, y = 0), aes(x, y)) +
 
658
  theme_minimal(base_size = 14) +
659
  geom_vline(xintercept = 0, linetype="dashed", color="gray40")
660
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
661
  }
662
 
663
  # ---------------------------------------------------------