-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathRS.py
78 lines (61 loc) · 3.62 KB
/
RS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
This is the implementation of the Rayleigh-Sommerfeld algorithm. Refer to Goodman, Joseph W.
Introduction to Fourier optics. Roberts and Company Publishers, 2005, for principle details.
This code is adapted from a Matlab script from Xin Liu and converted into a GPU parallel
-computing Python script by Haoyu Wei (haoyu.wei97@gmail.com).
This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
# The license is only for non-commercial use (commercial licenses can be obtained from authors).
# The material is provided as-is, with no warranties whatsoever.
# If you publish any code, data, or scientific work based on this, please cite our work.
Technical Paper:
Haoyu Wei, Xin Liu, Xiang Hao, Edmund Y. Lam, and Yifan Peng, "Modeling off-axis
diffraction with the least-sampling angular spectrum method," Optica 10, 959-962 (2023)
"""
import torch
import math
from tqdm import tqdm
class RSDiffraction_GPU():
'''
Optimized for parallel computing
'''
def __init__(self, z, xvec, yvec, svec, tvec, wavelengths, device) -> None:
'''
x,s are horizontal. y,t are vertical.
'''
self.device = device
self.k = 2 * torch.pi / wavelengths
self.z = z
xvec, yvec = torch.tensor(xvec), torch.tensor(yvec)
svec, tvec = torch.tensor(svec), torch.tensor(tvec)
xx, yy = torch.meshgrid(xvec, yvec, indexing='xy')
ss, tt = torch.meshgrid(svec, tvec, indexing='xy')
self.ss, self.tt = ss.to(device), tt.to(device)
self.xx, self.yy = xx.to(device), yy.to(device)
self.block_sz = 100 # depends on your memory, e.g., 128 needs ~24GB GPU memory
def __call__(self, E0):
E0 = torch.tensor(E0, dtype=torch.complex128, device=self.device)
LX, LY = E0.shape[-2:]
LS, LT = self.ss.shape
Eout = []
for bt in tqdm(range(math.ceil(LT / self.block_sz)), desc='tvec', position=0):
Erow = []
for bs in tqdm(range(math.ceil(LS / self.block_sz)), desc='svec', position=1, leave=False):
ss_ = self.ss[bt*self.block_sz : (bt+1)*self.block_sz, bs*self.block_sz : (bs+1)*self.block_sz]
tt_ = self.tt[bt*self.block_sz : (bt+1)*self.block_sz, bs*self.block_sz : (bs+1)*self.block_sz]
block_sum = torch.zeros_like(ss_, dtype=E0.dtype)
for by in tqdm(range(math.ceil(LY / self.block_sz)), desc='yvec', position=2, leave=False):
for bx in tqdm(range(math.ceil(LX / self.block_sz)), desc='xvec', position=3, leave=False):
E0_ = E0[by*self.block_sz : (by+1)*self.block_sz, bx*self.block_sz : (bx+1)*self.block_sz]
xx_ = self.xx[by*self.block_sz : (by+1)*self.block_sz, bx*self.block_sz : (bx+1)*self.block_sz]
yy_ = self.yy[by*self.block_sz : (by+1)*self.block_sz, bx*self.block_sz : (bx+1)*self.block_sz]
xx_st = xx_[..., None, None]
yy_st = yy_[..., None, None]
xy_ss = ss_.expand(*xx_.shape, *ss_.shape)
xy_tt = tt_.expand(*xx_.shape, *tt_.shape)
r = torch.sqrt((xy_ss - xx_st)**2 + (xy_tt - yy_st)**2 + self.z**2)
h = -1 / (2 * torch.pi) * (1j * self.k - 1 / r) * torch.exp(1j * self.k * r) * self.z / r**2
block_sum += torch.einsum('xy, xyst', E0_, h)
Erow.append(block_sum)
Eout.append(torch.hstack(Erow))
Eout = torch.vstack(Eout)
return Eout.cpu().numpy()