diff --git a/just_prs.py b/just_prs.py index 4a804a5..0b88d08 100644 --- a/just_prs.py +++ b/just_prs.py @@ -12,17 +12,17 @@ INVERS = "invers" class CravatPostAggregator (BasePostAggregator): - prs:dict = {} - prs_names:list = [] + prs: dict = {} + prs_names: list = [] # prs5_rsids = [] - sql_get_prs:str = """SELECT name, title, total, invers FROM prs;""" + sql_get_prs: str = """SELECT name, title, total, invers FROM prs;""" - def check(self): + def check(self) -> bool: return True - def setup (self): + def setup(self) -> None: self.sql_file:str = str(Path(__file__).parent) + "/data/prs.sqlite" if Path(self.sql_file).exists(): self.prsconn:sqlite3.Connection = sqlite3.connect(self.sql_file) @@ -53,7 +53,7 @@ def setup (self): self.result_cursor.execute("DELETE FROM prs;") - def get_prs_dataframe(self, name): + def get_prs_dataframe(self, name: str) -> pl.DataFrame: import platform sql:str = f"SELECT pos, chrom, effect_allele, weight FROM prs, position, weights WHERE prs.name = '{name}' AND prs.id = weights.prsid AND weights.posid = position.id" ol_pl = platform.platform() @@ -61,12 +61,12 @@ def get_prs_dataframe(self, name): conn_url = f"sqlite://{urllib.parse.quote(self.sql_file)}" else: conn_url = f"sqlite://{self.sql_file}" - return pl.read_sql(sql, conn_url) + return pl.read_database(sql, conn_url) - def calculate_prs(self, data_df, name): + def calculate_prs(self, data_df: pl.DataFrame, name: str) -> tuple: prs_df:pl.DataFrame = self.get_prs_dataframe(name) - prs_df = prs_df.with_column((pl.col('chrom') + pl.col('pos').cast(pl.datatypes.Utf8)).alias("key")) + prs_df = prs_df.with_columns((pl.col('chrom') + pl.col('pos').cast(pl.datatypes.Utf8)).alias("key")) unite:pl.DataFrame = data_df.join(prs_df, left_on='key', right_on="key") unite1 = unite.filter(pl.col("A") == pl.col("effect_allele")) unite2 = unite.filter(pl.col("B") == pl.col("effect_allele")) @@ -79,12 +79,12 @@ def calculate_prs(self, data_df, name): return float(res1.item()) + float(res2.item()), unite.shape[0] - def process_file(self): + def process_file(self) -> None: self._close_db_connection() data_df = self.get_df("variant", None, 0) data_df = data_df.select(['base__pos', 'vcfinfo__zygosity', 'base__ref_base', 'base__alt_base', 'base__chrom']) - data_df = data_df.with_column(pl.col('vcfinfo__zygosity').fill_null("het")) - data_df = data_df.with_column((pl.col('base__chrom') + pl.col('base__pos').cast(pl.datatypes.Utf8)).alias("key")) + data_df = data_df.with_columns(pl.col('vcfinfo__zygosity').fill_null("het")) + data_df = data_df.with_columns((pl.col('base__chrom') + pl.col('base__pos').cast(pl.datatypes.Utf8)).alias("key")) het_zygot = data_df.filter(pl.col('vcfinfo__zygosity') == 'het') het_zygot = het_zygot.with_columns([pl.col('base__ref_base').alias("A"), pl.col('base__alt_base').alias("B")]) @@ -100,7 +100,7 @@ def process_file(self): self._open_db_connection() - def cleanup (self): + def cleanup(self) -> None: if self.result_cursor is not None: self.result_cursor.close() if self.result_conn is not None: @@ -169,7 +169,7 @@ def get_percent(self, name:str, value:float) -> float: return min_percent - def postprocess(self): + def postprocess(self) -> None: sql:str = """ INSERT INTO prs (name, sum, avg, count, title, total, percent, fraction, invers) VALUES (?,?,?,?,?,?,?,?,?);""" for name in self.prs_names: avg:float = 0 diff --git a/just_prs.yml b/just_prs.yml index fa1507e..6a464dc 100644 --- a/just_prs.yml +++ b/just_prs.yml @@ -1,5 +1,5 @@ title: Prs postagregator -version: 0.1.1 +version: 0.1.2 data_version: 0.1.1 requires_opencravat: '>=1.8.1' type: postaggregator @@ -16,6 +16,8 @@ input_columns: - base__ref_base - base__chrom - vcfinfo__zygosity +pypi_dependency: + - polars>=0.19.0 tags: - prs - longevity @@ -36,3 +38,4 @@ developer: release_note: 0.1.0: initial commit 0.1.1: updated database, added new prs + 0.1.2: updated dependencies and changed polars code according to its new version