From b58fdc412b39532b7d8b099e02f42b671e6db908 Mon Sep 17 00:00:00 2001 From: Dean Hazineh Date: Thu, 9 Jan 2025 11:39:46 -0500 Subject: [PATCH] test mod --- dflat/GDSII/assemble.py | 20 +++++---- dflat/render/fft_convolve.py | 22 +++++----- dflat/render/util_meas.py | 82 +++++++++++++++++++++++++++++++++--- docs/api/rcwa.rst | 3 +- 4 files changed, 100 insertions(+), 27 deletions(-) diff --git a/dflat/GDSII/assemble.py b/dflat/GDSII/assemble.py index 8bebd4b..9856c58 100644 --- a/dflat/GDSII/assemble.py +++ b/dflat/GDSII/assemble.py @@ -213,7 +213,9 @@ def assemble_standard_shapes( if cell_fun == gdspy.Round: if len(shape_params) == 1: shape_params = [shape_params[0], shape_params[0]] - shape = cell_fun((xoffset, yoffset), shape_params) + shape = cell_fun( + (xoffset, yoffset), shape_params, number_of_points=number_of_points + ) elif cell_fun == gdspy.Rectangle: shape_params += [xoffset, yoffset] shape = cell_fun((xoffset, yoffset), shape_params) @@ -221,17 +223,17 @@ def assemble_standard_shapes( raise ValueError cell.add(shape) - # Add lens markers - hx = cell_size[1] * pshape[1] / gds_unit - hy = cell_size[0] * pshape[0] / gds_unit - ms = marker_size / gds_unit - cell_annot = lib.new_cell(f"TEXT_{unique_id}") - add_marker_tag(cell_annot, ms, hx, hy) + # # Add lens markers + # hx = cell_size[1] * pshape[1] / gds_unit + # hy = cell_size[0] * pshape[0] / gds_unit + # ms = marker_size / gds_unit + # cell_annot = lib.new_cell(f"TEXT_{unique_id}") + # add_marker_tag(cell_annot, ms, hx, hy) - # Create top-level cell and add references + # # Create top-level cell and add references top_cell = lib.new_cell(f"TOP_CELL_{unique_id}") top_cell.add(gdspy.CellReference(cell)) - top_cell.add(gdspy.CellReference(cell_annot)) + # top_cell.add(gdspy.CellReference(cell_annot)) # Write GDS file lib.write_gds(savepath) diff --git a/dflat/render/fft_convolve.py b/dflat/render/fft_convolve.py index 0b54116..8af5b05 100644 --- a/dflat/render/fft_convolve.py +++ b/dflat/render/fft_convolve.py @@ -6,9 +6,7 @@ from dflat.radial_tranforms import resize_with_crop_or_pad - - -def general_convolve(image, filter, rfft=False, mode="valid"): +def general_convolve(image, filter, rfft=False, mode="valid", adjoint=False): """Runs the Fourier space convolution between an image and filter, where the filter kernels may have a different size from the image shape. Args: @@ -40,7 +38,7 @@ def general_convolve(image, filter, rfft=False, mode="valid"): filter_resh = resize_with_crop_or_pad(filter, *image_shape[-2:], radial_flag=False) ### Run the convolution (Defualt to using a checkpoint of the fourier transform) - image = checkpoint(fourier_convolve, image, filter_resh, rfft) + image = checkpoint(fourier_convolve, image, filter_resh, rfft, adjoint) image = torch.real(image) if mode == "valid": @@ -90,27 +88,31 @@ def weiner_deconvolve(image, filter, const=1e-4, abs=False): return image -def fourier_convolve(image, filter, rfft=False): +def fourier_convolve(image, filter, rfft=False, adjoint=False): """Computes the convolution of two signals (real or complex) using frequency space multiplcation. Convolution is done over the two inner-most dimensions. Args: `image` (float or complex): Image to apply filter to, of shape [..., Ny, Nx] `filter` (float or complex): Filter kernel; The kernel must be the same shape as the image + `adjoint' (bool, optional): _description_. Defaults to False. Returns: complex: Image with filter convolved, same shape as input """ + # Ensure inputs are complex TORCH_ZERO = torch.tensor(0.0).to(dtype=image.dtype, device=image.device) if rfft: - fourier_product = rfft2(ifftshift(image)) * rfft2(ifftshift(filter)) + kf = rfft2(ifftshift(filter)) + kf = torch.conj(kf) if adjoint else kf + fourier_product = rfft2(ifftshift(image)) * kf fourier_product = fftshift(irfft2(fourier_product)) else: image = torch.complex(image, TORCH_ZERO) if not image.is_complex() else image - filter = ( - torch.complex(filter, TORCH_ZERO) if not filter.is_complex() else filter - ) - fourier_product = fft2(ifftshift(image)) * fft2(ifftshift(filter)) + kf = torch.complex(filter, TORCH_ZERO) if not filter.is_complex() else filter + kf = fft2(ifftshift(filter)) + kf = torch.conj(kf) if adjoint else kf + fourier_product = fft2(ifftshift(image)) * kf fourier_product = fftshift(ifft2(fourier_product)) return fourier_product diff --git a/dflat/render/util_meas.py b/dflat/render/util_meas.py index 9d2162e..49e19e2 100644 --- a/dflat/render/util_meas.py +++ b/dflat/render/util_meas.py @@ -9,11 +9,11 @@ def hsi_to_rgb( hsi, wavelength_set_m, - demosaic=False, gamma=False, tensor_ordering=False, normalize=True, projection="Basler_Bayer", + process="ideal", ): """Converts a batched hyperspectral datacube of shape [minibatch, Height, Width, Channels] to RGB. If tensor_ordering is true, input may instead be passed with the more common tensor shape [B, Ch, H, W]. The CIE1931 color matching functions are used by default. @@ -21,25 +21,30 @@ def hsi_to_rgb( Args: hsi (float): Hyperspectral cube with shsape [B, H, W, Ch] or [B, Ch, H, W] if tensor_ordering is True. wavelength_set_m (float): List of wavelengths corresponding to the input channel dimension. - demosaic (bool, optional): If True, a Bayer filter mask is applied to the RGB images and then interpolation is used to match experiment. Defaults to True. gamma (bool, optional): Applies gamma transformation to the input images. Defaults to True. tensor_ordering (bool, optional): If True, allows passing in a HSI with the more covenient pytorch to_tensor form. Defaults to False. normalize (bool, optional): If true, the returned projection is max normalized to 1. projection (str, optional): Either "CIE1931" or "Basler_Bayer". Specifies the color spectral curves. - - Returns: + process (str, optional): Either 'idea', 'raw', 'demosaic'. ideal means return 3 color channels with no spatial resolution loss. Demosaic applies bayer mask and interp, raw returns 1 channel spatial mosaiced measurement. + Returns: RGB: Stack of images with output channels=3 """ assert projection in [ "CIE1931", "Basler_Bayer", ], "Projection must be one of ['CIE1931', 'Basler_Bayer']." + assert process in [ + "ideal", + "raw", + "demosaic", + ], "Process must be one of ['ideal', 'raw', 'demosaic']." input_tensor = torch.is_tensor(hsi) if not input_tensor: hsi = torch.tensor(hsi) if tensor_ordering: hsi = hsi.transpose(-3, -1).transpose(-3, -2).contiguous() + assert ( len(wavelength_set_m) == hsi.shape[-1] ), "List of wavelengths should match the input channel dimension." @@ -54,14 +59,22 @@ def hsi_to_rgb( rgb = torch.matmul(hsi, spec) scale = torch.amax(rgb, dim=(-3, -2, -1), keepdim=True) - if normalize: + + if process == "demosaic": + out = bayer_interpolate(bayer_mask(out)) + elif process == "raw": + out = bayer_mask(out) + out = torch.sum(out, axis=-1, keepdims=True) + + if normalize or gamma: rgb = rgb / scale - if demosaic: - rgb = bayer_interpolate(bayer_mask(rgb)) + if gamma: rgb = gamma_correction(rgb) + if tensor_ordering: rgb = rgb.transpose(-3, -1).transpose(-2, -1).contiguous() + if not input_tensor: rgb = rgb.cpu().numpy() @@ -185,3 +198,58 @@ def photons_to_ADU( return torch.clip(electrons_signal, min=0) else: return electrons_signal + + +def rgb_to_hsi_adjoint( + rgb, + wavelength_set_m, + tensor_ordering=False, + normalize=True, + projection="Basler_Bayer", +): + """Compute the adjoint approximation of rgb to hsi (used for some algorithm initializations) + + Args: + rgb (float): Three channel RGB measurement + wavelength_set_m (float): List of wavelengths corresponding to the input channel dimension. + tensor_ordering (bool, optional): If True, allows passing in a HSI with the more covenient pytorch to_tensor form. Defaults to False. + normalize (bool, optional): If true, the returned projection is max normalized to 1. + projection (str, optional): Either "CIE1931" or "Basler_Bayer". Specifies the color spectral curves. + + Returns: + float: Hyperspectral initialization + """ + + assert projection.lower() in [ + "cie1931", + "basler_bayer", + ], "Projection must be one of ['cie1931', 'basler_bayer']." + + input_tensor = torch.is_tensor(rgb) + if not input_tensor: + rgb = torch.tensor(rgb) + + if tensor_ordering: + rgb = rgb.transpose(-3, -1).transpose(-3, -2).contiguous() # ... h w c + assert 3 == rgb.shape[-1], "Channel dimension must be 3 for adjoint transform" + + if projection.lower() == "cie1931": + spec = get_rgb_bar_CIE1931(wavelength_set_m * 1e9) + elif projection.lower() == "basler_bayer": + spec, _ = get_QETrans_Basler_Bayer(wavelength_set_m * 1e9) + spec = np.concatenate([spec[:, 0:1], spec[:, 2:]], axis=-1) + spec = spec / np.sum(spec, axis=0, keepdims=True) + + spec = torch.tensor(spec).type_as(rgb) # [C, 3] + out = torch.matmul(rgb, spec.T) + + if normalize: + out = out / torch.amax(out, dim=(-3, -2, -1), keepdim=True) + + if tensor_ordering: + out = out.transpose(-3, -1).transpose(-2, -1).contiguous() + + if not input_tensor: + out = out.cpu().numpy() + + return out diff --git a/docs/api/rcwa.rst b/docs/api/rcwa.rst index dcf06f4..a7c5d88 100644 --- a/docs/api/rcwa.rst +++ b/docs/api/rcwa.rst @@ -13,5 +13,6 @@ Public Functions :members: :undoc-members: :show-inheritance: - :inherited-members: + + .. automethod:: forward