Skip to content

Commit

Permalink
chore: remove unused strided iterator. may be useful later
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Oct 28, 2024
1 parent 52863d2 commit be77442
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 110 deletions.
71 changes: 0 additions & 71 deletions crates/ratchet-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,77 +36,6 @@ pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result<Tensor, OperationError
}
}

pub struct StridedIterator<'a> {
shape: &'a Shape,
strides: &'a Strides,
next_index: Option<usize>,
multi_index: Vec<usize>,
}

impl<'a> StridedIterator<'a> {
pub fn new(shape: &'a Shape, strides: &'a Strides, start_offset: usize) -> Self {
Self {
shape,
strides,
next_index: if shape.numel() == 0 {
None
} else {
Some(start_offset)
},
multi_index: vec![0; shape.len()],
}
}
}

impl<'a> Iterator for StridedIterator<'a> {
type Item = usize;

fn next(&mut self) -> Option<Self::Item> {
let storage_index = match self.next_index {
None => return None,
Some(storage_index) => storage_index,
};
let mut updated = false;
let mut next_storage_index = storage_index;
for ((multi_i, max_i), stride_i) in self
.multi_index
.iter_mut()
.zip(self.shape.iter())
.zip(self.strides.iter())
.rev()
{
let next_i = *multi_i + 1;
if next_i < *max_i {
*multi_i = next_i;
updated = true;
next_storage_index += *stride_i as usize;
break;
} else {
next_storage_index -= *multi_i * *stride_i as usize;
*multi_i = 0
}
}
self.next_index = if updated {
Some(next_storage_index)
} else {
None
};
Some(storage_index)
}
}

impl<'a> From<(&'a Shape, &'a Strides)> for StridedIterator<'a> {
fn from((shape, strides): (&'a Shape, &'a Strides)) -> Self {
StridedIterator::new(shape, strides, 0)
}
}

impl<'a> From<(&'a Shape, &'a Strides, usize)> for StridedIterator<'a> {
fn from((shape, strides, offset): (&'a Shape, &'a Strides, usize)) -> Self {
StridedIterator::new(shape, strides, offset)
}
}

pub trait CPUOperation: Operation {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError>;
}
Expand Down
40 changes: 1 addition & 39 deletions crates/ratchet-core/src/cpu/rope.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
concat,
cpu::{cpu_store_result, gemm::gemm},
shape, DType, OperationError, RoPE, Shape, StridedIterator, Strides, Tensor,
shape, DType, OperationError, RoPE, Shape, Strides, Tensor,
};
use anyhow::anyhow;

Expand Down Expand Up @@ -83,44 +83,6 @@ fn slice(src: &[f32], src_strides: &Strides, start: &[usize], stop: &[usize]) ->
dst
}

// Generic transpose function
fn transpose(
src: Vec<f32>,
shape: &Shape,
dim1: usize,
dim2: usize,
) -> Result<Vec<f32>, OperationError> {
let rank = shape.rank();
if dim1 == dim2 {
return Ok(src);
}
if rank <= dim1 || rank <= dim2 {
return Err(anyhow!("Invalid dimensions for transpose operation").into());
}
let mut dims = shape.to_vec();
let mut strides = Strides::from(shape).to_vec();
println!("dims: {:?}", dims);
println!("strides: {:?}", strides);
dims.swap(dim1, dim2);
strides.swap(dim1, dim2);
println!("dims: {:?}", dims);
println!("strides: {:?}", strides);

let shape_t = Shape::from(dims);
let strides_t = Strides::from(strides);

let mut result = vec![0.0; src.len()];
let strided_iter = StridedIterator::new(&shape_t, &strides_t, 0);
let strided_iter2 = StridedIterator::new(&shape_t, &strides_t, 0);
let indices = strided_iter2.collect::<Vec<_>>();
println!("indices: {:?}", indices);
for (index, dst_index) in strided_iter.enumerate() {
result[dst_index] = src[index];
}

Ok(result)
}

fn rope(src: Vec<f32>, shape: &Shape, dim: usize, base: f32, offset: usize) -> Vec<f32> {
let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap();

Expand Down

0 comments on commit be77442

Please sign in to comment.