diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 87cca39e..8e4d43ac 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -102,7 +102,11 @@ from jetstream.core.metrics.prometheus import JetstreamMetricsCollector import numpy as np -log_level = os.getenv("LOG_LEVEL", "WARNING").upper() +from jax.experimental import layout as jax_layout +DLL = jax_layout.DeviceLocalLayout +Layout = jax_layout.Layout + +log_level = os.getenv("LOG_LEVEL", "DEBUG").upper() logger = logging.getLogger("JetstreamLogger") logger.propagate = False @@ -405,6 +409,29 @@ def __init__( self._jax_padding = jax_padding + ##### Auto layout compile for interleaved engine + self._generate_executables = [None for _ in self._generate_engines] + self._cached_insert = [None for _ in self._generate_engines] + self._cached_prefill = [None for _ in self._prefill_engines] + self._decode_states = [None for _ in self._generate_engines] + if self._interleaved_mode: + for idx in range(len(self._generate_engines)): + logger.debug("Compiling interleaved engine {}".format(idx)) + engine = self._generate_engines[idx] + params = self._generate_params[idx] + engine, params, gen_fn, prefill_fn, insert_fn, decode_state = self._auto_layout_compile(engine, params) + + self._prefill_engines[idx] = engine + self._generate_engines[idx] = engine + self._prefill_params[idx] = params + self._generate_params[idx] = params + self._cached_prefill[idx] = prefill_fn + self._cached_insert[idx] = insert_fn + self._generate_executables[idx] = gen_fn + + self._decode_states[idx] = decode_state + + # Create all threads self._prefill_threads = [ JetThread( @@ -670,6 +697,56 @@ def _do_chunked_prefill( return prefill_result, first_token + def _auto_layout_compile(self, engine, params): + logger.debug("Compiling generate function") + generate_executable, params, decode_state_executable = engine.aot_compile( + params, pass_rng_shape=False + ) + decode_state = decode_state_executable(None) + + # prefill + interesting_buckets = [ + 64, + 128, + 256, + 512, + 1024, + ] + + cached_prefill = {} + cached_insert = {} + for length in interesting_buckets: + i32_scalar = jax.ShapeDtypeStruct((), int) + logger.debug("Compiling prefill: %d", length) + input_data = jax.ShapeDtypeStruct((length,), jax.numpy.dtype("int32")) + + cached_prefill[length] = ( + jax.jit( + engine.prefill_aot, + in_shardings=(engine.param_layouts, None, None), + out_shardings=(Layout(DLL.AUTO), Layout(DLL.AUTO)), + ).lower(params, input_data, i32_scalar) + ).compile(compiler_options=None) + + logger.debug("Generate dummy prefix: %d", length) + dummy_tokens = jax.numpy.ones(shape=(length,), dtype=jax.numpy.dtype("int32")) + prefix_shapes = jax.eval_shape(engine.prefill_aot, params, dummy_tokens, 1) + + logger.debug("Compiling insert: %d", length) + prefill_output_layout, _ = cached_prefill[length].output_layouts + logger.debug("Prefill output layout: {}".format(prefill_output_layout)) + logger.debug("Prefix shapes: {}".format(prefix_shapes)) + i32_scalar = jax.ShapeDtypeStruct((), int) + cached_insert[length] = ( + jax.jit( + engine.insert, + in_shardings=(prefill_output_layout, engine.decode_state_layouts, None), + out_shardings=(engine.decode_state_layouts), + donate_argnames=("decode_state"), + ).lower(prefix_shapes[0], engine.decode_state_shapes, i32_scalar) + ).compile(compiler_options=None) + return engine, params, generate_executable, cached_prefill, cached_insert, decode_state + def _prefill_thread(self, idx: int): """Thread which runs in the background performing prefills.""" logger.info("Spinning up prefill thread %d.", idx) @@ -683,6 +760,12 @@ def _prefill_thread(self, idx: int): thread_name = f"Prefill thread {idx}" ThreadDebugLog(thread_name, f"Prefill params {idx} loaded.") + if not self._interleaved_mode: + prefill_engine, prefill_params, gen_fn, prefill_fn, insert_fn, _ = self._auto_layout_compile( + prefill_engine, prefill_params + ) + self._cached_prefill[idx] = prefill_fn + while self.live: my_transfer_backlog = self._transfer_backlogs[idx] # The prefill thread can just sleep until it has work to do. @@ -759,10 +842,11 @@ def _prefill_thread(self, idx: int): ) else: # Compute new kv cache for the prefill_content. - prefill_result, first_token = prefill_engine.prefill( - params=final_prefill_params, - padded_tokens=padded_tokens, - true_length=true_length, + assert padded_tokens.shape[0] in self._cached_prefill[idx] + prefill_result, first_token = self._cached_prefill[idx][padded_tokens.shape[0]]( + final_prefill_params, + padded_tokens, + true_length, ) request.complete = np.zeros( @@ -967,10 +1051,11 @@ def _insert_if_possible( else: break - decode_state = generate_engine.insert( + length = new_request.prefill_result['cache']['decoder']['layers_0']['self_attention']['KVCache_0']['cache_prefill_segment_id'].value.shape[1] + decode_state = self._cached_insert[idx][length]( new_request.prefill_result, decode_state, - slot=slot, + slot, # request_id=new_request.request_id, ) ThreadDebugLog( @@ -1115,9 +1200,17 @@ def _generate_thread(self, idx: int): # Keep track of what step tokens were generated at. generate_timestep = 0 # State to store things like running kv cache in. - decode_state = generate_engine.init_decode_state() - generate_params = self._generate_params[idx] + + if not self._interleaved_mode: + generate_engine, generate_params, gen_fn, prefill_fn, insert_fn, decode_state = self._auto_layout_compile( + generate_engine, generate_params + ) + self._generate_executables[idx] = gen_fn + self._decode_states[idx] = decode_state + + decode_state = self._decode_states[idx] + thread_name = f"Generate thread {idx}" ThreadDebugLog(thread_name, f"Generate params {idx} loaded.") time_of_last_generate = time.time() @@ -1178,8 +1271,8 @@ def _generate_thread(self, idx: int): ), "At this point we must have some requests inserted into the slots." # Now we actually take a generate step on requests in the slots. - decode_state, sampled_tokens = generate_engine.generate( - generate_params, decode_state + decode_state, sampled_tokens = self._generate_executables[idx]( + generate_params, decode_state, None ) sampled_tokens.copy_to_host_async() # Respond to detokenization backpressure.