Skip to content

Commit

Permalink
refactor: save/load optimizer state (#190)
Browse files Browse the repository at this point in the history
* add optimizer state functions

* added OptimizerState trait

* impl save/load safetensors

* impl unflatten

* test saving and loading optimizer states

* clippy & fmt
  • Loading branch information
minghuaw authored Feb 3, 2025
1 parent a410938 commit 99e345d
Show file tree
Hide file tree
Showing 13 changed files with 415 additions and 21 deletions.
40 changes: 40 additions & 0 deletions mlx-rs/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,21 @@ pub enum IoError {
#[error(transparent)]
NulError(#[from] NulError),

/// Error with unfalttening the loaded optimizer state
#[error(transparent)]
Unflatten(#[from] UnflattenError),

/// Exception
#[error(transparent)]
Exception(#[from] Exception),
}

impl From<Infallible> for IoError {
fn from(_: Infallible) -> Self {
unreachable!()
}
}

impl From<RawException> for IoError {
#[track_caller]
fn from(e: RawException) -> Self {
Expand Down Expand Up @@ -87,6 +97,36 @@ pub enum AsSliceError {
Exception(#[from] Exception),
}

/// Error with unflattening a loaded optimizer state
#[derive(Debug, PartialEq, Error)]
pub enum UnflattenError {
/// Expecting next (key, value) pair, found none
#[error("Expecting next (key, value) pair, found none")]
ExpectingNextPair,

/// The key is not in a valid format
#[error("Invalid key")]
InvalidKey,
}

/// Error with loading an optimizer state
#[derive(Debug, PartialEq, Error)]
pub enum OptimizerStateLoadError {
/// Error with io operations
#[error(transparent)]
Io(#[from] IoError),

/// Error with unflattening the optimizer state
#[error(transparent)]
Unflatten(#[from] UnflattenError),
}

impl From<Infallible> for OptimizerStateLoadError {
fn from(_: Infallible) -> Self {
unreachable!()
}
}

cfg_safetensors! {
/// Error associated with conversion between `safetensors::tensor::TensorView` and `Array`
/// when the data type is not supported.
Expand Down
14 changes: 12 additions & 2 deletions mlx-rs/src/optimizers/adadelta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ generate_builder! {

/// Inner state
#[builder(ignore)]
pub state: OptimizerState<(Array, Array)>,
pub state: State<(Array, Array)>,
}
}

Expand All @@ -63,7 +63,7 @@ fn build_adadelta(builder: AdaDeltaBuilder) -> Result<AdaDelta, AdaDeltaBuildErr
lr: array!(builder.lr),
rho: array!(rho),
eps: array!(eps),
state: OptimizerState::new(),
state: State::new(),
})
}

Expand All @@ -76,6 +76,16 @@ impl AdaDelta {
}

impl Optimizer for AdaDelta {
type State = State<(Array, Array)>;

fn state(&self) -> &Self::State {
&self.state
}

fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}

fn update_single(
&mut self,
key: &Rc<str>,
Expand Down
105 changes: 103 additions & 2 deletions mlx-rs/src/optimizers/adafactor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,97 @@ pub struct AdafactorState {
pub(crate) exp_avg: Option<Array>,
}

impl OptimizerState for State<AdafactorState> {
type UnflattenError = UnflattenError;

fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)> {
self.iter().flat_map(|(k, v)| {
let mut iter = vec![(Rc::from(format!("{}.step", k)), &v.step)];

if let Some(exp_avg_sq_row) = &v.exp_avg_sq_row {
iter.push((Rc::from(format!("{}.exp_avg_sq_row", k)), exp_avg_sq_row));
}

if let Some(exp_avg_sq_col) = &v.exp_avg_sq_col {
iter.push((Rc::from(format!("{}.exp_avg_sq_col", k)), exp_avg_sq_col));
}

if let Some(exp_avg_sq) = &v.exp_avg_sq {
iter.push((Rc::from(format!("{}.exp_avg_sq", k)), exp_avg_sq));
}

if let Some(exp_avg) = &v.exp_avg {
iter.push((Rc::from(format!("{}.exp_avg", k)), exp_avg));
}

iter
})
}

fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)> {
self.iter_mut().flat_map(|(k, v)| {
let mut iter = vec![(Rc::from(format!("{}.step", k)), &mut v.step)];

if let Some(exp_avg_sq_row) = &mut v.exp_avg_sq_row {
iter.push((Rc::from(format!("{}.exp_avg_sq_row", k)), exp_avg_sq_row));
}

if let Some(exp_avg_sq_col) = &mut v.exp_avg_sq_col {
iter.push((Rc::from(format!("{}.exp_avg_sq_col", k)), exp_avg_sq_col));
}

if let Some(exp_avg_sq) = &mut v.exp_avg_sq {
iter.push((Rc::from(format!("{}.exp_avg_sq", k)), exp_avg_sq));
}

if let Some(exp_avg) = &mut v.exp_avg {
iter.push((Rc::from(format!("{}.exp_avg", k)), exp_avg));
}

iter
})
}

fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
where
Self: Sized,
I: IntoIterator<Item = (K, Array)>,
K: Ord + AsRef<str> + Into<Rc<str>>,
{
let mut state = State::new();
let iter = input
.into_iter()
.sorted_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()));

for (k, v) in iter {
let key = k.into();
let mut parts = key.rsplit('.');
let suffix = parts.next().ok_or(UnflattenError::InvalidKey)?;
let prefix = parts.next().ok_or(UnflattenError::InvalidKey)?;

let prefix = Rc::from(prefix);
let state = state.entry(prefix).or_insert_with(|| AdafactorState {
step: array!(AdafactorState::DEFAULT_STEP),
exp_avg_sq_row: None,
exp_avg_sq_col: None,
exp_avg_sq: None,
exp_avg: None,
});

match suffix {
"step" => state.step = v,
"exp_avg_sq_row" => state.exp_avg_sq_row = Some(v),
"exp_avg_sq_col" => state.exp_avg_sq_col = Some(v),
"exp_avg_sq" => state.exp_avg_sq = Some(v),
"exp_avg" => state.exp_avg = Some(v),
_ => return Err(UnflattenError::InvalidKey),
}
}

Ok(state)
}
}

impl AdafactorState {
/// Default value for `step`
pub const DEFAULT_STEP: i32 = 0;
Expand Down Expand Up @@ -148,7 +239,7 @@ generate_builder! {

/// Inner state.
#[builder(ignore)]
pub state: OptimizerState<AdafactorState>,
pub state: State<AdafactorState>,
}
}

Expand Down Expand Up @@ -176,7 +267,7 @@ fn build_adafactor(builder: AdafactorBuilder) -> Result<Adafactor, AdafactorBuil
scale_parameter,
relative_step,
warmup_init,
state: OptimizerState::new(),
state: State::new(),
})
}

Expand Down Expand Up @@ -253,6 +344,16 @@ fn compute_lr(
}

impl Optimizer for Adafactor {
type State = State<AdafactorState>;

fn state(&self) -> &Self::State {
&self.state
}

fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}

fn update_single(
&mut self,
key: &std::rc::Rc<str>,
Expand Down
14 changes: 12 additions & 2 deletions mlx-rs/src/optimizers/adagrad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ generate_builder! {

/// Inner state
#[builder(ignore)]
pub state: OptimizerState,
pub state: State,
}
}

Expand All @@ -43,7 +43,7 @@ fn build_adagrad(builder: AdaGradBuilder) -> Result<AdaGrad, Infallible> {
Ok(AdaGrad {
lr: array!(builder.lr),
eps,
state: OptimizerState::new(),
state: State::new(),
})
}

Expand All @@ -53,6 +53,16 @@ impl AdaGrad {
}

impl Optimizer for AdaGrad {
type State = State;

fn state(&self) -> &Self::State {
&self.state
}

fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}

fn update_single(
&mut self,
key: &Rc<str>,
Expand Down
14 changes: 12 additions & 2 deletions mlx-rs/src/optimizers/adam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ generate_builder! {

/// Inner state
#[builder(ignore)]
pub state: OptimizerState<(Array, Array)>,
pub state: State<(Array, Array)>,
}
}

Expand All @@ -55,7 +55,7 @@ fn build_adam(builder: AdamBuilder) -> Result<Adam, Infallible> {
lr,
betas: (array!(betas.0), array!(betas.1)),
eps,
state: OptimizerState::new(),
state: State::new(),
})
}

Expand All @@ -68,6 +68,16 @@ impl Adam {
}

impl Optimizer for Adam {
type State = State<(Array, Array)>;

fn state(&self) -> &Self::State {
&self.state
}

fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}

fn update_single(
&mut self,
key: &Rc<str>,
Expand Down
14 changes: 12 additions & 2 deletions mlx-rs/src/optimizers/adamax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ generate_builder! {

/// Inner state.
#[builder(ignore)]
pub state: OptimizerState<(Array, Array)>,
pub state: State<(Array, Array)>,
}
}

Expand All @@ -52,7 +52,7 @@ fn build_adamax(builder: AdamaxBuilder) -> Result<Adamax, Infallible> {
lr: array!(lr),
betas: (array!(betas.0), array!(betas.1)),
eps: array!(eps),
state: OptimizerState::new(),
state: State::new(),
})
}

Expand All @@ -65,6 +65,16 @@ impl Adamax {
}

impl Optimizer for Adamax {
type State = State<(Array, Array)>;

fn state(&self) -> &Self::State {
&self.state
}

fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}

fn update_single(
&mut self,
key: &Rc<str>,
Expand Down
14 changes: 12 additions & 2 deletions mlx-rs/src/optimizers/adamw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ generate_builder! {

/// Inner state.
#[builder(ignore)]
pub state: OptimizerState<(Array, Array)>,
pub state: State<(Array, Array)>,
}
}

Expand All @@ -65,7 +65,7 @@ fn build_adamw(builder: AdamWBuilder) -> Result<AdamW, Infallible> {
betas: (array!(betas.0), array!(betas.1)),
eps: array!(eps),
weight_decay: array!(weight_decay),
state: OptimizerState::new(),
state: State::new(),
})
}

Expand All @@ -81,6 +81,16 @@ impl AdamW {
}

impl Optimizer for AdamW {
type State = State<(Array, Array)>;

fn state(&self) -> &Self::State {
&self.state
}

fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}

fn update_single(
&mut self,
key: &std::rc::Rc<str>,
Expand Down
Loading

0 comments on commit 99e345d

Please sign in to comment.