Pre-load RF-DETR models at startup and add progress bar

#1
Files changed (3) hide show
  1. .gitignore +0 -2
  2. app.py +90 -397
  3. requirements.txt +2 -2
.gitignore DELETED
@@ -1,2 +0,0 @@
1
- .idea/
2
- .gradio/
 
 
 
app.py CHANGED
@@ -1,19 +1,13 @@
 
 
1
  from __future__ import annotations
2
 
3
- import os
4
- import sys
5
  import tempfile
6
  from pathlib import Path
7
 
8
  import cv2
9
  import gradio as gr
10
- import numpy as np
11
- import supervision as sv
12
- import torch
13
- from tqdm import tqdm
14
- from inference_models import AutoModel
15
-
16
- from trackers import ByteTrackTracker, OCSORTTracker, SORTTracker, frames_from_source
17
 
18
  MAX_DURATION_SECONDS = 30
19
 
@@ -28,7 +22,7 @@ MODELS = [
28
  "rfdetr-seg-large",
29
  ]
30
 
31
- TRACKERS = ["bytetrack", "sort", "ocsort"]
32
 
33
  COCO_CLASSES = [
34
  "person",
@@ -43,142 +37,18 @@ COCO_CLASSES = [
43
  "sports ball",
44
  ]
45
 
46
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
47
-
48
- print(f"Loading {len(MODELS)} models on {DEVICE}...")
49
- LOADED_MODELS = {}
50
- for model_id in MODELS:
51
- print(f" Loading {model_id}...")
52
- LOADED_MODELS[model_id] = AutoModel.from_pretrained(model_id, device=DEVICE)
53
- print("All models loaded.")
54
-
55
- COLOR_PALETTE = sv.ColorPalette.from_hex(
56
- [
57
- "#ffff00",
58
- "#ff9b00",
59
- "#ff8080",
60
- "#ff66b2",
61
- "#ff66ff",
62
- "#b266ff",
63
- "#9999ff",
64
- "#3399ff",
65
- "#66ffff",
66
- "#33ff99",
67
- "#66ff66",
68
- "#99ff00",
69
- ]
70
- )
71
-
72
- RESULTS_DIR = "results"
73
- os.makedirs(RESULTS_DIR, exist_ok=True)
74
-
75
-
76
- def _init_annotators(
77
- show_boxes: bool = False,
78
- show_masks: bool = False,
79
- show_labels: bool = False,
80
- show_ids: bool = False,
81
- show_confidence: bool = False,
82
- ) -> tuple[list, sv.LabelAnnotator | None]:
83
- """Initialize supervision annotators based on display options."""
84
- annotators: list = []
85
- label_annotator: sv.LabelAnnotator | None = None
86
-
87
- if show_masks:
88
- annotators.append(
89
- sv.MaskAnnotator(
90
- color=COLOR_PALETTE,
91
- color_lookup=sv.ColorLookup.TRACK,
92
- )
93
- )
94
-
95
- if show_boxes:
96
- annotators.append(
97
- sv.BoxAnnotator(
98
- color=COLOR_PALETTE,
99
- color_lookup=sv.ColorLookup.TRACK,
100
- )
101
- )
102
-
103
- if show_labels or show_ids or show_confidence:
104
- label_annotator = sv.LabelAnnotator(
105
- color=COLOR_PALETTE,
106
- text_color=sv.Color.BLACK,
107
- text_position=sv.Position.TOP_LEFT,
108
- color_lookup=sv.ColorLookup.TRACK,
109
- )
110
-
111
- return annotators, label_annotator
112
-
113
-
114
- def _format_labels(
115
- detections: sv.Detections,
116
- class_names: list[str],
117
- *,
118
- show_ids: bool = False,
119
- show_labels: bool = False,
120
- show_confidence: bool = False,
121
- ) -> list[str]:
122
- """Generate label strings for each detection."""
123
- labels = []
124
-
125
- for i in range(len(detections)):
126
- parts = []
127
-
128
- if show_ids and detections.tracker_id is not None:
129
- parts.append(f"#{int(detections.tracker_id[i])}")
130
-
131
- if show_labels and detections.class_id is not None:
132
- class_id = int(detections.class_id[i])
133
- if class_names and 0 <= class_id < len(class_names):
134
- parts.append(class_names[class_id])
135
- else:
136
- parts.append(str(class_id))
137
-
138
- if show_confidence and detections.confidence is not None:
139
- parts.append(f"{detections.confidence[i]:.2f}")
140
-
141
- labels.append(" ".join(parts))
142
-
143
- return labels
144
-
145
-
146
  VIDEO_EXAMPLES = [
147
  [
148
  "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-1.mp4",
149
  "rfdetr-small",
150
- "ocsort",
151
  0.2,
152
  30,
153
  0.3,
154
  3,
155
  0.1,
156
  0.6,
157
- 0.2,
158
- 3,
159
  [],
160
- "",
161
- True,
162
- True,
163
- False,
164
- False,
165
- True,
166
- False,
167
- ],
168
- [
169
- "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-1.mp4",
170
- "rfdetr-small",
171
- "ocsort",
172
- 0.2,
173
- 30,
174
- 0.3,
175
- 3,
176
- 0.1,
177
- 0.6,
178
- 0.2,
179
- 3,
180
- ["person"],
181
- "",
182
  True,
183
  True,
184
  False,
@@ -196,10 +66,7 @@ VIDEO_EXAMPLES = [
196
  3,
197
  0.3,
198
  0.6,
199
- 0.2,
200
- 3,
201
  [],
202
- "",
203
  True,
204
  True,
205
  False,
@@ -208,25 +75,22 @@ VIDEO_EXAMPLES = [
208
  True,
209
  ],
210
  [
211
- "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/apples-1280x720-2.mp4",
212
- "rfdetr-nano",
213
- "sort",
214
  0.2,
215
  30,
216
  0.3,
217
  3,
218
  0.1,
219
  0.6,
220
- 0.2,
221
- 3,
222
- [],
223
- "",
224
- True,
225
  True,
226
  True,
227
  False,
228
  True,
229
  False,
 
230
  ],
231
  [
232
  "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/jets-1280x720-1.mp4",
@@ -238,10 +102,7 @@ VIDEO_EXAMPLES = [
238
  3,
239
  0.1,
240
  0.6,
241
- 0.2,
242
- 3,
243
  [],
244
- "",
245
  True,
246
  True,
247
  False,
@@ -259,73 +120,25 @@ VIDEO_EXAMPLES = [
259
  3,
260
  0.1,
261
  0.6,
262
- 0.2,
263
- 3,
264
  [],
265
- "",
266
  True,
267
  True,
268
  False,
269
  False,
270
  True,
271
- True,
272
- ],
273
- [
274
- "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/jets-1280x720-2.mp4",
275
- "rfdetr-seg-small",
276
- "bytetrack",
277
- 0.2,
278
- 30,
279
- 0.3,
280
- 3,
281
- 0.1,
282
- 0.6,
283
- 0.2,
284
- 3,
285
- [],
286
- "1",
287
- True,
288
- True,
289
- False,
290
- False,
291
- True,
292
- True,
293
- ],
294
- [
295
- "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/suitcases-1280x720-4.mp4",
296
- "rfdetr-small",
297
- "sort",
298
- 0.2,
299
- 30,
300
- 0.3,
301
- 3,
302
- 0.1,
303
- 0.6,
304
- 0.2,
305
- 3,
306
- [],
307
- "",
308
- True,
309
- True,
310
- True,
311
- False,
312
- True,
313
  False,
314
  ],
315
  [
316
  "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/vehicles-1280x720.mp4",
317
- "rfdetr-medium",
318
- "ocsort",
319
  0.2,
320
  30,
321
  0.3,
322
  3,
323
  0.1,
324
  0.6,
325
- 0.2,
326
- 3,
327
  [],
328
- "",
329
  True,
330
  True,
331
  True,
@@ -336,190 +149,102 @@ VIDEO_EXAMPLES = [
336
  ]
337
 
338
 
339
- def _get_video_info(path: str) -> tuple[float, int]:
340
- """Return video duration in seconds and frame count using OpenCV."""
341
- video_capture = cv2.VideoCapture(path)
342
- if not video_capture.isOpened():
343
  raise gr.Error("Could not open the uploaded video.")
344
- frames_per_second = video_capture.get(cv2.CAP_PROP_FPS)
345
- frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
346
- video_capture.release()
347
- if frames_per_second <= 0:
348
  raise gr.Error("Could not determine video frame rate.")
349
- return frame_count / frames_per_second, frame_count
350
-
351
-
352
- def _resolve_class_filter(
353
- classes: list[str] | None,
354
- class_names: list[str],
355
- ) -> list[int] | None:
356
- """Resolve class names to integer IDs."""
357
- if not classes:
358
- return None
359
-
360
- name_to_id = {name: i for i, name in enumerate(class_names)}
361
- class_filter: list[int] = []
362
- for name in classes:
363
- if name in name_to_id:
364
- class_filter.append(name_to_id[name])
365
- return class_filter if class_filter else None
366
-
367
-
368
- def _resolve_track_id_filter(track_ids_arg: str | None) -> list[int] | None:
369
- """Resolve a comma-separated string of track IDs to a list of integers.
370
-
371
- Args:
372
- track_ids_arg: Comma-separated string (e.g. `"1,3,5"`). `None` or
373
- empty string means no filter.
374
-
375
- Returns:
376
- List of integer track IDs, or `None` when no valid filter remains.
377
- """
378
- if not track_ids_arg:
379
- return None
380
-
381
- track_ids: list[int] = []
382
- for token in track_ids_arg.split(","):
383
- token = token.strip()
384
- try:
385
- track_ids.append(int(token))
386
- except ValueError:
387
- print(
388
- f"Warning: '{token}' is not a valid track ID, skipping.",
389
- file=sys.stderr,
390
- )
391
- return track_ids if track_ids else None
392
 
393
 
394
  def track(
395
  video_path: str,
396
- model_id: str,
397
- tracker_type: str,
398
  confidence: float,
399
  lost_track_buffer: int,
400
  track_activation_threshold: float,
401
  minimum_consecutive_frames: int,
402
  minimum_iou_threshold: float,
403
  high_conf_det_threshold: float,
404
- direction_consistency_weight: float,
405
- delta_t: int,
406
  classes: list[str] | None = None,
407
- track_ids: str = "",
408
  show_boxes: bool = True,
409
  show_ids: bool = True,
410
  show_labels: bool = False,
411
  show_confidence: bool = False,
412
  show_trajectories: bool = False,
413
  show_masks: bool = False,
414
- progress=gr.Progress(track_tqdm=True),
415
  ) -> str:
416
  """Run tracking on the uploaded video and return the output path."""
417
  if video_path is None:
418
  raise gr.Error("Please upload a video.")
419
 
420
- duration, total_frames = _get_video_info(video_path)
421
  if duration > MAX_DURATION_SECONDS:
422
  raise gr.Error(
423
  f"Video is {duration:.1f}s long. "
424
- f"Maximum allowed duration is {MAX_DURATION_SECONDS}s. "
425
- f"Please use the trim tool in the Input Video player to shorten it."
426
  )
427
 
428
- detection_model = LOADED_MODELS[model_id]
429
- class_names = getattr(detection_model, "class_names", [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
- selected_class_ids = _resolve_class_filter(classes, class_names)
432
- selected_track_ids = _resolve_track_id_filter(track_ids)
 
433
 
434
- if tracker_type == "bytetrack":
435
- tracker = ByteTrackTracker(
436
- lost_track_buffer=lost_track_buffer,
437
- track_activation_threshold=track_activation_threshold,
438
- minimum_consecutive_frames=minimum_consecutive_frames,
439
- minimum_iou_threshold=minimum_iou_threshold,
440
- high_conf_det_threshold=high_conf_det_threshold,
441
- )
442
- elif tracker_type == "ocsort":
443
- tracker = OCSORTTracker(
444
- lost_track_buffer=lost_track_buffer,
445
- minimum_consecutive_frames=minimum_consecutive_frames,
446
- minimum_iou_threshold=minimum_iou_threshold,
447
- high_conf_det_threshold=high_conf_det_threshold,
448
- direction_consistency_weight=direction_consistency_weight,
449
- delta_t=delta_t,
450
- )
451
- else:
452
- tracker = SORTTracker(
453
- lost_track_buffer=lost_track_buffer,
454
- track_activation_threshold=track_activation_threshold,
455
- minimum_consecutive_frames=minimum_consecutive_frames,
456
- minimum_iou_threshold=minimum_iou_threshold,
457
- )
458
- tracker.reset()
459
 
460
- annotators, label_annotator = _init_annotators(
461
- show_boxes=show_boxes,
462
- show_masks=show_masks,
463
- show_labels=show_labels,
464
- show_ids=show_ids,
465
- show_confidence=show_confidence,
466
- )
467
- trace_annotator = None
 
 
468
  if show_trajectories:
469
- trace_annotator = sv.TraceAnnotator(
470
- color=COLOR_PALETTE,
471
- color_lookup=sv.ColorLookup.TRACK,
472
- )
473
-
474
- temporary_directory = tempfile.mkdtemp()
475
- output_path = str(Path(temporary_directory) / "output.mp4")
476
-
477
- video_info = sv.VideoInfo.from_video_path(video_path)
478
-
479
- frame_generator = frames_from_source(video_path)
480
-
481
- with sv.VideoSink(output_path, video_info=video_info) as sink:
482
- for frame_idx, frame in tqdm(
483
- frame_generator, total=total_frames, desc="Processing video..."
484
- ):
485
- predictions = detection_model(frame)
486
- if predictions:
487
- detections = predictions[0].to_supervision()
488
-
489
- if len(detections) > 0 and detections.confidence is not None:
490
- confidence_mask = detections.confidence >= confidence
491
- detections = detections[confidence_mask]
492
-
493
- if selected_class_ids is not None and len(detections) > 0:
494
- class_mask = np.isin(detections.class_id, selected_class_ids)
495
- detections = detections[class_mask]
496
- else:
497
- detections = sv.Detections.empty()
498
-
499
- tracked = tracker.update(detections)
500
-
501
- if selected_track_ids is not None and len(tracked) > 0:
502
- if tracked.tracker_id is not None:
503
- track_id_mask = np.isin(tracked.tracker_id, selected_track_ids)
504
- tracked = tracked[track_id_mask]
505
-
506
- annotated = frame.copy()
507
- if trace_annotator is not None:
508
- annotated = trace_annotator.annotate(annotated, tracked)
509
- for annotator in annotators:
510
- annotated = annotator.annotate(annotated, tracked)
511
- if label_annotator is not None:
512
- labeled = tracked[tracked.tracker_id != -1]
513
- labels = _format_labels(
514
- labeled,
515
- class_names,
516
- show_ids=show_ids,
517
- show_labels=show_labels,
518
- show_confidence=show_confidence,
519
- )
520
- annotated = label_annotator.annotate(annotated, labeled, labels=labels)
521
 
522
- sink.write_frame(annotated)
 
 
523
 
524
  return output_path
525
 
@@ -536,7 +261,7 @@ with gr.Blocks(title="Trackers Playground 🔥") as demo:
536
  input_video = gr.Video(label="Input Video")
537
  output_video = gr.Video(label="Tracked Video")
538
 
539
- track_button = gr.Button(value="Track", variant="primary")
540
 
541
  with gr.Row():
542
  model_dropdown = gr.Dropdown(
@@ -568,16 +293,6 @@ with gr.Blocks(title="Trackers Playground 🔥") as demo:
568
  label="Filter Classes",
569
  info="Only track selected classes. None selected means all.",
570
  )
571
- track_id_filter = gr.Textbox(
572
- value="",
573
- label="Filter IDs",
574
- info=(
575
- "Only display tracks with specific track IDs "
576
- "(comma-separated, e.g. 1,3,5). "
577
- "Leave empty for all."
578
- ),
579
- placeholder="e.g. 1,3,5",
580
- )
581
 
582
  with gr.Column():
583
  gr.Markdown("### Tracker")
@@ -587,7 +302,7 @@ with gr.Blocks(title="Trackers Playground 🔥") as demo:
587
  value=30,
588
  step=1,
589
  label="Lost Track Buffer",
590
- info="Frames to keep a lost track before removing it (ByteTrack, SORT, OC-SORT).",
591
  )
592
  track_activation_slider = gr.Slider(
593
  minimum=0.0,
@@ -595,47 +310,31 @@ with gr.Blocks(title="Trackers Playground 🔥") as demo:
595
  value=0.3,
596
  step=0.05,
597
  label="Track Activation Threshold",
598
- info="Minimum score for a track to be activated (ByteTrack, SORT).",
599
  )
600
- minimum_consecutive_slider = gr.Slider(
601
  minimum=1,
602
  maximum=10,
603
  value=2,
604
  step=1,
605
  label="Minimum Consecutive Frames",
606
- info="Detections needed before a track is confirmed (ByteTrack, SORT, OC-SORT).",
607
  )
608
- minimum_iou_slider = gr.Slider(
609
  minimum=0.0,
610
  maximum=1.0,
611
  value=0.1,
612
  step=0.05,
613
  label="Minimum IoU Threshold",
614
- info="Overlap required to match a detection to a track (ByteTrack, SORT, OC-SORT).",
615
  )
616
- high_confidence_slider = gr.Slider(
617
  minimum=0.0,
618
  maximum=1.0,
619
  value=0.6,
620
  step=0.05,
621
  label="High Confidence Detection Threshold",
622
- info="Detections above this are matched first (ByteTrack / OC-SORT).",
623
- )
624
- direction_consistency_slider = gr.Slider(
625
- minimum=0.0,
626
- maximum=1.0,
627
- value=0.2,
628
- step=0.05,
629
- label="Direction Consistency Weight",
630
- info="Weight for direction consistency in association cost (OC-SORT only).",
631
- )
632
- delta_t_slider = gr.Slider(
633
- minimum=1,
634
- maximum=10,
635
- value=3,
636
- step=1,
637
- label="Delta T",
638
- info="Past frames for velocity estimation during occlusion (OC-SORT only).",
639
  )
640
 
641
  with gr.Column():
@@ -682,13 +381,10 @@ with gr.Blocks(title="Trackers Playground 🔥") as demo:
682
  confidence_slider,
683
  lost_track_buffer_slider,
684
  track_activation_slider,
685
- minimum_consecutive_slider,
686
- minimum_iou_slider,
687
- high_confidence_slider,
688
- direction_consistency_slider,
689
- delta_t_slider,
690
  class_filter,
691
- track_id_filter,
692
  show_boxes_checkbox,
693
  show_ids_checkbox,
694
  show_labels_checkbox,
@@ -699,7 +395,7 @@ with gr.Blocks(title="Trackers Playground 🔥") as demo:
699
  outputs=output_video,
700
  )
701
 
702
- track_button.click(
703
  fn=track,
704
  inputs=[
705
  input_video,
@@ -708,13 +404,10 @@ with gr.Blocks(title="Trackers Playground 🔥") as demo:
708
  confidence_slider,
709
  lost_track_buffer_slider,
710
  track_activation_slider,
711
- minimum_consecutive_slider,
712
- minimum_iou_slider,
713
- high_confidence_slider,
714
- direction_consistency_slider,
715
- delta_t_slider,
716
  class_filter,
717
- track_id_filter,
718
  show_boxes_checkbox,
719
  show_ids_checkbox,
720
  show_labels_checkbox,
 
1
+ """Gradio app for the trackers library — run object tracking on uploaded videos."""
2
+
3
  from __future__ import annotations
4
 
5
+ import subprocess
 
6
  import tempfile
7
  from pathlib import Path
8
 
9
  import cv2
10
  import gradio as gr
 
 
 
 
 
 
 
11
 
12
  MAX_DURATION_SECONDS = 30
13
 
 
22
  "rfdetr-seg-large",
23
  ]
24
 
25
+ TRACKERS = ["bytetrack", "sort"]
26
 
27
  COCO_CLASSES = [
28
  "person",
 
37
  "sports ball",
38
  ]
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  VIDEO_EXAMPLES = [
41
  [
42
  "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/bikes-1280x720-1.mp4",
43
  "rfdetr-small",
44
+ "bytetrack",
45
  0.2,
46
  30,
47
  0.3,
48
  3,
49
  0.1,
50
  0.6,
 
 
51
  [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  True,
53
  True,
54
  False,
 
66
  3,
67
  0.3,
68
  0.6,
 
 
69
  [],
 
70
  True,
71
  True,
72
  False,
 
75
  True,
76
  ],
77
  [
78
+ "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/cars-1280x720-1.mp4",
79
+ "rfdetr-small",
80
+ "bytetrack",
81
  0.2,
82
  30,
83
  0.3,
84
  3,
85
  0.1,
86
  0.6,
87
+ ["car"],
 
 
 
 
88
  True,
89
  True,
90
  False,
91
  True,
92
  False,
93
+ False,
94
  ],
95
  [
96
  "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/jets-1280x720-1.mp4",
 
102
  3,
103
  0.1,
104
  0.6,
 
 
105
  [],
 
106
  True,
107
  True,
108
  False,
 
120
  3,
121
  0.1,
122
  0.6,
 
 
123
  [],
 
124
  True,
125
  True,
126
  False,
127
  False,
128
  True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  False,
130
  ],
131
  [
132
  "https://storage.googleapis.com/com-roboflow-marketing/supervision/video-examples/vehicles-1280x720.mp4",
133
+ "rfdetr-small",
134
+ "bytetrack",
135
  0.2,
136
  30,
137
  0.3,
138
  3,
139
  0.1,
140
  0.6,
 
 
141
  [],
 
142
  True,
143
  True,
144
  True,
 
149
  ]
150
 
151
 
152
+ def _get_video_duration(path: str) -> float:
153
+ """Return video duration in seconds using OpenCV."""
154
+ cap = cv2.VideoCapture(path)
155
+ if not cap.isOpened():
156
  raise gr.Error("Could not open the uploaded video.")
157
+ fps = cap.get(cv2.CAP_PROP_FPS)
158
+ frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
159
+ cap.release()
160
+ if fps <= 0:
161
  raise gr.Error("Could not determine video frame rate.")
162
+ return frame_count / fps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
 
165
  def track(
166
  video_path: str,
167
+ model: str,
168
+ tracker: str,
169
  confidence: float,
170
  lost_track_buffer: int,
171
  track_activation_threshold: float,
172
  minimum_consecutive_frames: int,
173
  minimum_iou_threshold: float,
174
  high_conf_det_threshold: float,
 
 
175
  classes: list[str] | None = None,
 
176
  show_boxes: bool = True,
177
  show_ids: bool = True,
178
  show_labels: bool = False,
179
  show_confidence: bool = False,
180
  show_trajectories: bool = False,
181
  show_masks: bool = False,
 
182
  ) -> str:
183
  """Run tracking on the uploaded video and return the output path."""
184
  if video_path is None:
185
  raise gr.Error("Please upload a video.")
186
 
187
+ duration = _get_video_duration(video_path)
188
  if duration > MAX_DURATION_SECONDS:
189
  raise gr.Error(
190
  f"Video is {duration:.1f}s long. "
191
+ f"Maximum allowed duration is {MAX_DURATION_SECONDS}s."
 
192
  )
193
 
194
+ tmp_dir = tempfile.mkdtemp()
195
+ output_path = str(Path(tmp_dir) / "output.mp4")
196
+
197
+ cmd = [
198
+ "trackers",
199
+ "track",
200
+ "--source",
201
+ video_path,
202
+ "--output",
203
+ output_path,
204
+ "--overwrite",
205
+ "--model",
206
+ model,
207
+ "--model.device",
208
+ "cuda",
209
+ "--tracker",
210
+ tracker,
211
+ "--model.confidence",
212
+ str(confidence),
213
+ "--tracker.lost_track_buffer",
214
+ str(lost_track_buffer),
215
+ "--tracker.track_activation_threshold",
216
+ str(track_activation_threshold),
217
+ "--tracker.minimum_consecutive_frames",
218
+ str(minimum_consecutive_frames),
219
+ "--tracker.minimum_iou_threshold",
220
+ str(minimum_iou_threshold),
221
+ ]
222
 
223
+ # ByteTrack extra param
224
+ if tracker == "bytetrack":
225
+ cmd += ["--tracker.high_conf_det_threshold", str(high_conf_det_threshold)]
226
 
227
+ if classes:
228
+ cmd += ["--classes", ",".join(classes)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ if show_boxes:
231
+ cmd += ["--show-boxes"]
232
+ else:
233
+ cmd += ["--no-boxes"]
234
+ if show_ids:
235
+ cmd += ["--show-ids"]
236
+ if show_labels:
237
+ cmd += ["--show-labels"]
238
+ if show_confidence:
239
+ cmd += ["--show-confidence"]
240
  if show_trajectories:
241
+ cmd += ["--show-trajectories"]
242
+ if show_masks:
243
+ cmd += ["--show-masks"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ result = subprocess.run(cmd, capture_output=True, text=True) # noqa: S603
246
+ if result.returncode != 0:
247
+ raise gr.Error(f"Tracking failed:\n{result.stderr[-500:]}")
248
 
249
  return output_path
250
 
 
261
  input_video = gr.Video(label="Input Video")
262
  output_video = gr.Video(label="Tracked Video")
263
 
264
+ track_btn = gr.Button(value="Track", variant="primary")
265
 
266
  with gr.Row():
267
  model_dropdown = gr.Dropdown(
 
293
  label="Filter Classes",
294
  info="Only track selected classes. None selected means all.",
295
  )
 
 
 
 
 
 
 
 
 
 
296
 
297
  with gr.Column():
298
  gr.Markdown("### Tracker")
 
302
  value=30,
303
  step=1,
304
  label="Lost Track Buffer",
305
+ info="Frames to keep a lost track before removing it.",
306
  )
307
  track_activation_slider = gr.Slider(
308
  minimum=0.0,
 
310
  value=0.3,
311
  step=0.05,
312
  label="Track Activation Threshold",
313
+ info="Minimum score for a track to be activated.",
314
  )
315
+ min_consecutive_slider = gr.Slider(
316
  minimum=1,
317
  maximum=10,
318
  value=2,
319
  step=1,
320
  label="Minimum Consecutive Frames",
321
+ info="Detections needed before a track is confirmed.",
322
  )
323
+ min_iou_slider = gr.Slider(
324
  minimum=0.0,
325
  maximum=1.0,
326
  value=0.1,
327
  step=0.05,
328
  label="Minimum IoU Threshold",
329
+ info="Overlap required to match a detection to a track.",
330
  )
331
+ high_conf_slider = gr.Slider(
332
  minimum=0.0,
333
  maximum=1.0,
334
  value=0.6,
335
  step=0.05,
336
  label="High Confidence Detection Threshold",
337
+ info="Detections above this are matched first (ByteTrack only).",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  )
339
 
340
  with gr.Column():
 
381
  confidence_slider,
382
  lost_track_buffer_slider,
383
  track_activation_slider,
384
+ min_consecutive_slider,
385
+ min_iou_slider,
386
+ high_conf_slider,
 
 
387
  class_filter,
 
388
  show_boxes_checkbox,
389
  show_ids_checkbox,
390
  show_labels_checkbox,
 
395
  outputs=output_video,
396
  )
397
 
398
+ track_btn.click(
399
  fn=track,
400
  inputs=[
401
  input_video,
 
404
  confidence_slider,
405
  lost_track_buffer_slider,
406
  track_activation_slider,
407
+ min_consecutive_slider,
408
+ min_iou_slider,
409
+ high_conf_slider,
 
 
410
  class_filter,
 
411
  show_boxes_checkbox,
412
  show_ids_checkbox,
413
  show_labels_checkbox,
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  gradio>=6.3.0,<6.4.0
2
- inference-models>=0.19.0
3
- trackers==2.3.0
 
1
  gradio>=6.3.0,<6.4.0
2
+ inference-models[onnx-cpu]==0.18.6rc14
3
+ trackers==2.2.0rc1