diff --git a/precondition/datamix_gemma/Pretokenization_for_Dolly,_MetaMath,_and_CodeAlpaca,_OpenWebMath.ipynb b/precondition/datamix_gemma/Pretokenization_for_Dolly,_MetaMath,_and_CodeAlpaca,_OpenWebMath.ipynb new file mode 100644 index 0000000..b820dea --- /dev/null +++ b/precondition/datamix_gemma/Pretokenization_for_Dolly,_MetaMath,_and_CodeAlpaca,_OpenWebMath.ipynb @@ -0,0 +1,1001 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9PEefz8wEcoY" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:\"string\"}\n", + "# weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')\n", + "ckpt_path = 'g_mini/2b_it_v1p1_orbax/1'\n", + "vocab_path = 'home/mriviere/g_mini/tokenizer/gemini_bpe_256k_v5_no_tags_cleared_v1.model'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yWaP_LPoEcoY" + }, + "outputs": [], + "source": [ + "# @title Python imports\n", + "import re\n", + "import string\n", + "\n", + "# We import JAX and some related packages.\n", + "import chex\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "# We will use tensorflow to handle the dataset\n", + "import tensorflow as tf\n", + "import tensorflow_datasets as tfds\n", + "\n", + "import enum as Enum\n", + "import random\n", + "\n", + "from absl import logging\n", + "import jax.dlpack\n", + "\n", + "\n", + "# Finally, we import Gemma.\n", + "from colabtools import adhoc_import\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tpHoTcNKoLyX" + }, + "outputs": [], + "source": [ + "a = jnp.zeros(10)\n", + "a.devices()\n", + "a = jax.device_put(a, jax.devices('cpu')[0])\n", + "print(a.devices())\n", + "#a" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iDPHJpx_zPja" + }, + "source": [ + "## Inspect Dolly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "pg8SfQH0EcoY" + }, + "outputs": [], + "source": [ + "ds = tfds.load('huggingface:databricks__databricks_dolly_15k', split='train')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9GVk7a3LG-h7" + }, + "outputs": [], + "source": [ + "ds.cardinality()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sAWW7VUoEgE_" + }, + "outputs": [], + "source": [ + "for element in ds.take(1):\n", + " print(element)\n", + " for key, val in element.items():\n", + " print(f'{key}: {val}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cmFvqmADzRUg" + }, + "source": [ + "## Inspect MetaMath" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1-YxWE0KPvwg" + }, + "outputs": [], + "source": [ + "ds = tfds.load('huggingface:meta_math__metamathqa', split='train')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Huu12Qw2zMFG" + }, + "outputs": [], + "source": [ + "ds.cardinality()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "z1PqJsKTzNtw" + }, + "outputs": [], + "source": [ + "for element in ds.take(1):\n", + " print(element)\n", + " for key, val in element.items():\n", + " print(f'{key}: {val}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Gj7UtBtg9Nd5" + }, + "source": [ + "## Inspect CodeAlpaca" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1FG5pIsA9PCW" + }, + "outputs": [], + "source": [ + "ds = tfds.load('huggingface:sahil2801__codealpaca_20k', split='train')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nlofOj5b9Qi-" + }, + "outputs": [], + "source": [ + "ds.cardinality()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eQSyvW-W9S51" + }, + "outputs": [], + "source": [ + "for element in ds.take(1):\n", + " print(element)\n", + " for key, val in element.items():\n", + " print(f'{key}: {val}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3M732bDgUojC" + }, + "source": [ + "## Inspect Open Web Math" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2tW9ryGUUqDV" + }, + "outputs": [], + "source": [ + "ds = tfds.load('huggingface:open_web_math__open_web_math', split='train')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5S6r9WcBUxXn" + }, + "outputs": [], + "source": [ + "ds.cardinality()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cTJdf2GwUzpg" + }, + "outputs": [], + "source": [ + "for element in ds.take(1):\n", + " print(element)\n", + " for key, val in element.items():\n", + " print(f'{key}: {val}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NYC42hJgEcoY" + }, + "source": [ + "### Tokenizer\n", + "\n", + "Let's start by loading our vocabulary base tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "TpyG5YW1EcoY" + }, + "outputs": [], + "source": [ + "vocab = spm.SentencePieceProcessor()\n", + "vocab.Load(vocab_path)\n", + "\n", + "vocab_list = [(id, vocab.IdToPiece(id)) for id in range(vocab.GetPieceSize())]\n", + "letters = ['A', 'B', 'C', 'D']\n", + "res_dict = {}\n", + "for id, piece in vocab_list:\n", + " try:\n", + " letter = piece[piece.find(next(filter(str.isalpha, piece)))]\n", + " if letter in letters:\n", + " res_dict[id] = letter\n", + " except:\n", + " pass\n", + "\n", + "class DatasetSplit(Enum.Enum):\n", + " TRAIN = 'train'\n", + "\n", + "@chex.dataclass(frozen=True)\n", + "class TrainingInput:\n", + " # Input tokens provided to model\n", + " input_tokens: jax.Array\n", + "\n", + " # A mask that determines which tokens contribute to the target loss\n", + " # calculation\n", + " target_mask: jax.Array" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "L9cjK0uxEcoY" + }, + "outputs": [], + "source": [ + "class GemmaTokenizer:\n", + " \"\"\"Custom wrapper around a SentencePieceProcessor for tensorflow.\"\"\"\n", + "\n", + " def __init__(self,\n", + " spm_processor: spm.SentencePieceProcessor):\n", + " self._spm_processor = spm_processor\n", + "\n", + " @property\n", + " def pad_id(self) -\u003e int:\n", + " \"\"\"Fast access to the pad id.\"\"\"\n", + " return self._spm_processor.pad_id()\n", + "\n", + " def tokenize(self,\n", + " example: str | bytes,\n", + " prefix: str = '',\n", + " suffix: str = '',\n", + " add_eos: bool = True) -\u003e jax.Array:\n", + " \"\"\"\n", + " Tokenization function.\n", + "\n", + " Args:\n", + " example: input string to tokenize.\n", + " prefix: prefix to add to the input string.\n", + " suffix: suffix to add to the input string.\n", + " add_eos: if True, add an end of sentence token at the end of the output\n", + " sequence.\n", + " Returns:\n", + " Tokens corresponding to the input string.\n", + " \"\"\"\n", + " int_list = [self._spm_processor.bos_id()]\n", + " int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))\n", + " if add_eos:\n", + " int_list.append(self._spm_processor.eos_id())\n", + "\n", + " return jnp.array(int_list, dtype=jnp.int32)\n", + "\n", + " def tokenize_tf_op(self,\n", + " str_tensor: tf.Tensor,\n", + " prefix: str = '',\n", + " suffix: str = '',\n", + " add_eos: bool = True) -\u003e tf.Tensor:\n", + " \"\"\"Tensforflow operator for the tokenize function.\"\"\"\n", + " encoded = tf.numpy_function(\n", + " self.tokenize,\n", + " [str_tensor, prefix, suffix, add_eos],\n", + " tf.int32)\n", + " encoded.set_shape([None])\n", + " return encoded\n", + "\n", + " def to_string(self, tokens: jax.Array) -\u003e str:\n", + " \"\"\"Convert an array of tokens to a string.\"\"\"\n", + " return self._spm_processor.EncodeIds(tokens.tolist())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "r-x0aTugEcoY" + }, + "source": [ + "### Data loader\n", + "\n", + "We can now wrap everything a build our data loader." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "qlFBVXRcnKen" + }, + "outputs": [], + "source": [ + "# @title\n", + "\"\"\"Base class for dataset builders.\"\"\"\n", + "\n", + "class DatasetBuilder:\n", + " \"\"\"Base class for dataset builders.\n", + "\n", + " This class provides the interface for dataset builders.\n", + " \"\"\"\n", + "\n", + " def __init__(self, tokenizer: GemmaTokenizer,\n", + " max_seq_len: int):\n", + " \"\"\"Constructor.\n", + "\n", + " Args:\n", + " tokenizer: Gemma tokenizer to use.\n", + " max_seq_len: size of each sequence in a given batch.\n", + " \"\"\"\n", + " self._tokenizer = tokenizer\n", + " self._max_seq_len = max_seq_len\n", + "\n", + " def _pad_up_to_max_len(\n", + " self, input_tensor: tf.Tensor, pad_value: int | bool\n", + " ) -\u003e tf.Tensor:\n", + " \"\"\"Pads the given tensor up to max_seq_len.\"\"\"\n", + " seq_len = tf.shape(input_tensor)[0]\n", + " to_pad = tf.maximum(0, self._max_seq_len - seq_len)\n", + " return tf.pad(\n", + " input_tensor,\n", + " [[0, to_pad]],\n", + " mode='CONSTANT',\n", + " constant_values=pad_value\n", + " )\n", + "\n", + " def get_train_dataset(self):\n", + " raise NotImplementedError()\n", + "\n", + " def get_validation_dataset(self, batch_size: int):\n", + " raise NotImplementedError()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_QoszBGn40mz" + }, + "source": [ + "## MetaMath" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9v2JF2e_kisZ" + }, + "outputs": [], + "source": [ + "class MetaMathDatasetBuilder(dataset_builder.DatasetBuilder):\n", + " \"\"\"Dataset builder for the MetaMath dataset.\"\"\"\n", + "\n", + " N_ITEMS = {DatasetSplit.TRAIN: 395000}\n", + "\n", + " BUFFER_SIZE_SHUFFLE = 100\n", + " QUERY_PREFIX = 'Query: \\n'\n", + " QUERY_SUFFIX = '\\n'\n", + " RESPONSE_PREFIX = 'Response: \\n'\n", + " RESPONSE_SUFFIX = '\\n'\n", + "\n", + " def __init__(\n", + " self, tokenizer: GemmaTokenizer, max_seq_len: int\n", + " ):\n", + " \"\"\"Constructor.\n", + "\n", + " Args:\n", + " tokenizer: Gemma tokenizer to use.\n", + " max_seq_len: size of each sequence in a given batch.\n", + " \"\"\"\n", + " self._tokenizer = tokenizer\n", + " self._base_data = {\n", + " DatasetSplit.TRAIN: tfds.load(\n", + " 'huggingface:meta_math__metamathqa', split='train',\n", + " ),\n", + " }\n", + " # logging.info(f'sciq size: {self._base_data[DatasetSplit.TRAIN].cardinality().numpy()}')\n", + " self._max_seq_len = max_seq_len\n", + "\n", + " def _tokenize_query(self, example: tf.Tensor):\n", + " \"\"\"Tokenization function for the Question.\"\"\"\n", + " return self._tokenizer.tokenize_tf_op(\n", + " example,\n", + " prefix=self.QUERY_PREFIX,\n", + " suffix=self.QUERY_SUFFIX,\n", + " add_eos=False,\n", + " )\n", + "\n", + " def _tokenize_response(self, example: tf.Tensor):\n", + " return self._tokenizer.tokenize_tf_op(\n", + " example,\n", + " prefix=self.RESPONSE_PREFIX,\n", + " suffix=self.RESPONSE_SUFFIX,\n", + " add_eos=False,\n", + " )\n", + "\n", + " def _to_training_input(\n", + " self,\n", + " query_tokens: jax.Array,\n", + " response_tokens: jax.Array,\n", + " ):\n", + " \"\"\"Build a training input from a tuple of source and destination tokens.\"\"\"\n", + "\n", + " # The input sequence fed to the model is simply the concatenation of the\n", + " # source and the destination.\n", + " tokens = tf.concat(\n", + " [query_tokens, response_tokens], axis=0\n", + " )\n", + "\n", + " # To prevent the model from updating based on the source (input)\n", + " # tokens, add a target mask to each input.\n", + " query_mask = tf.zeros_like(query_tokens, dtype=tf.bool)\n", + " response_mask = tf.ones_like(response_tokens, dtype=tf.bool)\n", + " mask = tf.concat([query_mask, response_mask], axis=0)\n", + "\n", + " # If the output tokens sequence is smaller than the target sequence size,\n", + " # then pad it with pad tokens.\n", + " tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)\n", + "\n", + " # Don't want to perform the backward pass on the pad tokens.\n", + " mask = self._pad_up_to_max_len(mask, False)\n", + " return TrainingInput( #type: ignore\n", + " input_tokens=tokens, #type:ignore\n", + " target_mask=mask, #type:ignore\n", + " )# type: ignore\n", + "\n", + " def get_train_dataset(self):\n", + " \"\"\"Build the training dataset.\"\"\"\n", + "\n", + " ds = self._base_data[DatasetSplit.TRAIN].map(\n", + " lambda x: (\n", + " self._tokenize_query(x['query']),\n", + " self._tokenize_response(x['response'])\n", + " )\n", + " )\n", + " ds = ds.map(lambda x, y: self._to_training_input(x, y),\n", + " num_parallel_calls=tf.data.AUTOTUNE)\n", + " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] \u003c= self._max_seq_len)\n", + " ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)\n", + " return ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SmH2TvurGo3p" + }, + "outputs": [], + "source": [ + "metamath_path = '/home/xinyic/metamath/metamath_data.tfrecord'\n", + "tokenizer = GemmaTokenizer(vocab)\n", + "metamath_dataset_builder = MetaMathDatasetBuilder(tokenizer, max_seq_len=1000) # why is this the case?\n", + "train_ds = metamath_dataset_builder.get_train_dataset()\n", + "train_ds = train_ds.as_numpy_iterator()\n", + "it = 0\n", + "with tf.io.TFRecordWriter(metamath_path) as writer:\n", + " for train_record in train_ds:\n", + " record_bytes = tf.train.Example( features=tf.train.Features(feature={'input_tokens': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.input_tokens.tobytes()])), \"target_mask\": tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.target_mask.tobytes()]))})).SerializeToString()\n", + " writer.write(record_bytes)\n", + " print(f'it: {it}')\n", + " it += 1" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LnGpfvaC-X5S" + }, + "source": [ + "## CodeAlpaca" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "n9WmCtsA1Gwk" + }, + "outputs": [], + "source": [ + "class CodeAlpacaDatasetBuilder(dataset_builder.DatasetBuilder):\n", + " \"\"\"Dataset builder for the CodeAlpaca dataset.\"\"\"\n", + "\n", + " N_ITEMS = {DatasetSplit.TRAIN: 20022}\n", + " BUFFER_SIZE_SHUFFLE = 100\n", + "\n", + " def __init__(\n", + " self, tokenizer: GemmaTokenizer, max_seq_len: int\n", + " ):\n", + " \"\"\"Constructor.\n", + "\n", + " Args:\n", + " tokenizer: Gemma tokenizer to use.\n", + " max_seq_len: size of each sequence in a given batch.\n", + " \"\"\"\n", + " self._tokenizer = tokenizer\n", + " self._base_data = {\n", + " DatasetSplit.TRAIN: tfds.load(\n", + " 'huggingface:sahil2801__codealpaca_20k', split='train'\n", + " ),\n", + " }\n", + " # logging.info(f'orca math size: {self._base_data[DatasetSplit.TRAIN].cardinality().numpy()}')\n", + " self._max_seq_len = max_seq_len\n", + "\n", + " def _tokenize_input(self, example: tf.Tensor):\n", + " \"\"\"Tokenization function for the Input.\"\"\"\n", + " return self._tokenizer.tokenize_tf_op(\n", + " example,\n", + " prefix='Input: \\n',\n", + " suffix='\\n',\n", + " add_eos=False,\n", + " )\n", + "\n", + " def _tokenize_instruction(self, example: tf.Tensor):\n", + " \"\"\"Tokenization function for the Instruction.\"\"\"\n", + " return self._tokenizer.tokenize_tf_op(\n", + " example,\n", + " prefix='Instruction: \\n',\n", + " suffix='\\n',\n", + " add_eos=False,\n", + " )\n", + "\n", + " def _tokenize_output(self, example: tf.Tensor):\n", + " \"\"\"Tokenization function for the Output.\"\"\"\n", + " return self._tokenizer.tokenize_tf_op(\n", + " example,\n", + " prefix='Output: \\n',\n", + " suffix='\\n',\n", + " add_eos=False,\n", + " )\n", + "\n", + " def _to_training_input(\n", + " self,\n", + " input_tokens: jax.Array,\n", + " instruction_tokens: jax.Array,\n", + " output_tokens: jax.Array,\n", + " ):\n", + " \"\"\"Build a training input from a tuple of source and destination tokens.\"\"\"\n", + "\n", + " # The input sequence fed to the model is simply the concatenation of the\n", + " # source and the destination.\n", + " tokens = tf.concat(\n", + " [input_tokens, instruction_tokens, output_tokens], axis=0\n", + " )\n", + "\n", + " # To prevent the model from updating based on the source (input)\n", + " # tokens, add a target mask to each input.\n", + " input_mask = tf.zeros_like(input_tokens, dtype=tf.bool)\n", + " instruction_mask = tf.zeros_like(instruction_tokens, dtype=tf.bool)\n", + " output_mask = tf.ones_like(output_tokens, dtype=tf.bool)\n", + " mask = tf.concat([input_mask, instruction_mask, output_mask], axis=0)\n", + "\n", + " # If the output tokens sequence is smaller than the target sequence size,\n", + " # then pad it with pad tokens.\n", + " tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)\n", + "\n", + " # Don't want to perform the backward pass on the pad tokens.\n", + " mask = self._pad_up_to_max_len(mask, False)\n", + " return TrainingInput( #type: ignore\n", + " input_tokens=tokens, #type:ignore\n", + " target_mask=mask, #type:ignore\n", + " )# type: ignore\n", + "\n", + " def get_train_dataset(self):\n", + " \"\"\"Build the training dataset.\"\"\"\n", + "\n", + " ds = self._base_data[DatasetSplit.TRAIN].map(\n", + " lambda x: (\n", + " self._tokenize_input(x['input']),\n", + " self._tokenize_instruction(x['instruction']),\n", + " self._tokenize_output(x['output']),\n", + " )\n", + " )\n", + " ds = ds.map(lambda x, y, z: self._to_training_input(x, y, z))\n", + " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] \u003c= self._max_seq_len)\n", + " ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)\n", + " return ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2l6tS-Vc_8K0" + }, + "outputs": [], + "source": [ + "codealpaca_path = '/home/xinyic/codealpaca/codealpaca_data.tfrecord'\n", + "tokenizer = GemmaTokenizer(vocab)\n", + "codealpaca_dataset_builder = CodeAlpacaDatasetBuilder(tokenizer, max_seq_len=1000)\n", + "train_ds = codealpaca_dataset_builder.get_train_dataset()\n", + "train_ds = train_ds.as_numpy_iterator()\n", + "it = 0\n", + "with tf.io.TFRecordWriter(codealpaca_path) as writer:\n", + " for train_record in train_ds:\n", + " record_bytes = tf.train.Example( features=tf.train.Features(feature={'input_tokens': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.input_tokens.tobytes()])), \"target_mask\": tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.target_mask.tobytes()]))})).SerializeToString()\n", + " writer.write(record_bytes)\n", + " print(f'it: {it}')\n", + " it += 1" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fCMwRA661aI3" + }, + "source": [ + "## Dolly-15K" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1tHFWjCYO5G0" + }, + "outputs": [], + "source": [ + "class DollyDatasetBuilder(dataset_builder.DatasetBuilder):\n", + " \"\"\"Dataset builder for the Dolly dataset.\"\"\"\n", + "\n", + " N_ITEMS = {DatasetSplit.TRAIN: 15011}\n", + "\n", + "\n", + " BUFFER_SIZE_SHUFFLE = 100\n", + " CONTEXT_PREFIX = 'Context: \\n'\n", + " CONTEXT_SUFFIX = '\\n'\n", + " INSTRUCTION_PREFIX = 'Instruction: \\n'\n", + " INSTRUCTION_SUFFIX = '\\n'\n", + " RESPONSE_PREFIX = 'Response: \\n'\n", + " RESPONSE_SUFFIX = '\\n'\n", + "\n", + " def __init__(\n", + " self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int\n", + " ):\n", + " \"\"\"Constructor.\n", + "\n", + " Args:\n", + " tokenizer: Gemma tokenizer to use.\n", + " max_seq_len: size of each sequence in a given batch.\n", + " \"\"\"\n", + " self._tokenizer = tokenizer\n", + " self._base_data = {\n", + " DatasetSplit.TRAIN: tfds.load(\n", + " 'huggingface:databricks__databricks_dolly_15k', split='train'\n", + " ),\n", + " }\n", + " logging.info(f'dolly size: {self._base_data[DatasetSplit.TRAIN].cardinality().numpy()}')\n", + " self._max_seq_len = max_seq_len\n", + "\n", + " def _tokenize_context(self, example: tf.Tensor):\n", + " \"\"\"Tokenization function for the context.\"\"\"\n", + " return self._tokenizer.tokenize_tf_op(\n", + " example,\n", + " prefix=self.CONTEXT_PREFIX,\n", + " suffix=self.CONTEXT_SUFFIX,\n", + " add_eos=False,\n", + " )\n", + "\n", + " def _tokenize_response(self, example: tf.Tensor):\n", + " \"\"\"Tokenization function for the Response.\"\"\"\n", + " return self._tokenizer.tokenize_tf_op(\n", + " example,\n", + " prefix=self.RESPONSE_PREFIX,\n", + " suffix=self.RESPONSE_SUFFIX,\n", + " add_eos=False,\n", + " )\n", + "\n", + " def _tokenize_instruction(self, example: tf.Tensor):\n", + " \"\"\"Tokenization function for the instruction.\"\"\"\n", + " return self._tokenizer.tokenize_tf_op(\n", + " example,\n", + " prefix=self.INSTRUCTION_PREFIX,\n", + " suffix=self.INSTRUCTION_SUFFIX,\n", + " add_eos=False,\n", + " )\n", + "\n", + " def _to_training_input(\n", + " self,\n", + " instruction_tokens: jax.Array,\n", + " context_tokens: jax.Array,\n", + " response_tokens: jax.Array,\n", + " ):\n", + " \"\"\"Build a training input from a tuple of source and destination tokens.\"\"\"\n", + "\n", + " # The input sequence fed to the model is simply the concatenation of the\n", + " # source and the destination.\n", + " tokens = tf.concat(\n", + " [instruction_tokens, context_tokens, response_tokens], axis=0\n", + " )\n", + "\n", + " # To prevent the model from updating based on the source (input)\n", + " # tokens, add a target mask to each input.\n", + " context_mask = tf.zeros_like(context_tokens, dtype=tf.bool)\n", + " instruction_mask = tf.zeros_like(instruction_tokens, dtype=tf.bool)\n", + " response_mask = tf.ones_like(response_tokens, dtype=tf.bool)\n", + " mask = tf.concat([instruction_mask, context_mask, response_mask], axis=0)\n", + "\n", + " # If the output tokens sequence is smaller than the target sequence size,\n", + " # then pad it with pad tokens.\n", + " tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)\n", + "\n", + " # Don't want to perform the backward pass on the pad tokens.\n", + " mask = self._pad_up_to_max_len(mask, False)\n", + " return dataset_builder.TrainingInput( #type: ignore\n", + " input_tokens=tokens, #type:ignore\n", + " target_mask=mask, #type:ignore\n", + " )# type: ignore\n", + "\n", + " def get_train_dataset(self):\n", + " \"\"\"Build the training dataset.\"\"\"\n", + "\n", + " ds = self._base_data[DatasetSplit.TRAIN].map(\n", + " lambda x: (\n", + " self._tokenize_instruction(x['instruction']),\n", + " self._tokenize_context(x['context']),\n", + " self._tokenize_response(x['response'])\n", + " )\n", + " )\n", + " ds = ds.map(lambda x, y, z: self._to_training_input(x, y, z))\n", + " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] \u003c= self._max_seq_len)\n", + " ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)\n", + "\n", + " return ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YCUF_wIIKleS" + }, + "outputs": [], + "source": [ + "dolly_path = '/home/xinyic/dolly/dolly_data.tfrecord'\n", + "tokenizer = GemmaTokenizer(vocab)\n", + "dolly_dataset_builder = DollyDatasetBuilder(tokenizer, max_seq_len=1000) # why is this the case?\n", + "train_ds = dolly_dataset_builder.get_train_dataset()\n", + "train_ds = train_ds.as_numpy_iterator()\n", + "it = 0\n", + "with tf.io.TFRecordWriter(dolly_path) as writer:\n", + " for train_record in train_ds:\n", + " record_bytes = tf.train.Example( features=tf.train.Features(feature={'input_tokens': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.input_tokens.tobytes()])), \"target_mask\": tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.target_mask.tobytes()]))})).SerializeToString()\n", + " writer.write(record_bytes)\n", + " print(f'it: {it}')\n", + " it += 1" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ux42ZSnMUarN" + }, + "source": [ + "## Open Web Math" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cH-_ajIYUXSc" + }, + "outputs": [], + "source": [ + "class OpenWebMathDatasetBuilder(dataset_builder.DatasetBuilder):\n", + " \"\"\"Dataset builder for the Open Web Math dataset.\"\"\"\n", + "\n", + " N_ITEMS = {DatasetSplit.TRAIN: 6315233}\n", + "\n", + "\n", + " BUFFER_SIZE_SHUFFLE = 100\n", + " CONTEXT_PREFIX = 'Context: \\n'\n", + " CONTEXT_SUFFIX = '\\n'\n", + " INSTRUCTION_PREFIX = 'Instruction: \\n'\n", + " INSTRUCTION_SUFFIX = '\\n'\n", + " RESPONSE_PREFIX = 'Response: \\n'\n", + " RESPONSE_SUFFIX = '\\n'\n", + "\n", + " def __init__(\n", + " self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int\n", + " ):\n", + " \"\"\"Constructor.\n", + "\n", + " Args:\n", + " tokenizer: Gemma tokenizer to use.\n", + " max_seq_len: size of each sequence in a given batch.\n", + " \"\"\"\n", + " self._tokenizer = tokenizer\n", + " self._base_data = {\n", + " DatasetSplit.TRAIN: tfds.load(\n", + " 'huggingface:databricks__databricks_dolly_15k', split='train'\n", + " ),\n", + " }\n", + " logging.info(f'dolly size: {self._base_data[DatasetSplit.TRAIN].cardinality().numpy()}')\n", + " self._max_seq_len = max_seq_len\n", + "\n", + " def _tokenize_context(self, example: tf.Tensor):\n", + " \"\"\"Tokenization function for the context.\"\"\"\n", + " return self._tokenizer.tokenize_tf_op(\n", + " example,\n", + " prefix=self.CONTEXT_PREFIX,\n", + " suffix=self.CONTEXT_SUFFIX,\n", + " add_eos=False,\n", + " )\n", + "\n", + " def _tokenize_response(self, example: tf.Tensor):\n", + " \"\"\"Tokenization function for the Response.\"\"\"\n", + " return self._tokenizer.tokenize_tf_op(\n", + " example,\n", + " prefix=self.RESPONSE_PREFIX,\n", + " suffix=self.RESPONSE_SUFFIX,\n", + " add_eos=False,\n", + " )\n", + "\n", + " def _tokenize_instruction(self, example: tf.Tensor):\n", + " \"\"\"Tokenization function for the instruction.\"\"\"\n", + " return self._tokenizer.tokenize_tf_op(\n", + " example,\n", + " prefix=self.INSTRUCTION_PREFIX,\n", + " suffix=self.INSTRUCTION_SUFFIX,\n", + " add_eos=False,\n", + " )\n", + "\n", + " def _to_training_input(\n", + " self,\n", + " instruction_tokens: jax.Array,\n", + " context_tokens: jax.Array,\n", + " response_tokens: jax.Array,\n", + " ):\n", + " \"\"\"Build a training input from a tuple of source and destination tokens.\"\"\"\n", + "\n", + " # The input sequence fed to the model is simply the concatenation of the\n", + " # source and the destination.\n", + " tokens = tf.concat(\n", + " [instruction_tokens, context_tokens, response_tokens], axis=0\n", + " )\n", + "\n", + " # To prevent the model from updating based on the source (input)\n", + " # tokens, add a target mask to each input.\n", + " context_mask = tf.zeros_like(context_tokens, dtype=tf.bool)\n", + " instruction_mask = tf.zeros_like(instruction_tokens, dtype=tf.bool)\n", + " response_mask = tf.ones_like(response_tokens, dtype=tf.bool)\n", + " mask = tf.concat([instruction_mask, context_mask, response_mask], axis=0)\n", + "\n", + " # If the output tokens sequence is smaller than the target sequence size,\n", + " # then pad it with pad tokens.\n", + " tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)\n", + "\n", + " # Don't want to perform the backward pass on the pad tokens.\n", + " mask = self._pad_up_to_max_len(mask, False)\n", + " return dataset_builder.TrainingInput( #type: ignore\n", + " input_tokens=tokens, #type:ignore\n", + " target_mask=mask, #type:ignore\n", + " )# type: ignore\n", + "\n", + " def get_train_dataset(self):\n", + " \"\"\"Build the training dataset.\"\"\"\n", + "\n", + " ds = self._base_data[DatasetSplit.TRAIN].map(\n", + " lambda x: (\n", + " self._tokenize_instruction(x['instruction']),\n", + " self._tokenize_context(x['context']),\n", + " self._tokenize_response(x['response'])\n", + " )\n", + " )\n", + " ds = ds.map(lambda x, y, z: self._to_training_input(x, y, z))\n", + " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] \u003c= self._max_seq_len)\n", + " ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)\n", + "\n", + " return ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2cF0EeA0h6ky" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "last_runtime": { + "build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu", + "kind": "private" + }, + "private_outputs": true, + "provenance": [ + { + "file_id": "10egMS7uVoQJlKU0Z6Zm6rmBfYm-QozJO", + "timestamp": 1729907590715 + }, + { + "file_id": "1NDGtgpcaiOgC15gOirVXaLC6Af_TArwW", + "timestamp": 1721739741159 + }, + { + "file_id": "/third_party/py/gemma/colabs/fine_tuning_tutorial.ipynb", + "timestamp": 1720546560685 + } + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}