diff --git a/Dockerfile b/Dockerfile index 71f2468..8354865 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ RUN apt-get update && \ apt-get install -y git build-essential \ liblapack-dev libopenblas-dev libgl1 libxrender1 -RUN git clone git@github.com:scil-vital/TractOracleNet.git +RUN git clone https://github.com/scil-vital/TractOracleNet.git WORKDIR /TractOracleNet diff --git a/TractOracleNet/runners/predictor.py b/TractOracleNet/runners/predictor.py index 20ed3f1..c8476a3 100644 --- a/TractOracleNet/runners/predictor.py +++ b/TractOracleNet/runners/predictor.py @@ -10,7 +10,7 @@ from tqdm import tqdm from scilpy.io.utils import ( - assert_inputs_exist, assert_outputs_exist) + assert_inputs_exist, assert_outputs_exist, add_overwrite_arg) from TractOracleNet.utils import get_data, save_filtered_streamlines from TractOracleNet.models.utils import get_model @@ -37,7 +37,7 @@ def __init__( self.threshold = train_dto['threshold'] self.batch_size = train_dto['batch_size'] self.out = train_dto['out'] - self.save_rejected = train_dto['rejected'] + self.rejected = train_dto['rejected'] self.nofilter = train_dto['nofilter'] def predict(self, model, sft): @@ -133,22 +133,25 @@ def run(self): predictions = self.predict(model, sft) # Save the filtered streamlines - if not self.nofilter and not self.dense: + if not self.dense: # Fetch the streamlines that passed the gauntlet - ids = np.argwhere( - predictions > self.threshold).squeeze() + if self.nofilter: + ids = np.arange(0, len(predictions)) + else: + # Save the filtered streamlines + print('Kept {}/{} streamlines ({}%).'.format(len(ids), + len(sft), (len(ids) / len(sft) * 100))) - new_sft = StatefulTractogram.from_sft(sft[ids].streamlines, sft) + ids = np.argwhere( + predictions > self.threshold).squeeze() - # Save the filtered streamlines - print('Kept {}/{} streamlines ({}%).'.format(len(ids), - len(sft), (len(ids) / len(sft) * 100))) + new_sft = StatefulTractogram.from_sft(sft[ids].streamlines, sft) # Save the streamlines save_filtered_streamlines(new_sft, predictions[ids], self.out) # Save the streamlines that rejected - if self.save_rejected: + if self.rejected: # Fetch the streamlines that rejected rejected_ids = np.setdiff1d(np.arange(predictions.shape[0]), ids) @@ -158,7 +161,7 @@ def run(self): # Save the streamlines save_filtered_streamlines( - new_sft, predictions[rejected_ids], self.save_rejected) + new_sft, predictions[rejected_ids], self.rejected) else: # Save all streamlines sft.data_per_point['score'] = predictions @@ -199,6 +202,7 @@ def _build_arg_parser(parser): ' Streamlines\' endpoints should be uniformized for' ' best visualization.') + add_overwrite_arg(parser) def parse_args(): """ Filter a tractogram. """ diff --git a/install.sh b/install.sh index 186d731..b35a239 100755 --- a/install.sh +++ b/install.sh @@ -13,12 +13,8 @@ if [ -x "$(command -v nvidia-smi)" ]; then FOUND_CUDA=$(nvidia-smi | grep "CUDA Version" | awk '{print $9}' | sed 's/\.//g') if (( $FOUND_CUDA == 116 )); then CUDA_VERSION="cu116" - elif (( $FOUND_CUDA == 117 )); then + elif (( $FOUND_CUDA >= 117 )); then CUDA_VERSION="cu117" - elif (( $FOUND_CUDA == 118 )); then - CUDA_VERSION="cu118" - elif (( $FOUND_CUDA == 121 )); then - CUDA_VERSION="cu121" else CUDA_VERSION="cpu" echo "CUDA version ${FOUND_CUDA} is not compatible. Installing PyTorch without CUDA support."