Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CSX Jan 2020 zone plate experiment #75

Merged
merged 17 commits into from
Feb 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nsls2ptycho/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
#TODO: use versioneer instead
__version__ = '1.4.0b2'
__version__ = '1.4.0b7'
101 changes: 75 additions & 26 deletions nsls2ptycho/core/CSX_databroker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
except FileNotFoundError:
print("csx.yml not found. Unable to access CSX's database.", file=sys.stderr)
csx_db = None
from csxtools.utils import get_fastccd_images, get_images_to_4D
from csxtools.utils import get_fastccd_images, get_images_to_4D, get_fastccd_flatfield


# ***************************** "Public API" *****************************
Expand All @@ -28,7 +28,7 @@
# ************************************************************************


# CSX's fastccd detector has a vertical dark stride
# CSX's fastccd detector has a vertical dark stripe
cs = 486 # pixel start point
cl = 28 # width
cedge = 988 # detector edge
Expand Down Expand Up @@ -61,29 +61,46 @@ def load_metadata(db, scan_num:int, det_name:str=''):
energy_kev = scan_supp.pgm_energy_setpoint.iat[0] / 1E+3

# get scan_type, x_range, y_range, dr_x, dr_y
extents = header.start['extents']
if scan_type.endswith('grid_scan'):
extents = header.start['extents']
x_num = plan_args['args'][7]
y_num = plan_args['args'][3]
x_range = (extents[0][1] - extents[0][0]) * 1E+3
y_range = (extents[1][1] - extents[1][0]) * 1E+3
dr_x = x_range / (x_num - 1)
dr_y = y_range / (y_num - 1)

# get points
points = np.zeros((2, nz))
points[0] = scan_data.nanop_bx * 1E+3
points[1] = scan_data.nanop_bz * 1E+3
elif scan_type == 'spiral_continuous':
extents = header.start['extents']
# The customized plan messed up...
x_range = (extents[0][1] - extents[0][0])/2 * 1E+3
y_range = (extents[1][1] - extents[1][0])/2 * 1E+3
# not available in CSX
dr_x = 0
dr_y = 0
else:
raise NotImplementedError("Ask Wen Hu to explain to Leo Fang.")

# get points
points = np.zeros((2, nz))
points[0] = scan_data.nanop_bx * 1E+3
points[1] = scan_data.nanop_bz * 1E+3
# get points
points = np.zeros((2, nz))
points[0] = scan_data.nanop_bx * 1E+3
points[1] = scan_data.nanop_bz * 1E+3
elif scan_type == 'scan_nd': # for custom plans with a zone plate
# this is a mesh scan, but a lot of info is missing in the plan...
# set the missing values to zero
x_range = 0
y_range = 0
dr_x = 0
dr_y = 0

# get points (from top motors, not bottom ones!)
points = np.zeros((2, nz))
points[0] = scan_data.nanop_tx * 1E+3
points[1] = scan_data.nanop_tz * 1E+3
else:
raise NotImplementedError("Ask Wen Hu to explain to Leo Fang.")
# get angle, ic
angle = 90 - scan_supp.tardis_theta.iat[0] # TODO(leofang): verify this

Expand Down Expand Up @@ -142,10 +159,6 @@ def save_data(db, param, scan_num:int, nx_prb:int, ny_prb:int, cx:int, cy:int, t
print("[WARNING] Bad-pixel removal is not yet supported for CSX.", file=sys.stderr)
bad_pixels = None

if zero_out is not None:
print("[WARNING] Zero masks are not yet supported for CSX.", file=sys.stderr)
zero_out = None

det_distance_m = param.z_m
det_pixel_um = param.ccd_pixel_um
num_frame = param.nz
Expand Down Expand Up @@ -174,10 +187,10 @@ def save_data(db, param, scan_num:int, nx_prb:int, ny_prb:int, cx:int, cy:int, t

# get raw data
images_stack = get_images_to_4D(itr)
raw_mean_data = _preprocess_image(images_stack)
raw_mean_data = _preprocess_image(images_stack, bad_pixels, zero_out)

# construct data array
diffamp = np.empty((num_frame, nx_prb, ny_prb))
diffamp = np.empty((num_frame, nx_prb//2*2, ny_prb//2*2))
diffamp[...] = np.rot90(raw_mean_data[:, cy-ny_prb//2:cy+ny_prb//2, cx-nx_prb//2:cx+nx_prb//2], axes=(2, 1)) # equivalent to np.flipud(arr).T
diffamp = np.fft.fftshift(diffamp, axes=(1, 2))
diffamp = np.sqrt(diffamp)
Expand Down Expand Up @@ -216,34 +229,58 @@ def save_data(db, param, scan_num:int, nx_prb:int, ny_prb:int, cx:int, cy:int, t
os.remove(symlink_path)
os.symlink(file_path, symlink_path)

'''
For actual scans:
key : value = (scan num, dark8 num, dark2 num, dark1 num) : slicerator

For flat-field scans:
key : value = (scan num, dark8 num, dark2 num, dark1 num) : flat image
'''
scan_image_itr_cache = {}


def _load_scan_image_itr(db, scan_num:int, dark8ID:int=None, dark2ID:int=None, dark1ID:int=None):
bgnd8 = db[dark8ID] if (dark8ID is not None) else None
bgnd2 = db[dark2ID] if (dark2ID is not None) else None
bgnd1 = db[dark1ID] if (dark1ID is not None) else None
def _load_scan_image_itr(db, scan_num:int, dark8:int=None, dark2:int=None, dark1:int=None, key_flat:tuple=None):
bgnd8 = db[dark8] if (dark8 is not None) else None
bgnd2 = db[dark2] if (dark2 is not None) else None
bgnd1 = db[dark1] if (dark1 is not None) else None
if (bgnd8 is None) and (bgnd2 is None) and (bgnd1 is None):
dark_headers = None
else:
dark_headers = (bgnd8, bgnd2, bgnd1)
silcerator = get_fastccd_images(db[scan_num], dark_headers=dark_headers)

flat_scan_num, flat_scan_dark8, flat_scan_num_dark2, flat_scan_dark1 = key_flat
if flat_scan_num is not None:
flat_im = scan_image_itr_cache.get(key_flat) # load flat field image from cache
if flat_im is None:
flat_im = get_fastccd_flatfield(db[flat_scan_num], dark=(db[flat_scan_dark8], db[flat_scan_num_dark2], db[flat_scan_dark1]))
scan_image_itr_cache[key_flat] = flat_im
roi = [0, 0, 960, 1000] # the entire detector; TODO: remove the hard-coded value?
else:
flat_im = None
roi = None

silcerator = get_fastccd_images(db[scan_num], dark_headers=dark_headers, flat=flat_im, roi=roi)
return silcerator


def get_single_image(db, frame_num, scan_num:int, dark8ID:int=None, dark2ID:int=None, dark1ID:int=None):
def get_single_image(db, frame_num, scan_num:int, dark8:int=None, dark2:int=None, dark1:int=None,
flat_scan_num:int=None, flat_scan_dark8:int=None, flat_scan_num_dark2:int=None, flat_scan_dark1:int=None):
# TODO: use mds_table here
key = (scan_num, dark8ID, dark2ID, dark1ID)
key_flat = (flat_scan_num, flat_scan_dark8, flat_scan_num_dark2, flat_scan_dark1)
key = (scan_num, dark8, dark2, dark1, key_flat)

if key in scan_image_itr_cache:
return _preprocess_image(scan_image_itr_cache[key][frame_num])
else:
itr = _load_scan_image_itr(db, scan_num, dark8ID, dark2ID, dark1ID)
itr = _load_scan_image_itr(db, scan_num, dark8, dark2, dark1, key_flat)
scan_image_itr_cache[key] = itr
return _preprocess_image(itr[frame_num])


def _preprocess_image(img):
def _preprocess_image(img, bad_pixels=None, zero_out=None):
# TODO: support this
assert bad_pixels is None

# average over the axis corresponding to the same scan point,
# and then remove the stripe
if img.ndim == 3:
Expand All @@ -259,6 +296,13 @@ def _preprocess_image(img):
img[img < 0.] = 0. # needed due to dark subtraction
img = np.mean(img, axis=axis)
img = stack((img[..., :cs], img[..., cs+cl:cedge]))

# The coordinates we got from GUI have the stripe removed
if zero_out is not None:
for blue_roi in zero_out:
x0, y0, w, h = blue_roi
img[..., y0:y0+h, x0:x0+w] = 0.

return img


Expand All @@ -270,6 +314,9 @@ def _expand_partial_key(scan_num:int):
if scan_num in key:
# sort key based on the number of provided dark IDs
# prefer a complete info
if key[4][0] is not None: # found match with flat-field
related_keys[5] = key
break # shortcut
if key[3] is not None:
related_keys[4] = key
break # shortcut
Expand All @@ -280,7 +327,9 @@ def _expand_partial_key(scan_num:int):
else:
related_keys[1] = key

if 4 in related_keys:
if 5 in related_keys:
key = related_keys[5]
elif 4 in related_keys:
key = related_keys[4]
elif 3 in related_keys:
key = related_keys[3]
Expand All @@ -291,7 +340,7 @@ def _expand_partial_key(scan_num:int):
print("[WARNING] Proceeding without dark IDs...", file=sys.stderr)
else:
raise ValueError("Data for scan number", scan_num, "not found. Forget to click load?")
print("Found", key, "from", related_keys)
print("Found", key) # "from", related_keys)

return key

Expand Down
2 changes: 1 addition & 1 deletion nsls2ptycho/core/HXN_databroker.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def save_data(db, param, scan_num:int, n:int, nn:int, cx:int, cy:int, threshold=
#print('depth of field: ', x_depth_of_field_m, y_depth_of_field_m)

# get data array
data = np.zeros((num_frame, n, nn)) # nz*nx*ny
data = np.zeros((num_frame, n//2*2, nn//2*2)) # nz*nx*ny
mask = []
for i in range(num_frame):
#print(param.mds_table.iat[i], file=sys.stderr)
Expand Down
3 changes: 3 additions & 0 deletions nsls2ptycho/core/databroker_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,6 @@ def _load_CSX():
raise RuntimeError("[WARNING] Cannot detect the beamline name. Databroker is disabled.")
except RuntimeError as ex:
print(ex, file=sys.stderr)


del config_path, hostname, json, os, platform, sys
6 changes: 3 additions & 3 deletions nsls2ptycho/core/ptycho_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ def __init__(self):

# param for Bragg mode
self.bragg_flag = False
self.bragg_theta = 69.41
self.bragg_gamma = 33.4
self.bragg_delta = 15.458
self.bragg_theta = 0.
self.bragg_gamma = 0.
self.bragg_delta = 0.

# partial coherence parameter
self.pc_flag = False
Expand Down
10 changes: 4 additions & 6 deletions nsls2ptycho/core/widgets/eventhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def set_curr_roi(self, ax, xy, width, height):

ax.figure.canvas.draw()

def _find_closest_rect(self, x, y, delta=2.):
def _find_closest_rect(self, x, y, delta=10.):
min_dist = 9999999.
ref_rect = None
ref_idx = -1
Expand Down Expand Up @@ -231,7 +231,7 @@ def on_press(event):
self.coord_x_ratio = event.xdata / event.x
self.coord_y_ratio = (ax.get_ylim()[0] - event.ydata) / event.y

ref_rect, ref_idx = self._find_closest_rect(event.xdata, event.ydata, delta=2.)
ref_rect, ref_idx = self._find_closest_rect(event.xdata, event.ydata)
self.rect_x0 = None
self.rect_y0 = None

Expand All @@ -241,11 +241,10 @@ def on_press(event):
self.rect_y0 = event.ydata

# make solid line for all existing roi
for rect in self.all_rect: rect.set_linestyle('solid')

for rect in self.all_rect:
rect.set_linestyle('solid')
# left click, select an existing roi
elif event.button == 1 and ref_rect is not None and ref_idx >= 0:

if event.dblclick:
clr = RED_EDGECOLOR
if self.ref_rect.get_edgecolor() == RED_EDGECOLOR:
Expand All @@ -260,7 +259,6 @@ def on_press(event):
# make solid line for all existing roi except the selected one
for rect in self.all_rect: rect.set_linestyle('solid')
self.ref_rect.set_linestyle('dashed')

# right click, delete the selected roi
elif event.button == 3 and ref_rect is not None and ref_idx >=0:
self.ref_rect = ref_rect
Expand Down
39 changes: 35 additions & 4 deletions nsls2ptycho/core/widgets/list_widget.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,56 @@
from PyQt5 import QtCore, QtGui, QtWidgets
from nsls2ptycho.core.utils import parse_range
from nsls2ptycho.ui import ui_list_dialog


#def _dragEnterEvent(listWidget, event):
# #event.setDropAction(QtCore.Qt.MoveAction)
# event.accept()
#
#def _dragMoveEvent(listWidget, event):
# #event.setDropAction(QtCore.Qt.MoveAction)
# event.accept()
#
#def _dropEvent(listWidget, event):
# #event.setDropAction(QtCore.Qt.CopyAction)
# event.accept()
# listWidget.addItem(str(event.mimeData().text()))


class ListWidget(QtWidgets.QDialog, ui_list_dialog.Ui_Form):
def __init__(self, parent=None):
super().__init__(parent)
self.setupUi(self)
QtWidgets.QApplication.setStyle('Plastique')
self.setObjectName("Assoc.Scans")

self.listWidget.setDragDropMode(QtWidgets.QAbstractItemView.DragDrop)
self.listWidget.setSelectionMode(QtWidgets.QAbstractItemView.ExtendedSelection)
#self.listWidget.setSelectionMode(QtWidgets.QAbstractItemView.SingleSelection)
self.listWidget.setAcceptDrops(True)
self.listWidget.setDragEnabled(True)
self.listWidget.setDropIndicatorShown(True)
self.listWidget.setDefaultDropAction(QtCore.Qt.MoveAction)
#self.listWidget.dragEnterEvent = lambda event: _dragEnterEvent(self.listWidget, event)
#self.listWidget.dragMoveEvent = lambda event: _dragMoveEvent(self.listWidget, event)
#self.listWidget.dropEvent = lambda event: _dropEvent(self.listWidget, event)

self.btn_add_item.clicked.connect(self.add_item)
self.btn_rm_item.clicked.connect(self.remove_item)
# self.btn_close.clicked.connect(self.close_dialog)
#

self.le_input.setToolTip(QtCore.QCoreApplication.translate("Assoc.Scans", "Set scan numbers and ranges. Example: 128-131, 137-139"))

# def close_dialog(self):
# self.destroy()

def add_item(self):
self.listWidget.addItem(str(self.le_input.text()))
items = parse_range(self.le_input.text(), batch_processing=False)
self.listWidget.addItems([str(item) for item in items])
self.le_input.setText('')
self.listWidget.sortItems()
#self.listWidget.sortItems()

def remove_item(self):
item = self.listWidget.takeItem(self.listWidget.currentRow())
self.listWidget.sortItems()
#self.listWidget.sortItems()
del item
Loading