Source code for multimodars.ccta.labeling

from __future__ import annotations


from pathlib import Path
import trimesh
import numpy as np

from ..multimodars import (
    find_centerline_bounded_points_simple,
    remove_occluded_points_ray_triangle,
    clean_outlier_points,
    build_adjacency_map,
    find_points_by_cl_region,
    PyCenterline,
)
from .._converters import numpy_to_centerline
from ..io.read_geometrical import read_mesh
from .debug_plots import plot_results_key


[docs] def label_geometry( path_ccta_geometry: Path | str | trimesh.Trimesh, path_centerline_aorta: Path | str | PyCenterline, path_centerline_rca: Path | str | PyCenterline, path_centerline_lca: Path | str | PyCenterline, anomalous_rca: bool = False, anomalous_lca: bool = False, n_points_intramural: int = 120, bounding_sphere_radius_mm: float = 3.0, tolerance_float: float = 1e-6, control_plot: bool = True, ) -> tuple[dict, tuple[PyCenterline, PyCenterline, PyCenterline]]: """Label CCTA mesh vertices as aorta, RCA, or LCA using centerline-based region detection. Loads a 3-D surface mesh and three centerlines (aorta, RCA, LCA), then assigns each mesh vertex to one of the anatomical regions. For anomalous vessels an additional occlusion-removal step uses ray-triangle intersection to strip intramural segments, followed by adjacency-map reclassification to clean up isolated mis-labelled vertices. Herfore, a ray is cast from every aorta point to the centerline points of the anomalous section and if 3 faces are intersected by the ray the points from the first face must correspond to the intramural section. Parameters ---------- path_ccta_geometry : Path or str Path to the CCTA surface mesh file (any format supported by :func:`multimodars.io.read_geometrical.read_mesh`). path_centerline_aorta : Path or str Path to a CSV file containing the aortic centerline (comma-delimited, columns: x, y, z, …). path_centerline_rca : Path or str Path to a CSV file containing the RCA centerline. path_centerline_lca : Path or str Path to a CSV file containing the LCA centerline. anomalous_rca : bool, optional When ``True`` applies ray-triangle occlusion removal to the RCA region to handle anomalous (intramural) courses. Default is ``False``. anomalous_lca : bool, optional When ``True`` applies ray-triangle occlusion removal to the LCA region. Default is ``False``. n_points_intramural : int, optional Number of coronary centerline points examined during occlusion removal (the intramural segment length). Default is ``120``. bounding_sphere_radius_mm : float, optional Radius in millimetres of the rolling sphere used to collect candidate mesh vertices around each centerline point. Default is ``3.0``. tolerance_float : float, optional Distance tolerance used when matching mesh vertices to points during face lookup. Default is ``1e-6``. control_plot : bool, optional When ``True`` opens an interactive 3-D scene showing the labelled mesh after processing. Default is ``True``. Returns ------- results : dict Dictionary with keys: * ``"mesh"`` - the original :class:`trimesh.Trimesh` object. * ``"aorta_points"`` - list of ``(x, y, z)`` tuples for aortic vertices. * ``"rca_points"`` - list of ``(x, y, z)`` tuples for RCA vertices. * ``"lca_points"`` - list of ``(x, y, z)`` tuples for LCA vertices. * ``"rca_removed_points"`` - RCA vertices removed by occlusion detection. * ``"lca_removed_points"`` - LCA vertices removed by occlusion detection. centerlines : tuple A 3-tuple ``(cl_rca, cl_lca, cl_aorta)`` of ``PyCenterline`` objects. Raises ------ Exception Re-raises any error that occurs while reading the mesh or centerline files, after printing a descriptive message. """ if isinstance(path_ccta_geometry, trimesh.Trimesh): mesh = path_ccta_geometry print( f"Using provided mesh: {len(mesh.vertices)} vertices, {len(mesh.faces)} faces" ) else: try: mesh = read_mesh(path_ccta_geometry) print( f"Loaded mesh: {len(mesh.vertices)} vertices, {len(mesh.faces)} faces" ) except Exception as e: print(f"Error reading CCTA mesh from {path_ccta_geometry}: {e}") raise if isinstance(path_centerline_aorta, PyCenterline): cl_aorta = path_centerline_aorta print(f"Using provided aorta centerline: {len(cl_aorta.points)} points") elif isinstance(path_centerline_aorta, np.ndarray): cl_aorta = numpy_to_centerline(path_centerline_aorta) print(f"Using provided aorta centerline: {len(cl_aorta.points)} points") else: try: cl_aorta_raw = np.genfromtxt(path_centerline_aorta, delimiter=",") cl_aorta = numpy_to_centerline(cl_aorta_raw) print(f"Loaded aorta centerline: {len(cl_aorta.points)} points") except Exception as e: print(f"Error reading Aorta centerline from {path_centerline_aorta}: {e}") raise if isinstance(path_centerline_lca, PyCenterline): cl_lca = path_centerline_lca print(f"Using provided LCA centerline: {len(cl_lca.points)} points") elif isinstance(path_centerline_lca, np.ndarray): cl_lca = numpy_to_centerline(path_centerline_lca) print(f"Using provided LCA centerline: {len(cl_lca.points)} points") else: try: cl_lca_raw = np.genfromtxt(path_centerline_lca, delimiter=",") cl_lca = numpy_to_centerline(cl_lca_raw) print(f"Loaded LCA centerline: {len(cl_lca.points)} points") except Exception as e: print(f"Error reading LCA centerline from {path_centerline_lca}: {e}") raise if isinstance(path_centerline_rca, PyCenterline): cl_rca = path_centerline_rca print(f"Using provided RCA centerline: {len(cl_rca.points)} points") elif isinstance(path_centerline_rca, np.ndarray): cl_rca = numpy_to_centerline(path_centerline_rca) print(f"Using provided RCA centerline: {len(cl_rca.points)} points") else: try: cl_rca_raw = np.genfromtxt(path_centerline_rca, delimiter=",") cl_rca = numpy_to_centerline(cl_rca_raw) print(f"Loaded RCA centerline: {len(cl_rca.points)} points") except Exception as e: print(f"Error reading RCA centerline from {path_centerline_rca}: {e}") raise points_list = [tuple(vertex) for vertex in mesh.vertices.tolist()] # Rust implementation using a rolling sphere with fixed radius rca_points_found = find_centerline_bounded_points_simple( cl_rca, points_list, bounding_sphere_radius_mm ) lca_points_found = find_centerline_bounded_points_simple( cl_lca, points_list, bounding_sphere_radius_mm ) print(f"\nRCA points found: {len(rca_points_found)}") print(f"LCA points found: {len(lca_points_found)}") rca_removed_points = [] lca_removed_points = [] if anomalous_rca: print("Applying occlusion removal for anomalous RCA...") rca_faces_for_rust = _prepare_faces_for_rust( mesh, points=rca_points_found, tol=tolerance_float ) # Rust implementation, that creates ray between aortic and coronary centerline, and # removes faces if 3 consecutive faces are "pierced" by the ray final_rca_points_found = remove_occluded_points_ray_triangle( centerline_coronary=cl_rca, centerline_aorta=cl_aorta, range_coronary=n_points_intramural, points=rca_points_found, faces=rca_faces_for_rust, ) rca_removed_points = [ p for p in rca_points_found if p not in final_rca_points_found ] print(f"RCA: relabeled {len(rca_removed_points)} points in intramual course") else: final_rca_points_found = rca_points_found.copy() if anomalous_lca: print("Applying occlusion removal for anomalous LCA...") lca_faces_for_rust = _prepare_faces_for_rust( mesh, points=lca_points_found, tol=tolerance_float ) final_lca_points_found = remove_occluded_points_ray_triangle( centerline_coronary=cl_lca, centerline_aorta=cl_aorta, range_coronary=n_points_intramural, points=lca_points_found, faces=lca_faces_for_rust, ) lca_removed_points = [ p for p in lca_points_found if p not in final_lca_points_found ] print(f"LCA: relabeled {len(lca_removed_points)} points in intramual course") else: final_lca_points_found = lca_points_found.copy() print("\nRemoving LCA and RCA island points...") aortic_points = _find_aortic_points( mesh.vertices, final_rca_points_found, final_lca_points_found ) print(f"length before: {len(final_lca_points_found)}") final_lca_points, final_aortic_points = clean_outlier_points( final_lca_points_found, aortic_points, 2.0, 0.4 ) # based on patient data, only precleaning anyways, rest done by final_reclassification final_rca_points, _ = clean_outlier_points( final_rca_points_found, final_aortic_points, 2.0, 0.4 ) final_aortic_points = _find_aortic_points( mesh.vertices, final_rca_points, final_lca_points ) # add also the rca_removed points and lca_removed points to aortic points final_aortic_points = list( set(final_aortic_points) | set(rca_removed_points) | set(lca_removed_points) ) print(f"length after: {len(final_lca_points)}") results = { "mesh": mesh, "aorta_points": final_aortic_points, "rca_points": final_rca_points_found, "lca_points": final_lca_points, "rca_removed_points": rca_removed_points, "lca_removed_points": lca_removed_points, } # final reclassification based on adjacency map print("\nApplying final reclassification based on adjacency map...") new_results = _final_reclassification(results) print(f"aorta_points:{len(new_results['aorta_points'])}") print(f"rca_points:{len(new_results['rca_points'])}") print(f"lca_points:{len(new_results['lca_points'])}") print(f"rca_removed_points:{len(new_results['rca_removed_points'])}") print(f"lca_removed_points:{len(new_results['lca_removed_points'])}") if control_plot: plot_results_key( new_results, aorta_points=True, rca_points=True, lca_points=True, rca_removed_points=True, proximal_points=True, distal_points=False, anomalous_points=False, cl_rca=cl_rca, cl_lca=cl_lca, cl_aorta=cl_aorta, ) return new_results, (cl_rca, cl_lca, cl_aorta)
def _prepare_faces_for_rust( mesh: trimesh.Trimesh, *, points=None, face_indices=None, tol: float = 1e-6 ): """Convert selected mesh faces to the Rust-friendly format. Parameters ---------- mesh : trimesh.Trimesh Source mesh whose faces will be converted. points : list of tuple, optional If provided and *face_indices* is ``None``, the face indices are derived by finding which faces reference vertices closest to these points (within *tol*). face_indices : list of int, optional Explicit list of face indices to convert. When given, *points* is ignored. When both are ``None``, all faces are converted. tol : float, optional Distance tolerance for vertex matching when using *points*. Default is ``1e-6``. Returns ------- list of tuple Each element is a ``((x0,y0,z0), (x1,y1,z1), (x2,y2,z2))`` triple of vertex coordinate tuples suitable for passing to Rust functions. """ if face_indices is None: if points is not None: face_indices = _find_faces_for_points(mesh, points, tol=tol) else: face_indices = list(range(len(mesh.faces))) rust_faces = [] for fi in face_indices: face = mesh.faces[fi] v0 = tuple(map(float, mesh.vertices[face[0]])) v1 = tuple(map(float, mesh.vertices[face[1]])) v2 = tuple(map(float, mesh.vertices[face[2]])) rust_faces.append((v0, v1, v2)) return rust_faces def _find_faces_for_points(mesh: trimesh.Trimesh, points_found, tol: float = 1e-6): """Find face indices whose vertices are within tolerance of the given points. For each point in *points_found* the nearest mesh vertex is located. Any face that references at least one of those vertices is included in the result. Parameters ---------- mesh : trimesh.Trimesh Mesh to search. points_found : array-like of shape (N, 3) Query points. tol : float, optional Maximum distance from a query point to a mesh vertex for the vertex to be considered a match. Default is ``1e-6``. Returns ------- list of int Indices into ``mesh.faces`` for all faces that contain at least one matched vertex. Returns an empty list when *points_found* is empty or no vertices fall within *tol*. """ points_array = np.asarray(points_found, dtype=np.float64) if points_array.size == 0: return [] found_vertex_indices = set() verts = mesh.vertices for p in points_array: distances = np.linalg.norm(verts - p, axis=1) closest_idx = int(np.argmin(distances)) if distances[closest_idx] <= tol: found_vertex_indices.add(closest_idx) if not found_vertex_indices: return [] face_indices = [] for i, face in enumerate(mesh.faces): if ( (face[0] in found_vertex_indices) or (face[1] in found_vertex_indices) or (face[2] in found_vertex_indices) ): face_indices.append(i) return face_indices def _find_aortic_points(all_vertices, rca_points, lca_points): """Return mesh vertices that belong neither to the RCA nor to the LCA region. Parameters ---------- all_vertices : array-like of shape (N, 3) All vertex coordinates of the mesh. rca_points : list of tuple Vertices classified as RCA. lca_points : list of tuple Vertices classified as LCA. Returns ------- list of tuple ``(x, y, z)`` tuples for vertices not present in *rca_points* or *lca_points*. """ rca_set = set(rca_points) lca_set = set(lca_points) aortic_points = [ tuple(vertex) for vertex in all_vertices if tuple(vertex) not in rca_set and tuple(vertex) not in lca_set ] return aortic_points def _final_reclassification(results: dict) -> dict: """Refine vertex labels using a mesh adjacency map. Applies two adjacency-based correction rules: * **Logic A** - An isolated RCA or LCA vertex (no same-label neighbours) is re-assigned to the aorta class. * **Logic B** - A vertex that was removed by occlusion detection but whose neighbours are predominantly (> 70 %) the corresponding coronary label is restored to that label. Parameters ---------- results : dict Dictionary produced by :func:`label_geometry` containing keys ``"mesh"``, ``"rca_points"``, ``"lca_points"``, ``"rca_removed_points"``, and ``"lca_removed_points"``. Returns ------- dict Updated dictionary with the same keys as *results* plus ``"aorta_points"``, with corrected point lists. """ mesh = results["mesh"] n_vertices = len(mesh.vertices) # 1. Create a coordinate -> index map for fast lookup coord_to_idx = {tuple(coord): i for i, coord in enumerate(mesh.vertices)} # 2. Create the initial label array (Default to 0/Aorta) labels = np.zeros(n_vertices, dtype=np.uint8) # Labels based on existing result lists for pt in results["rca_points"]: if pt in coord_to_idx: labels[coord_to_idx[pt]] = 1 for pt in results["lca_points"]: if pt in coord_to_idx: labels[coord_to_idx[pt]] = 2 for pt in results["rca_removed_points"]: if pt in coord_to_idx: labels[coord_to_idx[pt]] = 3 for pt in results["lca_removed_points"]: if pt in coord_to_idx: labels[coord_to_idx[pt]] = 4 # 3. Build Adjacency Map adj_map = build_adjacency_map(mesh.faces.tolist()) new_labels = labels.copy() # 4. Apply logic for i in range(n_vertices): neighbors = list(adj_map.get(i, [])) if not neighbors: continue neighbor_labels = labels[neighbors] current_label = labels[i] # LOGIC A: Isolated RCA/LCA -> Aorta if current_label == 1 and not np.any(neighbor_labels == 1): new_labels[i] = 0 elif current_label == 2 and not np.any(neighbor_labels == 2): new_labels[i] = 0 # LOGIC B: Removed RCA/LCA points with most neighbours RCA/LCA -> RCA/LCA # If I am RCA_REMOVED(3) but MOST neighbors are NOT removed (e.g., they are RCA) elif current_label == 3: # "Most" here defined as > 70% non_removed_neighbors = np.sum(neighbor_labels == 1) if non_removed_neighbors > (len(neighbors) * 0.7): new_labels[i] = 1 elif current_label == 4: non_removed_neighbors = np.sum(neighbor_labels == 2) if non_removed_neighbors > (len(neighbors) * 0.7): new_labels[i] = 2 # 5. Convert back to coordinate lists for results dict updated_results = { "mesh": mesh, "rca_points": [ tuple(mesh.vertices[i]) for i in range(n_vertices) if new_labels[i] == 1 ], "lca_points": [ tuple(mesh.vertices[i]) for i in range(n_vertices) if new_labels[i] == 2 ], "rca_removed_points": [ tuple(mesh.vertices[i]) for i in range(n_vertices) if new_labels[i] == 3 ], "lca_removed_points": [ tuple(mesh.vertices[i]) for i in range(n_vertices) if new_labels[i] == 4 ], } updated_results["aorta_points"] = [ tuple(mesh.vertices[i]) for i in range(n_vertices) if new_labels[i] == 0 ] return updated_results
[docs] def label_anomalous_region( centerline, frames, results: dict, results_key: str = "rca_points", debug_plot: bool = False, ) -> dict: """Partition a coronary region into proximal, anomalous, and distal sub-regions. Uses the intravascular imaging frames to determine where along the centerline the anomalous (intramural) segment begins and ends, then tags each mesh vertex accordingly. Parameters ---------- centerline : PyCenterline Centerline of the coronary vessel of interest. frames : list of PyFrame Ordered list of intravascular imaging frames for the vessel. results : dict Labelled results dictionary (e.g. from :func:`label_geometry`). Must contain the key specified by *results_key*. results_key : str, optional Key in *results* whose point list is partitioned. Default is ``"rca_points"``. debug_plot : bool, optional When ``True`` opens an interactive visualisation of the three sub-regions. Default is ``False``. Returns ------- dict The input *results* dictionary extended with three new keys: * ``"proximal_points"`` - vertices proximal to the anomalous segment. * ``"distal_points"`` - vertices distal to the anomalous segment. * ``"anomalous_points"`` - vertices within the anomalous segment. """ proximal_points, distal_points, anomalous_points = find_points_by_cl_region( centerline=centerline, frames=frames, points=results[results_key], ) results["proximal_points"] = proximal_points results["distal_points"] = distal_points results["anomalous_points"] = anomalous_points all_coronary = ( set(results.get("rca_points", [])) | set(results.get("lca_points", [])) | set(proximal_points) | set(distal_points) | set(anomalous_points) ) results["aorta_points"] = [ tuple(v) for v in results["mesh"].vertices if tuple(v) not in all_coronary ] print("\nApplying anomalous labeling based on aligned intravascular frames...") print(f"proximal_points: {len(results['proximal_points'])}") print(f"distal_points: {len(results['distal_points'])}") print(f"anomalous_points: {len(results['anomalous_points'])}") if debug_plot: plot_results_key( results=results, aorta_points=False, rca_points=False, lca_points=False, rca_removed_points=False, proximal_points=True, distal_points=True, anomalous_points=True, cl_rca=centerline, cl_lca=None, cl_aorta=None, ) return results
[docs] def label_branches( centerline, results: dict, results_key: str = "rca_points", branch_id: int | list[int] = 0, bounding_sphere_radius_mm: float = 3.0, ) -> dict: """Partition a coronary region into main branch and per-side-branch point sets. Parameters ---------- centerline : PyCenterline Centerline of the coronary vessel of interest. results : dict Labelled results dictionary (e.g. from :func:`label_geometry`). Must contain the key specified by *results_key*. results_key : str, optional Key in *results* whose point list is partitioned. Default is ``"rca_points"``. branch_id : int or list of int, optional Branch index or list of branch indices whose combined points form the main branch (e.g. ``[0, 1]`` for LAD + LCx). Default is ``0``. bounding_sphere_radius_mm : float, optional Radius of the rolling sphere used to collect candidate vertices. Default is ``3.0``. Returns ------- dict The input *results* dictionary extended with: * ``"{results_key}_main"`` - vertices along the main branch(es). * ``"{results_key}_side"`` - all remaining vertices (aggregate). * ``"{results_key}_side_{k}"`` - vertices near side branch *k*, one key per side branch discovered in *centerline*. A point may appear in more than one side-branch set when it sits near a bifurcation; the Voronoi inside :func:`discretize_vessel` resolves the assignment. """ branch_ids = [branch_id] if isinstance(branch_id, int) else list(branch_id) main_id_set = set(branch_ids) # Collect main-branch points. main_set: set = set() for bid in branch_ids: branch = centerline.get_branch(bid) points_found = find_centerline_bounded_points_simple( branch, results[results_key], bounding_sphere_radius_mm ) main_set.update(points_found) main_points = [p for p in results[results_key] if p in main_set] side_points = [p for p in results[results_key] if p not in main_set] results[f"{results_key}_main"] = main_points results[f"{results_key}_side"] = side_points # Split side points per individual side branch. n_branches = len(centerline.branch_start_indices) side_branch_ids = [k for k in range(n_branches) if k not in main_id_set] print(f"\nBranch labeling for '{results_key}' (branch_ids={branch_ids}):") print(f" {results_key}_main: {len(main_points)}") print(f" {results_key}_side: {len(side_points)}") for k in side_branch_ids: branch_k = centerline.get_branch(k) branch_k_points = find_centerline_bounded_points_simple( branch_k, side_points, bounding_sphere_radius_mm ) results[f"{results_key}_side_{k}"] = branch_k_points print(f" {results_key}_side_{k}: {len(branch_k_points)}") return results