Skip to content

Commit

Permalink
add: cli for checking all fits files and inspect tables from fits fil…
Browse files Browse the repository at this point in the history
…es in given config file
  • Loading branch information
gosow9 committed Jul 25, 2024
1 parent bb73f92 commit 93d3c87
Show file tree
Hide file tree
Showing 16 changed files with 282 additions and 20 deletions.
10 changes: 10 additions & 0 deletions fits2db/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from ..config.config_model import ConfigType
class DBWriter:
def load_db(self, config:ConfigType):
db_type = config["database"]["type"]
loader = self._get_loader(format)
return loader

def _get_loader(self, format:str):
if format.lower() == 'mysql':
return
19 changes: 19 additions & 0 deletions fits2db/adapters/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from abc import ABC, abstractmethod
from sqlalchemy import create_engine, MetaData
import pandas as pd


class BaseLoader(ABC):
def __init__(self, db_url):
self.engine = create_engine(db_url)
self.metadata = MetaData()
self.metadata.reflect(bind=self.engine)

@abstractmethod
def create_table_if_not_exists(self, table_name, df: pd.DataFrame):
pass

@abstractmethod
def upsert_data(self, table_name, df: pd.DataFrame, unique_key):
pass

18 changes: 18 additions & 0 deletions fits2db/adapters/mysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .base import BaseLoader
from ..config.config_model import ConfigType

class MySQL(BaseLoader):
def __init__(self, config:ConfigType):
self.config = config
db_url = self.create_db_url()
super().__init__(db_url)

def create_db_url(self):
user = self.config["database"]["user"]
password = self.config["database"]["password"]
host = self.config["database"]["host"]
port = self.config["database"]["port"]
db_name = self.config["database"]["db_name"]
return f'mysql+mysqlconnector://{user}:{password}@{host}:{port}/{db_name}'


Empty file added fits2db/cli/__init__.py
Empty file.
94 changes: 94 additions & 0 deletions fits2db/cli/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import click
from ..core import Fits2db, get_all_fits


def validate_output_filename(ctx, param, value):
if ctx.params.get("csv") and not value.endswith(".csv"):
raise click.BadParameter(
"CSV filename must have a .csv extension."
)
if ctx.params.get("excel") and not value.endswith(".xlsx"):
raise click.BadParameter(
"Excel filename must have a .xlsx extension."
)
if ctx.params.get("csv") or ctx.params.get("excel"):
if not value:
raise click.BadParameter(
"Output filename is required when --csv or --excel is specified."
)
return value


@click.group()
def cli():
"""Fits2DB CLI"""
pass


@click.command()
@click.argument("config_path", type=click.Path(exists=True))
@click.option(
"-f",
"--folder",
default=False,
is_flag=True,
help="Show all fits files in given folder",
)
def files(folder, config_path):
"""Prints all files from given config.yaml file"""
if folder:
files = get_all_fits([config_path])
else:
fits = Fits2db(config_path)
files = fits.get_files()
for f in files:
click.echo(f)


@click.command()
@click.argument("config_path", default=".", type=click.Path(exists=True))
@click.option(
"-m",
"--matrix",
default=False,
is_flag=True,
help="Show all tables and files as matrix",
)
@click.option(
"--csv", default=False, is_flag=True, help="Save the output as csv"
)
@click.option(
"--excel", default=False, is_flag=True, help="Save the output as excel"
)
@click.option(
"--filename",
default="output.csv",
callback=validate_output_filename,
help="The filename for the output (required if --csv or --excel is specified).",
)
def tables(config_path, matrix, csv, excel, filename):
"""Prints all table names from all fits files from given config.yaml file"""
fits = Fits2db(config_path)
format = None
if csv:
format = "csv"
elif excel:
format = "excel"

if matrix:
m = fits.create_table_matrix(
output_format=format, output_file=filename
)
if format is None:
click.echo(m.to_string())
else:
names, _ = fits.get_table_names()
for f in names:
click.echo(f)


cli.add_command(files)
cli.add_command(tables)

if __name__ == "__main__":
cli()
6 changes: 3 additions & 3 deletions fits2db/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
This module contains the configuration validation for the FITS to database application.
"""

from .config_model import ConfigFileValidator, ApplicationConfig, ConfigType

from typing import Union
import yaml
import os
from jinja2 import Environment, FileSystemLoader

from .config_model import ConfigFileValidator, ApplicationConfig, ConfigType


def get_configs(path: Union[str, os.PathLike]) -> ConfigType:
"""Loads config file from given path
Expand All @@ -32,7 +32,7 @@ def get_configs(path: Union[str, os.PathLike]) -> ConfigType:
except (TypeError, ValueError) as err:
print("Config file validation error:", err)
return {}

return data


Expand Down
4 changes: 2 additions & 2 deletions fits2db/config/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def check_cedentials(self) -> Self:

class FitsConfig(BaseModel):
"""Fits files configuraion."""

name: str
paths: list
tables: dict


class ConfigFileValidator(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions fits2db/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .core import Fits2db, get_all_fits
__all__ = ["Fits2db","get_all_fits"]
85 changes: 85 additions & 0 deletions fits2db/core/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Core module to extract fits files and insert into db"""

from ..fits import FitsFile
from ..config import get_configs
import os
from pathlib import Path
from tqdm import tqdm
import pandas as pd


def get_all_fits(paths: list):
all_fits_files = []
for path in paths:
if os.path.isdir(path):
for root, _, files in os.walk(path):
for file in files:
if file.endswith(".fits"):
all_fits_files.append(os.path.join(root, file))
elif os.path.isfile(path) and path.endswith(".fits"):
all_fits_files.append(path)
return all_fits_files


def flatten_and_deduplicate(input_list):
unique_values = set()
flat_list = []

def flatten(item):
if isinstance(item, list):
for sub_item in item:
flatten(sub_item)
else:
if item not in unique_values:
unique_values.add(item)
flat_list.append(item)

flatten(input_list)
return flat_list


class Fits2db:
def __init__(self, config_path):
self.config_path = Path(config_path)
self.configs = get_configs(config_path)
self.fits_file_paths = self.get_files()

def get_files(self):
paths = self.configs["fits_files"]["paths"]
return list(dict.fromkeys(get_all_fits(paths)))

def get_table_names(self):
self.all_table_names = []
self.file_table_dict = {}
for path in tqdm(self.fits_file_paths):
path = Path(path)
try:
file = FitsFile(path)
self.all_table_names.append(file.table_names)
self.file_table_dict[path] = file.table_names
except ValueError as err:
print(err)

self.all_table_names = flatten_and_deduplicate(self.all_table_names)
return self.all_table_names, self.file_table_dict

def create_table_matrix(self, output_format=None, output_file=None):
all_table_names, file_table_dict = self.get_table_names()
file_names = [path.name for path in file_table_dict.keys()]
df = pd.DataFrame(index=file_names, columns=all_table_names)
for path, tables in file_table_dict.items():
file_name = path.name
for table in tables:
df.at[file_name, table] = "X"

df = df.fillna("")

if output_format and file_name:
current_dir = os.getcwd()
full_file_path = os.path.join(current_dir, output_file)
if output_format.lower() == "csv":
df.to_csv(full_file_path)
elif output_format.lower() == "excel":
df.to_excel(full_file_path, index=True)

return df
4 changes: 2 additions & 2 deletions fits2db/fits/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

counter = count()


class FitsTable(TypedDict):
@dataclass
class FitsTable:
name: str
meta: pd.DataFrame
data: pd.DataFrame
Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
name = "fits2db"
version = "0.0.1b1"
authors = [
{ name="Cédric Renda", email="gosow9@hotmail.com" },
{ name="Cédric Renda", email="cedric.renda@hotmail.com" },
]
description = "A small cli tool to load fits files into a sql database"
readme = "README.md"
Expand Down Expand Up @@ -39,13 +39,17 @@ dependencies = [
"Jinja2 >= 2.5.0",
"astropy >= 6.0.0",
"numpy >= 1.0.0, < 2.0",
"sqlalchemy >= 2.0.0"
"sqlalchemy >= 2.0.0",
"tqdm >= 4.0.0"
]

[project.urls]
Homepage = "https://github.com/pmodwrc/fits2db"
Issues = "https://github.com/pmodwrc/fits2db/issues"

[project.scripts]
fits2db = "fits2db.cli.cli:cli"

[tool.setuptools.packages.find]
where = ["."]

Expand Down
14 changes: 11 additions & 3 deletions tests/unit/data/config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
database:
type: mysql
host: local
host: localhost
user: kd
password: pfd
token:
token:
port: 3006


fits_files:
name: bla
paths:
- tests\unit\data\test.fits
- tests\unit\data
- \\ad.pmodwrc.ch\Institute\Projects\FY-3E\JOIM\16_Flight_Data\LEVEL_1
tables:
names:
- HOUSEKEEPING



17 changes: 17 additions & 0 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
@@ -1 +1,18 @@

import pytest
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
from pydantic import ValidationError
from fits2db.config.config import get_configs

current_dir = os.path.dirname(__file__)
sample_config_file = os.path.join(current_dir, 'data', 'config.yaml')


def test_get_configs_valid():
"""Test getting a valid config file"""
configs = get_configs(sample_config_file)
assert configs["database"]["type"] == "mysql"


6 changes: 4 additions & 2 deletions tests/unit/test_config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,12 @@ def test_valid_application_config():
token=None,
port=3306,
)
fits_config = FitsConfig(name="example_fits")
paths = [r"tests\unit\data\test.fits",r"tests\unit\data"]
tables = {}
fits_config = FitsConfig(paths=paths, tables=tables)
app_config = ApplicationConfig(database=db_config, fits_files=fits_config)
assert app_config.database.host == "localhost"
assert app_config.fits_files.name == "example_fits"
assert app_config.fits_files.paths == [r"tests\unit\data\test.fits",r"tests\unit\data"]


def test_invalid_application_config():
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

import pytest


def test_import():
from fits2db.adapters.mysql import MySQL
9 changes: 3 additions & 6 deletions tests/unit/test_fits_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,9 @@ def fits_file(create_sample_fits_file):
def test_get_table_valid(fits_file):
"""Test getting a valid table from the FITS file."""
table = fits_file.get_table(SAMPLE_TABLE_NAME)
assert 'name' in table
assert 'data' in table
assert 'meta' in table
assert table['name'] == SAMPLE_TABLE_NAME
assert isinstance(table['data'], pd.DataFrame)
assert isinstance(table['meta'], pd.DataFrame)
assert table.name == SAMPLE_TABLE_NAME
assert isinstance(table.data, pd.DataFrame)
assert isinstance(table.meta, pd.DataFrame)

def test_get_table_invalid_table_name(fits_file):
"""Test getting a table with an invalid name from the FITS file."""
Expand Down

0 comments on commit 93d3c87

Please sign in to comment.