-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added qlora + fsdp #14
Merged
Merged
Changes from 18 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
74dc122
avoid import dynamic linear
weifengpy d5022ba
fix fp8 import
weifengpy 9b2dca2
added qlora finetuning on fast-gpt model
weifengpy d113f06
add enable_qlora in train.py
weifengpy 4347e48
add qlora to README
weifengpy 93acee0
reorder QLoraMLP.w123 and use Hyperparameters
weifengpy db918e6
use separate finetune.py script from train.py
weifengpy bb3b9ff
use qlora prefix in log files
weifengpy d65c3a0
remove qlora import from train.py
weifengpy 81b420d
fix linter
weifengpy a6093b8
added FSDP + qlora
weifengpy 5ab3a76
remove duplicate code with torchao benchmark
weifengpy ad73696
remove duplicate code with torchao benchmark
weifengpy 99ce030
reuse hyper param and train config
weifengpy 9fd3ac5
mark copied code
weifengpy b99d3b1
Merge pull request #1 from weifengpy/fsdp_weif
weifengpy da88596
Merge branch 'main' into fsdp
weifengpy 6ae68cd
fix linter and log every 10 iter
weifengpy 7a0114d
use logging.info instead of print
weifengpy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
""" | ||
Used to train a model from scratch on big dense blocks of text data using causal attention. | ||
""" | ||
import argparse | ||
import csv | ||
import logging | ||
import math | ||
|
@@ -14,29 +15,48 @@ | |
|
||
import numpy as np | ||
import torch | ||
import torch.distributed as dist | ||
import torch.multiprocessing as mp | ||
import transformer_nuggets.llama.train | ||
import transformer_nuggets.quant.qlora as qlora | ||
from fire import Fire | ||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy | ||
from torch.utils.data import DataLoader, IterableDataset | ||
from tqdm import tqdm | ||
from transformer_nuggets.llama.model import ModelArgs, Transformer | ||
from transformer_nuggets.llama.train import load_datasets, log_num_params, train, TrainingConfig | ||
from transformer_nuggets.llama.model import ModelArgs, Transformer, TransformerBlock | ||
from transformer_nuggets.llama.train import ( | ||
calculate_loss, | ||
get_lr, | ||
get_profile_context, | ||
log_num_params, | ||
write_loss_to_file, | ||
) | ||
|
||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
@dataclass | ||
class Hyperparameters(transformer_nuggets.llama.train.Hyperparameters): | ||
# qlora config | ||
lora_r: int = 8 | ||
lora_alpha: int = 16 | ||
lora_dropout: float = 0.05 | ||
|
||
|
||
@dataclass | ||
class TrainingConfig(transformer_nuggets.llama.train.TrainingConfig): | ||
log_interval: int = 10 | ||
track_max_memory: bool = False | ||
|
||
|
||
def main( | ||
hyper_params: Hyperparameters, | ||
training_config: TrainingConfig, | ||
rank: int, | ||
world_size: int, | ||
): | ||
torch.cuda.set_device(rank) | ||
|
||
random.seed(1337) | ||
np.random.seed(1337) | ||
torch.manual_seed(1337) | ||
|
@@ -47,7 +67,8 @@ def main( | |
|
||
# Setup Model | ||
model_args = ModelArgs.from_name(training_config.model_name) | ||
logging.info(f"Initializing model: {training_config.model_name}") | ||
if rank == 0: | ||
logging.info(f"Initializing model: {training_config.model_name}") | ||
with training_config.device: | ||
model = Transformer(model_args).to(torch.bfloat16) | ||
model.init_parameters() | ||
|
@@ -58,20 +79,35 @@ def main( | |
hyper_params.lora_dropout, | ||
) | ||
qlora.swap_for_qlora(model, qlora_config, torch.bfloat16) | ||
|
||
model.setup_caches( | ||
hyper_params.micro_batch_size, hyper_params.max_seq_length, training_config.device | ||
) | ||
|
||
logging.info("Setting up the dataloaders") | ||
train_data, val_data = load_datasets(hyper_params, training_config) | ||
if rank == 0: | ||
logging.info("Setting up the dataloaders") | ||
train_data, val_data = load_datasets(hyper_params, training_config, rank, world_size) | ||
train_dataloader = DataLoader( | ||
train_data, batch_size=hyper_params.micro_batch_size, num_workers=2 | ||
train_data, | ||
batch_size=hyper_params.micro_batch_size, | ||
num_workers=2, | ||
) | ||
val_dataloader = DataLoader(val_data, batch_size=hyper_params.micro_batch_size, num_workers=2) | ||
|
||
log_num_params(model) | ||
|
||
if world_size > 1: | ||
model = FSDP( | ||
model, | ||
use_orig_params=True, | ||
auto_wrap_policy=ModuleWrapPolicy([TransformerBlock]), | ||
) | ||
|
||
if training_config.compile: | ||
model = torch.compile(model) | ||
|
||
if rank == 0: | ||
print(model) | ||
|
||
optimizer = torch.optim.AdamW( | ||
[p for p in model.parameters() if p.requires_grad], | ||
lr=hyper_params.learning_rate, | ||
|
@@ -87,18 +123,233 @@ def main( | |
val_dataloader, | ||
hyper_params, | ||
training_config, | ||
rank, | ||
world_size, | ||
) | ||
|
||
|
||
def entrypoint( | ||
profile: bool = False, | ||
rank: int = 0, | ||
world_size: int = 1, | ||
): | ||
batch_size = int(128 / world_size) | ||
assert isinstance(profile, bool), "profile must be bool" | ||
hyper_params = Hyperparameters() | ||
training_config = TrainingConfig(profile=profile) | ||
main(hyper_params, training_config) | ||
hyper_params = Hyperparameters(batch_size=batch_size) | ||
training_config = TrainingConfig( | ||
profile=profile, | ||
device=torch.device(f"cuda:{rank}"), | ||
) | ||
main(hyper_params, training_config, rank, world_size) | ||
|
||
|
||
def fsdp_main(rank, world_size, args): | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
os.environ["MASTER_PORT"] = "12355" | ||
dist.init_process_group("nccl", rank=rank, world_size=world_size) | ||
entrypoint(*args, rank=rank, world_size=world_size) | ||
dist.destroy_process_group() | ||
|
||
|
||
def train( | ||
model: Transformer, | ||
optimizer: torch.optim.Optimizer, | ||
train_data: DataLoader, | ||
val_data: DataLoader, | ||
hyper_params: Hyperparameters, | ||
training_config: TrainingConfig, | ||
rank: int, | ||
world_size: int, | ||
) -> None: | ||
"""Lets go!""" | ||
step_count = 0 | ||
|
||
model.train() | ||
profile_context = get_profile_context(hyper_params, training_config) | ||
train_iter = iter(train_data) | ||
|
||
dtype_str = "bf16" | ||
|
||
val_loss_file = ( | ||
training_config.log_dir | ||
/ f"qlora_validation_loss_{dtype_str}_overfit_{training_config.overfit}_compile_{training_config.compile}_{rank}.csv" | ||
) | ||
train_loss_file = ( | ||
training_config.log_dir | ||
/ f"qlora_train_loss_{dtype_str}_overfit_{training_config.overfit}_compile_{training_config.compile}_{rank}.csv" | ||
) | ||
if rank == 0: | ||
logging.info(f"val_loss_file: {val_loss_file}") | ||
logging.info(f"train_loss_file: {train_loss_file}") | ||
|
||
this_batch_loss = torch.tensor(0.0, device=training_config.device) | ||
this_batch_n = 0 | ||
fsdp_loss = torch.zeros(2, device=training_config.device) | ||
|
||
with profile_context as p: | ||
for iter_num in range(hyper_params.max_iters): | ||
lr = get_lr(iter_num, hyper_params) | ||
for param_group in optimizer.param_groups: | ||
param_group["lr"] = lr | ||
|
||
input_ids, targets = next(train_iter) | ||
input_ids = input_ids.pin_memory().to(training_config.device) | ||
targets = targets.pin_memory().to(training_config.device) | ||
is_accumulating = (iter_num + 1) % hyper_params.gradient_accumulation_iters != 0 | ||
|
||
if iter_num % hyper_params.gradient_accumulation_iters == 0: | ||
with torch.no_grad(): | ||
this_batch_loss.fill_(0) | ||
this_batch_n = 0 | ||
|
||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | ||
logits = model(input_ids) | ||
|
||
# Calculate the loss | ||
loss = calculate_loss(logits, targets) | ||
with torch.no_grad(): | ||
this_batch_loss += loss | ||
this_batch_n += len(input_ids) | ||
|
||
# Scale the loss by grad_accumulation iters | ||
(loss / hyper_params.gradient_accumulation_iters).backward() | ||
|
||
if not is_accumulating: | ||
optimizer.step() | ||
optimizer.zero_grad() | ||
step_count += 1 | ||
|
||
# TODO(future): fix this condition, eval currently only happens | ||
# if eval_interval and batch_size are multiples of each other | ||
if not is_accumulating and step_count % training_config.eval_interval == 0: | ||
t0 = time.time() | ||
val_loss = validate( | ||
model, val_data, val_loss_file, training_config, step_count, rank, world_size | ||
) | ||
t1 = time.time() - t0 | ||
if rank == 0: | ||
logging.info( | ||
f"step {iter_num}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms" | ||
) | ||
|
||
if not is_accumulating and step_count % training_config.save_interval == 0: | ||
checkpoint_path = training_config.out_dir / f"iter-{iter_num:06d}-ckpt.pth" | ||
torch.save(checkpoint_path, {"model": model}) | ||
|
||
if (iter_num + 1) % training_config.log_interval == 0: | ||
# loss.item causes a sync so we update the progress bar sporadically | ||
if world_size == 1: | ||
with torch.no_grad(): | ||
avg_loss_this_batch = this_batch_loss / this_batch_n | ||
loss_val = avg_loss_this_batch | ||
else: | ||
fsdp_loss[0] = this_batch_loss | ||
fsdp_loss[1] = this_batch_n | ||
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) | ||
loss_val = fsdp_loss[0] / fsdp_loss[1] | ||
|
||
write_loss_to_file(train_loss_file, step_count, loss_val) | ||
|
||
if rank == 0: | ||
logging.info( | ||
f"iter={iter_num} max_iters={hyper_params.max_iters} loss={loss_val:.4f}" | ||
) | ||
|
||
if training_config.profile and iter_num < 103: | ||
# We want to profile iters 100-102 of the model training | ||
p.step() | ||
|
||
if training_config.track_max_memory and rank == 0: | ||
print( | ||
"iter_num", | ||
iter_num, | ||
"mem usage GB", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit this shoudl actually be GiB |
||
float(torch.cuda.max_memory_allocated()) / 1024 / 1024 / 1024, | ||
) | ||
torch.cuda.reset_peak_memory_stats() | ||
|
||
|
||
class Dataset(IterableDataset): | ||
def __init__( | ||
self, | ||
data_file: Path, | ||
hyper_params: Hyperparameters, | ||
training_config: TrainingConfig, | ||
rank: int = 0, | ||
world_size: int = 1, | ||
): | ||
super().__init__() | ||
self.data_file = data_file | ||
self.max_seq_length = hyper_params.max_seq_length | ||
self.max_iters = hyper_params.max_iters | ||
self.overfit = training_config.overfit | ||
self.deterministic_data_loading = training_config.deterministic_data_loading | ||
self.index = 0 | ||
self.rank = rank | ||
self.world_size = world_size | ||
|
||
def __iter__(self): | ||
data = np.memmap(self.data_file, dtype=np.uint16, mode="r") | ||
per_rank = int(self.max_iters / float(self.world_size)) | ||
rank_offset = self.rank * per_rank | ||
worker_info = torch.utils.data.get_worker_info() | ||
assert worker_info is not None, "single process data loading not implemented yet" | ||
per_worker = int(per_rank / float(worker_info.num_workers)) | ||
worker_id = worker_info.id | ||
worker_offset = worker_id * per_worker | ||
while True: | ||
if self.overfit: | ||
i = 0 | ||
else: | ||
if self.deterministic_data_loading: | ||
i = self.index + rank_offset + worker_offset | ||
self.index += self.max_seq_length | ||
else: | ||
i = torch.randint(len(data) - self.max_seq_length, (1,)).item() | ||
x = torch.from_numpy((data[i : i + self.max_seq_length]).astype(np.int64)) | ||
y = torch.from_numpy((data[i + 1 : i + 1 + self.max_seq_length]).astype(np.int64)) | ||
yield x, y | ||
|
||
|
||
def load_datasets( | ||
hyper_params: Hyperparameters, | ||
training_config: TrainingConfig, | ||
rank: int, | ||
world_size: int, | ||
): | ||
train_data = Dataset( | ||
str(training_config.data_dir / "train.bin"), | ||
hyper_params=hyper_params, | ||
training_config=training_config, | ||
rank=rank, | ||
world_size=world_size, | ||
) | ||
val_data = Dataset( | ||
str(training_config.data_dir / "val.bin"), | ||
hyper_params=hyper_params, | ||
training_config=training_config, | ||
rank=rank, | ||
world_size=world_size, | ||
) | ||
return train_data, val_data | ||
|
||
|
||
if __name__ == "__main__": | ||
torch.set_float32_matmul_precision("high") | ||
Fire(entrypoint) | ||
parser = argparse.ArgumentParser(description="Native PyTorch LLaMa trainer") | ||
parser.add_argument("--profile", action=argparse.BooleanOptionalAction, default=False) | ||
parser.add_argument( | ||
"--fsdp_num_gpus", | ||
type=int, | ||
default=1, | ||
help="if specified, runs FSDP with this many GPUs on a single host", | ||
) | ||
args = parser.parse_args() | ||
fsdp_num_gpus = args.fsdp_num_gpus | ||
inner_args = (args.profile,) | ||
|
||
if fsdp_num_gpus is None or fsdp_num_gpus == 1: | ||
entrypoint(*inner_args) | ||
else: | ||
assert fsdp_num_gpus <= torch.cuda.device_count() | ||
mp.spawn(fsdp_main, args=(fsdp_num_gpus, inner_args), nprocs=fsdp_num_gpus, join=True) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: log