Source code for frontend._plot_

# File: _plot_.py
# Author: Ryoichi Ando (ryoichi.ando@zozo.com)
# License: Apache v2.0

from frontend._utils_ import Utils
from IPython.display import display
from typing import Optional
from dataclasses import dataclass
import pythreejs as p3s
import copy
import numpy as np
from ._render_ import OpenGLRenderer


[docs] class PlotManager: """PlotManager class. Use this to create a plot.""" def __init__(self) -> None: """Initialize the plot manager.""" self._in_jupyter_notebook = Utils.in_jupyter_notebook() self.param = PlotParam()
[docs] def create(self, engine: str = "threejs") -> "Plot": """Create a plot.""" return Plot(engine, self.param)
[docs] def is_jupyter_notebook(self) -> bool: """Check if the code is running in a Jupyter notebook.""" return self._in_jupyter_notebook
[docs] class Plot: """Plot class. Use this to create a plot.""" def __init__(self, engine: str, param: "PlotParam") -> None: """Initialize the plot. Args: _darkmode (bool): True to turn on dark mode, False otherwise. """ self._in_jupyter_notebook = Utils.in_jupyter_notebook() if engine == "threejs": self._engine = ThreejsPlotEngine() elif engine == "opengl": self._engine = OpenGLRenderEngine() else: raise ValueError(f"Unknown engine: {engine}") self._vert = np.zeros(0) self._color = np.zeros(0) self.param = param
[docs] def is_jupyter_notebook(self) -> bool: """Check if the code is running in a Jupyter notebook.""" return self._in_jupyter_notebook
[docs] def plot( self, vert: np.ndarray, color: np.ndarray = np.zeros(0), tri: np.ndarray = np.zeros(0), seg: np.ndarray = np.zeros(0), pts: np.ndarray = np.zeros(0), param_override: dict = {}, ) -> "Plot": """Plot a mesh. Args: vert (np.ndarray): The vertices (#x3) of the mesh. color (np.ndarray): The color (#x3) of the mesh. Each value should be in [0,1]. tri (np.ndarray): The triangle elements (#x3) of the mesh. seg (np.ndarray): The edge elements (#x2) of the mesh. pts (np.ndarray): The point elements (#x1) of the mesh. param_override (dict): The parameter override. Returns: Plot: The plot object. """ if self._in_jupyter_notebook: param = copy.deepcopy(self.param) for key, value in param_override.items(): setattr(param, key, value) self._vert = vert.copy() self._color = color.copy() self._engine.plot(self._vert, self._color, tri, seg, pts, param) return self
[docs] def update( self, vert: Optional[np.ndarray] = None, color: Optional[np.ndarray] = None ): if vert is not None: self._vert[0 : len(vert)] = vert vert = self._vert if color is not None: self._color[0 : len(color)] = color color = self._color self._engine.update(vert, color)
[docs] def tri( self, vert: np.ndarray, tri: np.ndarray, stitch: tuple[np.ndarray, np.ndarray] = (np.zeros(0), np.zeros(0)), color: np.ndarray = np.zeros(0), param_override: dict = {}, ) -> "Plot": """Plot a triangle mesh. Args: vert (np.ndarray): The vertices (#x3) of the mesh. tri (np.ndarray): The triangle elements (#x3) of the mesh. stitch (tuple[np.ndarray, np.ndarray]): The stitch data (index #x3 and weight #x2). color (np.ndarray): The color (#x3) of the mesh. Each value should be in [0,1]. param_override (dict): The parameter override. Returns: Plot: The plot object. """ if self._in_jupyter_notebook: if tri.shape[1] != 3: raise ValueError("triangles must have 3 vertices") if vert.shape[1] == 2: vert = np.concatenate( [vert, np.zeros((vert.shape[0], 1), dtype=np.uint32)], axis=1 ) else: vert = vert.copy() ind, w = stitch if len(ind) and len(w): edge = [] new_vert = [] for ind, w in zip(ind, w): x0, y0, y1 = vert[ind[0]], vert[ind[1]], vert[ind[2]] w0, w1 = w[0], w[1] idx0 = len(new_vert) + len(vert) idx1 = idx0 + 1 new_vert.append(x0) new_vert.append(w0 * y0 + w1 * y1) edge.append([idx0, idx1]) vert = np.vstack([vert, np.array(new_vert)]) edge = np.array(edge) else: edge = np.zeros(0) self.plot(vert, color, tri, edge, np.zeros(0), param_override) return self
[docs] def edge( self, vert: np.ndarray, edge: np.ndarray, color: np.ndarray, param_override: dict = {}, ) -> "Plot": """Add edges to the plot. Args: vert (np.ndarray): The vertices (#x3) of the edges. edge (np.ndarray): The edge elements (#x2) of the edges. color (np.ndarray): The color (#x3) of the edges. Each value should be in [0,1]. param_override (dict): The parameter override. Returns: Plot: The plot object. """ if self._in_jupyter_notebook: self.plot(vert, color, np.zeros(0), edge, np.zeros(0), param_override) return self
[docs] def point(self, vert: np.ndarray, param_override: dict = {}) -> "Plot": """Add points to the plot. Args: vert (np.ndarray): The vertices (#x3) of the points. param_override (dict): The parameter override. Returns: Plot: The plot object. """ if self._in_jupyter_notebook: self.plot( vert, np.zeros(0), np.zeros(0), np.zeros(0), np.arange(len(vert)), param_override, ) return self
[docs] def curve( self, vert: np.ndarray, _edge: np.ndarray = np.zeros(0), color: np.ndarray = np.zeros(0), param_override: dict = {}, ) -> "Plot": """Plot a curve. Args: vert (np.ndarray): The vertices (#x3) of the curve. _edge (np.ndarray): The edge elements (#x2) of the curve. color (np.ndarray): The color (#x3) of the curve. Each value should be in [0,1]. param_override (dict): The parameter override. Returns: Plot: The plot object. """ if self._in_jupyter_notebook: if _edge.size == 0: edge = np.array( [[i, (i + 1) % len(vert)] for i in range(len(vert))], dtype=np.uint32, ) else: edge = _edge if vert.shape[1] == 2: _pts = np.concatenate( [vert, np.zeros((vert.shape[0], 1), dtype=np.uint32)], axis=1 ) else: _pts = vert self.edge(_pts, edge, color, param_override) return self
[docs] def tet( self, vert: np.ndarray, tet: np.ndarray, axis: int = 0, cut: float = 0.5, color: np.ndarray = np.zeros(0), param_override: dict = {}, ) -> "Plot": """Plot a tetrahedral mesh. Args: vert (np.ndarray): The vertices (#x3) of the mesh. tet (np.ndarray): The tetrahedral elements (#x4) of the mesh. axis (int): The axis to cut the mesh. cut (float): The cut ratio. color (np.ndarray): The color (#x3) of the mesh. Each value should be in [0,1]. param_override (dict): The parameter override. Returns: Plot: The plot object. """ if "flat_shading" not in param_override: param_override["flat_shading"] = True if self._in_jupyter_notebook: param = copy.deepcopy(self.param) for key, value in param_override.items(): setattr(param, key, value) def compute_hash(tri, n): n = np.int64(n) i0, i1, i2 = sorted(tri) return i0 + i1 * n + i2 * n * n assert vert.shape[1] == 3 assert tet.shape[1] == 4 max_coord = np.max(vert[:, axis]) min_coord = np.min(vert[:, axis]) tmp_tri = {} for t in tet: x = [vert[i] for i in t] c = (x[0] + x[1] + x[2] + x[3]) / 4 if c[axis] > min_coord + cut * (max_coord - min_coord): tri = [[0, 1, 2], [0, 2, 3], [0, 1, 3], [1, 2, 3]] for k in tri: e = [t[i] for i in k] hash = compute_hash(e, len(vert)) if hash not in tmp_tri: tmp_tri[hash] = e else: del tmp_tri[hash] return self.tri( vert, np.array(list(tmp_tri.values())), color=color, param_override=param_override, ) else: return self
@dataclass class PlotBuffer: vert: Optional[p3s.BufferAttribute] = None tri: Optional[p3s.BufferAttribute] = None color: Optional[p3s.BufferAttribute] = None pts: Optional[p3s.BufferAttribute] = None seg: Optional[p3s.BufferAttribute] = None @dataclass class PlotGeometry: tri: Optional[p3s.BufferGeometry] = None pts: Optional[p3s.BufferGeometry] = None seg: Optional[p3s.BufferGeometry] = None @dataclass class PlotObject: tri: Optional[p3s.Mesh] = None pts: Optional[p3s.Points] = None seg: Optional[p3s.LineSegments] = None wireframe: Optional[p3s.Mesh] = None light_0: Optional[p3s.DirectionalLight] = None light_1: Optional[p3s.AmbientLight] = None camera: Optional[p3s.PerspectiveCamera] = None scene: Optional[p3s.Scene] = None renderer: Optional[p3s.Renderer] = None @dataclass class PlotParam: direct_intensity: float = 1.0 ambient_intensity: float = 0.7 wireframe: bool = True flat_shading: bool = False pts_scale: float = 0.004 pts_color: str = "white" default_color: np.ndarray = np.array([1.0, 0.8, 0.2]) lookat: Optional[list[float]] = None eyeup: float = 0.0 fov: float = 50.0 width: int = 600 height: int = 600 class ThreejsPlotEngine: def __init__(self): self.buff = PlotBuffer() self.geom = PlotGeometry() self.obj = PlotObject() self.flat_shading = False def plot( self, vert: np.ndarray, color: np.ndarray, tri: np.ndarray, seg: np.ndarray, pts: np.ndarray, param: PlotParam = PlotParam(), ): assert len(vert) > 0 if len(color) == 0: color = np.tile(param.default_color, (len(vert), 1)) assert len(color) == len(vert) color = color.astype("float32") vert = vert.astype("float32") bbox = np.max(vert, axis=0) - np.min(vert, axis=0) if param.lookat is None: center = list(-np.min(vert, axis=0) - 0.5 * bbox) else: center = list(-np.array(param.lookat)) self.buff.vert = p3s.BufferAttribute(vert, normalized=False) self.buff.color = p3s.BufferAttribute(color) if len(tri): self.buff.tri = p3s.BufferAttribute( tri.astype("uint32").ravel(), normalized=False ) else: self.buff.tri = None if len(pts): self.buff.pts = p3s.BufferAttribute( pts.astype("uint32").ravel(), normalized=False ) else: self.buff.pts = None if len(seg): self.buff.seg = p3s.BufferAttribute( seg.astype("uint32").ravel(), normalized=False ) else: self.buff.seg = None if self.buff.tri is not None: self.geom.tri = p3s.BufferGeometry( attributes=dict( position=self.buff.vert, index=self.buff.tri, color=self.buff.color, ) ) else: self.geom.tri = None if self.buff.pts is not None: self.geom.pts = p3s.BufferGeometry( attributes=dict( position=self.buff.vert, index=self.buff.pts, ) ) else: self.geom.pts = None if self.buff.seg is not None: self.geom.seg = p3s.BufferGeometry( attributes=dict( position=self.buff.vert, index=self.buff.seg, color=self.buff.color, ) ) else: self.geom.seg = None if self.geom.tri is not None: self.flat_shading = param.flat_shading if param.flat_shading: self.geom.tri.exec_three_obj_method("computeFaceNormals") else: self.geom.tri.exec_three_obj_method("computeVertexNormals") if self.geom.tri is not None: self.obj.tri = p3s.Mesh( geometry=self.geom.tri, material=p3s.MeshStandardMaterial( vertexColors="VertexColors", side="DoubleSide", flatShading=param.flat_shading, polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=1, ), position=center, ) else: self.obj.tri = None if self.geom.pts is not None: self.obj.pts = p3s.Points( geometry=self.geom.pts, material=p3s.PointsMaterial( size=param.pts_scale, color=param.pts_color, ), position=center, ) else: self.obj.pts = None if self.geom.seg is not None: self.obj.seg = p3s.LineSegments( geometry=self.geom.seg, material=p3s.LineBasicMaterial(vertexColors="VertexColors"), position=center, ) else: self.obj.seg = None if param.wireframe and self.obj.tri is not None: self.obj.wireframe = p3s.Mesh( geometry=self.geom.tri, material=p3s.MeshBasicMaterial( color="black", wireframe=True, ), position=center, ) else: self.obj.wireframe = None scale = np.max(bbox) position = [0, scale * param.eyeup, 1.25 * scale] self.obj.light_0 = p3s.DirectionalLight( position=position, intensity=param.direct_intensity ) self.obj.light_1 = p3s.AmbientLight(intensity=param.ambient_intensity) self.obj.camera = p3s.PerspectiveCamera( position=position, fov=param.fov, aspect=param.width / param.height, children=[self.obj.light_0], ) children = [self.obj.camera, self.obj.light_1] if self.obj.tri is not None: children.append(self.obj.tri) if self.obj.wireframe is not None: children.append(self.obj.wireframe) if self.obj.pts is not None: children.append(self.obj.pts) if self.obj.seg is not None: children.append(self.obj.seg) self.obj.scene = p3s.Scene(children=children, background="#222222") self.obj.renderer = p3s.Renderer( camera=self.obj.camera, scene=self.obj.scene, controls=[p3s.OrbitControls(controlling=self.obj.camera)], antialias=True, width=param.width, height=param.height, ) display(self.obj.renderer) def update( self, vert: Optional[np.ndarray] = None, color: Optional[np.ndarray] = None ): if vert is not None: assert self.buff.vert is not None self.buff.vert.array = vert.astype("float32") self.buff.vert.needsUpdate = True if self.geom.tri is not None: if self.flat_shading: self.geom.tri.exec_three_obj_method("computeFaceNormals") else: self.geom.tri.exec_three_obj_method("computeVertexNormals") if color is not None: assert self.buff.color is not None self.buff.color.array = color.astype("float32") self.buff.color.needsUpdate = True class OpenGLRenderEngine: def __init__(self) -> None: self._handle = None def _render( self, vert: np.ndarray, color: np.ndarray, tri: np.ndarray, seg: np.ndarray, ): from IPython.display import display engine = OpenGLRenderer() image = engine.render( vert, color, seg, tri, None, ) if self._handle is None: self._handle = display(image, display_id=True) else: self._handle.update(image) def plot( self, vert: np.ndarray, color: np.ndarray, tri: np.ndarray, seg: np.ndarray, pts: np.ndarray, param: PlotParam = PlotParam(), ): assert len(vert) > 0 if len(color) == 0: color = np.tile(param.default_color, (len(vert), 1)) assert len(color) == len(vert) self._vert = vert.copy() self._color = color.copy() self._tri = tri.copy() self._seg = seg.copy() self._pts = pts.copy() self._param = param self._render(self._vert, self._color, self._tri, self._seg) def update( self, vert: Optional[np.ndarray] = None, color: Optional[np.ndarray] = None ): if vert is not None: self._vert = vert.copy() if color is not None: self._color = color.copy() self._render(self._vert, self._color, self._tri, self._seg)