Source code for tensormesh.visualization.static_plot

import torch
import numpy as np
from .utils import mesh_to_pyvista, setup_headless, pv, HAS_PYVISTA, _PYVISTA_INSTALL_HINT

[docs] def plot_deformation(mesh, displacement: torch.Tensor, file_name: str, scale_factor: float = 1.0, camera_position = 'isometric', fixed_nodes = None, force_vectors = None, linearize: bool = True): if not HAS_PYVISTA: raise ImportError(_PYVISTA_INSTALL_HINT) """ Save a static comparison plot of undeformed (wireframe) vs deformed (solid) mesh, optionally showing boundary conditions (fixed nodes and force vectors). Parameters ---------- mesh : tensormesh.Mesh displacement : torch.Tensor file_name : str Output filename (e.g. 'result.png') scale_factor : float Scale factor for deformation. Default 1.0. camera_position : str 'isometric', 'xy', 'xz', 'yz' fixed_nodes : torch.Tensor/ndarray, optional Boolean mask or indices of fixed nodes. force_vectors : torch.Tensor/ndarray, optional Force vectors at nodes (shape [N, 3]). linearize : bool If True, convert high-order elements to linear elements for robust visualization. Default True. """ setup_headless() # Prepare Displacement Data if isinstance(displacement, torch.Tensor): u = displacement.detach().cpu().numpy() else: u = displacement # Use utility for conversion pv_mesh = mesh_to_pyvista(mesh, point_data={"displacement": u}, linearize=linearize) # Calculate bounding box diagonal for scaling glyphs bounds = pv_mesh.bounds diag = np.linalg.norm([bounds[1]-bounds[0], bounds[3]-bounds[2], bounds[5]-bounds[4]]) glyph_scale = diag * 0.02 # 2% of diagonal size # Create Plotter plotter = pv.Plotter(off_screen=True, window_size=[1024, 768]) # 1. Undeformed Mesh (Wireframe, Grey, Transparent) plotter.add_mesh(pv_mesh, style='wireframe', color='grey', opacity=0.3, label='Original') # 2. Deformed Mesh (Solid, Colored by displacement) warped = pv_mesh.warp_by_vector(vectors="displacement", factor=scale_factor) plotter.add_mesh(warped, scalars="displacement", cmap="jet", show_edges=True, edge_color="black", label='Deformed', smooth_shading=True, split_sharp_edges=True) # 3. BC Visualization # Fixed Constraints (Blue Cubes) if fixed_nodes is not None: if isinstance(fixed_nodes, torch.Tensor): fixed_nodes = fixed_nodes.detach().cpu().numpy() # Get coordinates pts = pv_mesh.points # Use pv_mesh points which are already numpy and 3D if fixed_nodes.dtype == bool: fixed_pts = pts[fixed_nodes] else: fixed_pts = pts[fixed_nodes] if len(fixed_pts) > 0: cloud = pv.PolyData(fixed_pts) # Use cubes for constraints glyphs = cloud.glyph(scale=False, geom=pv.Cube(), orient=False, factor=glyph_scale) plotter.add_mesh(glyphs, color="blue", label="Fixed") # Force Vectors (Red Arrows) if force_vectors is not None: if isinstance(force_vectors, torch.Tensor): force_vectors = force_vectors.detach().cpu().numpy() # Ensure 3D if force_vectors.shape[1] == 2: force_vectors = np.concatenate([force_vectors, np.zeros((force_vectors.shape[0], 1))], axis=1) force_mag = np.linalg.norm(force_vectors, axis=1) load_mask = force_mag > 1e-9 * force_mag.max() # Filter near-zero if np.any(load_mask): pts = pv_mesh.points load_pts = pts[load_mask] load_vecs = force_vectors[load_mask] # Subsample if too many (>30) to avoid clutter if len(load_pts) > 30: # Random sampling indices = np.random.choice(len(load_pts), 30, replace=False) load_pts = load_pts[indices] load_vecs = load_vecs[indices] cloud = pv.PolyData(load_pts) cloud["vectors"] = load_vecs # Arrows # Reduced scale for better visibility arrows = cloud.glyph(orient="vectors", scale=False, factor=glyph_scale*2.0, geom=pv.Arrow()) plotter.add_mesh(arrows, color="red", label="Load") # Setup plotter.add_text(f"Deformation Scale: {scale_factor:.1f}x", position='upper_left') plotter.add_axes() plotter.add_legend() if camera_position == 'xy': plotter.view_xy() elif camera_position == 'isometric': plotter.view_isometric() elif camera_position == 'xz': plotter.view_xz() plotter.reset_camera() plotter.camera.zoom(1.2) # Save plotter.screenshot(file_name) plotter.close() print(f"Comparison plot saved to {file_name}")