Skip to content

Commit

Permalink
Removing an unused function.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaanaksit committed Dec 19, 2023
1 parent 7c63f55 commit b46c96a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 45 deletions.
48 changes: 4 additions & 44 deletions odak/learn/wave/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,45 +210,6 @@ def evaluate(self, input_image, target_image, plane_id = 0):
return loss



def reconstruct(self, hologram_phases):
"""
Internal function to reconstruct a given hologram.
Parameters
----------
hologram_phases : torch.tensor
A monochrome hologram phase [mxn].
Returns
-------
reconstruction_intensities : torch.tensor
Reconstructed frames.
reconstruction_intensity : torch.tensor
Reconstructed image.
peak_intensity : float
Peak intensity in the reconstructed image.
"""
torch.no_grad()
reconstruction_intensities = torch.zeros(
self.number_of_frames,
self.number_of_depth_layers,
self.number_of_channels,
self.resolution[0] * self.scale_factor,
self.resolution[1] * self.scale_factor,
device = self.device
)
for frame_id in range(self.number_of_frames):
for depth_id in range(self.number_of_depth_layers):
for channel_id in range(self.number_of_channels):
laser_power = self.propagator_get_laser_powers()[frame_id][channel_id]
hologram = generate_complex_field(laser_power * self.amplitude, hologram_phases[frame_id] * self.phase_scale[channel_id])
reconstruction_field = self.propagator(hologram, channel_id, depth_id)
reconstruction_intensities[frame_id, depth_id, channel_id] = calculate_amplitude(reconstruction_field) ** 2
return reconstruction_intensities


def double_phase_constrain(self, phase, phase_offset):
"""
Internal function to constrain a given phase similarly to double phase encoding.
Expand Down Expand Up @@ -366,10 +327,11 @@ def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]):
self.amplitude[1::self.scale_factor, 1::self.scale_factor] = 0.
else:
phase_scaled = phase
scaled_phase = phase_scaled * self.phase_scale[channel_id]
laser_power = laser_powers[frame_id][channel_id]
amplitude = laser_power * self.amplitude
hologram = generate_complex_field(amplitude, scaled_phase)
hologram = generate_complex_field(
laser_power * self.amplitude,
phase_scaled * self.phase_scale[channel_id]
)
reconstruction_field = self.propagator(hologram, channel_id, depth_id)
intensity = calculate_amplitude(reconstruction_field) ** 2
reconstruction_intensities[frame_id, channel_id] += intensity
Expand Down Expand Up @@ -415,12 +377,10 @@ def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]):
t.set_description(description)
del total_loss
del loss_image
del scaled_phase
del reconstruction_field
del reconstruction_intensities
del intensity
del phase
del amplitude
del hologram
logging.warning(description)
return hologram_phases.detach()
Expand Down
2 changes: 1 addition & 1 deletion odak/learn/wave/propagators.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,5 +351,5 @@ def reconstruct(self, hologram_phases, amplitude = None, no_grad = True):
frame_id,
depth_id,
channel_id
] = calculate_amplitude(reconstruction_field) ** 2
] = calculate_amplitude(reconstruction_field).detach().clone() ** 2
return reconstruction_intensities

0 comments on commit b46c96a

Please sign in to comment.