"""
Container class for instrument-specific data handling.
Abstracts out how we obtaininformation from the data file
"""
# pylint: disable=invalid-name, too-many-instance-attributes, line-too-long, bare-except
import logging
import math
from typing import TYPE_CHECKING, List, Optional
import mantid.simpleapi as api
import numpy as np
from mantid.dataobjects import EventWorkspace
from mr_reduction.dead_time_correction import apply_dead_time_correction
from mr_reduction.filter_events import split_error_events, split_events
from mr_reduction.settings import PolarizationLogs
from quicknxs.interfaces.data_handling.filepath import FilePath
if TYPE_CHECKING:
from quicknxs.interfaces.configuration import Configuration
from quicknxs.interfaces.data_handling.data_set import CrossSectionData
# Constants
h = 6.626e-34 # m^2 kg s^-1
m = 1.675e-27 # kg
UNPOLARIZED_XS_LABEL = "Off_Off"
[docs]
def get_cross_section_label(ws, entry_name):
"""Return the proper cross-section label."""
entry_name = str(entry_name)
pol_is_on = entry_name.lower().startswith("on")
ana_is_on = entry_name.lower().endswith("on")
pol_label = ""
ana_label = ""
# Look for log that define whether OFF or ON is +
if "PolarizerLabel" in ws.getRun():
pol_id = ws.getRun().getProperty("PolarizerLabel").value
if isinstance(pol_id, np.ndarray):
pol_id = int(pol_id[0])
if pol_id == 1:
pol_label = "+" if pol_is_on else "-"
elif pol_id == 0:
pol_label = "-" if pol_is_on else "+"
if "AnalyzerLabel" in ws.getRun():
ana_id = ws.getRun().getProperty("AnalyzerLabel").value
if isinstance(ana_id, np.ndarray):
ana_id = int(ana_id[0])
if ana_id == 1:
ana_label = "+" if ana_is_on else "-"
elif ana_id == 0:
ana_label = "-" if ana_is_on else "-"
entry_name = entry_name.replace("_", "-")
if ana_label == "" and pol_label == "":
return entry_name
else:
return "%s%s" % (pol_label, ana_label)
[docs]
def remove_low_event_workspaces(ws_list, nbr_events_cutoff):
"""
Removes workspaces with number of events below the cutoff from a list of workspaces.
Parameters
----------
ws_list: list[EventWorkspace]
nbr_events_cutoff: int
Minimum number of events
Returns
-------
List[EventWorkspace]
Input list with low event workspaces removed
"""
pruned_list = []
for ws in ws_list:
xs_name = ws.getRun()["cross_section_id"].value
if ws.getNumberEvents() < nbr_events_cutoff:
logging.warning("Too few events for %s: %s", xs_name, ws.getNumberEvents())
else:
pruned_list.append(ws)
return pruned_list
[docs]
class InsufficientEventCountError(Exception):
"""Exception raised when the number of events in the workspace is too low"""
pass
[docs]
class Instrument(object):
"""Instrument class. Holds the data handling that is unique to a specific instrument."""
n_x_pixel: int = 304
n_y_pixel: int = 256
huber_x_cut: float = 6.5
peak_range_offset: int = 50
tolerance: float = 0.05
pixel_width: float = 0.0007
instrument_name: str = "REF_M"
instrument_dir: str = "/SNS/REF_M"
file_search_template: str = "/SNS/REF_M/*/nexus/REF_M_%s"
legacy_search_template: str = "/SNS/REF_M/*/data/REF_M_%s"
# Option to use the slow flipper logs rather than the Analyzer/Polarizer logs
USE_SLOW_FLIPPER_LOG: bool = False
def __init__(self):
# Filtering
self.pol_state = "PolarizerState"
self.pol_veto = "PolarizerVeto"
self.ana_state = "AnalyzerState"
self.ana_veto = "AnalyzerVeto"
def _get_xs_list(self, file_path: str, ws_root_name: str, configuration: "Configuration") -> List[EventWorkspace]:
"""Load the cross-sections from a data file. Handles both pre- and post-epics data.
Parameters
----------
file_path:
Path to the data file
ws_root_name:
Root name of the workspace (used to rename the cross-sections after loading)
configuration:
Reduction parameters
Returns
-------
List[EventWorkspace]
List of cross-section workspaces
Raises
------
InsufficientEventCountError
If the data file does not contain enough events
"""
use_slow_flipper_log = self.USE_SLOW_FLIPPER_LOG
xs_list = []
# Determine if we need to use slow flipper log for post-epics data
is_pre_epics = file_path.endswith(".nxs")
if not is_pre_epics and file_path.endswith(".nxs.h5"):
# Check metadata for post-epics data to determine if slow flipper log is needed
event_ws = api.LoadEventNexus(Filename=file_path, OutputWorkspace="raw_events")
metadata = event_ws.getRun()
polarizer = metadata.getProperty("Polarizer").value[0]
analyzer = metadata.getProperty("Analyzer").value[0]
if (polarizer > 0 and self.pol_state not in event_ws.getRun()) or (
analyzer > 0 and self.ana_state not in event_ws.getRun()
):
use_slow_flipper_log = True
print("\n\nMISSING POLARIZER/ANALYZER META-DATA: USING SLOW LOGS\n\n")
# Delete the temporary workspace as split_events will reload it
api.DeleteWorkspace("raw_events")
elif not is_pre_epics and not file_path.endswith(".nxs.h5"):
raise RuntimeError(f"Unknown file type: {file_path}")
# Use mr_reduction's split_events to load and filter cross-sections
try:
# Create PolarizationLogs with custom log names matching REF_M
# Note: PolarizationLogs doesn't support initialization with parameters,
# so we set attributes individually. This could be improved in mr_reduction.
pol_logs = PolarizationLogs()
pol_logs.POL_STATE = self.pol_state
pol_logs.ANA_STATE = self.ana_state
pol_logs.POL_VETO = self.pol_veto
pol_logs.ANA_VETO = self.ana_veto
_path_xs_list = split_events(
file_path=file_path,
output_workspace=ws_root_name,
min_event_count=0, # We'll filter manually to preserve warning logs
use_slow_flipper_log=use_slow_flipper_log,
polarization_logs=pol_logs,
)
# Filter out workspaces with too few events and log warnings
_path_xs_list = remove_low_event_workspaces(_path_xs_list, configuration.nbr_events_min)
if len(_path_xs_list) == 0:
raise InsufficientEventCountError(
f"All cross-sections contain fewer than {configuration.nbr_events_min} events in: {file_path}"
)
except ValueError as e:
# split_events raises ValueError when there are insufficient events at the workspace level
raise InsufficientEventCountError(
f"All cross-sections contain fewer than {configuration.nbr_events_min} events in: {file_path}"
) from e
# Dead-time correction only applies to post-epics data
if configuration is not None and configuration.apply_deadtime and not is_pre_epics:
# Create PolarizationLogs with custom log names matching REF_M
pol_logs = PolarizationLogs()
pol_logs.POL_STATE = self.pol_state
pol_logs.ANA_STATE = self.ana_state
pol_logs.POL_VETO = self.pol_veto
pol_logs.ANA_VETO = self.ana_veto
# Use mr_reduction's split_error_events to load and filter error events
_err_list = split_error_events(
file_path=file_path,
output_workspace=f"{ws_root_name}_err",
use_slow_flipper_log=use_slow_flipper_log,
polarization_logs=pol_logs,
)
# Apply dead-time correction for each cross-section workspace
path_xs_list = []
for ws in _path_xs_list:
xs_name = ws.getRun()["cross_section_id"].value
if not xs_name == "unfiltered":
# Find the related workspace in with error events
is_found = False
for err_ws in _err_list:
if err_ws.getRun()["cross_section_id"].value == xs_name:
is_found = True
_ws = apply_dead_time_correction(
ws,
configuration.paralyzable_deadtime,
configuration.deadtime_value,
configuration.deadtime_tof_step,
error_ws=err_ws,
)
path_xs_list.append(_ws)
if not is_found:
print("Could not find error events for [%s]" % xs_name)
_ws = apply_dead_time_correction(
ws,
configuration.paralyzable_deadtime,
configuration.deadtime_value,
configuration.deadtime_tof_step,
)
path_xs_list.append(_ws)
else:
path_xs_list = [ws for ws in _path_xs_list if not ws.getRun()["cross_section_id"].value == "unfiltered"]
return path_xs_list
[docs]
def load_data(self, file_path: str, configuration: Optional["Configuration"] = None) -> List[EventWorkspace]:
r"""Load one or more data sets according to the needs of the instrument.
This function assumes that when loading more than one data file, the files are congruent and their
events will be added together.
Args:
file_path (str): absolute path to one or more data files. If more than one, paths should be concatenated with the plus symbol '+'.
configuration (Configuration): reduction configuration parameters
Returns
-------
List[EventWorkspace]:
A list of EventWorkspaces, one for each cross-section
Raises
------
InsufficientEventCountError
If the data file does not contain enough events
"""
fp_instance = FilePath(file_path)
ws_root_name = fp_instance.run_numbers(string_representation="short")
ws_run_numbers = fp_instance.run_numbers(string_representation="long")
# Collect cross-sections from all files
all_xs_lists = []
for idx, path in enumerate(fp_instance.single_paths):
# Use unique workspace names for each file to avoid overwrites
path_fp = FilePath(path)
path_ws_name = path_fp.run_numbers(string_representation="short")
all_xs_lists.append(self._get_xs_list(path, path_ws_name, configuration))
# If only one file, return its cross-sections directly
if len(all_xs_lists) == 1:
xs_list = all_xs_lists[0]
else:
# Merge cross-sections from multiple files by matching cross_section_id
xs_list = all_xs_lists[0]
for i, ws in enumerate(xs_list):
# Merge workspaces with matching cross_section_id from subsequent files
for xs_group in all_xs_lists[1:]:
merged = api.Plus(
LHSWorkspace=str(ws),
RHSWorkspace=str(xs_group[i]),
OutputWorkspace=str(ws),
)
xs_list[i] = merged # Update the reference to the merged workspace
# Insert a log indicating which run numbers contributed to this cross-section
for ws in xs_list:
api.AddSampleLog(
Workspace=str(ws),
LogName="run_numbers",
LogText=ws_run_numbers,
LogType="String",
)
return xs_list
[docs]
@classmethod
def mid_q_value(cls, ws: EventWorkspace) -> float:
"""Get the mid q value, at the requested wl mid-point.
This is used when sorting out data sets and doesn't need any overwrites.
"""
wl = ws.getRun().getProperty("LambdaRequest").value[0]
theta_d = api.MRGetTheta(ws)
return 4.0 * math.pi * math.sin(theta_d) / wl
[docs]
@classmethod
def scattering_angle_from_data(cls, data_object: "CrossSectionData") -> float:
"""Compute the scattering angle from a CrossSectionData object, in degrees."""
_dirpix = (
data_object.configuration.direct_pixel_overwrite if data_object.configuration.set_direct_pixel else None
)
_dangle0 = (
data_object.configuration.direct_angle_offset_overwrite
if data_object.configuration.set_direct_angle_offset
else None
)
return (
api.MRGetTheta(
data_object.event_workspace,
SpecularPixel=data_object.configuration.peak_position,
DAngle0Overwrite=_dangle0,
DirectPixelOverwrite=_dirpix,
)
* 180.0
/ math.pi
)
[docs]
@classmethod
def check_direct_beam(cls, ws):
"""Determine whether this data is a direct beam."""
try:
return ws.getRun().getProperty("data_type").value[0] == 1
except:
return False
[docs]
def direct_beam_match(self, scattering, direct_beam, skip_slits=False):
"""Verify whether two data sets are compatible."""
if math.fabs(scattering.lambda_center - direct_beam.lambda_center) < self.tolerance and (
skip_slits
or (
math.fabs(scattering.slit1_width - direct_beam.slit1_width) < self.tolerance
and math.fabs(scattering.slit2_width - direct_beam.slit2_width) < self.tolerance
and math.fabs(scattering.slit3_width - direct_beam.slit3_width) < self.tolerance
)
):
return True
return False
[docs]
def direct_beam_distance(self, scattering, direct_beam) -> float:
"""Return a Euclidean squared-distance between slit widths in scattered, direct beams"""
scatter_slit_array = (scattering.slit1_width, scattering.slit2_width, scattering.slit3_width)
direct_slit_array = (direct_beam.slit1_width, direct_beam.slit2_width, direct_beam.slit3_width)
return sum([(scatter - direct) ** 2 for scatter, direct in zip(scatter_slit_array, direct_slit_array)])
[docs]
@classmethod
def get_info(cls, workspace, data_object):
"""
Retrieve information that is specific to this particular instrument.
@param workspace: Mantid workspace
@param data_object: CrossSectionData object
"""
data = workspace.getRun()
data_object.lambda_center = data["LambdaRequest"].value[0]
data_object.dangle = data["DANGLE"].getStatistics().mean
if "BL4A:Mot:S1:X:Gap" in data:
data_object.slit1_width = data["BL4A:Mot:S1:X:Gap"].value[0]
data_object.slit2_width = data["BL4A:Mot:S2:X:Gap"].value[0]
data_object.slit3_width = data["BL4A:Mot:S3:X:Gap"].value[0]
else:
data_object.slit1_width = data["S1HWidth"].value[0]
data_object.slit2_width = data["S2HWidth"].value[0]
data_object.slit3_width = data["S3HWidth"].value[0]
data_object.huber_x = data["HuberX"].getStatistics().mean
if "SampleAngle" in data:
data_object.sangle = data["SampleAngle"].getStatistics().mean
else:
data_object.sangle = data["SANGLE"].getStatistics().mean
data_object.dist_sam_det = data["SampleDetDis"].value[0] * 1e-3
data_object.dist_mod_det = data["ModeratorSamDis"].value[0] * 1e-3 + data_object.dist_sam_det
data_object.dist_mod_mon = data["ModeratorSamDis"].value[0] * 1e-3 - 2.75
# Get these from instrument
data_object.pixel_width = float(workspace.getInstrument().getNumberParameter("pixel-width")[0]) / 1000.0
data_object.n_det_size_x = int(workspace.getInstrument().getNumberParameter("number-of-x-pixels")[0]) # 304
data_object.n_det_size_y = int(workspace.getInstrument().getNumberParameter("number-of-y-pixels")[0]) # 256
data_object.det_size_x = data_object.n_det_size_x * data_object.pixel_width # horizontal size of detector [m]
data_object.det_size_y = data_object.n_det_size_y * data_object.pixel_width # vertical size of detector [m]
# The following active area used to be taken from instrument.DETECTOR_REGION
data_object.active_area_x = (8, 295)
data_object.active_area_y = (8, 246)
# Convert to standard names
data_object.direct_pixel = data["DIRPIX"].getStatistics().mean
data_object.angle_offset = data["DANGLE0"].getStatistics().mean
# Get proper cross-section label
data_object.cross_section_label = get_cross_section_label(workspace, data_object.entry_name)
try:
data_object.is_direct_beam = data["data_type"].value[0] == 1
except:
data_object.is_direct_beam = False
[docs]
def integrate_detector(self, ws: EventWorkspace, specular: bool = True):
"""Integrate a workspace along either the main direction or the low-resolution direction.
ws:
Mantid workspace to integrate
specular:
If True, the workspace is integrated over the low-resolution direction.
If False, the workspace is integrated over the main direction.
"""
ws_summed = api.RefRoi(
InputWorkspace=ws,
IntegrateY=specular,
NXPixel=self.n_x_pixel,
NYPixel=self.n_y_pixel,
ConvertToQ=False,
OutputWorkspace="ws_summed",
)
integrated = api.Integration(ws_summed)
integrated = api.Transpose(integrated)
return integrated