Skip to content

Commit

Permalink
Fix memory calculation causing --compile to error (#756)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtejankar authored Feb 6, 2025
1 parent 59dc300 commit acaa217
Showing 1 changed file with 31 additions and 22 deletions.
53 changes: 31 additions & 22 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,26 @@ def adapter_memory_size(self) -> int:
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
return ADAPTER_MEMORY_FRACTION * total_gpu_memory


def init_graph_wrapper(self, max_total_tokens: int):
self.model_graph_wrapper = GraphCache(
self.model,
self.device,
self.kv_cache,
self.adapter_layers,
self.traced_adapter_layers,
self._forward_context,
max_total_tokens,
self.num_heads,
self.num_kv_heads,
self.sliding_window_blocks,
self.layer_to_lora_weights,
self.punica_wrapper,
)

def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model: bool = False):
logger.info(f'Pre warmup cuda memory: {get_cuda_free_memory(self.device, 1) / (1024 ** 3):.2f} GB')

# The warmup batch is the biggest batch we could ever receive
max_total_tokens = batch.max_input_length + max_new_tokens + get_speculative_tokens()

Expand All @@ -1297,6 +1316,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model
self.kv_dtype,
self.device,
)
logger.info(f'Pre warmup kv init cuda memory: {get_cuda_free_memory(self.device, 1) / (1024 ** 3):.2f} GB')

if not embedding_model:
with warmup_mode():
Expand Down Expand Up @@ -1326,24 +1346,17 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model
# Estimate the memory overhead from CUDA graphs so we can subtract it from the kv cache.
# Needs to be estimated here rather than fully initialized as the graph cache relies on the
# cache manager being set.
self.model_graph_wrapper = GraphCache(
self.model,
self.device,
self.kv_cache,
self.adapter_layers,
self.traced_adapter_layers,
self._forward_context,
max_total_tokens,
self.num_heads,
self.num_kv_heads,
self.sliding_window_blocks,
self.layer_to_lora_weights,
self.punica_wrapper,
)
self.init_graph_wrapper(max_total_tokens)
graph_cache_memory = self.model_graph_wrapper.get_estimated_cache_memory()
logger.info("Estimated graph cache memory: {} MB", graph_cache_memory / 1024 / 1024)
torch.cuda.synchronize(self.device)

logger.info(f'Post warmup cuda memory: {get_cuda_free_memory(self.device, 1) / (1024 ** 3):.2f} GB')
del self.model_graph_wrapper
self.kv_cache = []
torch.cuda.synchronize(self.device)
torch.cuda.empty_cache()
logger.info(f'Post warmup empty_cache cuda memory: {get_cuda_free_memory(self.device, 1) / (1024 ** 3):.2f} GB')
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
Expand All @@ -1358,13 +1371,8 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model
free_memory = max(0, free_memory - graph_cache_memory)
logger.info("Memory remaining for kv cache: {} MB", free_memory / 1024 / 1024)

batch_num_blocks = batch.num_blocks if batch is not None else 0
num_blocks = (
# Leave 5% for some wiggle room
int((free_memory * MEMORY_WIGGLE_ROOM) // total_cache_size)
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch_num_blocks
)
num_blocks = int((free_memory * MEMORY_WIGGLE_ROOM) // total_cache_size)
logger.info(f"num kv blocks: {num_blocks}, num kv tokens: {num_blocks * BLOCK_SIZE}")

del batch

Expand All @@ -1379,7 +1387,8 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model

torch.cuda.synchronize(self.device)

if self.model_graph_wrapper is not None:
if self.compile:
self.init_graph_wrapper(max_total_tokens)
# Warmup the graph cache. Needs to be done after setting cache manager as
# tracing will use the static kv cache tensors
self.model_graph_wrapper.kv_cache = self.kv_cache
Expand Down

0 comments on commit acaa217

Please sign in to comment.