diff --git a/examples/slm_basic_train_ex.cpp b/examples/slm_basic_train_ex.cpp index e4322fc774..35368f54a7 100644 --- a/examples/slm_basic_train_ex.cpp +++ b/examples/slm_basic_train_ex.cpp @@ -49,6 +49,9 @@ // ---------------------------------------------------------------------------------------- +using namespace std; +using namespace dlib; + // We treat each character as a token ID in [0..255]. const int MAX_TOKEN_ID = 255; const int PAD_TOKEN = 256; // an extra "pad" token if needed @@ -66,13 +69,13 @@ std::vector char_based_tokenize(const std::string& text) } // Function to shuffle samples and labels in sync -void shuffle_samples_and_labels(std::vector>& samples, std::vector& labels) { +void shuffle_samples_and_labels(std::vector>& samples, std::vector& labels) { std::vector indices(samples.size()); std::iota(indices.begin(), indices.end(), 0); // Fill with 0, 1, 2, ..., N-1 std::shuffle(indices.begin(), indices.end(), std::default_random_engine{}); // Create temporary vectors to hold shuffled data - std::vector> shuffled_samples(samples.size()); + std::vector> shuffled_samples(samples.size()); std::vector shuffled_labels(labels.size()); // Apply the shuffle @@ -93,15 +96,15 @@ int main(int argc, char** argv) { try { - dlib::command_line_parser parser; + command_line_parser parser; parser.add_option("train", "Train a small transformer on the built-in Shakespeare text"); parser.add_option("generate", "Generate text from a previously trained model (needs shakespeare_prompt)"); parser.add_option("learning-rate", "Set the learning rate for training (default: 1e-4)", 1); parser.add_option("batch-size", "Set the mini-batch size for training (default: 64)", 1); parser.add_option("generation-length", "Set the length of generated text (default: 400)", 1); - parser.add_option("alpha", "Set the initial learning rate for Adam optimizer (default: 0.004)", 1); - parser.add_option("beta1", "Set the decay rate for the first moment estimate (default: 0.9)", 1); - parser.add_option("beta2", "Set the decay rate for the second moment estimate (default: 0.999)", 1); + parser.add_option("alpha", "Set the weight decay for Adam optimizer (default: 0.004)", 1); + parser.add_option("beta1", "Set the first moment coefficient (default: 0.9)", 1); + parser.add_option("beta2", "Set the second moment coefficient (default: 0.999)", 1); parser.add_option("max-samples", "Set the maximum number of training samples (default: 50000)", 1); parser.add_option("shuffle", "Shuffle training sequences and labels before training (default: false)"); parser.parse(argc, argv); @@ -122,7 +125,7 @@ int main(int argc, char** argv) const size_t max_samples = get_option(parser, "max-samples",50000); // Default maximum number of training samples // We define a minimal config for demonstration - const long vocab_size = 257; // 0..255 for chars + 1 pad token + const long vocab_size = MAX_TOKEN_ID + 1 + 1; // 256 for chars + 1 pad token const long num_layers = 3; const long num_heads = 4; const long embedding_dim = 64; @@ -136,8 +139,8 @@ int main(int argc, char** argv) embedding_dim, max_seq_len, use_squeezing, - dlib::gelu, - dlib::dropout_10 + gelu, + dropout_10 >; // For GPU usage (if any), set gpus = {0} for a single GPU, etc. @@ -151,7 +154,7 @@ int main(int argc, char** argv) // ---------------------------------------------------------------------------------------- if (parser.option("train")) { - std::cout << "=== TRAIN MODE ===\n"; + cout << "=== TRAIN MODE ===\n"; // 1) Prepare training data (simple approach) // We will store characters from shakespeare_text into a vector @@ -160,7 +163,7 @@ int main(int argc, char** argv) auto full_tokens = char_based_tokenize(shakespeare_text); if (full_tokens.empty()) { - std::cerr << "ERROR: The Shakespeare text is empty. Please provide a valid training text.\n"; + cerr << "ERROR: The Shakespeare text is empty. Please provide a valid training text.\n"; return 0; } @@ -170,18 +173,18 @@ int main(int argc, char** argv) : 0; // Display the size of the training text and the number of sequences - std::cout << "Training text size: " << full_tokens.size() << " characters\n"; - std::cout << "Maximum number of sequences: " << max_sequences << "\n"; + cout << "Training text size: " << full_tokens.size() << " characters\n"; + cout << "Maximum number of sequences: " << max_sequences << "\n"; // Check if the text is too short if (max_sequences == 0) { - std::cerr << "ERROR: The Shakespeare text is too short for training. It must contain at least " + cerr << "ERROR: The Shakespeare text is too short for training. It must contain at least " << (max_seq_len + 1) << " characters.\n"; return 0; } - std::vector> samples; + std::vector> samples; std::vector labels; // Let's create a training set of about (N) samples from the text @@ -190,7 +193,7 @@ int main(int argc, char** argv) const size_t N = (max_sequences < max_samples) ? max_sequences : max_samples; for (size_t start = 0; start < N; ++start) { - dlib::matrix seq(max_seq_len, 1); + matrix seq(max_seq_len, 1); for (long t = 0; t < max_seq_len; ++t) seq(t, 0) = full_tokens[start + t]; samples.push_back(seq); @@ -200,18 +203,18 @@ int main(int argc, char** argv) // Shuffle samples and labels if the --shuffle option is enabled if (parser.option("shuffle")) { - std::cout << "Shuffling training sequences and labels...\n"; + cout << "Shuffling training sequences and labels...\n"; shuffle_samples_and_labels(samples, labels); } // 3) Construct the network in training mode using net_type = my_transformer_cfg::network_type; net_type net; - if (dlib::file_exists(model_file)) - dlib::deserialize(model_file) >> net; + if (file_exists(model_file)) + deserialize(model_file) >> net; // 4) Create dnn_trainer - dlib::dnn_trainer trainer(net, dlib::adam(alpha, beta1, beta2), gpus); + dnn_trainer trainer(net, adam(alpha, beta1, beta2), gpus); trainer.set_learning_rate(learning_rate); trainer.set_min_learning_rate(1e-6); trainer.set_mini_batch_size(batch_size); @@ -229,12 +232,12 @@ int main(int argc, char** argv) if (predicted[i] == labels[i]) correct++; double accuracy = (double)correct / labels.size(); - std::cout << "Training accuracy (on this sample set): " << accuracy << "\n"; + cout << "Training accuracy (on this sample set): " << accuracy << "\n"; // 7) Save the model net.clean(); - dlib::serialize(model_file) << net; - std::cout << "Model saved to " << model_file << "\n"; + serialize(model_file) << net; + cout << "Model saved to " << model_file << "\n"; } // ---------------------------------------------------------------------------------------- @@ -242,28 +245,28 @@ int main(int argc, char** argv) // ---------------------------------------------------------------------------------------- if (parser.option("generate")) { - std::cout << "=== GENERATE MODE ===\n"; + cout << "=== GENERATE MODE ===\n"; // 1) Load the trained model using net_infer = my_transformer_cfg::network_type; net_infer net; - if (dlib::file_exists(model_file)) + if (file_exists(model_file)) { - dlib::deserialize(model_file) >> net; - std::cout << "Loaded model from " << model_file << "\n"; + deserialize(model_file) >> net; + cout << "Loaded model from " << model_file << "\n"; } else { - std::cerr << "Error: model file not found. Please run --train first.\n"; + cerr << "Error: model file not found. Please run --train first.\n"; return 0; } - std::cout << my_transformer_cfg::model_info::describe() << std::endl; - std::cout << "Model parameters: " << count_parameters(net) << std::endl << std::endl; + cout << my_transformer_cfg::model_info::describe() << endl; + cout << "Model parameters: " << count_parameters(net) << endl << endl; // 2) Get the prompt from the included slm_data.h std::string prompt_text = shakespeare_prompt; if (prompt_text.empty()) { - std::cerr << "No prompt found in slm_data.h.\n"; + cerr << "No prompt found in slm_data.h.\n"; return 0; } // If prompt is longer than max_seq_len, we keep only the first window @@ -274,7 +277,7 @@ int main(int argc, char** argv) const auto prompt_tokens = char_based_tokenize(prompt_text); // Put into a dlib matrix - dlib::matrix input_seq(max_seq_len, 1); + matrix input_seq(max_seq_len, 1); // Fill with pad if prompt is shorter than max_seq_len for (long i = 0; i < max_seq_len; ++i) { @@ -284,7 +287,7 @@ int main(int argc, char** argv) input_seq(i, 0) = PAD_TOKEN; } - std::cout << "\nInitial prompt:\n" << prompt_text << " (...)\n\n\nGenerated text:\n" << prompt_text; + cout << "\nInitial prompt:\n" << prompt_text << " (...)\n\n\nGenerated text:\n" << prompt_text; // 3) Generate new text // We'll predict one character at a time, then shift the window @@ -293,7 +296,7 @@ int main(int argc, char** argv) const int next_char = net(input_seq); // single inference // Print the generated character - std::cout << static_cast(std::min(next_char, MAX_TOKEN_ID)) << std::flush; + cout << static_cast(std::min(next_char, MAX_TOKEN_ID)) << flush; // Shift left by 1 for (long i = 0; i < max_seq_len - 1; ++i) @@ -301,14 +304,14 @@ int main(int argc, char** argv) input_seq(max_seq_len - 1, 0) = std::min(next_char, MAX_TOKEN_ID); } - std::cout << "\n\n(end of generation)\n"; + cout << "\n\n(end of generation)\n"; } return 0; } - catch (std::exception& e) + catch (exception& e) { - std::cerr << "Exception thrown: " << e.what() << std::endl; + cerr << "Exception thrown: " << e.what() << endl; return 1; } } diff --git a/examples/slm_data.h b/examples/slm_data.h index 86f3bdcc10..37c08f29b8 100644 --- a/examples/slm_data.h +++ b/examples/slm_data.h @@ -6,7 +6,7 @@ #include // Utility function to concatenate text parts -std::string concatenateTexts(const std::vector& texts) { +inline std::string concatenateTexts(const std::vector& texts) { std::string result; for (const auto& text : texts) { result += text; @@ -590,4 +590,4 @@ And you shall understand from me her mind. )"; -#endif // SlmData_H \ No newline at end of file +#endif // SlmData_H diff --git a/examples/slm_defs.h b/examples/slm_defs.h index 786b1ffdd7..d556fc0dab 100644 --- a/examples/slm_defs.h +++ b/examples/slm_defs.h @@ -214,11 +214,11 @@ namespace transformer template using network_type = std::conditional_t>>>>, + repeat>>>>, classification_head>>>> + repeat>>>> >; /** @@ -283,4 +283,4 @@ namespace transformer */ } -#endif // SlmNet_H \ No newline at end of file +#endif // SlmNet_H