Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
nsheff committed Jan 14, 2025
1 parent 1592df3 commit 0257887
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 75 deletions.
11 changes: 8 additions & 3 deletions interactive_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@


from refget import fasta_file_to_seqcol

fasta_file_to_seqcol("demo_fasta/demo0.fa")
fasta_file_to_seqcol("/home/nsheff/sandbox/HG002.alt.pat.f1_v2.unmasked.fa.gz")

Expand All @@ -235,28 +236,33 @@
sha512t24u_digest_bytes("test")


var = 'VB84F57JQBHY8SDPXJQ06P1STTEV8ZUPLX5T76OYSH88DM8IV9A9KSYJ0DBLQM665UZZC5033LW26RR1AQFKQ84BGHZLFZKRRBWN877GB1IP15KVJ9LX6Z1K0087WJGFB0HYKYQQZ0R24A77M6PY1ZY9E3TAAO7SI2UEI3MORCQUUXFB59L54M01NF8IWYA4579DL47VJ9DVAPUVJHX9FIK4BUN71UMOOIML9UKLIGDP9N80JJMIEJQKM85M8BEWQF8KS8DN77VVWBMO76VR2K8EHTBF403IWYAF3J2Z9WEDMQKAA0IJKGCHLYK7Y3WSODLDNWCRUU3UGSVPE7CJFZI6O9RMLARB1Y66CZ125L8ERKFU0UY53SNO5DNGGC0D5DOGH8MCZQYRJXELOQNA7KOHLVPBMRNQYVP1A49I3H2Y6DE8FG0WAXIZ6RKFNEBU4ES3X18E79KRJO2DKCXAYYMRPM4WMX8WIC9EP4K6Q07T7UM7G4S4TG31FS9WOUX9BIVFL0642307KV2SFG9YNH5IZB9IJ4TUMM1D25NBBECUMMM28JGZQ2765SZOYRL3BVIZBU1G8NN8Z7N2WEK08FV22LA5YE7GB6GTCEH4ISA2WBTBUEJH65V3MX8EVEU2FDLZKI02O27N3GQT556ZI2YY44GZDWV1Z21RWOWM411X4FFJ2BZ7LQAG5I9J3U4BIF7F3ESKOOIHG388V0PG95ZF5AW1IGD2T6VM9TPJN3HRNWGMHAU3M6O1C6HJBMHB6P26CZJEBZ1K75L35KV9S9UU4NYUJH0KADJNXFI9WVRI7AG89OOVWXQ2GSBT4QUYJW1UZDJ53JQ8M1FVS8J3KTVCSXUW97M8WCNNKQOFB7LHC4YHUZSRKA103L6DPBQG3MTAKPZ9VW5PTQ9QXFX5TMJHU5YOTJAFZ80ISSPX5ZUPABZ1SUZWHRR951CBZ3TYYO88BFNLGR1HKSCZZWG471PPW561NLGINKKBBD9P'
var = "VB84F57JQBHY8SDPXJQ06P1STTEV8ZUPLX5T76OYSH88DM8IV9A9KSYJ0DBLQM665UZZC5033LW26RR1AQFKQ84BGHZLFZKRRBWN877GB1IP15KVJ9LX6Z1K0087WJGFB0HYKYQQZ0R24A77M6PY1ZY9E3TAAO7SI2UEI3MORCQUUXFB59L54M01NF8IWYA4579DL47VJ9DVAPUVJHX9FIK4BUN71UMOOIML9UKLIGDP9N80JJMIEJQKM85M8BEWQF8KS8DN77VVWBMO76VR2K8EHTBF403IWYAF3J2Z9WEDMQKAA0IJKGCHLYK7Y3WSODLDNWCRUU3UGSVPE7CJFZI6O9RMLARB1Y66CZ125L8ERKFU0UY53SNO5DNGGC0D5DOGH8MCZQYRJXELOQNA7KOHLVPBMRNQYVP1A49I3H2Y6DE8FG0WAXIZ6RKFNEBU4ES3X18E79KRJO2DKCXAYYMRPM4WMX8WIC9EP4K6Q07T7UM7G4S4TG31FS9WOUX9BIVFL0642307KV2SFG9YNH5IZB9IJ4TUMM1D25NBBECUMMM28JGZQ2765SZOYRL3BVIZBU1G8NN8Z7N2WEK08FV22LA5YE7GB6GTCEH4ISA2WBTBUEJH65V3MX8EVEU2FDLZKI02O27N3GQT556ZI2YY44GZDWV1Z21RWOWM411X4FFJ2BZ7LQAG5I9J3U4BIF7F3ESKOOIHG388V0PG95ZF5AW1IGD2T6VM9TPJN3HRNWGMHAU3M6O1C6HJBMHB6P26CZJEBZ1K75L35KV9S9UU4NYUJH0KADJNXFI9WVRI7AG89OOVWXQ2GSBT4QUYJW1UZDJ53JQ8M1FVS8J3KTVCSXUW97M8WCNNKQOFB7LHC4YHUZSRKA103L6DPBQG3MTAKPZ9VW5PTQ9QXFX5TMJHU5YOTJAFZ80ISSPX5ZUPABZ1SUZWHRR951CBZ3TYYO88BFNLGR1HKSCZZWG471PPW561NLGINKKBBD9P"

import random
import timeit
import string

strs = []
for i in range(1000):
strs.append(''.join(random.choices(string.ascii_uppercase + string.digits, k=1000)))
strs.append("".join(random.choices(string.ascii_uppercase + string.digits, k=1000)))


# Define the functions to benchmark
def benchmark_sha512t24u_digest():
for var in strs:
sha512t24u_digest(var)


def benchmark_gc_count_checksum():
for var in strs:
gc_count.checksum_from_str(var).sha512


def benchmark_gc_sha512_only():
for var in strs:
gc_count.sha512t24u_digest(var)


# Benchmark the functions
time_sha512t24u_digest = timeit.timeit(benchmark_sha512t24u_digest, number=1000)
time_gc_count_checksum = timeit.timeit(benchmark_gc_count_checksum, number=1000)
Expand All @@ -265,4 +271,3 @@ def benchmark_gc_sha512_only():
print(f"sha512t24u_digest: {time_sha512t24u_digest} seconds")
print(f"gc_count.checksum_from_str().sha512: {time_gc_count_checksum} seconds")
print(f"gc_count.sha512t24u_digest: {time_gc_512_only} seconds")

59 changes: 35 additions & 24 deletions refget/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def load_json(source):
:return: The loaded JSON as a dictionary.
"""
if os.path.isfile(source):
with open(source, 'r', encoding='utf-8') as file:
with open(source, "r", encoding="utf-8") as file:
return json.load(file)
else:
try:
Expand All @@ -71,22 +71,25 @@ def load_json(source):
raise e



class SeqColAgent(object):
def __init__(self, engine, inherent_attrs=None):
self.engine = engine
self.inherent_attrs = inherent_attrs

def get(self, digest: str, return_format: str = "level2", attribute: str = None, itemwise_limit: int = None) -> SequenceCollection:
def get(
self,
digest: str,
return_format: str = "level2",
attribute: str = None,
itemwise_limit: int = None,
) -> SequenceCollection:
with Session(self.engine) as session:
statement = select(SequenceCollection).where(
SequenceCollection.digest == digest
)
statement = select(SequenceCollection).where(SequenceCollection.digest == digest)
results = session.exec(statement)
seqcol = results.one_or_none()
if not seqcol:
raise ValueError(f"SequenceCollection with digest '{digest}' not found")
if attribute:
if attribute:
return getattr(seqcol, attribute).value
elif return_format == "level2":
return seqcol.level2()
Expand Down Expand Up @@ -176,7 +179,10 @@ def list_by_offset(self, limit=50, offset=0):
list_res = session.exec(list_stmt)
count = cnt_res.one()
seqcols = list_res.all()
return {"pagination": { "page": int(offset/limit), "page_size": limit, "total": count}, "results": seqcols}
return {
"pagination": {"page": int(offset / limit), "page_size": limit, "total": count},
"results": seqcols,
}

def list(self, page_size=100, cursor=None):
with Session(self.engine) as session:
Expand All @@ -189,9 +195,7 @@ def list(self, page_size=100, cursor=None):
)
else:
list_stmt = (
select(SequenceCollection)
.limit(page_size)
.order_by(SequenceCollection.digest)
select(SequenceCollection).limit(page_size).order_by(SequenceCollection.digest)
)
cnt_stmt = select(func.count(SequenceCollection.digest))
cnt_res = session.exec(cnt_stmt)
Expand Down Expand Up @@ -270,9 +274,7 @@ def add_from_fasta_pep(self, pep: peppy.Project, fa_root):
for s in pep.samples:
file_path = os.path.join(fa_root, s.fasta)
print(f"Fasta to be loaded: Name: {s.sample_name} File path: {file_path}")
pangenome_obj[s.sample_name] = self.parent.seqcol.add_from_fasta_file(
file_path
)
pangenome_obj[s.sample_name] = self.parent.seqcol.add_from_fasta_file(file_path)

p = build_pangenome_model(pangenome_obj)
return self.add(p)
Expand All @@ -285,7 +287,11 @@ def list_by_offset(self, limit=50, offset=0):
list_res = session.exec(list_stmt)
count = cnt_res.one()
seqcols = list_res.all()
return {"pagination": { "page": int(offset/limit), "page_size": limit, "total": count}, "results": seqcols}
return {
"pagination": {"page": int(offset / limit), "page_size": limit, "total": count},
"results": seqcols,
}


class AttributeAgent(object):
def __init__(self, engine):
Expand All @@ -312,16 +318,17 @@ def list(self, attribute_type, offset=0, limit=50):
list_res = session.exec(list_stmt)
count = cnt_res.one()
seqcols = list_res.all()
return {"pagination": { "page": offset*limit, "page_size": limit, "total": count}, "results": seqcols}
return {
"pagination": {"page": offset * limit, "page_size": limit, "total": count},
"results": seqcols,
}

def search(self, attribute_type, digest, offset=0, limit=50):
Attribute = ATTR_TYPE_MAP[attribute_type]
with Session(self.engine) as session:
list_stmt = (
select(SequenceCollection)
.where(
getattr(SequenceCollection, f"{attribute_type}_digest") == digest
)
.where(getattr(SequenceCollection, f"{attribute_type}_digest") == digest)
.offset(offset)
.limit(limit)
)
Expand All @@ -332,7 +339,10 @@ def search(self, attribute_type, digest, offset=0, limit=50):
list_res = session.exec(list_stmt)
count = cnt_res.one()
seqcols = list_res.all()
return {"pagination": { "page": offset*limit, "page_size": limit, "total": count}, "results": seqcols}
return {
"pagination": {"page": offset * limit, "page_size": limit, "total": count},
"results": seqcols,
}


class RefgetDBAgent(object):
Expand All @@ -341,7 +351,6 @@ class RefgetDBAgent(object):
"""

def __init__(
<<<<<<< HEAD
self,
engine: Optional[SqlalchemyDatabaseEngine] = None,
postgres_str: Optional[str] = None,
Expand All @@ -358,7 +367,7 @@ def __init__(
POSTGRES_DB = os.getenv("POSTGRES_DB")
POSTGRES_USER = os.getenv("POSTGRES_USER")
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD")
schema=f'{SCHEMA_FILEPATH}/seqcol.json',
schema = (f"{SCHEMA_FILEPATH}/seqcol.json",)
postgres_str = URL.create(
"postgresql",
username=POSTGRES_USER,
Expand Down Expand Up @@ -392,7 +401,9 @@ def __init__(
self.inherent_attrs = self.schema_dict["ga4gh"]["inherent"]
except KeyError:
self.inherent_attrs = inherent_attrs
_LOGGER.warning(f"No 'inherent' attributes found in schema; using defaults: {inherent_attrs}")
_LOGGER.warning(
f"No 'inherent' attributes found in schema; using defaults: {inherent_attrs}"
)
else:
_LOGGER.warning("No schema provided; using defaults")
self.schema_dict = None
Expand Down Expand Up @@ -449,6 +460,6 @@ def truncate(self):
result = session.exec(statement)
statement = delete(SortedSequencesAttr)
result = session.exec(statement)

session.commit()
return result1.rowcount
38 changes: 24 additions & 14 deletions refget/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

_LOGGER = logging.getLogger(__name__)


class SequencesClient(object):
"""
A client for interacting with a refget sequences API.
"""

def __init__(self, seq_urls=["https://www.ebi.ac.uk/ena/cram/sequence/"]):
self.seq_urls = seq_urls

Expand Down Expand Up @@ -84,17 +86,17 @@ def list_collections(self, page=None, page_size=None, attribute=None, attribute_
"""
params = {}
if page is not None:
params['page'] = page
params["page"] = page
if page_size is not None:
params['page_size'] = page_size
params["page_size"] = page_size

if attribute and attribute_digest:
endpoint = f"/list/collections/{attribute}/{attribute_digest}"
else:
endpoint = "/list/collections"

return _try_urls(self.seqcol_api_urls, endpoint, params=params)

def list_attributes(self, attribute, page=None, page_size=None):
"""
Lists all available values for a given attribute with optional paging support.
Expand All @@ -103,40 +105,43 @@ def list_attributes(self, attribute, page=None, page_size=None):
attribute (str): The attribute to list values for.
page (int, optional): The page number to retrieve. Defaults to None.
page_size (int, optional): The number of items per page. Defaults to None.
Returns:
dict: The JSON response containing the list of available values for the attribute.
"""
params = {}
if page is not None:
params['page'] = page
params["page"] = page
if page_size is not None:
params['page_size'] = page_size
params["page_size"] = page_size

endpoint = f"/list/attributes/{attribute}"
return _try_urls(self.seqcol_api_urls, endpoint, params=params)


class RefGetClient(SequencesClient, SeqColClient):
"""
A client for interacting with a refget API, for either
A client for interacting with a refget API, for either
sequences or sequence collections, or both.
"""

def __init__(self,
seq_api_urls=["https://www.ebi.ac.uk/ena/cram/sequence"],
seqcol_api_urls=["https://seqcolapi.databio.org"]):
def __init__(
self,
seq_api_urls=["https://www.ebi.ac.uk/ena/cram/sequence"],
seqcol_api_urls=["https://seqcolapi.databio.org"],
):
if seq_api_urls:
SequencesClient.__init__(self, seq_api_urls)
if seqcol_api_urls:
SeqColClient.__init__(self, seqcol_api_urls)

def __repr__(self):
return f"<RefGetClient(seq_api_urls={self.seq_api_urls}, seqcol_api_urls={self.seqcol_api_urls})>"


# Utilities


def _wrap_response(response):
"""
Wraps a response in a try/except block to catch any exceptions.
Expand All @@ -152,7 +157,8 @@ def _wrap_response(response):
return response.json()
except requests.exceptions.RequestException as e:
raise RuntimeError(f"An error occurred: {e}")



def _try_urls(urls, endpoint, params=None):
"""
Tries the list of URLs in succession until a successful response is received.
Expand All @@ -173,8 +179,12 @@ def _try_urls(urls, endpoint, params=None):
result = _wrap_response(response)
_LOGGER.info(f"Successful response from {base_url}")
return result
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout, requests.exceptions.RequestException) as e:
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.RequestException,
) as e:
_LOGGER.debug(f"Error from {base_url}: {e}")
errors.append(f"Error from {base_url}: {e}")
error_message = "All URLs failed:\n" + "\n".join(errors)
raise ConnectionError(error_message)
raise ConnectionError(error_message)
1 change: 1 addition & 0 deletions refget/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _schema_path(name):
GTARS_INSTALLED = False
try:
from gtars.digests import digest_fasta, sha512t24u_digest

GTARS_INSTALLED = True
except ImportError:
GTARS_INSTALLED = False
Expand Down
10 changes: 7 additions & 3 deletions refget/hash_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import binascii
from typing import Union


def trunc512_digest(seq, offset=24):
digest = hashlib.sha512(seq.encode("utf-8")).digest()
hex_digest = binascii.hexlify(digest[:offset])
Expand Down Expand Up @@ -39,15 +40,13 @@ def trunc512_to_ga4gh(trunc512):
return _ga4gh_format(digest, digest_length)



# def trunc512_digest(seq, offset=24) -> str:
# """Deprecated GA4GH digest function"""
# digest = hashlib.sha512(seq.encode()).digest()
# hex_digest = binascii.hexlify(digest[:offset])
# return hex_digest.decode()



# def sha512t24u_digest_bytes(seq: Union[str, bytes], offset: int = 24) -> str:
# """GA4GH digest function"""
# if isinstance(seq, str):
Expand All @@ -56,6 +55,7 @@ def trunc512_to_ga4gh(trunc512):
# tdigest_b64us = base64.urlsafe_b64encode(digest[:offset])
# return tdigest_b64us.decode("ascii")


def py_sha512t24u_digest(seq: Union[str, bytes], offset: int = 24) -> str:
"""GA4GH digest function in python"""
if isinstance(seq, str):
Expand All @@ -64,16 +64,20 @@ def py_sha512t24u_digest(seq: Union[str, bytes], offset: int = 24) -> str:
tdigest_b64us = base64.urlsafe_b64encode(digest[:offset])
return tdigest_b64us.decode("ascii")


def md5(seq):
return hashlib.md5(seq.encode()).hexdigest()


from .const import GTARS_INSTALLED

if GTARS_INSTALLED:
from gtars.digests import sha512t24u_digest as gtars_sha512t24u_digest

sha512t24u_digest = gtars_sha512t24u_digest
else:

def gtars_sha512t24u_digest(seq):
raise Exception("gtars is not installed")

sha512t24u_digest = py_sha512t24u_digest
Loading

0 comments on commit 0257887

Please sign in to comment.