Skip to content

Commit

Permalink
standardised import statements in __main__.py #25
Browse files Browse the repository at this point in the history
tested and runs; updated function names in dycore.physics.gas_dynamics.thermodynamic, utils.io, and data_assimilation.params
  • Loading branch information
ray-chew committed Mar 23, 2024
1 parent 58d0800 commit 09266bd
Show file tree
Hide file tree
Showing 13 changed files with 87 additions and 101 deletions.
2 changes: 1 addition & 1 deletion docs/source/apis/src.data_assimilation.params.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

.. autosummary::

da_params
init



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ src.dycore.physics.gas\_dynamics.thermodynamic

.. autosummary::

ThermodynamicInit
init



Expand Down
2 changes: 1 addition & 1 deletion docs/source/apis/src.utils.io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
.. autosummary::

datetime
io
init
read_input


Expand Down
156 changes: 71 additions & 85 deletions src/__main__.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,46 @@
import numpy as np

# dependencies of the atmospheric flow solver
from dycore.discretisation.grid import grid_init
from dycore.discretisation.time_update import do
from dycore.utils.boundary import (
set_explicit_boundary_data,
set_ghostnodes_p2,
get_tau_y,
)
from dycore.utils.variable import States, Vars
from dycore.physics.gas_dynamics.thermodynamic import ThermodynamicInit
from dycore.physics.low_mach.mpv import MPV
import dycore.physics.hydrostatics as hydrostatic
import dycore.discretisation.grid as dis_grid
import dycore.discretisation.time_update as dis_time_update
import dycore.utils.boundary as bdry
import dycore.utils.variable as var
import dycore.physics.low_mach.mpv as lm_var
import dycore.physics.hydrostatics as hydrostatic
import dycore.physics.gas_dynamics.thermodynamic as gd_thermodynamics

# dependencies of the parallelisation by dask
from dask.distributed import Client, progress

# dependencies of the data assimilation module
from data_assimilation.params import da_params
from data_assimilation.utils import (
ensemble,
ensemble_inflation,
HSprojector_2t3D,
HSprojector_3t2D,
sparse_obs_selector,
obs_noiser,
)
from data_assimilation.letkf import da_interface, prepare_rloc

# from data_assimilation.letkf import analysis as letkf_analysis
from data_assimilation import etpf
from data_assimilation import blending
from data_assimilation import post_processing
# dependencies of the data assimilation subpackag
from data_assimilation import etpf as da_etpf
from data_assimilation import blending as da_blending
from data_assimilation import post_processing as da_post_processing
from data_assimilation import letkf as da_letkf
from data_assimilation import params as da_params
from data_assimilation import utils as da_utils

# input file
from utils.user_data import UserDataInit
from utils.io import io, get_args, sim_restart, fn_gen, init_logger
import utils.sim_params as gparams
import utils.user_data as user_data
import utils.io as io
import utils.sim_params as params

# some diagnostics
from copy import deepcopy
from time import time
from termcolor import colored
import copy
import time
import termcolor
import logging

# test module
from tests import diagnostics as diag

import logging
import tests.diagnostics as diag

debug = gparams.debug
da_debug = gparams.da_debug
debug =params.debug
da_debug = params.da_debug
output_timesteps = False
if debug == True:
output_timesteps = True
label_type = "TIME"
np.set_printoptions(precision=gparams.print_precision)
np.set_printoptions(precision = params.print_precision)

step = 0
t = 0.0
Expand All @@ -63,12 +49,12 @@
# Initialisation of data containers and helper classes
##########################################################
# get arguments for initial condition and ensemble size
N, UserData, sol_init, restart, ud_rewrite, dap_rewrite, r_params = get_args()
N, UserData, sol_init, restart, ud_rewrite, dap_rewrite, r_params = io.get_args()
if N == 1:
da_debug = False

initial_data = vars(UserData())
ud = UserDataInit(**initial_data)
ud = user_data.UserDataInit(**initial_data)
if ud_rewrite is not None:
ud.update_ud(ud_rewrite)
if hasattr(ud, "rayleigh_bc"):
Expand All @@ -77,22 +63,22 @@
output_timesteps = True
ud.coriolis_strength = np.array(ud.coriolis_strength)

elem, node = grid_init(ud)
elem, node = dis_grid.grid_init(ud)

Sol = Vars(elem.sc, ud)
Sol = var.Vars(elem.sc, ud)

flux = np.empty((3), dtype=object)
flux[0] = States(elem.sfx, ud)
flux[0] = var.States(elem.sfx, ud)
if elem.ndim > 1:
flux[1] = States(elem.sfy, ud)
flux[1] = var.States(elem.sfy, ud)
if elem.ndim > 2:
flux[2] = States(elem.sfz, ud)
flux[2] = var.States(elem.sfz, ud)

th = ThermodynamicInit(ud)
mpv = MPV(elem, node, ud)
bld = blending.Blend(ud)
th = gd_thermodynamics.init(ud)
mpv = lm_var.MPV(elem, node, ud)
bld = da_blending.Blend(ud)

init_logger(ud)
io.init_logger(ud)


##########################################################
Expand All @@ -109,35 +95,35 @@
# 1) batch_obs for the LETKF with batch observations
# 2) rloc for LETKF with grid-point localisation
# 3) etpf for the ETPF algorithm
dap = da_params(N, da_type="rloc")
dap = da_params.init(N, da_type="rloc")
if dap_rewrite is not None:
dap.update_dap(dap_rewrite)

# if elem.ndim == 2:
if dap.da_type == "rloc" and N > 1:
rloc = prepare_rloc(ud, elem, node, dap, N)
rloc = da_letkf.prepare_rloc(ud, elem, node, dap, N)

logging.info(colored("Generating initial ensemble...", "yellow"))
logging.info(termcolor.colored("Generating initial ensemble...", "yellow"))
sol_ens = np.zeros((N), dtype=object)

# Set random seed for reproducibility
np.random.seed(gparams.random_seed)
np.random.seed(params.random_seed)

seeds = np.random.randint(10000, size=N) if N > 1 else None
if seeds is not None and restart == False:
logging.info("Seeds used in generating initial ensemble spread = ", seeds)
for n in range(N):
Sol0 = deepcopy(Sol)
mpv0 = deepcopy(mpv)
Sol0 = copy.deepcopy(Sol)
mpv0 = copy.deepcopy(mpv)
Sol0 = sol_init(Sol0, mpv0, elem, node, th, ud, seed=seeds[n])
sol_ens[n] = [Sol0, deepcopy(flux), mpv0, [-np.inf, step]]
sol_ens[n] = [Sol0, copy.deepcopy(flux), mpv0, [-np.inf, step]]
elif restart == False:
sol_ens = [[sol_init(Sol, mpv, elem, node, th, ud), flux, mpv, [-np.inf, step]]]
elif restart == True:
hydrostatic.state(mpv, elem, node, th, ud)
ud.old_suffix = np.copy(ud.output_suffix)
ud.old_suffix = "_ensemble=%i%s" % (N, ud.old_suffix)
Sol0, mpv0, touts = sim_restart(
Sol0, mpv0, touts = io.sim_restart(
r_params[0], r_params[1], elem, node, ud, Sol, mpv, r_params[2]
)
sol_ens = [[Sol0, flux, mpv0, [-np.inf, step]]]
Expand All @@ -146,9 +132,9 @@
t = touts[0]

if ud.bdry_type[1].value == "radiation":
ud.tcy, ud.tny = get_tau_y(ud, elem, node, 0.5)
ud.tcy, ud.tny = bdry.get_tau_y(ud, elem, node, 0.5)

ens = ensemble(sol_ens)
ens = da_utils.ensemble(sol_ens)

##########################################################
# Load data assimilation observations
Expand All @@ -158,14 +144,14 @@
if N > 1:
obs = dap.load_obs(dap.obs_path)
# obs_mask, no calculations where entries are True
obs_mask = sparse_obs_selector(obs, elem, node, ud, dap)
obs_noisy, obs_covar = obs_noiser(obs, obs_mask, dap, rloc, elem)
obs_mask = da_utils.sparse_obs_selector(obs, elem, node, ud, dap)
obs_noisy, obs_covar = da_utils.obs_noiser(obs, obs_mask, dap, rloc, elem)
# obs_noisy_interp, obs_mask = sparse_obs_selector(obs_noisy, elem, node, ud, dap)


# add ensemble info to filename
if ud.autogen_fn:
ud.output_suffix = fn_gen(ud, dap, N)
ud.output_suffix = io.fn_gen(ud, dap, N)
# ud.output_suffix = '_ensemble=%i%s' %(N, ud.output_suffix)

# ud.output_suffix = '%s_%s' %(ud.output_suffix, 'nr')
Expand All @@ -178,7 +164,7 @@
######################################################
# Initialise writer class for I/O operations
######################################################
writer = io(ud, restart)
writer = io.init(ud, restart)
writer.check_jar()
writer.jar([ud, mpv, elem, node, dap])
# sys.exit("Let's just dill the stuff and quit!")
Expand Down Expand Up @@ -206,7 +192,7 @@

# initialise dask parallelisation and timer
# client = Client(threads_per_worker=1, n_workers=1)
tic = time()
tic = time.time()

######################################################
# Time looping over data assimilation windows
Expand All @@ -231,8 +217,8 @@
# Forecast step
######################################################
logging.info("##############################################")
logging.info(colored("Next tout = %.3f" % tout, "yellow"))
logging.info(colored("Starting forecast...", "green"))
logging.info(termcolor.colored("Next tout = %.3f" % tout, "yellow"))
logging.info(termcolor.colored("Starting forecast...", "green"))
mem_cnt = 0
for mem in ens.members(ens):
# future = client.submit(time_update, *[mem[0],mem[1],mem[2], t, tout, ud, elem, node, mem[3], th, bld, None, False])
Expand All @@ -242,8 +228,8 @@
mem[3][0] = 0 if tout_old in dap.da_times else mem[3][0]
if N == 1:
mem[3][0] = mem[3][1]
logging.info(colored("For ensemble member = %i..." % mem_cnt, "yellow"))
future = do(
logging.info(termcolor.colored("For ensemble member = %i..." % mem_cnt, "yellow"))
future = dis_time_update.do(
mem[0],
mem[1],
mem[2],
Expand Down Expand Up @@ -286,18 +272,18 @@
######################################################
for n in range(N):
Sol = results[n][dap.loc_c]
set_explicit_boundary_data(Sol, elem, ud, th, mpv)
bdry.set_explicit_boundary_data(Sol, elem, ud, th, mpv)
results[n][dap.loc_c] = Sol
p2_nodes = getattr(results[n][dap.loc_n], "p2_nodes")
set_ghostnodes_p2(p2_nodes, node, ud)
bdry.set_ghostnodes_p2(p2_nodes, node, ud)
setattr(results[n][dap.loc_n], "p2_nodes", p2_nodes)

ens.set_members(results, tout)

######################################################
# Write output before assimilating data
######################################################
logging.info(colored("Starting output...", "yellow"))
logging.info(termcolor.colored("Starting output...", "yellow"))
for n in range(N):
Sol = ens.members(ens)[n][0]
mpv = ens.members(ens)[n][2]
Expand All @@ -318,7 +304,7 @@
logging.info("Assimilating %s..." % attr)
logging.info("Assimilating %s..." % attr)
# future = client.submit(da_interface, *[s_res,obs_current,dap.inflation_factor,attr,N,ud,dap.loc[attr]])
future = da_interface(results, dap, obs, attr, tout, N, ud)
future = da_letkf.da_interface(results, dap, obs, attr, tout, N, ud)
futures.append(future)

# analysis = client.gather(futures)
Expand All @@ -338,20 +324,20 @@
##################################################
elif dap.da_type == "rloc":
logging.info(
colored("Starting analysis... for rloc algorithm", "green")
termcolor.colored("Starting analysis... for rloc algorithm", "green")
)
results = HSprojector_3t2D(results, elem, dap, N)
results = da_utils.HSprojector_3t2D(results, elem, dap, N)
results = rloc.analyse(results, obs, obs_covar, obs_mask, N, tout)
results = HSprojector_2t3D(results, elem, node, dap, N)
results = da_utils.HSprojector_2t3D(results, elem, node, dap, N)
# if hasattr(dap, 'converter'):
# results = dap.converter(results, N, mpv, elem, node, th, ud)

##################################################
# ETPF
##################################################
elif dap.da_type == "etpf":
ensemble_inflation(results, dap.attributes, dap.inflation_factor, N)
results = etpf.da_interface(
da_utils.ensemble_inflation(results, dap.attributes, dap.inflation_factor, N)
results = da_etpf.da_interface(
results,
obs,
dap.obs_attributes,
Expand All @@ -365,7 +351,7 @@
# Post-processing
##################################################
elif dap.da_type == "pprocess":
results = post_processing.interface()
results = da_post_processing.interface()

else:
assert 0, "DA type not implemented: use 'rloc', 'batch_obs' or 'etpf'."
Expand All @@ -375,18 +361,18 @@
######################################################
for n in range(N):
Sol = results[n][dap.loc_c]
set_explicit_boundary_data(Sol, elem, ud, th, mpv)
bdry.set_explicit_boundary_data(Sol, elem, ud, th, mpv)
results[n][dap.loc_c] = Sol
p2_nodes = getattr(results[n][dap.loc_n], "p2_nodes")
set_ghostnodes_p2(p2_nodes, node, ud)
bdry.set_ghostnodes_p2(p2_nodes, node, ud)
setattr(results[n][dap.loc_n], "p2_nodes", p2_nodes)

ens.set_members(results, tout)

######################################################
# Write output at tout
######################################################
logging.info(colored("Starting output...", "yellow"))
logging.info(termcolor.colored("Starting output...", "yellow"))
for n in range(N):
Sol = ens.members(ens)[n][0]
mpv = ens.members(ens)[n][2]
Expand All @@ -401,14 +387,14 @@
# synchronise_variables(mpv, Sol, elem, node, ud, th)
t = tout
tout_old = np.copy(tout)
logging.info(colored("tout = %.3f" % tout, "yellow"))
logging.info(termcolor.colored("tout = %.3f" % tout, "yellow"))

tout_cnt += 1
outer_step += 1
if outer_step > ud.stepmax:
break

toc = time()
logging.info(colored("Time taken = %.6f" % (toc - tic), "yellow"))
toc = time.time()
logging.info(termcolor.colored("Time taken = %.6f" % (toc - tic), "yellow"))

writer.close_everything()
2 changes: 1 addition & 1 deletion src/data_assimilation/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import logging

class da_params(object):
class init(object):

def __init__(self,N,da_type='rloc'):
# number of ensemble members
Expand Down
2 changes: 1 addition & 1 deletion src/dycore/discretisation/time_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def do(
Nodes grid.
step : int
Current step.
th : :class:`physics.gas_dynamics.thermodynamic.ThermodynamicInit`
th : :class:`physics.gas_dynamics.thermodynamic.init`
Thermodynamic variables of the system
bld : :class:`data_assimilation.blending.Blend()`
Blending class used to initalise interface blending methods.
Expand Down
Loading

0 comments on commit 09266bd

Please sign in to comment.