Source code for frontend._plot_

# File: _plot_.py
# Author: Ryoichi Ando (ryoichi.ando@zozo.com)
# License: Apache v2.0

from meshplot import plot
from meshplot.Viewer import Viewer
from frontend._utils_ import Utils
import pythreejs as p3s
import numpy as np
from IPython.display import display

"""Default shading settings for light mode."""
LIGHT_DEFAULT_SHADING = {
    "flat": False,
    "wireframe": True,
    "line_width": 1.0,
    "line_color": "black",
    "point_color": "black",
}

"""Default shading settings for dark mode."""
DARK_DEFAULT_SHADING = {
    "flat": False,
    "wireframe": True,
    "line_width": 1.0,
    "background": "#222222",
    "line_color": "white",
    "point_color": "white",
}


[docs] class PlotManager: """PlotManager class. Use this to create a plot.""" def __init__(self) -> None: """Initialize the plot manager.""" self._darkmode = True self._in_jupyter_notebook = Utils.in_jupyter_notebook()
[docs] def darkmode(self, darkmode: bool) -> None: """Turn on or off dark mode. Args: darkmode (bool): True to turn on dark mode, False otherwise. """ self._darkmode = darkmode
[docs] def create(self) -> "Plot": """Create a plot.""" return Plot(self._darkmode)
[docs] def is_jupyter_notebook(self) -> bool: """Check if the code is running in a Jupyter notebook.""" return self._in_jupyter_notebook
[docs] class PlotAdder: """PlotAdder class. Use this to add elements to a plot.""" def __init__(self, parent: "Plot") -> None: """Initialize the plot adder.""" self._parent = parent self._in_jupyter_notebook = Utils.in_jupyter_notebook()
[docs] def tri(self, vert: np.ndarray, tri: np.ndarray, color: np.ndarray) -> "Plot": """Add a triangle mesh to the plot. Args: vert (np.ndarray): The vertices (#x3) of the mesh. tri (np.ndarray): The triangle elements (#x3) of the mesh. color (np.ndarray): The color (#x3) of the mesh. Each value should be in [0,1]. Returns: Plot: The plot object. """ if self._in_jupyter_notebook: viewer = self._parent._viewer shading = self._parent._shading if viewer is None: raise Exception("No plot to add to") else: viewer.add_mesh(vert, tri, color, shading=shading) return self._parent
[docs] def edge(self, vert: np.ndarray, edge: np.ndarray) -> "Plot": """Add edges to the plot. Args: vert (np.ndarray): The vertices (#x3) of the edges. edge (np.ndarray): The edge elements (#x2) of the edges. Returns: Plot: The plot object. """ if self._in_jupyter_notebook: viewer = self._parent._viewer shading = self._parent._shading edge = edge.copy().astype(np.uint32) vert = vert.copy().astype(np.float32) if viewer is None: raise Exception("No plot to add to") else: geometry = p3s.BufferGeometry( attributes={ "position": p3s.BufferAttribute(vert, normalized=False), "index": p3s.BufferAttribute(edge.flatten(), normalized=False), } ) material = p3s.LineBasicMaterial( linewidth=shading["line_width"], color=shading["line_color"] ) line = p3s.Line(geometry=geometry, material=material) obj = { "geometry": geometry, "mesh": line, "material": material, "max": np.max(vert, axis=0), "min": np.min(vert, axis=0), "type": "Lines", "wireframe": None, } viewer._Viewer__add_object(obj) # type: ignore return self._parent
[docs] def point(self, vert: np.ndarray) -> "Plot": """Add points to the plot. Args: vert (np.ndarray): The vertices (#x3) of the points. Returns: Plot: The plot object. """ if self._in_jupyter_notebook: viewer = self._parent._viewer shading = self._parent._shading if viewer is None: raise Exception("No plot to add to") else: viewer.add_points(vert, shading=shading) return self._parent
[docs] class Plot: """Plot class. Use this to create a plot.""" def __init__(self, _darkmode: bool): """Initialize the plot. Args: _darkmode (bool): True to turn on dark mode, False otherwise. """ self._in_jupyter_notebook = Utils.in_jupyter_notebook() self._darkmode = _darkmode self._viewer = None self._shading = {} self.add = PlotAdder(self)
[docs] def is_jupyter_notebook(self) -> bool: """Check if the code is running in a Jupyter notebook.""" return self._in_jupyter_notebook
[docs] def to_html(self, path: str = ""): """Export an HTML file with the plot. Args: path (str): The filename to save the HTML file. """ if self._in_jupyter_notebook: if self._viewer is None: raise Exception("No plot to save") else: self._viewer.save(path)
[docs] def has_view(self) -> bool: """Return if the plot has a view.""" return self._viewer is not None
[docs] def overwrite_shading(self, shading: dict) -> dict: """Overwrite the shading settings with the default settings.""" default_shading = ( DARK_DEFAULT_SHADING if self._darkmode else LIGHT_DEFAULT_SHADING ) for key in default_shading.keys(): if key not in shading: shading[key] = default_shading[key] return shading
[docs] def curve( self, vert: np.ndarray, _edge: np.ndarray = np.zeros(0), shading: dict = {} ) -> "Plot": """Plot a curve. Args: vert (np.ndarray): The vertices (#x3) of the curve. _edge (np.ndarray): The edge elements (#x2) of the curve. shading (dict): The shading settings. Returns: Plot: The plot object. """ if self._in_jupyter_notebook: shading = self.overwrite_shading(shading) if _edge.size == 0: edge = np.array([[i, (i + 1) % len(vert)] for i in range(len(vert))]) else: edge = _edge if vert.shape[1] == 2: _pts = np.concatenate([vert, np.zeros((vert.shape[0], 1))], axis=1) else: _pts = vert viewer = Viewer(shading) viewer.reset() self._viewer = viewer self._shading = shading self.add.edge(_pts, edge) display(self._viewer._renderer) return self
[docs] def tri( self, vert: np.ndarray, tri: np.ndarray, stitch: tuple[np.ndarray, np.ndarray] = (np.zeros(0), np.zeros(0)), color: np.ndarray = np.array([1.0, 0.85, 0.0]), shading: dict = {}, ) -> "Plot": """Plot a triangle mesh. Args: vert (np.ndarray): The vertices (#x3) of the mesh. tri (np.ndarray): The triangle elements (#x3) of the mesh. stitch (tuple[np.ndarray, np.ndarray]): The stitch data (index #x3 and weight #x2). color (np.ndarray): The color (#x3) of the mesh. Each value should be in [0,1]. sahding (dict): The shading settings. Returns: Plot: The plot object. """ if self._in_jupyter_notebook: if tri.shape[1] != 3: raise ValueError("triangles must have 3 vertices") shading = self.overwrite_shading(shading) self._viewer = plot(vert, tri, color, shading=shading) self._shading = shading assert isinstance(self._viewer, Viewer) ind, w = stitch if len(ind) and len(w): stitch_vert, stitch_edge = [], [] for ind, w in zip(ind, w): x0, y0, y1 = vert[ind[0]], vert[ind[1]], vert[ind[2]] w0, w1 = w[0], w[1] idx0, idx1 = len(stitch_vert), len(stitch_vert) + 1 stitch_vert.append(x0) stitch_vert.append(w0 * y0 + w1 * y1) stitch_edge.append([idx0, idx1]) stitch_vert = np.array(stitch_vert) stitch_edge = np.array(stitch_edge) self._viewer.add_edges(stitch_vert, stitch_edge, shading=shading) return self
[docs] def tet( self, vert: np.ndarray, tet: np.ndarray, axis: int = 0, cut: float = 0.5, color: np.ndarray = np.array([1.0, 0.85, 0.0]), shading: dict = {}, ) -> "Plot": """Plot a tetrahedral mesh. Args: vert (np.ndarray): The vertices (#x3) of the mesh. tet (np.ndarray): The tetrahedral elements (#x4) of the mesh. axis (int): The axis to cut the mesh. cut (float): The cut ratio. color (np.ndarray): The color (#x3) of the mesh. Each value should be in [0,1]. shading (dict): The shading settings. Returns: Plot: The plot object. """ if self._in_jupyter_notebook: def compute_hash(tri, n): n = np.int64(n) i0, i1, i2 = sorted(tri) return i0 + i1 * n + i2 * n * n assert vert.shape[1] == 3 assert tet.shape[1] == 4 shading = self.overwrite_shading(shading) max_coord = np.max(vert[:, axis]) min_coord = np.min(vert[:, axis]) tmp_tri = {} for t in tet: x = [vert[i] for i in t] c = (x[0] + x[1] + x[2] + x[3]) / 4 if c[axis] > min_coord + cut * (max_coord - min_coord): tri = [[0, 1, 2], [0, 2, 3], [0, 1, 3], [1, 2, 3]] for k in tri: e = [t[i] for i in k] hash = compute_hash(e, len(vert)) if hash not in tmp_tri: tmp_tri[hash] = e else: del tmp_tri[hash] return self.tri( vert, np.array(list(tmp_tri.values())), color=color, shading=shading ) else: return self
[docs] def update(self, vert: np.ndarray): """Update the plot with new vertices. Args: vert (np.ndarray): The new vertices (#x3). """ if self._in_jupyter_notebook: viewer = self._viewer if viewer is None: raise Exception("No plot to update") else: objects = viewer._Viewer__objects # type: ignore x = vert.copy().astype(np.float32) geo = objects[0]["geometry"] geo.attributes["position"].array = x geo.attributes["position"].needsUpdate = True if self._shading["flat"]: geo.exec_three_obj_method("computeFaceNormals") else: geo.exec_three_obj_method("computeVertexNormals") return self