|
|
""" |
|
|
Gradio app to explore pancreas or lymphome clinical report annotations. |
|
|
""" |
|
|
|
|
|
import os |
|
|
from functools import partial |
|
|
from pathlib import Path |
|
|
|
|
|
import gradio as gr |
|
|
from datasets import load_dataset |
|
|
|
|
|
MIN_ANNOTATIONS = 10 |
|
|
PANCREAS_REPO_ID = os.getenv("PANCREAS_REPO_ID", "rntc/biomed-fr-pancreas-annotations") |
|
|
LYMPHOME_REPO_ID = os.getenv("LYMPHOME_REPO_ID", "rntc/biomed-fr-lymphome-annotations") |
|
|
LYMPHOME_LOCAL_JSONL = ( |
|
|
Path(__file__).resolve().parent.parent |
|
|
/ "Qwen--Qwen3-235B-A22B-Instruct-2507-FP8-4-lymphome-annotation-20251201_153807.jsonl" |
|
|
) |
|
|
|
|
|
|
|
|
COLORS = [ |
|
|
"#FFEB3B", |
|
|
"#4CAF50", |
|
|
"#2196F3", |
|
|
"#FF9800", |
|
|
"#E91E63", |
|
|
"#9C27B0", |
|
|
"#00BCD4", |
|
|
"#8BC34A", |
|
|
"#FF5722", |
|
|
"#607D8B", |
|
|
] |
|
|
|
|
|
|
|
|
def count_real_annotations(annotation): |
|
|
"""Count real annotations (excluding 'not found' placeholders).""" |
|
|
count = 0 |
|
|
for var_data in annotation.values(): |
|
|
if var_data and isinstance(var_data, dict): |
|
|
value = var_data.get("value") |
|
|
span = var_data.get("span", "") |
|
|
if value: |
|
|
if span and "pas de mention" in span.lower(): |
|
|
continue |
|
|
if "not performed" in str(value).lower(): |
|
|
continue |
|
|
count += 1 |
|
|
return count |
|
|
|
|
|
|
|
|
def escape_html(text): |
|
|
if not text: |
|
|
return "" |
|
|
return str(text).replace("&", "&").replace("<", "<").replace(">", ">") |
|
|
|
|
|
|
|
|
def highlight_text(cr_text, annotation): |
|
|
"""Highlight spans in CR text.""" |
|
|
if not cr_text or not annotation: |
|
|
return f"<pre style='white-space:pre-wrap;'>{escape_html(cr_text)}</pre>" |
|
|
|
|
|
|
|
|
spans = [] |
|
|
for var_name, var_data in annotation.items(): |
|
|
if var_data and isinstance(var_data, dict): |
|
|
span = var_data.get("span") |
|
|
value = var_data.get("value") |
|
|
if span and value and span in cr_text: |
|
|
spans.append( |
|
|
{ |
|
|
"text": span, |
|
|
"start": cr_text.find(span), |
|
|
"var": var_name.replace("_", " ").title(), |
|
|
"value": str(value), |
|
|
} |
|
|
) |
|
|
|
|
|
if not spans: |
|
|
return f"<pre style='white-space:pre-wrap;'>{escape_html(cr_text)}</pre>" |
|
|
|
|
|
|
|
|
spans.sort(key=lambda x: x["start"]) |
|
|
filtered = [] |
|
|
for s in spans: |
|
|
s["end"] = s["start"] + len(s["text"]) |
|
|
if not filtered or s["start"] >= filtered[-1]["end"]: |
|
|
filtered.append(s) |
|
|
|
|
|
|
|
|
html = [] |
|
|
pos = 0 |
|
|
color_map = {} |
|
|
color_idx = 0 |
|
|
|
|
|
for s in filtered: |
|
|
if s["start"] > pos: |
|
|
html.append(escape_html(cr_text[pos : s["start"]])) |
|
|
|
|
|
if s["var"] not in color_map: |
|
|
color_map[s["var"]] = COLORS[color_idx % len(COLORS)] |
|
|
color_idx += 1 |
|
|
|
|
|
color = color_map[s["var"]] |
|
|
html.append( |
|
|
f'<mark style="background:{color};padding:1px 3px;border-radius:3px;" ' |
|
|
f'title="{escape_html(s["var"])}: {escape_html(s["value"])}">' |
|
|
f'{escape_html(s["text"])}</mark>' |
|
|
) |
|
|
pos = s["end"] |
|
|
|
|
|
if pos < len(cr_text): |
|
|
html.append(escape_html(cr_text[pos:])) |
|
|
|
|
|
return f"<pre style='white-space:pre-wrap;line-height:1.6;'>{''.join(html)}</pre>" |
|
|
|
|
|
|
|
|
def format_table(annotation): |
|
|
"""Format annotations as HTML table.""" |
|
|
if not annotation: |
|
|
return "<p>No annotations</p>" |
|
|
|
|
|
rows = [] |
|
|
for var_name, var_data in annotation.items(): |
|
|
if var_data and isinstance(var_data, dict): |
|
|
value = var_data.get("value") |
|
|
span = var_data.get("span", "") |
|
|
|
|
|
var_label = var_name.replace("_", " ").title() |
|
|
|
|
|
if value: |
|
|
if span and "pas de mention" in span.lower(): |
|
|
display_value = "/" |
|
|
display_span = "" |
|
|
elif "not performed" in str(value).lower(): |
|
|
display_value = "/" |
|
|
display_span = "" |
|
|
else: |
|
|
display_value = str(value) |
|
|
display_span = span[:60] + "..." if span and len(span) > 60 else (span or "") |
|
|
else: |
|
|
display_value = "/" |
|
|
display_span = "" |
|
|
|
|
|
rows.append( |
|
|
f"""<tr> |
|
|
<td style="padding:6px 10px;border-bottom:1px solid #ddd;font-weight:500;">{escape_html(var_label)}</td> |
|
|
<td style="padding:6px 10px;border-bottom:1px solid #ddd;color:#1565C0;">{escape_html(display_value)}</td> |
|
|
<td style="padding:6px 10px;border-bottom:1px solid #ddd;color:#666;font-size:12px;font-style:italic;">{escape_html(display_span)}</td> |
|
|
</tr>""" |
|
|
) |
|
|
|
|
|
return f"""<table style="width:100%;border-collapse:collapse;font-size:13px;"> |
|
|
<thead><tr style="background:#f5f5f5;"> |
|
|
<th style="padding:8px 10px;text-align:left;border-bottom:2px solid #ddd;">Variable</th> |
|
|
<th style="padding:8px 10px;text-align:left;border-bottom:2px solid #ddd;">Value</th> |
|
|
<th style="padding:8px 10px;text-align:left;border-bottom:2px solid #ddd;">Source</th> |
|
|
</tr></thead> |
|
|
<tbody>{"".join(rows)}</tbody> |
|
|
</table>""" |
|
|
|
|
|
|
|
|
def load_pancreas_dataset(): |
|
|
print(f"Loading pancreas dataset from {PANCREAS_REPO_ID}...") |
|
|
dataset = load_dataset(PANCREAS_REPO_ID, split="train") |
|
|
print(f"Loaded {len(dataset)} pancreas samples") |
|
|
return dataset |
|
|
|
|
|
|
|
|
def load_lymphome_dataset(): |
|
|
print(f"Loading lymphome dataset from {LYMPHOME_REPO_ID} (Hub)...") |
|
|
try: |
|
|
dataset = load_dataset(LYMPHOME_REPO_ID, split="train") |
|
|
print(f"Loaded {len(dataset)} lymphome samples from Hub") |
|
|
return dataset |
|
|
except Exception as exc: |
|
|
print(f"Failed to load lymphome dataset from Hub: {exc}") |
|
|
if LYMPHOME_LOCAL_JSONL.exists(): |
|
|
print(f"Falling back to local lymphome JSONL at {LYMPHOME_LOCAL_JSONL}") |
|
|
dataset = load_dataset("json", data_files=str(LYMPHOME_LOCAL_JSONL), split="train") |
|
|
print(f"Loaded {len(dataset)} lymphome samples from local file") |
|
|
return dataset |
|
|
raise |
|
|
|
|
|
|
|
|
def filter_indices(dataset, min_annotations): |
|
|
return [ |
|
|
i |
|
|
for i, sample in enumerate(dataset) |
|
|
if count_real_annotations(sample.get("annotation", {})) >= min_annotations |
|
|
] |
|
|
|
|
|
|
|
|
def prepare_source(key, label, loader, min_annotations): |
|
|
"""Load a dataset source and precompute filtered indices.""" |
|
|
try: |
|
|
dataset = loader() |
|
|
filtered = filter_indices(dataset, min_annotations) |
|
|
print(f"{label}: filtered to {len(filtered)} samples with >= {min_annotations} annotations") |
|
|
return { |
|
|
"label": label, |
|
|
"dataset": dataset, |
|
|
"filtered_indices": filtered, |
|
|
"min_annotations": min_annotations, |
|
|
"error": None, |
|
|
} |
|
|
except Exception as exc: |
|
|
print(f"Failed to load {label}: {exc}") |
|
|
return { |
|
|
"label": label, |
|
|
"dataset": None, |
|
|
"filtered_indices": [], |
|
|
"min_annotations": min_annotations, |
|
|
"error": str(exc), |
|
|
} |
|
|
|
|
|
|
|
|
SOURCES = { |
|
|
"pancreas": prepare_source("pancreas", "Pancréas", load_pancreas_dataset, MIN_ANNOTATIONS), |
|
|
"lymphome": prepare_source("lymphome", "Lymphome", load_lymphome_dataset, MIN_ANNOTATIONS), |
|
|
} |
|
|
|
|
|
|
|
|
def display_sample_for_source(source_key, slider_idx): |
|
|
"""Display a sample for a given dataset source.""" |
|
|
source = SOURCES[source_key] |
|
|
|
|
|
if source["error"]: |
|
|
message = f"Dataset unavailable: {source['error']}" |
|
|
return message, message, message |
|
|
|
|
|
if not source["filtered_indices"]: |
|
|
message = f"No samples with >= {source['min_annotations']} annotations." |
|
|
return message, message, message |
|
|
|
|
|
slider_idx = int(slider_idx) |
|
|
if slider_idx < 0 or slider_idx >= len(source["filtered_indices"]): |
|
|
return "Invalid", "Invalid", "Invalid" |
|
|
|
|
|
real_idx = source["filtered_indices"][slider_idx] |
|
|
sample = source["dataset"][real_idx] |
|
|
|
|
|
original = sample.get("original_text", "") |
|
|
cr = sample.get("CR", "") |
|
|
annotation = sample.get("annotation", {}) |
|
|
|
|
|
n_annotations = count_real_annotations(annotation) |
|
|
|
|
|
original_html = f"<pre style='white-space:pre-wrap;line-height:1.6;'>{escape_html(original)}</pre>" |
|
|
cr_html = ( |
|
|
f"<p><b>Sample #{real_idx}</b> — {n_annotations} annotations</p>" |
|
|
+ highlight_text(cr, annotation) |
|
|
) |
|
|
|
|
|
return original_html, cr_html, format_table(annotation) |
|
|
|
|
|
|
|
|
def build_tab(source_key): |
|
|
source = SOURCES[source_key] |
|
|
label = source["label"] |
|
|
|
|
|
with gr.TabItem(label): |
|
|
if source["error"]: |
|
|
gr.Markdown(f"⚠️ Could not load {label} dataset: {escape_html(source['error'])}") |
|
|
return |
|
|
|
|
|
if not source["filtered_indices"]: |
|
|
gr.Markdown(f"⚠️ No samples with >= {source['min_annotations']} annotations.") |
|
|
return |
|
|
|
|
|
gr.Markdown( |
|
|
f"Showing {len(source['filtered_indices'])} samples with >= " |
|
|
f"{source['min_annotations']} annotations. Hover over highlights to see values." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
slider = gr.Slider( |
|
|
0, |
|
|
len(source["filtered_indices"]) - 1, |
|
|
value=0, |
|
|
step=1, |
|
|
label="Sample", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Original (English)") |
|
|
original_html = gr.HTML() |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Generated CR (French)") |
|
|
cr_html = gr.HTML() |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Extracted Variables") |
|
|
table_html = gr.HTML() |
|
|
|
|
|
slider.change( |
|
|
fn=partial(display_sample_for_source, source_key), |
|
|
inputs=[slider], |
|
|
outputs=[original_html, cr_html, table_html], |
|
|
) |
|
|
demo.load( |
|
|
fn=partial(display_sample_for_source, source_key), |
|
|
inputs=[slider], |
|
|
outputs=[original_html, cr_html, table_html], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Clinical Annotations Explorer", theme=gr.themes.Base()) as demo: |
|
|
gr.Markdown("# 🔬 Clinical Annotation Explorer") |
|
|
gr.Markdown( |
|
|
"Use the tabs below to switch between pancreas and lymphome annotations. " |
|
|
"Hover over highlights to see the extracted values." |
|
|
) |
|
|
|
|
|
with gr.Tabs(): |
|
|
build_tab("pancreas") |
|
|
build_tab("lymphome") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|