Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/casangi/graphviper
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan-Willem committed Sep 3, 2024
2 parents bdaec59 + 9892742 commit e9b48ee
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 144 deletions.
18 changes: 1 addition & 17 deletions src/graphviper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,4 @@
log_to_file=False,
log_file="graphviper-logfile",
log_level="INFO",
)

# Setup environment variable to identify the configuration directory for the module
if os.path.exists(os.path.dirname(__file__) + "/config/"):
if not os.getenv("PARAMETER_CONFIG_PATH"):
os.environ["PARAMETER_CONFIG_PATH"] = os.path.dirname(__file__) + "/config/"

else:
if os.path.dirname(__file__) + "/config/" not in os.getenv(
"PARAMETER_CONFIG_PATH"
):
os.environ["PARAMETER_CONFIG_PATH"] = ":".join(
(
os.environ["PARAMETER_CONFIG_PATH"],
os.path.dirname(__file__) + "/config/",
)
)
)
77 changes: 38 additions & 39 deletions src/graphviper/graph_tools/coordinate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@


def make_time_coord(
time_start: str = "2019-10-03T19:00:00.000",
time_delta: numbers.Number = 3600,
n_samples: int = 10,
time_scale: {"tai", "tcb", "tcg", "tdb", "tt", "ut1", "utc", "local"} = "utc",
time_start: str = "2019-10-03T19:00:00.000",
time_delta: numbers.Number = 3600,
n_samples: int = 10,
time_scale: {"tai", "tcb", "tcg", "tdb", "tt", "ut1", "utc", "local"} = "utc",
) -> Dict:
"""Convenience function that creates a time coordinate `measures dictionary <https://docs.google.com/spreadsheets/d/14a6qMap9M5r_vjpLnaBKxsR9TF4azN5LVdOxLacOX-s/edit#gid=1504318014>`_ that can be used to create :ref:`parallel_coords <parallel coords>` using :func:`make_parallel_coord` function.
Expand Down Expand Up @@ -74,10 +74,10 @@ def make_time_coord(


def make_frequency_coord(
freq_start: numbers.Number = 3 * 10**9,
freq_delta: numbers.Number = 0.4 * 10**9,
n_channels: int = 50,
velocity_frame: {"gcrs", "icrs", "hcrs", "lsrk", "lsrd", "lsr"} = "lsrk",
freq_start: numbers.Number = 3 * 10 ** 9,
freq_delta: numbers.Number = 0.4 * 10 ** 9,
n_channels: int = 50,
velocity_frame: {"gcrs", "icrs", "hcrs", "lsrk", "lsrd", "lsr"} = "lsrk",
) -> Dict:
"""Convenience function that creates a frequency coordinate `measures dictionary <https://docs.google.com/spreadsheets/d/14a6qMap9M5r_vjpLnaBKxsR9TF4azN5LVdOxLacOX-s/edit#gid=1504318014>`_ that can be used to create :ref:`parallel_coords <parallel coords>` using :func:`make_parallel_coord` function.
Expand Down Expand Up @@ -123,9 +123,9 @@ def make_frequency_coord(


def make_parallel_coord(
coord: Union[Dict, xr.DataArray],
n_chunks: Union[None, int] = None,
gap: Union[None, float] = None,
coord: Union[Dict, xr.DataArray],
n_chunks: Union[None, int] = None,
gap: Union[None, float] = None,
) -> Dict:
"""Creates a single parallel coordinate from a `measures dictionary <https://docs.google.com/spreadsheets/d/14a6qMap9M5r_vjpLnaBKxsR9TF4azN5LVdOxLacOX-s/edit#gid=1504318014>`_ or a `xarray.DataArray <https://docs.xarray.dev/en/stable/generated/xarray.DataArray.html>`_ with `measures attributes <https://docs.google.com/spreadsheets/d/14a6qMap9M5r_vjpLnaBKxsR9TF4azN5LVdOxLacOX-s/edit#gid=1504318014>`_.
Expand Down Expand Up @@ -216,7 +216,7 @@ def make_parallel_coord(


def make_parallel_coord_by_gap(coord: Union[Dict, xr.DataArray],
gap: float)-> Dict:
gap: float) -> Dict:
"""Creates a single parallel coordinate from from a `measures dictionary <https://docs.google.com/spreadsheets/d/14a6qMap9M5r_vjpLnaBKxsR9TF4azN5LVdOxLacOX-s/edit#gid=1504318014>`_ or a `xarray.DataArray <https://docs.xarray.dev/en/stable/generated/xarray.DataArray.html>`_ with `measures attributes <https://docs.google.com/spreadsheets/d/14a6qMap9M5r_vjpLnaBKxsR9TF4azN5LVdOxLacOX-s/edit#gid=1504318014>`_.
This function only returns a single :ref:`parallel_coord <parallel coord>`; to create :ref:`parallel_coords <parallel coords>` a dictionary must be created where the keys are the dimension coordinate names and the values are the respective :ref:`parallel_coord <parallel coord>`.
Expand All @@ -232,19 +232,19 @@ def make_parallel_coord_by_gap(coord: Union[Dict, xr.DataArray],
coord = coord.copy(deep=True).to_dict()
parallel_coord = {}
coord_data = np.array(coord['data'])
if len(coord_data.shape)>1:
if len(coord_data.shape) > 1:
raise ValueError

nx = len(coord_data)
dxs = coord_data[1:] - coord_data[:-1]
jumps0 = np.argwhere(dxs>gap).flatten()
jumps = [0]+list(jumps0) + [nx]
jumps0 = np.argwhere(dxs > gap).flatten()
jumps = [0] + list(jumps0) + [nx]
data_chunk_edges = []
data_chunks = {}
for i, rnge in enumerate(itertools.pairwise(jumps)):
i0, i1 = rnge
data_chunks[i] = coord_data[range(*rnge)]
data_chunk_edges.extend([coord_data[i0], coord_data[i1-1]])
data_chunk_edges.extend([coord_data[i0], coord_data[i1 - 1]])
parallel_coord['data_chunks'] = data_chunks
parallel_coord['data_chunk_edges'] = data_chunk_edges
parallel_coord['data'] = coord_data
Expand All @@ -253,25 +253,24 @@ def make_parallel_coord_by_gap(coord: Union[Dict, xr.DataArray],
return parallel_coord



def interpolate_data_coords_onto_parallel_coords(
parallel_coords: dict,
input_data: Union[Dict, processing_set],
interpolation_method: {
"linear",
"nearest",
"nearest-up",
"zero",
"slinear",
"quadratic",
"cubic",
"previous",
"next",
} = "nearest",
assume_sorted: bool = True,
ps_partition: Optional[
str
] = None, # Current options are {'field_name', 'spectral_window_name'}
parallel_coords: dict,
input_data: Union[Dict, processing_set],
interpolation_method: {
"linear",
"nearest",
"nearest-up",
"zero",
"slinear",
"quadratic",
"cubic",
"previous",
"next",
} = "nearest",
assume_sorted: bool = True,
ps_partition: Optional[
list[str]
] = None, # Current options are {'field_name', 'spectral_window_name'}
) -> Dict:
"""Interpolate data_coords onto parallel_coords to create the ``node_task_data_mapping``.
Expand Down Expand Up @@ -415,10 +414,10 @@ def interpolate_data_coords_onto_parallel_coords(
if interp_index[i] == -1 and interp_index[i + 1] == -1:
chunk_indx_start_stop[chunk_index] = slice(None)
if (
pc["data_chunks_edges"][i] < input_data[xds_name][dim][0]
pc["data_chunks_edges"][i] < input_data[xds_name][dim][0]
) and (
pc["data_chunks_edges"][i + 1]
> input_data[xds_name][dim][-1]
pc["data_chunks_edges"][i + 1]
> input_data[xds_name][dim][-1]
):
interp_index[i] = 0
interp_index[i + 1] = -2
Expand Down Expand Up @@ -489,7 +488,7 @@ def interpolate_data_coords_onto_parallel_coords(

#
if (
empty_chunk
empty_chunk
): # The xds with xds_name has no data for the parallel chunk (no slice on one of the dims).
del node_task_data_mapping[task_id]["data_selection"][xds_name]
task_id += 1
Expand Down Expand Up @@ -594,7 +593,7 @@ def _partition_ps_by_non_dimensions(ps, ps_partition_keys):
d = {}
# We loop over the cartersian product of the keys
for multi_index in itertools.product(
*[ps_split_map[key] for key in ps_partition_keys]
*[ps_split_map[key] for key in ps_partition_keys]
):
# And for each key we look up the corresponding set of xds names
sets = [
Expand Down
2 changes: 1 addition & 1 deletion src/graphviper/utils/data/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _download(file: str, folder: str = ".") -> NoReturn:
return

if file not in file_meta_data["metadata"].keys():
logger.error("Requested file not found: {file}")
logger.error(f"Requested file not found: {file}")
logger.info(
f"For a list of available files try using "
f"{colorize.blue('graphviper.utils.data.list_files()')}."
Expand Down
143 changes: 59 additions & 84 deletions src/graphviper/utils/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from graphviper.utils.protego import Protego

from typing import Callable, Any, Union, NoReturn, Dict, List, Optional
from typing import Callable, Any, Union, NoReturn, Dict, List, Optional, Tuple
from types import ModuleType


Expand All @@ -32,10 +32,10 @@ def is_notebook() -> bool:


def validate(
config_dir: str = None,
custom_checker: Callable = None,
add_data_type: Any = None,
external_logger: Callable = None,
config_dir: str = None,
custom_checker: Callable = None,
add_data_type: Any = None,
external_logger: Callable = None,
):
def function_wrapper(function):
@functools.wraps(function)
Expand Down Expand Up @@ -69,19 +69,44 @@ def wrapper(*args, **kwargs):
return function_wrapper


def get_path(function: Callable) -> str:
# DEPRECATED
def _get_path(function: Callable) -> str:
module = inspect.getmodule(function)
module_path = inspect.getfile(module).rstrip(".py")

if "src" in module_path:
# This represents a local developer install

base_module_path = module_path.split("src/")[0]
return base_module_path

else:
# Here we hope that we can find the package in site-packages and it is unique
# otherwise the user should provide the configuration path in the decorator.

base_module_path = module_path.split("site-packages/")[0]
return "/".join((base_module_path, "site-packages/"))
return str(pathlib.Path(base_module_path).joinpath("site-packages/"))


def get_path(function: Callable) -> tuple[str, str]:
module = inspect.getmodule(function)
module_path = inspect.getfile(module).rstrip(".py")

# Determine whether this is a developer install or a site install
tag = "src" if "src" in module_path else "site-packages"

# The base directory should be to the left of the tag
base_module_path = module_path.split(f"{tag}/")[0]

# Split the module path up and determine what the package name and location is.
split_path = module_path.split("/")
index = split_path.index(tag) + 1
package_name = split_path[index]

# Build the full package path
base_module_path = pathlib.Path(base_module_path).joinpath(f"{tag}/{package_name}")

return str(base_module_path), module_path


def config_search(root: str = "/", module_name=None) -> Union[None, str]:
Expand Down Expand Up @@ -150,13 +175,13 @@ def verify_configuration(path: str, module: ModuleType) -> List[str]:


def verify(
function: Callable,
args: Dict,
meta_data: Dict[str, Union[Optional[str], Any]],
config_dir: str = None,
add_data_type: Any = None,
custom_checker: Callable = None,
external_logger: Callable = None,
function: Callable,
args: Dict,
meta_data: Dict[str, Union[Optional[str], Any]],
config_dir: str = None,
add_data_type: Any = None,
custom_checker: Callable = None,
external_logger: Callable = None,
) -> NoReturn:
colorize = console.Colorize()
function_name, module_name = meta_data.values()
Expand All @@ -179,90 +204,40 @@ def verify(
)
)

module_path = get_path(function)
logger.debug(f"Module path: {colorize.blue(module_path)}")

path = None

# First we need to find the parameter configuration files
if config_dir is not None:
tag, environment_path = config_dir.split(":")
if tag.lower() == "env":
if pathlib.Path(os.getenv(environment_path)).exists():
path = os.getenv(environment_path)
logger.debug(f"Configuration path set to: {path}")
package_path, module_path = get_path(function)
logger.info(f"Module path: {colorize.blue(package_path)}")

else:
logger.error(f"Configuration path provided does not exist: ENV={environment_path}")
# First we need to find the parameter configuration files
if pathlib.Path(package_path).joinpath("config").joinpath(f"{module_name}.param.json").exists():
logger.debug(f"Found configuration for {module_name}.{function_name} in: {colorize.blue(package_path)}")
path = str(pathlib.Path(package_path).joinpath("config"))

else:
# User specified configuration directory take precedent
if config_dir is not None:
if pathlib.Path(config_dir).joinpath(f"{module_name}.param.json").exists():
logger.debug(f"Setting configuration directory to user provided [{config_dir}]")
path = config_dir

# If the parameter configuration directory is not passed as an argument this environment variable should be set.
# In this case, the environment variable is set in the __init__ file of the astrohack module.
#
# This should be set according to the same pattern as PATH in terminal, ie. PATH=$PATH:/path/new
# the parsing code will expect this.
elif os.getenv("PARAMETER_CONFIG_PATH"):
for paths in os.getenv("PARAMETER_CONFIG_PATH").split(":"):
result = config_search(root=paths, module_name=module_name)
logger.debug("Result: {}".format(colorize.blue(result)))
if result:
path = result
logger.debug(
"PARAMETER_CONFIG_PATH: {dir}".format(dir=colorize.blue(result))
)
break

# If we can't find the configuration in the ENV path we will make a last ditch effort to find it in either src/,
# if that exists or looking in the python site-packages/ directory before giving up.
if not path:
logger.info(
"Failed to find module in PARAMETER_CONFIG_PATH ... attempting to check common directories ..."
)
path = config_search(root=module_path, module_name=module_name)
else:
logger.warning("User provided configuration directory does not exist. Searching for parameter files ...")

if not path:
logger.error(
"{function}: Cannot find parameter configuration directory.".format(
function=function_name
)
)
assert False
# If we have been delt only failure at this point, then we try one last ditch effort. Search the package directory!
if not path:
logger.debug(f"Couldn't determine parameter configuration directory, doing a depth search of {package_path}")
path = config_search(root=package_path, module_name=module_name)

else:
path = config_search(root=module_path, module_name=module_name)
if not path:
logger.error(
"{function}: Cannot find parameter configuration directory.".format(
function=function_name
)
)
assert False
logger.error(f"Cannot find parameter configuration directory for {function_name}")
raise FileNotFoundError

# Define parameter file name
parameter_file = module_name + ".param.json"

logger.debug(path + "/" + parameter_file)

# Load calling module to make this more general
module = importlib.import_module(function.__module__)

# This will check the configuration path and return the available modules
module_config_list = verify_configuration(path, module)

# Make sure that required module is present
if module_name not in module_config_list:
logger.error(
"Parameter file for {function} not found in {path}".format(
function=colorize.red(function_name),
path="/".join((path, parameter_file)),
)
)

raise FileNotFoundError
logger.debug(f"Parameter configuration file: {pathlib.Path(path).joinpath(parameter_file)}")

with open("/".join((path, parameter_file))) as json_file:
with open(pathlib.Path(path).joinpath(parameter_file)) as json_file:
schema = json.load(json_file)

if function_name not in schema.keys():
Expand Down
Loading

0 comments on commit e9b48ee

Please sign in to comment.