From 06eb8fc9a18dbd8841a0794826841ad170cc7926 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Mon, 3 Jun 2024 11:55:58 -0700 Subject: [PATCH 01/42] Add tensorboard plugin dep for remote access (#97) --- docs/profiling-with-jax-profiler-and-tensorboard.md | 3 ++- jetstream/tools/maxtext/model_ckpt_conversion.sh | 4 ++-- .../tools/maxtext/model_ckpt_finetune_with_aqt.sh | 2 +- requirements.in | 4 +++- requirements.txt | 11 ++++++++++- 5 files changed, 18 insertions(+), 6 deletions(-) diff --git a/docs/profiling-with-jax-profiler-and-tensorboard.md b/docs/profiling-with-jax-profiler-and-tensorboard.md index 3727c387..8006ffb3 100644 --- a/docs/profiling-with-jax-profiler-and-tensorboard.md +++ b/docs/profiling-with-jax-profiler-and-tensorboard.md @@ -10,7 +10,8 @@ Following the [JAX official manual profiling approach](https://jax.readthedocs.i ```bash tensorboard --logdir /tmp/tensorboard/ ``` -You should be able to load TensorBoard at http://localhost:6006/. You can specify a different port with the `--port` flag. +You should be able to load TensorBoard at http://localhost:6006/. You can specify a different port with the `--port` flag. If you are running on a remote Cloud TPU VM, the `tensorboard-plugin-profile` python package enables remote access to tensorboard endpoints (JetStream deps include this package). + 2. Start JetStream MaxText server: ```bash diff --git a/jetstream/tools/maxtext/model_ckpt_conversion.sh b/jetstream/tools/maxtext/model_ckpt_conversion.sh index 19a62b74..8e2b4d83 100644 --- a/jetstream/tools/maxtext/model_ckpt_conversion.sh +++ b/jetstream/tools/maxtext/model_ckpt_conversion.sh @@ -48,7 +48,7 @@ gcloud storage buckets create ${MODEL_BUCKET} --location=${BUCKET_LOCATION} || t gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} --location=${BUCKET_LOCATION} || true gcloud storage buckets create ${DATASET_PATH} --location=${BUCKET_LOCATION} || true -# Covert model checkpoints to MaxText compatible checkpoints. +# Convert model checkpoints to MaxText compatible checkpoints. if [ "$MODEL" == "gemma" ]; then CONVERT_CKPT_SCRIPT="convert_gemma_chkpt.py" JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \ @@ -74,7 +74,7 @@ echo "Written MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL}/${MODEL_ # We define `SCANNED_CKPT_PATH` to refer to the checkpoint subdirectory. export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items -# Covert MaxText compatible checkpoints to unscanned checkpoints. +# Convert MaxText compatible checkpoints to unscanned checkpoints. # Note that the `SCANNED_CKPT_PATH` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx} diff --git a/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh b/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh index 7e6ff1f5..a348bebd 100644 --- a/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh +++ b/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh @@ -66,7 +66,7 @@ checkpoint_period=100 # We will convert the `AQT_CKPT` to unscanned checkpoint in the next step. export AQT_CKPT=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/100/items -# Covert MaxText compatible AQT-fine-tuned checkpoints to unscanned checkpoints. +# Convert MaxText compatible AQT-fine-tuned checkpoints to unscanned checkpoints. # Note that the `AQT_CKPT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx} diff --git a/requirements.in b/requirements.in index eba423d4..459749ae 100644 --- a/requirements.in +++ b/requirements.in @@ -12,4 +12,6 @@ seqio tiktoken blobfile parameterized -shortuuid \ No newline at end of file +shortuuid +# For profiling +tensorboard-plugin-profile \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 057b4f8b..11cb5643 100644 --- a/requirements.txt +++ b/requirements.txt @@ -84,6 +84,8 @@ grpcio==1.60.1 # -r requirements.in # tensorboard # tensorflow +gviz-api==1.10.0 + # via tensorboard-plugin-profile h5py==3.10.0 # via tensorflow idna==3.7 @@ -189,6 +191,7 @@ protobuf==3.20.3 # orbax-checkpoint # seqio # tensorboard + # tensorboard-plugin-profile # tensorflow # tensorflow-hub # tensorflow-metadata @@ -244,13 +247,17 @@ six==1.16.0 # via # astunparse # google-pasta + # gviz-api # ml-collections # promise + # tensorboard-plugin-profile # tensorflow tensorboard==2.13.0 # via tensorflow tensorboard-data-server==0.7.2 # via tensorboard +tensorboard-plugin-profile==2.15.1 + # via -r requirements.in tensorflow==2.13.1 # via tensorflow-text tensorflow-estimator==2.13.0 @@ -300,7 +307,9 @@ urllib3==2.2.0 # blobfile # requests werkzeug==3.0.1 - # via tensorboard + # via + # tensorboard + # tensorboard-plugin-profile wheel==0.42.0 # via # astunparse From b1a1f6a2bb8adafc40687f15250e11468a02b4f6 Mon Sep 17 00:00:00 2001 From: Morgan Du Date: Mon, 3 Jun 2024 15:22:45 -0700 Subject: [PATCH 02/42] Update benchmark config for xlml automation (#96) * Update benchmark config for automation * merge warmup mode --- benchmarks/benchmark_serving.py | 70 +++++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 252cc534..d68be2d2 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -81,6 +81,31 @@ from eval_accuracy import eval_accuracy +def str2bool(v: str) -> bool: + """Convert a string of truth to True or False. + + Args: + - v (str): + - True values are 'y', 'yes', 't', 'true', and '1'; + - False values are 'n', 'no', 'f', 'false', and '0'. + + Returns: + bool: True or False + + Raises: + ValueError if v is anything else. + """ + v = v.lower() + true_values = ["y", "yes", "t", "true", "1"] + false_values = ["n", "no", "f", "false", "0"] + if v in true_values: + return True + elif v in false_values: + return False + else: + raise ValueError(f"Invalid value '{v}'!") + + @dataclass class BenchmarkMetrics: """Data class to store benchmark metrics.""" @@ -226,9 +251,9 @@ def tokenize_dataset( def filter_dataset( tokenized_dataset: list[tuple[str, Any, str, int, int, int]], - max_output_length: Optional[int] = None, + max_output_length: int = 0, ) -> list[InputRequest]: - if max_output_length is None: + if max_output_length != 0: print("In InputRequest, pass in actual output_length for each sample") else: print( @@ -269,7 +294,7 @@ def sample_requests( dataset: list[tuple[Any, Any]], tokenizer: Any, num_requests: int, - max_output_length: Optional[int] = None, + max_output_length: int = 0, oversample_multiplier: float = 1.2, ) -> list[InputRequest]: @@ -319,7 +344,7 @@ async def get_request( for request in input_requests: yield request - if request_rate == float("inf"): + if request_rate == 0.0: # If the request rate is infinity, then we don't need to wait. continue # Sample the request interval from the exponential distribution. @@ -574,9 +599,14 @@ def main(args: argparse.Namespace): max_output_length=args.max_output_length, ) - if args.warmup_first: - print("Warm up start:") + warmup_requests = None + if args.warmup_mode == "full": + warmup_requests = input_requests + elif args.warmup_mode == "sampled": warmup_requests = list(sample_warmup_requests(input_requests)) * 2 + + if warmup_requests: + print(f"Starting {args.warmup_mode} warmup:") benchmark_result, request_outputs = asyncio.run( benchmark( api_url=api_url, @@ -588,7 +618,7 @@ def main(args: argparse.Namespace): priority=args.priority, ) ) - print("Warm up done") + print(f"{args.warmup_mode} warmup completed.") # TODO: Replace this with warmup complete signal once supported. # Wait for server completely warmup before running the benchmark. @@ -630,10 +660,7 @@ def main(args: argparse.Namespace): metrics_json["num_prompts"] = args.num_prompts # Traffic - metrics_json["request_rate"] = ( - args.request_rate if args.request_rate < float("inf") else "inf" - ) - + metrics_json["request_rate"] = args.request_rate metrics_json = {**metrics_json, **benchmark_result} if args.run_eval: metrics_json = {**metrics_json, **eval_json} @@ -661,6 +688,7 @@ def main(args: argparse.Namespace): if __name__ == "__main__": + parser = argparse.ArgumentParser( description="Benchmark the online serving throughput." ) @@ -711,9 +739,9 @@ def main(args: argparse.Namespace): parser.add_argument( "--request-rate", type=float, - default=float("inf"), + default=0.0, help=( - "Number of requests per second. If this is inf, " + "Number of requests per second. If this is 0., " "then all the requests are sent at time 0. " "Otherwise, we use Poisson process to synthesize " "the request arrival times." @@ -729,7 +757,7 @@ def main(args: argparse.Namespace): parser.add_argument( "--max-output-length", type=int, - default=None, + default=0, help=( "The maximum output length for reference request. It would be passed" " to `max_tokens` parameter of the JetStream's DecodeRequest proto," @@ -738,7 +766,8 @@ def main(args: argparse.Namespace): " max_tokens <= (max_target_length - max_prefill_predict_length)." " max_target_length is the maximum length of a sequence;" " max_prefill_predict_length is the maximum length of the" - " input/prefill of a sequence." + " input/prefill of a sequence. Default to 0, in this case, " + "the output length of the golden dataset would be passed." ), ) @@ -792,15 +821,16 @@ def main(args: argparse.Namespace): ) parser.add_argument( "--run-eval", - type=bool, + type=str2bool, default=False, help="Whether to run evaluation script on the saved outputs", ) parser.add_argument( - "--warmup-first", - type=bool, - default=False, - help="Whether to send warmup req first", + "--warmup-mode", + type=str, + default="none", + choices=["none", "sample", "full"], + help="Whether to warmup first, and set the warmup mode", ) parser.add_argument( "--conversation-starter", From e73ba532efeb68d84e49e0474f15f134de32f0c6 Mon Sep 17 00:00:00 2001 From: Morgan Du Date: Mon, 3 Jun 2024 15:50:12 -0700 Subject: [PATCH 03/42] fix typo (#98) --- benchmarks/benchmark_serving.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index d68be2d2..07b36a84 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -829,7 +829,7 @@ def main(args: argparse.Namespace): "--warmup-mode", type=str, default="none", - choices=["none", "sample", "full"], + choices=["none", "sampled", "full"], help="Whether to warmup first, and set the warmup mode", ) parser.add_argument( From cc191034038e0fdb2d89e8bb5bf120f0a464e460 Mon Sep 17 00:00:00 2001 From: Fanhai Lu <154379058+FanhaiLu1@users.noreply.github.com> Date: Fri, 7 Jun 2024 14:56:57 -0400 Subject: [PATCH 04/42] add ssh support for profile (#99) --- docs/profiling-with-jax-profiler-and-tensorboard.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/profiling-with-jax-profiler-and-tensorboard.md b/docs/profiling-with-jax-profiler-and-tensorboard.md index 8006ffb3..d2590040 100644 --- a/docs/profiling-with-jax-profiler-and-tensorboard.md +++ b/docs/profiling-with-jax-profiler-and-tensorboard.md @@ -12,6 +12,12 @@ tensorboard --logdir /tmp/tensorboard/ ``` You should be able to load TensorBoard at http://localhost:6006/. You can specify a different port with the `--port` flag. If you are running on a remote Cloud TPU VM, the `tensorboard-plugin-profile` python package enables remote access to tensorboard endpoints (JetStream deps include this package). +When you can not access the tensorboard and the profiling code is run remotely, please run below command setup an SSH tunnel on port 6006 to work. If you run with vs code remote debug commandline, the vs code did ssh forward port for you. + +```bash + gcloud compute ssh -- -L 6006:127.0.0.1:6006 + ``` + 2. Start JetStream MaxText server: ```bash From 120c68bc41ca74107f2fa71569924c305ae8d8ab Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Mon, 10 Jun 2024 10:00:42 -0700 Subject: [PATCH 05/42] Add inference sampling utils in JetStream (#100) * Add inference sampling utils in JetStream * pylint * full unit test coverage * fmt --- jetstream/engine/sampling_utils.py | 87 +++++++++++++++++++ jetstream/tests/engine/test_sampling_utils.py | 74 ++++++++++++++++ 2 files changed, 161 insertions(+) create mode 100644 jetstream/engine/sampling_utils.py create mode 100644 jetstream/tests/engine/test_sampling_utils.py diff --git a/jetstream/engine/sampling_utils.py b/jetstream/engine/sampling_utils.py new file mode 100644 index 00000000..3adb191a --- /dev/null +++ b/jetstream/engine/sampling_utils.py @@ -0,0 +1,87 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=bare-except, consider-using-generator +""" Inference sampling utilities. + + Inspired by an Google-internal implementation, Global Vision Transformer. +""" + +import jax +import jax.numpy as jnp + +NEG_INF = -1.0e7 # Masking purpose + + +def sampling(logits, rng, algorithm, topk=0, nucleus_topp=0, temperature=1.0): + """ + logits: unnormalized logits to sample, shaped [YOUR_LEADING_DIMS, Vocab], + before logit + rng: rng key to use + algorithm: string representing supported algorithms + topk: restricting to topk logits before sampling + nucleus_topp: restricting to p probability mass before sampling + temperature: temperature parameter for scaling probability + """ + if algorithm == "greedy": + return jnp.argmax(logits, axis=-1) + elif algorithm == "weighted": + return jax.random.categorical(rng, logits / temperature) + elif algorithm == "nucleus": + return sample_nucleus_topp_logits(logits, nucleus_topp, temperature, rng) + elif algorithm == "topk": + return sample_topk_logits(logits, topk, temperature, rng) + else: + raise ValueError(f"Sampling {algorithm=} not supported!") + + +def sample_nucleus_topp_logits(logits, nucleus_topp, temperature, rng): + """Restrict sampling to the top logits with cumulative probability >= + nucleus_topp. + + The nucleus sampling method is proposed in the paper `The Curious Case of + Neural Text Degeneration (https://arxiv.org/pdf/1904.09751.pdf)` + + """ + if nucleus_topp < 0: + raise ValueError( + "Can't apply nucleus with parameter {nucleus_topp=} less zero" + ) + logits_sorted = jnp.sort(logits, axis=-1)[..., ::-1] # sort descending + sorted_cum_probs = jnp.cumsum( + jax.nn.softmax(logits_sorted, axis=-1), axis=-1 + ) # get cumsum probs + cutoff_index = jnp.sum( + sorted_cum_probs < nucleus_topp, axis=-1, keepdims=True + ) # find cutoff index + cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) + logits = jnp.where( + logits < cutoff_logit, jnp.full_like(logits, NEG_INF), logits + ) + return jax.random.categorical(rng, logits / temperature) + + +def sample_topk_logits(logits, topk, temperature, rng): + """Restricting sampling to the best k logits.""" + if topk <= 0: + raise ValueError("Can't apply algorithm topk with parameter {topk=} <= 0") + topk_logits, topk_idxs = jax.lax.top_k(logits, topk) + topk_token = jnp.expand_dims( + jax.random.categorical(rng, topk_logits / temperature).astype(jnp.int32), + axis=-1, + ) + sampled_tokens = jnp.squeeze( + jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1 + ).astype(jnp.int32) + return sampled_tokens diff --git a/jetstream/tests/engine/test_sampling_utils.py b/jetstream/tests/engine/test_sampling_utils.py new file mode 100644 index 00000000..0bc13bd6 --- /dev/null +++ b/jetstream/tests/engine/test_sampling_utils.py @@ -0,0 +1,74 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests functionality of inference sampling utils.""" + +import jax +import jax.numpy as jnp +import unittest +from jetstream.engine import sampling_utils + + +class SamplingUtilsTest(unittest.TestCase): + + def setUp(self): + self.rng = jax.random.PRNGKey(0) + self.logits = jnp.array([[-0.5, 1.2, 0.8], [-1.0, 0.3, 0.7]]) + + def test_greedy_sampling(self): + token = sampling_utils.sampling(self.logits, self.rng, "greedy") + expected_token = jnp.array([1, 2]) + self.assertTrue(jnp.array_equal(token, expected_token)) + + def test_weighted_sampling(self): + # Multiple samples to increase the chance of catching errors + for _ in range(10): + result = sampling_utils.sampling(self.logits, self.rng, "weighted") + self.assertTrue( + jnp.all(jnp.isin(result, jnp.array([0, 1, 2]))) + ) # Check if sampled from valid indices + + def test_nucleus_sampling(self): + for _ in range(10): + result = sampling_utils.sampling( + self.logits, self.rng, "nucleus", nucleus_topp=0.8 + ) + self.assertTrue(jnp.all(jnp.isin(result, jnp.array([0, 1, 2])))) + invalid_topp = -0.1 + with self.assertRaises(ValueError) as context: + sampling_utils.sampling( + self.logits, self.rng, "nucleus", nucleus_topp=invalid_topp + ) + self.assertIn( + f"Can't apply nucleus with parameter {invalid_topp=} less zero", + str(context.exception), + ) + + def test_topk_sampling(self): + for _ in range(10): + result = sampling_utils.sampling(self.logits, self.rng, "topk", topk=2) + self.assertTrue( + jnp.all(jnp.isin(result, jnp.array([1, 2]))) + ) # Only top 2 logits should be sampled + invalid_topk = 0 + with self.assertRaises(ValueError) as context: + sampling_utils.sampling(self.logits, self.rng, "topk", topk=invalid_topk) + self.assertIn( + f"Can't apply algorithm topk with parameter {invalid_topk=} <= 0", + str(context.exception), + ) + + def test_unsupported_algorithm(self): + with self.assertRaises(ValueError): + sampling_utils.sampling(self.logits, self.rng, "unsupported_algorithm") From 8a1e31322e8e953909482b71f2689f82dbf4572f Mon Sep 17 00:00:00 2001 From: Zhihao Shan <60905719+zhihaoshan-google@users.noreply.github.com> Date: Mon, 10 Jun 2024 16:28:30 -0700 Subject: [PATCH 06/42] Add profiling server for proxy backend (#101) Co-authored-by: Zhihao Shan --- jetstream/core/server_lib.py | 10 +++++ jetstream/core/utils/proxy_util.py | 48 +++++++++++++++++++++++ jetstream/engine/__init__.py | 32 +++------------ jetstream/tools/proxy_dev/base.Dockerfile | 25 ++++++++++++ jetstream/tools/proxy_dev/dev.Dockerfile | 17 ++++++++ 5 files changed, 105 insertions(+), 27 deletions(-) create mode 100644 jetstream/core/utils/proxy_util.py create mode 100644 jetstream/tools/proxy_dev/base.Dockerfile create mode 100644 jetstream/tools/proxy_dev/dev.Dockerfile diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 3d93746d..4ea65160 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -162,6 +162,16 @@ def run( jax.profiler.start_server(jax_profiler_port) else: logging.info("Not starting JAX profiler server: %s", enable_jax_profiler) + + # Start profiling server by default for proxy backend. + if jax.config.jax_platforms and "proxy" in jax.config.jax_platforms: + from jetstream.core.utils import proxy_util # pylint: disable=import-outside-toplevel + + thread = threading.Thread( + target=proxy_util.start_profiling_server, args=(jax_profiler_port,) + ) + thread.run() + return jetstream_server diff --git a/jetstream/core/utils/proxy_util.py b/jetstream/core/utils/proxy_util.py new file mode 100644 index 00000000..fbf4b031 --- /dev/null +++ b/jetstream/core/utils/proxy_util.py @@ -0,0 +1,48 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Proxy util functions.""" + +import dataclasses +import logging +import jax +import time +from fastapi import FastAPI +import uvicorn + + +# TODO: add a manner way to terminate. +def start_profiling_server(port: int): + + logging.info("Starting JAX profiler server on port %s", port) + app = FastAPI() + + @dataclasses.dataclass + class ProfilingConfig: + seconds: int + output_dir: str + + @app.post("/profiling") + async def profiling(pc: ProfilingConfig): + jax.profiler.start_trace(pc.output_dir) + logging.info("Capturing the profiling data for next %s seconds", pc.seconds) + time.sleep(pc.seconds) + logging.info("Writing profiling data to %s", pc.output_dir) + jax.profiler.stop_trace() + return {"response": "profiling completed"} + + @app.get("/") + async def root(): + return {"message": "Hello from proxy profiling server"} + + uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") diff --git a/jetstream/engine/__init__.py b/jetstream/engine/__init__.py index 2ed0398d..ee979964 100644 --- a/jetstream/engine/__init__.py +++ b/jetstream/engine/__init__.py @@ -16,30 +16,8 @@ import jax - -def register_proxy_backend(): - """Try to register IFRT Proxy backend if it's needed.""" - # TODO: find a more elegant way to do it. - if jax.config.jax_platforms and "proxy" in jax.config.jax_platforms: - try: - jax.lib.xla_bridge.get_backend("proxy") - except RuntimeError: - try: - from jaxlib.xla_extension import ifrt_proxy # pylint: disable=import-outside-toplevel - - jax_backend_target = jax.config.read("jax_backend_target") - jax._src.xla_bridge.register_backend_factory( # pylint: disable=protected-access - "proxy", - lambda: ifrt_proxy.get_client( - jax_backend_target, - ifrt_proxy.ClientConnectionOptions(), - ), - priority=-1, - ) - print(f"Registered IFRT Proxy with address {jax_backend_target}") - except ImportError as e: - print(f"Failed to register IFRT Proxy, exception: {e}") - pass - - -register_proxy_backend() +try: + import previewutilities +except ImportError as e: + print("Proxy backend support is not added") + pass diff --git a/jetstream/tools/proxy_dev/base.Dockerfile b/jetstream/tools/proxy_dev/base.Dockerfile new file mode 100644 index 00000000..0158902a --- /dev/null +++ b/jetstream/tools/proxy_dev/base.Dockerfile @@ -0,0 +1,25 @@ +# Ubuntu:22.04 +# Use Ubuntu 22.04 from Docker Hub. +# https://hub.docker.com/_/ubuntu/tags\?page\=1\&name\=22.04 +FROM ubuntu:22.04 + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt -y update && apt install -y --no-install-recommends apt-transport-https ca-certificates gnupg git python3.10 python3-pip curl + +RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 +RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && apt-get update -y && apt-get install google-cloud-sdk -y + + +# Copy all files from local workspace into docker container +COPY JetStream ./JetStream +COPY maxtext ./maxtext + +RUN cd maxtext/ && \ +pip install -r requirements.txt + +RUN pip install setuptools==58 fastapi==0.103.2 uvicorn nltk evaluate + +RUN pip install ./JetStream + +ENTRYPOINT ["bash"] diff --git a/jetstream/tools/proxy_dev/dev.Dockerfile b/jetstream/tools/proxy_dev/dev.Dockerfile new file mode 100644 index 00000000..126da735 --- /dev/null +++ b/jetstream/tools/proxy_dev/dev.Dockerfile @@ -0,0 +1,17 @@ +# Ubuntu:22.04 +# Use Ubuntu 22.04 from Docker Hub. +# https://hub.docker.com/_/ubuntu/tags\?page\=1\&name\=22.04 +FROM base_image + +ENV DEBIAN_FRONTEND=noninteractive + +ENV JAX_PLATFORMS=proxy +ENV JAX_BACKEND_TARGET=grpc://localhost:38681 + +# Copy all files from local workspace into docker container +COPY JetStream ./JetStream +COPY maxtext ./maxtext + +RUN pip install ./JetStream + +ENTRYPOINT ["bash"] From 26872c3c6e726f52f5bac1cb63e60a9a2a0bbe8a Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Wed, 12 Jun 2024 16:43:46 -0400 Subject: [PATCH 07/42] Change `jetstream_slots_available_percentage` to `jetstream_slots_used_percentage` (#102) * initial_commit * pylint * updated example --- ...rvability-prometheus-metrics-in-jetstream-server.md | 6 +++--- jetstream/core/metrics/prometheus.py | 10 +++++----- jetstream/core/orchestrator.py | 6 ++++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/docs/observability-prometheus-metrics-in-jetstream-server.md b/docs/observability-prometheus-metrics-in-jetstream-server.md index b61cf081..876e2cd3 100644 --- a/docs/observability-prometheus-metrics-in-jetstream-server.md +++ b/docs/observability-prometheus-metrics-in-jetstream-server.md @@ -45,7 +45,7 @@ Now that we configured `prometheus_port=9090` above, we can observe various Jets # HELP jetstream_prefill_backlog_size Size of prefill queue # TYPE jetstream_prefill_backlog_size gauge jetstream_prefill_backlog_size{id="SOME-HOSTNAME-HERE>"} 0.0 -# HELP jetstream_slots_available_percentage The percentage of available slots in decode batch -# TYPE jetstream_slots_available_percentage gauge -jetstream_slots_available_percentage{id="",idx="0"} 0.96875 +# HELP jetstream_slots_used_percentage The percentage of decode slots currently being used +# TYPE jetstream_slots_used_percentage gauge +jetstream_slots_used_percentage{id="",idx="0"} 0.04166666666666663 ``` \ No newline at end of file diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index de0be2c2..6fc38897 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -35,14 +35,14 @@ def __new__(cls): documentation="Size of prefill queue", labelnames=["id"], ) - _slots_available_percentage = Gauge( - name="jetstream_slots_available_percentage", - documentation="The percentage of available slots in decode batch", + _slots_used_percentage = Gauge( + name="jetstream_slots_used_percentage", + documentation="The percentage of decode slots currently being used", labelnames=["id", "idx"], ) def get_prefill_backlog_metric(self): return self._prefill_backlog.labels(id=self._id) - def get_slots_available_percentage_metric(self, idx: int): - return self._slots_available_percentage.labels(id=self._id, idx=idx) + def get_slots_used_percentage_metric(self, idx: int): + return self._slots_used_percentage.labels(id=self._id, idx=idx) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index eed35f8c..b387f113 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -597,9 +597,11 @@ def _generate_thread(self, idx: int): max_concurrent_decodes = generate_engine.max_concurrent_decodes if self._metrics_collector: - self._metrics_collector.get_slots_available_percentage_metric( + self._metrics_collector.get_slots_used_percentage_metric( idx - ).set_function(lambda: float(my_slots.qsize() / max_concurrent_decodes)) + ).set_function( + lambda: float(1 - (my_slots.qsize() / max_concurrent_decodes)) + ) # Check if there are any free my_slots. We don't want to block here since # we can still generate if we can't insert. We do this in a while loop to From f83e04226c63ee647e3063affc195286fe699075 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 17 Jun 2024 15:56:06 -0700 Subject: [PATCH 08/42] Bump urllib3 from 2.2.0 to 2.2.2 in the pip group across 1 directory (#104) Bumps the pip group with 1 update in the / directory: [urllib3](https://github.com/urllib3/urllib3). Updates `urllib3` from 2.2.0 to 2.2.2 - [Release notes](https://github.com/urllib3/urllib3/releases) - [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst) - [Commits](https://github.com/urllib3/urllib3/compare/2.2.0...2.2.2) --- updated-dependencies: - dependency-name: urllib3 dependency-type: indirect dependency-group: pip ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 11cb5643..24e162fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -302,7 +302,7 @@ typing-extensions==4.5.0 # flax # orbax-checkpoint # tensorflow -urllib3==2.2.0 +urllib3==2.2.2 # via # blobfile # requests From 7c1686ed24ca17f55764b1b08ac26a005a38d911 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Mon, 24 Jun 2024 16:04:20 -0400 Subject: [PATCH 09/42] Added `jetstream_transfer_backlog_size` and `jetstream_generate_backlog_size` metrics (#103) * first commit * unit tests * labels * typing * fix cell-var-from-loop error * extra log * tweak log * pylint * split log line --- jetstream/core/metrics/prometheus.py | 16 ++++++++++++++++ jetstream/core/orchestrator.py | 14 +++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index 6fc38897..e84a0905 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -35,6 +35,16 @@ def __new__(cls): documentation="Size of prefill queue", labelnames=["id"], ) + _transfer_backlog = Gauge( + name="jetstream_transfer_backlog_size", + documentation="Size of transfer queue", + labelnames=["id", "idx"], + ) + _generate_backlog = Gauge( + name="jetstream_generate_backlog_size", + documentation="Size of generate queue", + labelnames=["id", "idx"], + ) _slots_used_percentage = Gauge( name="jetstream_slots_used_percentage", documentation="The percentage of decode slots currently being used", @@ -44,5 +54,11 @@ def __new__(cls): def get_prefill_backlog_metric(self): return self._prefill_backlog.labels(id=self._id) + def get_transfer_backlog_metric(self, idx: int): + return self._transfer_backlog.labels(id=self._id, idx=idx) + + def get_generate_backlog_metric(self, idx: int): + return self._generate_backlog.labels(id=self._id, idx=idx) + def get_slots_used_percentage_metric(self, idx: int): return self._slots_used_percentage.labels(id=self._id, idx=idx) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index b387f113..1a2b7a88 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -267,6 +267,11 @@ def __init__( queue.Queue(1 if self._interleaved_mode else 4) for i in range(len(self._prefill_engines)) ] + if self._metrics_collector: + for idx, backlog in enumerate(self._transfer_backlogs): + self._metrics_collector.get_transfer_backlog_metric(idx).set_function( + functools.partial(float, backlog.qsize()) + ) # Stage 3 # Each generate engine accesses its own generate backlog. # Interleaved Mode: Max size is 1 to increase the HBM utilization @@ -281,6 +286,11 @@ def __init__( ) for idx, engine in enumerate(self._generate_engines) } + if self._metrics_collector: + for idx, backlog in self._generate_backlogs.items(): + self._metrics_collector.get_generate_backlog_metric(idx).set_function( + functools.partial(float, backlog.qsize()) + ) # Stage 4 # After generation, ActiveRequests are placed on the detokenization backlog # for tokens to be sent into each ActiveRequest's return channel. @@ -561,9 +571,11 @@ def _transfer_thread(self, idx: int): self._generate_backlogs[target_idx].put(new_request, block=True) logging.info( "Successfully transferred prefill " - "from prefill engine %d to generate engine %d.", + "from prefill engine %d to generate engine %d " + "(%d requests now in backlog).", idx, target_idx, + self._generate_backlogs[target_idx].qsize(), ) def _generate_thread(self, idx: int): From 3ddc26f764c4ae7cc409670870dbde85cdf050c1 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Wed, 26 Jun 2024 10:06:53 -0700 Subject: [PATCH 10/42] Update docs for benchmark warmup mode (#106) * Update docs for benchmark warmup mode * update * Update benchmarks/README.md Co-authored-by: Andy Ye --------- Co-authored-by: Andy Ye --- README.md | 1 + benchmarks/README.md | 19 ++++++++++++++++++- docs/online-inference-with-maxtext-engine.md | 4 ++-- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index ee0b1eee..a989b316 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ Currently, there are two reference engine implementations available -- one for J - [Online Inference with MaxText on v5e Cloud TPU VM](https://cloud.google.com/tpu/docs/tutorials/LLM/jetstream) [[README](https://github.com/google/JetStream/blob/main/docs/online-inference-with-maxtext-engine.md)] - [Online Inference with Pytorch on v5e Cloud TPU VM](https://cloud.google.com/tpu/docs/tutorials/LLM/jetstream-pytorch) [[README](https://github.com/google/jetstream-pytorch/tree/main?tab=readme-ov-file#jetstream-pytorch)] - [Serve Gemma using TPUs on GKE with JetStream](https://cloud.google.com/kubernetes-engine/docs/tutorials/serve-gemma-tpu-jetstream) +- [Benchmark JetStream Server](https://github.com/google/JetStream/blob/main/benchmarks/README.md) - [Observability in JetStream Server](https://github.com/google/JetStream/blob/main/docs/observability-prometheus-metrics-in-jetstream-server.md) - [Profiling in JetStream Server](https://github.com/google/JetStream/blob/main/docs/profiling-with-jax-profiler-and-tensorboard.md) - [JetStream Standalone Local Setup](#jetstream-standalone-local-setup) diff --git a/benchmarks/README.md b/benchmarks/README.md index b88501f2..2a511c1b 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -78,7 +78,7 @@ python benchmark_serving.py \ ``` python JetStream/benchmarks/benchmark_serving.py \ --tokenizer ~/maxtext/assets/tokenizer.llama2 \ ---warmup-first true \ +--warmup-mode sampled \ --save-result \ --save-request-outputs \ --request-outputs-file-path outputs.json \ @@ -88,6 +88,23 @@ python JetStream/benchmarks/benchmark_serving.py \ ``` +## Benchmark warmup mode + +The benchmark has better performance if it first conducts a warmup of the JetStream server. We currently support `sampled` and `full` warmup modes. `full` mode would warmup up the JetStream server with all the input requests. `sampled` mode would warmup up the JetStream server with a sampling of the input requests across different bucket sizes of input lengths. + +Example to run benchmark with `full` warmup mode: +``` +python JetStream/benchmarks/benchmark_serving.py \ +--tokenizer ~/maxtext/assets/tokenizer.llama2 \ +--warmup-mode full \ +--save-result \ +--save-request-outputs \ +--request-outputs-file-path outputs.json \ +--num-prompts 1000 \ +--max-output-length 1024 \ +--dataset openorca +``` + ## Standalone Evaluation Run If you used `--save-request-outputs`, you can separately evaluate against the generated outputs. diff --git a/docs/online-inference-with-maxtext-engine.md b/docs/online-inference-with-maxtext-engine.md index 9d3aefe1..96c9db81 100644 --- a/docs/online-inference-with-maxtext-engine.md +++ b/docs/online-inference-with-maxtext-engine.md @@ -259,7 +259,7 @@ python JetStream/benchmarks/benchmark_serving.py \ --dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \ --max-output-length 1024 \ --request-rate 5 \ ---warmup-first true +--warmup-mode sampled ``` ### Benchmarking Llama2-\*b @@ -274,7 +274,7 @@ python JetStream/benchmarks/benchmark_serving.py \ --dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \ --max-output-length 1024 \ --request-rate 5 \ ---warmup-first true +--warmup-mode sampled ``` ## Clean Up From cd6eb2d42b0a1b96f49431c1cf80d8800a2a2211 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 27 Jun 2024 11:57:01 -0700 Subject: [PATCH 11/42] Update docs with metrics observation instructions (#107) * first commit * links * remove extra space * tweak * more explicit documentation * wording * typo * added more docs * json to bash * Update docs/observability-prometheus-metrics-in-jetstream-server.md Co-authored-by: Zijun Zhou --------- Co-authored-by: Zijun Zhou --- ...-prometheus-metrics-in-jetstream-server.md | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/docs/observability-prometheus-metrics-in-jetstream-server.md b/docs/observability-prometheus-metrics-in-jetstream-server.md index 876e2cd3..04d7be4c 100644 --- a/docs/observability-prometheus-metrics-in-jetstream-server.md +++ b/docs/observability-prometheus-metrics-in-jetstream-server.md @@ -48,4 +48,38 @@ jetstream_prefill_backlog_size{id="SOME-HOSTNAME-HERE>"} 0.0 # HELP jetstream_slots_used_percentage The percentage of decode slots currently being used # TYPE jetstream_slots_used_percentage gauge jetstream_slots_used_percentage{id="",idx="0"} 0.04166666666666663 -``` \ No newline at end of file +``` + +## Observe metrics on GKE clusters + +The following applies only for Jetstream deployed on a GKE cluster. Currently [Google Cloud Managed Service for Prometheus](https://cloud.google.com/stackdriver/docs/managed-prometheus) is enabled by default on all GKE clusters, it determines scrape targets via the [PodMonitoring](https://github.com/GoogleCloudPlatform/prometheus-engine/blob/v0.10.0/doc/api.md#podmonitoring) custom resource. After you deployed the JetStream GKE workload, you need to apply the PodMonitoring resource to your cluster as follows: + +``` +echo '{ + "apiVersion": "monitoring.googleapis.com/v1", + "kind": "PodMonitoring", + "metadata": { + "name": "jetstream-podmonitoring" + }, + "spec": { + "endpoints": [ + { + "interval": "1s", + "path": "/", + "port": + } + ], + "targetLabels": { + "metadata": [ + "pod", + "container", + "node" + ] + } + } + }' | kubectl apply -f - + ``` + +The metrics can now be queried in the [Google Cloud Metrics Explorer](https://pantheon.corp.google.com/monitoring/metrics-explorer). When adding a metrics query with the `+Add Query` button the new metrics should be found under the `Prometheus Target > Jetstream` submenu. + +Additional guides on the metrics explorer can be found [here](https://cloud.google.com/monitoring/charts/metrics-selector). \ No newline at end of file From c3fe3ce068ed518c61f852079cba40b7aa9922f3 Mon Sep 17 00:00:00 2001 From: jwyang-google <132702993+jwyang-google@users.noreply.github.com> Date: Fri, 28 Jun 2024 08:42:12 -0700 Subject: [PATCH 12/42] Prefill return first token (#105) Change prefill API to return first token. --- jetstream/core/orchestrator.py | 37 +++++++++++- jetstream/engine/engine_api.py | 2 +- jetstream/engine/mock_engine.py | 65 ++++++++++++++++++---- jetstream/tests/engine/test_mock_engine.py | 38 +++++++------ 4 files changed, 113 insertions(+), 29 deletions(-) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 1a2b7a88..a9ea2444 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -486,6 +486,7 @@ def _prefill_thread(self, idx: int): my_transfer_backlog = self._transfer_backlogs[idx] # The prefill thread can just sleep until it has work to do. request = self._prefill_backlog.get(block=True) + request_start_time = time.perf_counter() if request is None: break @@ -503,12 +504,20 @@ def _prefill_thread(self, idx: int): request, tokenizer, is_bos, prefill_engine.max_prefill_length ) # Compute new kv cache for the prefill_content. - prefill_result = prefill_engine.prefill( + prefill_result, first_token = prefill_engine.prefill( params=prefill_params, padded_tokens=padded_tokens, true_length=true_length, ) request.prefill_result = prefill_result + + # put first token to detokenize queue + request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_) + my_detokenize_backlog = self._detokenize_backlogs[idx] + my_detokenize_backlog.put( + (first_token, request, request_start_time), block=True + ) + # Once prefill is complete, place it on the generation queue and block if # full. my_transfer_backlog.put(request, block=True) @@ -517,6 +526,7 @@ def _prefill_thread(self, idx: int): idx, my_transfer_backlog.qsize(), ) + del prefill_result del request @@ -714,7 +724,30 @@ def _detokenize_thread(self, idx: int): if data is None: break start_detokenize_time = time.time() - if isinstance(data[1], engine_api.ResultTokens): + # prefill first token + if isinstance(data[0], engine_api.ResultTokens): + request_first_token, request, request_start_time = data + request_first_token = request_first_token.convert_to_numpy() + + results, complete = token_utils.process_result_tokens( + tokenizer=tokenizer, + slot=0, # always 0 as prefill only run 1 sample + slot_max_length=request.max_tokens, + result_tokens=request_first_token, + is_client_side_tokenization=request.is_client_side_tokenization, + complete=request.complete, + ) + request.complete = complete + # Return some output samples. + request.enqueue_samples(results) + + first_token_return_time = time.perf_counter() + logging.info( + "TTFT duration: %fms", + (first_token_return_time - request_start_time) * 1000, + ) + # generate step tokens + elif isinstance(data[1], engine_api.ResultTokens): # We want to detokenize them. generate_timestep_added, result_tokens = data # Disable attribute error because pytype doesn't know this diff --git a/jetstream/engine/engine_api.py b/jetstream/engine/engine_api.py index 50feff6d..c971d30c 100644 --- a/jetstream/engine/engine_api.py +++ b/jetstream/engine/engine_api.py @@ -142,7 +142,7 @@ def prefill( existing_prefix: Optional[Prefix] = None, padded_tokens: jax.Array, true_length: int, - ) -> Prefix: + ) -> Tuple[Prefix, ResultTokens]: """Computes a kv-cache for a set of tokens conditional on existing cache. existing_prefix (if provided) represents a prefix that has already been diff --git a/jetstream/engine/mock_engine.py b/jetstream/engine/mock_engine.py index 502df8b6..0277e9a3 100644 --- a/jetstream/engine/mock_engine.py +++ b/jetstream/engine/mock_engine.py @@ -54,6 +54,7 @@ class DecodeState: generate_cache: jax.Array generate_cache_index: int generate_lengths: jax.Array + generate_tokens: jax.Array class TestEngine(engine_api.Engine): @@ -85,7 +86,7 @@ def prefill( existing_prefix: Optional[jax.Array] = None, padded_tokens: jax.Array, true_length: int, - ) -> Prefix: + ) -> Tuple[Prefix, engine_api.ResultTokens]: """Computes a kv-cache for a new generate request. Args: @@ -109,19 +110,55 @@ def prefill( ) # Do some fake work that isn't eliminated by dead code elimination (DCE). params = params + fake_work.mean() - fake_work.mean() - return padded_tokens[None, :] * params + prefill_cache = padded_tokens[None, :] * params + + # get dummy first token + first_step = (prefill_cache.sum(axis=-1))[:, jnp.newaxis] + first_token_data = jnp.concatenate( + [first_step, jnp.ones_like(first_step), jnp.ones_like(first_step)], + axis=-1, + ) + speculations = first_step.shape[1] + first_token = engine_api.ResultTokens( + data=first_token_data.astype(jnp.int32), + tokens_idx=(0, speculations), + # Validity occupies the same amount of space, but next in line. + valid_idx=(speculations, 2 * speculations), + # And lengths is rank 1. + length_idx=(2 * speculations, 2 * speculations + 1), + samples_per_slot=self.generate_cache_batch // self.prefill_cache_batch, + ) + + return (prefill_cache, first_step), first_token @functools.partial(jax.jit, static_argnums=(0,)) def generate( self, params: Params, decode_state: DecodeState ) -> Tuple[DecodeState, engine_api.ResultTokens]: """Generates tokens for each sequence being decoded in parallel.""" - prefill_cache, generate_cache, generate_cache_index, generate_lengths = ( + ( + prefill_cache, + generate_cache, + generate_cache_index, + generate_lengths, + previous_timestep, + ) = ( decode_state.prefill_cache, decode_state.generate_cache, decode_state.generate_cache_index, decode_state.generate_lengths, + decode_state.generate_tokens, ) + + # Update generate cache + generate_cache = jax.lax.dynamic_update_slice_in_dim( + generate_cache, + previous_timestep, + start_index=generate_cache_index, + axis=1, + ) + generate_cache_index = (generate_cache_index + 1) % self.cache_length + # Sum each row of prefill cache and generate cache to produce new timestep, # multiply by params. l_iota = jax.lax.broadcasted_iota( @@ -136,17 +173,13 @@ def generate( # token from prefill in the dummy. # This iota and masking is to allow for a cicular cache. length_mask = ( - -(l_iota - generate_cache_index + 1) % self.cache_length + -(l_iota - generate_cache_index) % self.cache_length ) <= generate_lengths[:, None] length_masked_gen_cache = generate_cache * length_mask new_timestep = ( prefill_cache.sum(axis=-1) + (length_masked_gen_cache.sum(axis=-1) / params) )[:, jnp.newaxis] - generate_cache = jax.lax.dynamic_update_slice_in_dim( - generate_cache, new_timestep, start_index=generate_cache_index, axis=1 - ) - generate_cache_index = (generate_cache_index + 1) % self.cache_length # Wait to simulate model step time. fake_size = 4096 fake_work = jnp.ones((fake_size, fake_size)) @ jnp.ones( @@ -168,6 +201,7 @@ def generate( generate_cache=generate_cache, generate_cache_index=generate_cache_index, generate_lengths=new_lengths, + generate_tokens=new_timestep, ), engine_api.ResultTokens( data=token_data.astype(jnp.int32), # Tokens are shape [batch, speculations], so when we concatenate @@ -190,8 +224,9 @@ def insert( ) -> DecodeState: """Adds `prefix` into `decode_state` at `slot`.""" # [B, T], [T,] -> [B, T] + prefill_cache, previous_timestep = prefix prefill_cache = jax.lax.dynamic_update_slice_in_dim( - decode_state.prefill_cache, prefix, slot, axis=0 + decode_state.prefill_cache, prefill_cache, slot, axis=0 ) generate_cache = jax.lax.dynamic_update_slice_in_dim( decode_state.generate_cache, @@ -202,7 +237,13 @@ def insert( samples_per_slot = self.generate_cache_batch // self.prefill_cache_batch generate_lengths = jax.lax.dynamic_update_slice_in_dim( decode_state.generate_lengths, - jnp.zeros((samples_per_slot), dtype=jnp.int32), + jnp.ones((samples_per_slot), dtype=jnp.int32), + slot * samples_per_slot, + axis=0, + ) + generate_tokens = jax.lax.dynamic_update_slice_in_dim( + decode_state.generate_tokens, + previous_timestep, slot * samples_per_slot, axis=0, ) @@ -210,6 +251,7 @@ def insert( prefill_cache=prefill_cache, generate_cache=generate_cache, generate_lengths=generate_lengths, + generate_tokens=generate_tokens, ) def get_prefix_destination_sharding(self) -> Any: @@ -234,6 +276,9 @@ def init_decode_state(self) -> DecodeState: generate_lengths=jnp.zeros( (self.generate_cache_batch), dtype=jnp.int32 ), + generate_tokens=jnp.zeros( + (self.generate_cache_batch, 1), dtype=jnp.float32 + ), ) @property diff --git a/jetstream/tests/engine/test_mock_engine.py b/jetstream/tests/engine/test_mock_engine.py index 3f112067..0d8f2da8 100644 --- a/jetstream/tests/engine/test_mock_engine.py +++ b/jetstream/tests/engine/test_mock_engine.py @@ -54,10 +54,10 @@ def _prefill(self): metadata = engine.get_tokenizer() tokenizer = engine.build_tokenizer(metadata) tokens, true_length = tokenizer.encode(text, is_bos=True) - prefill_result = engine.prefill( + prefill_result, first_token = engine.prefill( params=params, padded_tokens=tokens, true_length=3 ) - return engine, params, prefill_result, true_length + return engine, params, prefill_result, true_length, first_token def _prefill_np(self): """Performs prefill and returns a kv cache.""" @@ -67,14 +67,14 @@ def _prefill_np(self): metadata = engine.get_tokenizer() tokenizer = engine.build_tokenizer(metadata) tokens, true_length = tokenizer.encode(text, is_bos=True, jax_padding=False) - prefill_result = engine.prefill( + prefill_result, first_token = engine.prefill( params=params, padded_tokens=tokens, true_length=3 ) - return engine, params, prefill_result, true_length + return engine, params, prefill_result, true_length, first_token def _generate(self, slot=1): """Performs a single generation step.""" - engine, params, prefill_result, _ = self._prefill() + engine, params, prefill_result, _, _ = self._prefill() decode_state = engine.init_decode_state() decode_state = engine.insert( prefix=prefill_result, decode_state=decode_state, slot=slot @@ -91,16 +91,28 @@ def test_load_params(self): def test_prefill(self): """Tests prefill with weight = 2.""" - _, _, prefill_result, true_length = self._prefill() + engine, _, prefill_result, true_length, first_token = self._prefill() + prefill_cache, _ = prefill_result np.testing.assert_array_equal( - prefill_result[:, :true_length], np.array([[4.0, 130.0, 132.0]]) + prefill_cache[:, :true_length], np.array([[4.0, 130.0, 132.0]]) ) + # test first token + token_data = first_token.get_result_at_slot(0) + tok = token_data.tokens + + metadata = engine.get_tokenizer() + tokenizer = token_utils.load_vocab( + metadata.path, metadata.extra_ids + ).tokenizer + assert tokenizer.IdToPiece(int(tok.item())) == "Ċ" + def test_prefill_np(self): """Tests prefill with weight = 2.""" - _, _, prefill_result, true_length = self._prefill_np() + _, _, prefill_result, true_length, _ = self._prefill_np() + prefill_cache, _ = prefill_result np.testing.assert_array_equal( - prefill_result[:, :true_length], np.array([[4.0, 130.0, 132.0]]) + prefill_cache[:, :true_length], np.array([[4.0, 130.0, 132.0]]) ) def test_generate(self, slot=1): @@ -110,13 +122,7 @@ def test_generate(self, slot=1): tokenizer = token_utils.load_vocab( metadata.path, metadata.extra_ids ).tokenizer - # Char for 266 - token_data = sampled_tokens.get_result_at_slot(slot) - tok = token_data.tokens - assert tokenizer.IdToPiece(int(tok.item())) == "Ċ" - decode_state, sampled_tokens = engine.generate( - params=params, decode_state=decode_state - ) + # Char for 399 token_data = sampled_tokens.get_result_at_slot(slot) tok = token_data.tokens From 46e444c4e309f364041da36eed6a395ad57ad699 Mon Sep 17 00:00:00 2001 From: jwyang-google <132702993+jwyang-google@users.noreply.github.com> Date: Sat, 29 Jun 2024 05:10:57 +0600 Subject: [PATCH 13/42] change the detokenization thread to return the actual eos token. (#108) * change the detokenization thread to return the actual eos token. --- jetstream/engine/mock_utils.py | 2 +- jetstream/engine/token_utils.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/jetstream/engine/mock_utils.py b/jetstream/engine/mock_utils.py index 3c4fb0b3..a48a360f 100644 --- a/jetstream/engine/mock_utils.py +++ b/jetstream/engine/mock_utils.py @@ -63,7 +63,7 @@ def _encode(self, s: str) -> Sequence[int]: def _decode(self, ids: np.ndarray): """Converts a numpy array into a string.""" - return "".join([chr(r) for r in list(ids)]) + return "".join([chr(r) for r in list(ids) if r not in self.stop_tokens]) def _encode_tf(self, s: str) -> np.ndarray: """Converts a string into a numpy array.""" diff --git a/jetstream/engine/token_utils.py b/jetstream/engine/token_utils.py index 3d905688..0fa11ea4 100644 --- a/jetstream/engine/token_utils.py +++ b/jetstream/engine/token_utils.py @@ -214,6 +214,7 @@ def process_result_tokens( ) if tok_id in stop_tokens or not valid: complete[idx] = True + tok_id_so_far.append(tok_id) break else: if not is_client_side_tokenization: From 69ce8a2646ac32bea9194019078248b49e69728e Mon Sep 17 00:00:00 2001 From: Morgan Du Date: Tue, 2 Jul 2024 16:09:40 -0700 Subject: [PATCH 14/42] Add loadgen in dev image (#109) * build loadgen in docker --- jetstream/tools/proxy_dev/base.Dockerfile | 2 +- jetstream/tools/proxy_dev/dev.Dockerfile | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/jetstream/tools/proxy_dev/base.Dockerfile b/jetstream/tools/proxy_dev/base.Dockerfile index 0158902a..911acb33 100644 --- a/jetstream/tools/proxy_dev/base.Dockerfile +++ b/jetstream/tools/proxy_dev/base.Dockerfile @@ -5,7 +5,7 @@ FROM ubuntu:22.04 ENV DEBIAN_FRONTEND=noninteractive -RUN apt -y update && apt install -y --no-install-recommends apt-transport-https ca-certificates gnupg git python3.10 python3-pip curl +RUN apt -y update && apt install -y --no-install-recommends apt-transport-https ca-certificates gnupg git python3.10 python3-pip curl nano vim RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && apt-get update -y && apt-get install google-cloud-sdk -y diff --git a/jetstream/tools/proxy_dev/dev.Dockerfile b/jetstream/tools/proxy_dev/dev.Dockerfile index 126da735..a59ff904 100644 --- a/jetstream/tools/proxy_dev/dev.Dockerfile +++ b/jetstream/tools/proxy_dev/dev.Dockerfile @@ -14,4 +14,16 @@ COPY maxtext ./maxtext RUN pip install ./JetStream +COPY inference_mlperf4.1 ./inference_mlperf4.1 +RUN apt-get -y install python3-dev && apt-get -y install build-essential +RUN pip install ./inference_mlperf4.1/loadgen +RUN pip install \ + transformers==4.31.0 \ + nltk==3.8.1 \ + evaluate==0.4.0 \ + absl-py==1.4.0 \ + rouge-score==0.1.2 \ + sentencepiece==0.1.99 \ + accelerate==0.21.0 + ENTRYPOINT ["bash"] From 0db39c290b53e8c00bd8d17c6b394bbd6e63c2b0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Jul 2024 10:50:49 -0700 Subject: [PATCH 15/42] Bump certifi from 2024.2.2 to 2024.7.4 in the pip group (#110) Bumps the pip group with 1 update: [certifi](https://github.com/certifi/python-certifi). Updates `certifi` from 2024.2.2 to 2024.7.4 - [Commits](https://github.com/certifi/python-certifi/compare/2024.02.02...2024.07.04) --- updated-dependencies: - dependency-name: certifi dependency-type: indirect dependency-group: pip ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 24e162fc..3eea4098 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,7 +27,7 @@ blobfile==2.1.1 # via -r requirements.in cachetools==5.3.2 # via google-auth -certifi==2024.2.2 +certifi==2024.7.4 # via requests charset-normalizer==3.3.2 # via requests From 166dcd1d7c84213620355ed61977f68eb6e350c6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 9 Jul 2024 13:45:22 -0700 Subject: [PATCH 16/42] Bump zipp from 3.17.0 to 3.19.1 in the pip group (#111) Bumps the pip group with 1 update: [zipp](https://github.com/jaraco/zipp). Updates `zipp` from 3.17.0 to 3.19.1 - [Release notes](https://github.com/jaraco/zipp/releases) - [Changelog](https://github.com/jaraco/zipp/blob/main/NEWS.rst) - [Commits](https://github.com/jaraco/zipp/compare/v3.17.0...v3.19.1) --- updated-dependencies: - dependency-name: zipp dependency-type: indirect dependency-group: pip ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3eea4098..6bce9a98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -319,7 +319,7 @@ wrapt==1.16.0 # clu # tensorflow # tfds-nightly -zipp==3.17.0 +zipp==3.19.1 # via etils # The following packages are considered to be unsafe in a requirements file: From 196bedafd45e574e0c280da043df91993268fe22 Mon Sep 17 00:00:00 2001 From: vivianrwu Date: Thu, 11 Jul 2024 12:19:01 -0700 Subject: [PATCH 17/42] Model warmup support with AOT and endpoint for JetStream (#92) * initial setup for model warmup support * add engine api variables for prefill and insert * Add AOT warmup fixes * fix jetstream_pb2 spelling * fix jetstream_pb2 spelling * remove references to history * reformat files with pyink * fix typo in modelwarmuprequest * remove absl logging * refactor model warmup outside of orchestrator * fix stub to use utilities in test_server * retrigger checks * resolve libraries * fix pylint * fix pylint * Refactor model warmup engines and bake logic into server start * refactor warmup logic even more * fix pytype and modelwarmup func * remove warmup _enabled from unit test * add back in list * Add model warmup into server_lib run * Revert "Add model warmup into server_lib run" This reverts commit 2ffe7616aafcfa1dcabc2f6ca8f5baefdcb7674d. * Add model warmup into server_lib run * import engine _api * fix pylint issues * Rename instances of WarmedUpEngine to JetStreamEngine --- jetstream/core/orchestrator.py | 14 ++ jetstream/core/server_lib.py | 48 ++++- jetstream/engine/aot_utils.py | 260 ++++++++++++++++++++++++++++ jetstream/engine/engine_api.py | 106 ++++++++++++ jetstream/engine/token_utils.py | 30 ++-- jetstream/tests/core/test_server.py | 36 ++++ 6 files changed, 476 insertions(+), 18 deletions(-) create mode 100644 jetstream/engine/aot_utils.py diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index a9ea2444..d8fa9edd 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -135,6 +135,7 @@ class ActiveRequest: #################### Information relevant for prefill ######################## history_path: Optional[str] = None prefill_content: Optional[str | list[int]] = None + padded_token_length: Optional[int] = None ################## Information relevant for detokenization ################### # Which generate step this was added at. generate_timestep_added: Optional[int] = None @@ -503,12 +504,19 @@ def _prefill_thread(self, idx: int): padded_tokens, true_length = self._process_prefill_content( request, tokenizer, is_bos, prefill_engine.max_prefill_length ) + if isinstance(prefill_engine, engine_api.JetStreamEngine): + request.padded_token_length = token_utils.take_nearest_length( + prefill_engine.prefill_buckets, true_length + ) + prefill_engine.set_padded_token_length(request.padded_token_length) + # Compute new kv cache for the prefill_content. prefill_result, first_token = prefill_engine.prefill( params=prefill_params, padded_tokens=padded_tokens, true_length=true_length, ) + request.prefill_result = prefill_result # put first token to detokenize queue @@ -671,6 +679,12 @@ def _generate_thread(self, idx: int): slot, generate_timestep, ) + + if isinstance(generate_engine, engine_api.JetStreamEngine): + generate_engine.set_padded_token_length( + new_request.padded_token_length + ) + decode_state = generate_engine.insert( new_request.prefill_result, decode_state, slot=slot ) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 4ea65160..9c1c5986 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -20,15 +20,20 @@ import asyncio from concurrent import futures import logging +import os +import signal import threading +import traceback from typing import Any, Type + import grpc import jax from jetstream.core import config_lib from jetstream.core import orchestrator from jetstream.core.metrics.prometheus import JetstreamMetricsCollector from jetstream.core.proto import jetstream_pb2_grpc +from jetstream.engine import aot_utils, engine_api from prometheus_client import start_http_server @@ -97,6 +102,7 @@ def run( metrics_server_config: config_lib.MetricsServerConfig | None = None, enable_jax_profiler: bool = False, jax_profiler_port: int = 9999, + enable_model_warmup: bool = False, ) -> JetStreamServer: """Runs a server with a specified config. @@ -111,6 +117,7 @@ def run( metrics_server_config: The config to enable Promethus metric server. enable_jax_profiler: The flag to enable JAX profiler server. jax_profiler_port: The port JAX profiler server (default to 9999). + enable_model_warmup: The flag to enable model server warmup with AOT. Returns: JetStreamServer that wraps the grpc server and orchestrator driver. @@ -138,11 +145,44 @@ def run( "Not starting Prometheus server: --prometheus_port flag not set" ) + prefill_engines = engines.prefill_engines + engines.interleaved_engines + generate_engines = engines.generate_engines + engines.interleaved_engines + prefill_params = prefill_params + shared_params + generate_params = generate_params + shared_params + + if prefill_engines is None: + prefill_engines = [] + if generate_engines is None: + generate_engines = [] + if prefill_params is None: + prefill_params = [] + if generate_params is None: + generate_params = [] + + if enable_model_warmup: + prefill_engines = [engine_api.JetStreamEngine(pe) for pe in prefill_engines] + generate_engines = [ + engine_api.JetStreamEngine(ge) for ge in generate_engines + ] + + try: + _ = aot_utils.layout_params_and_compile_executables( + prefill_engines, # pylint: disable=protected-access + generate_engines, # pylint: disable=protected-access + prefill_params, # pylint: disable=protected-access + generate_params, # pylint: disable=protected-access + ) + + except ValueError as e: + print(f"Model warmup encountered an error: {e}") + traceback.print_exc() + os.kill(os.getpid(), signal.SIGKILL) + driver = orchestrator.Driver( - prefill_engines=engines.prefill_engines + engines.interleaved_engines, - generate_engines=engines.generate_engines + engines.interleaved_engines, - prefill_params=prefill_params + shared_params, - generate_params=generate_params + shared_params, + prefill_engines=prefill_engines, + generate_engines=generate_engines, + prefill_params=prefill_params, + generate_params=generate_params, interleaved_mode=interleaved_mode, jax_padding=jax_padding, metrics_collector=metrics_collector, diff --git a/jetstream/engine/aot_utils.py b/jetstream/engine/aot_utils.py new file mode 100644 index 00000000..65b61f87 --- /dev/null +++ b/jetstream/engine/aot_utils.py @@ -0,0 +1,260 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AOT compilation utils.""" + +import jax +import jax.numpy as jnp +import concurrent.futures +from typing import Any, Optional, cast +import logging +from jetstream.engine import engine_api, token_utils + + +def layout_params_and_compile_executables( + prefill_engines: Optional[list[engine_api.JetStreamEngine]] = None, + generate_engines: Optional[list[engine_api.JetStreamEngine]] = None, + prefill_params: Optional[list[Any]] = None, + generate_params: Optional[list[Any]] = None, +) -> bool: + """Organizes the engines and executables. + + Args: + prefill_engines: Prefill only engines. + generate_engines: Generate only engines. + prefill_params: Prefill only params. + generate_params: Generate only params. + """ + prefill_engines = prefill_engines if prefill_engines else [] + generate_engines = generate_engines if generate_engines else [] + prefill_params = prefill_params if prefill_params else [] + generate_params = generate_params if generate_params else [] + + any_prefill_engine = None + any_prefill_params = None + + prefill_executables = [] + inserts_generate_executables = [] + + for i, pe in enumerate(prefill_engines): + any_prefill_engine = pe + any_prefill_params = prefill_params[i] + prefill_executable = initialize_prefill_jit_cache( + prefill_engine=pe, + prefill_params=prefill_params[i], + prefill_idx=i, + ) + prefill_executables.append(prefill_executable) + + for i, ge in enumerate(generate_engines): + insert_executable, generate_executable = ( + initialize_insert_generate_jit_cache( + prefill_engine=any_prefill_engine, + generate_engine=ge, + prefill_params=any_prefill_params, + generate_params=generate_params[i], + generate_idx=i, + ) + ) + inserts_generate_executables.append( + [insert_executable, generate_executable] + ) + + if prefill_executables and inserts_generate_executables: + return True + return False + + +def initialize_prefill_jit_cache( + *, + prefill_engine: engine_api.JetStreamEngine, + prefill_params: Any, + prefill_idx: int, +): + """Precompile all prefill functions in parallel. + If we don't do this, then when a new request triggers a new prefill bucket it + will take a very long time for that query to come back. + + Args: + prefill_engine: A prefill engine to be compiled for. + prefill_params: The associated prefill parameters. + prefill_idx: Which prefill engine it is. + """ + prefill_buckets = token_utils.DEFAULT_PREFILL_BUCKETS + prefill_buckets = [ + bucket + for bucket in prefill_buckets + if bucket <= prefill_engine.max_prefill_length + ] + prefill_engine.prefill_buckets = prefill_buckets + if prefill_engine.max_prefill_length not in prefill_buckets: + prefill_buckets.append(prefill_engine.max_prefill_length) + + def compile_prefill(length): + padded_tokens, true_length = jnp.ones((length), dtype="int32"), length + + lowered = jax.jit( + prefill_engine._downstream_engine.prefill, # pylint: disable=protected-access + out_shardings=prefill_engine.get_prefix_destination_sharding(), + ).lower( + params=prefill_params, + padded_tokens=padded_tokens, + true_length=true_length, + ) + logging.info( + "---------Prefill engine %d lowered for prefill length %d.---------", + prefill_idx, + length, + ) + compiled = lowered.compile() + logging.info( + "---------Prefill engine %d compiled for prefill length %d.---------", + prefill_idx, + length, + ) + return compiled + + logging.info("---------Prefill compilation %d begun.---------", prefill_idx) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(prefill_buckets) + ) as executor: + prefill_executable = list(executor.map(compile_prefill, prefill_buckets)) + + prefill_executable = { + k: cast(jax.stages.Compiled, e) + for k, e in zip(prefill_buckets, prefill_executable) + } + + prefill_engine.prefill_executable = prefill_executable + prefill_engine.warm = True + + logging.info( + "---------Prefill compilation %d complete.---------", prefill_idx + ) + + return prefill_executable + + +def initialize_insert_generate_jit_cache( + *, + prefill_engine: engine_api.JetStreamEngine, + generate_engine: engine_api.JetStreamEngine, + prefill_params: Any, + generate_params: Any, + generate_idx: int, +): + """Initialiszes jit cache for insert and generate. + + Args: + generate_engine: A generate engine to be compiled for. + generate_params: The associated parameters. + generate_idx: Which generate engine it is. + """ + + prefill_buckets = token_utils.DEFAULT_PREFILL_BUCKETS + prefill_buckets = [ + bucket + for bucket in prefill_buckets + if bucket <= generate_engine.max_prefill_length + ] + generate_engine.prefill_buckets = prefill_buckets + if generate_engine.max_prefill_length not in prefill_buckets: + prefill_buckets.append(generate_engine.max_prefill_length) + + decode_state = generate_engine.init_decode_state() + + def compile_insert(length): + padded_tokens, true_length = jnp.ones((length), dtype="int32"), length + + prefill, _ = prefill_engine._downstream_engine.prefill( # pylint: disable=protected-access + params=prefill_params, + padded_tokens=padded_tokens, + true_length=true_length, + ) + + lowered = jax.jit(generate_engine._downstream_engine.insert).lower( # pylint: disable=protected-access + prefix=prefill, decode_state=decode_state, slot=1 + ) + logging.info( + "---------Generate engine %d lowered for insert length %d.---------", + generate_idx, + length, + ) + compiled = lowered.compile() + + logging.info( + "---------Generate engine %d compiled for insert length %d.---------", + generate_idx, + length, + ) + return compiled + + def compile_generate(): + + logging.info( + "---------Generate compilation %d begun.---------", generate_idx + ) + + lowered = jax.jit(generate_engine._downstream_engine.generate).lower( # pylint: disable=protected-access + params=generate_params, + decode_state=decode_state, + ) + logging.info( + "---------Generate engine %d lowered.---------", + generate_idx, + ) + + compiled = lowered.compile() + logging.info( + "---------Generate engine %d compiled.---------", + generate_idx, + ) + + logging.info( + "---------Generate compilation %d complete.---------", generate_idx + ) + + return compiled + + logging.info( + "---------Insertion generation compilation %d begun.---------", + generate_idx, + ) + + generate_executable = compile_generate() + logging.info( + "---------Generate engine %d compiled generation step.---------", + generate_idx, + ) + generate_engine.generate_executable = generate_executable + + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(prefill_buckets) + ) as executor: + insert_executable = list(executor.map(compile_insert, prefill_buckets)) + + insert_executable = { + k: cast(jax.stages.Compiled, e) + for k, e in zip(prefill_buckets, insert_executable) + } + generate_engine.insert_executable = insert_executable + generate_engine.warm = True + + logging.info( + "---------Insertion generation compilation %d complete.---------", + generate_idx, + ) + + return insert_executable, generate_executable diff --git a/jetstream/engine/engine_api.py b/jetstream/engine/engine_api.py index c971d30c..f501b51c 100644 --- a/jetstream/engine/engine_api.py +++ b/jetstream/engine/engine_api.py @@ -240,3 +240,109 @@ def mesh(self) -> jax.sharding.Mesh: @abc.abstractmethod def colocated_cpus(self) -> Union[list[CpuDevices], None]: """CPU devices colocated with the engine's accelerators.""" + + +class JetStreamEngine(Engine): + """A wrapper engine of the Engine class. + + JetStreamEngine defines the AOT warmed up model server engine. + """ + + def __init__(self, downstream_engine: Engine): + self._downstream_engine = downstream_engine + + # Executables + self.prefill_executable = None + self.insert_executable = None + self.generate_executable = None + + self.prefill_buckets = None + + # Nearest right token length + self._padded_token_length = None + + self.warm = False + + def prefill( + self, + *, + params: Params, + existing_prefix: Optional[Prefix] = None, + padded_tokens: jax.Array, + true_length: int, + ) -> Tuple[Prefix, ResultTokens]: + + prefill_result, first_token = self.prefill_executable[ + self.padded_token_length + ]( + params=params, + padded_tokens=padded_tokens, + true_length=true_length, + ) + return prefill_result, first_token + + def insert( + self, + prefix: Prefix, + decode_state: DecodeState, + slot: int, + ) -> DecodeState: + + decode_state = self.insert_executable[self.padded_token_length]( + prefix=prefix, + decode_state=decode_state, + slot=slot, + ) + return decode_state + + def generate( + self, params: Params, decode_state: DecodeState + ) -> Tuple[DecodeState, ResultTokens]: + decode_state, sampled_tokens = self.generate_executable( # pylint: disable=not-callable + params=params, decode_state=decode_state + ) + return decode_state, sampled_tokens + + def load_params(self, *args, **kwargs) -> Params: + return self._downstream_engine.load_params(*args, **kwargs) + + def get_prefix_destination_sharding(self) -> Any: + return self._downstream_engine.get_prefix_destination_sharding() + + def get_tokenizer( + self, + ) -> tokenizer_pb2.TokenizerParameters: + return self._downstream_engine.get_tokenizer() + + def build_tokenizer( + self, + metadata: tokenizer_pb2.TokenizerParameters, + ) -> Tokenizer: + """Builds a new tokenizer object and returns it.""" + return self._downstream_engine.build_tokenizer(metadata) + + def init_decode_state(self, *args, **kwargs) -> DecodeState: + return self._downstream_engine.init_decode_state(*args, **kwargs) + + @property + def max_concurrent_decodes(self) -> int: + return self._downstream_engine.max_concurrent_decodes + + @property + def samples_per_slot(self) -> int: + return self._downstream_engine.samples_per_slot + + @property + def max_prefill_length(self) -> int: + return self._downstream_engine.max_prefill_length + + @property + def mesh(self) -> jax.sharding.Mesh: + return self._downstream_engine.mesh + + @property + def colocated_cpus(self) -> Union[list[CpuDevices], None]: + return self._downstream_engine.colocated_cpus + + def set_padded_token_length(self, padded_token_length: int): + self.padded_token_length = padded_token_length diff --git a/jetstream/engine/token_utils.py b/jetstream/engine/token_utils.py index 0fa11ea4..d6b50d29 100644 --- a/jetstream/engine/token_utils.py +++ b/jetstream/engine/token_utils.py @@ -33,6 +33,21 @@ # ResultToken class to store tokens ids. ResultTokens = Any +DEFAULT_PREFILL_BUCKETS = [ + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 32768, +] + def take_nearest_length(lengths: list[int], length: int) -> int: """Gets the nearest length to the right in a set of lengths.""" @@ -109,20 +124,7 @@ def pad_tokens( true_length: Actual length of the non-padded sequence. """ if prefill_lengths is None: - prefill_lengths = [ - 16, - 32, - 64, - 128, - 256, - 512, - 1024, - 2048, - 4096, - 8192, - 16384, - 32768, - ] + prefill_lengths = DEFAULT_PREFILL_BUCKETS if max_prefill_length is not None: prefill_lengths = prefill_lengths[ : prefill_lengths.index(max_prefill_length) diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 150ac39d..731a72b5 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -29,6 +29,7 @@ from jetstream.core import server_lib from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc +from jetstream.engine import engine_api import portpicker @@ -149,3 +150,38 @@ def test_jax_profiler_server(self): def test_get_devices(self): assert len(server_lib.get_devices()) == 1 + + async def test_model_warmup(self): + port = portpicker.pick_unused_port() + + print("port: " + str(port)) + credentials = grpc.local_server_credentials() + + server = server_lib.run( + port=port, + config=config_lib.InterleavedCPUTestServer, + devices=[None], + credentials=credentials, + enable_model_warmup=True, + ) + + async with grpc.aio.secure_channel( + f"localhost:{port}", grpc.local_channel_credentials() + ) as channel: + stub = jetstream_pb2_grpc.OrchestratorStub(channel) + + healthcheck_request = jetstream_pb2.HealthCheckRequest() + healthcheck_response = stub.HealthCheck(healthcheck_request) + healthcheck_response = await healthcheck_response + + assert healthcheck_response.is_live is True + + for pe in server._driver._prefill_engines: # pylint: disable=protected-access + assert isinstance(pe, engine_api.JetStreamEngine) + assert pe.warm is True + + for ge in server._driver._generate_engines: # pylint: disable=protected-access + assert isinstance(ge, engine_api.JetStreamEngine) + assert ge.warm is True + + server.stop() From 46c152ff1659bf6db857fe96f7ecf4267945508d Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Tue, 16 Jul 2024 15:08:52 -0700 Subject: [PATCH 18/42] Cleanup orchestrator proto (#112) * Cleanup orchestrator proto * Update JetStream based on proto cleanup --- benchmarks/benchmark_serving.py | 30 ----------------- jetstream/core/orchestrator.py | 7 ++-- jetstream/core/proto/jetstream.proto | 5 +-- jetstream/core/proto/jetstream_pb2.py | 40 +++++++++++------------ jetstream/tests/core/test_orchestrator.py | 4 --- jetstream/tests/core/test_server.py | 2 -- jetstream/tools/load_tester.py | 3 -- jetstream/tools/requester.py | 8 ----- 8 files changed, 23 insertions(+), 76 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 07b36a84..f03fac91 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -426,18 +426,14 @@ async def send_request( tokenizer: Any, input_request: InputRequest, pbar: tqdm, - session_cache: str, - priority: int, ) -> RequestFuncOutput: """Send the request to JetStream server.""" # Tokenization on client side following MLPerf standard. token_ids = tokenizer.encode(input_request.prompt) request = jetstream_pb2.DecodeRequest( - session_cache=session_cache, token_content=jetstream_pb2.DecodeRequest.TokenContent( token_ids=token_ids ), - priority=priority, max_tokens=input_request.output_len, ) output = RequestFuncOutput() @@ -463,8 +459,6 @@ async def benchmark( input_requests: list[InputRequest], request_rate: float, disable_tqdm: bool, - session_cache: str, - priority: int, ): """Benchmark the online serving performance.""" pbar = None if disable_tqdm else tqdm(total=len(input_requests)) @@ -481,8 +475,6 @@ async def benchmark( tokenizer=tokenizer, input_request=request, pbar=pbar, - session_cache=session_cache, - priority=priority, ) ) ) @@ -614,8 +606,6 @@ def main(args: argparse.Namespace): input_requests=warmup_requests, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, - session_cache=args.session_cache, - priority=args.priority, ) ) print(f"{args.warmup_mode} warmup completed.") @@ -631,8 +621,6 @@ def main(args: argparse.Namespace): input_requests=input_requests, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, - session_cache=args.session_cache, - priority=args.priority, ) ) @@ -790,24 +778,6 @@ def main(args: argparse.Namespace): " the form of a string." ), ) - parser.add_argument( - "--priority", - type=int, - default=0, - help=( - "Message priority. (currently no business logic implemented, use" - " default 0)" - ), - ) - parser.add_argument( - "--session-cache", - type=str, - default="", - help=( - "Location of any pre-cached results. (currently _load_cache_history" - " not implemented, use default empty str)" - ), - ) parser.add_argument( "--save-request-outputs", action="store_true", diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index d8fa9edd..23ca365f 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -133,7 +133,6 @@ class ActiveRequest: complete: Optional[np.ndarray] = None prefill_result: Any = None #################### Information relevant for prefill ######################## - history_path: Optional[str] = None prefill_content: Optional[str | list[int]] = None padded_token_length: Optional[int] = None ################## Information relevant for detokenization ################### @@ -491,14 +490,13 @@ def _prefill_thread(self, idx: int): if request is None: break - is_bos = not bool(request.history_path) + is_bos = True logging.info( "Prefilling on prefill engine %d : prefill queue size, %d," - " is_bos: %s, history: %s", + " is_bos: %s", idx, self._prefill_backlog.qsize(), is_bos, - request.history_path, ) # Tokenize and padding the text or token input. padded_tokens, true_length = self._process_prefill_content( @@ -895,7 +893,6 @@ async def Decode( # pylint: disable=invalid-overridden-method # Wrap request as an ActiveRequest. active_request = ActiveRequest( max_tokens=request.max_tokens, - history_path=request.session_cache, prefill_content=prefill_content, is_client_side_tokenization=is_client_side_tokenization, return_channel=return_channel, diff --git a/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto index 5f2e8869..9fc7076f 100644 --- a/jetstream/core/proto/jetstream.proto +++ b/jetstream/core/proto/jetstream.proto @@ -26,9 +26,6 @@ service Orchestrator { } message DecodeRequest { - // Where to load any pre-existing kv cache from. - string session_cache = 1; - int32 priority = 3; // The maximum output length of a sequence. It's used in JetStream to control // the output/decode length of a sequence. It would not be used in the engine. // We should always set max_tokens <= (max_target_length - @@ -51,7 +48,7 @@ message DecodeRequest { TextContent text_content = 5; TokenContent token_content = 6; } - reserved 2; + reserved 1, 2, 3; // Next ID: 7 } diff --git a/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py index 3fadd54c..07a5f313 100644 --- a/jetstream/core/proto/jetstream_pb2.py +++ b/jetstream/core/proto/jetstream_pb2.py @@ -28,7 +28,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xa7\x02\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x02\x10\x03"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3' + b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\x8a\x02\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3' ) _globals = globals() @@ -39,23 +39,23 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None _globals["_DECODEREQUEST"]._serialized_start = 58 - _globals["_DECODEREQUEST"]._serialized_end = 353 - _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 274 - _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 301 - _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 303 - _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 336 - _globals["_DECODERESPONSE"]._serialized_start = 356 - _globals["_DECODERESPONSE"]._serialized_end = 687 - _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 522 - _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 538 - _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 541 - _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 670 - _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 629 - _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 670 - _globals["_HEALTHCHECKREQUEST"]._serialized_start = 689 - _globals["_HEALTHCHECKREQUEST"]._serialized_end = 709 - _globals["_HEALTHCHECKRESPONSE"]._serialized_start = 711 - _globals["_HEALTHCHECKRESPONSE"]._serialized_end = 749 - _globals["_ORCHESTRATOR"]._serialized_start = 752 - _globals["_ORCHESTRATOR"]._serialized_end = 937 + _globals["_DECODEREQUEST"]._serialized_end = 324 + _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 233 + _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 260 + _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 262 + _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 295 + _globals["_DECODERESPONSE"]._serialized_start = 327 + _globals["_DECODERESPONSE"]._serialized_end = 658 + _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 493 + _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 509 + _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 512 + _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 641 + _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 600 + _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 641 + _globals["_HEALTHCHECKREQUEST"]._serialized_start = 660 + _globals["_HEALTHCHECKREQUEST"]._serialized_end = 680 + _globals["_HEALTHCHECKRESPONSE"]._serialized_start = 682 + _globals["_HEALTHCHECKRESPONSE"]._serialized_end = 720 + _globals["_ORCHESTRATOR"]._serialized_start = 723 + _globals["_ORCHESTRATOR"]._serialized_end = 908 # @@protoc_insertion_point(module_scope) diff --git a/jetstream/tests/core/test_orchestrator.py b/jetstream/tests/core/test_orchestrator.py index 49494bef..00e2e1c1 100644 --- a/jetstream/tests/core/test_orchestrator.py +++ b/jetstream/tests/core/test_orchestrator.py @@ -78,9 +78,7 @@ async def test_orchestrator_interleaved_mode(self): text = "AB" request = jetstream_pb2.DecodeRequest( - session_cache="", text_content=jetstream_pb2.DecodeRequest.TextContent(text=text), - priority=1, max_tokens=3, ) iterator = client.Decode(request) @@ -109,11 +107,9 @@ async def test_orchestrator_interleaved_mode_client_tokenization(self): token_ids = [65, 66] request = jetstream_pb2.DecodeRequest( - session_cache="", token_content=jetstream_pb2.DecodeRequest.TokenContent( token_ids=token_ids ), - priority=1, max_tokens=3, ) iterator = client.Decode(request) diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 731a72b5..9114f2fd 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -93,9 +93,7 @@ async def test_server( # as BOS text = "AB" request = jetstream_pb2.DecodeRequest( - session_cache="", text_content=jetstream_pb2.DecodeRequest.TextContent(text=text), - priority=1, max_tokens=3, ) iterator = stub.Decode(request) diff --git a/jetstream/tools/load_tester.py b/jetstream/tools/load_tester.py index 4d6445be..5f791efd 100644 --- a/jetstream/tools/load_tester.py +++ b/jetstream/tools/load_tester.py @@ -50,14 +50,11 @@ def api_call( stub: jetstream_pb2_grpc.OrchestratorStub, text: str, max_tokens: int, - session_cache: str = "", print_interim: bool = True, ) -> str: """Sends a request to server and returns text.""" request = jetstream_pb2.DecodeRequest( - session_cache=session_cache, text_content=jetstream_pb2.DecodeRequest.TextContent(text=text), - priority=1, max_tokens=max_tokens, ) response = stub.Decode(request) diff --git a/jetstream/tools/requester.py b/jetstream/tools/requester.py index 8fcde556..30d7ac40 100644 --- a/jetstream/tools/requester.py +++ b/jetstream/tools/requester.py @@ -26,11 +26,7 @@ _SERVER = flags.DEFINE_string("server", "0.0.0.0", "server address") _PORT = flags.DEFINE_string("port", "9000", "port to ping") -_SESSION_CACHE = flags.DEFINE_string( - "session_cache", "", "Location of any pre-cached results" -) _TEXT = flags.DEFINE_string("text", "Today is a good day", "The message") -_PRIORITY = flags.DEFINE_integer("priority", 0, "Message priority") _MAX_TOKENS = flags.DEFINE_integer( "max_tokens", 3, "Maximum number of output/decode tokens of a sequence" ) @@ -82,20 +78,16 @@ def main(argv: Sequence[str]) -> None: vocab = load_vocab(_TOKENIZER.value) token_ids = vocab.tokenizer.encode(_TEXT.value) request = jetstream_pb2.DecodeRequest( - session_cache=_SESSION_CACHE.value, token_content=jetstream_pb2.DecodeRequest.TokenContent( token_ids=token_ids ), - priority=_PRIORITY.value, max_tokens=_MAX_TOKENS.value, ) else: request = jetstream_pb2.DecodeRequest( - session_cache=_SESSION_CACHE.value, text_content=jetstream_pb2.DecodeRequest.TextContent( text=_TEXT.value ), - priority=_PRIORITY.value, max_tokens=_MAX_TOKENS.value, ) return _GetResponseAsync(stub, request) From 8060d05a874c116173c02cd20bb03891f0ca14a4 Mon Sep 17 00:00:00 2001 From: Morgan Du Date: Tue, 16 Jul 2024 16:21:12 -0700 Subject: [PATCH 19/42] update images (#113) --- jetstream/tools/proxy_dev/base.Dockerfile | 14 +++++++++++++- jetstream/tools/proxy_dev/dev.Dockerfile | 13 +------------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/jetstream/tools/proxy_dev/base.Dockerfile b/jetstream/tools/proxy_dev/base.Dockerfile index 911acb33..092f5ebb 100644 --- a/jetstream/tools/proxy_dev/base.Dockerfile +++ b/jetstream/tools/proxy_dev/base.Dockerfile @@ -18,8 +18,20 @@ COPY maxtext ./maxtext RUN cd maxtext/ && \ pip install -r requirements.txt -RUN pip install setuptools==58 fastapi==0.103.2 uvicorn nltk evaluate +RUN pip install setuptools==58 fastapi==0.103.2 uvicorn RUN pip install ./JetStream +COPY inference_mlperf4.1 ./inference_mlperf4.1 +RUN apt-get -y install python3-dev && apt-get -y install build-essential +RUN pip install ./inference_mlperf4.1/loadgen +RUN pip install \ + transformers==4.31.0 \ + nltk==3.8.1 \ + evaluate==0.4.0 \ + absl-py==1.4.0 \ + rouge-score==0.1.2 \ + sentencepiece==0.1.99 \ + accelerate==0.21.0 + ENTRYPOINT ["bash"] diff --git a/jetstream/tools/proxy_dev/dev.Dockerfile b/jetstream/tools/proxy_dev/dev.Dockerfile index a59ff904..25bf382e 100644 --- a/jetstream/tools/proxy_dev/dev.Dockerfile +++ b/jetstream/tools/proxy_dev/dev.Dockerfile @@ -13,17 +13,6 @@ COPY JetStream ./JetStream COPY maxtext ./maxtext RUN pip install ./JetStream - -COPY inference_mlperf4.1 ./inference_mlperf4.1 -RUN apt-get -y install python3-dev && apt-get -y install build-essential -RUN pip install ./inference_mlperf4.1/loadgen -RUN pip install \ - transformers==4.31.0 \ - nltk==3.8.1 \ - evaluate==0.4.0 \ - absl-py==1.4.0 \ - rouge-score==0.1.2 \ - sentencepiece==0.1.99 \ - accelerate==0.21.0 +RUN pip install -r ./maxtext/requirements.txt ENTRYPOINT ["bash"] From 2b712af808a14b5df50f61f0b0c0e39f48f297fa Mon Sep 17 00:00:00 2001 From: Morgan Du Date: Tue, 16 Jul 2024 17:14:05 -0700 Subject: [PATCH 20/42] fix (#114) --- jetstream/tools/proxy_dev/base.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream/tools/proxy_dev/base.Dockerfile b/jetstream/tools/proxy_dev/base.Dockerfile index 092f5ebb..5e4cd2e4 100644 --- a/jetstream/tools/proxy_dev/base.Dockerfile +++ b/jetstream/tools/proxy_dev/base.Dockerfile @@ -23,7 +23,7 @@ RUN pip install setuptools==58 fastapi==0.103.2 uvicorn RUN pip install ./JetStream COPY inference_mlperf4.1 ./inference_mlperf4.1 -RUN apt-get -y install python3-dev && apt-get -y install build-essential +RUN apt -y update && apt-get -y install python3-dev && apt-get -y install build-essential RUN pip install ./inference_mlperf4.1/loadgen RUN pip install \ transformers==4.31.0 \ From c88d14d3896113aea93abe2c62d03741862fc999 Mon Sep 17 00:00:00 2001 From: Morgan Du Date: Thu, 18 Jul 2024 16:16:00 -0700 Subject: [PATCH 21/42] del prefill_result & update dev image (#116) * update dev image * add space * remove component that causes low tpu duty cycle on multi-host --- jetstream/core/orchestrator.py | 2 +- jetstream/tools/proxy_dev/dev.Dockerfile | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 23ca365f..29f9222e 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -686,7 +686,7 @@ def _generate_thread(self, idx: int): decode_state = generate_engine.insert( new_request.prefill_result, decode_state, slot=slot ) - delete_pytree(new_request.prefill_result) + del new_request.prefill_result new_request.generate_timestep_added = generate_timestep new_request.complete = np.zeros( (generate_engine.samples_per_slot,), dtype=np.bool_ diff --git a/jetstream/tools/proxy_dev/dev.Dockerfile b/jetstream/tools/proxy_dev/dev.Dockerfile index 25bf382e..be7a36fc 100644 --- a/jetstream/tools/proxy_dev/dev.Dockerfile +++ b/jetstream/tools/proxy_dev/dev.Dockerfile @@ -11,6 +11,7 @@ ENV JAX_BACKEND_TARGET=grpc://localhost:38681 # Copy all files from local workspace into docker container COPY JetStream ./JetStream COPY maxtext ./maxtext +COPY inference_mlperf4.1 ./inference_mlperf4.1 RUN pip install ./JetStream RUN pip install -r ./maxtext/requirements.txt From 6ec67e49c733faf85f9c5e298560206e3a56d731 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Fri, 19 Jul 2024 11:27:03 -0700 Subject: [PATCH 22/42] fix (#117) --- benchmarks/benchmark_serving.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index f03fac91..6076beba 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -641,10 +641,11 @@ def main(args: argparse.Namespace): dimensions_json["date"] = current_dt dimensions_json["model_id"] = model_id dimensions_json["tokenizer_id"] = tokenizer_id - dimensions_json = { - **dimensions_json, - **json.loads(args.additional_metadata_metrics_to_save), - } + if args.additional_metadata_metrics_to_save is not None: + dimensions_json = { + **dimensions_json, + **json.loads(args.additional_metadata_metrics_to_save), + } metrics_json["num_prompts"] = args.num_prompts # Traffic From bd6d013b4faa4d83484ed83bc41695c11f1a1ed5 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Mon, 22 Jul 2024 21:34:44 +0200 Subject: [PATCH 23/42] Add `jetstream_server_startup_latency` metric (#118) * first commit * no labels on metric * format * change measurement * fmt * rename metric * Time -> time * nits * fixed args * int -> float * int -> float * move endpoint to server_lib.py * nit * missing labels --- jetstream/core/metrics/prometheus.py | 8 ++++++++ jetstream/core/server_lib.py | 9 +++++++++ 2 files changed, 17 insertions(+) diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index e84a0905..7363297d 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -50,6 +50,11 @@ def __new__(cls): documentation="The percentage of decode slots currently being used", labelnames=["id", "idx"], ) + _server_startup_latency = Gauge( + name="jetstream_server_startup_latency", + documentation="Total time taken to start the Jetstream server", + labelnames=["id"], + ) def get_prefill_backlog_metric(self): return self._prefill_backlog.labels(id=self._id) @@ -62,3 +67,6 @@ def get_generate_backlog_metric(self, idx: int): def get_slots_used_percentage_metric(self, idx: int): return self._slots_used_percentage.labels(id=self._id, idx=idx) + + def get_server_startup_latency_metric(self): + return self._server_startup_latency.labels(id=self._id) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 9c1c5986..24f8506c 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -23,6 +23,7 @@ import os import signal import threading +import time import traceback from typing import Any, Type @@ -122,6 +123,9 @@ def run( Returns: JetStreamServer that wraps the grpc server and orchestrator driver. """ + + server_start_time = time.time() + logging.info("Kicking off gRPC server.") engines = config_lib.get_engines(config, devices=devices) prefill_params = [pe.load_params() for pe in engines.prefill_engines] @@ -196,6 +200,11 @@ def run( jetstream_server.start() + if metrics_collector: + metrics_collector.get_server_startup_latency_metric().set( + time.time() - server_start_time + ) + # Setup Jax Profiler if enable_jax_profiler: logging.info("Starting JAX profiler server on port %s", jax_profiler_port) From 1830342749849f0391fad466da5421a36cbd42b1 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Mon, 22 Jul 2024 17:40:01 -0700 Subject: [PATCH 24/42] Add http server to JetStream (#115) * Add http server to JetStream * Add generate api and cleanup * Add unit tests * format & deps * type & lint * Merge refactor * fix refactor --- jetstream/core/server_lib.py | 92 +++++++----- jetstream/entrypoints/__init__.py | 13 ++ jetstream/entrypoints/config.py | 32 +++++ jetstream/entrypoints/http/__init__.py | 13 ++ jetstream/entrypoints/http/api_server.py | 132 ++++++++++++++++++ jetstream/entrypoints/http/protocol.py | 36 +++++ jetstream/entrypoints/http/utils.py | 27 ++++ jetstream/tests/entrypoints/__init__.py | 13 ++ jetstream/tests/entrypoints/http/__init__.py | 13 ++ .../tests/entrypoints/http/test_api_server.py | 84 +++++++++++ requirements.in | 2 + requirements.txt | 31 +++- 12 files changed, 450 insertions(+), 38 deletions(-) create mode 100644 jetstream/entrypoints/__init__.py create mode 100644 jetstream/entrypoints/config.py create mode 100644 jetstream/entrypoints/http/__init__.py create mode 100644 jetstream/entrypoints/http/api_server.py create mode 100644 jetstream/entrypoints/http/protocol.py create mode 100644 jetstream/entrypoints/http/utils.py create mode 100644 jetstream/tests/entrypoints/__init__.py create mode 100644 jetstream/tests/entrypoints/http/__init__.py create mode 100644 jetstream/tests/entrypoints/http/test_api_server.py diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 24f8506c..22180f09 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -93,40 +93,25 @@ def wait_for_termination(self) -> None: self.stop() -def run( - port: int, +def create_driver( config: Type[config_lib.ServerConfig], devices: Any, - credentials: Any = grpc.insecure_server_credentials(), - threads: int | None = None, jax_padding: bool = True, - metrics_server_config: config_lib.MetricsServerConfig | None = None, - enable_jax_profiler: bool = False, - jax_profiler_port: int = 9999, + metrics_collector: JetstreamMetricsCollector | None = None, enable_model_warmup: bool = False, -) -> JetStreamServer: - """Runs a server with a specified config. +): + """Creates a driver with a specified config. Args: - port: Port on which the server will be made available. config: A ServerConfig to config engine, model, device slices, etc. devices: Device objects, will be used to get engine with proper slicing. - credentials: Should use grpc credentials by default. - threads: Number of RPC handlers worker threads. This should be at least - equal to the decoding batch size to fully saturate the decoding queue. jax_padding: The flag to enable JAX padding during tokenization. - metrics_server_config: The config to enable Promethus metric server. - enable_jax_profiler: The flag to enable JAX profiler server. - jax_profiler_port: The port JAX profiler server (default to 9999). + metrics_collector: The JetStream Promethus metric collector. enable_model_warmup: The flag to enable model server warmup with AOT. Returns: - JetStreamServer that wraps the grpc server and orchestrator driver. + An orchestrator driver. """ - - server_start_time = time.time() - - logging.info("Kicking off gRPC server.") engines = config_lib.get_engines(config, devices=devices) prefill_params = [pe.load_params() for pe in engines.prefill_engines] generate_params = [ge.load_params() for ge in engines.generate_engines] @@ -136,19 +121,6 @@ def run( len(config.prefill_slices) + len(config.generate_slices) == 0 ) - # Setup Prometheus server - metrics_collector: JetstreamMetricsCollector = None - if metrics_server_config and metrics_server_config.port: - logging.info( - "Starting Prometheus server on port %d", metrics_server_config.port - ) - start_http_server(metrics_server_config.port) - metrics_collector = JetstreamMetricsCollector() - else: - logging.info( - "Not starting Prometheus server: --prometheus_port flag not set" - ) - prefill_engines = engines.prefill_engines + engines.interleaved_engines generate_engines = engines.generate_engines + engines.interleaved_engines prefill_params = prefill_params + shared_params @@ -182,7 +154,7 @@ def run( traceback.print_exc() os.kill(os.getpid(), signal.SIGKILL) - driver = orchestrator.Driver( + return orchestrator.Driver( prefill_engines=prefill_engines, generate_engines=generate_engines, prefill_params=prefill_params, @@ -192,6 +164,56 @@ def run( metrics_collector=metrics_collector, is_ray_backend=config.is_ray_backend, ) + + +def run( + port: int, + config: Type[config_lib.ServerConfig], + devices: Any, + credentials: Any = grpc.insecure_server_credentials(), + threads: int | None = None, + jax_padding: bool = True, + metrics_server_config: config_lib.MetricsServerConfig | None = None, + enable_jax_profiler: bool = False, + jax_profiler_port: int = 9999, + enable_model_warmup: bool = False, +) -> JetStreamServer: + """Runs a server with a specified config. + + Args: + port: Port on which the server will be made available. + config: A ServerConfig to config engine, model, device slices, etc. + devices: Device objects, will be used to get engine with proper slicing. + credentials: Should use grpc credentials by default. + threads: Number of RPC handlers worker threads. This should be at least + equal to the decoding batch size to fully saturate the decoding queue. + jax_padding: The flag to enable JAX padding during tokenization. + metrics_server_config: The config to enable Promethus metric server. + enable_jax_profiler: The flag to enable JAX profiler server. + jax_profiler_port: The port JAX profiler server (default to 9999). + enable_model_warmup: The flag to enable model server warmup with AOT. + + Returns: + JetStreamServer that wraps the grpc server and orchestrator driver. + """ + server_start_time = time.time() + logging.info("Kicking off gRPC server.") + # Setup Prometheus server + metrics_collector: JetstreamMetricsCollector = None + if metrics_server_config and metrics_server_config.port: + logging.info( + "Starting Prometheus server on port %d", metrics_server_config.port + ) + start_http_server(metrics_server_config.port) + metrics_collector = JetstreamMetricsCollector() + else: + logging.info( + "Not starting Prometheus server: --prometheus_port flag not set" + ) + + driver = create_driver( + config, devices, jax_padding, metrics_collector, enable_model_warmup + ) # We default threads to the total number of concurrent allowed decodes, # to make sure we can fully saturate the model. Set default minimum to 64. threads = threads or max(driver.get_total_concurrent_requests(), 64) diff --git a/jetstream/entrypoints/__init__.py b/jetstream/entrypoints/__init__.py new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/jetstream/entrypoints/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jetstream/entrypoints/config.py b/jetstream/entrypoints/config.py new file mode 100644 index 00000000..79f2b012 --- /dev/null +++ b/jetstream/entrypoints/config.py @@ -0,0 +1,32 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Config for JetStream Server (including engine init).""" + +from typing import Type + +from jetstream.core import config_lib + + +def get_server_config( + config_str: str, +) -> config_lib.ServerConfig | Type[config_lib.ServerConfig]: + match config_str: + case "InterleavedCPUTestServer": + server_config = config_lib.InterleavedCPUTestServer + case "CPUTestServer": + server_config = config_lib.CPUTestServer + case _: + raise NotImplementedError + return server_config diff --git a/jetstream/entrypoints/http/__init__.py b/jetstream/entrypoints/http/__init__.py new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/jetstream/entrypoints/http/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py new file mode 100644 index 00000000..e7dabfed --- /dev/null +++ b/jetstream/entrypoints/http/api_server.py @@ -0,0 +1,132 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JetStream Http API server.""" + +import json +import logging +from typing import Sequence +from absl import app as abslapp +from absl import flags +from fastapi import APIRouter, Response +import fastapi +from fastapi.responses import StreamingResponse +from prometheus_client import start_http_server +import uvicorn +from google.protobuf.json_format import Parse + +from jetstream.core import config_lib, orchestrator, server_lib +from jetstream.core.metrics.prometheus import JetstreamMetricsCollector +from jetstream.core.proto import jetstream_pb2 +from jetstream.entrypoints.config import get_server_config +from jetstream.entrypoints.http.protocol import DecodeRequest +from jetstream.entrypoints.http.utils import proto_to_json_generator + +flags.DEFINE_string("host", "0.0.0.0", "server host address") +flags.DEFINE_integer("port", 8080, "http server port") +flags.DEFINE_string( + "config", + "InterleavedCPUTestServer", + "available servers", +) +flags.DEFINE_integer( + "prometheus_port", + 9988, + "prometheus_port", +) + +llm_orchestrator: orchestrator.LLMOrchestrator + +# Define Fast API endpoints (use llm_orchestrator to handle). +router = APIRouter() + + +@router.get("/") +def root(): + """Root path for Jetstream HTTP Server.""" + return Response( + content=json.dumps({"message": "JetStream HTTP Server"}, indent=4), + media_type="application/json", + ) + + +@router.post("/v1/generate") +async def generate(request: DecodeRequest): + proto_request = Parse(request.json(), jetstream_pb2.DecodeRequest()) + generator = llm_orchestrator.Decode(proto_request) + return StreamingResponse( + content=proto_to_json_generator(generator), media_type="text/event-stream" + ) + + +@router.get("/v1/health") +async def health() -> Response: + """Health check.""" + response = await llm_orchestrator.HealthCheck( + jetstream_pb2.HealthCheckRequest() + ) + return Response( + content=json.dumps({"is_live": str(response.is_live)}, indent=4), + media_type="application/json", + status_code=200, + ) + + +def server(argv: Sequence[str]): + # Init Fast API. + app = fastapi.FastAPI() + app.include_router(router) + + # Init LLMOrchestrator which would be the main handler in the api endpoints. + devices = server_lib.get_devices() + print(f"devices: {devices}") + server_config = get_server_config(flags.FLAGS.config) + print(f"server_config: {server_config}") + del argv + + metrics_server_config: config_lib.MetricsServerConfig | None = None + # Setup Prometheus server + metrics_collector: JetstreamMetricsCollector = None + if flags.FLAGS.prometheus_port != 0: + metrics_server_config = config_lib.MetricsServerConfig( + port=flags.FLAGS.prometheus_port + ) + logging.info( + "Starting Prometheus server on port %d", metrics_server_config.port + ) + start_http_server(metrics_server_config.port) + metrics_collector = JetstreamMetricsCollector() + else: + logging.info( + "Not starting Prometheus server: --prometheus_port flag not set" + ) + + global llm_orchestrator + llm_orchestrator = orchestrator.LLMOrchestrator( + driver=server_lib.create_driver( + config=server_config, + devices=devices, + metrics_collector=metrics_collector, + ) + ) + + # Start uvicorn http server. + uvicorn.run( + app, host=flags.FLAGS.host, port=flags.FLAGS.port, log_level="info" + ) + + +if __name__ == "__main__": + # Run Abseil app w flags parser. + abslapp.run(server) diff --git a/jetstream/entrypoints/http/protocol.py b/jetstream/entrypoints/http/protocol.py new file mode 100644 index 00000000..fb003386 --- /dev/null +++ b/jetstream/entrypoints/http/protocol.py @@ -0,0 +1,36 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Http API server protocol.""" + +from pydantic import BaseModel # type: ignore + + +class TextContent(BaseModel): + text: str + + +class TokenContent(BaseModel): + token_ids: list[int] + + +class DecodeRequest(BaseModel): + max_tokens: int + text_content: TextContent | None = None + token_content: TokenContent | None = None + + # Config to enforce the oneof behavior at runtime. + class Config: + extra = "forbid" # Prevent extra fields. + anystr_strip_whitespace = True diff --git a/jetstream/entrypoints/http/utils.py b/jetstream/entrypoints/http/utils.py new file mode 100644 index 00000000..7765a785 --- /dev/null +++ b/jetstream/entrypoints/http/utils.py @@ -0,0 +1,27 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Http API server utilities.""" + +from google.protobuf.json_format import MessageToJson + + +async def proto_to_json_generator(proto_generator): + """Wraps a generator yielding Protocol Buffer messages into a generator + + yielding JSON messages. + """ + async for proto_message in proto_generator: + json_string = MessageToJson(proto_message) + yield json_string diff --git a/jetstream/tests/entrypoints/__init__.py b/jetstream/tests/entrypoints/__init__.py new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/jetstream/tests/entrypoints/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jetstream/tests/entrypoints/http/__init__.py b/jetstream/tests/entrypoints/http/__init__.py new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/jetstream/tests/entrypoints/http/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jetstream/tests/entrypoints/http/test_api_server.py b/jetstream/tests/entrypoints/http/test_api_server.py new file mode 100644 index 00000000..e6d42e58 --- /dev/null +++ b/jetstream/tests/entrypoints/http/test_api_server.py @@ -0,0 +1,84 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests http server end-to-end.""" + +import subprocess +import sys +import time +import unittest + + +import requests + + +class HTTPServerTest(unittest.IsolatedAsyncioTestCase): + + @classmethod + def setUpClass(cls): + """Sets up a JetStream http server for unit tests.""" + cls.base_url = "http://localhost:8080" + cls.server = subprocess.Popen( + [ + "python", + "-m", + "jetstream.entrypoints.http.api_server", + "--config=InterleavedCPUTestServer", + ], + stdout=sys.stdout, + stderr=sys.stderr, + ) + time.sleep(10) + + @classmethod + def tearDownClass(cls): + """Stop the server gracefully.""" + cls.server.terminate() + + async def test_root_endpoint(self): + response = requests.get(self.base_url + "/", timeout=5) + assert response.status_code == 200 + expected_data = {"message": "JetStream HTTP Server"} + assert response.json() == expected_data + + async def test_health_endpoint(self): + response = requests.get(self.base_url + "/v1/health", timeout=5) + assert response.status_code == 200 + data = response.json() + assert "is_live" in data + assert data["is_live"] == "True" + + async def test_generate_endpoint(self): + # Prepare a sample request (replace with actual data) + sample_request_data = { + "max_tokens": 10, + "text_content": {"text": "translate this to french: hello world"}, + } + + response = requests.post( + self.base_url + "/v1/generate", + json=sample_request_data, + stream=True, + timeout=5, + ) + assert response.status_code == 200 + full_response = [] + for chunk in response.iter_content( + chunk_size=None + ): # chunk_size=None for complete lines + if chunk: + stream_response = chunk.decode("utf-8") + print(f"{stream_response=}") + full_response.append(stream_response) + assert len(full_response) == 11 # 10 tokens + eos token diff --git a/requirements.in b/requirements.in index 459749ae..86841a57 100644 --- a/requirements.in +++ b/requirements.in @@ -13,5 +13,7 @@ tiktoken blobfile parameterized shortuuid +fastapi +uvicorn # For profiling tensorboard-plugin-profile \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6bce9a98..67e31fdd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,10 @@ absl-py==1.4.0 # tensorflow # tensorflow-metadata # tfds-nightly +anyio==3.7.1 + # via + # fastapi + # starlette array-record==0.5.0 # via tfds-nightly astunparse==1.6.3 @@ -34,7 +38,9 @@ charset-normalizer==3.3.2 chex==0.1.7 # via optax click==8.1.7 - # via tfds-nightly + # via + # tfds-nightly + # uvicorn clu==0.0.10 # via seqio contextlib2==21.6.0 @@ -56,7 +62,11 @@ etils[array-types,enp,epath,epy,etqdm,etree]==1.6.0 # orbax-checkpoint # tfds-nightly exceptiongroup==1.2.0 - # via pytest + # via + # anyio + # pytest +fastapi==0.103.2 + # via -r requirements.in filelock==3.14.0 # via blobfile flatbuffers==23.5.26 @@ -86,10 +96,14 @@ grpcio==1.60.1 # tensorflow gviz-api==1.10.0 # via tensorboard-plugin-profile +h11==0.14.0 + # via uvicorn h5py==3.10.0 # via tensorflow idna==3.7 - # via requests + # via + # anyio + # requests importlib-resources==6.1.1 # via etils iniconfig==2.0.0 @@ -208,6 +222,8 @@ pyasn1-modules==0.3.0 # via google-auth pycryptodomex==3.20.0 # via blobfile +pydantic==1.10.17 + # via fastapi pyglove==0.4.4 # via seqio pygments==2.17.2 @@ -252,6 +268,10 @@ six==1.16.0 # promise # tensorboard-plugin-profile # tensorflow +sniffio==1.3.1 + # via anyio +starlette==0.27.0 + # via fastapi tensorboard==2.13.0 # via tensorflow tensorboard-data-server==0.7.2 @@ -299,13 +319,18 @@ typing-extensions==4.5.0 # chex # clu # etils + # fastapi # flax # orbax-checkpoint + # pydantic # tensorflow + # uvicorn urllib3==2.2.2 # via # blobfile # requests +uvicorn==0.30.1 + # via -r requirements.in werkzeug==3.0.1 # via # tensorboard From af1b91825ef86b4e6ba2c4a4fcce74fbb541822a Mon Sep 17 00:00:00 2001 From: Fanhai Lu <154379058+FanhaiLu1@users.noreply.github.com> Date: Mon, 22 Jul 2024 18:06:43 -0700 Subject: [PATCH 25/42] Free engine resource for the slot after finished one request decoding (#119) * Add free resource function after finished one request decoding * fix lint error * fix pyint error --- jetstream/core/orchestrator.py | 1 + jetstream/engine/engine_api.py | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 29f9222e..04faa285 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -784,6 +784,7 @@ def _detokenize_thread(self, idx: int): # Place the slot back on the free queue. my_live_requests[slot] = None my_slots.put(slot, block=False) # This should always have space. + my_generate_engine.free_resource(slot) logging.info( "Detokenizing generate step %d took %.2fms", generate_timestep_added, diff --git a/jetstream/engine/engine_api.py b/jetstream/engine/engine_api.py index f501b51c..cba42939 100644 --- a/jetstream/engine/engine_api.py +++ b/jetstream/engine/engine_api.py @@ -187,6 +187,18 @@ def insert( a [0, n) range of slots and converted internally. """ + def free_resource( + self, + slot: int, # pylint: disable=unused-argument + ) -> Any: + """Free cache and other decode resource for the slot. + + This function is needed for advanced attetnion kenel like PageAttetion. + After finishing one request, the engine need to free all used page block + resource and reuse for coming requests. + """ + return None + @abc.abstractmethod def load_params(self, *args, **kwargs) -> Params: """Loads parameters. From 64ff9eaeab3e5314e1e744deba5e404c3ed904df Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Tue, 6 Aug 2024 00:07:00 +0200 Subject: [PATCH 26/42] Add `jetstream_request_success_count` metric (#124) * first commit * Count -> Counter * typo --- jetstream/core/metrics/prometheus.py | 10 +++++++++- jetstream/core/orchestrator.py | 2 ++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index 7363297d..b0e5d3db 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -16,7 +16,7 @@ import os import shortuuid -from prometheus_client import Gauge +from prometheus_client import Counter, Gauge class JetstreamMetricsCollector: @@ -55,6 +55,11 @@ def __new__(cls): documentation="Total time taken to start the Jetstream server", labelnames=["id"], ) + _request_success_count = Counter( + name="jetstream_request_success_count", + documentation="Number of requests successfully completed", + labelnames=["id"], + ) def get_prefill_backlog_metric(self): return self._prefill_backlog.labels(id=self._id) @@ -70,3 +75,6 @@ def get_slots_used_percentage_metric(self, idx: int): def get_server_startup_latency_metric(self): return self._server_startup_latency.labels(id=self._id) + + def get_request_success_count_metric(self): + return self._request_success_count.labels(id=self._id) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 04faa285..c84fea92 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -780,6 +780,8 @@ def _detokenize_thread(self, idx: int): # Return some output samples. request.enqueue_samples(results) if request.complete.all(): + if self._metrics_collector: + self._metrics_collector.get_request_success_count_metric().inc() request.return_channel.close() # Place the slot back on the free queue. my_live_requests[slot] = None From 45f8735e0356741ce73b19d1a937b1410cbd9e61 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Tue, 6 Aug 2024 01:49:27 +0200 Subject: [PATCH 27/42] Request input/output size metrics (#123) * first commit * remove unused code * fmt * changed buckets * now using DEFAULT_PREFILL_BUCKETS * missing parenthese --- jetstream/core/metrics/prometheus.py | 43 +++++++++++++++++++++++++++- jetstream/core/orchestrator.py | 14 ++++----- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index b0e5d3db..4320327c 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -16,7 +16,9 @@ import os import shortuuid -from prometheus_client import Counter, Gauge +from prometheus_client import Counter, Gauge, Histogram + +from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS class JetstreamMetricsCollector: @@ -55,6 +57,39 @@ def __new__(cls): documentation="Total time taken to start the Jetstream server", labelnames=["id"], ) + _request_input_length = Histogram( + name="jetstream_request_input_length", + documentation="Number of input tokens per request", + labelnames=["id"], + buckets=DEFAULT_PREFILL_BUCKETS, + ) + _request_output_length = Histogram( + name="jetstream_request_output_length", + documentation="Number of output tokens per request", + labelnames=["id"], + buckets=[ + 1, + 2, + 5, + 10, + 20, + 50, + 100, + 200, + 500, + 1000, + 2000, + 5000, + 10000, + 20000, + 50000, + 100000, + 200000, + 500000, + 1000000, + 2000000, + ], + ) _request_success_count = Counter( name="jetstream_request_success_count", documentation="Number of requests successfully completed", @@ -76,5 +111,11 @@ def get_slots_used_percentage_metric(self, idx: int): def get_server_startup_latency_metric(self): return self._server_startup_latency.labels(id=self._id) + def get_request_input_length(self): + return self._request_input_length.labels(id=self._id) + + def get_request_output_length(self): + return self._request_output_length.labels(id=self._id) + def get_request_success_count_metric(self): return self._request_success_count.labels(id=self._id) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index c84fea92..2c54a6f8 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -109,15 +109,6 @@ root.addHandler(handler) -def delete_pytree(p): - def delete_leaf(leaf): - if isinstance(leaf, jax.Array): - leaf.delete() - del leaf - - jax.tree_map(delete_leaf, p) - - @dataclasses.dataclass class ActiveRequest: """Current state of the driver.""" @@ -532,6 +523,8 @@ def _prefill_thread(self, idx: int): idx, my_transfer_backlog.qsize(), ) + if self._metrics_collector: + self._metrics_collector.get_request_input_length().observe(true_length) del prefill_result del request @@ -781,6 +774,9 @@ def _detokenize_thread(self, idx: int): request.enqueue_samples(results) if request.complete.all(): if self._metrics_collector: + self._metrics_collector.get_request_output_length().observe( + result_tokens.get_result_at_slot(slot).lengths + ) self._metrics_collector.get_request_success_count_metric().inc() request.return_channel.close() # Place the slot back on the free queue. From 3946afac9521538b158f5c2750d2cd39f9d41fbf Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Wed, 7 Aug 2024 21:30:12 +0200 Subject: [PATCH 28/42] Makefile (#125) * first commit * changed unit_tests.yaml * generate-protos * better generate-protos logic * append -> prepend * more make targets --- .github/workflows/unit_tests.yaml | 25 +++------ Makefile | 60 ++++++++++++++++++++++ README.md | 2 +- jetstream/core/proto/jetstream.proto | 2 + jetstream/core/proto/jetstream_pb2.py | 2 - jetstream/core/proto/jetstream_pb2_grpc.py | 2 - jetstream/engine/tokenizer_pb2.py | 2 - jetstream/engine/tokenizer_pb2_grpc.py | 2 - license_preamble.txt | 13 +++++ 9 files changed, 83 insertions(+), 27 deletions(-) create mode 100644 Makefile create mode 100644 license_preamble.txt diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index 7b230dde..8db79fc3 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -42,21 +42,13 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install Dependencies - run: | - pip install pytype - pip install pylint - pip install pyink - pip install -r requirements.txt - pip install -r benchmarks/requirements.in + run: make install-deps - name: Typecheck the code with pytype - run: | - pytype --jobs auto --disable=import-error,module-attr jetstream/ benchmarks/ + run: make type-check - name: Analysing the code with pylint - run: | - pylint jetstream/ benchmarks/ + run: make linter-check - name: Format check with pyink - run: | - pyink --pyink-indentation 2 --line-length 80 --check --verbose . + run: make format-check cpu: name: "JetStream unit tests" @@ -73,11 +65,8 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install Dependencies - run: | - pip install -r requirements.txt + run: make install-deps - name: Run all unit tests in JetStream (jetstream/tests) - run: | - coverage run -m unittest -v + run: make unit-tests - name: Create test coverage report - run: | - coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/third_party/*" --fail-under=96 \ No newline at end of file + run: make check-test-coverage \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..a7699a53 --- /dev/null +++ b/Makefile @@ -0,0 +1,60 @@ +PYTHON := python +PIP := $(PYTHON) -m pip +GRPC_TOOLS_VERSION := 1.62.1 + +all: update-and-install-deps generate-protos format check + +# Dependency management targets +update-and-install-deps: update-deps install-deps + +update-deps: + $(PIP) install pip-tools + $(PYTHON) -m piptools compile requirements.in + +install-deps: + $(PIP) install pytype pylint pyink -r requirements.txt -r benchmarks/requirements.in + +# Code generation/formatting targets +generate-protos: generate-and-prepend-preambles format + +generate-and-prepend-preambles: + $(PIP) install grpcio-tools==$(GRPC_TOOLS_VERSION) + for id in $$(find . -name "*.proto"); do \ + $(PYTHON) -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. $$id && \ + PROTO_FILE=$$(echo $$id | awk '{print substr($$0, 1, length($$0)-6)}') && \ + PB_GRPC_PY=$(addsuffix "_pb2_grpc.py",$$PROTO_FILE) && \ + PB_PY=$(addsuffix "_pb2.py",$$PROTO_FILE) && \ + cat license_preamble.txt $$PB_GRPC_PY >> $(addsuffix "_temp",$$PB_GRPC_PY) && \ + mv $(addsuffix "_temp",$$PB_GRPC_PY) $$PB_GRPC_PY; \ + cat license_preamble.txt $$PB_PY >> $(addsuffix "_temp",$$PB_PY) && \ + mv $(addsuffix "_temp",$$PB_PY) $$PB_PY; \ + done + +format: + $(PIP) install pyink + pyink --pyink-indentation 2 --line-length 80 --verbose . + +# Code checking related targets +check: type-check format-check linter-check + +type-check: + $(PIP) install pytype + pytype --jobs auto --disable=import-error,module-attr jetstream/ benchmarks/ + +format-check: + $(PIP) install pyink + pyink --pyink-indentation 2 --line-length 80 --check --verbose . + +linter-check: + $(PIP) install pylint + pylint --ignore-patterns=".*_pb2.py,.*_pb2_grpc.py" jetstream/ benchmarks/ + + +# Testing related targets +tests: unit-tests check-test-coverage + +unit-tests: + coverage run -m unittest -v + +check-test-coverage: + coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/third_party/*" --fail-under=96 diff --git a/README.md b/README.md index a989b316..2159c0fd 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ Currently, there are two reference engine implementations available -- one for J ### Setup ``` -pip install -r requirements.txt +make update-and-install-deps ``` ### Run local server & Testing diff --git a/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto index 9fc7076f..60c65605 100644 --- a/jetstream/core/proto/jetstream.proto +++ b/jetstream/core/proto/jetstream.proto @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// NOTICE: run `make generate-protos` if making changes to this file + syntax = "proto3"; package jetstream_proto; diff --git a/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py index 07a5f313..c4be62d5 100644 --- a/jetstream/core/proto/jetstream_pb2.py +++ b/jetstream/core/proto/jetstream_pb2.py @@ -11,12 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: jetstream/core/proto/jetstream.proto # Protobuf Python Version: 4.25.1 -# pylint: disable=all """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool diff --git a/jetstream/core/proto/jetstream_pb2_grpc.py b/jetstream/core/proto/jetstream_pb2_grpc.py index 84521185..d571ade8 100644 --- a/jetstream/core/proto/jetstream_pb2_grpc.py +++ b/jetstream/core/proto/jetstream_pb2_grpc.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -# pylint: disable=all """Client and server classes corresponding to protobuf-defined services.""" import grpc diff --git a/jetstream/engine/tokenizer_pb2.py b/jetstream/engine/tokenizer_pb2.py index 4aa69f87..1df16528 100644 --- a/jetstream/engine/tokenizer_pb2.py +++ b/jetstream/engine/tokenizer_pb2.py @@ -11,12 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: jetstream/engine/tokenizer.proto # Protobuf Python Version: 4.25.1 -# pylint: disable=all """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool diff --git a/jetstream/engine/tokenizer_pb2_grpc.py b/jetstream/engine/tokenizer_pb2_grpc.py index 5aa1b7a4..c4ca2afd 100644 --- a/jetstream/engine/tokenizer_pb2_grpc.py +++ b/jetstream/engine/tokenizer_pb2_grpc.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -# pylint: disable=all """Client and server classes corresponding to protobuf-defined services.""" import grpc diff --git a/license_preamble.txt b/license_preamble.txt new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/license_preamble.txt @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From d681995dc0bb199e10a7e6ded3ffc9f7da2cf8d1 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Wed, 7 Aug 2024 22:50:03 +0200 Subject: [PATCH 29/42] Various request time metrics (#121) * first commit * nit * fmt * description tweak * added more metrics * nit * nit * default metadata values * move `new_request.metadata.transfer_start_time = time.perf_counter()` * avoid NoneType * NoneType * set transfer_end_time and fmt * camel case -> snake case * description update * change descriptions * fmt * logs * better logs * changed timings * observing queue duration metric * buckets in sorted order * buckets not in sorted order * corrected times * number of output tokens * move prefill_start_time, enable debug, maybe correct len for num tokens in detokenize * fmt * correct lengths of output tokens based on debug * debug transfer queue time * remove log * removed logs, almost final * nits * readd log * change logs * reomve log * condence * improve test coverage * revert _abort_or_raise deletion * start_time mandatory * undo * nit * updated buckets * added 'jetstream_time_per_request' * nit * add 'jetstream_wait_time_per_request' * nit * missing .metadata * lint * change order of params * changed metric description * Add metadata field to proto * update proto * tweak generated file * tweak generated file * update proto * pylint * generate protos * change start time assignment * .value * CopyFrom * change definition of queue duration metric * Increase test coverage * fixed assertions * fmt * incorrect prefill time * Add license statements * Protobuf Python Version * fmt * pylint --- jetstream/core/metrics/prometheus.py | 138 ++++++++++++++++++++++- jetstream/core/orchestrator.py | 99 +++++++++++++++- jetstream/core/proto/jetstream.proto | 11 +- jetstream/core/proto/jetstream_pb2.py | 42 +++---- jetstream/entrypoints/http/api_server.py | 5 + jetstream/entrypoints/http/protocol.py | 5 + jetstream/tests/core/test_server.py | 52 +++++---- 7 files changed, 300 insertions(+), 52 deletions(-) diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index 4320327c..dc8a00e9 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -17,7 +17,6 @@ import os import shortuuid from prometheus_client import Counter, Gauge, Histogram - from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS @@ -37,21 +36,46 @@ def __new__(cls): documentation="Size of prefill queue", labelnames=["id"], ) + _transfer_backlog = Gauge( name="jetstream_transfer_backlog_size", documentation="Size of transfer queue", labelnames=["id", "idx"], ) + _generate_backlog = Gauge( name="jetstream_generate_backlog_size", documentation="Size of generate queue", labelnames=["id", "idx"], ) + + _queue_duration = Histogram( + name="jetstream_queue_duration", + documentation="The total time each request spends enqueued in seconds", + labelnames=["id"], + buckets=[ + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1.0, + 2.0, + 5.0, + 10.0, + 20.0, + 50.0, + 100.0, + ], + ) + _slots_used_percentage = Gauge( name="jetstream_slots_used_percentage", documentation="The percentage of decode slots currently being used", labelnames=["id", "idx"], ) + _server_startup_latency = Gauge( name="jetstream_server_startup_latency", documentation="Total time taken to start the Jetstream server", @@ -96,6 +120,100 @@ def __new__(cls): labelnames=["id"], ) + _time_to_first_token = Histogram( + name="jetstream_time_to_first_token", + documentation="Time to first token per request in seconds", + labelnames=["id"], + buckets=[ + 0.001, + 0.005, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + ], + ) + + _time_per_output_token = Histogram( + name="jetstream_time_per_output_token", + documentation="Average time per output token per request in seconds", + labelnames=["id"], + buckets=[ + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + ], + ) + + _time_per_prefill_token = Histogram( + name="jetstream_time_per_prefill_token", + documentation="Prefill time per token per request in seconds", + labelnames=["id"], + buckets=[ + 0.00001, + 0.00002, + 0.00005, + 0.0001, + 0.0002, + 0.0005, + 0.001, + 0.002, + 0.005, + 0.01, + 0.02, + 0.05, + 0.1, + ], + ) + + _time_per_request = Histogram( + name="jetstream_time_per_request", + documentation="End to end request latency in seconds", + labelnames=["id"], + buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0], + ) + + _wait_time_per_request = Histogram( + name="jetstream_wait_time_per_request", + documentation="Time each request is not being prefilled or decoded", + labelnames=["id"], + buckets=[ + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1.0, + 2.0, + 5.0, + 10.0, + 20.0, + 50.0, + 100.0, + ], + ) + def get_prefill_backlog_metric(self): return self._prefill_backlog.labels(id=self._id) @@ -105,12 +223,30 @@ def get_transfer_backlog_metric(self, idx: int): def get_generate_backlog_metric(self, idx: int): return self._generate_backlog.labels(id=self._id, idx=idx) + def get_queue_duration(self): + return self._queue_duration.labels(id=self._id) + def get_slots_used_percentage_metric(self, idx: int): return self._slots_used_percentage.labels(id=self._id, idx=idx) def get_server_startup_latency_metric(self): return self._server_startup_latency.labels(id=self._id) + def get_time_to_first_token(self): + return self._time_to_first_token.labels(id=self._id) + + def get_time_per_output_token(self): + return self._time_per_output_token.labels(id=self._id) + + def get_time_per_prefill_token(self): + return self._time_per_prefill_token.labels(id=self._id) + + def get_time_per_request(self): + return self._time_per_request.labels(id=self._id) + + def get_wait_time_per_request(self): + return self._wait_time_per_request.labels(id=self._id) + def get_request_input_length(self): return self._request_input_length.labels(id=self._id) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 2c54a6f8..cefabd05 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -109,6 +109,24 @@ root.addHandler(handler) +@dataclasses.dataclass +class ActiveRequestMetadata: + """Inference request metadata.""" + + start_time: Optional[float] = None + + prefill_enqueue_time: Optional[float] = None + prefill_dequeue_time: Optional[float] = None + + transfer_enqueue_time: Optional[float] = None + transfer_dequeue_time: Optional[float] = None + + generate_enqueue_time: Optional[float] = None + generate_dequeue_time: Optional[float] = None + + complete_time: Optional[float] = None + + @dataclasses.dataclass class ActiveRequest: """Current state of the driver.""" @@ -130,6 +148,8 @@ class ActiveRequest: # Which generate step this was added at. generate_timestep_added: Optional[int] = None is_client_side_tokenization: Optional[bool] = False + ################## Information relevant for metrics ################### + metadata: ActiveRequestMetadata = ActiveRequestMetadata() def enqueue_samples(self, generated_samples: list[ReturnSample]): """Adds the generated sample(s) to return channel for current step. @@ -477,10 +497,10 @@ def _prefill_thread(self, idx: int): my_transfer_backlog = self._transfer_backlogs[idx] # The prefill thread can just sleep until it has work to do. request = self._prefill_backlog.get(block=True) - request_start_time = time.perf_counter() if request is None: break + request.metadata.prefill_dequeue_time = time.perf_counter() is_bos = True logging.info( "Prefilling on prefill engine %d : prefill queue size, %d," @@ -511,8 +531,10 @@ def _prefill_thread(self, idx: int): # put first token to detokenize queue request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_) my_detokenize_backlog = self._detokenize_backlogs[idx] + request.metadata.transfer_enqueue_time = time.perf_counter() my_detokenize_backlog.put( - (first_token, request, request_start_time), block=True + (first_token, request, request.metadata.prefill_dequeue_time), + block=True, ) # Once prefill is complete, place it on the generation queue and block if @@ -526,6 +548,15 @@ def _prefill_thread(self, idx: int): if self._metrics_collector: self._metrics_collector.get_request_input_length().observe(true_length) + if self._metrics_collector: + self._metrics_collector.get_time_per_prefill_token().observe( + ( + request.metadata.transfer_enqueue_time + - request.metadata.prefill_dequeue_time + ) + / true_length + ) + del prefill_result del request @@ -562,6 +593,7 @@ def _transfer_thread(self, idx: int): new_request = transfer_backlog.get(block=True) if new_request is None: break + new_request.metadata.transfer_dequeue_time = time.perf_counter() target_idx = min( self._generate_backlogs.items(), key=lambda q: q[1].qsize() )[0] @@ -577,6 +609,7 @@ def _transfer_thread(self, idx: int): # Transfer the info to the relevant generate slice. self._transfer_prefill_result(new_request, target_idx) # Place the request on the correct generate backlog and block if full. + new_request.metadata.generate_enqueue_time = time.perf_counter() self._generate_backlogs[target_idx].put(new_request, block=True) logging.info( "Successfully transferred prefill " @@ -649,6 +682,24 @@ def _generate_thread(self, idx: int): block |= not self._transfer_backlogs[idx].empty() try: new_request = my_generate_backlog.get(block=block, timeout=1.0) + if new_request is None: + break + new_request.metadata.generate_dequeue_time = time.perf_counter() + if ( + self._metrics_collector + and new_request.metadata.start_time is not None + ): + self._metrics_collector.get_queue_duration().observe( + # Time in prefill queue + new_request.metadata.prefill_dequeue_time + - new_request.metadata.prefill_enqueue_time + # Time in transfer queue + + new_request.metadata.transfer_dequeue_time + - new_request.metadata.transfer_enqueue_time + # Time in generate queue + + new_request.metadata.generate_dequeue_time + - new_request.metadata.generate_enqueue_time + ) # Got free slot and new request, use them. except queue.Empty: # No new requests, we can't insert, so put back slot. @@ -731,7 +782,7 @@ def _detokenize_thread(self, idx: int): start_detokenize_time = time.time() # prefill first token if isinstance(data[0], engine_api.ResultTokens): - request_first_token, request, request_start_time = data + request_first_token, request, _ = data request_first_token = request_first_token.convert_to_numpy() results, complete = token_utils.process_result_tokens( @@ -747,9 +798,14 @@ def _detokenize_thread(self, idx: int): request.enqueue_samples(results) first_token_return_time = time.perf_counter() + if self._metrics_collector: + self._metrics_collector.get_time_to_first_token().observe( + first_token_return_time - request.metadata.prefill_dequeue_time + ) logging.info( "TTFT duration: %fms", - (first_token_return_time - request_start_time) * 1000, + (first_token_return_time - request.metadata.prefill_dequeue_time) + * 1000, ) # generate step tokens elif isinstance(data[1], engine_api.ResultTokens): @@ -773,12 +829,41 @@ def _detokenize_thread(self, idx: int): # Return some output samples. request.enqueue_samples(results) if request.complete.all(): + request.metadata.complete_time = time.perf_counter() + request.return_channel.close() if self._metrics_collector: self._metrics_collector.get_request_output_length().observe( result_tokens.get_result_at_slot(slot).lengths ) self._metrics_collector.get_request_success_count_metric().inc() - request.return_channel.close() + self._metrics_collector.get_time_per_output_token().observe( + ( + request.metadata.complete_time + - request.metadata.transfer_enqueue_time + ) + / result_tokens.get_result_at_slot(slot).lengths + ) + self._metrics_collector.get_time_per_request().observe( + request.metadata.complete_time + - request.metadata.transfer_enqueue_time + ) + + if request.metadata.start_time: + total_time = ( + request.metadata.complete_time + - request.metadata.start_time + ) + prefill_time = ( + request.metadata.transfer_enqueue_time + - request.metadata.prefill_dequeue_time + ) + generate_time = ( + request.metadata.complete_time + - request.metadata.generate_dequeue_time + ) + self._metrics_collector.get_wait_time_per_request().observe( + total_time - prefill_time - generate_time + ) # Place the slot back on the free queue. my_live_requests[slot] = None my_slots.put(slot, block=False) # This should always have space. @@ -895,6 +980,10 @@ async def Decode( # pylint: disable=invalid-overridden-method prefill_content=prefill_content, is_client_side_tokenization=is_client_side_tokenization, return_channel=return_channel, + metadata=ActiveRequestMetadata( + start_time=request.metadata.start_time, + prefill_enqueue_time=time.perf_counter(), + ), ) # The first stage is being prefilled, all other stages are handled # inside the driver (transfer, generate*N, detokenize). diff --git a/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto index 60c65605..f06d89d5 100644 --- a/jetstream/core/proto/jetstream.proto +++ b/jetstream/core/proto/jetstream.proto @@ -50,8 +50,17 @@ message DecodeRequest { TextContent text_content = 5; TokenContent token_content = 6; } + + message Metadata { + float start_time = 1; + } + + oneof metadata_optional { + Metadata metadata = 7; + } + reserved 1, 2, 3; - // Next ID: 7 + // Next ID: 8 } message DecodeResponse { diff --git a/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py index c4be62d5..0b146032 100644 --- a/jetstream/core/proto/jetstream_pb2.py +++ b/jetstream/core/proto/jetstream_pb2.py @@ -26,7 +26,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\x8a\x02\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3' + b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xfc\x02\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3' ) _globals = globals() @@ -37,23 +37,25 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None _globals["_DECODEREQUEST"]._serialized_start = 58 - _globals["_DECODEREQUEST"]._serialized_end = 324 - _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 233 - _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 260 - _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 262 - _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 295 - _globals["_DECODERESPONSE"]._serialized_start = 327 - _globals["_DECODERESPONSE"]._serialized_end = 658 - _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 493 - _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 509 - _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 512 - _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 641 - _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 600 - _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 641 - _globals["_HEALTHCHECKREQUEST"]._serialized_start = 660 - _globals["_HEALTHCHECKREQUEST"]._serialized_end = 680 - _globals["_HEALTHCHECKRESPONSE"]._serialized_start = 682 - _globals["_HEALTHCHECKRESPONSE"]._serialized_end = 720 - _globals["_ORCHESTRATOR"]._serialized_start = 723 - _globals["_ORCHESTRATOR"]._serialized_end = 908 + _globals["_DECODEREQUEST"]._serialized_end = 438 + _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 294 + _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 321 + _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 323 + _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 356 + _globals["_DECODEREQUEST_METADATA"]._serialized_start = 358 + _globals["_DECODEREQUEST_METADATA"]._serialized_end = 388 + _globals["_DECODERESPONSE"]._serialized_start = 441 + _globals["_DECODERESPONSE"]._serialized_end = 772 + _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 607 + _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 623 + _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 626 + _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 755 + _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 714 + _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 755 + _globals["_HEALTHCHECKREQUEST"]._serialized_start = 774 + _globals["_HEALTHCHECKREQUEST"]._serialized_end = 794 + _globals["_HEALTHCHECKRESPONSE"]._serialized_start = 796 + _globals["_HEALTHCHECKRESPONSE"]._serialized_end = 834 + _globals["_ORCHESTRATOR"]._serialized_start = 837 + _globals["_ORCHESTRATOR"]._serialized_end = 1022 # @@protoc_insertion_point(module_scope) diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py index e7dabfed..aaced235 100644 --- a/jetstream/entrypoints/http/api_server.py +++ b/jetstream/entrypoints/http/api_server.py @@ -16,6 +16,7 @@ import json import logging +import time from typing import Sequence from absl import app as abslapp from absl import flags @@ -63,7 +64,11 @@ def root(): @router.post("/v1/generate") async def generate(request: DecodeRequest): + start_time = time.perf_counter() proto_request = Parse(request.json(), jetstream_pb2.DecodeRequest()) + metadata = jetstream_pb2.DecodeRequest.Metadata() + metadata.start_time = start_time + proto_request.metadata.CopyFrom(metadata) generator = llm_orchestrator.Decode(proto_request) return StreamingResponse( content=proto_to_json_generator(generator), media_type="text/event-stream" diff --git a/jetstream/entrypoints/http/protocol.py b/jetstream/entrypoints/http/protocol.py index fb003386..cbb8dc6a 100644 --- a/jetstream/entrypoints/http/protocol.py +++ b/jetstream/entrypoints/http/protocol.py @@ -25,10 +25,15 @@ class TokenContent(BaseModel): token_ids: list[int] +class Metadata(BaseModel): + start_time: float + + class DecodeRequest(BaseModel): max_tokens: int text_content: TextContent | None = None token_content: TokenContent | None = None + metadata: Metadata | None = None # Config to enforce the oneof behavior at runtime. class Config: diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 9114f2fd..2fdddce9 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -40,6 +40,7 @@ class ServerTest(unittest.IsolatedAsyncioTestCase): # Uses weight 2 for prefill, 4 for decode. ( config_lib.CPUTestServer, + True, ["Ċ", "Ō", "Ɵ", ""], [266, 332, 415, None], [None, None], @@ -47,6 +48,15 @@ class ServerTest(unittest.IsolatedAsyncioTestCase): # Uses the same prefill / generate weights (2). ( config_lib.InterleavedCPUTestServer, + True, + ["Ċ", "Ə", "ɖ", ""], + [266, 399, 598, None], + [None], + ), + # Disable the metrics server. + ( + config_lib.InterleavedCPUTestServer, + False, ["Ċ", "Ə", "ɖ", ""], [266, 399, 598, None], [None], @@ -56,6 +66,7 @@ class ServerTest(unittest.IsolatedAsyncioTestCase): async def test_server( self, config: Type[config_lib.ServerConfig], + metrics_enabled: bool, expected_text: list[str], expected_token_ids: list[int | None], devices: list[Any], @@ -63,6 +74,7 @@ async def test_server( """Sets up a server and requests token responses.""" ######################### Server side ###################################### port = portpicker.pick_unused_port() + metrics_port = portpicker.pick_unused_port() print("port: " + str(port)) credentials = grpc.local_server_credentials() @@ -72,11 +84,15 @@ async def test_server( config=config, devices=devices, credentials=credentials, + metrics_server_config=config_lib.MetricsServerConfig(port=metrics_port) + if metrics_enabled is True + else None, ) ###################### Requester side ###################################### - # prometheus not configured, assert no metrics collector on Driver - assert server._driver._metrics_collector is None # pylint: disable=protected-access + # if prometheus not configured, assert no metrics collector on Driver + if metrics_enabled is not True: + assert server._driver._metrics_collector is None # pylint: disable=protected-access async with grpc.aio.secure_channel( f"localhost:{port}", grpc.local_channel_credentials() @@ -106,31 +122,17 @@ async def test_server( assert output_text == expected_text[counter] assert output_token_id == expected_token_ids[counter] counter += 1 + # assert prometheus server is running and responding + if metrics_enabled is True: + assert server._driver._metrics_collector is not None # pylint: disable=protected-access + assert ( + requests.get( + f"http://localhost:{metrics_port}", timeout=5 + ).status_code + == requests.status_codes.codes["ok"] + ) server.stop() - def test_prometheus_server(self): - port = portpicker.pick_unused_port() - metrics_port = portpicker.pick_unused_port() - - print("port: " + str(port)) - print("metrics port: " + str(metrics_port)) - credentials = grpc.local_server_credentials() - # Now test server with prometheus config - server = server_lib.run( - port=port, - config=config_lib.InterleavedCPUTestServer, - devices=[None], - credentials=credentials, - metrics_server_config=config_lib.MetricsServerConfig(port=metrics_port), - ) - # assert prometheus server is running and responding - assert server._driver._metrics_collector is not None # pylint: disable=protected-access - assert ( - requests.get(f"http://localhost:{metrics_port}", timeout=5).status_code - == requests.status_codes.codes["ok"] - ) - server.stop() - def test_jax_profiler_server(self): port = portpicker.pick_unused_port() print("port: " + str(port)) From ef827a879c7f4c7ecbb826836fe0df93106de3a3 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Mon, 12 Aug 2024 11:19:03 -0700 Subject: [PATCH 30/42] Standalone JetStream removes pinned deps (#129) --- requirements.in | 19 --- requirements.txt | 370 +++-------------------------------------------- 2 files changed, 19 insertions(+), 370 deletions(-) delete mode 100644 requirements.in diff --git a/requirements.in b/requirements.in deleted file mode 100644 index 86841a57..00000000 --- a/requirements.in +++ /dev/null @@ -1,19 +0,0 @@ -absl-py -coverage -flax -grpcio -jax -jaxlib -numpy -portpicker -prometheus-client -pytest -seqio -tiktoken -blobfile -parameterized -shortuuid -fastapi -uvicorn -# For profiling -tensorboard-plugin-profile \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 67e31fdd..86841a57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,351 +1,19 @@ -# -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: -# -# pip-compile requirements.in -# -absl-py==1.4.0 - # via - # -r requirements.in - # array-record - # chex - # clu - # etils - # ml-collections - # optax - # orbax-checkpoint - # seqio - # tensorboard - # tensorflow - # tensorflow-metadata - # tfds-nightly -anyio==3.7.1 - # via - # fastapi - # starlette -array-record==0.5.0 - # via tfds-nightly -astunparse==1.6.3 - # via tensorflow -blobfile==2.1.1 - # via -r requirements.in -cachetools==5.3.2 - # via google-auth -certifi==2024.7.4 - # via requests -charset-normalizer==3.3.2 - # via requests -chex==0.1.7 - # via optax -click==8.1.7 - # via - # tfds-nightly - # uvicorn -clu==0.0.10 - # via seqio -contextlib2==21.6.0 - # via ml-collections -coverage==7.4.4 - # via -r requirements.in -dm-tree==0.1.8 - # via - # chex - # tfds-nightly -docstring-parser==0.15 - # via pyglove -editdistance==0.6.2 - # via seqio -etils[array-types,enp,epath,epy,etqdm,etree]==1.6.0 - # via - # array-record - # clu - # orbax-checkpoint - # tfds-nightly -exceptiongroup==1.2.0 - # via - # anyio - # pytest -fastapi==0.103.2 - # via -r requirements.in -filelock==3.14.0 - # via blobfile -flatbuffers==23.5.26 - # via tensorflow -flax==0.8.0 - # via - # -r requirements.in - # clu -fsspec==2023.12.2 - # via etils -gast==0.4.0 - # via tensorflow -google-auth==2.27.0 - # via - # google-auth-oauthlib - # tensorboard -google-auth-oauthlib==1.0.0 - # via tensorboard -google-pasta==0.2.0 - # via tensorflow -googleapis-common-protos==1.62.0 - # via tensorflow-metadata -grpcio==1.60.1 - # via - # -r requirements.in - # tensorboard - # tensorflow -gviz-api==1.10.0 - # via tensorboard-plugin-profile -h11==0.14.0 - # via uvicorn -h5py==3.10.0 - # via tensorflow -idna==3.7 - # via - # anyio - # requests -importlib-resources==6.1.1 - # via etils -iniconfig==2.0.0 - # via pytest -jax==0.4.23 - # via - # -r requirements.in - # chex - # clu - # flax - # optax - # orbax-checkpoint - # seqio -jaxlib==0.4.23 - # via - # -r requirements.in - # chex - # clu - # optax - # orbax-checkpoint - # seqio -keras==2.13.1 - # via tensorflow -libclang==16.0.6 - # via tensorflow -lxml==4.9.4 - # via blobfile -markdown==3.5.2 - # via tensorboard -markdown-it-py==3.0.0 - # via rich -markupsafe==2.1.5 - # via werkzeug -mdurl==0.1.2 - # via markdown-it-py -ml-collections==0.1.1 - # via clu -ml-dtypes==0.3.2 - # via - # jax - # jaxlib - # tensorstore -msgpack==1.0.7 - # via - # flax - # orbax-checkpoint -nest-asyncio==1.6.0 - # via orbax-checkpoint -numpy==1.23.1 - # via - # -r requirements.in - # chex - # clu - # etils - # flax - # h5py - # jax - # jaxlib - # ml-dtypes - # opt-einsum - # optax - # orbax-checkpoint - # scipy - # seqio - # tensorboard - # tensorflow - # tensorflow-hub - # tensorstore - # tfds-nightly -oauthlib==3.2.2 - # via requests-oauthlib -opt-einsum==3.3.0 - # via - # jax - # tensorflow -optax==0.1.8 - # via flax -orbax-checkpoint==0.5.2 - # via flax -packaging==23.2 - # via - # clu - # pytest - # seqio - # tensorflow -parameterized==0.9.0 - # via -r requirements.in -pluggy==1.4.0 - # via pytest -portpicker==1.6.0 - # via -r requirements.in -prometheus-client==0.20.0 - # via -r requirements.in -promise==2.3 - # via tfds-nightly -protobuf==3.20.3 - # via - # googleapis-common-protos - # orbax-checkpoint - # seqio - # tensorboard - # tensorboard-plugin-profile - # tensorflow - # tensorflow-hub - # tensorflow-metadata - # tfds-nightly -psutil==5.9.8 - # via - # portpicker - # tfds-nightly -pyasn1==0.5.1 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 - # via google-auth -pycryptodomex==3.20.0 - # via blobfile -pydantic==1.10.17 - # via fastapi -pyglove==0.4.4 - # via seqio -pygments==2.17.2 - # via rich -pytest==8.1.1 - # via -r requirements.in -pyyaml==6.0.1 - # via - # flax - # ml-collections - # orbax-checkpoint -regex==2024.4.28 - # via tiktoken -requests==2.32.0 - # via - # requests-oauthlib - # tensorboard - # tfds-nightly - # tiktoken -requests-oauthlib==1.3.1 - # via google-auth-oauthlib -rich==13.7.0 - # via flax -rsa==4.9 - # via google-auth -scipy==1.12.0 - # via - # jax - # jaxlib -sentencepiece==0.1.99 - # via seqio -seqio==0.0.19 - # via -r requirements.in -shortuuid==1.0.13 - # via -r requirements.in -six==1.16.0 - # via - # astunparse - # google-pasta - # gviz-api - # ml-collections - # promise - # tensorboard-plugin-profile - # tensorflow -sniffio==1.3.1 - # via anyio -starlette==0.27.0 - # via fastapi -tensorboard==2.13.0 - # via tensorflow -tensorboard-data-server==0.7.2 - # via tensorboard -tensorboard-plugin-profile==2.15.1 - # via -r requirements.in -tensorflow==2.13.1 - # via tensorflow-text -tensorflow-estimator==2.13.0 - # via tensorflow -tensorflow-hub==0.16.1 - # via tensorflow-text -tensorflow-io-gcs-filesystem==0.35.0 - # via tensorflow -tensorflow-metadata==1.14.0 - # via tfds-nightly -tensorflow-text==2.13.0 - # via seqio -tensorstore==0.1.52 - # via - # flax - # orbax-checkpoint -termcolor==2.4.0 - # via - # tensorflow - # tfds-nightly -tf-keras==2.15.0 - # via tensorflow-hub -tfds-nightly==4.9.2.dev202308090034 - # via seqio -tiktoken==0.6.0 - # via -r requirements.in -toml==0.10.2 - # via tfds-nightly -tomli==2.0.1 - # via pytest -toolz==0.12.1 - # via chex -tqdm==4.66.3 - # via - # etils - # tfds-nightly -typing-extensions==4.5.0 - # via - # chex - # clu - # etils - # fastapi - # flax - # orbax-checkpoint - # pydantic - # tensorflow - # uvicorn -urllib3==2.2.2 - # via - # blobfile - # requests -uvicorn==0.30.1 - # via -r requirements.in -werkzeug==3.0.1 - # via - # tensorboard - # tensorboard-plugin-profile -wheel==0.42.0 - # via - # astunparse - # tensorboard -wrapt==1.16.0 - # via - # clu - # tensorflow - # tfds-nightly -zipp==3.19.1 - # via etils - -# The following packages are considered to be unsafe in a requirements file: -# setuptools +absl-py +coverage +flax +grpcio +jax +jaxlib +numpy +portpicker +prometheus-client +pytest +seqio +tiktoken +blobfile +parameterized +shortuuid +fastapi +uvicorn +# For profiling +tensorboard-plugin-profile \ No newline at end of file From e61532d3512aa12c2097824f1d89af0fb73c7ac4 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Mon, 12 Aug 2024 16:55:43 -0700 Subject: [PATCH 31/42] Update deps file (#130) --- MANIFEST.in | 2 +- Makefile | 8 +------- README.md | 2 +- setup.py | 2 +- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 9d4615bd..540b7204 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -include requirements.in \ No newline at end of file +include requirements.txt \ No newline at end of file diff --git a/Makefile b/Makefile index a7699a53..a8a88085 100644 --- a/Makefile +++ b/Makefile @@ -2,15 +2,9 @@ PYTHON := python PIP := $(PYTHON) -m pip GRPC_TOOLS_VERSION := 1.62.1 -all: update-and-install-deps generate-protos format check +all: install-deps generate-protos format check # Dependency management targets -update-and-install-deps: update-deps install-deps - -update-deps: - $(PIP) install pip-tools - $(PYTHON) -m piptools compile requirements.in - install-deps: $(PIP) install pytype pylint pyink -r requirements.txt -r benchmarks/requirements.in diff --git a/README.md b/README.md index 2159c0fd..62959c46 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ Currently, there are two reference engine implementations available -- one for J ### Setup ``` -make update-and-install-deps +make install-deps ``` ### Run local server & Testing diff --git a/setup.py b/setup.py index c4efd21e..55b91c3b 100644 --- a/setup.py +++ b/setup.py @@ -39,5 +39,5 @@ def parse_requirements(filename): "Operating System :: OS Independent", ], python_requires=">=3.10", - install_requires=parse_requirements("requirements.in"), + install_requires=parse_requirements("requirements.txt"), ) From 59538fc0512aef6d0728bd583f604f0bc9118d16 Mon Sep 17 00:00:00 2001 From: vivianrwu Date: Wed, 14 Aug 2024 15:53:09 -0700 Subject: [PATCH 32/42] Manual model warmup to resolve AOT model warmup performance degradation (#126) * Implement manual model warmup to resolve performance degradation * fix insert generate compiled * remove check for JetStreamEngine in orchestrator * pyink pylint fixes * change references from aot to warmup * fix non-empty comparison * use all() to check True in entire lists --- jetstream/core/orchestrator.py | 12 --- jetstream/core/server_lib.py | 8 +- jetstream/engine/engine_api.py | 22 +---- .../engine/{aot_utils.py => warmup_utils.py} | 89 +++++-------------- 4 files changed, 32 insertions(+), 99 deletions(-) rename jetstream/engine/{aot_utils.py => warmup_utils.py} (69%) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index cefabd05..a0c77c85 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -143,7 +143,6 @@ class ActiveRequest: prefill_result: Any = None #################### Information relevant for prefill ######################## prefill_content: Optional[str | list[int]] = None - padded_token_length: Optional[int] = None ################## Information relevant for detokenization ################### # Which generate step this was added at. generate_timestep_added: Optional[int] = None @@ -513,11 +512,6 @@ def _prefill_thread(self, idx: int): padded_tokens, true_length = self._process_prefill_content( request, tokenizer, is_bos, prefill_engine.max_prefill_length ) - if isinstance(prefill_engine, engine_api.JetStreamEngine): - request.padded_token_length = token_utils.take_nearest_length( - prefill_engine.prefill_buckets, true_length - ) - prefill_engine.set_padded_token_length(request.padded_token_length) # Compute new kv cache for the prefill_content. prefill_result, first_token = prefill_engine.prefill( @@ -525,7 +519,6 @@ def _prefill_thread(self, idx: int): padded_tokens=padded_tokens, true_length=true_length, ) - request.prefill_result = prefill_result # put first token to detokenize queue @@ -722,11 +715,6 @@ def _generate_thread(self, idx: int): generate_timestep, ) - if isinstance(generate_engine, engine_api.JetStreamEngine): - generate_engine.set_padded_token_length( - new_request.padded_token_length - ) - decode_state = generate_engine.insert( new_request.prefill_result, decode_state, slot=slot ) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 22180f09..b323286a 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -34,7 +34,7 @@ from jetstream.core import orchestrator from jetstream.core.metrics.prometheus import JetstreamMetricsCollector from jetstream.core.proto import jetstream_pb2_grpc -from jetstream.engine import aot_utils, engine_api +from jetstream.engine import warmup_utils, engine_api from prometheus_client import start_http_server @@ -107,7 +107,7 @@ def create_driver( devices: Device objects, will be used to get engine with proper slicing. jax_padding: The flag to enable JAX padding during tokenization. metrics_collector: The JetStream Promethus metric collector. - enable_model_warmup: The flag to enable model server warmup with AOT. + enable_model_warmup: The flag to enable model server warmup. Returns: An orchestrator driver. @@ -142,7 +142,7 @@ def create_driver( ] try: - _ = aot_utils.layout_params_and_compile_executables( + _ = warmup_utils.layout_params_and_compile_executables( prefill_engines, # pylint: disable=protected-access generate_engines, # pylint: disable=protected-access prefill_params, # pylint: disable=protected-access @@ -191,7 +191,7 @@ def run( metrics_server_config: The config to enable Promethus metric server. enable_jax_profiler: The flag to enable JAX profiler server. jax_profiler_port: The port JAX profiler server (default to 9999). - enable_model_warmup: The flag to enable model server warmup with AOT. + enable_model_warmup: The flag to enable model server warmup. Returns: JetStreamServer that wraps the grpc server and orchestrator driver. diff --git a/jetstream/engine/engine_api.py b/jetstream/engine/engine_api.py index cba42939..5277f6df 100644 --- a/jetstream/engine/engine_api.py +++ b/jetstream/engine/engine_api.py @@ -257,22 +257,13 @@ def colocated_cpus(self) -> Union[list[CpuDevices], None]: class JetStreamEngine(Engine): """A wrapper engine of the Engine class. - JetStreamEngine defines the AOT warmed up model server engine. + JetStreamEngine defines the warmed up model server engine. """ def __init__(self, downstream_engine: Engine): self._downstream_engine = downstream_engine - # Executables - self.prefill_executable = None - self.insert_executable = None - self.generate_executable = None - self.prefill_buckets = None - - # Nearest right token length - self._padded_token_length = None - self.warm = False def prefill( @@ -284,9 +275,7 @@ def prefill( true_length: int, ) -> Tuple[Prefix, ResultTokens]: - prefill_result, first_token = self.prefill_executable[ - self.padded_token_length - ]( + prefill_result, first_token = self._downstream_engine.prefill( params=params, padded_tokens=padded_tokens, true_length=true_length, @@ -300,7 +289,7 @@ def insert( slot: int, ) -> DecodeState: - decode_state = self.insert_executable[self.padded_token_length]( + decode_state = self._downstream_engine.insert( prefix=prefix, decode_state=decode_state, slot=slot, @@ -310,7 +299,7 @@ def insert( def generate( self, params: Params, decode_state: DecodeState ) -> Tuple[DecodeState, ResultTokens]: - decode_state, sampled_tokens = self.generate_executable( # pylint: disable=not-callable + decode_state, sampled_tokens = self._downstream_engine.generate( params=params, decode_state=decode_state ) return decode_state, sampled_tokens @@ -355,6 +344,3 @@ def mesh(self) -> jax.sharding.Mesh: @property def colocated_cpus(self) -> Union[list[CpuDevices], None]: return self._downstream_engine.colocated_cpus - - def set_padded_token_length(self, padded_token_length: int): - self.padded_token_length = padded_token_length diff --git a/jetstream/engine/aot_utils.py b/jetstream/engine/warmup_utils.py similarity index 69% rename from jetstream/engine/aot_utils.py rename to jetstream/engine/warmup_utils.py index 65b61f87..6bf7c26a 100644 --- a/jetstream/engine/aot_utils.py +++ b/jetstream/engine/warmup_utils.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""AOT compilation utils.""" +"""Model server warmup utils.""" -import jax import jax.numpy as jnp import concurrent.futures -from typing import Any, Optional, cast +from typing import Any, Optional import logging from jetstream.engine import engine_api, token_utils @@ -44,34 +43,30 @@ def layout_params_and_compile_executables( any_prefill_engine = None any_prefill_params = None - prefill_executables = [] - inserts_generate_executables = [] + prefills_compiled = [] + inserts_generate_compiled = [] for i, pe in enumerate(prefill_engines): any_prefill_engine = pe any_prefill_params = prefill_params[i] - prefill_executable = initialize_prefill_jit_cache( + prefill_compiled = initialize_prefill_jit_cache( prefill_engine=pe, prefill_params=prefill_params[i], prefill_idx=i, ) - prefill_executables.append(prefill_executable) + prefills_compiled.append(prefill_compiled) for i, ge in enumerate(generate_engines): - insert_executable, generate_executable = ( - initialize_insert_generate_jit_cache( - prefill_engine=any_prefill_engine, - generate_engine=ge, - prefill_params=any_prefill_params, - generate_params=generate_params[i], - generate_idx=i, - ) - ) - inserts_generate_executables.append( - [insert_executable, generate_executable] + insert_generate_compiled = initialize_insert_generate_jit_cache( + prefill_engine=any_prefill_engine, + generate_engine=ge, + prefill_params=any_prefill_params, + generate_params=generate_params[i], + generate_idx=i, ) + inserts_generate_compiled.append([insert_generate_compiled]) - if prefill_executables and inserts_generate_executables: + if all(prefills_compiled) and all(inserts_generate_compiled): return True return False @@ -104,47 +99,32 @@ def initialize_prefill_jit_cache( def compile_prefill(length): padded_tokens, true_length = jnp.ones((length), dtype="int32"), length - lowered = jax.jit( - prefill_engine._downstream_engine.prefill, # pylint: disable=protected-access - out_shardings=prefill_engine.get_prefix_destination_sharding(), - ).lower( + _, _ = prefill_engine._downstream_engine.prefill( # pylint: disable=protected-access params=prefill_params, padded_tokens=padded_tokens, true_length=true_length, ) - logging.info( - "---------Prefill engine %d lowered for prefill length %d.---------", - prefill_idx, - length, - ) - compiled = lowered.compile() + logging.info( "---------Prefill engine %d compiled for prefill length %d.---------", prefill_idx, length, ) - return compiled logging.info("---------Prefill compilation %d begun.---------", prefill_idx) with concurrent.futures.ThreadPoolExecutor( max_workers=len(prefill_buckets) ) as executor: - prefill_executable = list(executor.map(compile_prefill, prefill_buckets)) - - prefill_executable = { - k: cast(jax.stages.Compiled, e) - for k, e in zip(prefill_buckets, prefill_executable) - } + _ = executor.map(compile_prefill, prefill_buckets) - prefill_engine.prefill_executable = prefill_executable prefill_engine.warm = True logging.info( "---------Prefill compilation %d complete.---------", prefill_idx ) - return prefill_executable + return prefill_engine.warm def initialize_insert_generate_jit_cache( @@ -184,22 +164,13 @@ def compile_insert(length): true_length=true_length, ) - lowered = jax.jit(generate_engine._downstream_engine.insert).lower( # pylint: disable=protected-access - prefix=prefill, decode_state=decode_state, slot=1 - ) - logging.info( - "---------Generate engine %d lowered for insert length %d.---------", - generate_idx, - length, - ) - compiled = lowered.compile() + generate_engine.insert(prefix=prefill, decode_state=decode_state, slot=0) logging.info( "---------Generate engine %d compiled for insert length %d.---------", generate_idx, length, ) - return compiled def compile_generate(): @@ -207,16 +178,11 @@ def compile_generate(): "---------Generate compilation %d begun.---------", generate_idx ) - lowered = jax.jit(generate_engine._downstream_engine.generate).lower( # pylint: disable=protected-access + generate_engine._downstream_engine.generate( # pylint: disable=protected-access params=generate_params, decode_state=decode_state, ) - logging.info( - "---------Generate engine %d lowered.---------", - generate_idx, - ) - compiled = lowered.compile() logging.info( "---------Generate engine %d compiled.---------", generate_idx, @@ -226,30 +192,23 @@ def compile_generate(): "---------Generate compilation %d complete.---------", generate_idx ) - return compiled - logging.info( "---------Insertion generation compilation %d begun.---------", generate_idx, ) - generate_executable = compile_generate() + compile_generate() + logging.info( "---------Generate engine %d compiled generation step.---------", generate_idx, ) - generate_engine.generate_executable = generate_executable with concurrent.futures.ThreadPoolExecutor( max_workers=len(prefill_buckets) ) as executor: - insert_executable = list(executor.map(compile_insert, prefill_buckets)) + _ = executor.map(compile_insert, prefill_buckets) - insert_executable = { - k: cast(jax.stages.Compiled, e) - for k, e in zip(prefill_buckets, insert_executable) - } - generate_engine.insert_executable = insert_executable generate_engine.warm = True logging.info( @@ -257,4 +216,4 @@ def compile_generate(): generate_idx, ) - return insert_executable, generate_executable + return generate_engine.warm From 647ab2459bc5ae6c7eda078a3706b1fdb511cb66 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Fri, 23 Aug 2024 19:16:13 -0400 Subject: [PATCH 33/42] Update JetStream instructions (#132) --- docs/online-inference-with-maxtext-engine.md | 179 ++++++++++++++---- .../tools/maxtext/model_ckpt_conversion.sh | 12 +- 2 files changed, 143 insertions(+), 48 deletions(-) diff --git a/docs/online-inference-with-maxtext-engine.md b/docs/online-inference-with-maxtext-engine.md index 96c9db81..60173e3c 100644 --- a/docs/online-inference-with-maxtext-engine.md +++ b/docs/online-inference-with-maxtext-engine.md @@ -21,11 +21,11 @@ Follow the steps in [Manage TPU resources | Google Cloud](https://cloud.google.c ## Step 1: Download JetStream and the MaxText github repository ```bash -git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git -git clone -b v0.2.2 https://github.com/google/JetStream.git +git clone https://github.com/google/maxtext.git +git clone https://github.com/google/JetStream.git ``` -## Step 2: Setup MaxText +## Step 2: Setup MaxText and JetStream ```bash # Create a python virtual environment for the demo. @@ -36,6 +36,12 @@ source .env/bin/activate # Setup MaxText. cd maxtext/ bash setup.sh + +# Setup JetStream +cd JetStream +pip install -e . +cd benchmarks +pip install -r requirements.in ``` ## Step 3: Convert Model Checkpoints @@ -45,16 +51,16 @@ You can run the JetStream MaxText Server with Gemma and Llama2 models. This sect ### Use a Gemma model checkpoint * You can download a [Gemma checkpoint from Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText/variations/7b). -* After downloading checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. +* After downloading orbax Gemma checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. You should also set two more paths `$MAXTEXT_BUCKET_SCANNED` and `$MAXTEXT_BUCKET_UNSCANNED` that point to the locations of the maxtext checkpoints for the scanned and unscanned (inference-optimized) versions, respectively. * `gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}` * Please refer to the [conversion script](https://github.com/google/JetStream/blob/main/jetstream/tools/maxtext/model_ckpt_conversion.sh) for an example of `$CHKPT_BUCKET`. * Then, using the following command to convert the Gemma checkpoint into a MaxText compatible unscanned checkpoint. ```bash -# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} +# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} # For gemma-7b -bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET} +bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} ``` Note: For more information about the Gemma model and checkpoints, see [About Gemma](https://github.com/google/maxtext/blob/main/end_to_end/gemma/Run_Gemma.md). @@ -63,25 +69,25 @@ Note: For more information about the Gemma model and checkpoints, see [About Gem ### Use a Llama2 model checkpoint * You can use a Llama2 checkpoint you have generated or one from [the open source community](https://llama.meta.com/llama-downloads/). -* After downloading checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. +* After downloading PyTorch checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. You should also set two more paths `$MAXTEXT_BUCKET_SCANNED` and `$MAXTEXT_BUCKET_UNSCANNED` that point to the locations of the maxtext checkpoints for the scanned and unscanned (inference-optimized) versions, respectively. * `gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}` * Please refer to the [conversion script](https://github.com/google/JetStream/blob/main/jetstream/tools/maxtext/model_ckpt_conversion.sh) for an example of `$CHKPT_BUCKET`. * Then, using the following command to convert the Llama2 checkpoint into a MaxText compatible unscanned checkpoint. ```bash -# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} +# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} # For llama2-7b -bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} +bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} # For llama2-13b -bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET} +bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} ``` Note: For more information about the Llama2 model and checkpoints, see [About Llama2](https://github.com/google/maxtext/blob/main/getting_started/Run_Llama2.md). -## Step4: Run the JetStream MaxText server +## Step 4: Run the JetStream MaxText server ### Create model config environment variables for server flags @@ -104,8 +110,8 @@ export MAX_PREFILL_PREDICT_LENGTH=1024 export MAX_TARGET_LENGTH=2048 export MODEL_NAME=gemma-7b export ICI_FSDP_PARALLELISM=1 -export ICI_AUTOREGRESSIVE_PARALLELISM=-1 -export ICI_TENSOR_PARALLELISM=1 +export ICI_AUTOREGRESSIVE_PARALLELISM=1 +export ICI_TENSOR_PARALLELISM=-1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 export PER_DEVICE_BATCH_SIZE=11 @@ -122,8 +128,8 @@ export MAX_PREFILL_PREDICT_LENGTH=1024 export MAX_TARGET_LENGTH=2048 export MODEL_NAME=llama2-7b export ICI_FSDP_PARALLELISM=1 -export ICI_AUTOREGRESSIVE_PARALLELISM=-1 -export ICI_TENSOR_PARALLELISM=1 +export ICI_AUTOREGRESSIVE_PARALLELISM=1 +export ICI_TENSOR_PARALLELISM=-1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 export PER_DEVICE_BATCH_SIZE=11 @@ -131,8 +137,6 @@ export PER_DEVICE_BATCH_SIZE=11 #### Create Llama2-13b environment variables for server flags - - * Configure the [flags](#jetstream-maxtext-server-flag-descriptions) passing into the JetStream MaxText server ```bash @@ -142,8 +146,8 @@ export MAX_PREFILL_PREDICT_LENGTH=1024 export MAX_TARGET_LENGTH=2048 export MODEL_NAME=llama2-13b export ICI_FSDP_PARALLELISM=1 -export ICI_AUTOREGRESSIVE_PARALLELISM=-1 -export ICI_TENSOR_PARALLELISM=1 +export ICI_AUTOREGRESSIVE_PARALLELISM=1 +export ICI_TENSOR_PARALLELISM=-1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 export PER_DEVICE_BATCH_SIZE=4 @@ -187,7 +191,8 @@ python MaxText/maxengine_server.py \ Note: these flags are from [MaxText config](https://github.com/google/maxtext/blob/f9e04cdc1eec74a0e648411857c09403c3358461/MaxText/configs/base.yml) -## Step 5: Send test request to JetStream MaxText server +## Step 5: Send a test request to JetStream MaxText server +In a new tab in your terminal, run the following command ```bash cd ~ @@ -207,34 +212,125 @@ Response: to be a fan ## Step 6: Run benchmarks with JetStream MaxText server -Note: The JetStream MaxText Server is not running with quantization optimization in Step 3. To get best benchmark results, we need to enable quantization (Please use AQT trained or fine tuned checkpoints to ensure accuracy) for both weights and KV cache, please add the quantization flags and restart the server as following: +Note: The JetStream MaxText Server commands from Step 4 are not running with any quantization optimizations. To get the best benchmark results, we need to enable quantization for weights and KV cache. To do this, first generate AQT trained or fine-tuned checkpoints. Then, add the quantization flags and restart the server. + +### Generating a quantized checkpoint + +First, define the path to which the quantized checkpoint +```bash +export SAVE_QUANT_PARAMS_PATH=gs://${USER}-bkt/quantized/llama2-7b-chat +``` + +There are several different quantization configurations to choose from: +#### int8 DRQ quantized checkpoint ```bash -# Enable int8 quantization for both weights and KV cache +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} +``` + +#### Weights-only int8 quantized checkpoint +```bash +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8w save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} +``` + +#### Mixed precision weight-only quantized checkpoint +First, update the mixed precision config file (`MaxText/configs/quantization/mp_scale.json`) in MaxText repo to the mixed-precision-config defined below. +``` +{ + ".*/query": {"bits": 4, "scale": 0.8}, + ".*/key": {"bits": 4, "scale": 0.9}, + ".*/value": {"bits": 8}, + ".*/out": {"bits": 4}, + ".*/wi_0": {"bits": 4}, + ".*/wo": {"bits": 8} +} +``` +Then run the following command: +```bash +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=intmp +quant_cfg_path=configs/quantization/mp_scale.json save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} +``` + +### Restart the server with quantization flags + +#### Set flags + +Setting base quantization flags +```bash +# To load an int8 DRQcheckpoint export QUANTIZATION=int8 -export QUANTIZE_KVCACHE=true +export LOAD_PARAMETERS_PATH${SAVE_QUANT_PARAMS_PATH} +export CHECKPOINT_IS_QUANTIZED=True + +# To load an int8 weight-only checkpoint +export QUANTIZATION=int8w +export LOAD_PARAMETERS_PATH${SAVE_QUANT_PARAMS_PATH} +export CHECKPOINT_IS_QUANTIZED=True + +# To load a Mixed-Precision quantized checkpoint +# If using Mixed-Precision mode, make sure to update the mixed precision config file to the same file as used for quantizing the checkpoint (MaxText/configs/quantization/mp_scale.json) +export QUANTIZATION=intmp +export LOAD_PARAMETERS_PATH${SAVE_QUANT_PARAMS_PATH} +export CHECKPOINT_IS_QUANTIZED=True +export QUANT_CFG_PATH=configs/quantization/mp_scale.json +``` +The KV-cache is quantized to int8 by using the following config params +```bash +export QUANTIZE_KVCACHE=True +``` +If you don't want to quantize the KV-cache, set +```bash +export QUANTIZE_KVCACHE=False +``` + + +#### Restart server +```bash # For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. export PER_DEVICE_BATCH_SIZE=12 cd ~/maxtext python MaxText/maxengine_server.py \ -MaxText/configs/base.yml \ -tokenizer_path=${TOKENIZER_PATH} \ -load_parameters_path=${LOAD_PARAMETERS_PATH} \ -max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ -max_target_length=${MAX_TARGET_LENGTH} \ -model_name=${MODEL_NAME} \ -ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ -ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ -ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ -scan_layers=${SCAN_LAYERS} \ -weight_dtype=${WEIGHT_DTYPE} \ -per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ -quantization=${QUANTIZATION} \ -quantize_kvcache=${QUANTIZE_KVCACHE} + MaxText/configs/base.yml \ + tokenizer_path=${TOKENIZER_PATH} \ + load_parameters_path=${LOAD_PARAMETERS_PATH} \ + max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ + max_target_length=${MAX_TARGET_LENGTH} \ + model_name=${MODEL_NAME} \ + ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ + ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ + ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ + scan_layers=${SCAN_LAYERS} \ + weight_dtype=${WEIGHT_DTYPE} \ + per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ + quantization=${QUANTIZATION} \ + quantize_kvcache=${QUANTIZE_KVCACHE} \ + checkpoint_is_quantized=${CHECKPOINT_IS_QUANTIZED} ``` +For the mixed precision quantized model +```bash +python MaxText/maxengine_server.py \ + MaxText/configs/base.yml \ + tokenizer_path=${TOKENIZER_PATH} \ + load_parameters_path=${LOAD_PARAMETERS_PATH} \ + max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ + max_target_length=${MAX_TARGET_LENGTH} \ + model_name=${MODEL_NAME} \ + ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ + ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ + ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ + scan_layers=${SCAN_LAYERS} \ + weight_dtype=${WEIGHT_DTYPE} \ + per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ + quantization=${QUANTIZATION} \ + quantize_kvcache=${QUANTIZE_KVCACHE} \ + checkpoint_is_quantized=${CHECKPOINT_IS_QUANTIZED} \ + quant_cfg_path=${QUANT_CFG_PATH} +``` + + ### Benchmarking Gemma-7b Instructions @@ -261,11 +357,12 @@ python JetStream/benchmarks/benchmark_serving.py \ --request-rate 5 \ --warmup-mode sampled ``` +For details, please see https://github.com/google/JetStream/blob/main/benchmarks/README.md -### Benchmarking Llama2-\*b +### Benchmarking Llama2 ```bash -# Same as Gemma-7b except for the tokenizer (must use a tokenizer that matches your model, which should now be tokenizer.llama2). +# The command is the same as that for the Gemma-7b, except for the tokenizer. Since we need to use a tokenizer that matches the model, it should now be tokenizer.llama2. python JetStream/benchmarks/benchmark_serving.py \ --tokenizer maxtext/assets/tokenizer.llama2 \ @@ -276,6 +373,7 @@ python JetStream/benchmarks/benchmark_serving.py \ --request-rate 5 \ --warmup-mode sampled ``` +For details, please see https://github.com/google/JetStream/blob/main/benchmarks/README.md ## Clean Up @@ -283,10 +381,11 @@ python JetStream/benchmarks/benchmark_serving.py \ # Clean up gcs buckets. gcloud storage buckets delete ${MODEL_BUCKET} gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY} -gcloud storage buckets delete ${DATASET_PATH} + # Clean up repositories. rm -rf maxtext rm -rf JetStream + # Clean up python virtual environment rm -rf .env ``` diff --git a/jetstream/tools/maxtext/model_ckpt_conversion.sh b/jetstream/tools/maxtext/model_ckpt_conversion.sh index 8e2b4d83..0340dbfe 100644 --- a/jetstream/tools/maxtext/model_ckpt_conversion.sh +++ b/jetstream/tools/maxtext/model_ckpt_conversion.sh @@ -28,25 +28,21 @@ export MODEL=$1 export MODEL_VARIATION=$2 export MODEL_NAME=${MODEL}-${MODEL_VARIATION} -# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ +# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET # Please use separate GCS paths for uploading open source model weights ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). # Point these variables to a GCS bucket that you created. # An example of CHKPT_BUCKET could be: gs://${USER}-maxtext/chkpt/${MODEL}/${MODEL_VARIATION} export CHKPT_BUCKET=$3 -export MODEL_BUCKET=gs://${USER}-maxtext +export MODEL_BUCKET=$4 -# Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run. -export BASE_OUTPUT_DIRECTORY=gs://${USER}-runner-maxtext-logs - -# Point `DATASET_PATH` to the GCS bucket where you have your training data. -export DATASET_PATH=gs://${USER}-maxtext-dataset +# Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run, specifically the unscanned checkpoint. +export BASE_OUTPUT_DIRECTORY=$5 export BUCKET_LOCATION=US # Create three GCS buckets for the demo. gcloud storage buckets create ${MODEL_BUCKET} --location=${BUCKET_LOCATION} || true gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} --location=${BUCKET_LOCATION} || true -gcloud storage buckets create ${DATASET_PATH} --location=${BUCKET_LOCATION} || true # Convert model checkpoints to MaxText compatible checkpoints. if [ "$MODEL" == "gemma" ]; then From 9e7fc3284670642e813614ca5f70bc70d62fa4dd Mon Sep 17 00:00:00 2001 From: qihqi Date: Tue, 27 Aug 2024 15:31:32 -0700 Subject: [PATCH 34/42] Add an optional parameter for sampling in prefill / sample. (#133) * Add an optional parameter for sampling in prefill / sample. This is needed because we want to enable per-request sampling parameters. This allows jetstream to be used as backend for HuggingFace TGI. * lint --- jetstream/engine/engine_api.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/jetstream/engine/engine_api.py b/jetstream/engine/engine_api.py index 5277f6df..9f42b60f 100644 --- a/jetstream/engine/engine_api.py +++ b/jetstream/engine/engine_api.py @@ -19,7 +19,7 @@ """ import abc -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union, Callable from flax import struct import jax @@ -142,6 +142,7 @@ def prefill( existing_prefix: Optional[Prefix] = None, padded_tokens: jax.Array, true_length: int, + sampler: Optional[Callable[[Any], Any]] = None, ) -> Tuple[Prefix, ResultTokens]: """Computes a kv-cache for a set of tokens conditional on existing cache. @@ -149,11 +150,16 @@ def prefill( processed by the underlying model. tokens is logically appended to the text represented by `existing_prefix`. This method returns a new kv_cache (typically) for the resulting text. + + If sampler is passed, then the engine should use it do sample next token. """ @abc.abstractmethod def generate( - self, params: Params, decode_state: DecodeState + self, + params: Params, + decode_state: DecodeState, + sampler: Optional[Callable[[Any], Any]] = None, ) -> Tuple[DecodeState, ResultTokens]: """Generates tokens for each sequence being decoded in parallel. @@ -165,6 +171,8 @@ def generate( consists of each microbatch progressing through every stage), in non-pipelined code this is a full forward pass. In both cases, this accounts for a full embed-layerstack-unembed-sample operation. + + If sampler is passed, then the engine should use it do sample next token. """ @abc.abstractmethod From 93de5901a19d5271ceea7de107406b6a40f52c0c Mon Sep 17 00:00:00 2001 From: jwyang-google <132702993+jwyang-google@users.noreply.github.com> Date: Wed, 28 Aug 2024 10:43:15 -0700 Subject: [PATCH 35/42] remove excessive logs in production run by changing from DEBUG to INFO (#134) --- jetstream/core/orchestrator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index a0c77c85..0fd64c5e 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -98,10 +98,10 @@ import numpy as np root = logging.getLogger() -root.setLevel(logging.DEBUG) +root.setLevel(logging.INFO) handler = logging.StreamHandler(sys.stdout) -handler.setLevel(logging.DEBUG) +handler.setLevel(logging.INFO) formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) From 445f1aa8e857d0a09d72618e365daf80723bdf4c Mon Sep 17 00:00:00 2001 From: Zhihao Shan <60905719+zhihaoshan-google@users.noreply.github.com> Date: Thu, 5 Sep 2024 16:04:37 -0700 Subject: [PATCH 36/42] Change the default message for requester.py and move mlperf 4.1 install (#136) for proxy version support. using a more positive default message for requester.py. Co-authored-by: Zhihao Shan --- jetstream/tools/proxy_dev/base.Dockerfile | 2 -- jetstream/tools/proxy_dev/dev.Dockerfile | 1 - jetstream/tools/requester.py | 2 +- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/jetstream/tools/proxy_dev/base.Dockerfile b/jetstream/tools/proxy_dev/base.Dockerfile index 5e4cd2e4..9162bcf0 100644 --- a/jetstream/tools/proxy_dev/base.Dockerfile +++ b/jetstream/tools/proxy_dev/base.Dockerfile @@ -22,9 +22,7 @@ RUN pip install setuptools==58 fastapi==0.103.2 uvicorn RUN pip install ./JetStream -COPY inference_mlperf4.1 ./inference_mlperf4.1 RUN apt -y update && apt-get -y install python3-dev && apt-get -y install build-essential -RUN pip install ./inference_mlperf4.1/loadgen RUN pip install \ transformers==4.31.0 \ nltk==3.8.1 \ diff --git a/jetstream/tools/proxy_dev/dev.Dockerfile b/jetstream/tools/proxy_dev/dev.Dockerfile index be7a36fc..25bf382e 100644 --- a/jetstream/tools/proxy_dev/dev.Dockerfile +++ b/jetstream/tools/proxy_dev/dev.Dockerfile @@ -11,7 +11,6 @@ ENV JAX_BACKEND_TARGET=grpc://localhost:38681 # Copy all files from local workspace into docker container COPY JetStream ./JetStream COPY maxtext ./maxtext -COPY inference_mlperf4.1 ./inference_mlperf4.1 RUN pip install ./JetStream RUN pip install -r ./maxtext/requirements.txt diff --git a/jetstream/tools/requester.py b/jetstream/tools/requester.py index 30d7ac40..7ac0d55a 100644 --- a/jetstream/tools/requester.py +++ b/jetstream/tools/requester.py @@ -26,7 +26,7 @@ _SERVER = flags.DEFINE_string("server", "0.0.0.0", "server address") _PORT = flags.DEFINE_string("port", "9000", "port to ping") -_TEXT = flags.DEFINE_string("text", "Today is a good day", "The message") +_TEXT = flags.DEFINE_string("text", "My dog is cute", "The message") _MAX_TOKENS = flags.DEFINE_integer( "max_tokens", 3, "Maximum number of output/decode tokens of a sequence" ) From 530d364e18d5759d7afd538bbb1d767de10a3f4c Mon Sep 17 00:00:00 2001 From: vivianrwu Date: Thu, 26 Sep 2024 14:00:08 -0700 Subject: [PATCH 37/42] Change previewutilities -> pathwaysutils (#138) --- jetstream/engine/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream/engine/__init__.py b/jetstream/engine/__init__.py index ee979964..99bf4983 100644 --- a/jetstream/engine/__init__.py +++ b/jetstream/engine/__init__.py @@ -17,7 +17,7 @@ import jax try: - import previewutilities + import pathwaysutils except ImportError as e: print("Proxy backend support is not added") pass From 52d63a5f1c39a44432c7c4af0525672c2bc93734 Mon Sep 17 00:00:00 2001 From: Ran Ran Date: Mon, 4 Nov 2024 15:01:26 -0800 Subject: [PATCH 38/42] Add option to use hf tokenizer (#147) --- benchmarks/benchmark_serving.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 6076beba..a738b8b1 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -79,6 +79,7 @@ import pandas from eval_accuracy import eval_accuracy +from transformers import AutoTokenizer def str2bool(v: str) -> bool: @@ -156,16 +157,29 @@ def to_dict(self): } -def get_tokenizer(model_id: str, tokenizer_name: str) -> Any: +def get_tokenizer( + model_id: str, + tokenizer_name: str, + use_hf_tokenizer: bool, +) -> Any: """Return a tokenizer or a tokenizer placholder.""" if tokenizer_name == "test": + print("Using test tokenizer") return "test" + elif use_hf_tokenizer: + # Please accept agreement to access private/gated models in HF, and + # follow up instructions below to set up access token + # https://huggingface.co/docs/transformers.js/en/guides/private + print(f"Using HuggingFace tokenizer: {tokenizer_name}") + return AutoTokenizer.from_pretrained(tokenizer_name) elif model_id == "llama-3": # Llama 3 uses a tiktoken tokenizer. + print(f"Using llama-3 tokenizer: {tokenizer_name}") return llama3_tokenizer.Tokenizer(tokenizer_name) else: # Use JetStream tokenizer util. It's using the sentencepiece wrapper in # seqio library. + print(f"Using tokenizer: {tokenizer_name}") vocab = load_vocab(tokenizer_name) return vocab.tokenizer @@ -563,10 +577,11 @@ def main(args: argparse.Namespace): model_id = args.model tokenizer_id = args.tokenizer + use_hf_tokenizer = args.use_hf_tokenizer api_url = f"{args.server}:{args.port}" - tokenizer = get_tokenizer(model_id, tokenizer_id) + tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer) if tokenizer == "test" or args.dataset == "test": input_requests = mock_requests( args.total_mock_requests @@ -716,6 +731,15 @@ def main(args: argparse.Namespace): " default value)" ), ) + parser.add_argument( + "--use-hf-tokenizer", + type=str2bool, + default=False, + help=( + "Whether to use tokenizer from HuggingFace. If so, set this flag" + " to True, and provide name of the tokenizer in the tokenizer flag." + ), + ) parser.add_argument( "--num-prompts", type=int, From 15e3963be8fd13b468e049f92545eed96a6a8e7a Mon Sep 17 00:00:00 2001 From: Yijia Date: Thu, 7 Nov 2024 18:24:46 -0800 Subject: [PATCH 39/42] Rename third_party folder to Avoid Copybara g3 Errors (#148) * rename third_party * fix ut * fix ut --- Makefile | 2 +- benchmarks/benchmark_serving.py | 2 +- jetstream/engine/token_utils.py | 2 +- .../__init__.py | 0 .../llama3/__init__.py | 0 .../llama3/llama3_tokenizer.py | 0 .../llama2/tokenizer.model | Bin .../llama3/tokenizer.model | 0 jetstream/tests/engine/test_token_utils.py | 4 ++-- pylintrc | 2 +- 10 files changed, 6 insertions(+), 6 deletions(-) rename jetstream/{third_party => external_tokenizers}/__init__.py (100%) rename jetstream/{third_party => external_tokenizers}/llama3/__init__.py (100%) rename jetstream/{third_party => external_tokenizers}/llama3/llama3_tokenizer.py (100%) rename jetstream/tests/engine/{third_party => external_tokenizers}/llama2/tokenizer.model (100%) rename jetstream/tests/engine/{third_party => external_tokenizers}/llama3/tokenizer.model (100%) diff --git a/Makefile b/Makefile index a8a88085..7f3cff00 100644 --- a/Makefile +++ b/Makefile @@ -51,4 +51,4 @@ unit-tests: coverage run -m unittest -v check-test-coverage: - coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/third_party/*" --fail-under=96 + coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/external_tokenizers/*" --fail-under=96 diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index a738b8b1..97628372 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -73,7 +73,7 @@ from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc from jetstream.engine.token_utils import load_vocab -from jetstream.third_party.llama3 import llama3_tokenizer +from jetstream.external_tokenizers.llama3 import llama3_tokenizer import numpy as np from tqdm.asyncio import tqdm # pytype: disable=pyi-error import pandas diff --git a/jetstream/engine/token_utils.py b/jetstream/engine/token_utils.py index d6b50d29..b653c34b 100644 --- a/jetstream/engine/token_utils.py +++ b/jetstream/engine/token_utils.py @@ -28,7 +28,7 @@ from jetstream.engine import mock_utils from jetstream.engine import tokenizer_api from jetstream.engine import tokenizer_pb2 -from jetstream.third_party.llama3 import llama3_tokenizer +from jetstream.external_tokenizers.llama3 import llama3_tokenizer # ResultToken class to store tokens ids. ResultTokens = Any diff --git a/jetstream/third_party/__init__.py b/jetstream/external_tokenizers/__init__.py similarity index 100% rename from jetstream/third_party/__init__.py rename to jetstream/external_tokenizers/__init__.py diff --git a/jetstream/third_party/llama3/__init__.py b/jetstream/external_tokenizers/llama3/__init__.py similarity index 100% rename from jetstream/third_party/llama3/__init__.py rename to jetstream/external_tokenizers/llama3/__init__.py diff --git a/jetstream/third_party/llama3/llama3_tokenizer.py b/jetstream/external_tokenizers/llama3/llama3_tokenizer.py similarity index 100% rename from jetstream/third_party/llama3/llama3_tokenizer.py rename to jetstream/external_tokenizers/llama3/llama3_tokenizer.py diff --git a/jetstream/tests/engine/third_party/llama2/tokenizer.model b/jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model similarity index 100% rename from jetstream/tests/engine/third_party/llama2/tokenizer.model rename to jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model diff --git a/jetstream/tests/engine/third_party/llama3/tokenizer.model b/jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model similarity index 100% rename from jetstream/tests/engine/third_party/llama3/tokenizer.model rename to jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model diff --git a/jetstream/tests/engine/test_token_utils.py b/jetstream/tests/engine/test_token_utils.py index 41fff3d3..cdee310e 100644 --- a/jetstream/tests/engine/test_token_utils.py +++ b/jetstream/tests/engine/test_token_utils.py @@ -55,7 +55,7 @@ def decode(self, t: int) -> str: class TokenUtilsTest(unittest.TestCase): def setup_sentencepiece(self): - self.tokenizer_path = "third_party/llama2/tokenizer.model" + self.tokenizer_path = "external_tokenizers/llama2/tokenizer.model" current_dir = os.path.dirname(__file__) self.tokenizer_path = os.path.join(current_dir, self.tokenizer_path) print(f"model_path: {self.tokenizer_path}") @@ -66,7 +66,7 @@ def setup_sentencepiece(self): self.jt_tokenizer = JetStreamTokenizer(self.tokenizer_path) def setup_tiktoken(self): - self.tokenizer_path = "third_party/llama3/tokenizer.model" + self.tokenizer_path = "external_tokenizers/llama3/tokenizer.model" current_dir = os.path.dirname(__file__) self.tokenizer_path = os.path.join(current_dir, self.tokenizer_path) print(f"model_path: {self.tokenizer_path}") diff --git a/pylintrc b/pylintrc index 0df2cb47..52bc64ba 100644 --- a/pylintrc +++ b/pylintrc @@ -9,7 +9,7 @@ [MAIN] # Files or directories to be skipped. They should be base names, not paths. -ignore=third_party +ignore=external_tokenizers # Files or directories matching the regex patterns are skipped. The regex # matches against base names, not paths. From d462ca9bbc55531bbe785203cb076e7797250f2a Mon Sep 17 00:00:00 2001 From: Zhihao Shan <60905719+zhihaoshan-google@users.noreply.github.com> Date: Mon, 18 Nov 2024 22:56:16 -0800 Subject: [PATCH 40/42] add seperate prefill detokenization thread (#152) Co-authored-by: Zhihao Shan --- jetstream/core/orchestrator.py | 68 ++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 19 deletions(-) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 0fd64c5e..15fc36dd 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -38,7 +38,7 @@ of the generation loop at the relevant slot. - Regardless, it performs a step. - It takes the sampled tokens, and places them on a 'detokenizing_queue'. -7. Within the detokenizing thread: +7. Within the detokenizing thread (Prefill and Generate separately): - Tokens are detokenized for every 'slot' in a given set of sampled tokens. - When an end condition is met, the 'slot' integer is returned to the respective generation queue. @@ -210,7 +210,8 @@ class Driver: # Stage 4 # This can be a list because we can pass it as an arg to generate and # detokenize threads. It is a list of tokens to be detokenized. - _detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = [] + _prefill_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = [] + _generate_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = [] _generate_slots: list[queue.Queue[int]] = [] _active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = [] @@ -270,11 +271,11 @@ def __init__( # one of the generate backlogs. # Interleaved Mode: Max size is 1 to increase the HBM utilization # during generate. - # Disaggregated Mode: Max size is 4 to allow for 2 prefills to be enqueued - # while 1 transfer is enqueued while 1 is being transferred. + # Disaggregated Mode: Max size is 16 to allow for total 16 prefills to + # be enqueued or enqueued while 1 is being transferred. # TODO: Make queue size configurable. self._transfer_backlogs = [ - queue.Queue(1 if self._interleaved_mode else 4) + queue.Queue(1 if self._interleaved_mode else 16) for i in range(len(self._prefill_engines)) ] if self._metrics_collector: @@ -302,10 +303,11 @@ def __init__( functools.partial(float, backlog.qsize()) ) # Stage 4 - # After generation, ActiveRequests are placed on the detokenization backlog - # for tokens to be sent into each ActiveRequest's return channel. - # We have one of these per generate engine to simplify the logic keeping - # track of which generation engine to replace slots on. + # After prefill and generation, ActiveRequests are placed on the + # detokenization backlog for tokens to be sent into each ActiveRequest's + # return channel. + # We have one of these per prefill / generate engine to simplify + # the logic keeping track of which generation engine to replace slots on. # This is a queue of either - tuple[int, ActiveRequest] which represents our # active requests, or tuple[int, sample_tokens]. We combine these into one # queue because it allows us to be somewhat clever with how we do @@ -320,7 +322,16 @@ def __init__( # the possibility of race conditions where a slot is made live before the # tokens are ready and it receives tokens from a different sequence, # or tokens detokenized before the relevant slot is live. - self._detokenize_backlogs = [ + + self._prefill_detokenize_backlogs = [ + # No need to set maxsize, as transfer queue can + # provide the backpressure to the prefill workload + # (to avoid the overwhelming prefill). + queue.Queue() + for _ in self._prefill_engines + ] + + self._generate_detokenize_backlogs = [ # We don't let detokenization accumulate more than 8 steps to avoid # synchronization issues. queue.Queue(8) @@ -376,13 +387,25 @@ def __init__( ) for idx in range(len(self._generate_engines)) ] - self.detokenize_threads = [ + self.prefill_detokenize_threads = [ JetThread( target=functools.partial( self._detokenize_thread, - idx, + is_prefill=True, + idx=idx, + ), + name=f"prefill_detokenize-{idx}", + ) + for idx in range(len(self._generate_engines)) + ] + self.generate_detokenize_threads = [ + JetThread( + target=functools.partial( + self._detokenize_thread, + is_prefill=False, + idx=idx, ), - name=f"detokenize-{idx}", + name=f"generate_detokenize-{idx}", ) for idx in range(len(self._generate_engines)) ] @@ -391,7 +414,8 @@ def __init__( self._prefill_threads, self._transfer_threads, self._generate_threads, - self.detokenize_threads, + self.prefill_detokenize_threads, + self.generate_detokenize_threads, ) ) self.live = True @@ -410,7 +434,8 @@ def stop(self): [self._prefill_backlog], self._transfer_backlogs, self._generate_backlogs.values(), - self._detokenize_backlogs, + self._prefill_detokenize_backlogs, + self._generate_detokenize_backlogs, ) ) @@ -523,7 +548,7 @@ def _prefill_thread(self, idx: int): # put first token to detokenize queue request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_) - my_detokenize_backlog = self._detokenize_backlogs[idx] + my_detokenize_backlog = self._prefill_detokenize_backlogs[idx] request.metadata.transfer_enqueue_time = time.perf_counter() my_detokenize_backlog.put( (first_token, request, request.metadata.prefill_dequeue_time), @@ -619,7 +644,7 @@ def _generate_thread(self, idx: int): generate_engine = self._generate_engines[idx] my_slots = self._generate_slots[idx] my_generate_backlog = self._generate_backlogs[idx] - my_detokenize_backlog = self._detokenize_backlogs[idx] + my_detokenize_backlog = self._generate_detokenize_backlogs[idx] # Keep track of what step tokens were generated at. generate_timestep = 0 @@ -749,12 +774,17 @@ def _generate_thread(self, idx: int): ) time_of_last_generate = time.time() - def _detokenize_thread(self, idx: int): + def _detokenize_thread(self, is_prefill: bool, idx: int): """Detokenize sampled tokens and returns them to the user.""" # One of these per generate engine. # For all filled my_slots, pop the sampled token onto the relevant # requests return channel. If it done, place it back onto free slots. - my_detokenize_backlog = self._detokenize_backlogs[idx] + + if is_prefill: + my_detokenize_backlog = self._prefill_detokenize_backlogs[idx] + else: + my_detokenize_backlog = self._generate_detokenize_backlogs[idx] + my_generate_engine = self._generate_engines[idx] my_slots = self._generate_slots[idx] From 8e18e7fd1db4ee271fa677eec86b0b90a3822c95 Mon Sep 17 00:00:00 2001 From: jetstream authors Date: Mon, 16 Dec 2024 19:16:10 +0000 Subject: [PATCH 41/42] Internal refactor PiperOrigin-RevId: 706772024 --- .../third_party/py/jetstream/.github}/CODEOWNERS | 0 .../py/jetstream/.github}/workflows/e2e_tests.yaml | 0 .../py/jetstream/.github}/workflows/release.yaml | 0 .../.github}/workflows/scripts/create_release.js | 0 .../py/jetstream/.github}/workflows/unit_tests.yaml | 0 .../third_party/py/jetstream/.gitignore | 0 AUTHORS => google3/third_party/py/jetstream/AUTHORS | 0 .../third_party/py/jetstream/CONTRIBUTING.md | 0 LICENSE => google3/third_party/py/jetstream/LICENSE | 0 .../third_party/py/jetstream/MANIFEST.in | 0 .../third_party/py/jetstream/Makefile | 0 .../third_party/py/jetstream/README.md | 0 .../third_party/py/jetstream}/__init__.py | 0 .../third_party/py/jetstream/benchmarks}/README.md | 0 .../py/jetstream/benchmarks}/__init__.py | 0 .../py/jetstream/benchmarks}/benchmark_serving.py | 0 .../py/jetstream/benchmarks}/eval_accuracy.py | 1 - ...n_orca_gpt4_tokenized_llama.calibration_1000.pkl | Bin .../py/jetstream/benchmarks}/requirements.in | 0 .../third_party/py/jetstream}/core/README.md | 0 .../third_party/py/jetstream}/core/__init__.py | 0 .../third_party/py/jetstream}/core/config_lib.py | 0 .../py/jetstream}/core/implementations/__init__.py | 0 .../jetstream}/core/implementations/mock/README.md | 0 .../core/implementations/mock/__init__.py | 0 .../jetstream}/core/implementations/mock/config.py | 0 .../jetstream}/core/implementations/mock/server.py | 0 .../py/jetstream}/core/metrics/__init__.py | 0 .../py/jetstream}/core/metrics/prometheus.py | 0 .../third_party/py/jetstream}/core/orchestrator.py | 0 .../py/jetstream}/core/proto/__init__.py | 0 .../py/jetstream}/core/proto/jetstream.proto | 0 .../py/jetstream}/core/proto/jetstream_pb2.py | 0 .../py/jetstream}/core/proto/jetstream_pb2_grpc.py | 0 .../third_party/py/jetstream}/core/server_lib.py | 0 .../py/jetstream}/core/utils/__init__.py | 0 .../py/jetstream}/core/utils/async_multifuture.py | 0 .../py/jetstream}/core/utils/proxy_util.py | 0 .../py/jetstream}/core/utils/return_sample.py | 0 ...bility-prometheus-metrics-in-jetstream-server.md | 2 +- .../docs}/online-inference-with-maxtext-engine.md | 0 .../profiling-with-jax-profiler-and-tensorboard.md | 0 .../third_party/py/jetstream}/engine/README.md | 0 .../third_party/py/jetstream}/engine/__init__.py | 0 .../third_party/py/jetstream}/engine/engine_api.py | 0 .../third_party/py/jetstream}/engine/mock_engine.py | 0 .../third_party/py/jetstream}/engine/mock_utils.py | 0 .../py/jetstream}/engine/sampling_utils.py | 0 .../third_party/py/jetstream}/engine/token_utils.py | 0 .../py/jetstream}/engine/tokenizer.proto | 0 .../py/jetstream}/engine/tokenizer_api.py | 0 .../py/jetstream}/engine/tokenizer_pb2.py | 0 .../py/jetstream}/engine/tokenizer_pb2_grpc.py | 0 .../py/jetstream}/engine/warmup_utils.py | 0 .../py/jetstream}/entrypoints/__init__.py | 0 .../third_party/py/jetstream}/entrypoints/config.py | 0 .../py/jetstream}/entrypoints/http/__init__.py | 0 .../py/jetstream}/entrypoints/http/api_server.py | 0 .../py/jetstream}/entrypoints/http/protocol.py | 0 .../py/jetstream}/entrypoints/http/utils.py | 0 .../py/jetstream}/external_tokenizers/__init__.py | 0 .../external_tokenizers/llama3/__init__.py | 0 .../external_tokenizers/llama3/llama3_tokenizer.py | 0 .../third_party/py/jetstream/license_preamble.txt | 0 .../third_party/py/jetstream/pylintrc | 0 .../third_party/py/jetstream/requirements.txt | 0 .../third_party/py/jetstream/setup.py | 0 .../third_party/py/jetstream}/tests/__init__.py | 0 .../py/jetstream}/tests/core/__init__.py | 0 .../py/jetstream}/tests/core/test_config_lib.py | 0 .../py/jetstream}/tests/core/test_orchestrator.py | 0 .../py/jetstream}/tests/core/test_server.py | 0 .../py/jetstream}/tests/engine/__init__.py | 0 .../external_tokenizers/llama2/tokenizer.model | Bin .../external_tokenizers/llama3/tokenizer.model | 0 .../py/jetstream}/tests/engine/test_mock_engine.py | 0 .../jetstream}/tests/engine/test_sampling_utils.py | 0 .../py/jetstream}/tests/engine/test_token_utils.py | 0 .../py/jetstream}/tests/engine/test_utils.py | 0 .../py/jetstream}/tests/entrypoints/__init__.py | 0 .../jetstream}/tests/entrypoints/http/__init__.py | 0 .../tests/entrypoints/http/test_api_server.py | 0 .../third_party/py/jetstream}/tools/load_tester.py | 0 .../tools/maxtext/model_ckpt_conversion.sh | 0 .../tools/maxtext/model_ckpt_finetune_with_aqt.sh | 0 .../py/jetstream}/tools/proxy_dev/base.Dockerfile | 0 .../py/jetstream}/tools/proxy_dev/dev.Dockerfile | 0 .../third_party/py/jetstream}/tools/requester.py | 0 88 files changed, 1 insertion(+), 2 deletions(-) rename {.github => google3/third_party/py/jetstream/.github}/CODEOWNERS (100%) rename {.github => google3/third_party/py/jetstream/.github}/workflows/e2e_tests.yaml (100%) rename {.github => google3/third_party/py/jetstream/.github}/workflows/release.yaml (100%) rename {.github => google3/third_party/py/jetstream/.github}/workflows/scripts/create_release.js (100%) rename {.github => google3/third_party/py/jetstream/.github}/workflows/unit_tests.yaml (100%) rename .gitignore => google3/third_party/py/jetstream/.gitignore (100%) rename AUTHORS => google3/third_party/py/jetstream/AUTHORS (100%) rename CONTRIBUTING.md => google3/third_party/py/jetstream/CONTRIBUTING.md (100%) rename LICENSE => google3/third_party/py/jetstream/LICENSE (100%) rename MANIFEST.in => google3/third_party/py/jetstream/MANIFEST.in (100%) rename Makefile => google3/third_party/py/jetstream/Makefile (100%) rename README.md => google3/third_party/py/jetstream/README.md (100%) rename {jetstream => google3/third_party/py/jetstream}/__init__.py (100%) rename {benchmarks => google3/third_party/py/jetstream/benchmarks}/README.md (100%) rename {benchmarks => google3/third_party/py/jetstream/benchmarks}/__init__.py (100%) rename {benchmarks => google3/third_party/py/jetstream/benchmarks}/benchmark_serving.py (100%) rename {benchmarks => google3/third_party/py/jetstream/benchmarks}/eval_accuracy.py (99%) rename {benchmarks => google3/third_party/py/jetstream/benchmarks}/open_orca_gpt4_tokenized_llama.calibration_1000.pkl (100%) rename {benchmarks => google3/third_party/py/jetstream/benchmarks}/requirements.in (100%) rename {jetstream => google3/third_party/py/jetstream}/core/README.md (100%) rename {jetstream => google3/third_party/py/jetstream}/core/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/config_lib.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/implementations/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/implementations/mock/README.md (100%) rename {jetstream => google3/third_party/py/jetstream}/core/implementations/mock/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/implementations/mock/config.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/implementations/mock/server.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/metrics/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/metrics/prometheus.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/orchestrator.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/proto/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/proto/jetstream.proto (100%) rename {jetstream => google3/third_party/py/jetstream}/core/proto/jetstream_pb2.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/proto/jetstream_pb2_grpc.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/server_lib.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/utils/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/utils/async_multifuture.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/utils/proxy_util.py (100%) rename {jetstream => google3/third_party/py/jetstream}/core/utils/return_sample.py (100%) rename {docs => google3/third_party/py/jetstream/docs}/observability-prometheus-metrics-in-jetstream-server.md (94%) rename {docs => google3/third_party/py/jetstream/docs}/online-inference-with-maxtext-engine.md (100%) rename {docs => google3/third_party/py/jetstream/docs}/profiling-with-jax-profiler-and-tensorboard.md (100%) rename {jetstream => google3/third_party/py/jetstream}/engine/README.md (100%) rename {jetstream => google3/third_party/py/jetstream}/engine/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/engine/engine_api.py (100%) rename {jetstream => google3/third_party/py/jetstream}/engine/mock_engine.py (100%) rename {jetstream => google3/third_party/py/jetstream}/engine/mock_utils.py (100%) rename {jetstream => google3/third_party/py/jetstream}/engine/sampling_utils.py (100%) rename {jetstream => google3/third_party/py/jetstream}/engine/token_utils.py (100%) rename {jetstream => google3/third_party/py/jetstream}/engine/tokenizer.proto (100%) rename {jetstream => google3/third_party/py/jetstream}/engine/tokenizer_api.py (100%) rename {jetstream => google3/third_party/py/jetstream}/engine/tokenizer_pb2.py (100%) rename {jetstream => google3/third_party/py/jetstream}/engine/tokenizer_pb2_grpc.py (100%) rename {jetstream => google3/third_party/py/jetstream}/engine/warmup_utils.py (100%) rename {jetstream => google3/third_party/py/jetstream}/entrypoints/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/entrypoints/config.py (100%) rename {jetstream => google3/third_party/py/jetstream}/entrypoints/http/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/entrypoints/http/api_server.py (100%) rename {jetstream => google3/third_party/py/jetstream}/entrypoints/http/protocol.py (100%) rename {jetstream => google3/third_party/py/jetstream}/entrypoints/http/utils.py (100%) rename {jetstream => google3/third_party/py/jetstream}/external_tokenizers/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/external_tokenizers/llama3/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/external_tokenizers/llama3/llama3_tokenizer.py (100%) rename license_preamble.txt => google3/third_party/py/jetstream/license_preamble.txt (100%) rename pylintrc => google3/third_party/py/jetstream/pylintrc (100%) rename requirements.txt => google3/third_party/py/jetstream/requirements.txt (100%) rename setup.py => google3/third_party/py/jetstream/setup.py (100%) rename {jetstream => google3/third_party/py/jetstream}/tests/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/tests/core/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/tests/core/test_config_lib.py (100%) rename {jetstream => google3/third_party/py/jetstream}/tests/core/test_orchestrator.py (100%) rename {jetstream => google3/third_party/py/jetstream}/tests/core/test_server.py (100%) rename {jetstream => google3/third_party/py/jetstream}/tests/engine/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/tests/engine/external_tokenizers/llama2/tokenizer.model (100%) rename {jetstream => google3/third_party/py/jetstream}/tests/engine/external_tokenizers/llama3/tokenizer.model (100%) rename {jetstream => google3/third_party/py/jetstream}/tests/engine/test_mock_engine.py (100%) rename {jetstream => google3/third_party/py/jetstream}/tests/engine/test_sampling_utils.py (100%) rename {jetstream => google3/third_party/py/jetstream}/tests/engine/test_token_utils.py (100%) rename {jetstream => google3/third_party/py/jetstream}/tests/engine/test_utils.py (100%) rename {jetstream => google3/third_party/py/jetstream}/tests/entrypoints/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/tests/entrypoints/http/__init__.py (100%) rename {jetstream => google3/third_party/py/jetstream}/tests/entrypoints/http/test_api_server.py (100%) rename {jetstream => google3/third_party/py/jetstream}/tools/load_tester.py (100%) rename {jetstream => google3/third_party/py/jetstream}/tools/maxtext/model_ckpt_conversion.sh (100%) rename {jetstream => google3/third_party/py/jetstream}/tools/maxtext/model_ckpt_finetune_with_aqt.sh (100%) rename {jetstream => google3/third_party/py/jetstream}/tools/proxy_dev/base.Dockerfile (100%) rename {jetstream => google3/third_party/py/jetstream}/tools/proxy_dev/dev.Dockerfile (100%) rename {jetstream => google3/third_party/py/jetstream}/tools/requester.py (100%) diff --git a/.github/CODEOWNERS b/google3/third_party/py/jetstream/.github/CODEOWNERS similarity index 100% rename from .github/CODEOWNERS rename to google3/third_party/py/jetstream/.github/CODEOWNERS diff --git a/.github/workflows/e2e_tests.yaml b/google3/third_party/py/jetstream/.github/workflows/e2e_tests.yaml similarity index 100% rename from .github/workflows/e2e_tests.yaml rename to google3/third_party/py/jetstream/.github/workflows/e2e_tests.yaml diff --git a/.github/workflows/release.yaml b/google3/third_party/py/jetstream/.github/workflows/release.yaml similarity index 100% rename from .github/workflows/release.yaml rename to google3/third_party/py/jetstream/.github/workflows/release.yaml diff --git a/.github/workflows/scripts/create_release.js b/google3/third_party/py/jetstream/.github/workflows/scripts/create_release.js similarity index 100% rename from .github/workflows/scripts/create_release.js rename to google3/third_party/py/jetstream/.github/workflows/scripts/create_release.js diff --git a/.github/workflows/unit_tests.yaml b/google3/third_party/py/jetstream/.github/workflows/unit_tests.yaml similarity index 100% rename from .github/workflows/unit_tests.yaml rename to google3/third_party/py/jetstream/.github/workflows/unit_tests.yaml diff --git a/.gitignore b/google3/third_party/py/jetstream/.gitignore similarity index 100% rename from .gitignore rename to google3/third_party/py/jetstream/.gitignore diff --git a/AUTHORS b/google3/third_party/py/jetstream/AUTHORS similarity index 100% rename from AUTHORS rename to google3/third_party/py/jetstream/AUTHORS diff --git a/CONTRIBUTING.md b/google3/third_party/py/jetstream/CONTRIBUTING.md similarity index 100% rename from CONTRIBUTING.md rename to google3/third_party/py/jetstream/CONTRIBUTING.md diff --git a/LICENSE b/google3/third_party/py/jetstream/LICENSE similarity index 100% rename from LICENSE rename to google3/third_party/py/jetstream/LICENSE diff --git a/MANIFEST.in b/google3/third_party/py/jetstream/MANIFEST.in similarity index 100% rename from MANIFEST.in rename to google3/third_party/py/jetstream/MANIFEST.in diff --git a/Makefile b/google3/third_party/py/jetstream/Makefile similarity index 100% rename from Makefile rename to google3/third_party/py/jetstream/Makefile diff --git a/README.md b/google3/third_party/py/jetstream/README.md similarity index 100% rename from README.md rename to google3/third_party/py/jetstream/README.md diff --git a/jetstream/__init__.py b/google3/third_party/py/jetstream/__init__.py similarity index 100% rename from jetstream/__init__.py rename to google3/third_party/py/jetstream/__init__.py diff --git a/benchmarks/README.md b/google3/third_party/py/jetstream/benchmarks/README.md similarity index 100% rename from benchmarks/README.md rename to google3/third_party/py/jetstream/benchmarks/README.md diff --git a/benchmarks/__init__.py b/google3/third_party/py/jetstream/benchmarks/__init__.py similarity index 100% rename from benchmarks/__init__.py rename to google3/third_party/py/jetstream/benchmarks/__init__.py diff --git a/benchmarks/benchmark_serving.py b/google3/third_party/py/jetstream/benchmarks/benchmark_serving.py similarity index 100% rename from benchmarks/benchmark_serving.py rename to google3/third_party/py/jetstream/benchmarks/benchmark_serving.py diff --git a/benchmarks/eval_accuracy.py b/google3/third_party/py/jetstream/benchmarks/eval_accuracy.py similarity index 99% rename from benchmarks/eval_accuracy.py rename to google3/third_party/py/jetstream/benchmarks/eval_accuracy.py index f84562be..559cd2a8 100644 --- a/benchmarks/eval_accuracy.py +++ b/google3/third_party/py/jetstream/benchmarks/eval_accuracy.py @@ -64,7 +64,6 @@ def main(args): eval_accuracy(request_outputs_dict) - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( diff --git a/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl b/google3/third_party/py/jetstream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl similarity index 100% rename from benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl rename to google3/third_party/py/jetstream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl diff --git a/benchmarks/requirements.in b/google3/third_party/py/jetstream/benchmarks/requirements.in similarity index 100% rename from benchmarks/requirements.in rename to google3/third_party/py/jetstream/benchmarks/requirements.in diff --git a/jetstream/core/README.md b/google3/third_party/py/jetstream/core/README.md similarity index 100% rename from jetstream/core/README.md rename to google3/third_party/py/jetstream/core/README.md diff --git a/jetstream/core/__init__.py b/google3/third_party/py/jetstream/core/__init__.py similarity index 100% rename from jetstream/core/__init__.py rename to google3/third_party/py/jetstream/core/__init__.py diff --git a/jetstream/core/config_lib.py b/google3/third_party/py/jetstream/core/config_lib.py similarity index 100% rename from jetstream/core/config_lib.py rename to google3/third_party/py/jetstream/core/config_lib.py diff --git a/jetstream/core/implementations/__init__.py b/google3/third_party/py/jetstream/core/implementations/__init__.py similarity index 100% rename from jetstream/core/implementations/__init__.py rename to google3/third_party/py/jetstream/core/implementations/__init__.py diff --git a/jetstream/core/implementations/mock/README.md b/google3/third_party/py/jetstream/core/implementations/mock/README.md similarity index 100% rename from jetstream/core/implementations/mock/README.md rename to google3/third_party/py/jetstream/core/implementations/mock/README.md diff --git a/jetstream/core/implementations/mock/__init__.py b/google3/third_party/py/jetstream/core/implementations/mock/__init__.py similarity index 100% rename from jetstream/core/implementations/mock/__init__.py rename to google3/third_party/py/jetstream/core/implementations/mock/__init__.py diff --git a/jetstream/core/implementations/mock/config.py b/google3/third_party/py/jetstream/core/implementations/mock/config.py similarity index 100% rename from jetstream/core/implementations/mock/config.py rename to google3/third_party/py/jetstream/core/implementations/mock/config.py diff --git a/jetstream/core/implementations/mock/server.py b/google3/third_party/py/jetstream/core/implementations/mock/server.py similarity index 100% rename from jetstream/core/implementations/mock/server.py rename to google3/third_party/py/jetstream/core/implementations/mock/server.py diff --git a/jetstream/core/metrics/__init__.py b/google3/third_party/py/jetstream/core/metrics/__init__.py similarity index 100% rename from jetstream/core/metrics/__init__.py rename to google3/third_party/py/jetstream/core/metrics/__init__.py diff --git a/jetstream/core/metrics/prometheus.py b/google3/third_party/py/jetstream/core/metrics/prometheus.py similarity index 100% rename from jetstream/core/metrics/prometheus.py rename to google3/third_party/py/jetstream/core/metrics/prometheus.py diff --git a/jetstream/core/orchestrator.py b/google3/third_party/py/jetstream/core/orchestrator.py similarity index 100% rename from jetstream/core/orchestrator.py rename to google3/third_party/py/jetstream/core/orchestrator.py diff --git a/jetstream/core/proto/__init__.py b/google3/third_party/py/jetstream/core/proto/__init__.py similarity index 100% rename from jetstream/core/proto/__init__.py rename to google3/third_party/py/jetstream/core/proto/__init__.py diff --git a/jetstream/core/proto/jetstream.proto b/google3/third_party/py/jetstream/core/proto/jetstream.proto similarity index 100% rename from jetstream/core/proto/jetstream.proto rename to google3/third_party/py/jetstream/core/proto/jetstream.proto diff --git a/jetstream/core/proto/jetstream_pb2.py b/google3/third_party/py/jetstream/core/proto/jetstream_pb2.py similarity index 100% rename from jetstream/core/proto/jetstream_pb2.py rename to google3/third_party/py/jetstream/core/proto/jetstream_pb2.py diff --git a/jetstream/core/proto/jetstream_pb2_grpc.py b/google3/third_party/py/jetstream/core/proto/jetstream_pb2_grpc.py similarity index 100% rename from jetstream/core/proto/jetstream_pb2_grpc.py rename to google3/third_party/py/jetstream/core/proto/jetstream_pb2_grpc.py diff --git a/jetstream/core/server_lib.py b/google3/third_party/py/jetstream/core/server_lib.py similarity index 100% rename from jetstream/core/server_lib.py rename to google3/third_party/py/jetstream/core/server_lib.py diff --git a/jetstream/core/utils/__init__.py b/google3/third_party/py/jetstream/core/utils/__init__.py similarity index 100% rename from jetstream/core/utils/__init__.py rename to google3/third_party/py/jetstream/core/utils/__init__.py diff --git a/jetstream/core/utils/async_multifuture.py b/google3/third_party/py/jetstream/core/utils/async_multifuture.py similarity index 100% rename from jetstream/core/utils/async_multifuture.py rename to google3/third_party/py/jetstream/core/utils/async_multifuture.py diff --git a/jetstream/core/utils/proxy_util.py b/google3/third_party/py/jetstream/core/utils/proxy_util.py similarity index 100% rename from jetstream/core/utils/proxy_util.py rename to google3/third_party/py/jetstream/core/utils/proxy_util.py diff --git a/jetstream/core/utils/return_sample.py b/google3/third_party/py/jetstream/core/utils/return_sample.py similarity index 100% rename from jetstream/core/utils/return_sample.py rename to google3/third_party/py/jetstream/core/utils/return_sample.py diff --git a/docs/observability-prometheus-metrics-in-jetstream-server.md b/google3/third_party/py/jetstream/docs/observability-prometheus-metrics-in-jetstream-server.md similarity index 94% rename from docs/observability-prometheus-metrics-in-jetstream-server.md rename to google3/third_party/py/jetstream/docs/observability-prometheus-metrics-in-jetstream-server.md index 04d7be4c..079b132a 100644 --- a/docs/observability-prometheus-metrics-in-jetstream-server.md +++ b/google3/third_party/py/jetstream/docs/observability-prometheus-metrics-in-jetstream-server.md @@ -80,6 +80,6 @@ echo '{ }' | kubectl apply -f - ``` -The metrics can now be queried in the [Google Cloud Metrics Explorer](https://pantheon.corp.google.com/monitoring/metrics-explorer). When adding a metrics query with the `+Add Query` button the new metrics should be found under the `Prometheus Target > Jetstream` submenu. +The metrics can now be queried in the Google Cloud Metrics Explorer. When adding a metrics query with the `+Add Query` button the new metrics should be found under the `Prometheus Target > Jetstream` submenu. Additional guides on the metrics explorer can be found [here](https://cloud.google.com/monitoring/charts/metrics-selector). \ No newline at end of file diff --git a/docs/online-inference-with-maxtext-engine.md b/google3/third_party/py/jetstream/docs/online-inference-with-maxtext-engine.md similarity index 100% rename from docs/online-inference-with-maxtext-engine.md rename to google3/third_party/py/jetstream/docs/online-inference-with-maxtext-engine.md diff --git a/docs/profiling-with-jax-profiler-and-tensorboard.md b/google3/third_party/py/jetstream/docs/profiling-with-jax-profiler-and-tensorboard.md similarity index 100% rename from docs/profiling-with-jax-profiler-and-tensorboard.md rename to google3/third_party/py/jetstream/docs/profiling-with-jax-profiler-and-tensorboard.md diff --git a/jetstream/engine/README.md b/google3/third_party/py/jetstream/engine/README.md similarity index 100% rename from jetstream/engine/README.md rename to google3/third_party/py/jetstream/engine/README.md diff --git a/jetstream/engine/__init__.py b/google3/third_party/py/jetstream/engine/__init__.py similarity index 100% rename from jetstream/engine/__init__.py rename to google3/third_party/py/jetstream/engine/__init__.py diff --git a/jetstream/engine/engine_api.py b/google3/third_party/py/jetstream/engine/engine_api.py similarity index 100% rename from jetstream/engine/engine_api.py rename to google3/third_party/py/jetstream/engine/engine_api.py diff --git a/jetstream/engine/mock_engine.py b/google3/third_party/py/jetstream/engine/mock_engine.py similarity index 100% rename from jetstream/engine/mock_engine.py rename to google3/third_party/py/jetstream/engine/mock_engine.py diff --git a/jetstream/engine/mock_utils.py b/google3/third_party/py/jetstream/engine/mock_utils.py similarity index 100% rename from jetstream/engine/mock_utils.py rename to google3/third_party/py/jetstream/engine/mock_utils.py diff --git a/jetstream/engine/sampling_utils.py b/google3/third_party/py/jetstream/engine/sampling_utils.py similarity index 100% rename from jetstream/engine/sampling_utils.py rename to google3/third_party/py/jetstream/engine/sampling_utils.py diff --git a/jetstream/engine/token_utils.py b/google3/third_party/py/jetstream/engine/token_utils.py similarity index 100% rename from jetstream/engine/token_utils.py rename to google3/third_party/py/jetstream/engine/token_utils.py diff --git a/jetstream/engine/tokenizer.proto b/google3/third_party/py/jetstream/engine/tokenizer.proto similarity index 100% rename from jetstream/engine/tokenizer.proto rename to google3/third_party/py/jetstream/engine/tokenizer.proto diff --git a/jetstream/engine/tokenizer_api.py b/google3/third_party/py/jetstream/engine/tokenizer_api.py similarity index 100% rename from jetstream/engine/tokenizer_api.py rename to google3/third_party/py/jetstream/engine/tokenizer_api.py diff --git a/jetstream/engine/tokenizer_pb2.py b/google3/third_party/py/jetstream/engine/tokenizer_pb2.py similarity index 100% rename from jetstream/engine/tokenizer_pb2.py rename to google3/third_party/py/jetstream/engine/tokenizer_pb2.py diff --git a/jetstream/engine/tokenizer_pb2_grpc.py b/google3/third_party/py/jetstream/engine/tokenizer_pb2_grpc.py similarity index 100% rename from jetstream/engine/tokenizer_pb2_grpc.py rename to google3/third_party/py/jetstream/engine/tokenizer_pb2_grpc.py diff --git a/jetstream/engine/warmup_utils.py b/google3/third_party/py/jetstream/engine/warmup_utils.py similarity index 100% rename from jetstream/engine/warmup_utils.py rename to google3/third_party/py/jetstream/engine/warmup_utils.py diff --git a/jetstream/entrypoints/__init__.py b/google3/third_party/py/jetstream/entrypoints/__init__.py similarity index 100% rename from jetstream/entrypoints/__init__.py rename to google3/third_party/py/jetstream/entrypoints/__init__.py diff --git a/jetstream/entrypoints/config.py b/google3/third_party/py/jetstream/entrypoints/config.py similarity index 100% rename from jetstream/entrypoints/config.py rename to google3/third_party/py/jetstream/entrypoints/config.py diff --git a/jetstream/entrypoints/http/__init__.py b/google3/third_party/py/jetstream/entrypoints/http/__init__.py similarity index 100% rename from jetstream/entrypoints/http/__init__.py rename to google3/third_party/py/jetstream/entrypoints/http/__init__.py diff --git a/jetstream/entrypoints/http/api_server.py b/google3/third_party/py/jetstream/entrypoints/http/api_server.py similarity index 100% rename from jetstream/entrypoints/http/api_server.py rename to google3/third_party/py/jetstream/entrypoints/http/api_server.py diff --git a/jetstream/entrypoints/http/protocol.py b/google3/third_party/py/jetstream/entrypoints/http/protocol.py similarity index 100% rename from jetstream/entrypoints/http/protocol.py rename to google3/third_party/py/jetstream/entrypoints/http/protocol.py diff --git a/jetstream/entrypoints/http/utils.py b/google3/third_party/py/jetstream/entrypoints/http/utils.py similarity index 100% rename from jetstream/entrypoints/http/utils.py rename to google3/third_party/py/jetstream/entrypoints/http/utils.py diff --git a/jetstream/external_tokenizers/__init__.py b/google3/third_party/py/jetstream/external_tokenizers/__init__.py similarity index 100% rename from jetstream/external_tokenizers/__init__.py rename to google3/third_party/py/jetstream/external_tokenizers/__init__.py diff --git a/jetstream/external_tokenizers/llama3/__init__.py b/google3/third_party/py/jetstream/external_tokenizers/llama3/__init__.py similarity index 100% rename from jetstream/external_tokenizers/llama3/__init__.py rename to google3/third_party/py/jetstream/external_tokenizers/llama3/__init__.py diff --git a/jetstream/external_tokenizers/llama3/llama3_tokenizer.py b/google3/third_party/py/jetstream/external_tokenizers/llama3/llama3_tokenizer.py similarity index 100% rename from jetstream/external_tokenizers/llama3/llama3_tokenizer.py rename to google3/third_party/py/jetstream/external_tokenizers/llama3/llama3_tokenizer.py diff --git a/license_preamble.txt b/google3/third_party/py/jetstream/license_preamble.txt similarity index 100% rename from license_preamble.txt rename to google3/third_party/py/jetstream/license_preamble.txt diff --git a/pylintrc b/google3/third_party/py/jetstream/pylintrc similarity index 100% rename from pylintrc rename to google3/third_party/py/jetstream/pylintrc diff --git a/requirements.txt b/google3/third_party/py/jetstream/requirements.txt similarity index 100% rename from requirements.txt rename to google3/third_party/py/jetstream/requirements.txt diff --git a/setup.py b/google3/third_party/py/jetstream/setup.py similarity index 100% rename from setup.py rename to google3/third_party/py/jetstream/setup.py diff --git a/jetstream/tests/__init__.py b/google3/third_party/py/jetstream/tests/__init__.py similarity index 100% rename from jetstream/tests/__init__.py rename to google3/third_party/py/jetstream/tests/__init__.py diff --git a/jetstream/tests/core/__init__.py b/google3/third_party/py/jetstream/tests/core/__init__.py similarity index 100% rename from jetstream/tests/core/__init__.py rename to google3/third_party/py/jetstream/tests/core/__init__.py diff --git a/jetstream/tests/core/test_config_lib.py b/google3/third_party/py/jetstream/tests/core/test_config_lib.py similarity index 100% rename from jetstream/tests/core/test_config_lib.py rename to google3/third_party/py/jetstream/tests/core/test_config_lib.py diff --git a/jetstream/tests/core/test_orchestrator.py b/google3/third_party/py/jetstream/tests/core/test_orchestrator.py similarity index 100% rename from jetstream/tests/core/test_orchestrator.py rename to google3/third_party/py/jetstream/tests/core/test_orchestrator.py diff --git a/jetstream/tests/core/test_server.py b/google3/third_party/py/jetstream/tests/core/test_server.py similarity index 100% rename from jetstream/tests/core/test_server.py rename to google3/third_party/py/jetstream/tests/core/test_server.py diff --git a/jetstream/tests/engine/__init__.py b/google3/third_party/py/jetstream/tests/engine/__init__.py similarity index 100% rename from jetstream/tests/engine/__init__.py rename to google3/third_party/py/jetstream/tests/engine/__init__.py diff --git a/jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model b/google3/third_party/py/jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model similarity index 100% rename from jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model rename to google3/third_party/py/jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model diff --git a/jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model b/google3/third_party/py/jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model similarity index 100% rename from jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model rename to google3/third_party/py/jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model diff --git a/jetstream/tests/engine/test_mock_engine.py b/google3/third_party/py/jetstream/tests/engine/test_mock_engine.py similarity index 100% rename from jetstream/tests/engine/test_mock_engine.py rename to google3/third_party/py/jetstream/tests/engine/test_mock_engine.py diff --git a/jetstream/tests/engine/test_sampling_utils.py b/google3/third_party/py/jetstream/tests/engine/test_sampling_utils.py similarity index 100% rename from jetstream/tests/engine/test_sampling_utils.py rename to google3/third_party/py/jetstream/tests/engine/test_sampling_utils.py diff --git a/jetstream/tests/engine/test_token_utils.py b/google3/third_party/py/jetstream/tests/engine/test_token_utils.py similarity index 100% rename from jetstream/tests/engine/test_token_utils.py rename to google3/third_party/py/jetstream/tests/engine/test_token_utils.py diff --git a/jetstream/tests/engine/test_utils.py b/google3/third_party/py/jetstream/tests/engine/test_utils.py similarity index 100% rename from jetstream/tests/engine/test_utils.py rename to google3/third_party/py/jetstream/tests/engine/test_utils.py diff --git a/jetstream/tests/entrypoints/__init__.py b/google3/third_party/py/jetstream/tests/entrypoints/__init__.py similarity index 100% rename from jetstream/tests/entrypoints/__init__.py rename to google3/third_party/py/jetstream/tests/entrypoints/__init__.py diff --git a/jetstream/tests/entrypoints/http/__init__.py b/google3/third_party/py/jetstream/tests/entrypoints/http/__init__.py similarity index 100% rename from jetstream/tests/entrypoints/http/__init__.py rename to google3/third_party/py/jetstream/tests/entrypoints/http/__init__.py diff --git a/jetstream/tests/entrypoints/http/test_api_server.py b/google3/third_party/py/jetstream/tests/entrypoints/http/test_api_server.py similarity index 100% rename from jetstream/tests/entrypoints/http/test_api_server.py rename to google3/third_party/py/jetstream/tests/entrypoints/http/test_api_server.py diff --git a/jetstream/tools/load_tester.py b/google3/third_party/py/jetstream/tools/load_tester.py similarity index 100% rename from jetstream/tools/load_tester.py rename to google3/third_party/py/jetstream/tools/load_tester.py diff --git a/jetstream/tools/maxtext/model_ckpt_conversion.sh b/google3/third_party/py/jetstream/tools/maxtext/model_ckpt_conversion.sh similarity index 100% rename from jetstream/tools/maxtext/model_ckpt_conversion.sh rename to google3/third_party/py/jetstream/tools/maxtext/model_ckpt_conversion.sh diff --git a/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh b/google3/third_party/py/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh similarity index 100% rename from jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh rename to google3/third_party/py/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh diff --git a/jetstream/tools/proxy_dev/base.Dockerfile b/google3/third_party/py/jetstream/tools/proxy_dev/base.Dockerfile similarity index 100% rename from jetstream/tools/proxy_dev/base.Dockerfile rename to google3/third_party/py/jetstream/tools/proxy_dev/base.Dockerfile diff --git a/jetstream/tools/proxy_dev/dev.Dockerfile b/google3/third_party/py/jetstream/tools/proxy_dev/dev.Dockerfile similarity index 100% rename from jetstream/tools/proxy_dev/dev.Dockerfile rename to google3/third_party/py/jetstream/tools/proxy_dev/dev.Dockerfile diff --git a/jetstream/tools/requester.py b/google3/third_party/py/jetstream/tools/requester.py similarity index 100% rename from jetstream/tools/requester.py rename to google3/third_party/py/jetstream/tools/requester.py From 973647d771a57a9680d0ca96c137f1b0b0353b34 Mon Sep 17 00:00:00 2001 From: Yijia Date: Mon, 16 Dec 2024 13:42:52 -0800 Subject: [PATCH 42/42] Revert "Internal refactor" (#156) This reverts commit 8e18e7fd1db4ee271fa677eec86b0b90a3822c95. Co-authored-by: Yijia J --- .../py/jetstream/.github => .github}/CODEOWNERS | 0 .../.github => .github}/workflows/e2e_tests.yaml | 0 .../.github => .github}/workflows/release.yaml | 0 .../workflows/scripts/create_release.js | 0 .../.github => .github}/workflows/unit_tests.yaml | 0 .../py/jetstream/.gitignore => .gitignore | 0 google3/third_party/py/jetstream/AUTHORS => AUTHORS | 0 .../py/jetstream/CONTRIBUTING.md => CONTRIBUTING.md | 0 google3/third_party/py/jetstream/LICENSE => LICENSE | 0 .../py/jetstream/MANIFEST.in => MANIFEST.in | 0 .../third_party/py/jetstream/Makefile => Makefile | 0 .../third_party/py/jetstream/README.md => README.md | 0 .../jetstream/benchmarks => benchmarks}/README.md | 0 .../jetstream/benchmarks => benchmarks}/__init__.py | 0 .../benchmarks => benchmarks}/benchmark_serving.py | 0 .../benchmarks => benchmarks}/eval_accuracy.py | 1 + ...n_orca_gpt4_tokenized_llama.calibration_1000.pkl | Bin .../benchmarks => benchmarks}/requirements.in | 0 ...bility-prometheus-metrics-in-jetstream-server.md | 2 +- .../online-inference-with-maxtext-engine.md | 0 .../profiling-with-jax-profiler-and-tensorboard.md | 0 .../py/jetstream => jetstream}/__init__.py | 0 .../py/jetstream => jetstream}/core/README.md | 0 .../py/jetstream => jetstream}/core/__init__.py | 0 .../py/jetstream => jetstream}/core/config_lib.py | 0 .../core/implementations/__init__.py | 0 .../core/implementations/mock/README.md | 0 .../core/implementations/mock/__init__.py | 0 .../core/implementations/mock/config.py | 0 .../core/implementations/mock/server.py | 0 .../core/metrics/__init__.py | 0 .../core/metrics/prometheus.py | 0 .../py/jetstream => jetstream}/core/orchestrator.py | 0 .../jetstream => jetstream}/core/proto/__init__.py | 0 .../core/proto/jetstream.proto | 0 .../core/proto/jetstream_pb2.py | 0 .../core/proto/jetstream_pb2_grpc.py | 0 .../py/jetstream => jetstream}/core/server_lib.py | 0 .../jetstream => jetstream}/core/utils/__init__.py | 0 .../core/utils/async_multifuture.py | 0 .../core/utils/proxy_util.py | 0 .../core/utils/return_sample.py | 0 .../py/jetstream => jetstream}/engine/README.md | 0 .../py/jetstream => jetstream}/engine/__init__.py | 0 .../py/jetstream => jetstream}/engine/engine_api.py | 0 .../jetstream => jetstream}/engine/mock_engine.py | 0 .../py/jetstream => jetstream}/engine/mock_utils.py | 0 .../engine/sampling_utils.py | 0 .../jetstream => jetstream}/engine/token_utils.py | 0 .../jetstream => jetstream}/engine/tokenizer.proto | 0 .../jetstream => jetstream}/engine/tokenizer_api.py | 0 .../jetstream => jetstream}/engine/tokenizer_pb2.py | 0 .../engine/tokenizer_pb2_grpc.py | 0 .../jetstream => jetstream}/engine/warmup_utils.py | 0 .../jetstream => jetstream}/entrypoints/__init__.py | 0 .../jetstream => jetstream}/entrypoints/config.py | 0 .../entrypoints/http/__init__.py | 0 .../entrypoints/http/api_server.py | 0 .../entrypoints/http/protocol.py | 0 .../entrypoints/http/utils.py | 0 .../external_tokenizers/__init__.py | 0 .../external_tokenizers/llama3/__init__.py | 0 .../external_tokenizers/llama3/llama3_tokenizer.py | 0 .../py/jetstream => jetstream}/tests/__init__.py | 0 .../jetstream => jetstream}/tests/core/__init__.py | 0 .../tests/core/test_config_lib.py | 0 .../tests/core/test_orchestrator.py | 0 .../tests/core/test_server.py | 0 .../tests/engine/__init__.py | 0 .../external_tokenizers/llama2/tokenizer.model | Bin .../external_tokenizers/llama3/tokenizer.model | 0 .../tests/engine/test_mock_engine.py | 0 .../tests/engine/test_sampling_utils.py | 0 .../tests/engine/test_token_utils.py | 0 .../tests/engine/test_utils.py | 0 .../tests/entrypoints/__init__.py | 0 .../tests/entrypoints/http/__init__.py | 0 .../tests/entrypoints/http/test_api_server.py | 0 .../py/jetstream => jetstream}/tools/load_tester.py | 0 .../tools/maxtext/model_ckpt_conversion.sh | 0 .../tools/maxtext/model_ckpt_finetune_with_aqt.sh | 0 .../tools/proxy_dev/base.Dockerfile | 0 .../tools/proxy_dev/dev.Dockerfile | 0 .../py/jetstream => jetstream}/tools/requester.py | 0 .../license_preamble.txt => license_preamble.txt | 0 .../third_party/py/jetstream/pylintrc => pylintrc | 0 .../jetstream/requirements.txt => requirements.txt | 0 .../third_party/py/jetstream/setup.py => setup.py | 0 88 files changed, 2 insertions(+), 1 deletion(-) rename {google3/third_party/py/jetstream/.github => .github}/CODEOWNERS (100%) rename {google3/third_party/py/jetstream/.github => .github}/workflows/e2e_tests.yaml (100%) rename {google3/third_party/py/jetstream/.github => .github}/workflows/release.yaml (100%) rename {google3/third_party/py/jetstream/.github => .github}/workflows/scripts/create_release.js (100%) rename {google3/third_party/py/jetstream/.github => .github}/workflows/unit_tests.yaml (100%) rename google3/third_party/py/jetstream/.gitignore => .gitignore (100%) rename google3/third_party/py/jetstream/AUTHORS => AUTHORS (100%) rename google3/third_party/py/jetstream/CONTRIBUTING.md => CONTRIBUTING.md (100%) rename google3/third_party/py/jetstream/LICENSE => LICENSE (100%) rename google3/third_party/py/jetstream/MANIFEST.in => MANIFEST.in (100%) rename google3/third_party/py/jetstream/Makefile => Makefile (100%) rename google3/third_party/py/jetstream/README.md => README.md (100%) rename {google3/third_party/py/jetstream/benchmarks => benchmarks}/README.md (100%) rename {google3/third_party/py/jetstream/benchmarks => benchmarks}/__init__.py (100%) rename {google3/third_party/py/jetstream/benchmarks => benchmarks}/benchmark_serving.py (100%) rename {google3/third_party/py/jetstream/benchmarks => benchmarks}/eval_accuracy.py (99%) rename {google3/third_party/py/jetstream/benchmarks => benchmarks}/open_orca_gpt4_tokenized_llama.calibration_1000.pkl (100%) rename {google3/third_party/py/jetstream/benchmarks => benchmarks}/requirements.in (100%) rename {google3/third_party/py/jetstream/docs => docs}/observability-prometheus-metrics-in-jetstream-server.md (94%) rename {google3/third_party/py/jetstream/docs => docs}/online-inference-with-maxtext-engine.md (100%) rename {google3/third_party/py/jetstream/docs => docs}/profiling-with-jax-profiler-and-tensorboard.md (100%) rename {google3/third_party/py/jetstream => jetstream}/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/README.md (100%) rename {google3/third_party/py/jetstream => jetstream}/core/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/config_lib.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/implementations/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/implementations/mock/README.md (100%) rename {google3/third_party/py/jetstream => jetstream}/core/implementations/mock/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/implementations/mock/config.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/implementations/mock/server.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/metrics/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/metrics/prometheus.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/orchestrator.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/proto/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/proto/jetstream.proto (100%) rename {google3/third_party/py/jetstream => jetstream}/core/proto/jetstream_pb2.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/proto/jetstream_pb2_grpc.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/server_lib.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/utils/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/utils/async_multifuture.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/utils/proxy_util.py (100%) rename {google3/third_party/py/jetstream => jetstream}/core/utils/return_sample.py (100%) rename {google3/third_party/py/jetstream => jetstream}/engine/README.md (100%) rename {google3/third_party/py/jetstream => jetstream}/engine/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/engine/engine_api.py (100%) rename {google3/third_party/py/jetstream => jetstream}/engine/mock_engine.py (100%) rename {google3/third_party/py/jetstream => jetstream}/engine/mock_utils.py (100%) rename {google3/third_party/py/jetstream => jetstream}/engine/sampling_utils.py (100%) rename {google3/third_party/py/jetstream => jetstream}/engine/token_utils.py (100%) rename {google3/third_party/py/jetstream => jetstream}/engine/tokenizer.proto (100%) rename {google3/third_party/py/jetstream => jetstream}/engine/tokenizer_api.py (100%) rename {google3/third_party/py/jetstream => jetstream}/engine/tokenizer_pb2.py (100%) rename {google3/third_party/py/jetstream => jetstream}/engine/tokenizer_pb2_grpc.py (100%) rename {google3/third_party/py/jetstream => jetstream}/engine/warmup_utils.py (100%) rename {google3/third_party/py/jetstream => jetstream}/entrypoints/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/entrypoints/config.py (100%) rename {google3/third_party/py/jetstream => jetstream}/entrypoints/http/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/entrypoints/http/api_server.py (100%) rename {google3/third_party/py/jetstream => jetstream}/entrypoints/http/protocol.py (100%) rename {google3/third_party/py/jetstream => jetstream}/entrypoints/http/utils.py (100%) rename {google3/third_party/py/jetstream => jetstream}/external_tokenizers/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/external_tokenizers/llama3/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/external_tokenizers/llama3/llama3_tokenizer.py (100%) rename {google3/third_party/py/jetstream => jetstream}/tests/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/tests/core/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/tests/core/test_config_lib.py (100%) rename {google3/third_party/py/jetstream => jetstream}/tests/core/test_orchestrator.py (100%) rename {google3/third_party/py/jetstream => jetstream}/tests/core/test_server.py (100%) rename {google3/third_party/py/jetstream => jetstream}/tests/engine/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/tests/engine/external_tokenizers/llama2/tokenizer.model (100%) rename {google3/third_party/py/jetstream => jetstream}/tests/engine/external_tokenizers/llama3/tokenizer.model (100%) rename {google3/third_party/py/jetstream => jetstream}/tests/engine/test_mock_engine.py (100%) rename {google3/third_party/py/jetstream => jetstream}/tests/engine/test_sampling_utils.py (100%) rename {google3/third_party/py/jetstream => jetstream}/tests/engine/test_token_utils.py (100%) rename {google3/third_party/py/jetstream => jetstream}/tests/engine/test_utils.py (100%) rename {google3/third_party/py/jetstream => jetstream}/tests/entrypoints/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/tests/entrypoints/http/__init__.py (100%) rename {google3/third_party/py/jetstream => jetstream}/tests/entrypoints/http/test_api_server.py (100%) rename {google3/third_party/py/jetstream => jetstream}/tools/load_tester.py (100%) rename {google3/third_party/py/jetstream => jetstream}/tools/maxtext/model_ckpt_conversion.sh (100%) rename {google3/third_party/py/jetstream => jetstream}/tools/maxtext/model_ckpt_finetune_with_aqt.sh (100%) rename {google3/third_party/py/jetstream => jetstream}/tools/proxy_dev/base.Dockerfile (100%) rename {google3/third_party/py/jetstream => jetstream}/tools/proxy_dev/dev.Dockerfile (100%) rename {google3/third_party/py/jetstream => jetstream}/tools/requester.py (100%) rename google3/third_party/py/jetstream/license_preamble.txt => license_preamble.txt (100%) rename google3/third_party/py/jetstream/pylintrc => pylintrc (100%) rename google3/third_party/py/jetstream/requirements.txt => requirements.txt (100%) rename google3/third_party/py/jetstream/setup.py => setup.py (100%) diff --git a/google3/third_party/py/jetstream/.github/CODEOWNERS b/.github/CODEOWNERS similarity index 100% rename from google3/third_party/py/jetstream/.github/CODEOWNERS rename to .github/CODEOWNERS diff --git a/google3/third_party/py/jetstream/.github/workflows/e2e_tests.yaml b/.github/workflows/e2e_tests.yaml similarity index 100% rename from google3/third_party/py/jetstream/.github/workflows/e2e_tests.yaml rename to .github/workflows/e2e_tests.yaml diff --git a/google3/third_party/py/jetstream/.github/workflows/release.yaml b/.github/workflows/release.yaml similarity index 100% rename from google3/third_party/py/jetstream/.github/workflows/release.yaml rename to .github/workflows/release.yaml diff --git a/google3/third_party/py/jetstream/.github/workflows/scripts/create_release.js b/.github/workflows/scripts/create_release.js similarity index 100% rename from google3/third_party/py/jetstream/.github/workflows/scripts/create_release.js rename to .github/workflows/scripts/create_release.js diff --git a/google3/third_party/py/jetstream/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml similarity index 100% rename from google3/third_party/py/jetstream/.github/workflows/unit_tests.yaml rename to .github/workflows/unit_tests.yaml diff --git a/google3/third_party/py/jetstream/.gitignore b/.gitignore similarity index 100% rename from google3/third_party/py/jetstream/.gitignore rename to .gitignore diff --git a/google3/third_party/py/jetstream/AUTHORS b/AUTHORS similarity index 100% rename from google3/third_party/py/jetstream/AUTHORS rename to AUTHORS diff --git a/google3/third_party/py/jetstream/CONTRIBUTING.md b/CONTRIBUTING.md similarity index 100% rename from google3/third_party/py/jetstream/CONTRIBUTING.md rename to CONTRIBUTING.md diff --git a/google3/third_party/py/jetstream/LICENSE b/LICENSE similarity index 100% rename from google3/third_party/py/jetstream/LICENSE rename to LICENSE diff --git a/google3/third_party/py/jetstream/MANIFEST.in b/MANIFEST.in similarity index 100% rename from google3/third_party/py/jetstream/MANIFEST.in rename to MANIFEST.in diff --git a/google3/third_party/py/jetstream/Makefile b/Makefile similarity index 100% rename from google3/third_party/py/jetstream/Makefile rename to Makefile diff --git a/google3/third_party/py/jetstream/README.md b/README.md similarity index 100% rename from google3/third_party/py/jetstream/README.md rename to README.md diff --git a/google3/third_party/py/jetstream/benchmarks/README.md b/benchmarks/README.md similarity index 100% rename from google3/third_party/py/jetstream/benchmarks/README.md rename to benchmarks/README.md diff --git a/google3/third_party/py/jetstream/benchmarks/__init__.py b/benchmarks/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/benchmarks/__init__.py rename to benchmarks/__init__.py diff --git a/google3/third_party/py/jetstream/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py similarity index 100% rename from google3/third_party/py/jetstream/benchmarks/benchmark_serving.py rename to benchmarks/benchmark_serving.py diff --git a/google3/third_party/py/jetstream/benchmarks/eval_accuracy.py b/benchmarks/eval_accuracy.py similarity index 99% rename from google3/third_party/py/jetstream/benchmarks/eval_accuracy.py rename to benchmarks/eval_accuracy.py index 559cd2a8..f84562be 100644 --- a/google3/third_party/py/jetstream/benchmarks/eval_accuracy.py +++ b/benchmarks/eval_accuracy.py @@ -64,6 +64,7 @@ def main(args): eval_accuracy(request_outputs_dict) + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( diff --git a/google3/third_party/py/jetstream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl b/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl similarity index 100% rename from google3/third_party/py/jetstream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl rename to benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl diff --git a/google3/third_party/py/jetstream/benchmarks/requirements.in b/benchmarks/requirements.in similarity index 100% rename from google3/third_party/py/jetstream/benchmarks/requirements.in rename to benchmarks/requirements.in diff --git a/google3/third_party/py/jetstream/docs/observability-prometheus-metrics-in-jetstream-server.md b/docs/observability-prometheus-metrics-in-jetstream-server.md similarity index 94% rename from google3/third_party/py/jetstream/docs/observability-prometheus-metrics-in-jetstream-server.md rename to docs/observability-prometheus-metrics-in-jetstream-server.md index 079b132a..04d7be4c 100644 --- a/google3/third_party/py/jetstream/docs/observability-prometheus-metrics-in-jetstream-server.md +++ b/docs/observability-prometheus-metrics-in-jetstream-server.md @@ -80,6 +80,6 @@ echo '{ }' | kubectl apply -f - ``` -The metrics can now be queried in the Google Cloud Metrics Explorer. When adding a metrics query with the `+Add Query` button the new metrics should be found under the `Prometheus Target > Jetstream` submenu. +The metrics can now be queried in the [Google Cloud Metrics Explorer](https://pantheon.corp.google.com/monitoring/metrics-explorer). When adding a metrics query with the `+Add Query` button the new metrics should be found under the `Prometheus Target > Jetstream` submenu. Additional guides on the metrics explorer can be found [here](https://cloud.google.com/monitoring/charts/metrics-selector). \ No newline at end of file diff --git a/google3/third_party/py/jetstream/docs/online-inference-with-maxtext-engine.md b/docs/online-inference-with-maxtext-engine.md similarity index 100% rename from google3/third_party/py/jetstream/docs/online-inference-with-maxtext-engine.md rename to docs/online-inference-with-maxtext-engine.md diff --git a/google3/third_party/py/jetstream/docs/profiling-with-jax-profiler-and-tensorboard.md b/docs/profiling-with-jax-profiler-and-tensorboard.md similarity index 100% rename from google3/third_party/py/jetstream/docs/profiling-with-jax-profiler-and-tensorboard.md rename to docs/profiling-with-jax-profiler-and-tensorboard.md diff --git a/google3/third_party/py/jetstream/__init__.py b/jetstream/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/__init__.py rename to jetstream/__init__.py diff --git a/google3/third_party/py/jetstream/core/README.md b/jetstream/core/README.md similarity index 100% rename from google3/third_party/py/jetstream/core/README.md rename to jetstream/core/README.md diff --git a/google3/third_party/py/jetstream/core/__init__.py b/jetstream/core/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/core/__init__.py rename to jetstream/core/__init__.py diff --git a/google3/third_party/py/jetstream/core/config_lib.py b/jetstream/core/config_lib.py similarity index 100% rename from google3/third_party/py/jetstream/core/config_lib.py rename to jetstream/core/config_lib.py diff --git a/google3/third_party/py/jetstream/core/implementations/__init__.py b/jetstream/core/implementations/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/core/implementations/__init__.py rename to jetstream/core/implementations/__init__.py diff --git a/google3/third_party/py/jetstream/core/implementations/mock/README.md b/jetstream/core/implementations/mock/README.md similarity index 100% rename from google3/third_party/py/jetstream/core/implementations/mock/README.md rename to jetstream/core/implementations/mock/README.md diff --git a/google3/third_party/py/jetstream/core/implementations/mock/__init__.py b/jetstream/core/implementations/mock/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/core/implementations/mock/__init__.py rename to jetstream/core/implementations/mock/__init__.py diff --git a/google3/third_party/py/jetstream/core/implementations/mock/config.py b/jetstream/core/implementations/mock/config.py similarity index 100% rename from google3/third_party/py/jetstream/core/implementations/mock/config.py rename to jetstream/core/implementations/mock/config.py diff --git a/google3/third_party/py/jetstream/core/implementations/mock/server.py b/jetstream/core/implementations/mock/server.py similarity index 100% rename from google3/third_party/py/jetstream/core/implementations/mock/server.py rename to jetstream/core/implementations/mock/server.py diff --git a/google3/third_party/py/jetstream/core/metrics/__init__.py b/jetstream/core/metrics/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/core/metrics/__init__.py rename to jetstream/core/metrics/__init__.py diff --git a/google3/third_party/py/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py similarity index 100% rename from google3/third_party/py/jetstream/core/metrics/prometheus.py rename to jetstream/core/metrics/prometheus.py diff --git a/google3/third_party/py/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py similarity index 100% rename from google3/third_party/py/jetstream/core/orchestrator.py rename to jetstream/core/orchestrator.py diff --git a/google3/third_party/py/jetstream/core/proto/__init__.py b/jetstream/core/proto/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/core/proto/__init__.py rename to jetstream/core/proto/__init__.py diff --git a/google3/third_party/py/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto similarity index 100% rename from google3/third_party/py/jetstream/core/proto/jetstream.proto rename to jetstream/core/proto/jetstream.proto diff --git a/google3/third_party/py/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py similarity index 100% rename from google3/third_party/py/jetstream/core/proto/jetstream_pb2.py rename to jetstream/core/proto/jetstream_pb2.py diff --git a/google3/third_party/py/jetstream/core/proto/jetstream_pb2_grpc.py b/jetstream/core/proto/jetstream_pb2_grpc.py similarity index 100% rename from google3/third_party/py/jetstream/core/proto/jetstream_pb2_grpc.py rename to jetstream/core/proto/jetstream_pb2_grpc.py diff --git a/google3/third_party/py/jetstream/core/server_lib.py b/jetstream/core/server_lib.py similarity index 100% rename from google3/third_party/py/jetstream/core/server_lib.py rename to jetstream/core/server_lib.py diff --git a/google3/third_party/py/jetstream/core/utils/__init__.py b/jetstream/core/utils/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/core/utils/__init__.py rename to jetstream/core/utils/__init__.py diff --git a/google3/third_party/py/jetstream/core/utils/async_multifuture.py b/jetstream/core/utils/async_multifuture.py similarity index 100% rename from google3/third_party/py/jetstream/core/utils/async_multifuture.py rename to jetstream/core/utils/async_multifuture.py diff --git a/google3/third_party/py/jetstream/core/utils/proxy_util.py b/jetstream/core/utils/proxy_util.py similarity index 100% rename from google3/third_party/py/jetstream/core/utils/proxy_util.py rename to jetstream/core/utils/proxy_util.py diff --git a/google3/third_party/py/jetstream/core/utils/return_sample.py b/jetstream/core/utils/return_sample.py similarity index 100% rename from google3/third_party/py/jetstream/core/utils/return_sample.py rename to jetstream/core/utils/return_sample.py diff --git a/google3/third_party/py/jetstream/engine/README.md b/jetstream/engine/README.md similarity index 100% rename from google3/third_party/py/jetstream/engine/README.md rename to jetstream/engine/README.md diff --git a/google3/third_party/py/jetstream/engine/__init__.py b/jetstream/engine/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/engine/__init__.py rename to jetstream/engine/__init__.py diff --git a/google3/third_party/py/jetstream/engine/engine_api.py b/jetstream/engine/engine_api.py similarity index 100% rename from google3/third_party/py/jetstream/engine/engine_api.py rename to jetstream/engine/engine_api.py diff --git a/google3/third_party/py/jetstream/engine/mock_engine.py b/jetstream/engine/mock_engine.py similarity index 100% rename from google3/third_party/py/jetstream/engine/mock_engine.py rename to jetstream/engine/mock_engine.py diff --git a/google3/third_party/py/jetstream/engine/mock_utils.py b/jetstream/engine/mock_utils.py similarity index 100% rename from google3/third_party/py/jetstream/engine/mock_utils.py rename to jetstream/engine/mock_utils.py diff --git a/google3/third_party/py/jetstream/engine/sampling_utils.py b/jetstream/engine/sampling_utils.py similarity index 100% rename from google3/third_party/py/jetstream/engine/sampling_utils.py rename to jetstream/engine/sampling_utils.py diff --git a/google3/third_party/py/jetstream/engine/token_utils.py b/jetstream/engine/token_utils.py similarity index 100% rename from google3/third_party/py/jetstream/engine/token_utils.py rename to jetstream/engine/token_utils.py diff --git a/google3/third_party/py/jetstream/engine/tokenizer.proto b/jetstream/engine/tokenizer.proto similarity index 100% rename from google3/third_party/py/jetstream/engine/tokenizer.proto rename to jetstream/engine/tokenizer.proto diff --git a/google3/third_party/py/jetstream/engine/tokenizer_api.py b/jetstream/engine/tokenizer_api.py similarity index 100% rename from google3/third_party/py/jetstream/engine/tokenizer_api.py rename to jetstream/engine/tokenizer_api.py diff --git a/google3/third_party/py/jetstream/engine/tokenizer_pb2.py b/jetstream/engine/tokenizer_pb2.py similarity index 100% rename from google3/third_party/py/jetstream/engine/tokenizer_pb2.py rename to jetstream/engine/tokenizer_pb2.py diff --git a/google3/third_party/py/jetstream/engine/tokenizer_pb2_grpc.py b/jetstream/engine/tokenizer_pb2_grpc.py similarity index 100% rename from google3/third_party/py/jetstream/engine/tokenizer_pb2_grpc.py rename to jetstream/engine/tokenizer_pb2_grpc.py diff --git a/google3/third_party/py/jetstream/engine/warmup_utils.py b/jetstream/engine/warmup_utils.py similarity index 100% rename from google3/third_party/py/jetstream/engine/warmup_utils.py rename to jetstream/engine/warmup_utils.py diff --git a/google3/third_party/py/jetstream/entrypoints/__init__.py b/jetstream/entrypoints/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/entrypoints/__init__.py rename to jetstream/entrypoints/__init__.py diff --git a/google3/third_party/py/jetstream/entrypoints/config.py b/jetstream/entrypoints/config.py similarity index 100% rename from google3/third_party/py/jetstream/entrypoints/config.py rename to jetstream/entrypoints/config.py diff --git a/google3/third_party/py/jetstream/entrypoints/http/__init__.py b/jetstream/entrypoints/http/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/entrypoints/http/__init__.py rename to jetstream/entrypoints/http/__init__.py diff --git a/google3/third_party/py/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py similarity index 100% rename from google3/third_party/py/jetstream/entrypoints/http/api_server.py rename to jetstream/entrypoints/http/api_server.py diff --git a/google3/third_party/py/jetstream/entrypoints/http/protocol.py b/jetstream/entrypoints/http/protocol.py similarity index 100% rename from google3/third_party/py/jetstream/entrypoints/http/protocol.py rename to jetstream/entrypoints/http/protocol.py diff --git a/google3/third_party/py/jetstream/entrypoints/http/utils.py b/jetstream/entrypoints/http/utils.py similarity index 100% rename from google3/third_party/py/jetstream/entrypoints/http/utils.py rename to jetstream/entrypoints/http/utils.py diff --git a/google3/third_party/py/jetstream/external_tokenizers/__init__.py b/jetstream/external_tokenizers/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/external_tokenizers/__init__.py rename to jetstream/external_tokenizers/__init__.py diff --git a/google3/third_party/py/jetstream/external_tokenizers/llama3/__init__.py b/jetstream/external_tokenizers/llama3/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/external_tokenizers/llama3/__init__.py rename to jetstream/external_tokenizers/llama3/__init__.py diff --git a/google3/third_party/py/jetstream/external_tokenizers/llama3/llama3_tokenizer.py b/jetstream/external_tokenizers/llama3/llama3_tokenizer.py similarity index 100% rename from google3/third_party/py/jetstream/external_tokenizers/llama3/llama3_tokenizer.py rename to jetstream/external_tokenizers/llama3/llama3_tokenizer.py diff --git a/google3/third_party/py/jetstream/tests/__init__.py b/jetstream/tests/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/tests/__init__.py rename to jetstream/tests/__init__.py diff --git a/google3/third_party/py/jetstream/tests/core/__init__.py b/jetstream/tests/core/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/tests/core/__init__.py rename to jetstream/tests/core/__init__.py diff --git a/google3/third_party/py/jetstream/tests/core/test_config_lib.py b/jetstream/tests/core/test_config_lib.py similarity index 100% rename from google3/third_party/py/jetstream/tests/core/test_config_lib.py rename to jetstream/tests/core/test_config_lib.py diff --git a/google3/third_party/py/jetstream/tests/core/test_orchestrator.py b/jetstream/tests/core/test_orchestrator.py similarity index 100% rename from google3/third_party/py/jetstream/tests/core/test_orchestrator.py rename to jetstream/tests/core/test_orchestrator.py diff --git a/google3/third_party/py/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py similarity index 100% rename from google3/third_party/py/jetstream/tests/core/test_server.py rename to jetstream/tests/core/test_server.py diff --git a/google3/third_party/py/jetstream/tests/engine/__init__.py b/jetstream/tests/engine/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/tests/engine/__init__.py rename to jetstream/tests/engine/__init__.py diff --git a/google3/third_party/py/jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model b/jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model similarity index 100% rename from google3/third_party/py/jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model rename to jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model diff --git a/google3/third_party/py/jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model b/jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model similarity index 100% rename from google3/third_party/py/jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model rename to jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model diff --git a/google3/third_party/py/jetstream/tests/engine/test_mock_engine.py b/jetstream/tests/engine/test_mock_engine.py similarity index 100% rename from google3/third_party/py/jetstream/tests/engine/test_mock_engine.py rename to jetstream/tests/engine/test_mock_engine.py diff --git a/google3/third_party/py/jetstream/tests/engine/test_sampling_utils.py b/jetstream/tests/engine/test_sampling_utils.py similarity index 100% rename from google3/third_party/py/jetstream/tests/engine/test_sampling_utils.py rename to jetstream/tests/engine/test_sampling_utils.py diff --git a/google3/third_party/py/jetstream/tests/engine/test_token_utils.py b/jetstream/tests/engine/test_token_utils.py similarity index 100% rename from google3/third_party/py/jetstream/tests/engine/test_token_utils.py rename to jetstream/tests/engine/test_token_utils.py diff --git a/google3/third_party/py/jetstream/tests/engine/test_utils.py b/jetstream/tests/engine/test_utils.py similarity index 100% rename from google3/third_party/py/jetstream/tests/engine/test_utils.py rename to jetstream/tests/engine/test_utils.py diff --git a/google3/third_party/py/jetstream/tests/entrypoints/__init__.py b/jetstream/tests/entrypoints/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/tests/entrypoints/__init__.py rename to jetstream/tests/entrypoints/__init__.py diff --git a/google3/third_party/py/jetstream/tests/entrypoints/http/__init__.py b/jetstream/tests/entrypoints/http/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/tests/entrypoints/http/__init__.py rename to jetstream/tests/entrypoints/http/__init__.py diff --git a/google3/third_party/py/jetstream/tests/entrypoints/http/test_api_server.py b/jetstream/tests/entrypoints/http/test_api_server.py similarity index 100% rename from google3/third_party/py/jetstream/tests/entrypoints/http/test_api_server.py rename to jetstream/tests/entrypoints/http/test_api_server.py diff --git a/google3/third_party/py/jetstream/tools/load_tester.py b/jetstream/tools/load_tester.py similarity index 100% rename from google3/third_party/py/jetstream/tools/load_tester.py rename to jetstream/tools/load_tester.py diff --git a/google3/third_party/py/jetstream/tools/maxtext/model_ckpt_conversion.sh b/jetstream/tools/maxtext/model_ckpt_conversion.sh similarity index 100% rename from google3/third_party/py/jetstream/tools/maxtext/model_ckpt_conversion.sh rename to jetstream/tools/maxtext/model_ckpt_conversion.sh diff --git a/google3/third_party/py/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh b/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh similarity index 100% rename from google3/third_party/py/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh rename to jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh diff --git a/google3/third_party/py/jetstream/tools/proxy_dev/base.Dockerfile b/jetstream/tools/proxy_dev/base.Dockerfile similarity index 100% rename from google3/third_party/py/jetstream/tools/proxy_dev/base.Dockerfile rename to jetstream/tools/proxy_dev/base.Dockerfile diff --git a/google3/third_party/py/jetstream/tools/proxy_dev/dev.Dockerfile b/jetstream/tools/proxy_dev/dev.Dockerfile similarity index 100% rename from google3/third_party/py/jetstream/tools/proxy_dev/dev.Dockerfile rename to jetstream/tools/proxy_dev/dev.Dockerfile diff --git a/google3/third_party/py/jetstream/tools/requester.py b/jetstream/tools/requester.py similarity index 100% rename from google3/third_party/py/jetstream/tools/requester.py rename to jetstream/tools/requester.py diff --git a/google3/third_party/py/jetstream/license_preamble.txt b/license_preamble.txt similarity index 100% rename from google3/third_party/py/jetstream/license_preamble.txt rename to license_preamble.txt diff --git a/google3/third_party/py/jetstream/pylintrc b/pylintrc similarity index 100% rename from google3/third_party/py/jetstream/pylintrc rename to pylintrc diff --git a/google3/third_party/py/jetstream/requirements.txt b/requirements.txt similarity index 100% rename from google3/third_party/py/jetstream/requirements.txt rename to requirements.txt diff --git a/google3/third_party/py/jetstream/setup.py b/setup.py similarity index 100% rename from google3/third_party/py/jetstream/setup.py rename to setup.py