Skip to content

Commit

Permalink
Add eager flag and set compile to be always true (#734)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtejankar authored Jan 30, 2025
1 parent 40dc6d3 commit 70d7936
Show file tree
Hide file tree
Showing 13 changed files with 42 additions and 11 deletions.
18 changes: 17 additions & 1 deletion launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,18 @@ struct Args {

/// Whether you want to compile the model into a CUDA graph.
/// This will speed up decoding but increase GPU memory usage.
/// Only use either `--compile` or `--eager`. Using both at the same time will
/// result in an error.
#[clap(long, env, value_enum)]
compile: bool,

/// Whether you want to run the model in eager mode, without
/// CUDA mode compilation, or run it with compilation.
/// Only use either `--compile` or `--eager`. Using both at the same time will
/// result in an error.
#[clap(long, env, value_enum)]
eager: bool,

// The maximum batch size past which CUDA graphs are disabled.
#[clap(default_value = "128", long, env)]
compile_max_batch_size: usize,
Expand Down Expand Up @@ -656,6 +665,7 @@ fn shard_manager(
adapter_source: String,
quantize: Option<Quantization>,
compile: bool,
eager: bool,
compile_max_batch_size: usize,
compile_max_rank: usize,
compile_batch_size: usize,
Expand Down Expand Up @@ -738,10 +748,14 @@ fn shard_manager(
}

// CUDA graph compilation
if compile {
if !eager {
shard_args.push("--compile".to_string());
}

if compile && eager {
panic!("Cannot use both --compile and --eager at the same time.");
}

// Speculative decoding
if let Some(speculative_tokens) = speculative_tokens {
shard_args.push("--speculative-tokens".to_string());
Expand Down Expand Up @@ -1310,6 +1324,7 @@ fn spawn_shards(
let otlp_endpoint = args.otlp_endpoint.clone();
let quantize = args.quantize;
let compile = args.compile;
let eager = args.eager;
let compile_max_batch_size = args.compile_max_batch_size;
let compile_max_rank = args.compile_max_rank;
let compile_batch_size = args.compile_batch_size;
Expand Down Expand Up @@ -1342,6 +1357,7 @@ fn spawn_shards(
adapter_source,
quantize,
compile,
eager,
compile_max_batch_size,
compile_max_rank,
compile_batch_size,
Expand Down
3 changes: 2 additions & 1 deletion server/lorax_server/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.distributed
from loguru import logger
from transformers import (
AutoConfig,
AutoTokenizer,
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(
trust_remote_code: bool = False,
):
if compile:
raise ValueError("`--compile` is not supported with Bloom")
logger.info(f"Model {model_id} does not support CUDA graph compilation. Skipping compilation.")

self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
Expand Down
3 changes: 2 additions & 1 deletion server/lorax_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, List, Optional, Tuple, Type

import torch
from loguru import logger
from opentelemetry import trace
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase

Expand Down Expand Up @@ -499,7 +500,7 @@ def __init__(
trust_remote_code: bool = False,
):
if compile:
raise ValueError("`--compile` is not supported with CausalLM")
logger.info(f"Model {model_id} does not support CUDA graph compilation. Skipping compilation.")

if torch.cuda.is_available():
device = torch.device("cuda")
Expand Down
1 change: 1 addition & 0 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,7 @@ def __init__(
SLIDING_WINDOW_BLOCKS = math.ceil(sliding_window / BLOCK_SIZE)

self.compile = compile

self.model_graph_wrapper: GraphCache = None
self.kv_cache = []

Expand Down
3 changes: 2 additions & 1 deletion server/lorax_server/models/galactica.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.distributed
from loguru import logger
from transformers import (
AutoConfig,
AutoTokenizer,
Expand Down Expand Up @@ -161,7 +162,7 @@ def __init__(
trust_remote_code: bool = False,
):
if compile:
raise ValueError("`--compile` is not supported with GalacticaSharded")
logger.info(f"Model {model_id} does not support CUDA graph compilation. Skipping compilation.")

self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
Expand Down
3 changes: 2 additions & 1 deletion server/lorax_server/models/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.distributed
from loguru import logger
from transformers import (
AutoConfig,
AutoTokenizer,
Expand Down Expand Up @@ -29,7 +30,7 @@ def __init__(
trust_remote_code: bool = False,
):
if compile:
raise ValueError("`--compile` is not supported with GPT-NeoX")
logger.info(f"Model {model_id} does not support CUDA graph compilation. Skipping compilation.")

self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
Expand Down
4 changes: 4 additions & 0 deletions server/lorax_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ def check_initialized(self):
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
)

@property
def supports_cuda_graph_compilation(self) -> bool:
return True

@property
def supports_adapter_loading(self) -> bool:
return False
Expand Down
3 changes: 2 additions & 1 deletion server/lorax_server/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.distributed
from huggingface_hub import hf_hub_download
from loguru import logger
from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase

Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(
trust_remote_code: bool = False,
):
if compile:
raise ValueError("`--compile` is not supported with MPT")
logger.info(f"Model {model_id} does not support CUDA graph compilation. Skipping compilation.")

self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
Expand Down
3 changes: 2 additions & 1 deletion server/lorax_server/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.distributed
from loguru import logger
from transformers import (
AutoConfig,
AutoTokenizer,
Expand All @@ -27,7 +28,7 @@ def __init__(
trust_remote_code: bool = False,
):
if compile:
raise ValueError("`--compile` is not supported with OPT")
logger.info(f"Model {model_id} does not support CUDA graph compilation. Skipping compilation.")

self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
Expand Down
3 changes: 2 additions & 1 deletion server/lorax_server/models/rw.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Tuple

import torch
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer

from lorax_server.models.causal_lm import CausalLM
Expand All @@ -17,7 +18,7 @@ def __init__(
trust_remote_code: bool = False,
):
if compile:
raise ValueError("`--compile` is not supported with RW")
logger.info(f"Model {model_id} does not support CUDA graph compilation. Skipping compilation.")

if torch.cuda.is_available():
device = torch.device("cuda")
Expand Down
3 changes: 2 additions & 1 deletion server/lorax_server/models/santacoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.distributed
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer

from lorax_server.models.causal_lm import CausalLM
Expand All @@ -24,7 +25,7 @@ def __init__(
trust_remote_code: bool = False,
):
if compile:
raise ValueError("`--compile` is not supported with SantaCoder")
logger.info(f"Model {model_id} does not support CUDA graph compilation. Skipping compilation.")

if torch.cuda.is_available():
device = torch.device("cuda")
Expand Down
3 changes: 2 additions & 1 deletion server/lorax_server/models/seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, List, Optional, Tuple, Type

import torch
from loguru import logger
from opentelemetry import trace
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedTokenizerBase

Expand Down Expand Up @@ -488,7 +489,7 @@ def __init__(
trust_remote_code: bool = False,
):
if compile:
raise ValueError("`--compile` is not supported with Seq2SeqLM")
logger.info(f"Model {model_id} does not support CUDA graph compilation. Skipping compilation.")

if torch.cuda.is_available():
device = torch.device("cuda")
Expand Down
3 changes: 2 additions & 1 deletion server/lorax_server/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.distributed
from loguru import logger
from transformers import (
AutoConfig,
AutoTokenizer,
Expand Down Expand Up @@ -29,7 +30,7 @@ def __init__(
trust_remote_code: bool = False,
):
if compile:
raise ValueError("`--compile` is not supported with T5")
logger.info(f"Model {model_id} does not support CUDA graph compilation. Skipping compilation.")

self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
Expand Down

0 comments on commit 70d7936

Please sign in to comment.