Skip to content

Commit

Permalink
chore: theta matches
Browse files Browse the repository at this point in the history
  • Loading branch information
FL33TW00D committed Oct 2, 2024
1 parent 44bc1ec commit 81f4bfc
Showing 1 changed file with 8 additions and 24 deletions.
32 changes: 8 additions & 24 deletions crates/ratchet-core/src/cpu/rope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,19 @@ pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result<Tensor, OperationError> {
Ok(dst)
}

fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Vec<f32>, Vec<f32>) {
fn compute_theta(dim: usize, seq_len: usize, base: f32, offset: usize) -> Vec<f32> {
let half_dim = dim / 2;
println!("Half dim: {}", half_dim);

let positions = (offset..seq_len + offset)
.map(|x| x as f32)
.collect::<Vec<f32>>();

println!("Positions: {:?}", positions);

let log_base = base.ln();

println!("Log base: {}", log_base);

let inv_freqs = (0..half_dim)
.map(|i| -(i as f32))
.map(|i| i * log_base / half_dim as f32)
.map(|i| i * base.ln() / half_dim as f32)
.map(f32::exp)
.collect::<Vec<f32>>();

println!("Inverse Frequencies: {:?}", inv_freqs);

let p_shape = shape!(seq_len, 1);
let p_strides = Strides::from(&p_shape);
let i_shape = shape!(1, half_dim);
Expand All @@ -62,10 +53,7 @@ fn calculate_sincos(dim: usize, seq_len: usize, base: f32, offset: usize) -> (Ve
)
.unwrap();

println!("THETA: {:?}", theta);

let (sin_theta, cos_theta) = theta.iter().map(|i| i.sin_cos()).unzip();
(sin_theta, cos_theta)
theta
}

#[inline]
Expand Down Expand Up @@ -190,7 +178,8 @@ fn rope(src: Vec<f32>, shape: &Shape, dim: usize, base: f32, offset: usize) -> V
let el_count = batches * num_heads * seq_len * head_dim;

let half_dim = dim / 2;
let (sin, cos) = calculate_sincos(dim, seq_len, base, offset);
let theta = compute_theta(dim, seq_len, base, offset);
let (sin, cos): (Vec<f32>, Vec<f32>) = theta.iter().map(|i| i.sin_cos()).unzip();

let mut intermediate = Vec::with_capacity(el_count);

Expand Down Expand Up @@ -219,8 +208,6 @@ fn rope(src: Vec<f32>, shape: &Shape, dim: usize, base: f32, offset: usize) -> V
intermediate.push(x1_sin + x2_cos);
});

let out_shape = shape!(batches, num_heads, seq_len, head_dim);

let skip = head_dim.abs_diff(dim);
let mut dst = merge(&intermediate, half_dim, skip);

Expand All @@ -237,15 +224,14 @@ fn rope_2(src: Vec<f32>, shape: &Shape, dim: usize, base: f32, offset: usize) ->
let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap();
let el_count = batches * num_heads * seq_len * head_dim;

let half_dim = dim / 2;
let (sin, cos) = calculate_sincos(dim, seq_len, base, offset);
let theta = compute_theta(dim, seq_len, base, offset);
let (sin, cos): (Vec<f32>, Vec<f32>) = theta.iter().map(|i| i.sin_cos()).unzip();

println!("cos: {:?}", cos);
println!("sin: {:?}", sin);

let src = transpose(src, &shape, 1, 2).unwrap();
let mut dst = vec![0.0; el_count];
let b = batches;
let t = num_heads;
let h = seq_len;
let d = head_dim;
Expand All @@ -265,7 +251,5 @@ fn rope_2(src: Vec<f32>, shape: &Shape, dim: usize, base: f32, offset: usize) ->
}
});

let dst = transpose(dst, &shape, 1, 2).unwrap();

dst
transpose(dst, &shape, 1, 2).unwrap()
}

0 comments on commit 81f4bfc

Please sign in to comment.