-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinterpolate_potentials.py
executable file
·144 lines (98 loc) · 4.72 KB
/
interpolate_potentials.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/software/anaconda3/bin/python
# --- Import standard modules to the python path.
import os
import pdb
import argparse
import pdb
import time
import pickle
import numpy as np
import pandas as pd
import itertools
import multiprocessing
from functools import partial
import astropy.units as u
import astropy.constants as C
from galpy.potential import RazorThinExponentialDiskPotential, DoubleExponentialDiskPotential, NFWPotential
from galpy.potential import interpRZPotential
from kickIT import utils
# --- Specify arguments for the interpolation function
def parse_commandline():
"""
Parse the arguments given on the command-line.
"""
parser = argparse.ArgumentParser()
# default information
parser.add_argument('-g', '--gal-path', type=str, help="Path to pickled gal file that we want to create interpolants for.")
parser.add_argument('-mp', '--multiproc', type=str, default=None, help="If specified, will parallelize over the number of cores provided as an argument. Can also use the string 'max' to parallelize over all available cores. Default is None.")
parser.add_argument('--interp-path', type=str, default='./interp.pkl', help="Path to where the interpolation file will be saved. Default is '/.interp.pkl'.")
# defining grid properties for interpolations
parser.add_argument('-rg', '--Rgrid', type=int, default=100, help="Number of gridpoints for the Z-component of the interpolation model. Default is 100.")
parser.add_argument('-zg', '--Zgrid', type=int, default=50, help="Number of gridpoints for the Z-component of the interpolation model. Default is 50.")
parser.add_argument('--Rgrid-max', type=float, default=1e3, help="Maximum R value for interpolated potentials. Default is 1e3.")
parser.add_argument('--Zgrid-max', type=float, default=1e2, help="Maximum Z value for interpolated potentials. Default is 1e2.")
args = parser.parse_args()
return args
def main(args):
"""
Main function.
"""
start = time.time()
# --- read galaxy file
gal = pickle.load(open(args.gal_path, 'rb'))
# --- create interpolatnts for the potentials in gal class
interps = construct_interpolants(gal, \
multiproc = args.multiproc, \
Rgrid = args.Rgrid, \
Zgrid = args.Zgrid, \
Rgrid_max = args.Rgrid_max, \
Zgrid_max = args.Zgrid_max)
pickle.dump(interps, open(args.interp_path, 'wb'))
def construct_interpolants(gal, multiproc=None, Rgrid=500, Zgrid=100, Rgrid_max=1000, Zgrid_max=100, ro=8*u.kpc, vo=220*u.km/u.s):
"""Creates interpolants for combined potentials specified in gal class.
To implement multiprocessing, specify an int for the argument 'multiproc'.
"""
print('Creating interpolation models of combined galactic potentials at each redshift...\n')
# --- create the grid of rads and heights we will be using
rad_range = np.asarray([1e-4, Rgrid_max])*u.kpc
height_range = np.asarray([0, Zgrid_max])*u.kpc
# need to convert Rs and Zs to natural units
rads = (rad_range / ro).value
heights = (height_range / ro).value
rs = (*rads, Rgrid)
logrs = (*np.log10(rads), Rgrid)
zs = (*heights, Zgrid)
# --- set up the interpolation function
func = partial(interp_func, rgrid=logrs, zgrid=zs)
# --- enable multiprocessing, if specified
if multiproc:
if multiproc=='max':
mp = multiprocessing.cpu_count()
else:
mp = int(multiproc)
pool = multiprocessing.Pool(mp)
func = partial(interp_func, rgrid=logrs, zgrid=zs)
start = time.time()
print('Parallelizing interpolations over {0:d} cores...\n'.format(mp))
interpolated_potentials = pool.map(func, gal.full_potentials_natural)
stop = time.time()
print(' finished! It took {0:0.2f}s\n'.format(stop-start))
# otherwise, do this in serial
else:
print('Interpolating potentials in serial...\n')
interpolated_potentials=[]
for ii, data in enumerate(gal.full_potentials_natural):
start = time.time()
ip = func(data)
end = time.time()
print(' interpolated potential for step {0:d} (z={1:0.2f}) created in {2:0.2f}s...'.format(ii,gal.redz[ii],end-start))
interpolated_potentials.append(ip)
return interpolated_potentials
# --- define interpolating function
def interp_func(potentials, rgrid, zgrid, ro=8*u.kpc, vo=220*u.km/u.s):
ip = interpRZPotential(potentials, rgrid=rgrid, zgrid=zgrid, logR=True, interpRforce=True, interpzforce=True, zsym=True, ro=ro, vo=vo)
return ip
# MAIN FUNCTINON
if __name__ == '__main__':
args = parse_commandline()
main(args)