#####################################################################################
######Creates a large mesh collating all aseg subcortices into one big surface#######
import pyvista as pv
import numpy as np
import pandas as pd
import os
import vtk
import tempfile
import time
from typing import Optional, Union, Sequence, Tuple
from pathlib import Path
from subcortexmesh import template_data_fetch
[docs]
def merge_all(
inputdir: Union[str, Path],
template: str,
toolboxdata: Optional[Union[str, Path]] = None,
metric: Union[str, Sequence[str]] = ['thickness', 'curvature', 'surfarea'],
plot_merged: bool = False,
overwrite: bool = True,
silent: bool = False,
):
"""Merging all subcortical outputs into a single surface object
This function creates a new mesh merging all subcortical meshes outputted by
mesh_metrics() in a given template, for available metrics separately, keeping
their vertex-wise values. It will only work if all subcortices have been
processed. The merged mesh will be saved along the surface meshes in the
directories used as input.
For fslfirst, the cerebella need to have been created in FSL FIRST inside the same
output directory as run_first_all's, naming them "*R_Cereb_first" and "*L_Cereb_first",
and processed with subseg_getvol() and vol2surf(). See subseg_getvol()'s description
for guidance.
Parallel processes: to avoid conflicts, subjects will be skipped if a "isrunning" tmp
file exists to mark them as currently processing. The tmp file is removed at the end or
replaced if 1 hour old. If a process has been interrupted, remove the tmp manually to
rerun a subject before the 1 hour (its path is printed when flagged).
Parameters
----------
inputdir : str, Path
The path where the surface-based metrics were outputted (using mesh_metrics()).
The outputdir will be the inputdir.
template: str
The name of the template the surfaces are supposed to be matching to. For FreeSurfer
outputs, it is 'fsaverage'. For FSL FIRST, it is 'fslfirst'.
toolboxdata : str, Path, optional
The path of the "subcortexmesh_data" package data directory. The default path
is assumed to be the user's home directory (pathlib's Path.home()). Users will
be prompted to download it if not found.
metric: str, Sequence
The name(s) of the metric(s) to be computed as strings. Options are "thickness",
"curvature", "surfarea", and default is all of them.
plot_merged: bool
Whether to plot the resulting merged mesh. Default is False.
overwrite : bool
Whether files are to be overwritten or skipped if already made. Default
is True.
silent : bool
Whether messages about the process are to be printed. Default is False.
"""
###################################################################
###################################################################
#template data is needed
toolboxdata=template_data_fetch(datapath=toolboxdata, template = template)
if template=='fsaverage':
mergedmesh='allaseg'
nroi=19
if template=='fslfirst':
mergedmesh='allfslfirst'
nroi=17
#Subfunctions
#mesh loader function
def load_mesh(path):
reader = vtk.vtkPolyDataReader()
reader.SetFileName(path)
reader.Update()
return reader.GetOutput()
#merge all meshes together, storing the ROI labels as an in-vtk array
def mesh_merger():
appendFilter = vtk.vtkAppendPolyData()
for i, mesh in enumerate(mesh_list):
meshfile=load_mesh(f"{inputdir}/{subid}/{mesh}")
# Add string array with ROI name
tagArray = vtk.vtkIntArray()
tagArray.SetName("roi_id")
tagArray.SetNumberOfTuples(meshfile.GetNumberOfPoints())
tagArray.FillComponent(0, i) # i = ROI index
_ = meshfile.GetPointData().AddArray(tagArray)
appendFilter.AddInputData(meshfile)
appendFilter.Update()
merged_mesh = appendFilter.GetOutput()
if plot_merged:
vis_merged(merged_vtk=merged_mesh)
return merged_mesh
###################################################################
###################################################################
#run on subject output directories
#list subjects in the surface subjects directory
sub_list =[ d for d in os.listdir(inputdir)
if os.path.isdir(os.path.join(inputdir, d))]
subindex=0
for subid in sub_list:
subindex=subindex+1
#unique tmp file to avoid parallel loop conflicts
fname = os.path.join(tempfile.gettempdir(), f"{subid}_isrunning_merge.tmp")
if os.path.exists(fname): #if exists already, and tmp file is younger than 1h, skip subject
tmp_lifetime = (time.time() - os.path.getmtime(fname)) / 3600
if tmp_lifetime < 1:
if not silent:
print(f"{subid} already running (tmp file: {fname}).")
continue
else: #creates tmp
with open(fname, "w"):
pass
if not silent:
print(f"Creating all-aseg surfaces for {subid}... [{subindex}/{len(sub_list)}]")
for measure in ['thickness', 'surfarea', 'curvature']:
if measure not in metric:
continue
if not os.path.exists(f"{inputdir}/{subid}/{mergedmesh}_{measure}.vtk") or overwrite:
#listing files for that metric
mesh_list = [
f for f in os.listdir(f"{inputdir}/{subid}")
if f"{measure}" in f and not f.startswith("all") #explicitly do not list the merged vtk
]
if len(mesh_list) > 0:
if len(mesh_list) != nroi:
if not silent:
print(f"=> {measure} ignored: all subcortices of the {template} template ({nroi}) must be available.")
else:
if not silent:
print(f"=> Merging {measure} ...")
#force the mesh_list follow templates' ROI order as in the lookup table
roi_lookup = pd.read_csv(f"{toolboxdata}/template_data/{template}/surfaces/{mergedmesh}_roi_id.txt",sep="\t")
roi_order = roi_lookup['label'].tolist()
#reorder mesh_list
mesh_list_sorted = []
for roi in roi_order:
#find filename in mesh_list that starts with ROI label
matches = [f for f in mesh_list if f.startswith(roi) and f.endswith(".vtk")]
if matches:
mesh_list_sorted.extend(matches)
mesh_list = mesh_list_sorted
#merge mesh
merged_mesh=mesh_merger()
#save it
#guarantee overwriting
out_path=f"{inputdir}/{subid}/{mergedmesh}_{measure}.vtk"
if os.path.exists(out_path):
os.remove(out_path)
writer = vtk.vtkPolyDataWriter()
writer.SetFileName(out_path)
writer.SetInputData(merged_mesh)
_ = writer.Write()
else:
if not silent:
print(f"No mesh file (.vtk) found at all for {subid}'s surface {measure}.")
else:
if not silent:
print(f"=> {measure} already merged")
os.remove(fname) #cleanup tmp file
###################################################################
###################################################################
#interactive plot function that plots all surfaces together, and gives a slider to space them out from eachother
[docs]
def vis_merged(
merged_vtk: Union[str, Path, vtk.vtkPolyData],
cmap: str = "viridis",
clim: Optional[Tuple[float, float]] = None,
smooth_mesh: Optional[int] = 0
):
"""Interactive 3D viewer for a merged subcortical surface.
Loads a merged mesh produced by merge_all(), separates ROIs by
roi_id, and displays them with a slider to spread the structures
apart from their global centroid.
Authors: Charly H.A. Billaud, Nicolas P.M. Lavarde
Parameters
----------
merged_vtk : str, Path, vtk.vtkPolyData
Path to the merged .vtk file produced by merge_all() or the VTK polydata variable itself
cmap: str
Name of the color map to be assigned to the background volume, as listed in matplotlib's colormaps. Default is "viridis".
clim: Tuple, optional
Sequence of float stating the minimum and maximum value of the color bar. Default is minimum and maximum value.
smooth_mesh: int, optional
Number of iterations of cosmetic smoothing to make the surface appear smoother. Default is 0.
"""
if isinstance(merged_vtk, (str, Path)):
reader = vtk.vtkPolyDataReader()
reader.SetFileName(str(merged_vtk))
reader.Update()
mesh = pv.wrap(reader.GetOutput())
else:
mesh = pv.wrap(merged_vtk)
#appearance smoother
if smooth_mesh is not None and smooth_mesh > 0:
s = vtk.vtkWindowedSincPolyDataFilter()
s.SetInputData(mesh)
s.SetNumberOfIterations(smooth_mesh)
s.SetPassBand(0.001)
s.NonManifoldSmoothingOn()
s.NormalizeCoordinatesOn()
s.Update()
mesh=pv.wrap(s.GetOutput())
if 'roi_id' not in mesh.point_data.keys():
raise ValueError(
f"'roi_id' point array not found in {merged_vtk}. "
"Make sure the file was produced by merge_all()."
)
roi_ids = np.array(mesh.point_data['roi_id']).astype(int)
n_roi = int(roi_ids.max()) + 1
wrapped_meshes = []
for i in range(n_roi):
mask = roi_ids == i
if not mask.any():
continue
sub_ug = mesh.extract_points(mask, adjacent_cells=False)
sub = sub_ug.extract_surface(algorithm='dataset_surface')
wrapped_meshes.append(sub)
all_points = np.vstack([wm.points for wm in wrapped_meshes])
global_centroid = all_points.mean(axis=0)
centroids = [wm.points.mean(axis=0) for wm in wrapped_meshes]
original_points = [wm.points.copy() for wm in wrapped_meshes]
plotter = pv.Plotter()
measure = mesh.active_scalars_name
if measure is None:
raise ValueError(
"No active scalar (surface-based value) found in the mesh. Make sure the file was produced by merge_all()."
)
#compute clim from the full mesh before splitting
scalars_data = mesh.point_data[measure]
if clim is None:
clim = [np.nanmin(scalars_data), np.nanmax(scalars_data)]
for wm in wrapped_meshes:
_ = plotter.add_mesh(wm, scalars=measure, cmap=cmap, clim=clim, nan_color='lightgrey')
# Y flipped as VTK's coord syst not following RAS
plotter.reset_camera()
loc, foc, _ = plotter.camera_position
plotter.camera_position = [loc, foc, (0, -1, 1)]
def update_distance(distfactor):
for wm, centroid, orig in zip(wrapped_meshes, centroids, original_points):
direction = centroid - global_centroid
norm = np.linalg.norm(direction)
if norm > 0:
direction = direction / norm
translation = direction * distfactor
wm.points[:] = orig - centroid + (centroid + translation)
plotter.render()
plotter.add_slider_widget(update_distance, rng=[0, 50], value=0)
if isinstance(merged_vtk, (str, Path)):
plotter.show(title=f"{str(merged_vtk)} - {measure}")
else:
plotter.show(title=f"{measure}")
#####################################################################################
######Flat 2D grid preview of all subcortical ROIs from a merged VTK surface########
[docs]
def vis_merged_flat(
merged_vtk: Union[str, Path, vtk.vtkPolyData],
output_path: Union[str, Path] = "flat_plot.png",
ncols: int = 4,
silent: bool = False,
scalars: str = None,
cmap: str = 'viridis',
clim: Optional[Tuple[float, float]] = None,
smooth_mesh: Optional[int] = 0,
toolboxdata: Optional[Union[str, Path]] = None
):
"""Flat 2D grid preview of all subcortical ROIs from a merged VTK surface
Reads a merged .vtk file produced by merge_all() which contains all
subcortical structures concatenated into a single mesh, where each ROI
was assigned a number ID. For each ROI, the corresponding sub-mesh is
extracted, centered, and placed in a 2D grid layout saved as PNG.
The layout is paired (left/right structures in adjacent columns); in
a top and bottom views.
Authors: Nicolas P.M. Lavarde, Charly H.A. Billaud
Parameters
----------
merged_vtk : str, Path, vtk.vtkPolyData
Path to the merged .vtk file produced by merge_all().
output_path : str, Path
Path for the output PNG file. Default is 'flat_plot_preview.png'.
ncols : int
Number of columns in the grid layout. Default is 4.
silent : bool
Whether to suppress progress messages. Default is False.
scalars: str
Name of the vertex-wise value which was assigned. Default is whatever measure was assigned
by mesh_metrics() ('thickness', 'curvature', or 'surfarea'). Can also be the 'roi_id'
assigned by merge_all().
cmap: str
Name of the color map to be assigned to the background volume, as listed in matplotlib's colormaps. Default is "viridis".
clim: Tuple, optional
Sequence of float stating the minimum and maximum value of the color bar. Default is minimum and maximum value.
smooth_mesh: int, optional
Number of iterations of cosmetic smoothing to make the surface appear smoother. Default is 0.
toolboxdata : str, Path, optional
The path of the "subcortexmesh_data" package data directory. The default path
is assumed to be the user's home directory (pathlib's Path.home()). Users will
be prompted to download it if not found.
"""
###################################################################
##########################LOAD MERGED MESH#########################
if isinstance(merged_vtk, (str, Path)):
reader = vtk.vtkPolyDataReader()
reader.SetFileName(str(merged_vtk))
reader.Update()
mesh = pv.wrap(reader.GetOutput())
else:
mesh = pv.wrap(merged_vtk)
if 'roi_id' not in mesh.point_data.keys():
raise ValueError(
f"'roi_id' point array not found in {merged_vtk}. "
"Make sure the file was produced by merge_all()."
)
roi_ids = np.array(mesh.point_data['roi_id']).astype(int)
#appearance smoother
if smooth_mesh is not None and smooth_mesh > 0:
s = vtk.vtkWindowedSincPolyDataFilter()
s.SetInputData(mesh)
s.SetNumberOfIterations(smooth_mesh)
s.SetPassBand(0.001)
s.NonManifoldSmoothingOn()
s.NormalizeCoordinatesOn()
s.Update()
mesh=pv.wrap(s.GetOutput())
###################################################################
##############################DEFINE SCALARS#######################
#default scalar is metric
measure = mesh.active_scalars_name
if measure is None:
raise ValueError(
"No active scalar (surface-based value) found in the mesh. Make sure the file was produced by merge_all().")
if scalars is None:
scalars=measure
scalars_data = mesh.point_data[scalars]
else:
mesh.set_active_scalars(scalars)
scalars_data = mesh.point_data[scalars]
#compute clim from the full mesh before splitting
if clim is None:
clim = [np.nanmin(scalars_data), np.nanmax(scalars_data)]
###################################################################
######################EXTRACT ROI SUB-MESHES#######################
submeshes = []
#get number of ROIs based on the available roi_id scalar (assigned by merge_all())
_N_ROI_MAP = {95718: 19, 82412: 17}
n_roi = _N_ROI_MAP.get(mesh.n_points)
if n_roi is None:
raise ValueError(
f"Cannot auto-detect n_roi for mesh with {mesh.n_points} vertices. "
"Expected 95718 (fsaverage, 19 ROIs) or 82412 (fslfirst, 17 ROIs)."
)
roi_names = {i: f"ROI {i}" for i in range(n_roi)}
if mesh.GetNumberOfPoints()==95718:
template='fsaverage'
elif mesh.GetNumberOfPoints()==82412:
template='fslfirst'
_N_ROI_MESH = {19: 'allaseg', 17: 'allfslfirst'}
mergedmesh = _N_ROI_MESH[n_roi]
toolboxdata = template_data_fetch(datapath=toolboxdata, template=template)
roi_id_path = f"{toolboxdata}/template_data/{template}/surfaces/{mergedmesh}_roi_id.txt"
roi_lookup = pd.read_csv(roi_id_path, sep='\t')
roi_names = dict(zip(roi_lookup['id'].astype(int), roi_lookup['label']))
for i in range(n_roi):
mask = roi_ids == i
n_pts = int(mask.sum())
if n_pts == 0:
if not silent:
print(f" ROI {i}: no vertices found, skipping.")
submeshes.append(None)
continue
# extract_points keeps only faces whose ALL vertices are in the mask
sub_ug = mesh.extract_points(mask, adjacent_cells=False)
sub = sub_ug.extract_surface(algorithm='dataset_surface')
# center each structure at origin so cells don't overlap
centroid = sub.points.mean(axis=0)
sub = sub.copy()
sub.points -= centroid
submeshes.append(sub)
###################################################################
####################BUILD PAIR/SINGLETON ROWS######################
# Pass 1: map stripped names to their left/right roi_id
left_map = {}
right_map = {}
for roi_id, label in roi_names.items():
ll = label.lower()
if ll.startswith("left-"):
left_map[ll[5:]] = roi_id
elif ll.startswith("right-"):
right_map[ll[6:]] = roi_id
# Pass 2: build ordered row list — pairs first (roi_id order), singletons last
seen = set()
rows = []
singleton_ids = []
for roi_id in sorted(roi_names.keys()):
if roi_id in seen:
continue
label = roi_names[roi_id]
ll = label.lower()
if ll.startswith("left-"):
key = ll[5:]
right_id = right_map.get(key)
if right_id is not None:
rows.append(('pair', roi_id, right_id))
seen.add(roi_id)
seen.add(right_id)
else:
singleton_ids.append(roi_id)
seen.add(roi_id)
elif ll.startswith("right-"):
key = ll[6:]
left_id = left_map.get(key)
# only fires when right precedes left in roi_id order (atypical)
if left_id is not None and left_id not in seen:
rows.append(('pair', left_id, roi_id))
seen.add(roi_id)
seen.add(left_id)
else:
singleton_ids.append(roi_id)
seen.add(roi_id)
else:
singleton_ids.append(roi_id)
seen.add(roi_id)
for roi_id in singleton_ids:
rows.append(('singleton', roi_id))
###################################################################
##########################GRID RENDERING###########################
# Y flipped as VTK's coord syst not following RAS
nrows = len(rows)
plotter = pv.Plotter(
shape=(nrows, 4),
off_screen=True,
window_size=(4 * 300, nrows * 300),
)
plotter.set_background("white")
# fixed orthographic cameras: Y is superior in VTK coords
cam_top = ((0, -500, 0), (0, 0, 0), (0, -1, -1))
cam_bottom = ((0, 500, 0), (0, 0, 0), (0, -1, -1))
def _place_mesh(sub_idx, col, row_idx, camera, text_label, view_label, is_empty_label=False):
plotter.subplot(row_idx, col)
sub = submeshes[sub_idx] if sub_idx < len(submeshes) else None
if sub is not None:
plotter.add_mesh(
sub,
show_edges=False,
smooth_shading=True,
ambient=0.3,
diffuse=0.7,
scalars=sub.point_data[scalars],
cmap=cmap,
clim=clim,
nan_color='lightgrey',
show_scalar_bar=False,
)
plotter.add_text(text_label, position="upper_edge", font_size=8, color="black")
plotter.add_text(view_label, position="lower_edge", font_size=8, color="black")
else:
plotter.add_text(
f"{text_label}\n(empty)" if not is_empty_label else text_label,
position="upper_edge",
font_size=8,
color="gray")
plotter.camera.parallel_projection = True
plotter.camera_position = camera
plotter.reset_camera()
for row_idx, row_def in enumerate(rows):
if row_def[0] == 'pair':
_, left_id, right_id = row_def
left_label = roi_names.get(left_id, f"ROI {left_id}")
right_label = roi_names.get(right_id, f"ROI {right_id}")
_place_mesh(left_id, 0, row_idx, cam_top, left_label, "top view")
_place_mesh(right_id, 1, row_idx, cam_top, right_label, "top view")
_place_mesh(right_id, 2, row_idx, cam_bottom, right_label, "bottom view")
_place_mesh(left_id, 3, row_idx, cam_bottom, left_label, "bottom view")
else: # singleton
_, roi_id = row_def
label = roi_names.get(roi_id, f"ROI {roi_id}")
_place_mesh(roi_id, 1, row_idx, cam_top, label, "top view")
plotter.subplot(row_idx, 1)
plotter.set_background("white")
_place_mesh(roi_id, 2, row_idx, cam_bottom, label, "bottom view")
plotter.subplot(row_idx, 3)
plotter.set_background("white")
#add scalar bar to last empty square
plotter.subplot(nrows - 1, 3)
plotter.add_text(scalars, position="upper_edge", font_size=8, color="black")
plotter.add_scalar_bar(
vertical=True,
position_x=0.3,
position_y=0.15,
width=0.8,
height=0.6,
fmt="%.2f",
label_font_size=10
)
plotter.screenshot(str(output_path))
if not silent:
print(f"Saved flat plot to: {output_path}")