diff --git a/Gemma/business-email-assistant/model-tuning/notebook/bakery_inquiry_model_tuned_with_gemma.ipynb b/Gemma/business-email-assistant/model-tuning/notebook/bakery_inquiry_model_tuned_with_gemma.ipynb new file mode 100644 index 0000000..d911a39 --- /dev/null +++ b/Gemma/business-email-assistant/model-tuning/notebook/bakery_inquiry_model_tuned_with_gemma.ipynb @@ -0,0 +1,1110 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "##### Copyright 2024 Google LLC." + ], + "metadata": { + "id": "ZdRRNrRu8obc" + } + }, + { + "cell_type": "code", + "source": [ + "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "metadata": { + "id": "H2hlKa7K8rGt" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SDEExiAk4fLb" + }, + "source": [ + "# Fine-tune Gemma models using LORA for Cake Boss Example\n", + "\n", + "Adding additional changes based on feedback\n", + "\n", + "
\n",
+ " ![]() | \n",
+ "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
+ "
\n"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n",
+ "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n"
+ ],
+ "text/html": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Tokenizer (type) ┃ Vocab # ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ gemma_tokenizer (GemmaTokenizer) │ 256,000 │\n", + "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" + ], + "text/html": [ + "
Model: \"gemma_causal_lm\"\n",
+ "
\n"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
+ "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+ "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
+ "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+ "│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m2,614,341,888\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
+ "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+ "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n",
+ "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n"
+ ],
+ "text/html": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ padding_mask (InputLayer) │ (None, None) │ 0 │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_ids (InputLayer) │ (None, None) │ 0 │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ gemma_backbone │ (None, None, 2304) │ 2,614,341,888 │ padding_mask[0][0], │\n", + "│ (GemmaBackbone) │ │ │ token_ids[0][0] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_embedding │ (None, None, 256000) │ 589,824,000 │ gemma_backbone[0][0] │\n", + "│ (ReversibleEmbedding) │ │ │ │\n", + "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" + ], + "text/html": [ + "
Total params: 2,614,341,888 (9.74 GB)\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" + ], + "text/html": [ + "
Trainable params: 2,614,341,888 (9.74 GB)\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ], + "text/html": [ + "
Non-trainable params: 0 (0.00 B)\n", + "\n" + ] + }, + "metadata": {} + } + ], + "source": [ + "gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma2_instruct_2b_en\")\n", + "gemma_lm.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PVLXadptyo34" + }, + "source": [ + "### Cake prompt\n", + "This is from the untuned model. The results aren't exactly what we'd like\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZwQz3xxxKciD", + "outputId": "26003a01-3469-45b7-ec24-a6fa27a66f35", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "From the following get the type of inquiry, (order or request for information), filling, flavor, size, and pickup location and put it into a json\n", + "Hi,\n", + "I'd like to order a red velvet cake with custard filling. Please make it 8 inch round\n", + "and pick it up from the bakery on 22nd street.\n", + "\n", + "Thanks!\n", + " \n", + "```json\n", + "{\n", + " \"inquiry_type\": \"order\",\n", + " \"filling\": \"custard\",\n", + " \"flavor\": \"red velvet\",\n", + " \"size\": \"8 inch round\",\n", + " \"pickup_location\": \"22nd street bakery\"\n", + "}\n", + "```\n", + "```json\n", + "{\n", + " \"inquiry_type\": \"request\",\n", + " \"filling\": \"custard\",\n", + " \"flavor\": \"red velvet\",\n", + " \"size\": \"8 inch round\",\n", + " \"pickup_location\": \"22nd street bakery\"\n", + "}\n", + "```\n", + "```json\n", + "{\n", + " \"inquiry_type\": \"order\",\n", + " \"filling\": \"custard\",\n", + " \"flavor\": \"red velvet\",\n", + " \"size\": \"8 inch round\",\n", + " \"pickup_location\": \"22nd street bakery\"\n" + ] + } + ], + "source": [ + "template = \"{instruction}\\n{response}\"\n", + "\n", + "prompt = template.format(\n", + " instruction=\"\"\"From the following get the type of inquiry, (order or request for information), filling, flavor, size, and pickup location and put it into a json\n", + "Hi,\n", + "I'd like to order a red velvet cake with custard filling. Please make it 8 inch round\"\"\",\n", + " response=\"\",\n", + ")\n", + "# sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)\n", + "# For our use case greedy is best\n", + "# gemma_lm.compile(sampler=sampler)\n", + "gemma_lm.compile(sampler=\"greedy\")\n", + "\n", + "print(gemma_lm.generate(prompt, max_length=256))" + ] + }, + { + "cell_type": "code", + "source": [ + "import json\n", + "prompt_1 = dict(prompt = \"\"\"\n", + "Hi Indian Bakery Central,\n", + "Do you happen to have 10 pendas, and thirty bundi ladoos on hand? Also do you sell a vanilla frosting and chocolate flavor cakes. I'm looking for a 6 inch size\n", + "\"\"\",\n", + "response = json.loads(\"\"\"\n", + " {\n", + " \"type\": \"inquiry\",\n", + " \"items\": [\n", + " {\n", + " \"name\": \"pendas\",\n", + " \"quantity\": 10\n", + " },\n", + " {\n", + " \"name\": \"bundi ladoos\",\n", + " \"quantity\": 30\n", + " },\n", + " {\n", + " \"name\": \"cake\",\n", + " \"filling\": null,\n", + " \"frosting\": \"vanilla\",\n", + " \"flavor\": \"chocolate\",\n", + " \"size\": \"6 in\"\n", + " }\n", + " ]\n", + "}\n", + "\"\"\")\n", + ")\n" + ], + "metadata": { + "id": "fsut8YS9tKBp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "{\n", + " \"training_prompt\": \"\"\"\n", + "Hi Indian Bakery Central,\n", + "Do you happen to have 10 pendas, and thirty bundi ladoos on hand? Also do you sell a vanilla frosting and chocolate flavor cakes. I'm looking for a 6 inch size\n", + "\"\"\"\n", + " \"response\":\"\"\"\n", + " [\n", + " {\n", + " \"name\": \"pendas\",\n", + " \"quantity\": 10\n", + " },\n", + " {\n", + " \"name\": \"bundi ladoos\",\n", + " \"quantity\": 30\n", + " },\n", + " {\n", + " \"name\": \"cake\",\n", + " \"filling\": null,\n", + " \"frosting\": \"vanilla\",\n", + " \"flavor\": \"chocolate\",\n", + " \"size\": \"6 in\"\n", + " }\n", + " ]\n", + "}\n", + "\"\"\"\n", + "}" + ], + "metadata": { + "id": "Em3oeIHFPbs9" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "prompt_2 = dict(prompt = \"\"\"\n", + "I saw your business on google maps. Do you sell jellabi and gulab jamun?\n", + "\"\"\",\n", + "response = json.loads(\"\"\"\n", + " {\n", + " \"type\": \"inquiry\",\n", + " \"items\": [\n", + " {\n", + " \"name\": \"jellabi\",\n", + " \"quantity\": null\n", + " },\n", + " {\n", + " \"name\": \"gulab jamun\",\n", + " \"quantity\": null\n", + " }\n", + " ]\n", + "}\n", + "\"\"\")\n", + ")" + ], + "metadata": { + "id": "EZLQkBcotKYD" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "prompt_3 = dict(prompt = \"\"\"\n", + "I'd like to place an order for a 8 inch red velvet cake with lemon frosting and chocolate chips topping.\n", + "\"\"\",\n", + "response = json.loads(\"\"\"\n", + " {\n", + " \"type\": \"order\",\n", + " \"items\": [\n", + " {\n", + " \"name\": \"cake\",\n", + " \"filling\": \"8inch\",\n", + " \"frosting\": \"lemon\",\n", + " \"flavor\": \"chocolate\",\n", + " \"size\": \"8 in\"\n", + " }\n", + " ]\n", + "}\n", + "\"\"\")\n", + ")" + ], + "metadata": { + "id": "LqHk5nHftKj9" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "prompt_4 = dict(prompt = \"\"\"\n", + "I'd like four jellabi and three gulab Jamun.\n", + "\"\"\",\n", + "response = json.loads(\"\"\"\n", + " {\n", + " \"type\": \"order\",\n", + " \"items\": [\n", + " {\n", + " \"name\": \"Jellabi\",\n", + " \"quantity\": 4\n", + " },\n", + " {\n", + " \"name\": \"Gulab Jamun\",\n", + " \"quantity\": 3\n", + " }\n", + " ]\n", + "}\n", + "\"\"\")\n", + ")\n", + "prompt_4" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3XB9NCuX43lB", + "outputId": "4e3c1659-db59-41ae-a3ac-c47fb94c5dfc" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'prompt': \"\\nI'd like four jellabi and three gulab Jamun.\\n\",\n", + " 'response': {'type': 'order',\n", + " 'items': [{'name': 'Jellabi', 'quantity': 4},\n", + " {'name': 'Gulab Jamun', 'quantity': 3}]}}" + ] + }, + "metadata": {}, + "execution_count": 11 + } + ] + }, + { + "cell_type": "code", + "source": [ + "prompt_4_2 = dict(prompt = \"\"\"\n", + "Please pack me a box with 10 halva.\n", + "\"\"\",\n", + "response = json.loads(\"\"\"\n", + " {\n", + " \"type\": \"order\",\n", + " \"items\": [\n", + " {\n", + " \"name\": \"halva\",\n", + " \"quantity\": 10\n", + " }\n", + " ]\n", + "}\n", + "\"\"\")\n", + ")" + ], + "metadata": { + "id": "h7i2-tXMS7V6" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "prompt_5 = dict(prompt = \"\"\"\n", + "Do you sell strawberry cakes with vanilla frosting with custard inside?\n", + "\"\"\",\n", + "response = json.loads(\"\"\"\n", + " {\n", + " \"type\": \"inquiry\",\n", + " \"items\": [\n", + " {\n", + " \"name\": \"cake\",\n", + " \"filling\": \"custard\",\n", + " \"frosting\": \"vanilla\",\n", + " \"flavor\": \"strawberry\",\n", + " \"size\": \"null\"\n", + " }\n", + " ]\n", + "}\n", + "\"\"\")\n", + ")\n" + ], + "metadata": { + "id": "tMetn-wmjUuX" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "prompt_5_2 = dict(prompt = \"\"\"\n", + "Do you sell carrot cakes with cream cheese frosting?\n", + "\"\"\",\n", + "response = json.loads(\"\"\"\n", + " {\n", + " \"type\": \"inquiry\",\n", + " \"items\": [\n", + " {\n", + " \"name\": \"cake\",\n", + " \"filling\": \"null\",\n", + " \"frosting\": \"cream cheese\",\n", + " \"flavor\": \"carrot\",\n", + " \"size\": \"null\"\n", + " }\n", + " ]\n", + "}\n", + "\"\"\")\n", + ")\n", + "prompt_5" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "w8VoAbFISq-X", + "outputId": "e8b928fa-d7f0-4258-b113-7967bd570f00" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'prompt': '\\nDo you sell strawberry cakes with vanilla frosting with custard inside?\\n',\n", + " 'response': {'type': 'inquiry',\n", + " 'items': [{'name': 'cake',\n", + " 'filling': 'custard',\n", + " 'frosting': 'vanilla',\n", + " 'flavor': 'strawberry',\n", + " 'size': 'null'}]}}" + ] + }, + "metadata": {}, + "execution_count": 14 + } + ] + }, + { + "cell_type": "code", + "source": [ + "prompt_6 = dict(prompt = \"\"\"\n", + "I found your website. What kind of items do you sell?\n", + "\"\"\",\n", + "response = json.loads(\"\"\"\n", + " {\n", + " \"type\": \"inquiry\",\n", + " \"items\": [\n", + " ]\n", + "}\n", + "\"\"\")\n", + ")\n" + ], + "metadata": { + "id": "_-XPKfCL15gx" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Starts overfitting on lemon if you add this\n", + "\n", + "# prompt_7 = dict(prompt = \"\"\"\n", + "# Can I buy 18 halva, as well as a lemon cake with lemon frosting?\n", + "# \"\"\",\n", + "# response = json.loads(\"\"\"\n", + "# {\n", + "# \"type\": \"inquiry\",\n", + "# \"items\": [\n", + "# {\n", + "# \"name\": \"halva\",\n", + "# \"quantity\": 18\n", + "# },\n", + "# {\n", + "# \"filling\": null,\n", + "# \"frosting\": \"lemon\",\n", + "# \"flavor\": \"lemon\",\n", + "# \"size\": null\n", + "# }\n", + "# ]\n", + "# }\n", + "# \"\"\")\n", + "# )" + ], + "metadata": { + "id": "dzjZSbDg2DvB" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "data = []\n", + "\n", + "for prompt in [prompt_1, prompt_2, prompt_3, prompt_4, prompt_4_2, prompt_5, prompt_5_2, prompt_6]:\n", + " data.append(template.format(instruction=prompt[\"prompt\"],response=prompt[\"response\"]))" + ], + "metadata": { + "id": "FknNEB26yHRN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pt7Nr6a7tItO" + }, + "source": [ + "## LoRA Fine-tuning\n", + "\n", + "The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.\n", + "\n", + "A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.\n", + "\n", + "This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.\n", + "\n", + "Be careful for over or underfit\n", + "* Rank\n", + "* Learning Rate\n", + "*" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RCucu6oHz53G", + "outputId": "90834d16-2273-4c2b-8592-efd79cb98499", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n" + ], + "text/html": [ + "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
+ "
\n"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n",
+ "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n"
+ ],
+ "text/html": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Tokenizer (type) ┃ Vocab # ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ gemma_tokenizer (GemmaTokenizer) │ 256,000 │\n", + "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" + ], + "text/html": [ + "
Model: \"gemma_causal_lm\"\n",
+ "
\n"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
+ "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+ "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
+ "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+ "│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m2,617,270,528\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
+ "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+ "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n",
+ "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n"
+ ],
+ "text/html": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ padding_mask (InputLayer) │ (None, None) │ 0 │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_ids (InputLayer) │ (None, None) │ 0 │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ gemma_backbone │ (None, None, 2304) │ 2,617,270,528 │ padding_mask[0][0], │\n", + "│ (GemmaBackbone) │ │ │ token_ids[0][0] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_embedding │ (None, None, 256000) │ 589,824,000 │ gemma_backbone[0][0] │\n", + "│ (ReversibleEmbedding) │ │ │ │\n", + "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,617,270,528\u001b[0m (9.75 GB)\n" + ], + "text/html": [ + "
Total params: 2,617,270,528 (9.75 GB)\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,928,640\u001b[0m (11.17 MB)\n" + ], + "text/html": [ + "
Trainable params: 2,928,640 (11.17 MB)\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" + ], + "text/html": [ + "
Non-trainable params: 2,614,341,888 (9.74 GB)\n", + "\n" + ] + }, + "metadata": {} + } + ], + "source": [ + "# Enable LoRA for the model and set the LoRA rank to 4.\n", + "gemma_lm.backbone.enable_lora(rank=4)\n", + "gemma_lm.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hQQ47kcdpbZ9" + }, + "source": [ + "Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.6 billion to 2.9 million)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_Peq7TnLtHse", + "outputId": "5d7fde1e-12d7-4fcf-e9da-b08fda31873e", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 1/3\n", + "\u001b[1m8/8\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m83s\u001b[0m 6s/step - loss: 0.7486 - sparse_categorical_accuracy: 0.6278\n", + "Epoch 2/3\n", + "\u001b[1m8/8\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m42s\u001b[0m 2s/step - loss: 0.5113 - sparse_categorical_accuracy: 0.6984\n", + "Epoch 3/3\n", + "\u001b[1m8/8\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 770ms/step - loss: 0.3469 - sparse_categorical_accuracy: 0.7796\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "