Source code for spline_mesher

import os
from typing import Dict, Optional, Tuple

from numpy import float64, ndarray, uint64
from SimpleITK.SimpleITK import Image

os.environ["NUMEXPR_MAX_THREADS"] = "16"

import logging
import pickle
import sys
import time
from itertools import chain
from pathlib import Path

import gmsh
import numpy as np
import plotly.io as pio
import SimpleITK as sitk

from pyhexspline import cortical_sanity as csc
from pyhexspline.gmsh_mesh_builder import Mesher, TrabecularVolume
from pyhexspline.quad_refinement import QuadRefinement
from pyhexspline.spline_volume import OCC_volume

pio.renderers.default = "browser"
LOGGING_NAME = "MESHING"
# flake8: noqa: E402


[docs] class HexMesh: def __init__( self, settings_dict: dict, img_dict: dict, sitk_image: Optional[Image] = None, logger: Optional[logging.Logger] = None, ): self.settings_dict = settings_dict self.img_dict = img_dict self.logger = logger if sitk_image is not None: self.sitk_image = sitk_image # imports the image else: self.sitk_image = None # reads the image from img_path_ext
[docs] def mesher( self, ) -> Tuple[ Dict[uint64, ndarray], Dict[uint64, ndarray], int, Dict[int, ndarray], Dict[int, ndarray], ndarray, ndarray, float64, float64, Dict[uint64, ndarray], Dict[uint64, ndarray], ndarray, ]: """ Perform the meshing process for cortical and trabecular volumes. This function is the entry point for the meshing process, which includes reading image data, generating splines, performing sanity checks, and creating the mesh using Gmsh. It also handles the configuration of various meshing parameters and settings. Returns: Tuple: A tuple containing the following elements: - nodes (Dict[uint64, ndarray]): Dictionary of node tags. - elms (Dict[uint64, ndarray]): Dictionary of element tags. - nb_nodes (int): Number of nodes in the model. - centroids_cort_dict (Dict[int, ndarray]): Dictionary of cortical element centroids. - centroids_trab_dict (Dict[int, ndarray]): Dictionary of trabecular element centroids. - elm_vol_cort (ndarray): Array of cortical element volumes. - elm_vol_trab (ndarray): Array of trabecular element volumes. - radius_roi_cort (float64): Radius of the largest cortical ROI in mm. - radius_roi_trab (float64): Radius of the largest trabecular ROI in mm. - bnds_bot (Dict[uint64, ndarray]): Dictionary of bottom boundary nodes. - bnds_top (Dict[uint64, ndarray]): Dictionary of top boundary nodes. - reference_point_coord (ndarray): Coordinates of the reference point in mm. Raises: AssertionError: If the number of volumes and centroids does not match. RuntimeError: If negative element volumes are detected, which implies that the simulation would fail. """ logger = logging.getLogger(LOGGING_NAME) logger.info("Starting meshing script...") start = time.time() img_basepath = self.img_dict["img_basepath"] img_basefilename = self.img_dict["img_basefilename"] str(Path(img_basepath, img_basefilename, img_basefilename)) if self.sitk_image is None: img_path_ext = str( Path( img_basepath, img_basefilename, img_basefilename + "_CORT_MASK_cap.mhd", ) ) sitk_image = None else: img_path_ext = None sitk_image = self.sitk_image filepath_ext = str(Path(img_basepath, img_basefilename, img_basefilename)) geo_unrolled_filename = str( Path( self.img_dict["outputpath"], img_basefilename, img_basefilename + ".geo_unrolled", ) ) mesh_file_path = str( Path( self.img_dict["meshpath"], img_basefilename, img_basefilename, ) ) # Spline_volume settings ASPECT = int(self.settings_dict["aspect"]) SLICE = int(self.settings_dict["_slice"]) UNDERSAMPLING = int(self.settings_dict["undersampling"]) SLICING_COEFFICIENT = int(self.settings_dict["slicing_coefficient"]) INSIDE_VAL = float(self.settings_dict["inside_val"]) OUTSIDE_VAL = float(self.settings_dict["outside_val"]) LOWER_THRESH = float(self.settings_dict["lower_thresh"]) UPPER_THRESH = float(self.settings_dict["upper_thresh"]) S = int(self.settings_dict["s"]) K = int(self.settings_dict["k"]) INTERP_POINTS = int(self.settings_dict["interp_points"]) DP_SIMPLIFICATION_OUTER = int(self.settings_dict["dp_simplification_outer"]) DP_SIMPLIFICATION_INNER = int(self.settings_dict["dp_simplification_inner"]) SHOW_PLOTS = bool(self.settings_dict["show_plots"]) SHOW_GMSH = bool(self.settings_dict["show_gmsh"]) WRITE_MESH = bool(self.settings_dict["write_mesh"]) THICKNESS_TOL = float(self.settings_dict["thickness_tol"]) PHASES = int(self.settings_dict["phases"]) # spline_mesher settings N_LONGITUDINAL = int( self.settings_dict["n_elms_longitudinal"] // SLICING_COEFFICIENT ) N_TRANSVERSE_CORT = int(self.settings_dict["n_elms_transverse_cort"]) N_TRANSVERSE_TRAB = int(self.settings_dict["n_elms_transverse_trab"]) N_RADIAL = int(self.settings_dict["n_elms_radial"] // 4) ELM_ORDER = int(self.settings_dict["mesh_order"]) QUAD_REFINEMENT = bool(self.settings_dict["trab_refinement"]) MESH_ANALYSIS = bool(self.settings_dict["mesh_analysis"]) ELLIPSOID_FITTING = bool(self.settings_dict["ellipsoid_fitting"]) # ! obscured from settings by design DEBUG_ORIENTATION = 0 # 0: no debug, 1: debug cortical_v = OCC_volume( sitk_image, img_path_ext, filepath_ext, geo_unrolled_filename, ASPECT=ASPECT, SLICE=SLICE, UNDERSAMPLING=UNDERSAMPLING, SLICING_COEFFICIENT=SLICING_COEFFICIENT, INSIDE_VAL=INSIDE_VAL, OUTSIDE_VAL=OUTSIDE_VAL, LOWER_THRESH=LOWER_THRESH, UPPER_THRESH=UPPER_THRESH, S=S, K=K, INTERP_POINTS=INTERP_POINTS, DP_SIMPLIFICATION_OUTER=DP_SIMPLIFICATION_OUTER, DP_SIMPLIFICATION_INNER=DP_SIMPLIFICATION_INNER, debug_orientation=DEBUG_ORIENTATION, show_plots=SHOW_PLOTS, thickness_tol=THICKNESS_TOL, phases=PHASES, ) if cortical_v.sitk_image is None: img = sitk.PermuteAxes(sitk.ReadImage(cortical_v.img_path), [1, 2, 0]) else: img = self.sitk_image if cortical_v.show_plots is True: cortical_v.plot_mhd_slice(img) else: logger.info(f"MHD slice\t\t\tshow_plots:\t{cortical_v.show_plots}") image_pad = cortical_v.pad_and_plot(img) cortical_ext, cortical_int = cortical_v.volume_splines_optimized(image_pad) # Cortical surface sanity check cortex = csc.CorticalSanityCheck( MIN_THICKNESS=cortical_v.MIN_THICKNESS, ext_contour=cortical_ext, int_contour=cortical_int, model=cortical_v.filename, save_plot=False, logger=logger, ) z_unique = np.unique(cortical_ext[:, 2]) cortical_ext_split = np.array_split(cortical_ext, len(z_unique)) cortical_int_split = np.array_split(cortical_int, len(z_unique)) cortical_int_sanity = np.zeros(np.shape(cortical_int_split)) for i, _ in enumerate(cortical_ext_split): cortical_int_sanity[i][:, :-1] = cortex.cortical_sanity_check( ext_contour=cortical_ext_split[i], int_contour=cortical_int_split[i], iterator=i, show_plots=False, ) cortical_int_sanity[i][:, -1] = cortical_int_split[i][:, -1] cortical_int_sanity = cortical_int_sanity.reshape(-1, 3) gmsh.initialize() gmsh.clear() gmsh.option.setNumber("General.NumThreads", 6) gmsh.logger.start() mesher = Mesher( geo_unrolled_filename, mesh_file_path, slicing_coefficient=cortical_v.SLICING_COEFFICIENT, n_longitudinal=N_LONGITUDINAL, n_transverse_trab=N_TRANSVERSE_TRAB, n_transverse_cort=N_TRANSVERSE_CORT, n_radial=N_RADIAL, ellipsoid_fitting=ELLIPSOID_FITTING, ) cortex_centroid = np.zeros((len(cortical_ext_split), 3)) cortical_int_sanity_split = np.array_split( cortical_int_sanity, len(np.unique(cortical_int_sanity[:, 2])) ) cortical_ext_centroid = np.zeros( ( np.shape(cortical_ext_split)[0], np.shape(cortical_ext_split)[1], np.shape(cortical_ext_split)[2], ) ) cortical_int_centroid = np.zeros( ( np.shape(cortical_int_split)[0], np.shape(cortical_int_split)[1], np.shape(cortical_int_split)[2], ) ) idx_list_ext = np.zeros((len(cortical_ext_split), 4), dtype=int) idx_list_int = np.zeros((len(cortical_ext_split), 4), dtype=int) intersections_ext = np.zeros((len(cortical_ext_split), 2, 2, 3), dtype=float) intersections_int = np.zeros((len(cortical_ext_split), 2, 2, 3), dtype=float) for i, _ in enumerate(cortical_ext_split): cortex_centroid[i][:-1] = mesher.polygon_tensor_of_inertia( cortical_ext_split[i], cortical_int_sanity_split[i], true_coi=False, ) cortex_centroid[i][-1] = cortical_ext_split[i][0, -1] ( cortical_ext_centroid[i], idx_list_ext[i], intersections_ext_s, ) = mesher.insert_tensor_of_inertia( cortical_ext_split[i], cortex_centroid[i][:-1] ) intersections_ext[i] = intersections_ext_s ( cortical_int_centroid[i], idx_list_int[i], intersections_int_s, ) = mesher.insert_tensor_of_inertia( cortical_int_sanity_split[i], cortex_centroid[i][:-1] ) intersections_int[i] = intersections_int_s cortical_ext_msh = np.reshape(cortical_ext_centroid, (-1, 3)) cortical_int_msh = np.reshape(cortical_int_centroid, (-1, 3)) ( indices_coi_ext, cortical_ext_bspline, intersection_line_tags_ext, cortical_ext_surfs, ) = mesher.gmsh_geometry_formulation(cortical_ext_msh, idx_list_ext) ( indices_coi_int, cortical_int_bspline, intersection_line_tags_int, cortical_int_surfs, ) = mesher.gmsh_geometry_formulation(cortical_int_msh, idx_list_int) intersurface_line_tags = mesher.add_interslice_segments( indices_coi_ext, indices_coi_int ) slices_tags = mesher.add_slice_surfaces( cortical_ext_bspline, cortical_int_bspline, intersurface_line_tags ) intersurface_surface_tags = mesher.add_intersurface_planes( intersurface_line_tags, intersection_line_tags_ext, intersection_line_tags_int, ) intersection_line_tags = np.append( intersection_line_tags_ext, intersection_line_tags_int ) cortical_bspline_tags = np.append(cortical_ext_bspline, cortical_int_bspline) cort_surfs = np.concatenate( ( cortical_ext_surfs, cortical_int_surfs, slices_tags, intersurface_surface_tags, ), axis=None, ) cort_vol_tags = mesher.add_volume( cortical_ext_surfs, cortical_int_surfs, slices_tags, intersurface_surface_tags, ) intersurface_line_tags = np.array(intersurface_line_tags, dtype=int).tolist() # Trabecular meshing trabecular_volume = TrabecularVolume( geo_unrolled_filename, mesh_file_path, slicing_coefficient=cortical_v.SLICING_COEFFICIENT, n_longitudinal=N_LONGITUDINAL, n_transverse_cort=N_TRANSVERSE_CORT, n_transverse_trab=N_TRANSVERSE_TRAB, n_radial=N_RADIAL, QUAD_REFINEMENT=QUAD_REFINEMENT, ellipsoid_fitting=ELLIPSOID_FITTING, ) trabecular_volume.set_length_factor( float(self.settings_dict["center_square_length_factor"]) ) ( trab_point_tags, trab_line_tags_v, trab_line_tags_h, trab_surfs_v, trab_surfs_h, trab_vols, ) = trabecular_volume.get_trabecular_vol(coi_idx=intersections_int) # connection between inner trabecular and cortical volumes trab_cort_line_tags = mesher.trabecular_cortical_connection( coi_idx=indices_coi_int, trab_point_tags=trab_point_tags ) trab_slice_surf_tags = mesher.trabecular_slices( trab_cort_line_tags=trab_cort_line_tags, trab_line_tags_h=trab_line_tags_h, cort_int_bspline_tags=cortical_int_bspline, ) trab_plane_inertia_tags = trabecular_volume.trabecular_planes_inertia( trab_cort_line_tags, trab_line_tags_v, intersection_line_tags_int ) cort_trab_vol_tags = mesher.get_trabecular_cortical_volume_mesh( trab_slice_surf_tags, trab_plane_inertia_tags, cortical_int_surfs, trab_surfs_v, ) cort_volume_tags = np.concatenate( (cort_vol_tags, cort_trab_vol_tags), axis=None ) trab_refinement = None quadref_vols = None if trabecular_volume.QUAD_REFINEMENT: self.logger.info( "Starting quad refinement procedure (this might take some time)" ) # get coords of trab_point_tags coords_vertices = [] for subset in trab_point_tags: coords = trabecular_volume.get_vertices_coords(subset) coords_vertices.append(coords) trab_refinement = QuadRefinement( nb_layers=1, DIM=2, SQUARE_SIZE_0_MM=1, MAX_SUBDIVISIONS=3, ELMS_THICKNESS=N_LONGITUDINAL, ) ( quadref_line_tags_intersurf, quadref_surfs, quadref_vols, ) = trab_refinement.exec_quad_refinement( trab_point_tags.tolist() ) # , coords_vertices self.logger.info("Trabecular refinement done") # * meshing trab_surfs = list( chain( trab_surfs_v, trab_surfs_h, trab_slice_surf_tags, trab_plane_inertia_tags, ) ) trab_lines_longitudinal = trab_line_tags_v trab_lines_transverse = trab_cort_line_tags trab_lines_radial = trab_line_tags_h trabecular_volume.meshing_transfinite( trab_lines_longitudinal, trab_lines_transverse, trab_lines_radial, trab_surfs, trab_vols, phase="trab", ) # * physical groups mesher.factory.synchronize() trab_vol_tags = np.concatenate((trab_vols, cort_trab_vol_tags), axis=None) cort_physical_group = mesher.model.addPhysicalGroup( 3, cort_vol_tags, name="Cortical_Compartment" ) if quadref_vols is not None: trab_vol_tags = np.append(trab_vol_tags, quadref_vols[0]) else: pass trab_physical_group = mesher.model.addPhysicalGroup( 3, trab_vol_tags, name="Trabecular_Compartment" ) if trab_refinement is not None and quadref_vols is not None: # same as saying if QUAD_REFINEMENT is True quadref_physical_group = trab_refinement.model.addPhysicalGroup( 3, quadref_vols[0] ) print( f"Cortical physical group: {cort_physical_group}\nTrabecular physical group: {trab_physical_group}" ) cort_longitudinal_lines = intersection_line_tags cort_transverse_lines = intersurface_line_tags mesher.meshing_transfinite( cort_longitudinal_lines, cort_transverse_lines, cortical_bspline_tags, cort_surfs, cort_volume_tags, phase="cort", ) mesher.model.geo.mesh.setRecombine(2, -1) tot_vol_tags = [cort_vol_tags, trab_vol_tags] mesher.mesh_generate(dim=3, element_order=ELM_ORDER, vol_tags=tot_vol_tags) mesher.model.mesh.removeDuplicateNodes() mesher.model.mesh.removeDuplicateElements() mesher.model.occ.synchronize() mesher.logger.info("Optimising mesh") if ELM_ORDER == 1: mesher.model.mesh.optimize(method="UntangleMeshGeometry", force=True) pass elif ELM_ORDER > 1: mesher.model.mesh.optimize(method="HighOrderFastCurving", force=False) if MESH_ANALYSIS: JAC_FULL = 999.9 # 999.9 if you want to see all the elements JAC_NEG = -0.01 # -0.01 if you want to see only negative Jacobians mesher.analyse_mesh_quality(hiding_thresh=JAC_FULL) nodes = mesher.gmsh_get_nodes() elms = mesher.gmsh_get_elms() bnds_bot, bnds_top = mesher.gmsh_get_bnds(nodes) reference_point_coord = mesher.gmsh_get_reference_point_coord(nodes) if SHOW_GMSH: gmsh.fltk.run() if WRITE_MESH: Path(mesher.mesh_file_path).parent.mkdir(parents=True, exist_ok=True) gmsh.write(f"{mesher.mesh_file_path}.msh") gmsh.write(f"{mesher.mesh_file_path}.inp") gmsh.write(f"{mesher.mesh_file_path}.vtk") entities_cort = mesher.model.getEntitiesForPhysicalGroup(3, cort_physical_group) entities_trab = mesher.model.getEntitiesForPhysicalGroup(3, trab_physical_group) centroids_cort = mesher.get_barycenters(tag_s=entities_cort) centroids_trab = mesher.get_barycenters(tag_s=entities_trab) elm_vol_cort = mesher.get_elm_volume(tag_s=entities_cort) elm_vol_trab = mesher.get_elm_volume(tag_s=entities_trab) min_elm_vol = min(elm_vol_cort.min(), elm_vol_trab.min()) if min_elm_vol < 0.0: logger.critical( f"Negative element volume detected: {min_elm_vol:.3f} (mm^3), exiting..." ) else: logger.info(f"Minimum cortical element volume: {np.min(elm_vol_cort):.3f}") logger.info( f"Minimum trabecular element volume: {np.min(elm_vol_trab):.3f}" ) # get biggest ROI_radius radius_roi_cort = mesher.get_radius_longest_edge(tag_s=entities_cort) radius_roi_trab = mesher.get_radius_longest_edge(tag_s=entities_trab) assert len(elm_vol_cort) + len(elm_vol_trab) == len(centroids_cort) + len( centroids_trab ), "The number of volumes and centroids does not match." # i.e., if cort_physical_group is 1 and trab_physical_group is 2: if cort_physical_group < trab_physical_group: centroids_c = np.concatenate((centroids_cort, centroids_trab)) else: centroids_c = np.concatenate((centroids_trab, centroids_cort)) centroids_dict = mesher.get_centroids_dict(centroids_c) # split centroids_dict into cort and trab centroids_cort_dict, centroids_trab_dict = mesher.split_dict_by_array_len( centroids_dict, len(centroids_cort) ) # get number of nodes node_tags_cort, _ = mesher.model.mesh.getNodesForPhysicalGroup( 3, cort_physical_group ) node_tags_trab, _ = mesher.model.mesh.getNodesForPhysicalGroup( 3, trab_physical_group ) nb_nodes = len(node_tags_cort) + len(node_tags_trab) logger.info(f"Number of nodes in model: {nb_nodes}") gmsh_log = gmsh.logger.get() Path(mesh_file_path).parent.mkdir(parents=True, exist_ok=True) with open(f"{mesh_file_path}_gmsh.log", "w") as f: for line in gmsh_log: f.write(line + "\n") gmsh.logger.stop() gmsh.finalize() end = time.time() elapsed = round(end - start, ndigits=1) logger.info(f"Elapsed time: {elapsed} (s)") logger.info("Meshing script finished.") return ( nodes, elms, nb_nodes, centroids_cort_dict, centroids_trab_dict, elm_vol_cort, elm_vol_trab, radius_roi_cort, radius_roi_trab, bnds_bot, bnds_top, reference_point_coord, )