|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import json |
|
|
import nibabel as nib |
|
|
import torch |
|
|
import os |
|
|
import argparse |
|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import numpy as np |
|
|
import nibabel as nib |
|
|
import torch.nn.functional as F |
|
|
|
|
|
import numpy as np |
|
|
import scipy.ndimage as ndi |
|
|
from skimage.morphology import ball |
|
|
from skimage.measure import label |
|
|
from scipy.ndimage import label as cc_label |
|
|
import numpy as np |
|
|
from scipy.ndimage import binary_dilation, binary_erosion |
|
|
from skimage.morphology import ball |
|
|
from collections import Counter |
|
|
|
|
|
def multiple_lesions_corrected(label_np, lesion_label=23, lung_labels=[28,29,30,31,32], change_percent=20): |
|
|
""" |
|
|
Modify multiple lesions with improved anatomically-constrained morphological operations. |
|
|
|
|
|
For shrinking: fills lesion with lung tissue first, then erodes to target size. |
|
|
For growing: expands lesion within lung boundaries to target size. |
|
|
|
|
|
Args: |
|
|
label_np: 3D numpy array with segmentation labels |
|
|
lesion_label: Label value for lesions/nodules (default: 23) |
|
|
lung_labels: List of lung tissue label values (default: [28,29,30,31,32]) |
|
|
change_percent: Percentage change (positive for growth, negative for shrinking) |
|
|
|
|
|
Returns: |
|
|
Modified label array with adjusted lesion sizes |
|
|
""" |
|
|
label_np = label_np.copy() |
|
|
lesion_mask = (label_np == lesion_label) |
|
|
lung_mask = np.isin(label_np, lung_labels) |
|
|
cc, num_lesions = cc_label(lesion_mask) |
|
|
|
|
|
if num_lesions == 0: |
|
|
print("No lesions found.") |
|
|
return label_np |
|
|
|
|
|
for i in range(1, num_lesions + 1): |
|
|
single_lesion_mask = (cc == i) |
|
|
original_volume = np.sum(single_lesion_mask) |
|
|
|
|
|
if original_volume == 0: |
|
|
continue |
|
|
|
|
|
print(f"Processing lesion {i}: original volume = {original_volume} voxels") |
|
|
|
|
|
|
|
|
dilated = binary_dilation(single_lesion_mask, structure=ball(3)) |
|
|
border = dilated & (~single_lesion_mask) |
|
|
neighbors = label_np[border] |
|
|
valid_neighbors = neighbors[np.isin(neighbors, lung_labels)] |
|
|
fill_label = Counter(valid_neighbors).most_common(1)[0][0] if len(valid_neighbors) > 0 else 30 |
|
|
|
|
|
target_volume = int(original_volume * (1 + change_percent / 100.0)) |
|
|
print(f"Target volume for lesion {i}: {target_volume} voxels ({1 + change_percent / 100.0:.2f}x original)") |
|
|
|
|
|
current_mask = single_lesion_mask.copy() |
|
|
struct = ball(1) |
|
|
|
|
|
if change_percent < 0: |
|
|
|
|
|
label_np[single_lesion_mask] = fill_label |
|
|
|
|
|
for _ in range(1000): |
|
|
next_mask = binary_erosion(current_mask, structure=struct) |
|
|
next_mask = next_mask & lung_mask |
|
|
if np.array_equal(next_mask, current_mask): |
|
|
break |
|
|
current_mask = next_mask |
|
|
if np.sum(current_mask) <= target_volume: |
|
|
break |
|
|
|
|
|
else: |
|
|
|
|
|
max_iterations = min(1000, target_volume) |
|
|
stuck_count = 0 |
|
|
|
|
|
for iteration in range(max_iterations): |
|
|
current_volume = np.sum(current_mask) |
|
|
|
|
|
|
|
|
if current_volume >= target_volume: |
|
|
print(f"✅ Lesion {i}: target volume reached at {current_volume} voxels (target: {target_volume})") |
|
|
break |
|
|
|
|
|
|
|
|
next_mask = binary_dilation(current_mask, structure=struct) |
|
|
|
|
|
|
|
|
valid_expansion = next_mask & lung_mask |
|
|
new_volume = np.sum(valid_expansion) |
|
|
|
|
|
|
|
|
if new_volume > target_volume * 1.1: |
|
|
print(f"✅ Lesion {i}: stopping to avoid overshoot. Current: {current_volume}, next would be: {new_volume}, target: {target_volume}") |
|
|
break |
|
|
|
|
|
|
|
|
if new_volume == current_volume: |
|
|
stuck_count += 1 |
|
|
if stuck_count >= 3: |
|
|
print(f"⚠️ Lesion {i}: growth stopped by boundaries at {current_volume} voxels (target: {target_volume})") |
|
|
break |
|
|
else: |
|
|
stuck_count = 0 |
|
|
|
|
|
current_mask = valid_expansion |
|
|
|
|
|
|
|
|
final_volume = np.sum(current_mask) |
|
|
if final_volume < original_volume: |
|
|
print(f"❌ Error: Lesion {i} shrunk during growth! Using original mask.") |
|
|
current_mask = single_lesion_mask |
|
|
|
|
|
|
|
|
label_np[current_mask] = lesion_label |
|
|
|
|
|
|
|
|
final_volume = np.sum(current_mask) |
|
|
actual_ratio = final_volume / original_volume if original_volume > 0 else 0 |
|
|
print(f"Lesion {i} final result: {original_volume} → {final_volume} voxels ({actual_ratio:.2f}x, target was {1 + change_percent / 100.0:.2f}x)") |
|
|
|
|
|
return label_np |
|
|
|
|
|
import numpy as np |
|
|
from scipy.ndimage import distance_transform_edt, label as cc_label, binary_dilation, label |
|
|
from skimage.morphology import ball |
|
|
from collections import Counter |
|
|
|
|
|
def shrink_lesions_preserve_shape_connectivity(label_np, lesion_label=23, lung_labels=[28,29,30,31,32], shrink_percent=50, min_keep_voxels=10): |
|
|
""" |
|
|
Shrinks lesions labeled as 23 by a precise percent using distance transform. |
|
|
Preserves shape and keeps only the largest connected component inside lung. |
|
|
|
|
|
Args: |
|
|
label_np (np.ndarray): 3D label volume. |
|
|
lesion_label (int): Label used for lesions (default: 23). |
|
|
lung_labels (list): List of lung region labels (default: 28–32). |
|
|
shrink_percent (float): Percentage to shrink (e.g., 50). |
|
|
min_keep_voxels (int): Minimum voxels to keep in shrunk lesion. |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Updated label array. |
|
|
""" |
|
|
label_np = label_np.copy() |
|
|
lung_mask = np.isin(label_np, lung_labels) |
|
|
lesion_mask = (label_np == lesion_label) |
|
|
cc, num_lesions = cc_label(lesion_mask) |
|
|
|
|
|
if num_lesions == 0: |
|
|
print("No lesions found.") |
|
|
return label_np |
|
|
|
|
|
for i in range(1, num_lesions + 1): |
|
|
lesion_i_mask = (cc == i) |
|
|
original_voxels = np.argwhere(lesion_i_mask) |
|
|
|
|
|
if len(original_voxels) == 0: |
|
|
continue |
|
|
|
|
|
original_volume = len(original_voxels) |
|
|
target_volume = int(original_volume * (1 - shrink_percent / 100.0)) |
|
|
target_volume = max(target_volume, min_keep_voxels) |
|
|
|
|
|
|
|
|
dist_map = distance_transform_edt(lesion_i_mask) |
|
|
|
|
|
|
|
|
voxel_indices = np.argwhere(lesion_i_mask) |
|
|
distances = dist_map[lesion_i_mask] |
|
|
sorted_indices = np.argsort(-distances) |
|
|
top_voxels = voxel_indices[sorted_indices[:target_volume]] |
|
|
|
|
|
|
|
|
dilated = binary_dilation(lesion_i_mask, structure=ball(3)) |
|
|
border = dilated & (~lesion_i_mask) |
|
|
neighbors = label_np[border] |
|
|
valid_neighbors = neighbors[np.isin(neighbors, lung_labels)] |
|
|
fill_label = Counter(valid_neighbors).most_common(1)[0][0] if len(valid_neighbors) > 0 else 30 |
|
|
label_np[lesion_i_mask] = fill_label |
|
|
|
|
|
|
|
|
shrunk_mask = np.zeros_like(label_np, dtype=bool) |
|
|
for x, y, z in top_voxels: |
|
|
if lung_mask[x, y, z]: |
|
|
shrunk_mask[x, y, z] = True |
|
|
|
|
|
|
|
|
cc_shrunk, num_cc = label(shrunk_mask) |
|
|
if num_cc > 0: |
|
|
sizes = [(cc_shrunk == idx).sum() for idx in range(1, num_cc + 1)] |
|
|
largest_cc = (cc_shrunk == (np.argmax(sizes) + 1)) |
|
|
if largest_cc.sum() >= min_keep_voxels: |
|
|
label_np[largest_cc] = lesion_label |
|
|
print(f"✅ Lesion {i}: shrunk from {original_volume} → {largest_cc.sum()} voxels") |
|
|
else: |
|
|
print(f"⚠️ Lesion {i} shrunk below min threshold, skipped.") |
|
|
else: |
|
|
print(f"⚠️ Lesion {i} lost all connectivity, skipped.") |
|
|
|
|
|
return label_np |
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
from datetime import datetime |
|
|
|
|
|
def augment_and_save_masks_from_json(json_path, dict_to_read, data_root, lunglesion_lbl, scale_percent, mode, save_dir, log_file=None, random_seed=None, prefix="aug_"): |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
if log_file: |
|
|
|
|
|
file_handler = logging.FileHandler(log_file) |
|
|
file_handler.setLevel(logging.INFO) |
|
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') |
|
|
file_handler.setFormatter(formatter) |
|
|
logger.addHandler(file_handler) |
|
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
|
|
|
start_time = datetime.now() |
|
|
logger.info(f"Starting augmentation process at {start_time}") |
|
|
logger.info(f"Parameters: json_path={json_path}, dict_to_read={dict_to_read}, scale_percent={scale_percent}%, mode={mode}") |
|
|
|
|
|
|
|
|
with open(json_path, 'r') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
logger.info(f"Loaded JSON file with {len(data[dict_to_read])} entries") |
|
|
|
|
|
for idx, mask_entry in enumerate(data[dict_to_read]): |
|
|
logger.info(f"Processing entry {idx + 1}/{len(data[dict_to_read])}: {mask_entry['label']}") |
|
|
|
|
|
mask_path = os.path.join(data_root, mask_entry['label']) |
|
|
output_size = mask_entry['dim'] |
|
|
|
|
|
|
|
|
nii = nib.load(mask_path) |
|
|
mask_data = nii.get_fdata() |
|
|
affine = nii.affine |
|
|
header = nii.header |
|
|
|
|
|
if mode == 'shrink': |
|
|
augmented_np = shrink_lesions_preserve_shape_connectivity(mask_data, lesion_label=lunglesion_lbl, lung_labels=[28,29,30,31,32], shrink_percent=scale_percent, min_keep_voxels=10) |
|
|
elif mode == 'grow': |
|
|
augmented_np = multiple_lesions_corrected(mask_data, lesion_label=lunglesion_lbl, lung_labels=[28,29,30,31,32], change_percent=scale_percent) |
|
|
|
|
|
|
|
|
original_volume = np.sum(mask_data == lunglesion_lbl) |
|
|
augmented_volume = np.sum(augmented_np == lunglesion_lbl) |
|
|
volume_ratio = 100 * augmented_volume / original_volume if original_volume > 0 else 0 |
|
|
|
|
|
logger.info(f"Original lesion volume: {original_volume} voxels") |
|
|
logger.info(f"Augmented lesion volume: {augmented_volume} voxels") |
|
|
logger.info(f"Volume ratio: {volume_ratio:.2f}% of original") |
|
|
|
|
|
|
|
|
base_name = os.path.basename(mask_path) |
|
|
new_base_name = prefix + base_name |
|
|
new_path = os.path.join(save_dir, new_base_name) |
|
|
|
|
|
|
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
augmented_nii = nib.Nifti1Image(augmented_np, affine, header) |
|
|
nib.save(augmented_nii, new_path) |
|
|
logger.info(f"Augmented and saved: {new_path}") |
|
|
|
|
|
|
|
|
end_time = datetime.now() |
|
|
duration = end_time - start_time |
|
|
logger.info(f"Augmentation process completed at {end_time}") |
|
|
logger.info(f"Total processing time: {duration}") |
|
|
logger.info(f"Successfully processed {len(data[dict_to_read])} files") |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Augment and save masks from JSON config.") |
|
|
parser.add_argument("--json_path", required=True, help="Path to the input JSON file.") |
|
|
parser.add_argument("--dict_to_read", required=True, help="Dictionary key to read in JSON.") |
|
|
parser.add_argument("--data_root", required=True, help="Root directory for mask files.") |
|
|
parser.add_argument("--lunglesion_lbl", type=int, required=True, help="Lung label value.") |
|
|
parser.add_argument("--scale_percent", type=int, required=True, help="Lobe label value.") |
|
|
parser.add_argument('--mode', type=str, choices=['shrink', 'grow'], required=True, help="Operation to perform: 'shrink' or 'grow'.") |
|
|
parser.add_argument("--save_dir", required=True, help="Directory to save augmented masks.") |
|
|
parser.add_argument("--random_seed", type=int, default=None, help="Random seed (optional).") |
|
|
parser.add_argument("--prefix", default="aug_", help="Prefix for output files (optional).") |
|
|
parser.add_argument("--log_file", default=None, help="Log file path (optional).") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
augment_and_save_masks_from_json( |
|
|
json_path=args.json_path, |
|
|
dict_to_read=args.dict_to_read, |
|
|
data_root=args.data_root, |
|
|
lunglesion_lbl=args.lunglesion_lbl, |
|
|
scale_percent=args.scale_percent, |
|
|
mode=args.mode, |
|
|
save_dir=args.save_dir, |
|
|
log_file=args.log_file, |
|
|
random_seed=args.random_seed, |
|
|
prefix=args.prefix |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
def improved_grow_logic(label_np, lesion_label=23, lung_labels=[28,29,30,31,32], change_percent=50): |
|
|
""" |
|
|
Improved grow logic with better boundary handling and validation. |
|
|
""" |
|
|
label_np = label_np.copy() |
|
|
lesion_mask = (label_np == lesion_label) |
|
|
lung_mask = np.isin(label_np, lung_labels) |
|
|
cc, num_lesions = cc_label(lesion_mask) |
|
|
|
|
|
for i in range(1, num_lesions + 1): |
|
|
single_lesion_mask = (cc == i) |
|
|
original_volume = np.sum(single_lesion_mask) |
|
|
|
|
|
if original_volume == 0: |
|
|
continue |
|
|
|
|
|
target_volume = int(original_volume * (1 + change_percent / 100.0)) |
|
|
current_mask = single_lesion_mask.copy() |
|
|
struct = ball(1) |
|
|
|
|
|
|
|
|
max_iterations = min(1000, target_volume) |
|
|
stuck_count = 0 |
|
|
|
|
|
for iteration in range(max_iterations): |
|
|
|
|
|
next_mask = binary_dilation(current_mask, structure=struct) |
|
|
|
|
|
|
|
|
valid_expansion = next_mask & lung_mask |
|
|
|
|
|
|
|
|
if np.sum(valid_expansion) == np.sum(current_mask): |
|
|
stuck_count += 1 |
|
|
if stuck_count >= 3: |
|
|
print(f"⚠️ Lesion {i}: growth stopped by boundaries at {np.sum(current_mask)} voxels") |
|
|
break |
|
|
else: |
|
|
stuck_count = 0 |
|
|
|
|
|
current_mask = valid_expansion |
|
|
|
|
|
|
|
|
if np.sum(current_mask) >= target_volume: |
|
|
print(f"✅ Lesion {i}: target volume reached at {np.sum(current_mask)} voxels") |
|
|
break |
|
|
|
|
|
|
|
|
final_volume = np.sum(current_mask) |
|
|
volume_ratio = final_volume / original_volume if original_volume > 0 else 0 |
|
|
|
|
|
if final_volume < original_volume: |
|
|
print(f"❌ Error: Lesion {i} shrunk during growth! Using original.") |
|
|
current_mask = single_lesion_mask |
|
|
|
|
|
|
|
|
label_np[current_mask] = lesion_label |
|
|
|
|
|
print(f"Lesion {i}: {original_volume} → {final_volume} voxels ({volume_ratio:.2f}x)") |
|
|
|
|
|
return label_np |