tensormesh.visualization.draw_graph 源代码

import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection, LineCollection
from matplotlib.patches import Polygon, Arc
from matplotlib import patches
from typing import Optional, Union
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from mpl_toolkits.mplot3d import Axes3D  # Add this import
from ..sparse import SparseMatrix
from .utils import ScipySparseMatrix, as_ndarray, as_sparse_matrix, dim

[文档] def draw_graph(sparse_matrix:Union[SparseMatrix,ScipySparseMatrix], points:Union[torch.Tensor,np.ndarray], draw_points:bool = True, point_color:str = 'orange', color:str = "blue", linewidth:int = 3, alpha:float = 0.5, ax:Optional[Union[plt.Axes, Axes3D]] = None )->Union[plt.Axes, Axes3D]: """ Parameters ---------- sparse_matrix: Union[tensormesh.sparse.SparseMatrix, ScipySparseMatrix] the sparse matrix points: Union[torch.Tensor, np.ndarray] 2D tensor of shape [n_points, 2] the points of the mesh color: str, optional the color of the edge, default is "blue" linewidth: int, optional the width of the edge, default is 3 ax: matplotlib.axes.Axes, optional the axis, default is None Returns ------- ax: matplotlib.axes.Axes the axis """ # assertion assert dim(points) == 2, f"points.dim() must be 2, but got {dim(points)}" n_points = points.shape[0] assert sparse_matrix.shape == (n_points, n_points), f"sparse_matrix.shape must be ({n_points}, {n_points}), but got {sparse_matrix.shape}" # input prepare points_np = as_ndarray(points) sparse_matrix = as_sparse_matrix(sparse_matrix) if ax is None: if points_np.shape[1] < 3: fig, ax = plt.subplots(figsize=(10,10)) ax.set_aspect('equal') else: fig = plt.figure(figsize=(10,10)) ax = fig.add_subplot(111, projection='3d') ax.set_box_aspect([1,1,1]) else: if points_np.shape[1] >= 3 and ax.name != "3d": # Get current figure and position fig = ax.figure pos = ax.get_position() # Remove old 2D axis ax.remove() # Create new 3D axis in same position ax = fig.add_axes(pos, projection='3d') ax.set_box_aspect([1,1,1]) elif points_np.shape[1] < 3 and ax.name == "3d": # Get current figure and position fig = ax.figure pos = ax.get_position() # Remove old 3D axis ax.remove() # Create new 2D axis in same position ax = fig.add_axes(pos) ax.set_aspect('equal') where_self_loop = sparse_matrix.row == sparse_matrix.col edges = sparse_matrix.edges[:, ~where_self_loop] self_loops = sparse_matrix.edges[:, where_self_loop] edges_np = edges.detach().cpu().numpy() self_loops_np = self_loops.detach().cpu().numpy() if points_np.shape[1] == 2: # 2D case lines = LineCollection(points_np[edges_np.T], color=color, linewidth=linewidth, alpha=alpha) arcs = [] for loop in self_loops_np.T: i = loop[0] arcs.append(Arc(xy=(points_np[i][0], points_np[i][1]), width=0.01, height=0.01, angle=0, theta1=0, theta2=360, color=color, alpha = alpha, linewidth=linewidth)) arcs = PatchCollection(arcs, match_original=True) else: # 3D case lines = Line3DCollection(points_np[edges_np.T], color=color, linewidth=linewidth, alpha=alpha) arcs = [] for loop in self_loops_np.T: loop = loop[0] # Create 3D self-loop using a parametric curve t = np.linspace(0, 2*np.pi, 50) radius = 0.01 # Size of the loop # Create a random direction for the loop orientation random_dir = np.random.randn(3) random_dir = random_dir / np.linalg.norm(random_dir) # Create two perpendicular vectors to form a plane v1 = np.array([1, 0, 0]) if not np.allclose(random_dir, [1, 0, 0]) else np.array([0, 1, 0]) perpendicular1 = v1 - (np.dot(v1, random_dir) * random_dir) perpendicular1 = perpendicular1 / np.linalg.norm(perpendicular1) perpendicular2 = np.cross(random_dir, perpendicular1) # Create the 3D loop using a parametric equation center_offset = random_dir * radius * 0.5 loop_points = (points_np[loop] + center_offset + radius * (np.cos(t)[:, np.newaxis] * perpendicular1 + np.sin(t)[:, np.newaxis] * perpendicular2)) arcs.append(loop_points) arcs = Line3DCollection(arcs, color=color, linewidth=linewidth, alpha=alpha) ax.add_collection(lines) ax.add_collection(arcs) if draw_points: if points_np.shape[1] == 2: ax.scatter(points_np[:, 0], points_np[:, 1], c=point_color) else: ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=point_color) return ax # def draw_graph(sparse_matrix, points, ax=None): r""" Parameters ---------- sparse_matrix: tensormesh.sparse.SparseTensor the sparse matrix points: torch.Tensor 2D tensor of shape :math:`[|\mathcal V|, 2]`, where :math:`|\mathcal V|` is the number of vertices Returns ------- ax: matplotlib.axes.Axes the axis """ if ax is None: fig, ax = plt.subplots(figsize=(10,10)) if isinstance(points, torch.Tensor): points = points.numpy() pos = points.numpy() row, col = sparse_matrix.row, sparse_matrix.col if isinstance(row, torch.Tensor): row, col = row.numpy(), col.numpy() fig, ax = plt.subplots(figsize=(10,10)) lines = [] selfloops = [] diameter = 0.02 for (u, v) in zip(row, col): if u == v: selfloops.append(Arc((pos[u,0],pos[u,1]+diameter/2), diameter, diameter, 0, 0, 360, color="black", linewidth=0.5)) else: line = (pos[u], pos[v]) lines.append(line) lc = LineCollection(lines, color="black", linewidth=0.5) loops = PatchCollection(selfloops, match_original=True) ax.add_collection(lc) ax.add_collection(loops) return ax