Source code for ogstools.meshplotlib.core

# Copyright (c) 2012-2024, OpenGeoSys Community (http://www.opengeosys.org)
#            Distributed under a Modified BSD License.
#            See accompanying file LICENSE.txt or
#            http://www.opengeosys.org/project/license
#

"""Meshplotlib core utilitites."""

import warnings
from math import nextafter
from typing import Any, Literal, Optional, Union

import numpy as np
import pyvista as pv
from matplotlib import cm as mcm
from matplotlib import colormaps, rcParams
from matplotlib import colors as mcolors
from matplotlib import figure as mfigure
from matplotlib import pyplot as plt
from matplotlib import ticker as mticker
from matplotlib.patches import Rectangle as Rect
from typeguard import typechecked

from ogstools.meshlib import MeshSeries
from ogstools.propertylib import Property, Vector
from ogstools.propertylib.presets import get_preset
from ogstools.propertylib.unit_registry import u_reg

from . import plot_features as pf
from . import setup
from .levels import compute_levels, median_exponent
from .utils import get_style_cycler

# TODO: define default data_name for regions in setup


def _q_zero_line(mesh_property: Property, levels: np.ndarray) -> bool:
    return mesh_property.bilinear_cmap or (
        mesh_property.data_name == "temperature" and levels[0] < 0 < levels[-1]
    )


[docs] def get_level_boundaries(levels: np.ndarray) -> np.ndarray: return np.array( [ levels[0] - 0.5 * (levels[1] - levels[0]), *0.5 * (levels[:-1] + levels[1:]), levels[-1] + 0.5 * (levels[-1] - levels[-2]), ] )
[docs] def get_cmap_norm( levels: np.ndarray, mesh_property: Property ) -> tuple[mcolors.Colormap, mcolors.Normalize]: """Construct a discrete colormap and norm for the property field.""" vmin, vmax = (levels[0], levels[-1]) if mesh_property.categoric: vmin += 0.5 vmax += 0.5 if isinstance(mesh_property.cmap, str): continuous_cmap = colormaps[mesh_property.cmap] else: continuous_cmap = mesh_property.cmap conti_norm: Union[mcolors.TwoSlopeNorm, mcolors.Normalize] if mesh_property.bilinear_cmap: if vmin <= 0.0 <= vmax: vcenter = 0.0 vmin, vmax = np.max(np.abs([vmin, vmax])) * np.array([-1.0, 1.0]) conti_norm = mcolors.TwoSlopeNorm(vcenter, vmin, vmax) else: # only use one half of the diverging colormap col_range = np.linspace(0.0, nextafter(0.5, -np.inf), 128) if vmax > 0.0: col_range += 0.5 continuous_cmap = mcolors.LinearSegmentedColormap.from_list( "half_cmap", continuous_cmap(col_range) ) conti_norm = mcolors.Normalize(vmin, vmax) else: conti_norm = mcolors.Normalize(vmin, vmax) mid_levels = np.append((levels[:-1] + levels[1:]) * 0.5, levels[-1]) colors = [continuous_cmap(conti_norm(m_l)) for m_l in mid_levels] cmap = mcolors.ListedColormap(colors, name="custom") boundaries = ( get_level_boundaries(levels) if mesh_property.categoric else levels ) norm = mcolors.BoundaryNorm( boundaries=boundaries, ncolors=len(boundaries), clip=False ) return cmap, norm
[docs] def get_ticklabels(ticks: np.ndarray) -> tuple[list[str], Optional[str]]: """Get formatted tick labels and optional offset str. If all values in ticks are too close together offset notation is used. """ if median_exponent(ticks) >= 2 + median_exponent(ticks[-1] - ticks[0]): # use offset notation label_lens = np.asarray([len(str(tick)) for tick in ticks]) offset = ticks[np.argmin(label_lens)] else: offset = 0 if np.issubdtype(ticks.dtype, np.integer): return [str(tick) for tick in ticks], ( None if offset == 0 else f"{offset:g}" ) for precision in [1, 2, 3, 4]: fmt = "f" if abs(median_exponent(ticks - offset)) <= 2 else "e" tick_labels: list[str] = [ f"{0.0 + tick:.{precision}{fmt}}" for tick in ticks - offset ] if len(tick_labels) == len(set(tick_labels)): break # pretty hacky but seems to do the job for idx, adj in [(0, 1), (-1, -2)]: if float(tick_labels[idx]) != float(tick_labels[adj]): continue for precision in range(12): new_ticklabel = f"{0.0 + ticks[idx] - offset:.{precision}{fmt}}" adj_ticklabel = f"{0.0 + ticks[adj] - offset:.{precision}{fmt}}" if float(new_ticklabel) != float(adj_ticklabel): tick_labels[idx] = new_ticklabel break if fmt != "e": tick_labels = [label.rstrip("0").rstrip(".") for label in tick_labels] return tick_labels, None if offset == 0 else f"{offset:g}"
[docs] def add_colorbars( fig: mfigure.Figure, ax: Union[plt.Axes, list[plt.Axes]], mesh_property: Property, levels: np.ndarray, pad: float = 0.05, labelsize: Optional[float] = None, ) -> None: """Add a colorbar to the matplotlib figure.""" ticks = levels if mesh_property.categoric or (len(levels) == 2): bounds = get_level_boundaries(levels) ticks = bounds[:-1] + 0.5 * np.diff(bounds) cmap, norm = get_cmap_norm(levels, mesh_property) cm = mcm.ScalarMappable(norm=norm, cmap=cmap) cb = fig.colorbar( cm, norm=norm, ax=ax, ticks=ticks, drawedges=True, location="right", spacing="uniform", pad=pad, # fmt: skip ) # Formatting the colorbar label and ticks tick_labels, offset = get_ticklabels(ticks) cb_label = mesh_property.get_label() if offset is not None: if offset[0] == "-": cb_label += " + " + offset[1:] else: cb_label += " - " + offset if setup.log_scaled: cb_label = f"log$_{{10}}$( {cb_label} )" labelsize = ( setup.rcParams_scaled["font.size"] if labelsize is None else labelsize ) cb.set_label(cb_label, size=labelsize) # special formatting for MaterialIDs if ( mesh_property.data_name == "MaterialIDs" and setup.material_names is not None ): tick_labels = [ setup.material_names.get(mat_id, mat_id) for mat_id in levels ] cb.ax.set_ylabel("") elif mesh_property.categoric: tick_labels = [str(level) for level in levels.astype(int)] cb.ax.tick_params(labelsize=labelsize, direction="out") cb.ax.set_yticklabels(tick_labels) # miscellaneous if mesh_property.is_mask(): cb.ax.add_patch(Rect((0, 0.5), 1, -1, lw=0, fc="none", hatch="/")) if setup.invert_colorbar: cb.ax.invert_yaxis() if _q_zero_line(mesh_property, ticks): cb.ax.axhline( y=0, color="w", lw=2 * setup.rcParams_scaled["lines.linewidth"] )
[docs] def get_projection( mesh: pv.UnstructuredGrid, ) -> tuple[int, int]: """ Identify which projection is used: XY, XZ or YZ. :param mesh: singular mesh """ mean_normal = np.abs(np.mean(mesh.extract_surface().cell_normals, axis=0)) projection = int(np.argmax(mean_normal)) x_id, y_id = np.delete([0, 1, 2], projection) return x_id, y_id
[docs] def subplot( mesh: pv.UnstructuredGrid, mesh_property: Union[Property, str], ax: plt.Axes, levels: Optional[np.ndarray] = None, ) -> None: """ Plot the property field of a mesh on a matplotlib.axis. In 3D the mesh gets sliced according to slice_type and the origin in the PlotSetup in meshplotlib.setup. Custom levels and a colormap string can be provided. """ mesh_property = get_preset(mesh_property, mesh) if mesh.get_cell(0).dimension == 3: msg = "meshplotlib is for 2D meshes only, but found 3D elements." raise ValueError(msg) ax.axis("auto") if mesh_property.mask_used(mesh): subplot(mesh, mesh_property.get_mask(), ax) mesh = mesh.ctp(True).threshold( value=[1, 1], scalars=mesh_property.mask ) surf_tri = mesh.triangulate().extract_surface() # get projection x_id, y_id = get_projection(mesh) mean_normal = np.abs(np.mean(mesh.extract_surface().cell_normals, axis=0)) projection = int(np.argmax(mean_normal)) # faces contains a padding indicating number of points per face which gets # removed with this reshaping and slicing to get the array of tri's x, y = setup.length.transform(surf_tri.points.T[[x_id, y_id]]) tri = surf_tri.faces.reshape((-1, 4))[:, 1:] values = mesh_property.magnitude.transform(surf_tri) if setup.log_scaled: values_temp = np.where(values > 1e-14, values, 1e-14) values = np.log10(values_temp) p_min, p_max = np.nanmin(values), np.nanmax(values) if levels is None: num_levels = min(setup.num_levels, len(np.unique(values))) levels = compute_levels(p_min, p_max, num_levels) cmap, norm = get_cmap_norm(levels, mesh_property) if mesh_property.data_name in mesh.point_data: ax.tricontourf( # type: ignore[call-overload] x, y, tri, values, levels=levels, cmap=cmap, norm=norm ) if _q_zero_line(mesh_property, levels): ax.tricontour( # type: ignore[call-overload] x, y, tri, values, levels=[0], colors="w" ) else: ax.tripcolor(x, y, tri, facecolors=values, cmap=cmap, norm=norm) if mesh_property.is_mask(): ax.tripcolor(x, y, tri, facecolors=values, mask=(values == 1), cmap=cmap, norm=norm, hatch="/") # fmt: skip surf = mesh.extract_surface() show_edges = setup.show_element_edges if isinstance(setup.show_element_edges, str): show_edges = setup.show_element_edges == mesh_property.data_name if show_edges: pf.plot_element_edges(ax, surf, projection) if setup.show_region_bounds and "MaterialIDs" in mesh.cell_data: pf.plot_layer_boundaries(ax, surf, projection) if isinstance(mesh_property, Vector): pf.plot_streamlines(ax, surf_tri, mesh_property, projection) ax.margins(0, 0) # otherwise it shrinks the plot content if abs(max(mean_normal) - 1) > 1e-6: sec_id = np.argmax(np.delete(mean_normal, projection)) sec_labels = [] for tick in ax.get_xticks(): origin = np.array(mesh.center) origin[sec_id] = min( max(tick, mesh.bounds[2 * sec_id] + 1e-6), mesh.bounds[2 * sec_id + 1] - 1e-6, ) sec_mesh = mesh.slice("xyz"[sec_id], origin) if sec_mesh.n_cells: sec_labels += [f"{sec_mesh.bounds[2 * projection]:.1f}"] else: sec_labels += [""] # TODO: use a function to make this short secax = ax.secondary_xaxis("top") secax.xaxis.set_major_locator( mticker.FixedLocator(list(ax.get_xticks())) ) secax.set_xticklabels(sec_labels) secax.set_xlabel(f'{"xyz"[projection]} / {setup.length.output_unit}')
[docs] def clear_labels(axes: Union[plt.Axes, np.ndarray]) -> None: ax: plt.Axes for ax in np.ravel(np.array(axes)): ax.set_xlabel("") ax.set_ylabel("")
[docs] @typechecked def label_spatial_axes( axes: Union[plt.Axes, np.ndarray], x_label: str = "x", y_label: str = "y" ) -> None: """ Add labels to x and y axis. If given an array of axes, only the outer axes will be labeled. """ if isinstance(axes, np.ndarray): ax: plt.Axes for ax in axes[-1, :]: ax.set_xlabel(f"{x_label} / {setup.length.output_unit}") for ax in axes[:, 0]: ax.set_ylabel(f"{y_label} / {setup.length.output_unit}") else: axes.set_xlabel(f"{x_label} / {setup.length.output_unit}") axes.set_ylabel(f"{y_label} / {setup.length.output_unit}")
def _get_rows_cols( meshes: Union[ list[pv.UnstructuredGrid], np.ndarray, pv.UnstructuredGrid, pv.MultiBlock, ], ) -> tuple[int, ...]: if isinstance(meshes, np.ndarray): if meshes.ndim in [1, 2]: return meshes.shape msg = "Input numpy array must be 1D or 2D." raise ValueError(msg) if isinstance(meshes, list): return (1, len(meshes)) if isinstance(meshes, pv.MultiBlock): return (1, meshes.n_blocks) return (1, 1) # TODO: fixed_figure_size -> ax aspect automatic def _fig_init( rows: int, cols: int, aspect: float = 1.0 ) -> tuple[mfigure.Figure, plt.Axes]: nx_cb = 1 if setup.combined_colorbar else cols default_size = 8 cb_width = 3 y_label_width = 2 x_label_height = 1 figsize = setup.fig_scale * np.asarray( [ default_size * cols * aspect + cb_width * nx_cb + y_label_width, default_size * rows + x_label_height, ] ) fig, ax = plt.subplots( rows, cols, dpi=setup.dpi * setup.fig_scale, figsize=figsize, layout=setup.layout, sharex=True, sharey=True, ) fig.patch.set_alpha(1) return fig, ax
[docs] def get_combined_levels( meshes: np.ndarray, mesh_property: Union[Property, str] ) -> np.ndarray: """ Calculate well spaced levels for the encompassing property range in meshes. """ mesh_property = get_preset(mesh_property, meshes.ravel()[0]) p_min, p_max = np.inf, -np.inf unique_vals = np.array([]) for mesh in np.ravel(meshes): values = mesh_property.magnitude.transform(mesh) if setup.log_scaled: # TODO: can be improved values = np.log10(np.where(values > 1e-14, values, 1e-14)) p_min = min(p_min, np.nanmin(values)) if setup.p_min is None else p_min p_max = max(p_max, np.nanmax(values)) if setup.p_max is None else p_max unique_vals = np.unique( np.concatenate((unique_vals, np.unique(values))) ) p_min = setup.p_min if setup.p_min is not None else p_min p_max = setup.p_max if setup.p_max is not None else p_max if p_min == p_max: return np.array([p_min, p_max + 1e-12]) if ( all(val.is_integer() for val in unique_vals) and setup.p_min is None and setup.p_max is None ): return unique_vals[(p_min <= unique_vals) & (unique_vals <= p_max)] return compute_levels(p_min, p_max, setup.num_levels)
# TODO: Have a look at fig and ax logic and make it more readable def _draw_plot( meshes: Union[list[pv.UnstructuredGrid], np.ndarray, pv.UnstructuredGrid], mesh_property: Property, fig: Optional[mfigure.Figure] = None, axes: Optional[plt.Axes] = None, ) -> Optional[mfigure.Figure]: """ Plot the property field of meshes on existing figure. :param meshes: Singular mesh of 2D numpy array of meshes :param property: the property field to be visualized on all meshes :param fig: Matplotlib figure to use for plotting (optional) :param axes: Matplotlib Axes to use for plotting (optional) """ shape = _get_rows_cols(meshes) np_meshes = np.reshape(meshes, shape) if fig is not None and axes is not None: np_axs = np.reshape(np.array(axes), shape) elif fig is not None and axes is None: # Only Fig is given # Multiple meshes should be accepted warnings.warn( "This is not a good practice. Consider providing both fig and ax instead. This option may lead to unexpected behaviour and may be removed without warning in the future.", Warning, stacklevel=4, ) np_axs = np.reshape(np.asarray(fig.axes), shape) elif fig is None and axes is not None: # Only ax is given # Only one mesh should be accepted if shape != (1, 1): msg = "You have provided only one Axis object but multiple meshes. Provide only one mesh per Axis object, or provide Figure object instead." raise ValueError(msg) np_axs = np.reshape(np.array(axes), (1, 1)) else: msg = "Neither Figure nor Axis object was provided." raise TypeError(msg) if setup.combined_colorbar: combined_levels = get_combined_levels(np_meshes, mesh_property) for i in range(shape[0]): for j in range(shape[1]): _levels = ( combined_levels if setup.combined_colorbar else get_combined_levels(np_meshes[i, j, None], mesh_property) ) subplot(np_meshes[i, j], mesh_property, np_axs[i, j], _levels) x_id, y_id = get_projection( np_meshes[0, 0] ) # One mesh is sufficient, it should be the same for all of them label_spatial_axes(np_axs, "xyz"[x_id], "xyz"[y_id]) np_axs[0, 0].set_title(setup.title_center, loc="center", y=1.02) np_axs[0, 0].set_title(setup.title_left, loc="left", y=1.02) np_axs[0, 0].set_title(setup.title_right, loc="right", y=1.02) # make extra space for the upper limit of the colorbar if setup.layout == "tight": plt.tight_layout(pad=1.4) if setup.combined_colorbar: if fig is None: warnings.warn( "Cannot plot combined colorbar if Figure object is not provided!", Warning, stacklevel=4, ) else: cb_axs = np.ravel(np.asarray(fig.axes)).tolist() add_colorbars( fig, cb_axs, mesh_property, combined_levels, pad=0.05 / shape[1] ) else: # TODO: restructure this logic if fig is None: warnings.warn( "Figure object is required to plot individual colorbars for Axes objects.", Warning, stacklevel=4, ) else: for i in range(shape[0]): for j in range(shape[1]): _levels = get_combined_levels( np_meshes[i, j, None], mesh_property ) add_colorbars(fig, np_axs[i, j], mesh_property, _levels) return fig
[docs] def get_data_aspect(mesh: pv.UnstructuredGrid) -> float: """ Calculate the data aspect ratio of a 2D mesh. """ mean_normal = np.abs(np.mean(mesh.extract_surface().cell_normals, axis=0)) projection = int(np.argmax(mean_normal)) x_id, y_id = 2 * np.delete([0, 1, 2], projection) lims = mesh.bounds return abs(lims[x_id + 1] - lims[x_id]) / abs(lims[y_id + 1] - lims[y_id])
[docs] def update_font_sizes( fig: mfigure.Figure, fontsize: int = 20, ) -> mfigure.Figure: """ Update font sizes of lebels and ticks in all subplots :param fig: Matplotlib Figure object to use for plotting :param int: New font size for the labels and ticks (optional) """ x_label = f"X / {setup.length.output_unit}" y_label = f"Y / {setup.length.output_unit}" for subax in fig.axes: subax.set_xlabel(x_label, fontsize=fontsize) subax.set_ylabel(y_label, fontsize=fontsize) subax_xlim = subax.get_xlim() subax_ylim = subax.get_ylim() subax.set_xticks( subax.get_xticks(), [label.get_text() for label in subax.get_xticklabels()], fontsize=fontsize, ) subax.set_yticks( subax.get_yticks(), [label.get_text() for label in subax.get_yticklabels()], fontsize=fontsize, ) subax.set_xlim(subax_xlim) subax.set_ylim(subax_ylim) return fig
# TODO: add as arguments: cmap, limits # TODO: num_levels should be min_levels
[docs] def plot( meshes: Union[list[pv.UnstructuredGrid], np.ndarray, pv.UnstructuredGrid], mesh_property: Union[Property, str], fig: Optional[mfigure.Figure] = None, ax: Optional[plt.Axes] = None, ) -> Optional[mfigure.Figure]: """ Plot the property field of meshes with default settings. The resulting figure adheres to the configurations in meshplotlib.setup. For 2D, the whole domain, for 3D a set of slices is displayed. :param meshes: Singular mesh of 2D numpy array of meshes :param property: The property field to be visualized on all meshes :param fig: Matplotlib Figure object to use for plotting (optional) :param ax: Matplotlib Axis object to use for plotting (optional) """ rcParams.update(setup.rcParams_scaled) shape = _get_rows_cols(meshes) _meshes = np.reshape(meshes, shape).ravel() mesh_property = get_preset(mesh_property, _meshes[0]) data_aspects = np.asarray([get_data_aspect(mesh) for mesh in _meshes]) if setup.min_ax_aspect is None and setup.max_ax_aspect is None: fig_aspect = np.mean(data_aspects) else: fig_aspect = np.mean( np.clip(data_aspects, setup.min_ax_aspect, setup.max_ax_aspect) ) ax_aspects = fig_aspect / data_aspects n_axs = shape[0] * shape[1] if ax is None and fig is None: _fig, _ax = _fig_init(rows=shape[0], cols=shape[1], aspect=fig_aspect) fig = _draw_plot(meshes, mesh_property, fig=_fig, axes=_ax) assert isinstance(fig, plt.Figure) for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects): ax.set_aspect(1.0 / aspect) elif ax is not None and fig is None: _draw_plot(meshes, mesh_property, axes=ax) ax.set_aspect(1.0 / ax_aspects[0]) elif ax is None and fig is not None: fig = _draw_plot(meshes, mesh_property, fig=fig) assert isinstance(fig, plt.Figure) for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects): ax.set_aspect(1.0 / aspect) elif ax is not None and fig is not None: _draw_plot(meshes, mesh_property, fig=fig, axes=ax) for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects): ax.set_aspect(1.0 / aspect) return fig
[docs] def plot_probe( mesh_series: MeshSeries, points: np.ndarray, mesh_property: Union[Property, str], mesh_property_abscissa: Optional[Union[Property, str]] = None, labels: Optional[list[str]] = None, time_unit: Optional[str] = "s", interp_method: Optional[Literal["nearest", "linear", "probefilter"]] = None, interp_backend_pvd: Optional[Literal["vtk", "scipy"]] = None, colors: Optional[list] = None, linestyles: Optional[list] = None, ax: Optional[plt.Axes] = None, fill_between: bool = False, **kwargs: Any, ) -> Optional[mfigure.Figure]: """ Plot the transient property on the observation points in the MeshSeries. :param mesh_series: MeshSeries object containing the data to be plotted. :param points: The points to sample at. :param mesh_property: The property to be sampled. :param labels: The labels for each observation point. :param time_unit: Output unit of the timevalues. :param interp_method: Choose the interpolation method, defaults to `linear` for xdmf MeshSeries and `probefilter` for pvd MeshSeries. :param interp_backend: Interpolation backend for PVD MeshSeries. :param kwargs: Keyword arguments passed to matplotlib's plot function. :returns: A matplotlib Figure """ points = np.asarray(points) if len(points.shape) == 1: points = points[np.newaxis] mesh_property = get_preset(mesh_property, mesh_series.read(0)) values = mesh_property.magnitude.transform( mesh_series.probe( points, mesh_property.data_name, interp_method, interp_backend_pvd ) ) if values.shape[0] == 1: values = values.flatten() Q_ = u_reg.Quantity time_unit_conversion = Q_(Q_(mesh_series.time_unit), time_unit).magnitude if mesh_property_abscissa is None: x_values = time_unit_conversion * mesh_series.timevalues x_label = f"time / {time_unit}" if time_unit else "time" else: mesh_property_abscissa = get_preset( mesh_property_abscissa, mesh_series.read(0) ) x_values = mesh_property_abscissa.magnitude.transform( mesh_series.probe( points, mesh_property_abscissa.data_name, interp_method, interp_backend_pvd, ) ) x_unit_str = ( f" / {mesh_property_abscissa.get_output_unit()}" if mesh_property_abscissa.get_output_unit() else "" ) x_label = ( mesh_property_abscissa.output_name.replace("_", " ") + x_unit_str ) if ax is None: fig, ax = plt.subplots() else: fig = None ax.set_prop_cycle(get_style_cycler(len(points), colors, linestyles)) if fill_between: ax.fill_between( x_values, np.min(values, axis=-1), np.max(values, axis=-1), label=labels, **kwargs, ) else: ax.plot(x_values, values, label=labels, **kwargs) if labels is not None: ax.legend(facecolor="white", framealpha=1, prop={"family": "monospace"}) ax.set_axisbelow(True) ax.grid(which="major", color="lightgrey", linestyle="-") ax.grid(which="minor", color="0.95", linestyle="--") ax.set_xlabel(x_label) ax.set_ylabel(mesh_property.get_label()) ax.label_outer() ax.minorticks_on() return fig