From 52710ea7f3d827daa1f0e89deec170907ecf1883 Mon Sep 17 00:00:00 2001 From: Arnav Garg <106701836+arnavgarg1@users.noreply.github.com> Date: Mon, 23 Dec 2024 04:14:29 -0800 Subject: [PATCH] Fix residual summation for the Solar architecture (#723) --- .../models/custom_modeling/flash_solar_modeling.py | 10 +++++++++- .../punica_kernels/punica_kernels/bgmv/bgmv_config.h | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/server/lorax_server/models/custom_modeling/flash_solar_modeling.py b/server/lorax_server/models/custom_modeling/flash_solar_modeling.py index de9e1a416..b37260a07 100644 --- a/server/lorax_server/models/custom_modeling/flash_solar_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_solar_modeling.py @@ -580,6 +580,12 @@ def forward( # Note, we use index 1 instead of index 0 since index 0 is used when training is enabled bskcn_tv = self.config.bskcn_tv[1] for i, layer in enumerate(self.layers): + # Add residual to hidden states explicitly. We have to do this because the cross-layer + # residuals assume the output hidden state already have the residuals added to it, but the + # LoRAX implementation only adds the residual to the hidden states in the next layer's input_layernorm. + if residual is not None: + hidden_states = hidden_states + residual + if i in self.config.bskcn_1: bskcn_1 = hidden_states if i in self.config.bskcn_2: @@ -589,9 +595,11 @@ def forward( if i in self.config.bskcn_4: hidden_states = (bskcn_2 * bskcn_tv).to(hidden_states.device) + hidden_states * (1 - bskcn_tv) + # Note, we explicitly set residual to None here to skip adding it to the hidden states + # in the input_layernorm layer because we do this explicitly above. hidden_states, residual = layer( hidden_states, - residual, + None, # residual cos, sin, cu_seqlen_prefill, diff --git a/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h b/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h index d1738229b..4907165e7 100644 --- a/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h +++ b/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h @@ -50,6 +50,7 @@ void bgmv_kernel(T *__restrict__ Y, const T *__restrict__ X, f(T, narrow, 14336) \ f(T, narrow, 15360) \ f(T, narrow, 16384) \ + f(T, narrow, 17920) \ f(T, narrow, 18944) \ f(T, narrow, 20480) \ f(T, narrow, 22016) \