From 82435ebb2c28beccf274f5b9e2512766aa375e3a Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Thu, 3 Oct 2024 16:26:15 +0200 Subject: [PATCH] chore: R1 and R2 match --- crates/ratchet-core/src/cpu/gemm.rs | 2 +- crates/ratchet-core/src/cpu/mod.rs | 16 ++-- crates/ratchet-core/src/cpu/rope.rs | 112 +++++++++++++++------------ crates/ratchet-core/src/cpu/slice.rs | 3 + crates/ratchet-core/src/op.rs | 2 +- crates/ratchet-core/src/tensor.rs | 2 +- 6 files changed, 76 insertions(+), 61 deletions(-) create mode 100644 crates/ratchet-core/src/cpu/slice.rs diff --git a/crates/ratchet-core/src/cpu/gemm.rs b/crates/ratchet-core/src/cpu/gemm.rs index fdba0622..b932b37b 100644 --- a/crates/ratchet-core/src/cpu/gemm.rs +++ b/crates/ratchet-core/src/cpu/gemm.rs @@ -156,7 +156,7 @@ fn gemm_impl( } impl CPUOperation for Matmul { - fn apply(&self, dst: Tensor) -> Result { + fn apply_cpu(&self, dst: Tensor) -> Result { fn run_gemm( spec: MatmulSpec, lhs: &Tensor, diff --git a/crates/ratchet-core/src/cpu/mod.rs b/crates/ratchet-core/src/cpu/mod.rs index 6d5182ff..3646373a 100644 --- a/crates/ratchet-core/src/cpu/mod.rs +++ b/crates/ratchet-core/src/cpu/mod.rs @@ -168,7 +168,7 @@ macro_rules! impl_cpu_unary { impl_cpu_unary_wrapper!($dtype, $conv); impl CPUOperation for CPU<$dtype, Unary> { - fn apply(&self, dst: Tensor) -> Result { + fn apply_cpu(&self, dst: Tensor) -> Result { match self.op.op() { UnaryOp::Gelu => Self::gelu(self.op.input(), dst), UnaryOp::Tanh => Self::tanh(self.op.input(), dst), @@ -196,9 +196,9 @@ impl_cpu_unary!(bf16, bf16::from_f32); pub fn cpu_unary(unary: Unary, dst: Tensor) -> Result { match dst.dt() { - DType::F32 => CPU::::new(unary).apply(dst), - DType::F16 => CPU::::new(unary).apply(dst), - DType::BF16 => CPU::::new(unary).apply(dst), + DType::F32 => CPU::::new(unary).apply_cpu(dst), + DType::F16 => CPU::::new(unary).apply_cpu(dst), + DType::BF16 => CPU::::new(unary).apply_cpu(dst), _ => todo!(), } } @@ -222,7 +222,7 @@ macro_rules! impl_cpu_binary { } impl CPUOperation for CPU<$dtype, Binary> { - fn apply(&self, dst: Tensor) -> Result { + fn apply_cpu(&self, dst: Tensor) -> Result { match self.op.op() { BinaryOp::Add => Self::add(self.op.lhs(), self.op.rhs(), dst), BinaryOp::Sub => Self::sub(self.op.lhs(), self.op.rhs(), dst), @@ -240,9 +240,9 @@ impl_cpu_binary!(bf16); pub fn cpu_binary(binary: Binary, dst: Tensor) -> Result { match dst.dt() { - DType::F32 => CPU::::new(binary).apply(dst), - DType::F16 => CPU::::new(binary).apply(dst), - DType::BF16 => CPU::::new(binary).apply(dst), + DType::F32 => CPU::::new(binary).apply_cpu(dst), + DType::F16 => CPU::::new(binary).apply_cpu(dst), + DType::BF16 => CPU::::new(binary).apply_cpu(dst), _ => todo!(), } } diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index c72b1d0f..8021a8a3 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -105,30 +105,29 @@ fn merge(data: &[f32], offset: usize, skip: usize) -> Vec { } fn slice(src: &[f32], start: &[usize], stop: &[usize]) -> Vec { - let stop_numel: usize = stop.iter().product(); - let start_numel: usize = stop.iter().product(); - assert!(stop_numel >= start_numel); - - let mut dst = vec![0.0; stop_numel - start_numel]; - - /* - start: [0, 0, 0, 8] - stop: [1, 1, 1, 16] - for - */ - - let mut src_idx = 0; - let mut dst_idx = 0; - for i in 0..start.len() { - let mut src_stride = start[i]; - let mut dst_stride = 0; - while src_stride < stop[i] { - dst[dst_idx] = src[src_idx]; - src_idx += src_stride; - dst_idx += dst_stride; - src_stride += 1; - dst_stride += 1; + assert!(start.len() == stop.len()); + start.iter().zip(stop.iter()).for_each(|(s, t)| { + assert!(s < t); + }); + + let src_shape = [2, 16, 16]; // Corrected input shape + let src_strides = [16 * 16, 16, 1]; + + let delta: Vec = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect(); + let dst_shape: Vec = delta.clone(); + let dst_numel: usize = delta.iter().product(); + + let mut dst = vec![0.0; dst_numel]; + + for i in 0..dst_numel { + let mut src_index = 0; + let mut tmp = i; + for d in 0..delta.len() { + let coord = tmp / dst_shape[d + 1..].iter().product::().max(1); + tmp %= dst_shape[d + 1..].iter().product::().max(1); + src_index += (coord + start[d]) * src_strides[d]; } + dst[i] = src[src_index]; } dst @@ -175,48 +174,61 @@ fn transpose( fn rope(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { println!("Ratchet RoPE"); let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap(); - let el_count = batches * num_heads * seq_len * head_dim; let half_dim = dim / 2; let theta = compute_theta(dim, seq_len, base, offset); + println!("Theta: {:?}", theta); let (sin, cos): (Vec, Vec) = theta.iter().map(|i| i.sin_cos()).unzip(); + println!("Cos: {:?}", cos); + println!("Sin: {:?}", sin); - let mut intermediate = Vec::with_capacity(el_count); + println!("Cos length: {:?}", cos.len()); + println!("Sin length: {:?}", sin.len()); - let chunk_offset = half_dim; - let skip = 0; + let x1 = slice(&src, &[0, 0, 0], &[num_heads, seq_len, half_dim]); + let x2 = slice(&src, &[0, 0, half_dim], &[num_heads, seq_len, dim]); + println!("X1: {:?}", x1); + println!("X1 length: {:?}", x1.len()); + println!("X2: {:?}", x2); + println!("X2 length: {:?}", x2.len()); - let (x1, x2) = chunk_by_offset(&src, chunk_offset, skip); - - let (x1_cos, x1_sin): (Vec, Vec) = x1 + let x1_cos = x1 .iter() .enumerate() - .map(|(i, x)| (x * cos[i % cos.len()], x * sin[i % sin.len()])) - .unzip(); - - let (x2_cos, x2_sin): (Vec, Vec) = x2 + .map(|(i, x)| x * cos[i % cos.len()]) + .collect::>(); + let x2_sin = x2 .iter() .enumerate() - .map(|(i, x)| (x * cos[i % cos.len()], x * sin[i % sin.len()])) - .unzip(); + .map(|(i, x)| x * sin[i % sin.len()]) + .collect::>(); - x1_cos.iter().zip(x2_sin).for_each(|(x1_cos, x2_sin)| { - intermediate.push(x1_cos - x2_sin); - }); + let r1 = x1_cos + .iter() + .zip(x2_sin.iter()) + .map(|(x1, x2)| x1 - x2) + .collect::>(); - x1_sin.iter().zip(x2_cos).for_each(|(x1_sin, x2_cos)| { - intermediate.push(x1_sin + x2_cos); - }); + let x1_sin = x1 + .iter() + .enumerate() + .map(|(i, x)| x * sin[i % sin.len()]) + .collect::>(); + let x2_cos = x2 + .iter() + .enumerate() + .map(|(i, x)| x * cos[i % cos.len()]) + .collect::>(); + let r2 = x1_sin + .iter() + .zip(x2_cos.iter()) + .map(|(x1, x2)| x1 + x2) + .collect::>(); - let skip = head_dim.abs_diff(dim); - let mut dst = merge(&intermediate, half_dim, skip); + println!("R1: {:?}", r1); + println!("R2: {:?}", r2); - if dim < head_dim { - let offset = (el_count / head_dim) * dim; - let appendix = &mut src[offset..].to_vec(); - dst.append(appendix); - } - dst + vec![] } fn rope_2(src: Vec, shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec { diff --git a/crates/ratchet-core/src/cpu/slice.rs b/crates/ratchet-core/src/cpu/slice.rs new file mode 100644 index 00000000..0b1c2b4e --- /dev/null +++ b/crates/ratchet-core/src/cpu/slice.rs @@ -0,0 +1,3 @@ +use crate::{Slice, Tensor}; + +pub fn cpu_slice(op: Slice, dst: Tensor) -> Result {} diff --git a/crates/ratchet-core/src/op.rs b/crates/ratchet-core/src/op.rs index 10bb5ab4..d3598fa4 100644 --- a/crates/ratchet-core/src/op.rs +++ b/crates/ratchet-core/src/op.rs @@ -363,5 +363,5 @@ pub trait GPUOperation: Operation { } pub trait CPUOperation: Operation { - fn apply(&self, dst: Tensor) -> Result; + fn apply_cpu(&self, dst: Tensor) -> Result; } diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index 5b8c108e..99cc5917 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -759,7 +759,7 @@ impl Tensor { match self.op().clone() { LazyOp::Binary(b) => cpu_binary(b, dst).ok(), LazyOp::Cast(c) => cpu_cast(c, dst).ok(), - LazyOp::Matmul(m) => m.apply(dst).ok(), + LazyOp::Matmul(m) => m.apply_cpu(dst).ok(), LazyOp::Softmax(_s) => todo!(), LazyOp::RoPE(r) => cpu_rope(r, dst).ok(), LazyOp::Unary(u) => cpu_unary(u, dst).ok(),