Skip to content

Commit

Permalink
minor tweaks to the transformer example (#3048)
Browse files Browse the repository at this point in the history
* minor tweaks to the transformer example

* actually take advantadge of using namespace std;
  • Loading branch information
arrufat authored Jan 24, 2025
1 parent 8fdd2a6 commit 28c46fb
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 44 deletions.
77 changes: 40 additions & 37 deletions examples/slm_basic_train_ex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@

// ----------------------------------------------------------------------------------------

using namespace std;
using namespace dlib;

// We treat each character as a token ID in [0..255].
const int MAX_TOKEN_ID = 255;
const int PAD_TOKEN = 256; // an extra "pad" token if needed
Expand All @@ -66,13 +69,13 @@ std::vector<int> char_based_tokenize(const std::string& text)
}

// Function to shuffle samples and labels in sync
void shuffle_samples_and_labels(std::vector<dlib::matrix<int, 0, 1>>& samples, std::vector<unsigned long>& labels) {
void shuffle_samples_and_labels(std::vector<matrix<int, 0, 1>>& samples, std::vector<unsigned long>& labels) {
std::vector<size_t> indices(samples.size());
std::iota(indices.begin(), indices.end(), 0); // Fill with 0, 1, 2, ..., N-1
std::shuffle(indices.begin(), indices.end(), std::default_random_engine{});

// Create temporary vectors to hold shuffled data
std::vector<dlib::matrix<int, 0, 1>> shuffled_samples(samples.size());
std::vector<matrix<int, 0, 1>> shuffled_samples(samples.size());
std::vector<unsigned long> shuffled_labels(labels.size());

// Apply the shuffle
Expand All @@ -93,15 +96,15 @@ int main(int argc, char** argv)
{
try
{
dlib::command_line_parser parser;
command_line_parser parser;
parser.add_option("train", "Train a small transformer on the built-in Shakespeare text");
parser.add_option("generate", "Generate text from a previously trained model (needs shakespeare_prompt)");
parser.add_option("learning-rate", "Set the learning rate for training (default: 1e-4)", 1);
parser.add_option("batch-size", "Set the mini-batch size for training (default: 64)", 1);
parser.add_option("generation-length", "Set the length of generated text (default: 400)", 1);
parser.add_option("alpha", "Set the initial learning rate for Adam optimizer (default: 0.004)", 1);
parser.add_option("beta1", "Set the decay rate for the first moment estimate (default: 0.9)", 1);
parser.add_option("beta2", "Set the decay rate for the second moment estimate (default: 0.999)", 1);
parser.add_option("alpha", "Set the weight decay for Adam optimizer (default: 0.004)", 1);
parser.add_option("beta1", "Set the first moment coefficient (default: 0.9)", 1);
parser.add_option("beta2", "Set the second moment coefficient (default: 0.999)", 1);
parser.add_option("max-samples", "Set the maximum number of training samples (default: 50000)", 1);
parser.add_option("shuffle", "Shuffle training sequences and labels before training (default: false)");
parser.parse(argc, argv);
Expand All @@ -122,7 +125,7 @@ int main(int argc, char** argv)
const size_t max_samples = get_option(parser, "max-samples",50000); // Default maximum number of training samples

// We define a minimal config for demonstration
const long vocab_size = 257; // 0..255 for chars + 1 pad token
const long vocab_size = MAX_TOKEN_ID + 1 + 1; // 256 for chars + 1 pad token
const long num_layers = 3;
const long num_heads = 4;
const long embedding_dim = 64;
Expand All @@ -136,8 +139,8 @@ int main(int argc, char** argv)
embedding_dim,
max_seq_len,
use_squeezing,
dlib::gelu,
dlib::dropout_10
gelu,
dropout_10
>;

// For GPU usage (if any), set gpus = {0} for a single GPU, etc.
Expand All @@ -151,7 +154,7 @@ int main(int argc, char** argv)
// ----------------------------------------------------------------------------------------
if (parser.option("train"))
{
std::cout << "=== TRAIN MODE ===\n";
cout << "=== TRAIN MODE ===\n";

// 1) Prepare training data (simple approach)
// We will store characters from shakespeare_text into a vector
Expand All @@ -160,7 +163,7 @@ int main(int argc, char** argv)
auto full_tokens = char_based_tokenize(shakespeare_text);
if (full_tokens.empty())
{
std::cerr << "ERROR: The Shakespeare text is empty. Please provide a valid training text.\n";
cerr << "ERROR: The Shakespeare text is empty. Please provide a valid training text.\n";
return 0;
}

Expand All @@ -170,18 +173,18 @@ int main(int argc, char** argv)
: 0;

// Display the size of the training text and the number of sequences
std::cout << "Training text size: " << full_tokens.size() << " characters\n";
std::cout << "Maximum number of sequences: " << max_sequences << "\n";
cout << "Training text size: " << full_tokens.size() << " characters\n";
cout << "Maximum number of sequences: " << max_sequences << "\n";

// Check if the text is too short
if (max_sequences == 0)
{
std::cerr << "ERROR: The Shakespeare text is too short for training. It must contain at least "
cerr << "ERROR: The Shakespeare text is too short for training. It must contain at least "
<< (max_seq_len + 1) << " characters.\n";
return 0;
}

std::vector<dlib::matrix<int, 0, 1>> samples;
std::vector<matrix<int, 0, 1>> samples;
std::vector<unsigned long> labels;

// Let's create a training set of about (N) samples from the text
Expand All @@ -190,7 +193,7 @@ int main(int argc, char** argv)
const size_t N = (max_sequences < max_samples) ? max_sequences : max_samples;
for (size_t start = 0; start < N; ++start)
{
dlib::matrix<int, 0, 1> seq(max_seq_len, 1);
matrix<int, 0, 1> seq(max_seq_len, 1);
for (long t = 0; t < max_seq_len; ++t)
seq(t, 0) = full_tokens[start + t];
samples.push_back(seq);
Expand All @@ -200,18 +203,18 @@ int main(int argc, char** argv)
// Shuffle samples and labels if the --shuffle option is enabled
if (parser.option("shuffle"))
{
std::cout << "Shuffling training sequences and labels...\n";
cout << "Shuffling training sequences and labels...\n";
shuffle_samples_and_labels(samples, labels);
}

// 3) Construct the network in training mode
using net_type = my_transformer_cfg::network_type<true>;
net_type net;
if (dlib::file_exists(model_file))
dlib::deserialize(model_file) >> net;
if (file_exists(model_file))
deserialize(model_file) >> net;

// 4) Create dnn_trainer
dlib::dnn_trainer<net_type, dlib::adam> trainer(net, dlib::adam(alpha, beta1, beta2), gpus);
dnn_trainer<net_type, adam> trainer(net, adam(alpha, beta1, beta2), gpus);
trainer.set_learning_rate(learning_rate);
trainer.set_min_learning_rate(1e-6);
trainer.set_mini_batch_size(batch_size);
Expand All @@ -229,41 +232,41 @@ int main(int argc, char** argv)
if (predicted[i] == labels[i])
correct++;
double accuracy = (double)correct / labels.size();
std::cout << "Training accuracy (on this sample set): " << accuracy << "\n";
cout << "Training accuracy (on this sample set): " << accuracy << "\n";

// 7) Save the model
net.clean();
dlib::serialize(model_file) << net;
std::cout << "Model saved to " << model_file << "\n";
serialize(model_file) << net;
cout << "Model saved to " << model_file << "\n";
}

// ----------------------------------------------------------------------------------------
// Generate mode
// ----------------------------------------------------------------------------------------
if (parser.option("generate"))
{
std::cout << "=== GENERATE MODE ===\n";
cout << "=== GENERATE MODE ===\n";
// 1) Load the trained model
using net_infer = my_transformer_cfg::network_type<false>;
net_infer net;
if (dlib::file_exists(model_file))
if (file_exists(model_file))
{
dlib::deserialize(model_file) >> net;
std::cout << "Loaded model from " << model_file << "\n";
deserialize(model_file) >> net;
cout << "Loaded model from " << model_file << "\n";
}
else
{
std::cerr << "Error: model file not found. Please run --train first.\n";
cerr << "Error: model file not found. Please run --train first.\n";
return 0;
}
std::cout << my_transformer_cfg::model_info::describe() << std::endl;
std::cout << "Model parameters: " << count_parameters(net) << std::endl << std::endl;
cout << my_transformer_cfg::model_info::describe() << endl;
cout << "Model parameters: " << count_parameters(net) << endl << endl;

// 2) Get the prompt from the included slm_data.h
std::string prompt_text = shakespeare_prompt;
if (prompt_text.empty())
{
std::cerr << "No prompt found in slm_data.h.\n";
cerr << "No prompt found in slm_data.h.\n";
return 0;
}
// If prompt is longer than max_seq_len, we keep only the first window
Expand All @@ -274,7 +277,7 @@ int main(int argc, char** argv)
const auto prompt_tokens = char_based_tokenize(prompt_text);

// Put into a dlib matrix
dlib::matrix<int, 0, 1> input_seq(max_seq_len, 1);
matrix<int, 0, 1> input_seq(max_seq_len, 1);
// Fill with pad if prompt is shorter than max_seq_len
for (long i = 0; i < max_seq_len; ++i)
{
Expand All @@ -284,7 +287,7 @@ int main(int argc, char** argv)
input_seq(i, 0) = PAD_TOKEN;
}

std::cout << "\nInitial prompt:\n" << prompt_text << " (...)\n\n\nGenerated text:\n" << prompt_text;
cout << "\nInitial prompt:\n" << prompt_text << " (...)\n\n\nGenerated text:\n" << prompt_text;

// 3) Generate new text
// We'll predict one character at a time, then shift the window
Expand All @@ -293,22 +296,22 @@ int main(int argc, char** argv)
const int next_char = net(input_seq); // single inference

// Print the generated character
std::cout << static_cast<char>(std::min(next_char, MAX_TOKEN_ID)) << std::flush;
cout << static_cast<char>(std::min(next_char, MAX_TOKEN_ID)) << flush;

// Shift left by 1
for (long i = 0; i < max_seq_len - 1; ++i)
input_seq(i, 0) = input_seq(i + 1, 0);
input_seq(max_seq_len - 1, 0) = std::min(next_char, MAX_TOKEN_ID);
}

std::cout << "\n\n(end of generation)\n";
cout << "\n\n(end of generation)\n";
}

return 0;
}
catch (std::exception& e)
catch (exception& e)
{
std::cerr << "Exception thrown: " << e.what() << std::endl;
cerr << "Exception thrown: " << e.what() << endl;
return 1;
}
}
Expand Down
4 changes: 2 additions & 2 deletions examples/slm_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <algorithm>

// Utility function to concatenate text parts
std::string concatenateTexts(const std::vector<std::string>& texts) {
inline std::string concatenateTexts(const std::vector<std::string>& texts) {
std::string result;
for (const auto& text : texts) {
result += text;
Expand Down Expand Up @@ -590,4 +590,4 @@ And you shall understand from me her mind.
)";

#endif // SlmData_H
#endif // SlmData_H
10 changes: 5 additions & 5 deletions examples/slm_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,11 @@ namespace transformer
template<bool is_training>
using network_type = std::conditional_t<is_training,
classification_head<USE_SQUEEZING, activation_func, VOCAB_SIZE, EMBEDDING_DIM,
repeat<NUM_LAYERS, t_transformer_block,
positional_embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>,
repeat<NUM_LAYERS, t_transformer_block,
positional_embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>,
classification_head<USE_SQUEEZING, activation_func, VOCAB_SIZE, EMBEDDING_DIM,
repeat<NUM_LAYERS, i_transformer_block,
positional_embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>
repeat<NUM_LAYERS, i_transformer_block,
positional_embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>
>;

/**
Expand Down Expand Up @@ -283,4 +283,4 @@ namespace transformer
*/
}

#endif // SlmNet_H
#endif // SlmNet_H

0 comments on commit 28c46fb

Please sign in to comment.