diff --git a/MuRaL/nn_utils.py b/MuRaL/nn_utils.py index f607af8..9ad8d0e 100644 --- a/MuRaL/nn_utils.py +++ b/MuRaL/nn_utils.py @@ -123,7 +123,7 @@ def run_time_view_model_predict_m(model, dataloader, criterion, device, n_class, total_loss = 0 batch_count = 0 get_batch_time_recode = 0 - get_batch_train_recode = 0 + get_batch_predict_recode = 0 step_time = time.time() get_batch_time = time.time() @@ -136,7 +136,7 @@ def run_time_view_model_predict_m(model, dataloader, criterion, device, n_class, print("get 500 batch used time: ", get_batch_time_recode) get_batch_time_recode = 0 - batch_train_time = time.time() + batch_predict_time = time.time() cat_x = cat_x.to(device) cont_x = cont_x.to(device) distal_x = distal_x.to(device) @@ -155,10 +155,10 @@ def run_time_view_model_predict_m(model, dataloader, criterion, device, n_class, print(f"Batch Number: {batch_count}; Mean Time of 500 batch: {(time.time()-step_time)}") step_time = time.time() - get_batch_train_recode += time.time() - batch_train_time + get_batch_predict_recode += time.time() - batch_predict_time if batch_count % 500 == 0: - print("training 500 batch used time:", get_batch_train_recode) - get_batch_train_recode = 0 + print("training 500 batch used time:", get_batch_predict_recode) + get_batch_predict_recode = 0 get_batch_time = time.time() if device == torch.device('cpu'): diff --git a/MuRaL/run_predict.py b/MuRaL/run_predict.py index 0fbf1db..61e43a9 100644 --- a/MuRaL/run_predict.py +++ b/MuRaL/run_predict.py @@ -113,12 +113,6 @@ def parse_arguments(parser): between RAM memory and preprocessing speed. It is recommended to use 300k. Default: 300000.""" ).strip()) - optional.add_argument('--sampled_segments', type=int, metavar='INT', default=[10], nargs='+', - help=textwrap.dedent(""" - Number of segments chosen for generating samples for batches in DataLoader. - Default: 10. - """ ).strip()) - optional.add_argument('--pred_batch_size', metavar='INT', default=16, help=textwrap.dedent(""" Size of mini batches for prediction. Default: 16. @@ -208,7 +202,7 @@ def main(): ref_genome= args.ref_genome pred_batch_size = args.pred_batch_size - sampled_segments = args.sampled_segments + sampled_segments = 1 # Output file path pred_file = args.pred_file diff --git a/MuRaL/run_train_TL_raytune.py b/MuRaL/run_train_TL_raytune.py index 2c63c74..a999794 100644 --- a/MuRaL/run_train_TL_raytune.py +++ b/MuRaL/run_train_TL_raytune.py @@ -441,7 +441,9 @@ def main(): if not sampled_segments: sampled_segments = args.sampled_segments = para_read_from_config('sampled_segments', config) - + else: + sampled_segments = args.sampled_segments[0] + args.seq_only = config['seq_only'] diff --git a/MuRaL/run_train_raytune.py b/MuRaL/run_train_raytune.py index e4e3f1a..8a49ea9 100644 --- a/MuRaL/run_train_raytune.py +++ b/MuRaL/run_train_raytune.py @@ -209,7 +209,7 @@ def parse_arguments(parser): Default: 0.25. """ ).strip()) - learn_args.add_argument('--segment_center', type=int, metavar='INT', default=300000, + learn_args.add_argument('--segment_center', type=int, metavar='INT', default=300000, help=textwrap.dedent(""" The maximum encoding unit of the sequence. It affects trade-off between RAM memory and preprocessing speed. It is recommended to use 300k. diff --git a/dirichlet.patch b/dirichlet.patch deleted file mode 100644 index ea7c258..0000000 --- a/dirichlet.patch +++ /dev/null @@ -1,22 +0,0 @@ -diff --git a/setup.py b/../../script_v0/dirichlet_python/setup.py -index f6ce80f..77dfced 100644 ---- a/setup.py -+++ b/../../script_v0/dirichlet_python/setup.py -@@ -29,11 +29,11 @@ setuptools.setup( - ], - python_requires='>=3.6', - install_requires = [ -- 'numpy>=1.14.2', -- 'scipy>=1.0.0', -- 'scikit-learn>=0.19.1', -- 'jax', -- 'jaxlib', -- 'autograd', -+ 'numpy>=1.14.2' -+ 'scipy>=1.0.0' -+ 'scikit-learn>=0.19.1' -+ 'jax' -+ 'jaxlib' -+ 'autograd' - ] - ) diff --git a/dirichlet_install.sh b/dirichlet_install.sh deleted file mode 100644 index 11e0e06..0000000 --- a/dirichlet_install.sh +++ /dev/null @@ -1,10 +0,0 @@ -# download dirchlet package -# git clone git@github.com:dirichletcal/dirichlet_python.git -cd dirichlet_python - -# patch -patch -p1 < ../dirichlet.patch -# install dirchlet package -python setup.py install -cd .. - diff --git a/dirichlet_python/dirichletcal/calib/multinomial.py b/dirichlet_python/dirichletcal/calib/multinomial.py index e41f97b..6616938 100644 --- a/dirichlet_python/dirichletcal/calib/multinomial.py +++ b/dirichlet_python/dirichletcal/calib/multinomial.py @@ -274,7 +274,7 @@ def _newton_update(weights_0, X, XX_T, target, k, method_, maxiter=int(1024), updates = gradient / hessian else: try: - inverse = scipy.linalg.pinv2(hessian) + inverse = scipy.linalg.pinv(hessian) updates = np.matmul(inverse, gradient) except (raw_np.linalg.LinAlgError, ValueError) as err: logging.error(err) diff --git a/dirichlet_python/setup.py b/dirichlet_python/setup.py index 77dfced..37767a4 100644 --- a/dirichlet_python/setup.py +++ b/dirichlet_python/setup.py @@ -29,11 +29,11 @@ ], python_requires='>=3.6', install_requires = [ - 'numpy>=1.14.2' - 'scipy>=1.0.0' - 'scikit-learn>=0.19.1' - 'jax' - 'jaxlib' + 'numpy>=1.14.2', + 'scipy>=1.0.0', + 'scikit-learn>=0.19.1', + 'jax', + 'jaxlib', 'autograd' ] ) diff --git a/environment.yml b/environment.yml index d1afaf9..14348ca 100644 --- a/environment.yml +++ b/environment.yml @@ -1,11 +1,13 @@ channels: + - pytorch - bioconda - conda-forge - defaults dependencies: - python=3.8.5 - pip=22.0.4 - - numpy=1.21.2 + - numpy=1.23.5 + - pandas=1.4.1 - cudatoolkit=11.3 - pytorch=1.10.2 - pybigwig=0.3.17 @@ -24,3 +26,4 @@ dependencies: - jaxlib==0.4.13 - autograd==1.6.2 - protobuf==3.20.0 + - setuptools==69.0.2 diff --git a/environment_cpu.yml b/environment_cpu.yml index be74ffb..eb1df07 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -1,11 +1,13 @@ channels: + - pytorch - bioconda - conda-forge - defaults dependencies: - python=3.8.5 - pip=22.0.4 - - numpy=1.21.2 + - numpy=1.23.5 + - pandas=1.4.1 - pytorch=1.10.2 - cpuonly=2.0 - pybigwig=0.3.17 @@ -24,3 +26,4 @@ dependencies: - jaxlib==0.4.13 - autograd==1.6.2 - protobuf==3.20.0 + - setuptools==69.0.2 diff --git a/setup.py b/setup.py index de80e6f..84f3422 100644 --- a/setup.py +++ b/setup.py @@ -4,8 +4,12 @@ from setuptools import setup, find_packages -import os -os.system('bash dirichlet_install.sh') +import subprocess +try: + subprocess.check_call(['pip', 'install', '.'], cwd='dirichlet_python') +except subprocess.CalledProcessError as e: + print(f"Error installing dirichlet_python: {e}") + exit(1) def get_version(): try: