"""
extra_tomo.py
=============
Specialist LamNI subclasses for specific experimental configurations.
Import explicitly when needed, e.g.:
from csaxs_bec...extra_tomo import MagLamNI
from csaxs_bec...extra_tomo import DataDrivenLamNI
"""
import builtins
import datetime
import os
import time
import h5py
import numpy as np
from bec_lib import bec_logger
from bec_lib.alarm_handler import AlarmBase
from .lamni import LamNI
logger = bec_logger.logger
if builtins.__dict__.get("bec") is not None:
bec = builtins.__dict__.get("bec")
dev = builtins.__dict__.get("dev")
umv = builtins.__dict__.get("umv")
scans = builtins.__dict__.get("scans")
[docs]
class MagLamNI(LamNI):
"""LamNI subclass for magnetic experiments (XMCD).
Adds a slow rotation helper and allows injection of a custom
per-angle callback via the ``lamni_at_each_angle`` builtin.
"""
[docs]
def sub_tomo_scan(self, subtomo_number, start_angle=None):
super().sub_tomo_scan(subtomo_number, start_angle)
# self.rotate_slowly(0)
[docs]
def rotate_slowly(self, angle, step_size=20):
"""Rotate to target angle in small steps to avoid mechanical stress."""
current_angle = dev.lsamrot.read(cached=True)["value"]
steps = int(np.ceil(np.abs(current_angle - angle) / step_size)) + 1
for target_angle in np.linspace(current_angle, angle, steps, endpoint=True):
umv(dev.lsamrot, target_angle)
scans.lamni_move_to_scan_center(
self.align.tomo_fovx_offset, self.align.tomo_fovy_offset, target_angle
)
def _at_each_angle(self, angle: float) -> None:
if "lamni_at_each_angle" in builtins.__dict__:
# pylint: disable=undefined-variable
lamni_at_each_angle(self, angle)
return
self.tomo_scan_projection(angle)
self.tomo_reconstruct()
[docs]
class DataDrivenLamNI(LamNI):
"""LamNI subclass that reads per-projection scan parameters from an HDF5 file.
Instead of a fixed FOV and step size for the whole tomogram, each
projection can have individual values for step size, loptz position
and lateral shifts, as specified in a datadriven_params.h5 file.
"""
def __init__(self, client):
super().__init__(client)
self.tomo_data = {}
[docs]
def tomo_scan(
self,
subtomo_start=1,
start_index=None,
fname="~/Data10/data_driven_config/datadriven_params.h5",
):
"""Start a data-driven tomo scan.
Args:
subtomo_start (int): Unused; kept for API compatibility. Use start_index to resume.
start_index (int, optional): Skip projections before this index. Defaults to None.
fname (str): Path to the HDF5 parameter file. Defaults to the standard location.
"""
bec = builtins.__dict__.get("bec")
scans = builtins.__dict__.get("scans")
fname = os.path.expanduser(fname)
if not os.path.exists(fname):
raise FileNotFoundError(f"Could not find datadriven params file in {fname}.")
content = f"Loading tomo parameters from {fname}."
logger.warning(content)
msg = bec.logbook.LogbookMessage()
msg.add_text(content).add_tag(["Data_driven_file", "BEC"])
self.client.logbook.send_logbook_message(msg)
self._update_tomo_data_from_file(fname)
self._current_special_angles = self.special_angles.copy()
if subtomo_start == 1 and start_index is None:
self.tomo_id = self.add_sample_database(
self.sample_name,
str(datetime.date.today()),
bec.active_account.decode(),
bec.queue.next_scan_number,
"lamni",
"test additional info",
"BEC",
)
self.write_pdf_report()
with scans.dataset_id_on_hold:
self.sub_tomo_data_driven(start_index)
[docs]
def sub_tomo_scan(self, subtomo_number=None, start_angle=None):
raise NotImplementedError(
"Cannot run sub_tomo_scan with DataDrivenLamNI. "
"Use lamni.tomo_scan(start_index=<N>) to resume instead."
)
def _at_each_angle(
self, angle=None, stepsize=None, loptz_pos=None, manual_shift_x=0, manual_shift_y=0
):
self.manual_shift_x = manual_shift_x
self.manual_shift_y = manual_shift_y
self.tomo_shellstep = stepsize
if loptz_pos is not None:
dev.rtx.controller.feedback_disable()
umv(dev.loptz, loptz_pos)
super()._at_each_angle(angle=angle)
[docs]
def sub_tomo_data_driven(self, start_index=None):
"""Iterate over all projections defined in the loaded HDF5 parameter file."""
for scan_index, scan_data in enumerate(zip(*self.tomo_data.values())):
if start_index and scan_index < start_index:
continue
(
angle,
stepsize,
loptz_pos,
propagation_distance,
manual_shift_x,
manual_shift_y,
subtomo_number,
) = scan_data
bec.metadata.update(
{key: float(val) for key, val in zip(self.tomo_data.keys(), scan_data)}
)
successful = False
error_caught = False
if 0 <= angle < 360.05:
print(f"Starting LamNI scan for angle {angle}")
while not successful:
self._start_beam_check()
if not self.special_angles:
self._current_special_angles = []
if self._current_special_angles:
next_special_angle = self._current_special_angles[0]
if np.isclose(angle, next_special_angle, atol=0.5):
self._current_special_angles.pop(0)
num_repeats = self.special_angle_repeats
else:
num_repeats = 1
try:
start_scan_number = bec.queue.next_scan_number
for i in range(num_repeats):
self._at_each_angle(
float(angle),
stepsize=float(stepsize),
loptz_pos=float(loptz_pos),
manual_shift_x=float(manual_shift_x),
manual_shift_y=float(manual_shift_y),
)
error_caught = False
except AlarmBase as exc:
if exc.alarm_type == "TimeoutError":
bec.queue.request_queue_reset()
time.sleep(2)
error_caught = True
else:
raise exc
if self._was_beam_okay() and not error_caught:
successful = True
else:
self._wait_for_beamline_checks()
end_scan_number = bec.queue.next_scan_number
for scan_nr in range(start_scan_number, end_scan_number):
self._write_tomo_scan_number(scan_nr, angle, subtomo_number)
def _update_tomo_data_from_file(self, fname: str) -> None:
"""Load projection parameters from the HDF5 file into self.tomo_data."""
with h5py.File(fname, "r") as file:
self.tomo_data["theta"] = np.array([*file["theta"]]).flatten()
self.tomo_data["stepsize"] = np.array([*file["stepsize"]]).flatten()
self.tomo_data["loptz"] = np.array([*file["loptz"]]).flatten()
self.tomo_data["propagation_distance"] = np.array(
[*file["relative_propagation_distance"]]
).flatten()
self.tomo_data["manual_shift_x"] = np.array([*file["manual_shift_x"]]).flatten()
self.tomo_data["manual_shift_y"] = np.array([*file["manual_shift_y"]]).flatten()
self.tomo_data["subtomo_id"] = np.array([*file["subtomo_id"]]).flatten()
shapes = [data.shape for data in self.tomo_data.values()]
if len(set(shapes)) > 1:
raise ValueError(f"Tomo data file has entries of inconsistent lengths: {shapes}.")