Source code for csaxs_bec.devices.npoint.npoint

import functools
import threading
import time

import numpy as np
from bec_lib.logger import bec_logger
from ophyd import Component as Cpt
from ophyd import Device, PositionerBase, Signal, SignalRO
from ophyd.status import wait as status_wait
from ophyd.utils import LimitError, ReadOnlyError
from ophyd_devices.utils.controller import Controller, threadlocked
from ophyd_devices.utils.socket import SocketIO, SocketSignal, raise_if_disconnected
from prettytable import PrettyTable

logger = bec_logger.logger


[docs] def channel_checked(fcn): """Decorator to catch attempted access to channels that are not available.""" @functools.wraps(fcn) def wrapper(self, *args, **kwargs): # pylint: disable=protected-access self._check_channel(args[0]) return fcn(self, *args, **kwargs) return wrapper
[docs] class NpointError(Exception): """ Base class for Npoint errors. """
[docs] class NPointController(Controller): """ Controller for nPoint piezo stages. This class inherits from the Controller class and provides a singleton interface to the nPoint controller. """ _axes_per_controller = 3 _read_single_loc_bit = "A0" _write_single_loc_bit = "A2" _trailing_bit = "55" _range_offset = "78" _channel_base = ["11", "83"]
[docs] def show_all(self) -> None: """Display current status of all channels Returns: None """ if not self.connected: print("npoint controller is currently disabled.") return print(f"Connected to controller at {self._socket_host}:{self._socket_port}") t = PrettyTable() t.field_names = ["Channel", "Range", "Position", "Target"] for ii in range(self._axes_per_controller): t.add_row([ii, self._get_range(ii), self.get_current_pos(ii), self.get_target_pos(ii)]) print(t)
@channel_checked def _get_range(self, channel: int) -> int: """Get the range of the specified channel axis. Args: channel (int): Channel for which the range should be requested. Raises: RuntimeError: Raised if the received message doesn't have the expected number of bytes (10). Returns: int: Range """ # for first channel: 0x11 83 10 78 addr = self._channel_base.copy() addr.extend([f"{16 + 16 * channel:x}", self._range_offset]) send_buffer = self.__read_single_location_buffer(addr) recvd = self._put_and_receive(send_buffer) if len(recvd) != 10: raise RuntimeError( f"Received buffer is corrupted. Expected 10 bytes and instead got {len(recvd)}" ) device_range = self._hex_list_to_int(recvd[5:-1], signed=False) return device_range @channel_checked def get_current_pos(self, channel: int) -> float: # for first channel: 0x11 83 13 34 addr = self._channel_base.copy() addr.extend([f"{19 + 16 * channel:x}", "34"]) send_buffer = self.__read_single_location_buffer(addr) recvd = self._put_and_receive(send_buffer) pos_buffer = recvd[5:-1] pos = self._hex_list_to_int(pos_buffer) / 1048574 * 100 return pos @channel_checked def set_target_pos(self, channel: int, pos: float) -> None: # for first channel: 0x11 83 12 18 00 00 00 00 addr = self._channel_base.copy() addr.extend([f"{18 + channel * 16:x}", "18"]) target = int(round(1048574 / 100 * pos)) data = [f"{m:02x}" for m in target.to_bytes(4, byteorder="big", signed=True)] send_buffer = self.__write_single_location_buffer(addr, data) self._put(send_buffer) @channel_checked def get_target_pos(self, channel: int) -> float: # for first channel: 0x11 83 12 18 addr = self._channel_base.copy() addr.extend([f"{18 + channel * 16:x}", "18"]) send_buffer = self.__read_single_location_buffer(addr) recvd = self._put_and_receive(send_buffer) pos_buffer = recvd[5:-1] pos = self._hex_list_to_int(pos_buffer) / 1048574 * 100 return pos @channel_checked def _set_servo(self, channel: int, enable: bool) -> None: print("Not tested") return # # for first channel: 0x11 83 10 84 00 00 00 00 # addr = self._channel_base.copy() # addr.extend([f"{16 + channel * 16:x}", "84"]) # if enable: # data = ["00"] * 3 + ["01"] # else: # data = ["00"] * 4 # send_buffer = self.__write_single_location_buffer(addr, data) # self._put(send_buffer) @channel_checked def _get_servo(self, channel: int) -> int: # for first channel: 0x11 83 10 84 00 00 00 00 addr = self._channel_base.copy() addr.extend([f"{16 + channel * 16:x}", "84"]) send_buffer = self.__read_single_location_buffer(addr) recvd = self._put_and_receive(send_buffer) buffer = recvd[5:-1] status = self._hex_list_to_int(buffer) return status @threadlocked def _put(self, buffer: list) -> None: """Translates a list of hex values to bytes and sends them to the socket. Args: buffer (list): List of hex values without leading 0x Returns: None """ buffer = b"".join([bytes.fromhex(m) for m in buffer]) self.sock.put(buffer) @threadlocked def _put_and_receive(self, msg_hex_list: list) -> list: """Send msg to socket and wait for a reply. Args: msg_hex_list (list): List of hex values without leading 0x. Returns: list: Received message as a list of hex values """ buffer = b"".join([bytes.fromhex(m) for m in msg_hex_list]) self.sock.put(buffer) recv_msg = self.sock.receive() recv_hex_list = [hex(m) for m in recv_msg] self._verify_received_msg(msg_hex_list, recv_hex_list) return recv_hex_list def _verify_received_msg(self, in_list: list, out_list: list) -> None: """Ensure that the first address bits of sent and received messages are the same. Args: in_list (list): list containing the sent message out_list (list): list containing the received message Raises: RuntimeError: Raised if first two address bits of 'in' and 'out' are not identical Returns: None """ # first, translate hex (str) values to int in_list_int = [int(val, 16) for val in in_list] out_list_int = [int(val, 16) for val in out_list] # first ints of the reply should be the same. Otherwise something went wrong if not in_list_int[:2] == out_list_int[:2]: raise RuntimeError("Connection failure. Please restart the controller.") def _check_channel(self, channel: int) -> None: if channel >= self._axes_per_controller: raise ValueError( f"Channel {channel+1} exceeds the available number of channels ({self._axes_per_controller})" ) @staticmethod def _hex_list_to_int(in_buffer: list, byteorder="little", signed=True) -> int: """Translate hex list to int. Args: in_buffer (list): Input buffer; received as list of hex values byteorder (str, optional): Byteorder of in_buffer. Defaults to "little". signed (bool, optional): Whether the hex list represents a signed int. Defaults to True. Returns: int: Translated integer. """ if byteorder == "little": in_buffer.reverse() # make sure that all hex strings have the same format ("FF") val_hex = [f"{int(m, 16):02x}" for m in in_buffer] val_bytes = [bytes.fromhex(m) for m in val_hex] val = int.from_bytes(b"".join(val_bytes), byteorder="big", signed=signed) return val @staticmethod def __read_single_location_buffer(addr) -> list: """Prepare buffer for reading from a single memory location (hex address). Number of bytes: 6 Format: 0xA0 [addr] 0x55 Return Value: 0xA0 [addr] [data] 0x55 Sample Hex Transmission from PC to LC.400: A0 18 12 83 11 55 Sample Hex Return Transmission from LC.400 to PC: A0 18 12 83 11 64 00 00 00 55 Args: addr (list): Hex address to read from Returns: list: List of hex values representing the read instruction. """ buffer = [] buffer.append(NPointController._read_single_loc_bit) if isinstance(addr, list): addr.reverse() buffer.extend(addr) else: buffer.append(addr) buffer.append(NPointController._trailing_bit) return buffer @staticmethod def __write_single_location_buffer(addr: list, data: list) -> list: """Prepare buffer for writing to a single memory location (hex address). Number of bytes: 10 Format: 0xA2 [addr] [data] 0x55 Return Value: none Sample Hex Transmission from PC to C.400: A2 18 12 83 11 E8 03 00 00 55 Args: addr (list): List of hex values representing the address to write to. data (list): List of hex values representing the data that should be written. Returns: list: List of hex values representing the write instruction. """ buffer = [] buffer.append(NPointController._write_single_loc_bit) if isinstance(addr, list): addr.reverse() buffer.extend(addr) else: buffer.append(addr) if isinstance(data, list): data.reverse() buffer.extend(data) else: buffer.append(data) buffer.append(NPointController._trailing_bit) return buffer @staticmethod def __read_array(): raise NotImplementedError @staticmethod def __write_next_command(): raise NotImplementedError def __del__(self): if self.connected: print("Closing npoint socket") self.off()
[docs] class NpointSignalBase(SocketSignal): """ Base class for nPoint signals. """ def __init__(self, signal_name, **kwargs): self.signal_name = signal_name super().__init__(**kwargs) self.controller: NPointController = self.parent.controller self.sock = self.parent.controller.sock
[docs] class NpointSignalRO(NpointSignalBase): """ Base class for read-only signals. """ def __init__(self, signal_name, **kwargs): super().__init__(signal_name, **kwargs) self._metadata["write_access"] = False @threadlocked def _socket_set(self, val): raise ReadOnlyError("Read-only signals cannot be set")
[docs] class NpointReadbackSignal(NpointSignalRO): """ Signal to read the current position of an nPoint piezo stage. """ @threadlocked def _socket_get(self): return self.controller.get_current_pos(self.parent.axis_Id_numeric) * self.parent.sign
[docs] class NpointSetpointSignal(NpointSignalBase): """ Signal to set the target position of an nPoint piezo stage. """ def __init__(self, signal_name, **kwargs): super().__init__(signal_name, **kwargs) self.setpoint = 0.0 @threadlocked def _socket_get(self): return self.controller.get_target_pos(self.parent.axis_Id_numeric) * self.parent.sign @threadlocked def _socket_set(self, val): target_val = val * self.parent.sign self.setpoint = target_val return self.controller.set_target_pos( self.parent.axis_Id_numeric, target_val * self.parent.sign )
[docs] class NpointMotorIsMoving(SignalRO): """ Signal to indicate whether the motor is currently moving or not. """
[docs] def set_motor_is_moving(self, value: int) -> None: """ Set the motor_is_moving signal to the specified value. Args: value (int): 1 if the motor is moving, 0 otherwise. """ self._readback = value
[docs] class NPointAxis(Device, PositionerBase): """ NPointAxis class, which inherits from Device and PositionerBase. This class represents an axis of an nPoint piezo stage and provides the necessary functionality to move the axis and read its current position. """ USER_ACCESS = ["controller"] readback = Cpt(NpointReadbackSignal, signal_name="readback", kind="hinted") user_setpoint = Cpt(NpointSetpointSignal, signal_name="setpoint") motor_is_moving = Cpt(NpointMotorIsMoving, value=0, kind="normal") settle_time: Cpt[Signal] = Cpt(Signal, value=0.1, kind="config") high_limit_travel = Cpt(Signal, value=0, kind="omitted") low_limit_travel = Cpt(Signal, value=0, kind="omitted") SUB_READBACK = "readback" SUB_CONNECTION_CHANGE = "connection_change" _default_sub = SUB_READBACK def __init__( self, axis_Id, prefix="", *, name, kind=None, read_attrs=None, configuration_attrs=None, parent=None, host="mpc2680.psi.ch", port=8085, limits=None, sign=1, socket_cls=SocketIO, tolerance: float = 0.05, device_manager=None, **kwargs, ): self.controller = NPointController( socket_cls=socket_cls, socket_host=host, socket_port=port, device_manager=device_manager ) self.axis_Id = axis_Id self.sign = sign self.controller.set_axis(axis=self, axis_nr=self.axis_Id_numeric) self.tolerance = tolerance super().__init__( prefix, name=name, kind=kind, read_attrs=read_attrs, configuration_attrs=configuration_attrs, parent=parent, **kwargs, ) self.readback.name = self.name self.controller.subscribe( self._update_connection_state, event_type=self.SUB_CONNECTION_CHANGE ) self._update_connection_state() if limits is not None: assert len(limits) == 2 self.low_limit_travel.put(limits[0]) self.high_limit_travel.put(limits[1])
[docs] def wait_for_connection(self, timeout: float = 30.0) -> bool: for _ in range(5): try: self.controller.on(timeout=timeout) self._update_setpoint_from_readback() except TimeoutError: self.controller.off(update_config=False) time.sleep(1) else: break else: raise TimeoutError( f"NPointAxis {self.name}: Failed to update the setpoint from the readback value after 5 attempts during startup. This may happen occasionally. " f"Try to reload the config and if the problem persists, check the connection to the nPoint controller " f"and ensure that it is powered on and accessible at {self.controller._socket_host}:{self.controller._socket_port}." )
def _update_setpoint_from_readback(self): """ The setpoint is only stored locally. After a restart, we need to update it to match the current readback value. """ self.user_setpoint.setpoint = self.readback.get()
[docs] def destroy(self): """Make sure to turn off the controller socket on destroy.""" self.controller.off(update_config=False) return super().destroy()
@property def limits(self): return (self.low_limit_travel.get(), self.high_limit_travel.get()) @property def low_limit(self): return self.limits[0] @property def high_limit(self): return self.limits[1]
[docs] def check_value(self, pos): """Check that the position is within the soft limits""" low_limit, high_limit = self.limits if low_limit < high_limit and not (low_limit <= pos <= high_limit): raise LimitError(f"position={pos} not within limits {self.limits}")
def _update_connection_state(self, **kwargs): for walk in self.walk_signals(): walk.item._metadata["connected"] = self.controller.connected
[docs] @raise_if_disconnected def move(self, position, wait=True, **kwargs): """Move to a specified position, optionally waiting for motion to complete. Parameters ---------- position Position to move to moved_cb : callable Call this callback when movement has finished. This callback must accept one keyword argument: 'obj' which will be set to this positioner instance. timeout : float, optional Maximum time to wait for the motion. If None, the default timeout for this positioner is used. Returns ------- status : MoveStatus Raises ------ TimeoutError When motion takes longer than `timeout` ValueError On invalid positions RuntimeError If motion fails other than timing out """ self._started_moving = False timeout = kwargs.pop("timeout", 10) status = super().move(position, timeout=timeout, **kwargs) self.user_setpoint.put(position, wait=False) def move_and_finish(): self.motor_is_moving.set_motor_is_moving(1) val = self.readback.read() self._run_subs(sub_type=self.SUB_READBACK, value=val, timestamp=time.time()) time.sleep(max(self.settle_time.get(), 0)) self.motor_is_moving.set_motor_is_moving(0) val = self.readback.read() self._run_subs(sub_type=self.SUB_READBACK, value=val, timestamp=time.time()) success = np.isclose(val[self.name]["value"], position, atol=self.tolerance) self._done_moving(success=success) threading.Thread(target=move_and_finish, daemon=True).start() try: if wait: status_wait(status) except KeyboardInterrupt: self.stop() raise return status
@property def axis_Id(self): """ Return the axis_Id_alpha. """ return self._axis_Id_alpha @axis_Id.setter def axis_Id(self, val: str): """ Set the axis_Id_alpha and axis_Id_numeric based on the alpha value. Args: val (str): Single-character axis identifier. """ if isinstance(val, str): if len(val) != 1: raise ValueError("Only single-character axis_Ids are supported.") self._axis_Id_alpha = val self._axis_Id_numeric = ord(val.lower()) - 97 else: raise TypeError(f"Expected value of type str but received {type(val)}") @property def axis_Id_numeric(self): """ Return the numeric value of the axis_Id. """ return self._axis_Id_numeric @axis_Id_numeric.setter def axis_Id_numeric(self, val: int): """ Set the axis_Id_numeric and axis_Id_alpha based on the numeric value. Args: val (int): Numeric axis identifier. """ if isinstance(val, int): if val > 26: raise ValueError("Numeric value exceeds supported range.") self._axis_Id_alpha = val self._axis_Id_numeric = (chr(val + 97)).capitalize() else: raise TypeError(f"Expected value of type int but received {type(val)}") @property def egu(self): """The engineering units (EGU) for positions""" return "um"
[docs] def stage(self) -> list[object]: return super().stage()
[docs] def unstage(self) -> list[object]: return super().unstage()
if __name__ == "__main__": npx = NPointAxis(axis_Id="A", name="npx", host="nPoint000003.psi.ch", port=23) npy = NPointAxis(axis_Id="B", name="npy", host="nPoint000003.psi.ch", port=23) npx.controller.on() print("socket is open, axis is ready!") npx.move(10) print(npx.read()) npx.controller.off()