Source code for quicknxs.views.widgets.mplwidget

#!/usr/bin/env python
"""
Plotting widget taken from QuickNXS.

#TODO: refactor this or replace it with a standard solution
"""

import inspect
import logging
import os
import pickle
import tempfile
from pathlib import Path

import matplotlib.colors
import numpy as np
from matplotlib.backends.backend_qt import NavigationToolbar2QT
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.collections import QuadMesh
from matplotlib.colors import LogNorm, Normalize
from matplotlib.figure import Figure
from qtpy import QtCore, QtGui, QtPrintSupport, QtWidgets

from quicknxs.config import plotting

try:
    import matplotlib.backends.qt_editor.figureoptions as figureoptions
except ImportError:
    figureoptions = None

cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
    "default", ["#0000ff", "#00ff00", "#ffff00", "#ff0000", "#bd7efc", "#000000"], N=256
)
matplotlib.colormaps.register(cmap, name="default")


def _set_default_rc():
    matplotlib.rc("font", **plotting.font)
    matplotlib.rc("savefig", **plotting.savefig)


_set_default_rc()

# path where all of the icons are
ICON_DIR = Path(__file__).parent.parent.parent / "icons"


[docs] def getIcon(filename: str) -> "QtGui.QIcon": filename_full = ICON_DIR / filename icon = QtGui.QIcon() icon.addPixmap(QtGui.QPixmap(str(filename_full)), QtGui.QIcon.Normal, QtGui.QIcon.Off) return icon
[docs] def centerbins(xvals): """For a given numpy array of bin edges, return the bin centers.""" new_xvals = (xvals + np.roll(xvals, -1)) / 2 return np.delete(new_xvals, -1)
def _data_lines(ax): """Return lines that carry plotted data, excluding overlay markers (axvline/axhline).""" return [line for line in ax.lines if line.get_transform() == ax.transData] def _errorbar_containers(ax): """Return ErrorbarContainer objects from the axes, if any.""" from matplotlib.container import ErrorbarContainer return [c for c in ax.containers if isinstance(c, ErrorbarContainer)] def _detect_plot_type(ax): """Classify the current axes content. Returns one of: ``"imshow"``, ``"pcolormesh"``, ``"errorbar"``, ``"line"``, ``"empty"``. """ if len(ax.images) > 0: return "imshow" if any(isinstance(c, QuadMesh) for c in ax.collections): return "pcolormesh" if len(_errorbar_containers(ax)) > 0: return "errorbar" lines = _data_lines(ax) if len(lines) == 0: return "empty" return "line" def _extract_errorbar_data(ax): """Extract X, Y, Error datasets from ErrorbarContainer objects.""" containers = _errorbar_containers(ax) datasets = [] for container in containers: data_line = container[0] x = np.array(data_line.get_xdata(), dtype=float) y = np.array(data_line.get_ydata(), dtype=float) # Extract error from the vertical bar LineCollection segments bar_collections = container[2] if bar_collections: segments = bar_collections[0].get_segments() # Matplotlib stores empty (0,) segments for NaN/masked data points; # return NaN as the error for those points. error = np.array( [ abs(seg[1, 1] - seg[0, 1]) / 2.0 if seg.ndim == 2 and seg.shape == (2, 2) else np.nan for seg in segments ] ) else: error = np.zeros_like(y) label = container.get_label() or f"dataset_{len(datasets)}" datasets.append({"x": x, "y": y, "error": error, "label": label}) return { "datasets": datasets, "xlabel": ax.get_xlabel(), "ylabel": ax.get_ylabel(), "title": ax.get_title(), "xscale": ax.get_xscale(), "yscale": ax.get_yscale(), } def _extract_imshow_data(ax): """Extract the 2D array, extent, origin, norm, and colormap from an imshow plot.""" img = ax.images[0] data = np.array(img.get_array(), dtype=float) extent = np.array(img.get_extent()) norm = img.norm return { "data": data, "extent": extent, "origin": img.origin, "cmap": img.get_cmap().name, "norm": type(norm).__name__, "norm_vmin": float(norm.vmin) if norm.vmin is not None else None, "norm_vmax": float(norm.vmax) if norm.vmax is not None else None, "xlabel": ax.get_xlabel(), "ylabel": ax.get_ylabel(), "title": ax.get_title(), } def _extract_pcolormesh_data(ax): """Extract mesh coordinates and Z values from all QuadMesh objects on the axes. Handles both flat shading (1D edge arrays, z is one smaller per dim) and gouraud shading (2D node grids, z matches coordinate dims). For gouraud meshes the full 2D coordinate grids are preserved because off-specular and GISANS data use irregular grids where each cell has unique (x, y). Multiple surfaces (one per run file) are captured as a list. """ quadmeshes = [c for c in ax.collections if isinstance(c, QuadMesh)] qm0 = quadmeshes[0] norm = qm0.norm surfaces = [] for qm in quadmeshes: coords = qm.get_coordinates() z_data = np.array(qm.get_array(), dtype=float) is_gouraud = z_data.shape == (coords.shape[0], coords.shape[1]) if is_gouraud: surfaces.append( { "x_grid": np.array(coords[:, :, 0], dtype=float), "y_grid": np.array(coords[:, :, 1], dtype=float), "z_data": z_data, } ) else: x_coords_1d = coords[0, :, 0] y_coords_1d = coords[:, 0, 1] ny, nx = coords.shape[0] - 1, coords.shape[1] - 1 if z_data.ndim == 1: z_data = z_data.reshape(ny, nx) surfaces.append( { "x_edges": x_coords_1d, "y_edges": y_coords_1d, "x_centers": centerbins(x_coords_1d), "y_centers": centerbins(y_coords_1d), "z_data": z_data, } ) # Use first surface to determine shading type first_coords = quadmeshes[0].get_coordinates() first_z = np.array(quadmeshes[0].get_array(), dtype=float) shading = "gouraud" if first_z.shape == (first_coords.shape[0], first_coords.shape[1]) else "flat" return { "surfaces": surfaces, "shading": shading, "xlabel": ax.get_xlabel(), "ylabel": ax.get_ylabel(), "title": ax.get_title(), "cmap": qm0.get_cmap().name, "norm": type(norm).__name__, "norm_vmin": float(norm.vmin) if norm.vmin is not None else None, "norm_vmax": float(norm.vmax) if norm.vmax is not None else None, } def _extract_line_data(ax): """Extract X, Y data from simple line plots (non-errorbar).""" lines = _data_lines(ax) datasets = [] for line in lines: x = np.array(line.get_xdata(), dtype=float) y = np.array(line.get_ydata(), dtype=float) label = line.get_label() or f"dataset_{len(datasets)}" datasets.append({"x": x, "y": y, "label": label}) return { "datasets": datasets, "xlabel": ax.get_xlabel(), "ylabel": ax.get_ylabel(), "title": ax.get_title(), "xscale": ax.get_xscale(), "yscale": ax.get_yscale(), } def _save_dat(fname, extracted, plot_type): """Save extracted data as ASCII table.""" if plot_type == "errorbar": _save_dat_errorbar(fname, extracted) elif plot_type == "line": _save_dat_line(fname, extracted) elif plot_type == "imshow": _save_dat_imshow(fname, extracted) elif plot_type == "pcolormesh": _save_dat_pcolormesh(fname, extracted) else: raise ValueError(f"Cannot save plot type '{plot_type}' as .dat") def _save_dat_errorbar(fname, extracted): with open(fname, "w") as f: f.write(f"# title: {extracted['title']}\n") f.write(f"# xlabel: {extracted['xlabel']}\n") f.write(f"# ylabel: {extracted['ylabel']}\n") f.write(f"# xscale: {extracted.get('xscale', 'linear')}\n") f.write(f"# yscale: {extracted.get('yscale', 'linear')}\n") for i, ds in enumerate(extracted["datasets"]): f.write(f"# Dataset {i}: {ds['label']}\n") f.write(f"# {extracted['xlabel']}\t{extracted['ylabel']}\tError\n") block = np.column_stack([ds["x"], ds["y"], ds["error"]]) np.savetxt(f, block, delimiter="\t") if i < len(extracted["datasets"]) - 1: f.write("\n\n") def _save_dat_line(fname, extracted): with open(fname, "w") as f: f.write(f"# title: {extracted['title']}\n") f.write(f"# xlabel: {extracted['xlabel']}\n") f.write(f"# ylabel: {extracted['ylabel']}\n") f.write(f"# xscale: {extracted.get('xscale', 'linear')}\n") f.write(f"# yscale: {extracted.get('yscale', 'linear')}\n") for i, ds in enumerate(extracted["datasets"]): f.write(f"# Dataset {i}: {ds['label']}\n") f.write(f"# {extracted['xlabel']}\t{extracted['ylabel']}\n") block = np.column_stack([ds["x"], ds["y"]]) np.savetxt(f, block, delimiter="\t") if i < len(extracted["datasets"]) - 1: f.write("\n\n") def _save_dat_imshow(fname, extracted): header_lines = [ f"title: {extracted['title']}", f"xlabel: {extracted['xlabel']}", f"ylabel: {extracted['ylabel']}", f"extent: xmin={extracted['extent'][0]}, xmax={extracted['extent'][1]}, " f"ymin={extracted['extent'][2]}, ymax={extracted['extent'][3]}", f"origin: {extracted['origin']}", f"cmap: {extracted.get('cmap', 'default')}", f"norm: {extracted.get('norm', 'Normalize')}", f"norm_vmin: {extracted.get('norm_vmin', '')}", f"norm_vmax: {extracted.get('norm_vmax', '')}", ] np.savetxt(fname, extracted["data"], header="\n".join(header_lines), delimiter="\t") def _save_dat_pcolormesh(fname, extracted): """Save pcolormesh data in gnuplot splot xyz format. Each row is ``x y z``. Blank lines separate grid rows within a surface. Two blank lines separate successive surfaces. Multiple surfaces arise when several run files are overlaid on the same axes. """ surfaces = extracted["surfaces"] shading = extracted.get("shading", "flat") ny, nx = surfaces[0]["z_data"].shape with open(fname, "w") as f: f.write(f"# title: {extracted['title']}\n") f.write(f"# xlabel: {extracted['xlabel']}\n") f.write(f"# ylabel: {extracted['ylabel']}\n") f.write(f"# shading: {shading}\n") f.write(f"# n_surfaces: {len(surfaces)}\n") f.write(f"# grid: {ny} {nx}\n") f.write(f"# cmap: {extracted.get('cmap', 'default')}\n") f.write(f"# norm: {extracted.get('norm', 'Normalize')}\n") f.write(f"# norm_vmin: {extracted.get('norm_vmin', '')}\n") f.write(f"# norm_vmax: {extracted.get('norm_vmax', '')}\n") f.write(f"# {extracted['xlabel']}\t{extracted['ylabel']}\tZ\n") for si, surf in enumerate(surfaces): z_data = surf["z_data"] s_ny, s_nx = z_data.shape if shading == "gouraud": x_grid = surf["x_grid"] y_grid = surf["y_grid"] for iy in range(s_ny): f.writelines( f"{x_grid[iy, ix]:.6g}\t{y_grid[iy, ix]:.6g}\t{z_data[iy, ix]:.6g}\n" for ix in range(s_nx) ) if iy < s_ny - 1: f.write("\n") else: x_vals = surf["x_centers"] y_vals = surf["y_centers"] for ix, xc in enumerate(x_vals): for iy, yc in enumerate(y_vals): f.write(f"{xc:.6g}\t{yc:.6g}\t{z_data[iy, ix]:.6g}\n") if ix < len(x_vals) - 1: f.write("\n") if si < len(surfaces) - 1: f.write("\n\n") def _save_npz(fname, extracted, plot_type): """Save extracted data as a compressed numpy archive.""" save_dict = {"plot_type": np.array(plot_type)} if plot_type in ("errorbar", "line"): save_dict["n_datasets"] = np.array(len(extracted["datasets"])) for i, ds in enumerate(extracted["datasets"]): save_dict[f"x_{i}"] = ds["x"] save_dict[f"y_{i}"] = ds["y"] if "error" in ds: save_dict[f"error_{i}"] = ds["error"] save_dict[f"label_{i}"] = np.array(ds["label"]) save_dict["xscale"] = np.array(extracted.get("xscale", "linear")) save_dict["yscale"] = np.array(extracted.get("yscale", "linear")) elif plot_type == "imshow": save_dict["data"] = extracted["data"] save_dict["extent"] = extracted["extent"] save_dict["origin"] = np.array(extracted["origin"]) save_dict["cmap"] = np.array(extracted.get("cmap", "default")) save_dict["norm"] = np.array(extracted.get("norm", "Normalize")) if extracted.get("norm_vmin") is not None: save_dict["norm_vmin"] = np.array(extracted["norm_vmin"]) if extracted.get("norm_vmax") is not None: save_dict["norm_vmax"] = np.array(extracted["norm_vmax"]) elif plot_type == "pcolormesh": shading = extracted.get("shading", "flat") surfaces = extracted["surfaces"] save_dict["shading"] = np.array(shading) save_dict["n_surfaces"] = np.array(len(surfaces)) save_dict["cmap"] = np.array(extracted.get("cmap", "default")) save_dict["norm"] = np.array(extracted.get("norm", "Normalize")) if extracted.get("norm_vmin") is not None: save_dict["norm_vmin"] = np.array(extracted["norm_vmin"]) if extracted.get("norm_vmax") is not None: save_dict["norm_vmax"] = np.array(extracted["norm_vmax"]) for i, surf in enumerate(surfaces): save_dict[f"z_data_{i}"] = surf["z_data"] if shading == "gouraud": save_dict[f"x_grid_{i}"] = surf["x_grid"] save_dict[f"y_grid_{i}"] = surf["y_grid"] else: save_dict[f"x_edges_{i}"] = surf["x_edges"] save_dict[f"y_edges_{i}"] = surf["y_edges"] save_dict["xlabel"] = np.array(extracted.get("xlabel", "")) save_dict["ylabel"] = np.array(extracted.get("ylabel", "")) save_dict["title"] = np.array(extracted.get("title", "")) np.savez_compressed(fname, **save_dict) def _save_pkl(fname, figure, extracted, plot_type): """Save figure and extracted data as a pickle file.""" payload = { "figure": figure, "plot_type": plot_type, "data": extracted, } with open(fname, "wb") as f: pickle.dump(payload, f, protocol=4)
[docs] class MplCanvas(FigureCanvas): """A canvas for matplotlib figures, used in the MPLWidget.""" def __init__(self, parent=None, width=3, height=3, dpi=100, sharex=None, sharey=None, adjust={}): self.fig = Figure(figsize=(width, height), dpi=dpi, facecolor="None") self.ax = self.fig.add_subplot(111, sharex=sharex, sharey=sharey) self.fig.subplots_adjust(left=0.15, bottom=0.15, right=0.95, top=0.95) self.xtitle = "" self.ytitle = "" self.PlotTitle = "" self.grid_status = True self.xaxis_style = "linear" self.yaxis_style = "linear" self.format_labels() FigureCanvas.__init__(self, self.fig) FigureCanvas.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) FigureCanvas.updateGeometry(self)
[docs] def format_labels(self): self.ax.set_title(self.PlotTitle)
[docs] def sizeHint(self): w, h = self.get_width_height() w = max(w, self.height()) h = max(h, self.width()) return QtCore.QSize(w, h)
[docs] def minimumSizeHint(self): return QtCore.QSize(40, 40)
[docs] def get_default_filetype(self): """Return the default file type for saving figures (``'png'``).""" return "png"
[docs] class MPLWidget(QtWidgets.QWidget): """A widget for displaying matplotlib plots, with a navigation toolbar.""" cplot = None cbar = None def __init__(self, parent=None, with_toolbar=True, coordinates=False): QtWidgets.QWidget.__init__(self, parent) self.canvas = MplCanvas() self.canvas.ax2 = None self.vbox = QtWidgets.QVBoxLayout() self.vbox.addWidget(self.canvas) if with_toolbar: self.stacked_toolbars = QtWidgets.QStackedWidget(self.canvas) self.stacked_toolbars.setSizePolicy(QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Maximum) toolbar_generic = NavigationToolbarGeneric(self.canvas, self) toolbar_generic.coordinates = coordinates toolbar_refl = NavigationToolbarReflectivity(self.canvas, self) toolbar_refl.coordinates = coordinates self.stacked_toolbars.addWidget(toolbar_generic) self.stacked_toolbars.addWidget(toolbar_refl) self.toolbar = self.stacked_toolbars.currentWidget() self.vbox.addWidget(self.stacked_toolbars) else: self.toolbar = None self.setLayout(self.vbox)
[docs] def sync_toolbar_view(self, clear_history=False): """Ensure navigation toolbar state matches the current plot.""" if not self.toolbar: return if clear_history: self.toolbar._views.clear() self.toolbar._positions.clear() self.toolbar.push_current() self.canvas.draw() self.toolbar.update()
[docs] def leaveEvent(self, event): """Make sure the cursor is reset to it's default when leaving the widget. In some cases the zoom cursor does not reset when leaving the plot. """ if self.toolbar: QtWidgets.QApplication.restoreOverrideCursor() self.toolbar._lastCursor = None return QtWidgets.QWidget.leaveEvent(self, event)
[docs] def set_config(self, config): self.canvas.fig.subplots_adjust(**config)
[docs] def get_config(self): spp = self.canvas.fig.subplotpars config = dict(left=spp.left, right=spp.right, bottom=spp.bottom, top=spp.top) return config
[docs] def draw(self): """Convenience to redraw the graph.""" self.canvas.fig.tight_layout() self.canvas.draw()
[docs] def plot(self, *args, **opts): """Convenience wrapper for self.canvas.ax.plot.""" result = self.canvas.ax.plot(*args, **opts) self.sync_toolbar_view() return result
[docs] def semilogy(self, *args, **opts): """Convenience wrapper for self.canvas.ax.semilogy.""" result = self.canvas.ax.semilogy(*args, **opts) self.sync_toolbar_view() return result
[docs] def errorbar(self, *args, **opts): """Convenience wrapper for self.canvas.ax.errorbar.""" if self.toolbar: # change to toolbar with reflectivity-specific options self.stacked_toolbars.setCurrentIndex(1) self.toolbar = self.stacked_toolbars.currentWidget() if "fmt" in opts: set_linestyle = False elif "linestyle" in opts: set_linestyle = False elif "ls" in opts: set_linestyle = False else: set_linestyle = True if set_linestyle: self.toolbar.calling_function = str(inspect.stack()[1][3]) setting = QtCore.QSettings(".quicknxs") ls = setting.value(self.toolbar.calling_function + "/linestyle", "-") opts["ls"] = str(ls) result = self.canvas.ax.errorbar(*args, **opts) self.sync_toolbar_view() return result
[docs] def pcolormesh(self, datax, datay, dataz, log=False, imin=None, imax=None, update=False, **opts): """Convenience wrapper for self.canvas.ax.plot.""" if self.cplot is None or not update: if log: self.cplot = self.canvas.ax.pcolormesh(datax, datay, dataz, norm=LogNorm(imin, imax), **opts) else: self.cplot = self.canvas.ax.pcolormesh(datax, datay, dataz, **opts) else: self.update(datax, datay, dataz) self.sync_toolbar_view() return self.cplot
[docs] def imshow(self, data, log=False, imin=None, imax=None, update=True, **opts): """Convenience wrapper for self.canvas.ax.plot.""" if self.cplot is None or not update: if log: self.cplot = self.canvas.ax.imshow(data, norm=LogNorm(imin, imax), **opts) else: self.cplot = self.canvas.ax.imshow(data, **opts) else: self.update(data, **opts) self.sync_toolbar_view() return self.cplot
[docs] def set_title(self, new_title, fontsize=None): return self.canvas.ax.set_title(new_title, fontsize=fontsize)
[docs] def set_xlabel(self, label, fontsize=None): return self.canvas.ax.set_xlabel(label, fontsize=fontsize)
[docs] def set_ylabel(self, label, fontsize=None): return self.canvas.ax.set_ylabel(label, fontsize=fontsize)
[docs] def set_xticks_fontsize(self, fontsize): for label in self.canvas.ax.get_xticklabels(): label.set_fontsize(fontsize)
[docs] def set_yticks_fontsize(self, fontsize): for label in self.canvas.ax.get_yticklabels(): label.set_fontsize(fontsize)
[docs] def set_xscale(self, scale): try: return self.canvas.ax.set_xscale(scale) except ValueError: pass
[docs] def set_yscale(self, scale): try: return self.canvas.ax.set_yscale(scale) except ValueError: pass
[docs] def clear_fig(self): self.cplot = None self.cbar = None self.canvas.fig.clear() self.canvas.ax = self.canvas.fig.add_subplot(111, sharex=None, sharey=None)
[docs] def clear(self): self.cplot = None self.canvas.ax.clear() if self.canvas.ax2 is not None: self.canvas.ax2.clear()
[docs] def update(self, *data, **opts): self.cplot.set_data(*data) if "extent" in opts: self.cplot.set_extent(opts["extent"])
[docs] def legend(self, *args, **opts): handles, labels = self.canvas.ax.get_legend_handles_labels() if labels: return self.canvas.ax.legend(*args, **opts)
[docs] def adjust(self, **adjustment): result = self.canvas.fig.subplots_adjust(**adjustment) self.sync_toolbar_view() return result