Skip to content

Commit

Permalink
Merge pull request #125 from imoneoi/3.5.1
Browse files Browse the repository at this point in the history
3.5.1
  • Loading branch information
imoneoi authored Dec 12, 2023
2 parents 7df4496 + 1aa7f19 commit 85466f5
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 191 deletions.
Binary file modified assets/openchat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified assets/openchat_grok.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 1 addition & 2 deletions ochat/evaluation/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ async def run_eval(
if output_file is None:
output_file = os.path.join(os.path.dirname(data_path), "eval_results", f"{os.path.basename(model)}_{condition}.json")

os.makedirs(os.path.dirname(output_file), exist_ok=True)

os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "wb") as f:
f.write(orjson.dumps(questions, option=orjson.OPT_INDENT_2))

Expand Down
31 changes: 19 additions & 12 deletions ochat/models/unpadded_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,6 @@ def weighted_cross_entropy(logits: torch.Tensor, labels: torch.Tensor, weights:
return (weights * torch.nn.functional.cross_entropy(logits, labels, reduction="none")).sum()


@torch.jit.script # type: ignore
def rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, variance_epsilon: float):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)

variance = (hidden_states * hidden_states).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
return weight * hidden_states.to(input_dtype)


def rotate_half(x: torch.Tensor):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
Expand All @@ -80,6 +70,18 @@ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, si


# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
RMS_NORM_TRACED = None


def rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, variance_epsilon: torch.Tensor):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)

variance = (hidden_states * hidden_states).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
return weight * hidden_states.to(input_dtype)


class UnpaddedMistralRMSNorm(nn.Module):
def __init__(self, hidden_size, eps):
"""
Expand All @@ -88,10 +90,15 @@ def __init__(self, hidden_size, eps):
super().__init__()

self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.variance_epsilon = torch.tensor(eps, dtype=torch.get_default_dtype())

global RMS_NORM_TRACED
if RMS_NORM_TRACED is None:
RMS_NORM_TRACED = torch.jit.trace(rms_norm, (torch.ones(hidden_size), torch.ones(hidden_size), self.variance_epsilon))

def forward(self, hidden_states):
return rms_norm(hidden_states, self.weight, self.variance_epsilon)
global RMS_NORM_TRACED
return RMS_NORM_TRACED(hidden_states, self.weight, self.variance_epsilon)


# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from typing import Any, Optional, List, Callable

import torch.distributed as dist

import numpy as np
import numba

Expand Down Expand Up @@ -96,61 +92,42 @@ def allocate(lengths: np.ndarray, numseqs: np.ndarray, lengths_cumsum: np.ndarra
return result, result_totseqs, s, len(result) * c * n


class MultipackDistributedDataloader:
class MultipackDistributedSampler:
"""Unpadded data loading using Multipack.
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard."""

def __init__(
self,
dataset: Any,
lengths: np.ndarray,
numseqs: np.ndarray,

batch_max_length: int,
collate_fn: Callable,

num_replicas: Optional[int] = None,
rank: Optional[int] = None,
num_replicas: int,
rank: int,

seed: int = 0,
seed: int,
):
# Dataset
self.dataset = dataset
self.lengths = lengths
self.numseqs = numseqs
assert isinstance(self.lengths, np.ndarray)

self.batch_max_length = batch_max_length
self.collate_fn = collate_fn

# Get rank
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()

self.num_replicas = num_replicas
self.rank = rank

# Seed
self.seed = seed

# Epoch
self.epoch = 0

# statistics
self.eff_total_used = 0
self.eff_total_slots = 0

def set_epoch(self, epoch: int):
self.epoch = epoch

def generate_batches(self, set_stats=False):
indices = np.random.default_rng(seed=self.seed + self.epoch).permutation(len(self.lengths))
def generate_batches(self, epoch, set_stats=False):
indices = np.random.default_rng(seed=self.seed + epoch).permutation(len(self.lengths))

lengths = self.lengths[indices]
numseqs = self.numseqs[indices]
Expand All @@ -173,14 +150,14 @@ def generate_batches(self, set_stats=False):

return batches, totseqs, curseqs

def __iter__(self):
all_batches, all_totseqs, all_curseqs = self.generate_batches(set_stats=True)
def iter(self, epoch):
all_batches, all_totseqs, all_curseqs = self.generate_batches(epoch, set_stats=True)

for batch, totseq, curseq in zip(all_batches, all_totseqs, all_curseqs):
yield self.collate_fn(self.dataset[batch]), totseq, curseq
yield batch, totseq, curseq

def num_batches(self):
batches, _, _ = self.generate_batches()
def estimate_num_batches(self):
batches, _, _ = self.generate_batches(epoch=0)
return len(batches)

def efficiency(self):
Expand Down
26 changes: 0 additions & 26 deletions ochat/training_deepspeed/numpy_dataset.py

This file was deleted.

110 changes: 110 additions & 0 deletions ochat/training_deepspeed/openchat_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch
import numpy as np
from torch.utils.data import IterableDataset, get_worker_info

import pyarrow.parquet as pq
import orjson

from ochat.training_deepspeed.multipack_sampler import MultipackDistributedSampler


def _find_multiple(a, b):
return (-(a // -b)) * b


class OpenchatDataset(IterableDataset):
def __init__(self, dataset_filename, batch_max_length, rank, num_replicas):
super().__init__()
# Init constants
self.PAD_ID = 0
self.PAD_MULTIPLE = 64
self.BATCH_KEYS = {
"seqlens": torch.int32,
"nz_input_ids": torch.long,
"nz_position_ids": torch.long,
"nz_shifted_label_ids": torch.long,

"nz_shifted_loss_weights": torch.bfloat16
}

assert batch_max_length % self.PAD_MULTIPLE == 0, f"Batch size {batch_max_length} need to be multiples of {self.PAD_MULTIPLE}"

# Load data
# Convert parquet to numpy for fast random access
table = pq.read_table(dataset_filename, memory_map=True)
self.dataset = {k: v.to_numpy() for k, v in zip(table.column_names, table.columns)}

# read metadata
self.metadata = table.schema.metadata.get(b"metadata_json", None)
if self.metadata is not None:
self.metadata = orjson.loads(self.metadata)

# Free table space
del table

# Create sampler
self.sampler = MultipackDistributedSampler(
lengths=self.dataset["total_length"],
numseqs=self.dataset["num_seqs"],

batch_max_length=batch_max_length,

rank=rank,
num_replicas=num_replicas,
seed=0
)

# Init state
self._epoch = 0

def _load_batch(self, indices):
batch = {k: v[indices] for k, v in self.dataset.items()}

# Concat batches
batch = {k: np.concatenate(batch[k], axis=0) for k in self.BATCH_KEYS.keys()}

# Pad an unused item to reach multiple of PAD_MULTIPLE, for faster GEMM
total_seqlen = batch["nz_input_ids"].size
pad_len = _find_multiple(total_seqlen, self.PAD_MULTIPLE) - total_seqlen

if pad_len > 0:
assert pad_len < self.PAD_MULTIPLE

# total length
padding_specs = {
"seqlens": (1, pad_len),

"nz_input_ids": (pad_len, self.PAD_ID),
"nz_position_ids": (pad_len, 0),
"nz_shifted_label_ids": (pad_len, self.PAD_ID),
"nz_shifted_loss_weights": (pad_len, 0),
}
for k, pad_spec in padding_specs.items():
batch[k] = np.concatenate((batch[k], np.full(*pad_spec, dtype=batch[k].dtype)), axis=0)

# to tensor
batch_tensor = {}
for k, dtype in self.BATCH_KEYS.items():
batch_tensor[k] = torch.from_numpy(batch[k]).to(dtype)

# cu seqlens
batch_tensor["cu_seqlens"] = torch.nn.functional.pad(batch_tensor["seqlens"].cumsum(-1, dtype=torch.int32), (1, 0))
# batch info
batch_info = {"max_seqlen": torch.max(batch_tensor["seqlens"]).item()}

# inputs
del batch_tensor["seqlens"]
return batch_tensor, batch_info

def __iter__(self):
worker_info = get_worker_info()
assert worker_info is None or worker_info.num_workers == 1

for indices, all_numseq, cur_numseq in self.sampler.iter(self._epoch):
yield self._load_batch(indices), all_numseq, cur_numseq

# Increase epoch count
self._epoch += 1

def estimate_num_batches(self):
return self.sampler.estimate_num_batches()
Loading

0 comments on commit 85466f5

Please sign in to comment.