Skip to content

Commit

Permalink
Allow loading datasets from disk using load_from_disk method. (#53)
Browse files Browse the repository at this point in the history
* feat: Allow loading datasets from disk using `load_from_disk` method.

* Fixing the type of error being catched.
  • Loading branch information
dmilcevski authored Dec 1, 2023
1 parent 80e952e commit 15279e7
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions src/alignment/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
from typing import List, Literal, Optional

from datasets import DatasetDict, concatenate_datasets, load_dataset
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError

from .configs import DataArguments

Expand Down Expand Up @@ -145,20 +146,17 @@ def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffl
for ds, frac in dataset_mixer.items():
fracs.append(frac)
for split in splits:
try:
# Try first if dataset on a Hub repo
dataset = load_dataset(ds, split=split)
except DatasetGenerationError:
# If not, check local dataset
dataset = load_from_disk(os.path.join(ds, split))

if "train" in split:
raw_train_datasets.append(
load_dataset(
ds,
split=split,
)
)
raw_train_datasets.append(dataset)
elif "test" in split:
raw_val_datasets.append(
load_dataset(
ds,
split=split,
)
)
raw_val_datasets.append(dataset)
else:
raise ValueError(f"Split type {split} not recognized as one of test or train.")

Expand Down

0 comments on commit 15279e7

Please sign in to comment.