VizAnnot / app.py
rntc's picture
Upload folder using huggingface_hub
f12fd7f verified
"""
Gradio app to explore pancreas or lymphome clinical report annotations.
"""
import os
from functools import partial
from pathlib import Path
import gradio as gr
from datasets import load_dataset
MIN_ANNOTATIONS = 10
PANCREAS_REPO_ID = os.getenv("PANCREAS_REPO_ID", "rntc/biomed-fr-pancreas-annotations")
LYMPHOME_REPO_ID = os.getenv("LYMPHOME_REPO_ID", "rntc/biomed-fr-lymphome-annotations")
LYMPHOME_LOCAL_JSONL = (
Path(__file__).resolve().parent.parent
/ "Qwen--Qwen3-235B-A22B-Instruct-2507-FP8-4-lymphome-annotation-20251201_153807.jsonl"
)
# Colors for highlighting
COLORS = [
"#FFEB3B",
"#4CAF50",
"#2196F3",
"#FF9800",
"#E91E63",
"#9C27B0",
"#00BCD4",
"#8BC34A",
"#FF5722",
"#607D8B",
]
def count_real_annotations(annotation):
"""Count real annotations (excluding 'not found' placeholders)."""
count = 0
for var_data in annotation.values():
if var_data and isinstance(var_data, dict):
value = var_data.get("value")
span = var_data.get("span", "")
if value:
if span and "pas de mention" in span.lower():
continue
if "not performed" in str(value).lower():
continue
count += 1
return count
def escape_html(text):
if not text:
return ""
return str(text).replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
def highlight_text(cr_text, annotation):
"""Highlight spans in CR text."""
if not cr_text or not annotation:
return f"<pre style='white-space:pre-wrap;'>{escape_html(cr_text)}</pre>"
# Collect valid spans (that exist in text)
spans = []
for var_name, var_data in annotation.items():
if var_data and isinstance(var_data, dict):
span = var_data.get("span")
value = var_data.get("value")
if span and value and span in cr_text:
spans.append(
{
"text": span,
"start": cr_text.find(span),
"var": var_name.replace("_", " ").title(),
"value": str(value),
}
)
if not spans:
return f"<pre style='white-space:pre-wrap;'>{escape_html(cr_text)}</pre>"
# Sort by position and remove overlaps
spans.sort(key=lambda x: x["start"])
filtered = []
for s in spans:
s["end"] = s["start"] + len(s["text"])
if not filtered or s["start"] >= filtered[-1]["end"]:
filtered.append(s)
# Build HTML
html = []
pos = 0
color_map = {}
color_idx = 0
for s in filtered:
if s["start"] > pos:
html.append(escape_html(cr_text[pos : s["start"]]))
if s["var"] not in color_map:
color_map[s["var"]] = COLORS[color_idx % len(COLORS)]
color_idx += 1
color = color_map[s["var"]]
html.append(
f'<mark style="background:{color};padding:1px 3px;border-radius:3px;" '
f'title="{escape_html(s["var"])}: {escape_html(s["value"])}">'
f'{escape_html(s["text"])}</mark>'
)
pos = s["end"]
if pos < len(cr_text):
html.append(escape_html(cr_text[pos:]))
return f"<pre style='white-space:pre-wrap;line-height:1.6;'>{''.join(html)}</pre>"
def format_table(annotation):
"""Format annotations as HTML table."""
if not annotation:
return "<p>No annotations</p>"
rows = []
for var_name, var_data in annotation.items():
if var_data and isinstance(var_data, dict):
value = var_data.get("value")
span = var_data.get("span", "")
var_label = var_name.replace("_", " ").title()
if value:
if span and "pas de mention" in span.lower():
display_value = "/"
display_span = ""
elif "not performed" in str(value).lower():
display_value = "/"
display_span = ""
else:
display_value = str(value)
display_span = span[:60] + "..." if span and len(span) > 60 else (span or "")
else:
display_value = "/"
display_span = ""
rows.append(
f"""<tr>
<td style="padding:6px 10px;border-bottom:1px solid #ddd;font-weight:500;">{escape_html(var_label)}</td>
<td style="padding:6px 10px;border-bottom:1px solid #ddd;color:#1565C0;">{escape_html(display_value)}</td>
<td style="padding:6px 10px;border-bottom:1px solid #ddd;color:#666;font-size:12px;font-style:italic;">{escape_html(display_span)}</td>
</tr>"""
)
return f"""<table style="width:100%;border-collapse:collapse;font-size:13px;">
<thead><tr style="background:#f5f5f5;">
<th style="padding:8px 10px;text-align:left;border-bottom:2px solid #ddd;">Variable</th>
<th style="padding:8px 10px;text-align:left;border-bottom:2px solid #ddd;">Value</th>
<th style="padding:8px 10px;text-align:left;border-bottom:2px solid #ddd;">Source</th>
</tr></thead>
<tbody>{"".join(rows)}</tbody>
</table>"""
def load_pancreas_dataset():
print(f"Loading pancreas dataset from {PANCREAS_REPO_ID}...")
dataset = load_dataset(PANCREAS_REPO_ID, split="train")
print(f"Loaded {len(dataset)} pancreas samples")
return dataset
def load_lymphome_dataset():
print(f"Loading lymphome dataset from {LYMPHOME_REPO_ID} (Hub)...")
try:
dataset = load_dataset(LYMPHOME_REPO_ID, split="train")
print(f"Loaded {len(dataset)} lymphome samples from Hub")
return dataset
except Exception as exc: # noqa: BLE001 (we want to surface any failure)
print(f"Failed to load lymphome dataset from Hub: {exc}")
if LYMPHOME_LOCAL_JSONL.exists():
print(f"Falling back to local lymphome JSONL at {LYMPHOME_LOCAL_JSONL}")
dataset = load_dataset("json", data_files=str(LYMPHOME_LOCAL_JSONL), split="train")
print(f"Loaded {len(dataset)} lymphome samples from local file")
return dataset
raise
def filter_indices(dataset, min_annotations):
return [
i
for i, sample in enumerate(dataset)
if count_real_annotations(sample.get("annotation", {})) >= min_annotations
]
def prepare_source(key, label, loader, min_annotations):
"""Load a dataset source and precompute filtered indices."""
try:
dataset = loader()
filtered = filter_indices(dataset, min_annotations)
print(f"{label}: filtered to {len(filtered)} samples with >= {min_annotations} annotations")
return {
"label": label,
"dataset": dataset,
"filtered_indices": filtered,
"min_annotations": min_annotations,
"error": None,
}
except Exception as exc: # noqa: BLE001 (we want to surface any failure)
print(f"Failed to load {label}: {exc}")
return {
"label": label,
"dataset": None,
"filtered_indices": [],
"min_annotations": min_annotations,
"error": str(exc),
}
SOURCES = {
"pancreas": prepare_source("pancreas", "Pancréas", load_pancreas_dataset, MIN_ANNOTATIONS),
"lymphome": prepare_source("lymphome", "Lymphome", load_lymphome_dataset, MIN_ANNOTATIONS),
}
def display_sample_for_source(source_key, slider_idx):
"""Display a sample for a given dataset source."""
source = SOURCES[source_key]
if source["error"]:
message = f"Dataset unavailable: {source['error']}"
return message, message, message
if not source["filtered_indices"]:
message = f"No samples with >= {source['min_annotations']} annotations."
return message, message, message
slider_idx = int(slider_idx)
if slider_idx < 0 or slider_idx >= len(source["filtered_indices"]):
return "Invalid", "Invalid", "Invalid"
real_idx = source["filtered_indices"][slider_idx]
sample = source["dataset"][real_idx]
original = sample.get("original_text", "")
cr = sample.get("CR", "")
annotation = sample.get("annotation", {})
n_annotations = count_real_annotations(annotation)
original_html = f"<pre style='white-space:pre-wrap;line-height:1.6;'>{escape_html(original)}</pre>"
cr_html = (
f"<p><b>Sample #{real_idx}</b> — {n_annotations} annotations</p>"
+ highlight_text(cr, annotation)
)
return original_html, cr_html, format_table(annotation)
def build_tab(source_key):
source = SOURCES[source_key]
label = source["label"]
with gr.TabItem(label):
if source["error"]:
gr.Markdown(f"⚠️ Could not load {label} dataset: {escape_html(source['error'])}")
return
if not source["filtered_indices"]:
gr.Markdown(f"⚠️ No samples with >= {source['min_annotations']} annotations.")
return
gr.Markdown(
f"Showing {len(source['filtered_indices'])} samples with >= "
f"{source['min_annotations']} annotations. Hover over highlights to see values."
)
with gr.Row():
slider = gr.Slider(
0,
len(source["filtered_indices"]) - 1,
value=0,
step=1,
label="Sample",
)
with gr.Row():
with gr.Column():
gr.Markdown("### Original (English)")
original_html = gr.HTML()
with gr.Column():
gr.Markdown("### Generated CR (French)")
cr_html = gr.HTML()
with gr.Column():
gr.Markdown("### Extracted Variables")
table_html = gr.HTML()
slider.change(
fn=partial(display_sample_for_source, source_key),
inputs=[slider],
outputs=[original_html, cr_html, table_html],
)
demo.load(
fn=partial(display_sample_for_source, source_key),
inputs=[slider],
outputs=[original_html, cr_html, table_html],
)
# Build UI
with gr.Blocks(title="Clinical Annotations Explorer", theme=gr.themes.Base()) as demo:
gr.Markdown("# 🔬 Clinical Annotation Explorer")
gr.Markdown(
"Use the tabs below to switch between pancreas and lymphome annotations. "
"Hover over highlights to see the extracted values."
)
with gr.Tabs():
build_tab("pancreas")
build_tab("lymphome")
if __name__ == "__main__":
demo.launch()