#!/usr/bin/env python
"""
Plotting widget taken from QuickNXS.
#TODO: refactor this or replace it with a standard solution
"""
import inspect
import logging
import os
import pickle
import tempfile
from pathlib import Path
import matplotlib.colors
import numpy as np
from matplotlib.backends.backend_qt import NavigationToolbar2QT
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.collections import QuadMesh
from matplotlib.colors import LogNorm, Normalize
from matplotlib.figure import Figure
from qtpy import QtCore, QtGui, QtPrintSupport, QtWidgets
from quicknxs.config import plotting
try:
import matplotlib.backends.qt_editor.figureoptions as figureoptions
except ImportError:
figureoptions = None
cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
"default", ["#0000ff", "#00ff00", "#ffff00", "#ff0000", "#bd7efc", "#000000"], N=256
)
matplotlib.colormaps.register(cmap, name="default")
def _set_default_rc():
matplotlib.rc("font", **plotting.font)
matplotlib.rc("savefig", **plotting.savefig)
_set_default_rc()
# path where all of the icons are
ICON_DIR = Path(__file__).parent.parent.parent / "icons"
[docs]
def getIcon(filename: str) -> "QtGui.QIcon":
filename_full = ICON_DIR / filename
icon = QtGui.QIcon()
icon.addPixmap(QtGui.QPixmap(str(filename_full)), QtGui.QIcon.Normal, QtGui.QIcon.Off)
return icon
[docs]
def centerbins(xvals):
"""For a given numpy array of bin edges, return the bin centers."""
new_xvals = (xvals + np.roll(xvals, -1)) / 2
return np.delete(new_xvals, -1)
def _data_lines(ax):
"""Return lines that carry plotted data, excluding overlay markers (axvline/axhline)."""
return [line for line in ax.lines if line.get_transform() == ax.transData]
def _errorbar_containers(ax):
"""Return ErrorbarContainer objects from the axes, if any."""
from matplotlib.container import ErrorbarContainer
return [c for c in ax.containers if isinstance(c, ErrorbarContainer)]
def _detect_plot_type(ax):
"""Classify the current axes content.
Returns one of: ``"imshow"``, ``"pcolormesh"``, ``"errorbar"``, ``"line"``, ``"empty"``.
"""
if len(ax.images) > 0:
return "imshow"
if any(isinstance(c, QuadMesh) for c in ax.collections):
return "pcolormesh"
if len(_errorbar_containers(ax)) > 0:
return "errorbar"
lines = _data_lines(ax)
if len(lines) == 0:
return "empty"
return "line"
def _extract_errorbar_data(ax):
"""Extract X, Y, Error datasets from ErrorbarContainer objects."""
containers = _errorbar_containers(ax)
datasets = []
for container in containers:
data_line = container[0]
x = np.array(data_line.get_xdata(), dtype=float)
y = np.array(data_line.get_ydata(), dtype=float)
# Extract error from the vertical bar LineCollection segments
bar_collections = container[2]
if bar_collections:
segments = bar_collections[0].get_segments()
# Matplotlib stores empty (0,) segments for NaN/masked data points;
# return NaN as the error for those points.
error = np.array(
[
abs(seg[1, 1] - seg[0, 1]) / 2.0 if seg.ndim == 2 and seg.shape == (2, 2) else np.nan
for seg in segments
]
)
else:
error = np.zeros_like(y)
label = container.get_label() or f"dataset_{len(datasets)}"
datasets.append({"x": x, "y": y, "error": error, "label": label})
return {
"datasets": datasets,
"xlabel": ax.get_xlabel(),
"ylabel": ax.get_ylabel(),
"title": ax.get_title(),
"xscale": ax.get_xscale(),
"yscale": ax.get_yscale(),
}
def _extract_imshow_data(ax):
"""Extract the 2D array, extent, origin, norm, and colormap from an imshow plot."""
img = ax.images[0]
data = np.array(img.get_array(), dtype=float)
extent = np.array(img.get_extent())
norm = img.norm
return {
"data": data,
"extent": extent,
"origin": img.origin,
"cmap": img.get_cmap().name,
"norm": type(norm).__name__,
"norm_vmin": float(norm.vmin) if norm.vmin is not None else None,
"norm_vmax": float(norm.vmax) if norm.vmax is not None else None,
"xlabel": ax.get_xlabel(),
"ylabel": ax.get_ylabel(),
"title": ax.get_title(),
}
def _extract_pcolormesh_data(ax):
"""Extract mesh coordinates and Z values from all QuadMesh objects on the axes.
Handles both flat shading (1D edge arrays, z is one smaller per dim) and
gouraud shading (2D node grids, z matches coordinate dims). For gouraud
meshes the full 2D coordinate grids are preserved because off-specular and
GISANS data use irregular grids where each cell has unique (x, y).
Multiple surfaces (one per run file) are captured as a list.
"""
quadmeshes = [c for c in ax.collections if isinstance(c, QuadMesh)]
qm0 = quadmeshes[0]
norm = qm0.norm
surfaces = []
for qm in quadmeshes:
coords = qm.get_coordinates()
z_data = np.array(qm.get_array(), dtype=float)
is_gouraud = z_data.shape == (coords.shape[0], coords.shape[1])
if is_gouraud:
surfaces.append(
{
"x_grid": np.array(coords[:, :, 0], dtype=float),
"y_grid": np.array(coords[:, :, 1], dtype=float),
"z_data": z_data,
}
)
else:
x_coords_1d = coords[0, :, 0]
y_coords_1d = coords[:, 0, 1]
ny, nx = coords.shape[0] - 1, coords.shape[1] - 1
if z_data.ndim == 1:
z_data = z_data.reshape(ny, nx)
surfaces.append(
{
"x_edges": x_coords_1d,
"y_edges": y_coords_1d,
"x_centers": centerbins(x_coords_1d),
"y_centers": centerbins(y_coords_1d),
"z_data": z_data,
}
)
# Use first surface to determine shading type
first_coords = quadmeshes[0].get_coordinates()
first_z = np.array(quadmeshes[0].get_array(), dtype=float)
shading = "gouraud" if first_z.shape == (first_coords.shape[0], first_coords.shape[1]) else "flat"
return {
"surfaces": surfaces,
"shading": shading,
"xlabel": ax.get_xlabel(),
"ylabel": ax.get_ylabel(),
"title": ax.get_title(),
"cmap": qm0.get_cmap().name,
"norm": type(norm).__name__,
"norm_vmin": float(norm.vmin) if norm.vmin is not None else None,
"norm_vmax": float(norm.vmax) if norm.vmax is not None else None,
}
def _extract_line_data(ax):
"""Extract X, Y data from simple line plots (non-errorbar)."""
lines = _data_lines(ax)
datasets = []
for line in lines:
x = np.array(line.get_xdata(), dtype=float)
y = np.array(line.get_ydata(), dtype=float)
label = line.get_label() or f"dataset_{len(datasets)}"
datasets.append({"x": x, "y": y, "label": label})
return {
"datasets": datasets,
"xlabel": ax.get_xlabel(),
"ylabel": ax.get_ylabel(),
"title": ax.get_title(),
"xscale": ax.get_xscale(),
"yscale": ax.get_yscale(),
}
def _save_dat(fname, extracted, plot_type):
"""Save extracted data as ASCII table."""
if plot_type == "errorbar":
_save_dat_errorbar(fname, extracted)
elif plot_type == "line":
_save_dat_line(fname, extracted)
elif plot_type == "imshow":
_save_dat_imshow(fname, extracted)
elif plot_type == "pcolormesh":
_save_dat_pcolormesh(fname, extracted)
else:
raise ValueError(f"Cannot save plot type '{plot_type}' as .dat")
def _save_dat_errorbar(fname, extracted):
with open(fname, "w") as f:
f.write(f"# title: {extracted['title']}\n")
f.write(f"# xlabel: {extracted['xlabel']}\n")
f.write(f"# ylabel: {extracted['ylabel']}\n")
f.write(f"# xscale: {extracted.get('xscale', 'linear')}\n")
f.write(f"# yscale: {extracted.get('yscale', 'linear')}\n")
for i, ds in enumerate(extracted["datasets"]):
f.write(f"# Dataset {i}: {ds['label']}\n")
f.write(f"# {extracted['xlabel']}\t{extracted['ylabel']}\tError\n")
block = np.column_stack([ds["x"], ds["y"], ds["error"]])
np.savetxt(f, block, delimiter="\t")
if i < len(extracted["datasets"]) - 1:
f.write("\n\n")
def _save_dat_line(fname, extracted):
with open(fname, "w") as f:
f.write(f"# title: {extracted['title']}\n")
f.write(f"# xlabel: {extracted['xlabel']}\n")
f.write(f"# ylabel: {extracted['ylabel']}\n")
f.write(f"# xscale: {extracted.get('xscale', 'linear')}\n")
f.write(f"# yscale: {extracted.get('yscale', 'linear')}\n")
for i, ds in enumerate(extracted["datasets"]):
f.write(f"# Dataset {i}: {ds['label']}\n")
f.write(f"# {extracted['xlabel']}\t{extracted['ylabel']}\n")
block = np.column_stack([ds["x"], ds["y"]])
np.savetxt(f, block, delimiter="\t")
if i < len(extracted["datasets"]) - 1:
f.write("\n\n")
def _save_dat_imshow(fname, extracted):
header_lines = [
f"title: {extracted['title']}",
f"xlabel: {extracted['xlabel']}",
f"ylabel: {extracted['ylabel']}",
f"extent: xmin={extracted['extent'][0]}, xmax={extracted['extent'][1]}, "
f"ymin={extracted['extent'][2]}, ymax={extracted['extent'][3]}",
f"origin: {extracted['origin']}",
f"cmap: {extracted.get('cmap', 'default')}",
f"norm: {extracted.get('norm', 'Normalize')}",
f"norm_vmin: {extracted.get('norm_vmin', '')}",
f"norm_vmax: {extracted.get('norm_vmax', '')}",
]
np.savetxt(fname, extracted["data"], header="\n".join(header_lines), delimiter="\t")
def _save_dat_pcolormesh(fname, extracted):
"""Save pcolormesh data in gnuplot splot xyz format.
Each row is ``x y z``. Blank lines separate grid rows within a surface.
Two blank lines separate successive surfaces. Multiple surfaces arise
when several run files are overlaid on the same axes.
"""
surfaces = extracted["surfaces"]
shading = extracted.get("shading", "flat")
ny, nx = surfaces[0]["z_data"].shape
with open(fname, "w") as f:
f.write(f"# title: {extracted['title']}\n")
f.write(f"# xlabel: {extracted['xlabel']}\n")
f.write(f"# ylabel: {extracted['ylabel']}\n")
f.write(f"# shading: {shading}\n")
f.write(f"# n_surfaces: {len(surfaces)}\n")
f.write(f"# grid: {ny} {nx}\n")
f.write(f"# cmap: {extracted.get('cmap', 'default')}\n")
f.write(f"# norm: {extracted.get('norm', 'Normalize')}\n")
f.write(f"# norm_vmin: {extracted.get('norm_vmin', '')}\n")
f.write(f"# norm_vmax: {extracted.get('norm_vmax', '')}\n")
f.write(f"# {extracted['xlabel']}\t{extracted['ylabel']}\tZ\n")
for si, surf in enumerate(surfaces):
z_data = surf["z_data"]
s_ny, s_nx = z_data.shape
if shading == "gouraud":
x_grid = surf["x_grid"]
y_grid = surf["y_grid"]
for iy in range(s_ny):
f.writelines(
f"{x_grid[iy, ix]:.6g}\t{y_grid[iy, ix]:.6g}\t{z_data[iy, ix]:.6g}\n" for ix in range(s_nx)
)
if iy < s_ny - 1:
f.write("\n")
else:
x_vals = surf["x_centers"]
y_vals = surf["y_centers"]
for ix, xc in enumerate(x_vals):
for iy, yc in enumerate(y_vals):
f.write(f"{xc:.6g}\t{yc:.6g}\t{z_data[iy, ix]:.6g}\n")
if ix < len(x_vals) - 1:
f.write("\n")
if si < len(surfaces) - 1:
f.write("\n\n")
def _save_npz(fname, extracted, plot_type):
"""Save extracted data as a compressed numpy archive."""
save_dict = {"plot_type": np.array(plot_type)}
if plot_type in ("errorbar", "line"):
save_dict["n_datasets"] = np.array(len(extracted["datasets"]))
for i, ds in enumerate(extracted["datasets"]):
save_dict[f"x_{i}"] = ds["x"]
save_dict[f"y_{i}"] = ds["y"]
if "error" in ds:
save_dict[f"error_{i}"] = ds["error"]
save_dict[f"label_{i}"] = np.array(ds["label"])
save_dict["xscale"] = np.array(extracted.get("xscale", "linear"))
save_dict["yscale"] = np.array(extracted.get("yscale", "linear"))
elif plot_type == "imshow":
save_dict["data"] = extracted["data"]
save_dict["extent"] = extracted["extent"]
save_dict["origin"] = np.array(extracted["origin"])
save_dict["cmap"] = np.array(extracted.get("cmap", "default"))
save_dict["norm"] = np.array(extracted.get("norm", "Normalize"))
if extracted.get("norm_vmin") is not None:
save_dict["norm_vmin"] = np.array(extracted["norm_vmin"])
if extracted.get("norm_vmax") is not None:
save_dict["norm_vmax"] = np.array(extracted["norm_vmax"])
elif plot_type == "pcolormesh":
shading = extracted.get("shading", "flat")
surfaces = extracted["surfaces"]
save_dict["shading"] = np.array(shading)
save_dict["n_surfaces"] = np.array(len(surfaces))
save_dict["cmap"] = np.array(extracted.get("cmap", "default"))
save_dict["norm"] = np.array(extracted.get("norm", "Normalize"))
if extracted.get("norm_vmin") is not None:
save_dict["norm_vmin"] = np.array(extracted["norm_vmin"])
if extracted.get("norm_vmax") is not None:
save_dict["norm_vmax"] = np.array(extracted["norm_vmax"])
for i, surf in enumerate(surfaces):
save_dict[f"z_data_{i}"] = surf["z_data"]
if shading == "gouraud":
save_dict[f"x_grid_{i}"] = surf["x_grid"]
save_dict[f"y_grid_{i}"] = surf["y_grid"]
else:
save_dict[f"x_edges_{i}"] = surf["x_edges"]
save_dict[f"y_edges_{i}"] = surf["y_edges"]
save_dict["xlabel"] = np.array(extracted.get("xlabel", ""))
save_dict["ylabel"] = np.array(extracted.get("ylabel", ""))
save_dict["title"] = np.array(extracted.get("title", ""))
np.savez_compressed(fname, **save_dict)
def _save_pkl(fname, figure, extracted, plot_type):
"""Save figure and extracted data as a pickle file."""
payload = {
"figure": figure,
"plot_type": plot_type,
"data": extracted,
}
with open(fname, "wb") as f:
pickle.dump(payload, f, protocol=4)
[docs]
class MplCanvas(FigureCanvas):
"""A canvas for matplotlib figures, used in the MPLWidget."""
def __init__(self, parent=None, width=3, height=3, dpi=100, sharex=None, sharey=None, adjust={}):
self.fig = Figure(figsize=(width, height), dpi=dpi, facecolor="None")
self.ax = self.fig.add_subplot(111, sharex=sharex, sharey=sharey)
self.fig.subplots_adjust(left=0.15, bottom=0.15, right=0.95, top=0.95)
self.xtitle = ""
self.ytitle = ""
self.PlotTitle = ""
self.grid_status = True
self.xaxis_style = "linear"
self.yaxis_style = "linear"
self.format_labels()
FigureCanvas.__init__(self, self.fig)
FigureCanvas.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)
FigureCanvas.updateGeometry(self)
[docs]
def sizeHint(self):
w, h = self.get_width_height()
w = max(w, self.height())
h = max(h, self.width())
return QtCore.QSize(w, h)
[docs]
def minimumSizeHint(self):
return QtCore.QSize(40, 40)
[docs]
def get_default_filetype(self):
"""Return the default file type for saving figures (``'png'``)."""
return "png"