import functools
import threading
import time
import numpy as np
from bec_lib.logger import bec_logger
from ophyd import Component as Cpt
from ophyd import Device, PositionerBase, Signal, SignalRO
from ophyd.status import wait as status_wait
from ophyd.utils import LimitError, ReadOnlyError
from ophyd_devices.utils.controller import Controller, threadlocked
from ophyd_devices.utils.socket import SocketIO, SocketSignal, raise_if_disconnected
from prettytable import PrettyTable
logger = bec_logger.logger
[docs]
def channel_checked(fcn):
"""Decorator to catch attempted access to channels that are not available."""
@functools.wraps(fcn)
def wrapper(self, *args, **kwargs):
# pylint: disable=protected-access
self._check_channel(args[0])
return fcn(self, *args, **kwargs)
return wrapper
[docs]
class NpointError(Exception):
"""
Base class for Npoint errors.
"""
[docs]
class NPointController(Controller):
"""
Controller for nPoint piezo stages. This class inherits from the Controller class
and provides a singleton interface to the nPoint controller.
"""
_axes_per_controller = 3
_read_single_loc_bit = "A0"
_write_single_loc_bit = "A2"
_trailing_bit = "55"
_range_offset = "78"
_channel_base = ["11", "83"]
[docs]
def show_all(self) -> None:
"""Display current status of all channels
Returns:
None
"""
if not self.connected:
print("npoint controller is currently disabled.")
return
print(f"Connected to controller at {self._socket_host}:{self._socket_port}")
t = PrettyTable()
t.field_names = ["Channel", "Range", "Position", "Target"]
for ii in range(self._axes_per_controller):
t.add_row([ii, self._get_range(ii), self.get_current_pos(ii), self.get_target_pos(ii)])
print(t)
@channel_checked
def _get_range(self, channel: int) -> int:
"""Get the range of the specified channel axis.
Args:
channel (int): Channel for which the range should be requested.
Raises:
RuntimeError: Raised if the received message doesn't have the expected number of bytes (10).
Returns:
int: Range
"""
# for first channel: 0x11 83 10 78
addr = self._channel_base.copy()
addr.extend([f"{16 + 16 * channel:x}", self._range_offset])
send_buffer = self.__read_single_location_buffer(addr)
recvd = self._put_and_receive(send_buffer)
if len(recvd) != 10:
raise RuntimeError(
f"Received buffer is corrupted. Expected 10 bytes and instead got {len(recvd)}"
)
device_range = self._hex_list_to_int(recvd[5:-1], signed=False)
return device_range
@channel_checked
def get_current_pos(self, channel: int) -> float:
# for first channel: 0x11 83 13 34
addr = self._channel_base.copy()
addr.extend([f"{19 + 16 * channel:x}", "34"])
send_buffer = self.__read_single_location_buffer(addr)
recvd = self._put_and_receive(send_buffer)
pos_buffer = recvd[5:-1]
pos = self._hex_list_to_int(pos_buffer) / 1048574 * 100
return pos
@channel_checked
def set_target_pos(self, channel: int, pos: float) -> None:
# for first channel: 0x11 83 12 18 00 00 00 00
addr = self._channel_base.copy()
addr.extend([f"{18 + channel * 16:x}", "18"])
target = int(round(1048574 / 100 * pos))
data = [f"{m:02x}" for m in target.to_bytes(4, byteorder="big", signed=True)]
send_buffer = self.__write_single_location_buffer(addr, data)
self._put(send_buffer)
@channel_checked
def get_target_pos(self, channel: int) -> float:
# for first channel: 0x11 83 12 18
addr = self._channel_base.copy()
addr.extend([f"{18 + channel * 16:x}", "18"])
send_buffer = self.__read_single_location_buffer(addr)
recvd = self._put_and_receive(send_buffer)
pos_buffer = recvd[5:-1]
pos = self._hex_list_to_int(pos_buffer) / 1048574 * 100
return pos
@channel_checked
def _set_servo(self, channel: int, enable: bool) -> None:
print("Not tested")
return
# # for first channel: 0x11 83 10 84 00 00 00 00
# addr = self._channel_base.copy()
# addr.extend([f"{16 + channel * 16:x}", "84"])
# if enable:
# data = ["00"] * 3 + ["01"]
# else:
# data = ["00"] * 4
# send_buffer = self.__write_single_location_buffer(addr, data)
# self._put(send_buffer)
@channel_checked
def _get_servo(self, channel: int) -> int:
# for first channel: 0x11 83 10 84 00 00 00 00
addr = self._channel_base.copy()
addr.extend([f"{16 + channel * 16:x}", "84"])
send_buffer = self.__read_single_location_buffer(addr)
recvd = self._put_and_receive(send_buffer)
buffer = recvd[5:-1]
status = self._hex_list_to_int(buffer)
return status
@threadlocked
def _put(self, buffer: list) -> None:
"""Translates a list of hex values to bytes and sends them to the socket.
Args:
buffer (list): List of hex values without leading 0x
Returns:
None
"""
buffer = b"".join([bytes.fromhex(m) for m in buffer])
self.sock.put(buffer)
@threadlocked
def _put_and_receive(self, msg_hex_list: list) -> list:
"""Send msg to socket and wait for a reply.
Args:
msg_hex_list (list): List of hex values without leading 0x.
Returns:
list: Received message as a list of hex values
"""
buffer = b"".join([bytes.fromhex(m) for m in msg_hex_list])
self.sock.put(buffer)
recv_msg = self.sock.receive()
recv_hex_list = [hex(m) for m in recv_msg]
self._verify_received_msg(msg_hex_list, recv_hex_list)
return recv_hex_list
def _verify_received_msg(self, in_list: list, out_list: list) -> None:
"""Ensure that the first address bits of sent and received messages are the same.
Args:
in_list (list): list containing the sent message
out_list (list): list containing the received message
Raises:
RuntimeError: Raised if first two address bits of 'in' and 'out' are not identical
Returns:
None
"""
# first, translate hex (str) values to int
in_list_int = [int(val, 16) for val in in_list]
out_list_int = [int(val, 16) for val in out_list]
# first ints of the reply should be the same. Otherwise something went wrong
if not in_list_int[:2] == out_list_int[:2]:
raise RuntimeError("Connection failure. Please restart the controller.")
def _check_channel(self, channel: int) -> None:
if channel >= self._axes_per_controller:
raise ValueError(
f"Channel {channel+1} exceeds the available number of channels ({self._axes_per_controller})"
)
@staticmethod
def _hex_list_to_int(in_buffer: list, byteorder="little", signed=True) -> int:
"""Translate hex list to int.
Args:
in_buffer (list): Input buffer; received as list of hex values
byteorder (str, optional): Byteorder of in_buffer. Defaults to "little".
signed (bool, optional): Whether the hex list represents a signed int. Defaults to True.
Returns:
int: Translated integer.
"""
if byteorder == "little":
in_buffer.reverse()
# make sure that all hex strings have the same format ("FF")
val_hex = [f"{int(m, 16):02x}" for m in in_buffer]
val_bytes = [bytes.fromhex(m) for m in val_hex]
val = int.from_bytes(b"".join(val_bytes), byteorder="big", signed=signed)
return val
@staticmethod
def __read_single_location_buffer(addr) -> list:
"""Prepare buffer for reading from a single memory location (hex address).
Number of bytes: 6
Format: 0xA0 [addr] 0x55
Return Value: 0xA0 [addr] [data] 0x55
Sample Hex Transmission from PC to LC.400: A0 18 12 83 11 55
Sample Hex Return Transmission from LC.400 to PC: A0 18 12 83 11 64 00 00 00 55
Args:
addr (list): Hex address to read from
Returns:
list: List of hex values representing the read instruction.
"""
buffer = []
buffer.append(NPointController._read_single_loc_bit)
if isinstance(addr, list):
addr.reverse()
buffer.extend(addr)
else:
buffer.append(addr)
buffer.append(NPointController._trailing_bit)
return buffer
@staticmethod
def __write_single_location_buffer(addr: list, data: list) -> list:
"""Prepare buffer for writing to a single memory location (hex address).
Number of bytes: 10
Format: 0xA2 [addr] [data] 0x55
Return Value: none
Sample Hex Transmission from PC to C.400: A2 18 12 83 11 E8 03 00 00 55
Args:
addr (list): List of hex values representing the address to write to.
data (list): List of hex values representing the data that should be written.
Returns:
list: List of hex values representing the write instruction.
"""
buffer = []
buffer.append(NPointController._write_single_loc_bit)
if isinstance(addr, list):
addr.reverse()
buffer.extend(addr)
else:
buffer.append(addr)
if isinstance(data, list):
data.reverse()
buffer.extend(data)
else:
buffer.append(data)
buffer.append(NPointController._trailing_bit)
return buffer
@staticmethod
def __read_array():
raise NotImplementedError
@staticmethod
def __write_next_command():
raise NotImplementedError
def __del__(self):
if self.connected:
print("Closing npoint socket")
self.off()
[docs]
class NpointSignalBase(SocketSignal):
"""
Base class for nPoint signals.
"""
def __init__(self, signal_name, **kwargs):
self.signal_name = signal_name
super().__init__(**kwargs)
self.controller: NPointController = self.parent.controller
self.sock = self.parent.controller.sock
[docs]
class NpointSignalRO(NpointSignalBase):
"""
Base class for read-only signals.
"""
def __init__(self, signal_name, **kwargs):
super().__init__(signal_name, **kwargs)
self._metadata["write_access"] = False
@threadlocked
def _socket_set(self, val):
raise ReadOnlyError("Read-only signals cannot be set")
[docs]
class NpointReadbackSignal(NpointSignalRO):
"""
Signal to read the current position of an nPoint piezo stage.
"""
@threadlocked
def _socket_get(self):
return self.controller.get_current_pos(self.parent.axis_Id_numeric) * self.parent.sign
[docs]
class NpointSetpointSignal(NpointSignalBase):
"""
Signal to set the target position of an nPoint piezo stage.
"""
def __init__(self, signal_name, **kwargs):
super().__init__(signal_name, **kwargs)
self.setpoint = 0.0
@threadlocked
def _socket_get(self):
return self.controller.get_target_pos(self.parent.axis_Id_numeric) * self.parent.sign
@threadlocked
def _socket_set(self, val):
target_val = val * self.parent.sign
self.setpoint = target_val
return self.controller.set_target_pos(
self.parent.axis_Id_numeric, target_val * self.parent.sign
)
[docs]
class NpointMotorIsMoving(SignalRO):
"""
Signal to indicate whether the motor is currently moving or not.
"""
[docs]
def set_motor_is_moving(self, value: int) -> None:
"""
Set the motor_is_moving signal to the specified value.
Args:
value (int): 1 if the motor is moving, 0 otherwise.
"""
self._readback = value
[docs]
class NPointAxis(Device, PositionerBase):
"""
NPointAxis class, which inherits from Device and PositionerBase. This class
represents an axis of an nPoint piezo stage and provides the necessary
functionality to move the axis and read its current position.
"""
USER_ACCESS = ["controller"]
readback = Cpt(NpointReadbackSignal, signal_name="readback", kind="hinted")
user_setpoint = Cpt(NpointSetpointSignal, signal_name="setpoint")
motor_is_moving = Cpt(NpointMotorIsMoving, value=0, kind="normal")
settle_time: Cpt[Signal] = Cpt(Signal, value=0.1, kind="config")
high_limit_travel = Cpt(Signal, value=0, kind="omitted")
low_limit_travel = Cpt(Signal, value=0, kind="omitted")
SUB_READBACK = "readback"
SUB_CONNECTION_CHANGE = "connection_change"
_default_sub = SUB_READBACK
def __init__(
self,
axis_Id,
prefix="",
*,
name,
kind=None,
read_attrs=None,
configuration_attrs=None,
parent=None,
host="mpc2680.psi.ch",
port=8085,
limits=None,
sign=1,
socket_cls=SocketIO,
tolerance: float = 0.05,
device_manager=None,
**kwargs,
):
self.controller = NPointController(
socket_cls=socket_cls, socket_host=host, socket_port=port, device_manager=device_manager
)
self.axis_Id = axis_Id
self.sign = sign
self.controller.set_axis(axis=self, axis_nr=self.axis_Id_numeric)
self.tolerance = tolerance
super().__init__(
prefix,
name=name,
kind=kind,
read_attrs=read_attrs,
configuration_attrs=configuration_attrs,
parent=parent,
**kwargs,
)
self.readback.name = self.name
self.controller.subscribe(
self._update_connection_state, event_type=self.SUB_CONNECTION_CHANGE
)
self._update_connection_state()
if limits is not None:
assert len(limits) == 2
self.low_limit_travel.put(limits[0])
self.high_limit_travel.put(limits[1])
[docs]
def wait_for_connection(self, timeout: float = 30.0) -> bool:
for _ in range(5):
try:
self.controller.on(timeout=timeout)
self._update_setpoint_from_readback()
except TimeoutError:
self.controller.off(update_config=False)
time.sleep(1)
else:
break
else:
raise TimeoutError(
f"NPointAxis {self.name}: Failed to update the setpoint from the readback value after 5 attempts during startup. This may happen occasionally. "
f"Try to reload the config and if the problem persists, check the connection to the nPoint controller "
f"and ensure that it is powered on and accessible at {self.controller._socket_host}:{self.controller._socket_port}."
)
def _update_setpoint_from_readback(self):
"""
The setpoint is only stored locally. After a restart,
we need to update it to match the current readback value.
"""
self.user_setpoint.setpoint = self.readback.get()
[docs]
def destroy(self):
"""Make sure to turn off the controller socket on destroy."""
self.controller.off(update_config=False)
return super().destroy()
@property
def limits(self):
return (self.low_limit_travel.get(), self.high_limit_travel.get())
@property
def low_limit(self):
return self.limits[0]
@property
def high_limit(self):
return self.limits[1]
[docs]
def check_value(self, pos):
"""Check that the position is within the soft limits"""
low_limit, high_limit = self.limits
if low_limit < high_limit and not (low_limit <= pos <= high_limit):
raise LimitError(f"position={pos} not within limits {self.limits}")
def _update_connection_state(self, **kwargs):
for walk in self.walk_signals():
walk.item._metadata["connected"] = self.controller.connected
[docs]
@raise_if_disconnected
def move(self, position, wait=True, **kwargs):
"""Move to a specified position, optionally waiting for motion to
complete.
Parameters
----------
position
Position to move to
moved_cb : callable
Call this callback when movement has finished. This callback must
accept one keyword argument: 'obj' which will be set to this
positioner instance.
timeout : float, optional
Maximum time to wait for the motion. If None, the default timeout
for this positioner is used.
Returns
-------
status : MoveStatus
Raises
------
TimeoutError
When motion takes longer than `timeout`
ValueError
On invalid positions
RuntimeError
If motion fails other than timing out
"""
self._started_moving = False
timeout = kwargs.pop("timeout", 10)
status = super().move(position, timeout=timeout, **kwargs)
self.user_setpoint.put(position, wait=False)
def move_and_finish():
self.motor_is_moving.set_motor_is_moving(1)
val = self.readback.read()
self._run_subs(sub_type=self.SUB_READBACK, value=val, timestamp=time.time())
time.sleep(max(self.settle_time.get(), 0))
self.motor_is_moving.set_motor_is_moving(0)
val = self.readback.read()
self._run_subs(sub_type=self.SUB_READBACK, value=val, timestamp=time.time())
success = np.isclose(val[self.name]["value"], position, atol=self.tolerance)
self._done_moving(success=success)
threading.Thread(target=move_and_finish, daemon=True).start()
try:
if wait:
status_wait(status)
except KeyboardInterrupt:
self.stop()
raise
return status
@property
def axis_Id(self):
"""
Return the axis_Id_alpha.
"""
return self._axis_Id_alpha
@axis_Id.setter
def axis_Id(self, val: str):
"""
Set the axis_Id_alpha and axis_Id_numeric based on the alpha value.
Args:
val (str): Single-character axis identifier.
"""
if isinstance(val, str):
if len(val) != 1:
raise ValueError("Only single-character axis_Ids are supported.")
self._axis_Id_alpha = val
self._axis_Id_numeric = ord(val.lower()) - 97
else:
raise TypeError(f"Expected value of type str but received {type(val)}")
@property
def axis_Id_numeric(self):
"""
Return the numeric value of the axis_Id.
"""
return self._axis_Id_numeric
@axis_Id_numeric.setter
def axis_Id_numeric(self, val: int):
"""
Set the axis_Id_numeric and axis_Id_alpha based on the numeric value.
Args:
val (int): Numeric axis identifier.
"""
if isinstance(val, int):
if val > 26:
raise ValueError("Numeric value exceeds supported range.")
self._axis_Id_alpha = val
self._axis_Id_numeric = (chr(val + 97)).capitalize()
else:
raise TypeError(f"Expected value of type int but received {type(val)}")
@property
def egu(self):
"""The engineering units (EGU) for positions"""
return "um"
[docs]
def stage(self) -> list[object]:
return super().stage()
[docs]
def unstage(self) -> list[object]:
return super().unstage()
if __name__ == "__main__":
npx = NPointAxis(axis_Id="A", name="npx", host="nPoint000003.psi.ch", port=23)
npy = NPointAxis(axis_Id="B", name="npy", host="nPoint000003.psi.ch", port=23)
npx.controller.on()
print("socket is open, axis is ready!")
npx.move(10)
print(npx.read())
npx.controller.off()