Skip to content

Commit

Permalink
Run black formatting on hpc-prune.py.
Browse files Browse the repository at this point in the history
Run black formatting on hpc-prune.py.
  • Loading branch information
ghuls committed Nov 21, 2022
1 parent b69accb commit b979209
Showing 1 changed file with 41 additions and 29 deletions.
70 changes: 41 additions & 29 deletions scripts/hpc-prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def create_logging_handler(debug: bool) -> logging.Handler:
# to DEBUG, information will still be outputted. In addition, errors and warnings are more
# severe than info and therefore will always be outputted to the log.
ch.setLevel(logging.DEBUG if debug else logging.INFO)
ch.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
ch.setFormatter(
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
return ch


Expand All @@ -49,19 +51,23 @@ def __exit__(*x):


def create_argument_parser():
parser = argparse.ArgumentParser(prog='hpc-prune')
parser.add_argument('config_filename',
type=str, default=CONFIG_FILENAME,
help='Name of configuration file to use.')
parser.add_argument('-i', '--input',
type=str,
help='Intput file/stream.')
parser.add_argument('-o', '--output',
type=str,
help='Output file/stream.')
parser.add_argument('--num_workers',
type=int, default=cpu_count(),
help='The number of workers to use. Only valid of using dask_multiprocessing, custom_multiprocessing or local as mode. (default: {}).'.format(cpu_count()))
parser = argparse.ArgumentParser(prog="hpc-prune")
parser.add_argument(
"config_filename",
type=str,
default=CONFIG_FILENAME,
help="Name of configuration file to use.",
)
parser.add_argument("-i", "--input", type=str, help="Intput file/stream.")
parser.add_argument("-o", "--output", type=str, help="Output file/stream.")
parser.add_argument(
"--num_workers",
type=int,
default=cpu_count(),
help="The number of workers to use. Only valid of using dask_multiprocessing, custom_multiprocessing or local as mode. (default: {}).".format(
cpu_count()
),
)

return parser

Expand All @@ -76,40 +82,46 @@ def run(args):
cfg = ConfigParser()
cfg.read(args.config_filename)

in_fname = cfg['data']['modules'] if not args.input else args.input
in_fname = cfg["data"]["modules"] if not args.input else args.input
LOGGER.info("Loading modules from {}.".format(in_fname))
# Loading from YAML is extremely slow. Therefore this is a potential performance improvement.
# Potential improvements are switching to JSON or to use a CLoader:
# https://stackoverflow.com/questions/27743711/can-i-speedup-yaml
if in_fname.endswith('.yaml'):
if in_fname.endswith(".yaml"):
modules = load_from_yaml(in_fname)
else:
with open(in_fname, 'rb') as f:
with open(in_fname, "rb") as f:
modules = pickle.load(f)
# Filter out modules with to few genes.
min_genes = int(cfg['parameters']['min_genes'])
min_genes = int(cfg["parameters"]["min_genes"])
modules = list(filter(lambda m: len(m) >= min_genes, modules))

LOGGER.info("Loading databases.")

def name(fname):
return os.path.splitext(os.path.basename(fname))[0]
db_fnames = list(mapcat(glob.glob, cfg['data']['databases'].split(";")))

db_fnames = list(mapcat(glob.glob, cfg["data"]["databases"].split(";")))
dbs = [RankingDatabase(fname=fname, name=name(fname)) for fname in db_fnames]

LOGGER.info("Calculating regulons.")
motif_annotations_fname = cfg['data']['motif_annotations']
mode= cfg['parameters']['mode']
motif_annotations_fname = cfg["data"]["motif_annotations"]
mode = cfg["parameters"]["mode"]
with ProgressBar() if mode == "dask_multiprocessing" else NoProgressBar():
df = prune2df(dbs, modules, motif_annotations_fname,
rank_threshold=int(cfg['parameters']['rank_threshold']),
auc_threshold=float(cfg['parameters']['auc_threshold']),
nes_threshold=float(cfg['parameters']['nes_threshold']),
client_or_address=mode,
module_chunksize=cfg['parameters']['chunk_size'],
num_workers=args.num_workers)
df = prune2df(
dbs,
modules,
motif_annotations_fname,
rank_threshold=int(cfg["parameters"]["rank_threshold"]),
auc_threshold=float(cfg["parameters"]["auc_threshold"]),
nes_threshold=float(cfg["parameters"]["nes_threshold"]),
client_or_address=mode,
module_chunksize=cfg["parameters"]["chunk_size"],
num_workers=args.num_workers,
)

LOGGER.info("Writing results to file.")
df.to_csv(cfg['parameters']['output'] if not args.output else args.output)
df.to_csv(cfg["parameters"]["output"] if not args.output else args.output)


if __name__ == "__main__":
Expand Down

0 comments on commit b979209

Please sign in to comment.