diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index 7b230dde..8db79fc3 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -42,21 +42,13 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install Dependencies - run: | - pip install pytype - pip install pylint - pip install pyink - pip install -r requirements.txt - pip install -r benchmarks/requirements.in + run: make install-deps - name: Typecheck the code with pytype - run: | - pytype --jobs auto --disable=import-error,module-attr jetstream/ benchmarks/ + run: make type-check - name: Analysing the code with pylint - run: | - pylint jetstream/ benchmarks/ + run: make linter-check - name: Format check with pyink - run: | - pyink --pyink-indentation 2 --line-length 80 --check --verbose . + run: make format-check cpu: name: "JetStream unit tests" @@ -73,11 +65,8 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install Dependencies - run: | - pip install -r requirements.txt + run: make install-deps - name: Run all unit tests in JetStream (jetstream/tests) - run: | - coverage run -m unittest -v + run: make unit-tests - name: Create test coverage report - run: | - coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/third_party/*" --fail-under=96 \ No newline at end of file + run: make check-test-coverage \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in index 9d4615bd..540b7204 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -include requirements.in \ No newline at end of file +include requirements.txt \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..7f3cff00 --- /dev/null +++ b/Makefile @@ -0,0 +1,54 @@ +PYTHON := python +PIP := $(PYTHON) -m pip +GRPC_TOOLS_VERSION := 1.62.1 + +all: install-deps generate-protos format check + +# Dependency management targets +install-deps: + $(PIP) install pytype pylint pyink -r requirements.txt -r benchmarks/requirements.in + +# Code generation/formatting targets +generate-protos: generate-and-prepend-preambles format + +generate-and-prepend-preambles: + $(PIP) install grpcio-tools==$(GRPC_TOOLS_VERSION) + for id in $$(find . -name "*.proto"); do \ + $(PYTHON) -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. $$id && \ + PROTO_FILE=$$(echo $$id | awk '{print substr($$0, 1, length($$0)-6)}') && \ + PB_GRPC_PY=$(addsuffix "_pb2_grpc.py",$$PROTO_FILE) && \ + PB_PY=$(addsuffix "_pb2.py",$$PROTO_FILE) && \ + cat license_preamble.txt $$PB_GRPC_PY >> $(addsuffix "_temp",$$PB_GRPC_PY) && \ + mv $(addsuffix "_temp",$$PB_GRPC_PY) $$PB_GRPC_PY; \ + cat license_preamble.txt $$PB_PY >> $(addsuffix "_temp",$$PB_PY) && \ + mv $(addsuffix "_temp",$$PB_PY) $$PB_PY; \ + done + +format: + $(PIP) install pyink + pyink --pyink-indentation 2 --line-length 80 --verbose . + +# Code checking related targets +check: type-check format-check linter-check + +type-check: + $(PIP) install pytype + pytype --jobs auto --disable=import-error,module-attr jetstream/ benchmarks/ + +format-check: + $(PIP) install pyink + pyink --pyink-indentation 2 --line-length 80 --check --verbose . + +linter-check: + $(PIP) install pylint + pylint --ignore-patterns=".*_pb2.py,.*_pb2_grpc.py" jetstream/ benchmarks/ + + +# Testing related targets +tests: unit-tests check-test-coverage + +unit-tests: + coverage run -m unittest -v + +check-test-coverage: + coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/external_tokenizers/*" --fail-under=96 diff --git a/README.md b/README.md index ee0b1eee..62959c46 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ Currently, there are two reference engine implementations available -- one for J - [Online Inference with MaxText on v5e Cloud TPU VM](https://cloud.google.com/tpu/docs/tutorials/LLM/jetstream) [[README](https://github.com/google/JetStream/blob/main/docs/online-inference-with-maxtext-engine.md)] - [Online Inference with Pytorch on v5e Cloud TPU VM](https://cloud.google.com/tpu/docs/tutorials/LLM/jetstream-pytorch) [[README](https://github.com/google/jetstream-pytorch/tree/main?tab=readme-ov-file#jetstream-pytorch)] - [Serve Gemma using TPUs on GKE with JetStream](https://cloud.google.com/kubernetes-engine/docs/tutorials/serve-gemma-tpu-jetstream) +- [Benchmark JetStream Server](https://github.com/google/JetStream/blob/main/benchmarks/README.md) - [Observability in JetStream Server](https://github.com/google/JetStream/blob/main/docs/observability-prometheus-metrics-in-jetstream-server.md) - [Profiling in JetStream Server](https://github.com/google/JetStream/blob/main/docs/profiling-with-jax-profiler-and-tensorboard.md) - [JetStream Standalone Local Setup](#jetstream-standalone-local-setup) @@ -38,7 +39,7 @@ Currently, there are two reference engine implementations available -- one for J ### Setup ``` -pip install -r requirements.txt +make install-deps ``` ### Run local server & Testing diff --git a/benchmarks/README.md b/benchmarks/README.md index b88501f2..2a511c1b 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -78,7 +78,7 @@ python benchmark_serving.py \ ``` python JetStream/benchmarks/benchmark_serving.py \ --tokenizer ~/maxtext/assets/tokenizer.llama2 \ ---warmup-first true \ +--warmup-mode sampled \ --save-result \ --save-request-outputs \ --request-outputs-file-path outputs.json \ @@ -88,6 +88,23 @@ python JetStream/benchmarks/benchmark_serving.py \ ``` +## Benchmark warmup mode + +The benchmark has better performance if it first conducts a warmup of the JetStream server. We currently support `sampled` and `full` warmup modes. `full` mode would warmup up the JetStream server with all the input requests. `sampled` mode would warmup up the JetStream server with a sampling of the input requests across different bucket sizes of input lengths. + +Example to run benchmark with `full` warmup mode: +``` +python JetStream/benchmarks/benchmark_serving.py \ +--tokenizer ~/maxtext/assets/tokenizer.llama2 \ +--warmup-mode full \ +--save-result \ +--save-request-outputs \ +--request-outputs-file-path outputs.json \ +--num-prompts 1000 \ +--max-output-length 1024 \ +--dataset openorca +``` + ## Standalone Evaluation Run If you used `--save-request-outputs`, you can separately evaluate against the generated outputs. diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 252cc534..97628372 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -73,12 +73,38 @@ from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc from jetstream.engine.token_utils import load_vocab -from jetstream.third_party.llama3 import llama3_tokenizer +from jetstream.external_tokenizers.llama3 import llama3_tokenizer import numpy as np from tqdm.asyncio import tqdm # pytype: disable=pyi-error import pandas from eval_accuracy import eval_accuracy +from transformers import AutoTokenizer + + +def str2bool(v: str) -> bool: + """Convert a string of truth to True or False. + + Args: + - v (str): + - True values are 'y', 'yes', 't', 'true', and '1'; + - False values are 'n', 'no', 'f', 'false', and '0'. + + Returns: + bool: True or False + + Raises: + ValueError if v is anything else. + """ + v = v.lower() + true_values = ["y", "yes", "t", "true", "1"] + false_values = ["n", "no", "f", "false", "0"] + if v in true_values: + return True + elif v in false_values: + return False + else: + raise ValueError(f"Invalid value '{v}'!") @dataclass @@ -131,16 +157,29 @@ def to_dict(self): } -def get_tokenizer(model_id: str, tokenizer_name: str) -> Any: +def get_tokenizer( + model_id: str, + tokenizer_name: str, + use_hf_tokenizer: bool, +) -> Any: """Return a tokenizer or a tokenizer placholder.""" if tokenizer_name == "test": + print("Using test tokenizer") return "test" + elif use_hf_tokenizer: + # Please accept agreement to access private/gated models in HF, and + # follow up instructions below to set up access token + # https://huggingface.co/docs/transformers.js/en/guides/private + print(f"Using HuggingFace tokenizer: {tokenizer_name}") + return AutoTokenizer.from_pretrained(tokenizer_name) elif model_id == "llama-3": # Llama 3 uses a tiktoken tokenizer. + print(f"Using llama-3 tokenizer: {tokenizer_name}") return llama3_tokenizer.Tokenizer(tokenizer_name) else: # Use JetStream tokenizer util. It's using the sentencepiece wrapper in # seqio library. + print(f"Using tokenizer: {tokenizer_name}") vocab = load_vocab(tokenizer_name) return vocab.tokenizer @@ -226,9 +265,9 @@ def tokenize_dataset( def filter_dataset( tokenized_dataset: list[tuple[str, Any, str, int, int, int]], - max_output_length: Optional[int] = None, + max_output_length: int = 0, ) -> list[InputRequest]: - if max_output_length is None: + if max_output_length != 0: print("In InputRequest, pass in actual output_length for each sample") else: print( @@ -269,7 +308,7 @@ def sample_requests( dataset: list[tuple[Any, Any]], tokenizer: Any, num_requests: int, - max_output_length: Optional[int] = None, + max_output_length: int = 0, oversample_multiplier: float = 1.2, ) -> list[InputRequest]: @@ -319,7 +358,7 @@ async def get_request( for request in input_requests: yield request - if request_rate == float("inf"): + if request_rate == 0.0: # If the request rate is infinity, then we don't need to wait. continue # Sample the request interval from the exponential distribution. @@ -401,18 +440,14 @@ async def send_request( tokenizer: Any, input_request: InputRequest, pbar: tqdm, - session_cache: str, - priority: int, ) -> RequestFuncOutput: """Send the request to JetStream server.""" # Tokenization on client side following MLPerf standard. token_ids = tokenizer.encode(input_request.prompt) request = jetstream_pb2.DecodeRequest( - session_cache=session_cache, token_content=jetstream_pb2.DecodeRequest.TokenContent( token_ids=token_ids ), - priority=priority, max_tokens=input_request.output_len, ) output = RequestFuncOutput() @@ -438,8 +473,6 @@ async def benchmark( input_requests: list[InputRequest], request_rate: float, disable_tqdm: bool, - session_cache: str, - priority: int, ): """Benchmark the online serving performance.""" pbar = None if disable_tqdm else tqdm(total=len(input_requests)) @@ -456,8 +489,6 @@ async def benchmark( tokenizer=tokenizer, input_request=request, pbar=pbar, - session_cache=session_cache, - priority=priority, ) ) ) @@ -546,10 +577,11 @@ def main(args: argparse.Namespace): model_id = args.model tokenizer_id = args.tokenizer + use_hf_tokenizer = args.use_hf_tokenizer api_url = f"{args.server}:{args.port}" - tokenizer = get_tokenizer(model_id, tokenizer_id) + tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer) if tokenizer == "test" or args.dataset == "test": input_requests = mock_requests( args.total_mock_requests @@ -574,9 +606,14 @@ def main(args: argparse.Namespace): max_output_length=args.max_output_length, ) - if args.warmup_first: - print("Warm up start:") + warmup_requests = None + if args.warmup_mode == "full": + warmup_requests = input_requests + elif args.warmup_mode == "sampled": warmup_requests = list(sample_warmup_requests(input_requests)) * 2 + + if warmup_requests: + print(f"Starting {args.warmup_mode} warmup:") benchmark_result, request_outputs = asyncio.run( benchmark( api_url=api_url, @@ -584,11 +621,9 @@ def main(args: argparse.Namespace): input_requests=warmup_requests, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, - session_cache=args.session_cache, - priority=args.priority, ) ) - print("Warm up done") + print(f"{args.warmup_mode} warmup completed.") # TODO: Replace this with warmup complete signal once supported. # Wait for server completely warmup before running the benchmark. @@ -601,8 +636,6 @@ def main(args: argparse.Namespace): input_requests=input_requests, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, - session_cache=args.session_cache, - priority=args.priority, ) ) @@ -623,17 +656,15 @@ def main(args: argparse.Namespace): dimensions_json["date"] = current_dt dimensions_json["model_id"] = model_id dimensions_json["tokenizer_id"] = tokenizer_id - dimensions_json = { - **dimensions_json, - **json.loads(args.additional_metadata_metrics_to_save), - } + if args.additional_metadata_metrics_to_save is not None: + dimensions_json = { + **dimensions_json, + **json.loads(args.additional_metadata_metrics_to_save), + } metrics_json["num_prompts"] = args.num_prompts # Traffic - metrics_json["request_rate"] = ( - args.request_rate if args.request_rate < float("inf") else "inf" - ) - + metrics_json["request_rate"] = args.request_rate metrics_json = {**metrics_json, **benchmark_result} if args.run_eval: metrics_json = {**metrics_json, **eval_json} @@ -661,6 +692,7 @@ def main(args: argparse.Namespace): if __name__ == "__main__": + parser = argparse.ArgumentParser( description="Benchmark the online serving throughput." ) @@ -699,6 +731,15 @@ def main(args: argparse.Namespace): " default value)" ), ) + parser.add_argument( + "--use-hf-tokenizer", + type=str2bool, + default=False, + help=( + "Whether to use tokenizer from HuggingFace. If so, set this flag" + " to True, and provide name of the tokenizer in the tokenizer flag." + ), + ) parser.add_argument( "--num-prompts", type=int, @@ -711,9 +752,9 @@ def main(args: argparse.Namespace): parser.add_argument( "--request-rate", type=float, - default=float("inf"), + default=0.0, help=( - "Number of requests per second. If this is inf, " + "Number of requests per second. If this is 0., " "then all the requests are sent at time 0. " "Otherwise, we use Poisson process to synthesize " "the request arrival times." @@ -729,7 +770,7 @@ def main(args: argparse.Namespace): parser.add_argument( "--max-output-length", type=int, - default=None, + default=0, help=( "The maximum output length for reference request. It would be passed" " to `max_tokens` parameter of the JetStream's DecodeRequest proto," @@ -738,7 +779,8 @@ def main(args: argparse.Namespace): " max_tokens <= (max_target_length - max_prefill_predict_length)." " max_target_length is the maximum length of a sequence;" " max_prefill_predict_length is the maximum length of the" - " input/prefill of a sequence." + " input/prefill of a sequence. Default to 0, in this case, " + "the output length of the golden dataset would be passed." ), ) @@ -761,24 +803,6 @@ def main(args: argparse.Namespace): " the form of a string." ), ) - parser.add_argument( - "--priority", - type=int, - default=0, - help=( - "Message priority. (currently no business logic implemented, use" - " default 0)" - ), - ) - parser.add_argument( - "--session-cache", - type=str, - default="", - help=( - "Location of any pre-cached results. (currently _load_cache_history" - " not implemented, use default empty str)" - ), - ) parser.add_argument( "--save-request-outputs", action="store_true", @@ -792,15 +816,16 @@ def main(args: argparse.Namespace): ) parser.add_argument( "--run-eval", - type=bool, + type=str2bool, default=False, help="Whether to run evaluation script on the saved outputs", ) parser.add_argument( - "--warmup-first", - type=bool, - default=False, - help="Whether to send warmup req first", + "--warmup-mode", + type=str, + default="none", + choices=["none", "sampled", "full"], + help="Whether to warmup first, and set the warmup mode", ) parser.add_argument( "--conversation-starter", diff --git a/docs/observability-prometheus-metrics-in-jetstream-server.md b/docs/observability-prometheus-metrics-in-jetstream-server.md index b61cf081..04d7be4c 100644 --- a/docs/observability-prometheus-metrics-in-jetstream-server.md +++ b/docs/observability-prometheus-metrics-in-jetstream-server.md @@ -45,7 +45,41 @@ Now that we configured `prometheus_port=9090` above, we can observe various Jets # HELP jetstream_prefill_backlog_size Size of prefill queue # TYPE jetstream_prefill_backlog_size gauge jetstream_prefill_backlog_size{id="SOME-HOSTNAME-HERE>"} 0.0 -# HELP jetstream_slots_available_percentage The percentage of available slots in decode batch -# TYPE jetstream_slots_available_percentage gauge -jetstream_slots_available_percentage{id="",idx="0"} 0.96875 -``` \ No newline at end of file +# HELP jetstream_slots_used_percentage The percentage of decode slots currently being used +# TYPE jetstream_slots_used_percentage gauge +jetstream_slots_used_percentage{id="",idx="0"} 0.04166666666666663 +``` + +## Observe metrics on GKE clusters + +The following applies only for Jetstream deployed on a GKE cluster. Currently [Google Cloud Managed Service for Prometheus](https://cloud.google.com/stackdriver/docs/managed-prometheus) is enabled by default on all GKE clusters, it determines scrape targets via the [PodMonitoring](https://github.com/GoogleCloudPlatform/prometheus-engine/blob/v0.10.0/doc/api.md#podmonitoring) custom resource. After you deployed the JetStream GKE workload, you need to apply the PodMonitoring resource to your cluster as follows: + +``` +echo '{ + "apiVersion": "monitoring.googleapis.com/v1", + "kind": "PodMonitoring", + "metadata": { + "name": "jetstream-podmonitoring" + }, + "spec": { + "endpoints": [ + { + "interval": "1s", + "path": "/", + "port": + } + ], + "targetLabels": { + "metadata": [ + "pod", + "container", + "node" + ] + } + } + }' | kubectl apply -f - + ``` + +The metrics can now be queried in the [Google Cloud Metrics Explorer](https://pantheon.corp.google.com/monitoring/metrics-explorer). When adding a metrics query with the `+Add Query` button the new metrics should be found under the `Prometheus Target > Jetstream` submenu. + +Additional guides on the metrics explorer can be found [here](https://cloud.google.com/monitoring/charts/metrics-selector). \ No newline at end of file diff --git a/docs/online-inference-with-maxtext-engine.md b/docs/online-inference-with-maxtext-engine.md index 9d3aefe1..60173e3c 100644 --- a/docs/online-inference-with-maxtext-engine.md +++ b/docs/online-inference-with-maxtext-engine.md @@ -21,11 +21,11 @@ Follow the steps in [Manage TPU resources | Google Cloud](https://cloud.google.c ## Step 1: Download JetStream and the MaxText github repository ```bash -git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git -git clone -b v0.2.2 https://github.com/google/JetStream.git +git clone https://github.com/google/maxtext.git +git clone https://github.com/google/JetStream.git ``` -## Step 2: Setup MaxText +## Step 2: Setup MaxText and JetStream ```bash # Create a python virtual environment for the demo. @@ -36,6 +36,12 @@ source .env/bin/activate # Setup MaxText. cd maxtext/ bash setup.sh + +# Setup JetStream +cd JetStream +pip install -e . +cd benchmarks +pip install -r requirements.in ``` ## Step 3: Convert Model Checkpoints @@ -45,16 +51,16 @@ You can run the JetStream MaxText Server with Gemma and Llama2 models. This sect ### Use a Gemma model checkpoint * You can download a [Gemma checkpoint from Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText/variations/7b). -* After downloading checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. +* After downloading orbax Gemma checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. You should also set two more paths `$MAXTEXT_BUCKET_SCANNED` and `$MAXTEXT_BUCKET_UNSCANNED` that point to the locations of the maxtext checkpoints for the scanned and unscanned (inference-optimized) versions, respectively. * `gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}` * Please refer to the [conversion script](https://github.com/google/JetStream/blob/main/jetstream/tools/maxtext/model_ckpt_conversion.sh) for an example of `$CHKPT_BUCKET`. * Then, using the following command to convert the Gemma checkpoint into a MaxText compatible unscanned checkpoint. ```bash -# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} +# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} # For gemma-7b -bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET} +bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} ``` Note: For more information about the Gemma model and checkpoints, see [About Gemma](https://github.com/google/maxtext/blob/main/end_to_end/gemma/Run_Gemma.md). @@ -63,25 +69,25 @@ Note: For more information about the Gemma model and checkpoints, see [About Gem ### Use a Llama2 model checkpoint * You can use a Llama2 checkpoint you have generated or one from [the open source community](https://llama.meta.com/llama-downloads/). -* After downloading checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. +* After downloading PyTorch checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. You should also set two more paths `$MAXTEXT_BUCKET_SCANNED` and `$MAXTEXT_BUCKET_UNSCANNED` that point to the locations of the maxtext checkpoints for the scanned and unscanned (inference-optimized) versions, respectively. * `gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}` * Please refer to the [conversion script](https://github.com/google/JetStream/blob/main/jetstream/tools/maxtext/model_ckpt_conversion.sh) for an example of `$CHKPT_BUCKET`. * Then, using the following command to convert the Llama2 checkpoint into a MaxText compatible unscanned checkpoint. ```bash -# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} +# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} # For llama2-7b -bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} +bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} # For llama2-13b -bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET} +bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} ``` Note: For more information about the Llama2 model and checkpoints, see [About Llama2](https://github.com/google/maxtext/blob/main/getting_started/Run_Llama2.md). -## Step4: Run the JetStream MaxText server +## Step 4: Run the JetStream MaxText server ### Create model config environment variables for server flags @@ -104,8 +110,8 @@ export MAX_PREFILL_PREDICT_LENGTH=1024 export MAX_TARGET_LENGTH=2048 export MODEL_NAME=gemma-7b export ICI_FSDP_PARALLELISM=1 -export ICI_AUTOREGRESSIVE_PARALLELISM=-1 -export ICI_TENSOR_PARALLELISM=1 +export ICI_AUTOREGRESSIVE_PARALLELISM=1 +export ICI_TENSOR_PARALLELISM=-1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 export PER_DEVICE_BATCH_SIZE=11 @@ -122,8 +128,8 @@ export MAX_PREFILL_PREDICT_LENGTH=1024 export MAX_TARGET_LENGTH=2048 export MODEL_NAME=llama2-7b export ICI_FSDP_PARALLELISM=1 -export ICI_AUTOREGRESSIVE_PARALLELISM=-1 -export ICI_TENSOR_PARALLELISM=1 +export ICI_AUTOREGRESSIVE_PARALLELISM=1 +export ICI_TENSOR_PARALLELISM=-1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 export PER_DEVICE_BATCH_SIZE=11 @@ -131,8 +137,6 @@ export PER_DEVICE_BATCH_SIZE=11 #### Create Llama2-13b environment variables for server flags - - * Configure the [flags](#jetstream-maxtext-server-flag-descriptions) passing into the JetStream MaxText server ```bash @@ -142,8 +146,8 @@ export MAX_PREFILL_PREDICT_LENGTH=1024 export MAX_TARGET_LENGTH=2048 export MODEL_NAME=llama2-13b export ICI_FSDP_PARALLELISM=1 -export ICI_AUTOREGRESSIVE_PARALLELISM=-1 -export ICI_TENSOR_PARALLELISM=1 +export ICI_AUTOREGRESSIVE_PARALLELISM=1 +export ICI_TENSOR_PARALLELISM=-1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 export PER_DEVICE_BATCH_SIZE=4 @@ -187,7 +191,8 @@ python MaxText/maxengine_server.py \ Note: these flags are from [MaxText config](https://github.com/google/maxtext/blob/f9e04cdc1eec74a0e648411857c09403c3358461/MaxText/configs/base.yml) -## Step 5: Send test request to JetStream MaxText server +## Step 5: Send a test request to JetStream MaxText server +In a new tab in your terminal, run the following command ```bash cd ~ @@ -207,34 +212,125 @@ Response: to be a fan ## Step 6: Run benchmarks with JetStream MaxText server -Note: The JetStream MaxText Server is not running with quantization optimization in Step 3. To get best benchmark results, we need to enable quantization (Please use AQT trained or fine tuned checkpoints to ensure accuracy) for both weights and KV cache, please add the quantization flags and restart the server as following: +Note: The JetStream MaxText Server commands from Step 4 are not running with any quantization optimizations. To get the best benchmark results, we need to enable quantization for weights and KV cache. To do this, first generate AQT trained or fine-tuned checkpoints. Then, add the quantization flags and restart the server. + +### Generating a quantized checkpoint + +First, define the path to which the quantized checkpoint +```bash +export SAVE_QUANT_PARAMS_PATH=gs://${USER}-bkt/quantized/llama2-7b-chat +``` + +There are several different quantization configurations to choose from: +#### int8 DRQ quantized checkpoint ```bash -# Enable int8 quantization for both weights and KV cache +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} +``` + +#### Weights-only int8 quantized checkpoint +```bash +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8w save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} +``` + +#### Mixed precision weight-only quantized checkpoint +First, update the mixed precision config file (`MaxText/configs/quantization/mp_scale.json`) in MaxText repo to the mixed-precision-config defined below. +``` +{ + ".*/query": {"bits": 4, "scale": 0.8}, + ".*/key": {"bits": 4, "scale": 0.9}, + ".*/value": {"bits": 8}, + ".*/out": {"bits": 4}, + ".*/wi_0": {"bits": 4}, + ".*/wo": {"bits": 8} +} +``` +Then run the following command: +```bash +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=intmp +quant_cfg_path=configs/quantization/mp_scale.json save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} +``` + +### Restart the server with quantization flags + +#### Set flags + +Setting base quantization flags +```bash +# To load an int8 DRQcheckpoint export QUANTIZATION=int8 -export QUANTIZE_KVCACHE=true +export LOAD_PARAMETERS_PATH${SAVE_QUANT_PARAMS_PATH} +export CHECKPOINT_IS_QUANTIZED=True + +# To load an int8 weight-only checkpoint +export QUANTIZATION=int8w +export LOAD_PARAMETERS_PATH${SAVE_QUANT_PARAMS_PATH} +export CHECKPOINT_IS_QUANTIZED=True + +# To load a Mixed-Precision quantized checkpoint +# If using Mixed-Precision mode, make sure to update the mixed precision config file to the same file as used for quantizing the checkpoint (MaxText/configs/quantization/mp_scale.json) +export QUANTIZATION=intmp +export LOAD_PARAMETERS_PATH${SAVE_QUANT_PARAMS_PATH} +export CHECKPOINT_IS_QUANTIZED=True +export QUANT_CFG_PATH=configs/quantization/mp_scale.json +``` +The KV-cache is quantized to int8 by using the following config params +```bash +export QUANTIZE_KVCACHE=True +``` +If you don't want to quantize the KV-cache, set +```bash +export QUANTIZE_KVCACHE=False +``` + + +#### Restart server +```bash # For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. export PER_DEVICE_BATCH_SIZE=12 cd ~/maxtext python MaxText/maxengine_server.py \ -MaxText/configs/base.yml \ -tokenizer_path=${TOKENIZER_PATH} \ -load_parameters_path=${LOAD_PARAMETERS_PATH} \ -max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ -max_target_length=${MAX_TARGET_LENGTH} \ -model_name=${MODEL_NAME} \ -ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ -ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ -ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ -scan_layers=${SCAN_LAYERS} \ -weight_dtype=${WEIGHT_DTYPE} \ -per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ -quantization=${QUANTIZATION} \ -quantize_kvcache=${QUANTIZE_KVCACHE} + MaxText/configs/base.yml \ + tokenizer_path=${TOKENIZER_PATH} \ + load_parameters_path=${LOAD_PARAMETERS_PATH} \ + max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ + max_target_length=${MAX_TARGET_LENGTH} \ + model_name=${MODEL_NAME} \ + ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ + ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ + ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ + scan_layers=${SCAN_LAYERS} \ + weight_dtype=${WEIGHT_DTYPE} \ + per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ + quantization=${QUANTIZATION} \ + quantize_kvcache=${QUANTIZE_KVCACHE} \ + checkpoint_is_quantized=${CHECKPOINT_IS_QUANTIZED} ``` +For the mixed precision quantized model +```bash +python MaxText/maxengine_server.py \ + MaxText/configs/base.yml \ + tokenizer_path=${TOKENIZER_PATH} \ + load_parameters_path=${LOAD_PARAMETERS_PATH} \ + max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ + max_target_length=${MAX_TARGET_LENGTH} \ + model_name=${MODEL_NAME} \ + ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ + ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ + ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ + scan_layers=${SCAN_LAYERS} \ + weight_dtype=${WEIGHT_DTYPE} \ + per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ + quantization=${QUANTIZATION} \ + quantize_kvcache=${QUANTIZE_KVCACHE} \ + checkpoint_is_quantized=${CHECKPOINT_IS_QUANTIZED} \ + quant_cfg_path=${QUANT_CFG_PATH} +``` + + ### Benchmarking Gemma-7b Instructions @@ -259,13 +355,14 @@ python JetStream/benchmarks/benchmark_serving.py \ --dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \ --max-output-length 1024 \ --request-rate 5 \ ---warmup-first true +--warmup-mode sampled ``` +For details, please see https://github.com/google/JetStream/blob/main/benchmarks/README.md -### Benchmarking Llama2-\*b +### Benchmarking Llama2 ```bash -# Same as Gemma-7b except for the tokenizer (must use a tokenizer that matches your model, which should now be tokenizer.llama2). +# The command is the same as that for the Gemma-7b, except for the tokenizer. Since we need to use a tokenizer that matches the model, it should now be tokenizer.llama2. python JetStream/benchmarks/benchmark_serving.py \ --tokenizer maxtext/assets/tokenizer.llama2 \ @@ -274,8 +371,9 @@ python JetStream/benchmarks/benchmark_serving.py \ --dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \ --max-output-length 1024 \ --request-rate 5 \ ---warmup-first true +--warmup-mode sampled ``` +For details, please see https://github.com/google/JetStream/blob/main/benchmarks/README.md ## Clean Up @@ -283,10 +381,11 @@ python JetStream/benchmarks/benchmark_serving.py \ # Clean up gcs buckets. gcloud storage buckets delete ${MODEL_BUCKET} gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY} -gcloud storage buckets delete ${DATASET_PATH} + # Clean up repositories. rm -rf maxtext rm -rf JetStream + # Clean up python virtual environment rm -rf .env ``` diff --git a/docs/profiling-with-jax-profiler-and-tensorboard.md b/docs/profiling-with-jax-profiler-and-tensorboard.md index 3727c387..d2590040 100644 --- a/docs/profiling-with-jax-profiler-and-tensorboard.md +++ b/docs/profiling-with-jax-profiler-and-tensorboard.md @@ -10,7 +10,14 @@ Following the [JAX official manual profiling approach](https://jax.readthedocs.i ```bash tensorboard --logdir /tmp/tensorboard/ ``` -You should be able to load TensorBoard at http://localhost:6006/. You can specify a different port with the `--port` flag. +You should be able to load TensorBoard at http://localhost:6006/. You can specify a different port with the `--port` flag. If you are running on a remote Cloud TPU VM, the `tensorboard-plugin-profile` python package enables remote access to tensorboard endpoints (JetStream deps include this package). + +When you can not access the tensorboard and the profiling code is run remotely, please run below command setup an SSH tunnel on port 6006 to work. If you run with vs code remote debug commandline, the vs code did ssh forward port for you. + +```bash + gcloud compute ssh -- -L 6006:127.0.0.1:6006 + ``` + 2. Start JetStream MaxText server: ```bash diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index de0be2c2..dc8a00e9 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -16,7 +16,8 @@ import os import shortuuid -from prometheus_client import Gauge +from prometheus_client import Counter, Gauge, Histogram +from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS class JetstreamMetricsCollector: @@ -35,14 +36,222 @@ def __new__(cls): documentation="Size of prefill queue", labelnames=["id"], ) - _slots_available_percentage = Gauge( - name="jetstream_slots_available_percentage", - documentation="The percentage of available slots in decode batch", + + _transfer_backlog = Gauge( + name="jetstream_transfer_backlog_size", + documentation="Size of transfer queue", + labelnames=["id", "idx"], + ) + + _generate_backlog = Gauge( + name="jetstream_generate_backlog_size", + documentation="Size of generate queue", + labelnames=["id", "idx"], + ) + + _queue_duration = Histogram( + name="jetstream_queue_duration", + documentation="The total time each request spends enqueued in seconds", + labelnames=["id"], + buckets=[ + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1.0, + 2.0, + 5.0, + 10.0, + 20.0, + 50.0, + 100.0, + ], + ) + + _slots_used_percentage = Gauge( + name="jetstream_slots_used_percentage", + documentation="The percentage of decode slots currently being used", labelnames=["id", "idx"], ) + _server_startup_latency = Gauge( + name="jetstream_server_startup_latency", + documentation="Total time taken to start the Jetstream server", + labelnames=["id"], + ) + _request_input_length = Histogram( + name="jetstream_request_input_length", + documentation="Number of input tokens per request", + labelnames=["id"], + buckets=DEFAULT_PREFILL_BUCKETS, + ) + _request_output_length = Histogram( + name="jetstream_request_output_length", + documentation="Number of output tokens per request", + labelnames=["id"], + buckets=[ + 1, + 2, + 5, + 10, + 20, + 50, + 100, + 200, + 500, + 1000, + 2000, + 5000, + 10000, + 20000, + 50000, + 100000, + 200000, + 500000, + 1000000, + 2000000, + ], + ) + _request_success_count = Counter( + name="jetstream_request_success_count", + documentation="Number of requests successfully completed", + labelnames=["id"], + ) + + _time_to_first_token = Histogram( + name="jetstream_time_to_first_token", + documentation="Time to first token per request in seconds", + labelnames=["id"], + buckets=[ + 0.001, + 0.005, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + ], + ) + + _time_per_output_token = Histogram( + name="jetstream_time_per_output_token", + documentation="Average time per output token per request in seconds", + labelnames=["id"], + buckets=[ + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + ], + ) + + _time_per_prefill_token = Histogram( + name="jetstream_time_per_prefill_token", + documentation="Prefill time per token per request in seconds", + labelnames=["id"], + buckets=[ + 0.00001, + 0.00002, + 0.00005, + 0.0001, + 0.0002, + 0.0005, + 0.001, + 0.002, + 0.005, + 0.01, + 0.02, + 0.05, + 0.1, + ], + ) + + _time_per_request = Histogram( + name="jetstream_time_per_request", + documentation="End to end request latency in seconds", + labelnames=["id"], + buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0], + ) + + _wait_time_per_request = Histogram( + name="jetstream_wait_time_per_request", + documentation="Time each request is not being prefilled or decoded", + labelnames=["id"], + buckets=[ + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1.0, + 2.0, + 5.0, + 10.0, + 20.0, + 50.0, + 100.0, + ], + ) + def get_prefill_backlog_metric(self): return self._prefill_backlog.labels(id=self._id) - def get_slots_available_percentage_metric(self, idx: int): - return self._slots_available_percentage.labels(id=self._id, idx=idx) + def get_transfer_backlog_metric(self, idx: int): + return self._transfer_backlog.labels(id=self._id, idx=idx) + + def get_generate_backlog_metric(self, idx: int): + return self._generate_backlog.labels(id=self._id, idx=idx) + + def get_queue_duration(self): + return self._queue_duration.labels(id=self._id) + + def get_slots_used_percentage_metric(self, idx: int): + return self._slots_used_percentage.labels(id=self._id, idx=idx) + + def get_server_startup_latency_metric(self): + return self._server_startup_latency.labels(id=self._id) + + def get_time_to_first_token(self): + return self._time_to_first_token.labels(id=self._id) + + def get_time_per_output_token(self): + return self._time_per_output_token.labels(id=self._id) + + def get_time_per_prefill_token(self): + return self._time_per_prefill_token.labels(id=self._id) + + def get_time_per_request(self): + return self._time_per_request.labels(id=self._id) + + def get_wait_time_per_request(self): + return self._wait_time_per_request.labels(id=self._id) + + def get_request_input_length(self): + return self._request_input_length.labels(id=self._id) + + def get_request_output_length(self): + return self._request_output_length.labels(id=self._id) + + def get_request_success_count_metric(self): + return self._request_success_count.labels(id=self._id) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index eed35f8c..15fc36dd 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -38,7 +38,7 @@ of the generation loop at the relevant slot. - Regardless, it performs a step. - It takes the sampled tokens, and places them on a 'detokenizing_queue'. -7. Within the detokenizing thread: +7. Within the detokenizing thread (Prefill and Generate separately): - Tokens are detokenized for every 'slot' in a given set of sampled tokens. - When an end condition is met, the 'slot' integer is returned to the respective generation queue. @@ -98,10 +98,10 @@ import numpy as np root = logging.getLogger() -root.setLevel(logging.DEBUG) +root.setLevel(logging.INFO) handler = logging.StreamHandler(sys.stdout) -handler.setLevel(logging.DEBUG) +handler.setLevel(logging.INFO) formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) @@ -109,13 +109,22 @@ root.addHandler(handler) -def delete_pytree(p): - def delete_leaf(leaf): - if isinstance(leaf, jax.Array): - leaf.delete() - del leaf +@dataclasses.dataclass +class ActiveRequestMetadata: + """Inference request metadata.""" + + start_time: Optional[float] = None + + prefill_enqueue_time: Optional[float] = None + prefill_dequeue_time: Optional[float] = None - jax.tree_map(delete_leaf, p) + transfer_enqueue_time: Optional[float] = None + transfer_dequeue_time: Optional[float] = None + + generate_enqueue_time: Optional[float] = None + generate_dequeue_time: Optional[float] = None + + complete_time: Optional[float] = None @dataclasses.dataclass @@ -133,12 +142,13 @@ class ActiveRequest: complete: Optional[np.ndarray] = None prefill_result: Any = None #################### Information relevant for prefill ######################## - history_path: Optional[str] = None prefill_content: Optional[str | list[int]] = None ################## Information relevant for detokenization ################### # Which generate step this was added at. generate_timestep_added: Optional[int] = None is_client_side_tokenization: Optional[bool] = False + ################## Information relevant for metrics ################### + metadata: ActiveRequestMetadata = ActiveRequestMetadata() def enqueue_samples(self, generated_samples: list[ReturnSample]): """Adds the generated sample(s) to return channel for current step. @@ -200,7 +210,8 @@ class Driver: # Stage 4 # This can be a list because we can pass it as an arg to generate and # detokenize threads. It is a list of tokens to be detokenized. - _detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = [] + _prefill_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = [] + _generate_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = [] _generate_slots: list[queue.Queue[int]] = [] _active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = [] @@ -260,13 +271,18 @@ def __init__( # one of the generate backlogs. # Interleaved Mode: Max size is 1 to increase the HBM utilization # during generate. - # Disaggregated Mode: Max size is 4 to allow for 2 prefills to be enqueued - # while 1 transfer is enqueued while 1 is being transferred. + # Disaggregated Mode: Max size is 16 to allow for total 16 prefills to + # be enqueued or enqueued while 1 is being transferred. # TODO: Make queue size configurable. self._transfer_backlogs = [ - queue.Queue(1 if self._interleaved_mode else 4) + queue.Queue(1 if self._interleaved_mode else 16) for i in range(len(self._prefill_engines)) ] + if self._metrics_collector: + for idx, backlog in enumerate(self._transfer_backlogs): + self._metrics_collector.get_transfer_backlog_metric(idx).set_function( + functools.partial(float, backlog.qsize()) + ) # Stage 3 # Each generate engine accesses its own generate backlog. # Interleaved Mode: Max size is 1 to increase the HBM utilization @@ -281,11 +297,17 @@ def __init__( ) for idx, engine in enumerate(self._generate_engines) } + if self._metrics_collector: + for idx, backlog in self._generate_backlogs.items(): + self._metrics_collector.get_generate_backlog_metric(idx).set_function( + functools.partial(float, backlog.qsize()) + ) # Stage 4 - # After generation, ActiveRequests are placed on the detokenization backlog - # for tokens to be sent into each ActiveRequest's return channel. - # We have one of these per generate engine to simplify the logic keeping - # track of which generation engine to replace slots on. + # After prefill and generation, ActiveRequests are placed on the + # detokenization backlog for tokens to be sent into each ActiveRequest's + # return channel. + # We have one of these per prefill / generate engine to simplify + # the logic keeping track of which generation engine to replace slots on. # This is a queue of either - tuple[int, ActiveRequest] which represents our # active requests, or tuple[int, sample_tokens]. We combine these into one # queue because it allows us to be somewhat clever with how we do @@ -300,7 +322,16 @@ def __init__( # the possibility of race conditions where a slot is made live before the # tokens are ready and it receives tokens from a different sequence, # or tokens detokenized before the relevant slot is live. - self._detokenize_backlogs = [ + + self._prefill_detokenize_backlogs = [ + # No need to set maxsize, as transfer queue can + # provide the backpressure to the prefill workload + # (to avoid the overwhelming prefill). + queue.Queue() + for _ in self._prefill_engines + ] + + self._generate_detokenize_backlogs = [ # We don't let detokenization accumulate more than 8 steps to avoid # synchronization issues. queue.Queue(8) @@ -356,13 +387,25 @@ def __init__( ) for idx in range(len(self._generate_engines)) ] - self.detokenize_threads = [ + self.prefill_detokenize_threads = [ JetThread( target=functools.partial( self._detokenize_thread, - idx, + is_prefill=True, + idx=idx, + ), + name=f"prefill_detokenize-{idx}", + ) + for idx in range(len(self._generate_engines)) + ] + self.generate_detokenize_threads = [ + JetThread( + target=functools.partial( + self._detokenize_thread, + is_prefill=False, + idx=idx, ), - name=f"detokenize-{idx}", + name=f"generate_detokenize-{idx}", ) for idx in range(len(self._generate_engines)) ] @@ -371,7 +414,8 @@ def __init__( self._prefill_threads, self._transfer_threads, self._generate_threads, - self.detokenize_threads, + self.prefill_detokenize_threads, + self.generate_detokenize_threads, ) ) self.live = True @@ -390,7 +434,8 @@ def stop(self): [self._prefill_backlog], self._transfer_backlogs, self._generate_backlogs.values(), - self._detokenize_backlogs, + self._prefill_detokenize_backlogs, + self._generate_detokenize_backlogs, ) ) @@ -479,26 +524,37 @@ def _prefill_thread(self, idx: int): if request is None: break - is_bos = not bool(request.history_path) + request.metadata.prefill_dequeue_time = time.perf_counter() + is_bos = True logging.info( "Prefilling on prefill engine %d : prefill queue size, %d," - " is_bos: %s, history: %s", + " is_bos: %s", idx, self._prefill_backlog.qsize(), is_bos, - request.history_path, ) # Tokenize and padding the text or token input. padded_tokens, true_length = self._process_prefill_content( request, tokenizer, is_bos, prefill_engine.max_prefill_length ) + # Compute new kv cache for the prefill_content. - prefill_result = prefill_engine.prefill( + prefill_result, first_token = prefill_engine.prefill( params=prefill_params, padded_tokens=padded_tokens, true_length=true_length, ) request.prefill_result = prefill_result + + # put first token to detokenize queue + request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_) + my_detokenize_backlog = self._prefill_detokenize_backlogs[idx] + request.metadata.transfer_enqueue_time = time.perf_counter() + my_detokenize_backlog.put( + (first_token, request, request.metadata.prefill_dequeue_time), + block=True, + ) + # Once prefill is complete, place it on the generation queue and block if # full. my_transfer_backlog.put(request, block=True) @@ -507,6 +563,18 @@ def _prefill_thread(self, idx: int): idx, my_transfer_backlog.qsize(), ) + if self._metrics_collector: + self._metrics_collector.get_request_input_length().observe(true_length) + + if self._metrics_collector: + self._metrics_collector.get_time_per_prefill_token().observe( + ( + request.metadata.transfer_enqueue_time + - request.metadata.prefill_dequeue_time + ) + / true_length + ) + del prefill_result del request @@ -543,6 +611,7 @@ def _transfer_thread(self, idx: int): new_request = transfer_backlog.get(block=True) if new_request is None: break + new_request.metadata.transfer_dequeue_time = time.perf_counter() target_idx = min( self._generate_backlogs.items(), key=lambda q: q[1].qsize() )[0] @@ -558,12 +627,15 @@ def _transfer_thread(self, idx: int): # Transfer the info to the relevant generate slice. self._transfer_prefill_result(new_request, target_idx) # Place the request on the correct generate backlog and block if full. + new_request.metadata.generate_enqueue_time = time.perf_counter() self._generate_backlogs[target_idx].put(new_request, block=True) logging.info( "Successfully transferred prefill " - "from prefill engine %d to generate engine %d.", + "from prefill engine %d to generate engine %d " + "(%d requests now in backlog).", idx, target_idx, + self._generate_backlogs[target_idx].qsize(), ) def _generate_thread(self, idx: int): @@ -572,7 +644,7 @@ def _generate_thread(self, idx: int): generate_engine = self._generate_engines[idx] my_slots = self._generate_slots[idx] my_generate_backlog = self._generate_backlogs[idx] - my_detokenize_backlog = self._detokenize_backlogs[idx] + my_detokenize_backlog = self._generate_detokenize_backlogs[idx] # Keep track of what step tokens were generated at. generate_timestep = 0 @@ -597,9 +669,11 @@ def _generate_thread(self, idx: int): max_concurrent_decodes = generate_engine.max_concurrent_decodes if self._metrics_collector: - self._metrics_collector.get_slots_available_percentage_metric( + self._metrics_collector.get_slots_used_percentage_metric( idx - ).set_function(lambda: float(my_slots.qsize() / max_concurrent_decodes)) + ).set_function( + lambda: float(1 - (my_slots.qsize() / max_concurrent_decodes)) + ) # Check if there are any free my_slots. We don't want to block here since # we can still generate if we can't insert. We do this in a while loop to @@ -626,6 +700,24 @@ def _generate_thread(self, idx: int): block |= not self._transfer_backlogs[idx].empty() try: new_request = my_generate_backlog.get(block=block, timeout=1.0) + if new_request is None: + break + new_request.metadata.generate_dequeue_time = time.perf_counter() + if ( + self._metrics_collector + and new_request.metadata.start_time is not None + ): + self._metrics_collector.get_queue_duration().observe( + # Time in prefill queue + new_request.metadata.prefill_dequeue_time + - new_request.metadata.prefill_enqueue_time + # Time in transfer queue + + new_request.metadata.transfer_dequeue_time + - new_request.metadata.transfer_enqueue_time + # Time in generate queue + + new_request.metadata.generate_dequeue_time + - new_request.metadata.generate_enqueue_time + ) # Got free slot and new request, use them. except queue.Empty: # No new requests, we can't insert, so put back slot. @@ -647,10 +739,11 @@ def _generate_thread(self, idx: int): slot, generate_timestep, ) + decode_state = generate_engine.insert( new_request.prefill_result, decode_state, slot=slot ) - delete_pytree(new_request.prefill_result) + del new_request.prefill_result new_request.generate_timestep_added = generate_timestep new_request.complete = np.zeros( (generate_engine.samples_per_slot,), dtype=np.bool_ @@ -681,12 +774,17 @@ def _generate_thread(self, idx: int): ) time_of_last_generate = time.time() - def _detokenize_thread(self, idx: int): + def _detokenize_thread(self, is_prefill: bool, idx: int): """Detokenize sampled tokens and returns them to the user.""" # One of these per generate engine. # For all filled my_slots, pop the sampled token onto the relevant # requests return channel. If it done, place it back onto free slots. - my_detokenize_backlog = self._detokenize_backlogs[idx] + + if is_prefill: + my_detokenize_backlog = self._prefill_detokenize_backlogs[idx] + else: + my_detokenize_backlog = self._generate_detokenize_backlogs[idx] + my_generate_engine = self._generate_engines[idx] my_slots = self._generate_slots[idx] @@ -700,7 +798,35 @@ def _detokenize_thread(self, idx: int): if data is None: break start_detokenize_time = time.time() - if isinstance(data[1], engine_api.ResultTokens): + # prefill first token + if isinstance(data[0], engine_api.ResultTokens): + request_first_token, request, _ = data + request_first_token = request_first_token.convert_to_numpy() + + results, complete = token_utils.process_result_tokens( + tokenizer=tokenizer, + slot=0, # always 0 as prefill only run 1 sample + slot_max_length=request.max_tokens, + result_tokens=request_first_token, + is_client_side_tokenization=request.is_client_side_tokenization, + complete=request.complete, + ) + request.complete = complete + # Return some output samples. + request.enqueue_samples(results) + + first_token_return_time = time.perf_counter() + if self._metrics_collector: + self._metrics_collector.get_time_to_first_token().observe( + first_token_return_time - request.metadata.prefill_dequeue_time + ) + logging.info( + "TTFT duration: %fms", + (first_token_return_time - request.metadata.prefill_dequeue_time) + * 1000, + ) + # generate step tokens + elif isinstance(data[1], engine_api.ResultTokens): # We want to detokenize them. generate_timestep_added, result_tokens = data # Disable attribute error because pytype doesn't know this @@ -721,10 +847,45 @@ def _detokenize_thread(self, idx: int): # Return some output samples. request.enqueue_samples(results) if request.complete.all(): + request.metadata.complete_time = time.perf_counter() request.return_channel.close() + if self._metrics_collector: + self._metrics_collector.get_request_output_length().observe( + result_tokens.get_result_at_slot(slot).lengths + ) + self._metrics_collector.get_request_success_count_metric().inc() + self._metrics_collector.get_time_per_output_token().observe( + ( + request.metadata.complete_time + - request.metadata.transfer_enqueue_time + ) + / result_tokens.get_result_at_slot(slot).lengths + ) + self._metrics_collector.get_time_per_request().observe( + request.metadata.complete_time + - request.metadata.transfer_enqueue_time + ) + + if request.metadata.start_time: + total_time = ( + request.metadata.complete_time + - request.metadata.start_time + ) + prefill_time = ( + request.metadata.transfer_enqueue_time + - request.metadata.prefill_dequeue_time + ) + generate_time = ( + request.metadata.complete_time + - request.metadata.generate_dequeue_time + ) + self._metrics_collector.get_wait_time_per_request().observe( + total_time - prefill_time - generate_time + ) # Place the slot back on the free queue. my_live_requests[slot] = None my_slots.put(slot, block=False) # This should always have space. + my_generate_engine.free_resource(slot) logging.info( "Detokenizing generate step %d took %.2fms", generate_timestep_added, @@ -834,10 +995,13 @@ async def Decode( # pylint: disable=invalid-overridden-method # Wrap request as an ActiveRequest. active_request = ActiveRequest( max_tokens=request.max_tokens, - history_path=request.session_cache, prefill_content=prefill_content, is_client_side_tokenization=is_client_side_tokenization, return_channel=return_channel, + metadata=ActiveRequestMetadata( + start_time=request.metadata.start_time, + prefill_enqueue_time=time.perf_counter(), + ), ) # The first stage is being prefilled, all other stages are handled # inside the driver (transfer, generate*N, detokenize). diff --git a/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto index 5f2e8869..f06d89d5 100644 --- a/jetstream/core/proto/jetstream.proto +++ b/jetstream/core/proto/jetstream.proto @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// NOTICE: run `make generate-protos` if making changes to this file + syntax = "proto3"; package jetstream_proto; @@ -26,9 +28,6 @@ service Orchestrator { } message DecodeRequest { - // Where to load any pre-existing kv cache from. - string session_cache = 1; - int32 priority = 3; // The maximum output length of a sequence. It's used in JetStream to control // the output/decode length of a sequence. It would not be used in the engine. // We should always set max_tokens <= (max_target_length - @@ -51,8 +50,17 @@ message DecodeRequest { TextContent text_content = 5; TokenContent token_content = 6; } - reserved 2; - // Next ID: 7 + + message Metadata { + float start_time = 1; + } + + oneof metadata_optional { + Metadata metadata = 7; + } + + reserved 1, 2, 3; + // Next ID: 8 } message DecodeResponse { diff --git a/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py index 3fadd54c..0b146032 100644 --- a/jetstream/core/proto/jetstream_pb2.py +++ b/jetstream/core/proto/jetstream_pb2.py @@ -11,12 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: jetstream/core/proto/jetstream.proto # Protobuf Python Version: 4.25.1 -# pylint: disable=all """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -28,7 +26,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xa7\x02\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x02\x10\x03"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3' + b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xfc\x02\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3' ) _globals = globals() @@ -39,23 +37,25 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None _globals["_DECODEREQUEST"]._serialized_start = 58 - _globals["_DECODEREQUEST"]._serialized_end = 353 - _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 274 - _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 301 - _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 303 - _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 336 - _globals["_DECODERESPONSE"]._serialized_start = 356 - _globals["_DECODERESPONSE"]._serialized_end = 687 - _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 522 - _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 538 - _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 541 - _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 670 - _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 629 - _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 670 - _globals["_HEALTHCHECKREQUEST"]._serialized_start = 689 - _globals["_HEALTHCHECKREQUEST"]._serialized_end = 709 - _globals["_HEALTHCHECKRESPONSE"]._serialized_start = 711 - _globals["_HEALTHCHECKRESPONSE"]._serialized_end = 749 - _globals["_ORCHESTRATOR"]._serialized_start = 752 - _globals["_ORCHESTRATOR"]._serialized_end = 937 + _globals["_DECODEREQUEST"]._serialized_end = 438 + _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 294 + _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 321 + _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 323 + _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 356 + _globals["_DECODEREQUEST_METADATA"]._serialized_start = 358 + _globals["_DECODEREQUEST_METADATA"]._serialized_end = 388 + _globals["_DECODERESPONSE"]._serialized_start = 441 + _globals["_DECODERESPONSE"]._serialized_end = 772 + _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 607 + _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 623 + _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 626 + _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 755 + _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 714 + _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 755 + _globals["_HEALTHCHECKREQUEST"]._serialized_start = 774 + _globals["_HEALTHCHECKREQUEST"]._serialized_end = 794 + _globals["_HEALTHCHECKRESPONSE"]._serialized_start = 796 + _globals["_HEALTHCHECKRESPONSE"]._serialized_end = 834 + _globals["_ORCHESTRATOR"]._serialized_start = 837 + _globals["_ORCHESTRATOR"]._serialized_end = 1022 # @@protoc_insertion_point(module_scope) diff --git a/jetstream/core/proto/jetstream_pb2_grpc.py b/jetstream/core/proto/jetstream_pb2_grpc.py index 84521185..d571ade8 100644 --- a/jetstream/core/proto/jetstream_pb2_grpc.py +++ b/jetstream/core/proto/jetstream_pb2_grpc.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -# pylint: disable=all """Client and server classes corresponding to protobuf-defined services.""" import grpc diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 3d93746d..b323286a 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -20,15 +20,21 @@ import asyncio from concurrent import futures import logging +import os +import signal import threading +import time +import traceback from typing import Any, Type + import grpc import jax from jetstream.core import config_lib from jetstream.core import orchestrator from jetstream.core.metrics.prometheus import JetstreamMetricsCollector from jetstream.core.proto import jetstream_pb2_grpc +from jetstream.engine import warmup_utils, engine_api from prometheus_client import start_http_server @@ -87,6 +93,79 @@ def wait_for_termination(self) -> None: self.stop() +def create_driver( + config: Type[config_lib.ServerConfig], + devices: Any, + jax_padding: bool = True, + metrics_collector: JetstreamMetricsCollector | None = None, + enable_model_warmup: bool = False, +): + """Creates a driver with a specified config. + + Args: + config: A ServerConfig to config engine, model, device slices, etc. + devices: Device objects, will be used to get engine with proper slicing. + jax_padding: The flag to enable JAX padding during tokenization. + metrics_collector: The JetStream Promethus metric collector. + enable_model_warmup: The flag to enable model server warmup. + + Returns: + An orchestrator driver. + """ + engines = config_lib.get_engines(config, devices=devices) + prefill_params = [pe.load_params() for pe in engines.prefill_engines] + generate_params = [ge.load_params() for ge in engines.generate_engines] + shared_params = [ie.load_params() for ie in engines.interleaved_engines] + logging.info("Loaded all weights.") + interleaved_mode = ( + len(config.prefill_slices) + len(config.generate_slices) == 0 + ) + + prefill_engines = engines.prefill_engines + engines.interleaved_engines + generate_engines = engines.generate_engines + engines.interleaved_engines + prefill_params = prefill_params + shared_params + generate_params = generate_params + shared_params + + if prefill_engines is None: + prefill_engines = [] + if generate_engines is None: + generate_engines = [] + if prefill_params is None: + prefill_params = [] + if generate_params is None: + generate_params = [] + + if enable_model_warmup: + prefill_engines = [engine_api.JetStreamEngine(pe) for pe in prefill_engines] + generate_engines = [ + engine_api.JetStreamEngine(ge) for ge in generate_engines + ] + + try: + _ = warmup_utils.layout_params_and_compile_executables( + prefill_engines, # pylint: disable=protected-access + generate_engines, # pylint: disable=protected-access + prefill_params, # pylint: disable=protected-access + generate_params, # pylint: disable=protected-access + ) + + except ValueError as e: + print(f"Model warmup encountered an error: {e}") + traceback.print_exc() + os.kill(os.getpid(), signal.SIGKILL) + + return orchestrator.Driver( + prefill_engines=prefill_engines, + generate_engines=generate_engines, + prefill_params=prefill_params, + generate_params=generate_params, + interleaved_mode=interleaved_mode, + jax_padding=jax_padding, + metrics_collector=metrics_collector, + is_ray_backend=config.is_ray_backend, + ) + + def run( port: int, config: Type[config_lib.ServerConfig], @@ -97,6 +176,7 @@ def run( metrics_server_config: config_lib.MetricsServerConfig | None = None, enable_jax_profiler: bool = False, jax_profiler_port: int = 9999, + enable_model_warmup: bool = False, ) -> JetStreamServer: """Runs a server with a specified config. @@ -111,20 +191,13 @@ def run( metrics_server_config: The config to enable Promethus metric server. enable_jax_profiler: The flag to enable JAX profiler server. jax_profiler_port: The port JAX profiler server (default to 9999). + enable_model_warmup: The flag to enable model server warmup. Returns: JetStreamServer that wraps the grpc server and orchestrator driver. """ + server_start_time = time.time() logging.info("Kicking off gRPC server.") - engines = config_lib.get_engines(config, devices=devices) - prefill_params = [pe.load_params() for pe in engines.prefill_engines] - generate_params = [ge.load_params() for ge in engines.generate_engines] - shared_params = [ie.load_params() for ie in engines.interleaved_engines] - logging.info("Loaded all weights.") - interleaved_mode = ( - len(config.prefill_slices) + len(config.generate_slices) == 0 - ) - # Setup Prometheus server metrics_collector: JetstreamMetricsCollector = None if metrics_server_config and metrics_server_config.port: @@ -138,15 +211,8 @@ def run( "Not starting Prometheus server: --prometheus_port flag not set" ) - driver = orchestrator.Driver( - prefill_engines=engines.prefill_engines + engines.interleaved_engines, - generate_engines=engines.generate_engines + engines.interleaved_engines, - prefill_params=prefill_params + shared_params, - generate_params=generate_params + shared_params, - interleaved_mode=interleaved_mode, - jax_padding=jax_padding, - metrics_collector=metrics_collector, - is_ray_backend=config.is_ray_backend, + driver = create_driver( + config, devices, jax_padding, metrics_collector, enable_model_warmup ) # We default threads to the total number of concurrent allowed decodes, # to make sure we can fully saturate the model. Set default minimum to 64. @@ -156,12 +222,27 @@ def run( jetstream_server.start() + if metrics_collector: + metrics_collector.get_server_startup_latency_metric().set( + time.time() - server_start_time + ) + # Setup Jax Profiler if enable_jax_profiler: logging.info("Starting JAX profiler server on port %s", jax_profiler_port) jax.profiler.start_server(jax_profiler_port) else: logging.info("Not starting JAX profiler server: %s", enable_jax_profiler) + + # Start profiling server by default for proxy backend. + if jax.config.jax_platforms and "proxy" in jax.config.jax_platforms: + from jetstream.core.utils import proxy_util # pylint: disable=import-outside-toplevel + + thread = threading.Thread( + target=proxy_util.start_profiling_server, args=(jax_profiler_port,) + ) + thread.run() + return jetstream_server diff --git a/jetstream/core/utils/proxy_util.py b/jetstream/core/utils/proxy_util.py new file mode 100644 index 00000000..fbf4b031 --- /dev/null +++ b/jetstream/core/utils/proxy_util.py @@ -0,0 +1,48 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Proxy util functions.""" + +import dataclasses +import logging +import jax +import time +from fastapi import FastAPI +import uvicorn + + +# TODO: add a manner way to terminate. +def start_profiling_server(port: int): + + logging.info("Starting JAX profiler server on port %s", port) + app = FastAPI() + + @dataclasses.dataclass + class ProfilingConfig: + seconds: int + output_dir: str + + @app.post("/profiling") + async def profiling(pc: ProfilingConfig): + jax.profiler.start_trace(pc.output_dir) + logging.info("Capturing the profiling data for next %s seconds", pc.seconds) + time.sleep(pc.seconds) + logging.info("Writing profiling data to %s", pc.output_dir) + jax.profiler.stop_trace() + return {"response": "profiling completed"} + + @app.get("/") + async def root(): + return {"message": "Hello from proxy profiling server"} + + uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") diff --git a/jetstream/engine/__init__.py b/jetstream/engine/__init__.py index 2ed0398d..99bf4983 100644 --- a/jetstream/engine/__init__.py +++ b/jetstream/engine/__init__.py @@ -16,30 +16,8 @@ import jax - -def register_proxy_backend(): - """Try to register IFRT Proxy backend if it's needed.""" - # TODO: find a more elegant way to do it. - if jax.config.jax_platforms and "proxy" in jax.config.jax_platforms: - try: - jax.lib.xla_bridge.get_backend("proxy") - except RuntimeError: - try: - from jaxlib.xla_extension import ifrt_proxy # pylint: disable=import-outside-toplevel - - jax_backend_target = jax.config.read("jax_backend_target") - jax._src.xla_bridge.register_backend_factory( # pylint: disable=protected-access - "proxy", - lambda: ifrt_proxy.get_client( - jax_backend_target, - ifrt_proxy.ClientConnectionOptions(), - ), - priority=-1, - ) - print(f"Registered IFRT Proxy with address {jax_backend_target}") - except ImportError as e: - print(f"Failed to register IFRT Proxy, exception: {e}") - pass - - -register_proxy_backend() +try: + import pathwaysutils +except ImportError as e: + print("Proxy backend support is not added") + pass diff --git a/jetstream/engine/engine_api.py b/jetstream/engine/engine_api.py index 50feff6d..9f42b60f 100644 --- a/jetstream/engine/engine_api.py +++ b/jetstream/engine/engine_api.py @@ -19,7 +19,7 @@ """ import abc -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union, Callable from flax import struct import jax @@ -142,18 +142,24 @@ def prefill( existing_prefix: Optional[Prefix] = None, padded_tokens: jax.Array, true_length: int, - ) -> Prefix: + sampler: Optional[Callable[[Any], Any]] = None, + ) -> Tuple[Prefix, ResultTokens]: """Computes a kv-cache for a set of tokens conditional on existing cache. existing_prefix (if provided) represents a prefix that has already been processed by the underlying model. tokens is logically appended to the text represented by `existing_prefix`. This method returns a new kv_cache (typically) for the resulting text. + + If sampler is passed, then the engine should use it do sample next token. """ @abc.abstractmethod def generate( - self, params: Params, decode_state: DecodeState + self, + params: Params, + decode_state: DecodeState, + sampler: Optional[Callable[[Any], Any]] = None, ) -> Tuple[DecodeState, ResultTokens]: """Generates tokens for each sequence being decoded in parallel. @@ -165,6 +171,8 @@ def generate( consists of each microbatch progressing through every stage), in non-pipelined code this is a full forward pass. In both cases, this accounts for a full embed-layerstack-unembed-sample operation. + + If sampler is passed, then the engine should use it do sample next token. """ @abc.abstractmethod @@ -187,6 +195,18 @@ def insert( a [0, n) range of slots and converted internally. """ + def free_resource( + self, + slot: int, # pylint: disable=unused-argument + ) -> Any: + """Free cache and other decode resource for the slot. + + This function is needed for advanced attetnion kenel like PageAttetion. + After finishing one request, the engine need to free all used page block + resource and reuse for coming requests. + """ + return None + @abc.abstractmethod def load_params(self, *args, **kwargs) -> Params: """Loads parameters. @@ -240,3 +260,95 @@ def mesh(self) -> jax.sharding.Mesh: @abc.abstractmethod def colocated_cpus(self) -> Union[list[CpuDevices], None]: """CPU devices colocated with the engine's accelerators.""" + + +class JetStreamEngine(Engine): + """A wrapper engine of the Engine class. + + JetStreamEngine defines the warmed up model server engine. + """ + + def __init__(self, downstream_engine: Engine): + self._downstream_engine = downstream_engine + + self.prefill_buckets = None + self.warm = False + + def prefill( + self, + *, + params: Params, + existing_prefix: Optional[Prefix] = None, + padded_tokens: jax.Array, + true_length: int, + ) -> Tuple[Prefix, ResultTokens]: + + prefill_result, first_token = self._downstream_engine.prefill( + params=params, + padded_tokens=padded_tokens, + true_length=true_length, + ) + return prefill_result, first_token + + def insert( + self, + prefix: Prefix, + decode_state: DecodeState, + slot: int, + ) -> DecodeState: + + decode_state = self._downstream_engine.insert( + prefix=prefix, + decode_state=decode_state, + slot=slot, + ) + return decode_state + + def generate( + self, params: Params, decode_state: DecodeState + ) -> Tuple[DecodeState, ResultTokens]: + decode_state, sampled_tokens = self._downstream_engine.generate( + params=params, decode_state=decode_state + ) + return decode_state, sampled_tokens + + def load_params(self, *args, **kwargs) -> Params: + return self._downstream_engine.load_params(*args, **kwargs) + + def get_prefix_destination_sharding(self) -> Any: + return self._downstream_engine.get_prefix_destination_sharding() + + def get_tokenizer( + self, + ) -> tokenizer_pb2.TokenizerParameters: + return self._downstream_engine.get_tokenizer() + + def build_tokenizer( + self, + metadata: tokenizer_pb2.TokenizerParameters, + ) -> Tokenizer: + """Builds a new tokenizer object and returns it.""" + return self._downstream_engine.build_tokenizer(metadata) + + def init_decode_state(self, *args, **kwargs) -> DecodeState: + return self._downstream_engine.init_decode_state(*args, **kwargs) + + @property + def max_concurrent_decodes(self) -> int: + return self._downstream_engine.max_concurrent_decodes + + @property + def samples_per_slot(self) -> int: + return self._downstream_engine.samples_per_slot + + @property + def max_prefill_length(self) -> int: + return self._downstream_engine.max_prefill_length + + @property + def mesh(self) -> jax.sharding.Mesh: + return self._downstream_engine.mesh + + @property + def colocated_cpus(self) -> Union[list[CpuDevices], None]: + return self._downstream_engine.colocated_cpus diff --git a/jetstream/engine/mock_engine.py b/jetstream/engine/mock_engine.py index 502df8b6..0277e9a3 100644 --- a/jetstream/engine/mock_engine.py +++ b/jetstream/engine/mock_engine.py @@ -54,6 +54,7 @@ class DecodeState: generate_cache: jax.Array generate_cache_index: int generate_lengths: jax.Array + generate_tokens: jax.Array class TestEngine(engine_api.Engine): @@ -85,7 +86,7 @@ def prefill( existing_prefix: Optional[jax.Array] = None, padded_tokens: jax.Array, true_length: int, - ) -> Prefix: + ) -> Tuple[Prefix, engine_api.ResultTokens]: """Computes a kv-cache for a new generate request. Args: @@ -109,19 +110,55 @@ def prefill( ) # Do some fake work that isn't eliminated by dead code elimination (DCE). params = params + fake_work.mean() - fake_work.mean() - return padded_tokens[None, :] * params + prefill_cache = padded_tokens[None, :] * params + + # get dummy first token + first_step = (prefill_cache.sum(axis=-1))[:, jnp.newaxis] + first_token_data = jnp.concatenate( + [first_step, jnp.ones_like(first_step), jnp.ones_like(first_step)], + axis=-1, + ) + speculations = first_step.shape[1] + first_token = engine_api.ResultTokens( + data=first_token_data.astype(jnp.int32), + tokens_idx=(0, speculations), + # Validity occupies the same amount of space, but next in line. + valid_idx=(speculations, 2 * speculations), + # And lengths is rank 1. + length_idx=(2 * speculations, 2 * speculations + 1), + samples_per_slot=self.generate_cache_batch // self.prefill_cache_batch, + ) + + return (prefill_cache, first_step), first_token @functools.partial(jax.jit, static_argnums=(0,)) def generate( self, params: Params, decode_state: DecodeState ) -> Tuple[DecodeState, engine_api.ResultTokens]: """Generates tokens for each sequence being decoded in parallel.""" - prefill_cache, generate_cache, generate_cache_index, generate_lengths = ( + ( + prefill_cache, + generate_cache, + generate_cache_index, + generate_lengths, + previous_timestep, + ) = ( decode_state.prefill_cache, decode_state.generate_cache, decode_state.generate_cache_index, decode_state.generate_lengths, + decode_state.generate_tokens, ) + + # Update generate cache + generate_cache = jax.lax.dynamic_update_slice_in_dim( + generate_cache, + previous_timestep, + start_index=generate_cache_index, + axis=1, + ) + generate_cache_index = (generate_cache_index + 1) % self.cache_length + # Sum each row of prefill cache and generate cache to produce new timestep, # multiply by params. l_iota = jax.lax.broadcasted_iota( @@ -136,17 +173,13 @@ def generate( # token from prefill in the dummy. # This iota and masking is to allow for a cicular cache. length_mask = ( - -(l_iota - generate_cache_index + 1) % self.cache_length + -(l_iota - generate_cache_index) % self.cache_length ) <= generate_lengths[:, None] length_masked_gen_cache = generate_cache * length_mask new_timestep = ( prefill_cache.sum(axis=-1) + (length_masked_gen_cache.sum(axis=-1) / params) )[:, jnp.newaxis] - generate_cache = jax.lax.dynamic_update_slice_in_dim( - generate_cache, new_timestep, start_index=generate_cache_index, axis=1 - ) - generate_cache_index = (generate_cache_index + 1) % self.cache_length # Wait to simulate model step time. fake_size = 4096 fake_work = jnp.ones((fake_size, fake_size)) @ jnp.ones( @@ -168,6 +201,7 @@ def generate( generate_cache=generate_cache, generate_cache_index=generate_cache_index, generate_lengths=new_lengths, + generate_tokens=new_timestep, ), engine_api.ResultTokens( data=token_data.astype(jnp.int32), # Tokens are shape [batch, speculations], so when we concatenate @@ -190,8 +224,9 @@ def insert( ) -> DecodeState: """Adds `prefix` into `decode_state` at `slot`.""" # [B, T], [T,] -> [B, T] + prefill_cache, previous_timestep = prefix prefill_cache = jax.lax.dynamic_update_slice_in_dim( - decode_state.prefill_cache, prefix, slot, axis=0 + decode_state.prefill_cache, prefill_cache, slot, axis=0 ) generate_cache = jax.lax.dynamic_update_slice_in_dim( decode_state.generate_cache, @@ -202,7 +237,13 @@ def insert( samples_per_slot = self.generate_cache_batch // self.prefill_cache_batch generate_lengths = jax.lax.dynamic_update_slice_in_dim( decode_state.generate_lengths, - jnp.zeros((samples_per_slot), dtype=jnp.int32), + jnp.ones((samples_per_slot), dtype=jnp.int32), + slot * samples_per_slot, + axis=0, + ) + generate_tokens = jax.lax.dynamic_update_slice_in_dim( + decode_state.generate_tokens, + previous_timestep, slot * samples_per_slot, axis=0, ) @@ -210,6 +251,7 @@ def insert( prefill_cache=prefill_cache, generate_cache=generate_cache, generate_lengths=generate_lengths, + generate_tokens=generate_tokens, ) def get_prefix_destination_sharding(self) -> Any: @@ -234,6 +276,9 @@ def init_decode_state(self) -> DecodeState: generate_lengths=jnp.zeros( (self.generate_cache_batch), dtype=jnp.int32 ), + generate_tokens=jnp.zeros( + (self.generate_cache_batch, 1), dtype=jnp.float32 + ), ) @property diff --git a/jetstream/engine/mock_utils.py b/jetstream/engine/mock_utils.py index 3c4fb0b3..a48a360f 100644 --- a/jetstream/engine/mock_utils.py +++ b/jetstream/engine/mock_utils.py @@ -63,7 +63,7 @@ def _encode(self, s: str) -> Sequence[int]: def _decode(self, ids: np.ndarray): """Converts a numpy array into a string.""" - return "".join([chr(r) for r in list(ids)]) + return "".join([chr(r) for r in list(ids) if r not in self.stop_tokens]) def _encode_tf(self, s: str) -> np.ndarray: """Converts a string into a numpy array.""" diff --git a/jetstream/engine/sampling_utils.py b/jetstream/engine/sampling_utils.py new file mode 100644 index 00000000..3adb191a --- /dev/null +++ b/jetstream/engine/sampling_utils.py @@ -0,0 +1,87 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=bare-except, consider-using-generator +""" Inference sampling utilities. + + Inspired by an Google-internal implementation, Global Vision Transformer. +""" + +import jax +import jax.numpy as jnp + +NEG_INF = -1.0e7 # Masking purpose + + +def sampling(logits, rng, algorithm, topk=0, nucleus_topp=0, temperature=1.0): + """ + logits: unnormalized logits to sample, shaped [YOUR_LEADING_DIMS, Vocab], + before logit + rng: rng key to use + algorithm: string representing supported algorithms + topk: restricting to topk logits before sampling + nucleus_topp: restricting to p probability mass before sampling + temperature: temperature parameter for scaling probability + """ + if algorithm == "greedy": + return jnp.argmax(logits, axis=-1) + elif algorithm == "weighted": + return jax.random.categorical(rng, logits / temperature) + elif algorithm == "nucleus": + return sample_nucleus_topp_logits(logits, nucleus_topp, temperature, rng) + elif algorithm == "topk": + return sample_topk_logits(logits, topk, temperature, rng) + else: + raise ValueError(f"Sampling {algorithm=} not supported!") + + +def sample_nucleus_topp_logits(logits, nucleus_topp, temperature, rng): + """Restrict sampling to the top logits with cumulative probability >= + nucleus_topp. + + The nucleus sampling method is proposed in the paper `The Curious Case of + Neural Text Degeneration (https://arxiv.org/pdf/1904.09751.pdf)` + + """ + if nucleus_topp < 0: + raise ValueError( + "Can't apply nucleus with parameter {nucleus_topp=} less zero" + ) + logits_sorted = jnp.sort(logits, axis=-1)[..., ::-1] # sort descending + sorted_cum_probs = jnp.cumsum( + jax.nn.softmax(logits_sorted, axis=-1), axis=-1 + ) # get cumsum probs + cutoff_index = jnp.sum( + sorted_cum_probs < nucleus_topp, axis=-1, keepdims=True + ) # find cutoff index + cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) + logits = jnp.where( + logits < cutoff_logit, jnp.full_like(logits, NEG_INF), logits + ) + return jax.random.categorical(rng, logits / temperature) + + +def sample_topk_logits(logits, topk, temperature, rng): + """Restricting sampling to the best k logits.""" + if topk <= 0: + raise ValueError("Can't apply algorithm topk with parameter {topk=} <= 0") + topk_logits, topk_idxs = jax.lax.top_k(logits, topk) + topk_token = jnp.expand_dims( + jax.random.categorical(rng, topk_logits / temperature).astype(jnp.int32), + axis=-1, + ) + sampled_tokens = jnp.squeeze( + jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1 + ).astype(jnp.int32) + return sampled_tokens diff --git a/jetstream/engine/token_utils.py b/jetstream/engine/token_utils.py index 3d905688..b653c34b 100644 --- a/jetstream/engine/token_utils.py +++ b/jetstream/engine/token_utils.py @@ -28,11 +28,26 @@ from jetstream.engine import mock_utils from jetstream.engine import tokenizer_api from jetstream.engine import tokenizer_pb2 -from jetstream.third_party.llama3 import llama3_tokenizer +from jetstream.external_tokenizers.llama3 import llama3_tokenizer # ResultToken class to store tokens ids. ResultTokens = Any +DEFAULT_PREFILL_BUCKETS = [ + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 32768, +] + def take_nearest_length(lengths: list[int], length: int) -> int: """Gets the nearest length to the right in a set of lengths.""" @@ -109,20 +124,7 @@ def pad_tokens( true_length: Actual length of the non-padded sequence. """ if prefill_lengths is None: - prefill_lengths = [ - 16, - 32, - 64, - 128, - 256, - 512, - 1024, - 2048, - 4096, - 8192, - 16384, - 32768, - ] + prefill_lengths = DEFAULT_PREFILL_BUCKETS if max_prefill_length is not None: prefill_lengths = prefill_lengths[ : prefill_lengths.index(max_prefill_length) @@ -214,6 +216,7 @@ def process_result_tokens( ) if tok_id in stop_tokens or not valid: complete[idx] = True + tok_id_so_far.append(tok_id) break else: if not is_client_side_tokenization: diff --git a/jetstream/engine/tokenizer_pb2.py b/jetstream/engine/tokenizer_pb2.py index 4aa69f87..1df16528 100644 --- a/jetstream/engine/tokenizer_pb2.py +++ b/jetstream/engine/tokenizer_pb2.py @@ -11,12 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: jetstream/engine/tokenizer.proto # Protobuf Python Version: 4.25.1 -# pylint: disable=all """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool diff --git a/jetstream/engine/tokenizer_pb2_grpc.py b/jetstream/engine/tokenizer_pb2_grpc.py index 5aa1b7a4..c4ca2afd 100644 --- a/jetstream/engine/tokenizer_pb2_grpc.py +++ b/jetstream/engine/tokenizer_pb2_grpc.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -# pylint: disable=all """Client and server classes corresponding to protobuf-defined services.""" import grpc diff --git a/jetstream/engine/warmup_utils.py b/jetstream/engine/warmup_utils.py new file mode 100644 index 00000000..6bf7c26a --- /dev/null +++ b/jetstream/engine/warmup_utils.py @@ -0,0 +1,219 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model server warmup utils.""" + +import jax.numpy as jnp +import concurrent.futures +from typing import Any, Optional +import logging +from jetstream.engine import engine_api, token_utils + + +def layout_params_and_compile_executables( + prefill_engines: Optional[list[engine_api.JetStreamEngine]] = None, + generate_engines: Optional[list[engine_api.JetStreamEngine]] = None, + prefill_params: Optional[list[Any]] = None, + generate_params: Optional[list[Any]] = None, +) -> bool: + """Organizes the engines and executables. + + Args: + prefill_engines: Prefill only engines. + generate_engines: Generate only engines. + prefill_params: Prefill only params. + generate_params: Generate only params. + """ + prefill_engines = prefill_engines if prefill_engines else [] + generate_engines = generate_engines if generate_engines else [] + prefill_params = prefill_params if prefill_params else [] + generate_params = generate_params if generate_params else [] + + any_prefill_engine = None + any_prefill_params = None + + prefills_compiled = [] + inserts_generate_compiled = [] + + for i, pe in enumerate(prefill_engines): + any_prefill_engine = pe + any_prefill_params = prefill_params[i] + prefill_compiled = initialize_prefill_jit_cache( + prefill_engine=pe, + prefill_params=prefill_params[i], + prefill_idx=i, + ) + prefills_compiled.append(prefill_compiled) + + for i, ge in enumerate(generate_engines): + insert_generate_compiled = initialize_insert_generate_jit_cache( + prefill_engine=any_prefill_engine, + generate_engine=ge, + prefill_params=any_prefill_params, + generate_params=generate_params[i], + generate_idx=i, + ) + inserts_generate_compiled.append([insert_generate_compiled]) + + if all(prefills_compiled) and all(inserts_generate_compiled): + return True + return False + + +def initialize_prefill_jit_cache( + *, + prefill_engine: engine_api.JetStreamEngine, + prefill_params: Any, + prefill_idx: int, +): + """Precompile all prefill functions in parallel. + If we don't do this, then when a new request triggers a new prefill bucket it + will take a very long time for that query to come back. + + Args: + prefill_engine: A prefill engine to be compiled for. + prefill_params: The associated prefill parameters. + prefill_idx: Which prefill engine it is. + """ + prefill_buckets = token_utils.DEFAULT_PREFILL_BUCKETS + prefill_buckets = [ + bucket + for bucket in prefill_buckets + if bucket <= prefill_engine.max_prefill_length + ] + prefill_engine.prefill_buckets = prefill_buckets + if prefill_engine.max_prefill_length not in prefill_buckets: + prefill_buckets.append(prefill_engine.max_prefill_length) + + def compile_prefill(length): + padded_tokens, true_length = jnp.ones((length), dtype="int32"), length + + _, _ = prefill_engine._downstream_engine.prefill( # pylint: disable=protected-access + params=prefill_params, + padded_tokens=padded_tokens, + true_length=true_length, + ) + + logging.info( + "---------Prefill engine %d compiled for prefill length %d.---------", + prefill_idx, + length, + ) + + logging.info("---------Prefill compilation %d begun.---------", prefill_idx) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(prefill_buckets) + ) as executor: + _ = executor.map(compile_prefill, prefill_buckets) + + prefill_engine.warm = True + + logging.info( + "---------Prefill compilation %d complete.---------", prefill_idx + ) + + return prefill_engine.warm + + +def initialize_insert_generate_jit_cache( + *, + prefill_engine: engine_api.JetStreamEngine, + generate_engine: engine_api.JetStreamEngine, + prefill_params: Any, + generate_params: Any, + generate_idx: int, +): + """Initialiszes jit cache for insert and generate. + + Args: + generate_engine: A generate engine to be compiled for. + generate_params: The associated parameters. + generate_idx: Which generate engine it is. + """ + + prefill_buckets = token_utils.DEFAULT_PREFILL_BUCKETS + prefill_buckets = [ + bucket + for bucket in prefill_buckets + if bucket <= generate_engine.max_prefill_length + ] + generate_engine.prefill_buckets = prefill_buckets + if generate_engine.max_prefill_length not in prefill_buckets: + prefill_buckets.append(generate_engine.max_prefill_length) + + decode_state = generate_engine.init_decode_state() + + def compile_insert(length): + padded_tokens, true_length = jnp.ones((length), dtype="int32"), length + + prefill, _ = prefill_engine._downstream_engine.prefill( # pylint: disable=protected-access + params=prefill_params, + padded_tokens=padded_tokens, + true_length=true_length, + ) + + generate_engine.insert(prefix=prefill, decode_state=decode_state, slot=0) + + logging.info( + "---------Generate engine %d compiled for insert length %d.---------", + generate_idx, + length, + ) + + def compile_generate(): + + logging.info( + "---------Generate compilation %d begun.---------", generate_idx + ) + + generate_engine._downstream_engine.generate( # pylint: disable=protected-access + params=generate_params, + decode_state=decode_state, + ) + + logging.info( + "---------Generate engine %d compiled.---------", + generate_idx, + ) + + logging.info( + "---------Generate compilation %d complete.---------", generate_idx + ) + + logging.info( + "---------Insertion generation compilation %d begun.---------", + generate_idx, + ) + + compile_generate() + + logging.info( + "---------Generate engine %d compiled generation step.---------", + generate_idx, + ) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(prefill_buckets) + ) as executor: + _ = executor.map(compile_insert, prefill_buckets) + + generate_engine.warm = True + + logging.info( + "---------Insertion generation compilation %d complete.---------", + generate_idx, + ) + + return generate_engine.warm diff --git a/jetstream/entrypoints/__init__.py b/jetstream/entrypoints/__init__.py new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/jetstream/entrypoints/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jetstream/entrypoints/config.py b/jetstream/entrypoints/config.py new file mode 100644 index 00000000..79f2b012 --- /dev/null +++ b/jetstream/entrypoints/config.py @@ -0,0 +1,32 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Config for JetStream Server (including engine init).""" + +from typing import Type + +from jetstream.core import config_lib + + +def get_server_config( + config_str: str, +) -> config_lib.ServerConfig | Type[config_lib.ServerConfig]: + match config_str: + case "InterleavedCPUTestServer": + server_config = config_lib.InterleavedCPUTestServer + case "CPUTestServer": + server_config = config_lib.CPUTestServer + case _: + raise NotImplementedError + return server_config diff --git a/jetstream/entrypoints/http/__init__.py b/jetstream/entrypoints/http/__init__.py new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/jetstream/entrypoints/http/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py new file mode 100644 index 00000000..aaced235 --- /dev/null +++ b/jetstream/entrypoints/http/api_server.py @@ -0,0 +1,137 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JetStream Http API server.""" + +import json +import logging +import time +from typing import Sequence +from absl import app as abslapp +from absl import flags +from fastapi import APIRouter, Response +import fastapi +from fastapi.responses import StreamingResponse +from prometheus_client import start_http_server +import uvicorn +from google.protobuf.json_format import Parse + +from jetstream.core import config_lib, orchestrator, server_lib +from jetstream.core.metrics.prometheus import JetstreamMetricsCollector +from jetstream.core.proto import jetstream_pb2 +from jetstream.entrypoints.config import get_server_config +from jetstream.entrypoints.http.protocol import DecodeRequest +from jetstream.entrypoints.http.utils import proto_to_json_generator + +flags.DEFINE_string("host", "0.0.0.0", "server host address") +flags.DEFINE_integer("port", 8080, "http server port") +flags.DEFINE_string( + "config", + "InterleavedCPUTestServer", + "available servers", +) +flags.DEFINE_integer( + "prometheus_port", + 9988, + "prometheus_port", +) + +llm_orchestrator: orchestrator.LLMOrchestrator + +# Define Fast API endpoints (use llm_orchestrator to handle). +router = APIRouter() + + +@router.get("/") +def root(): + """Root path for Jetstream HTTP Server.""" + return Response( + content=json.dumps({"message": "JetStream HTTP Server"}, indent=4), + media_type="application/json", + ) + + +@router.post("/v1/generate") +async def generate(request: DecodeRequest): + start_time = time.perf_counter() + proto_request = Parse(request.json(), jetstream_pb2.DecodeRequest()) + metadata = jetstream_pb2.DecodeRequest.Metadata() + metadata.start_time = start_time + proto_request.metadata.CopyFrom(metadata) + generator = llm_orchestrator.Decode(proto_request) + return StreamingResponse( + content=proto_to_json_generator(generator), media_type="text/event-stream" + ) + + +@router.get("/v1/health") +async def health() -> Response: + """Health check.""" + response = await llm_orchestrator.HealthCheck( + jetstream_pb2.HealthCheckRequest() + ) + return Response( + content=json.dumps({"is_live": str(response.is_live)}, indent=4), + media_type="application/json", + status_code=200, + ) + + +def server(argv: Sequence[str]): + # Init Fast API. + app = fastapi.FastAPI() + app.include_router(router) + + # Init LLMOrchestrator which would be the main handler in the api endpoints. + devices = server_lib.get_devices() + print(f"devices: {devices}") + server_config = get_server_config(flags.FLAGS.config) + print(f"server_config: {server_config}") + del argv + + metrics_server_config: config_lib.MetricsServerConfig | None = None + # Setup Prometheus server + metrics_collector: JetstreamMetricsCollector = None + if flags.FLAGS.prometheus_port != 0: + metrics_server_config = config_lib.MetricsServerConfig( + port=flags.FLAGS.prometheus_port + ) + logging.info( + "Starting Prometheus server on port %d", metrics_server_config.port + ) + start_http_server(metrics_server_config.port) + metrics_collector = JetstreamMetricsCollector() + else: + logging.info( + "Not starting Prometheus server: --prometheus_port flag not set" + ) + + global llm_orchestrator + llm_orchestrator = orchestrator.LLMOrchestrator( + driver=server_lib.create_driver( + config=server_config, + devices=devices, + metrics_collector=metrics_collector, + ) + ) + + # Start uvicorn http server. + uvicorn.run( + app, host=flags.FLAGS.host, port=flags.FLAGS.port, log_level="info" + ) + + +if __name__ == "__main__": + # Run Abseil app w flags parser. + abslapp.run(server) diff --git a/jetstream/entrypoints/http/protocol.py b/jetstream/entrypoints/http/protocol.py new file mode 100644 index 00000000..cbb8dc6a --- /dev/null +++ b/jetstream/entrypoints/http/protocol.py @@ -0,0 +1,41 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Http API server protocol.""" + +from pydantic import BaseModel # type: ignore + + +class TextContent(BaseModel): + text: str + + +class TokenContent(BaseModel): + token_ids: list[int] + + +class Metadata(BaseModel): + start_time: float + + +class DecodeRequest(BaseModel): + max_tokens: int + text_content: TextContent | None = None + token_content: TokenContent | None = None + metadata: Metadata | None = None + + # Config to enforce the oneof behavior at runtime. + class Config: + extra = "forbid" # Prevent extra fields. + anystr_strip_whitespace = True diff --git a/jetstream/entrypoints/http/utils.py b/jetstream/entrypoints/http/utils.py new file mode 100644 index 00000000..7765a785 --- /dev/null +++ b/jetstream/entrypoints/http/utils.py @@ -0,0 +1,27 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Http API server utilities.""" + +from google.protobuf.json_format import MessageToJson + + +async def proto_to_json_generator(proto_generator): + """Wraps a generator yielding Protocol Buffer messages into a generator + + yielding JSON messages. + """ + async for proto_message in proto_generator: + json_string = MessageToJson(proto_message) + yield json_string diff --git a/jetstream/third_party/__init__.py b/jetstream/external_tokenizers/__init__.py similarity index 100% rename from jetstream/third_party/__init__.py rename to jetstream/external_tokenizers/__init__.py diff --git a/jetstream/third_party/llama3/__init__.py b/jetstream/external_tokenizers/llama3/__init__.py similarity index 100% rename from jetstream/third_party/llama3/__init__.py rename to jetstream/external_tokenizers/llama3/__init__.py diff --git a/jetstream/third_party/llama3/llama3_tokenizer.py b/jetstream/external_tokenizers/llama3/llama3_tokenizer.py similarity index 100% rename from jetstream/third_party/llama3/llama3_tokenizer.py rename to jetstream/external_tokenizers/llama3/llama3_tokenizer.py diff --git a/jetstream/tests/core/test_orchestrator.py b/jetstream/tests/core/test_orchestrator.py index 49494bef..00e2e1c1 100644 --- a/jetstream/tests/core/test_orchestrator.py +++ b/jetstream/tests/core/test_orchestrator.py @@ -78,9 +78,7 @@ async def test_orchestrator_interleaved_mode(self): text = "AB" request = jetstream_pb2.DecodeRequest( - session_cache="", text_content=jetstream_pb2.DecodeRequest.TextContent(text=text), - priority=1, max_tokens=3, ) iterator = client.Decode(request) @@ -109,11 +107,9 @@ async def test_orchestrator_interleaved_mode_client_tokenization(self): token_ids = [65, 66] request = jetstream_pb2.DecodeRequest( - session_cache="", token_content=jetstream_pb2.DecodeRequest.TokenContent( token_ids=token_ids ), - priority=1, max_tokens=3, ) iterator = client.Decode(request) diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 150ac39d..2fdddce9 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -29,6 +29,7 @@ from jetstream.core import server_lib from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc +from jetstream.engine import engine_api import portpicker @@ -39,6 +40,7 @@ class ServerTest(unittest.IsolatedAsyncioTestCase): # Uses weight 2 for prefill, 4 for decode. ( config_lib.CPUTestServer, + True, ["Ċ", "Ō", "Ɵ", ""], [266, 332, 415, None], [None, None], @@ -46,6 +48,15 @@ class ServerTest(unittest.IsolatedAsyncioTestCase): # Uses the same prefill / generate weights (2). ( config_lib.InterleavedCPUTestServer, + True, + ["Ċ", "Ə", "ɖ", ""], + [266, 399, 598, None], + [None], + ), + # Disable the metrics server. + ( + config_lib.InterleavedCPUTestServer, + False, ["Ċ", "Ə", "ɖ", ""], [266, 399, 598, None], [None], @@ -55,6 +66,7 @@ class ServerTest(unittest.IsolatedAsyncioTestCase): async def test_server( self, config: Type[config_lib.ServerConfig], + metrics_enabled: bool, expected_text: list[str], expected_token_ids: list[int | None], devices: list[Any], @@ -62,6 +74,7 @@ async def test_server( """Sets up a server and requests token responses.""" ######################### Server side ###################################### port = portpicker.pick_unused_port() + metrics_port = portpicker.pick_unused_port() print("port: " + str(port)) credentials = grpc.local_server_credentials() @@ -71,11 +84,15 @@ async def test_server( config=config, devices=devices, credentials=credentials, + metrics_server_config=config_lib.MetricsServerConfig(port=metrics_port) + if metrics_enabled is True + else None, ) ###################### Requester side ###################################### - # prometheus not configured, assert no metrics collector on Driver - assert server._driver._metrics_collector is None # pylint: disable=protected-access + # if prometheus not configured, assert no metrics collector on Driver + if metrics_enabled is not True: + assert server._driver._metrics_collector is None # pylint: disable=protected-access async with grpc.aio.secure_channel( f"localhost:{port}", grpc.local_channel_credentials() @@ -92,9 +109,7 @@ async def test_server( # as BOS text = "AB" request = jetstream_pb2.DecodeRequest( - session_cache="", text_content=jetstream_pb2.DecodeRequest.TextContent(text=text), - priority=1, max_tokens=3, ) iterator = stub.Decode(request) @@ -107,14 +122,20 @@ async def test_server( assert output_text == expected_text[counter] assert output_token_id == expected_token_ids[counter] counter += 1 + # assert prometheus server is running and responding + if metrics_enabled is True: + assert server._driver._metrics_collector is not None # pylint: disable=protected-access + assert ( + requests.get( + f"http://localhost:{metrics_port}", timeout=5 + ).status_code + == requests.status_codes.codes["ok"] + ) server.stop() - def test_prometheus_server(self): + def test_jax_profiler_server(self): port = portpicker.pick_unused_port() - metrics_port = portpicker.pick_unused_port() - print("port: " + str(port)) - print("metrics port: " + str(metrics_port)) credentials = grpc.local_server_credentials() # Now test server with prometheus config server = server_lib.run( @@ -122,30 +143,45 @@ def test_prometheus_server(self): config=config_lib.InterleavedCPUTestServer, devices=[None], credentials=credentials, - metrics_server_config=config_lib.MetricsServerConfig(port=metrics_port), - ) - # assert prometheus server is running and responding - assert server._driver._metrics_collector is not None # pylint: disable=protected-access - assert ( - requests.get(f"http://localhost:{metrics_port}", timeout=5).status_code - == requests.status_codes.codes["ok"] + enable_jax_profiler=True, ) + assert server server.stop() - def test_jax_profiler_server(self): + def test_get_devices(self): + assert len(server_lib.get_devices()) == 1 + + async def test_model_warmup(self): port = portpicker.pick_unused_port() + print("port: " + str(port)) credentials = grpc.local_server_credentials() - # Now test server with prometheus config + server = server_lib.run( port=port, config=config_lib.InterleavedCPUTestServer, devices=[None], credentials=credentials, - enable_jax_profiler=True, + enable_model_warmup=True, ) - assert server - server.stop() - def test_get_devices(self): - assert len(server_lib.get_devices()) == 1 + async with grpc.aio.secure_channel( + f"localhost:{port}", grpc.local_channel_credentials() + ) as channel: + stub = jetstream_pb2_grpc.OrchestratorStub(channel) + + healthcheck_request = jetstream_pb2.HealthCheckRequest() + healthcheck_response = stub.HealthCheck(healthcheck_request) + healthcheck_response = await healthcheck_response + + assert healthcheck_response.is_live is True + + for pe in server._driver._prefill_engines: # pylint: disable=protected-access + assert isinstance(pe, engine_api.JetStreamEngine) + assert pe.warm is True + + for ge in server._driver._generate_engines: # pylint: disable=protected-access + assert isinstance(ge, engine_api.JetStreamEngine) + assert ge.warm is True + + server.stop() diff --git a/jetstream/tests/engine/third_party/llama2/tokenizer.model b/jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model similarity index 100% rename from jetstream/tests/engine/third_party/llama2/tokenizer.model rename to jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model diff --git a/jetstream/tests/engine/third_party/llama3/tokenizer.model b/jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model similarity index 100% rename from jetstream/tests/engine/third_party/llama3/tokenizer.model rename to jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model diff --git a/jetstream/tests/engine/test_mock_engine.py b/jetstream/tests/engine/test_mock_engine.py index 3f112067..0d8f2da8 100644 --- a/jetstream/tests/engine/test_mock_engine.py +++ b/jetstream/tests/engine/test_mock_engine.py @@ -54,10 +54,10 @@ def _prefill(self): metadata = engine.get_tokenizer() tokenizer = engine.build_tokenizer(metadata) tokens, true_length = tokenizer.encode(text, is_bos=True) - prefill_result = engine.prefill( + prefill_result, first_token = engine.prefill( params=params, padded_tokens=tokens, true_length=3 ) - return engine, params, prefill_result, true_length + return engine, params, prefill_result, true_length, first_token def _prefill_np(self): """Performs prefill and returns a kv cache.""" @@ -67,14 +67,14 @@ def _prefill_np(self): metadata = engine.get_tokenizer() tokenizer = engine.build_tokenizer(metadata) tokens, true_length = tokenizer.encode(text, is_bos=True, jax_padding=False) - prefill_result = engine.prefill( + prefill_result, first_token = engine.prefill( params=params, padded_tokens=tokens, true_length=3 ) - return engine, params, prefill_result, true_length + return engine, params, prefill_result, true_length, first_token def _generate(self, slot=1): """Performs a single generation step.""" - engine, params, prefill_result, _ = self._prefill() + engine, params, prefill_result, _, _ = self._prefill() decode_state = engine.init_decode_state() decode_state = engine.insert( prefix=prefill_result, decode_state=decode_state, slot=slot @@ -91,16 +91,28 @@ def test_load_params(self): def test_prefill(self): """Tests prefill with weight = 2.""" - _, _, prefill_result, true_length = self._prefill() + engine, _, prefill_result, true_length, first_token = self._prefill() + prefill_cache, _ = prefill_result np.testing.assert_array_equal( - prefill_result[:, :true_length], np.array([[4.0, 130.0, 132.0]]) + prefill_cache[:, :true_length], np.array([[4.0, 130.0, 132.0]]) ) + # test first token + token_data = first_token.get_result_at_slot(0) + tok = token_data.tokens + + metadata = engine.get_tokenizer() + tokenizer = token_utils.load_vocab( + metadata.path, metadata.extra_ids + ).tokenizer + assert tokenizer.IdToPiece(int(tok.item())) == "Ċ" + def test_prefill_np(self): """Tests prefill with weight = 2.""" - _, _, prefill_result, true_length = self._prefill_np() + _, _, prefill_result, true_length, _ = self._prefill_np() + prefill_cache, _ = prefill_result np.testing.assert_array_equal( - prefill_result[:, :true_length], np.array([[4.0, 130.0, 132.0]]) + prefill_cache[:, :true_length], np.array([[4.0, 130.0, 132.0]]) ) def test_generate(self, slot=1): @@ -110,13 +122,7 @@ def test_generate(self, slot=1): tokenizer = token_utils.load_vocab( metadata.path, metadata.extra_ids ).tokenizer - # Char for 266 - token_data = sampled_tokens.get_result_at_slot(slot) - tok = token_data.tokens - assert tokenizer.IdToPiece(int(tok.item())) == "Ċ" - decode_state, sampled_tokens = engine.generate( - params=params, decode_state=decode_state - ) + # Char for 399 token_data = sampled_tokens.get_result_at_slot(slot) tok = token_data.tokens diff --git a/jetstream/tests/engine/test_sampling_utils.py b/jetstream/tests/engine/test_sampling_utils.py new file mode 100644 index 00000000..0bc13bd6 --- /dev/null +++ b/jetstream/tests/engine/test_sampling_utils.py @@ -0,0 +1,74 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests functionality of inference sampling utils.""" + +import jax +import jax.numpy as jnp +import unittest +from jetstream.engine import sampling_utils + + +class SamplingUtilsTest(unittest.TestCase): + + def setUp(self): + self.rng = jax.random.PRNGKey(0) + self.logits = jnp.array([[-0.5, 1.2, 0.8], [-1.0, 0.3, 0.7]]) + + def test_greedy_sampling(self): + token = sampling_utils.sampling(self.logits, self.rng, "greedy") + expected_token = jnp.array([1, 2]) + self.assertTrue(jnp.array_equal(token, expected_token)) + + def test_weighted_sampling(self): + # Multiple samples to increase the chance of catching errors + for _ in range(10): + result = sampling_utils.sampling(self.logits, self.rng, "weighted") + self.assertTrue( + jnp.all(jnp.isin(result, jnp.array([0, 1, 2]))) + ) # Check if sampled from valid indices + + def test_nucleus_sampling(self): + for _ in range(10): + result = sampling_utils.sampling( + self.logits, self.rng, "nucleus", nucleus_topp=0.8 + ) + self.assertTrue(jnp.all(jnp.isin(result, jnp.array([0, 1, 2])))) + invalid_topp = -0.1 + with self.assertRaises(ValueError) as context: + sampling_utils.sampling( + self.logits, self.rng, "nucleus", nucleus_topp=invalid_topp + ) + self.assertIn( + f"Can't apply nucleus with parameter {invalid_topp=} less zero", + str(context.exception), + ) + + def test_topk_sampling(self): + for _ in range(10): + result = sampling_utils.sampling(self.logits, self.rng, "topk", topk=2) + self.assertTrue( + jnp.all(jnp.isin(result, jnp.array([1, 2]))) + ) # Only top 2 logits should be sampled + invalid_topk = 0 + with self.assertRaises(ValueError) as context: + sampling_utils.sampling(self.logits, self.rng, "topk", topk=invalid_topk) + self.assertIn( + f"Can't apply algorithm topk with parameter {invalid_topk=} <= 0", + str(context.exception), + ) + + def test_unsupported_algorithm(self): + with self.assertRaises(ValueError): + sampling_utils.sampling(self.logits, self.rng, "unsupported_algorithm") diff --git a/jetstream/tests/engine/test_token_utils.py b/jetstream/tests/engine/test_token_utils.py index 41fff3d3..cdee310e 100644 --- a/jetstream/tests/engine/test_token_utils.py +++ b/jetstream/tests/engine/test_token_utils.py @@ -55,7 +55,7 @@ def decode(self, t: int) -> str: class TokenUtilsTest(unittest.TestCase): def setup_sentencepiece(self): - self.tokenizer_path = "third_party/llama2/tokenizer.model" + self.tokenizer_path = "external_tokenizers/llama2/tokenizer.model" current_dir = os.path.dirname(__file__) self.tokenizer_path = os.path.join(current_dir, self.tokenizer_path) print(f"model_path: {self.tokenizer_path}") @@ -66,7 +66,7 @@ def setup_sentencepiece(self): self.jt_tokenizer = JetStreamTokenizer(self.tokenizer_path) def setup_tiktoken(self): - self.tokenizer_path = "third_party/llama3/tokenizer.model" + self.tokenizer_path = "external_tokenizers/llama3/tokenizer.model" current_dir = os.path.dirname(__file__) self.tokenizer_path = os.path.join(current_dir, self.tokenizer_path) print(f"model_path: {self.tokenizer_path}") diff --git a/jetstream/tests/entrypoints/__init__.py b/jetstream/tests/entrypoints/__init__.py new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/jetstream/tests/entrypoints/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jetstream/tests/entrypoints/http/__init__.py b/jetstream/tests/entrypoints/http/__init__.py new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/jetstream/tests/entrypoints/http/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jetstream/tests/entrypoints/http/test_api_server.py b/jetstream/tests/entrypoints/http/test_api_server.py new file mode 100644 index 00000000..e6d42e58 --- /dev/null +++ b/jetstream/tests/entrypoints/http/test_api_server.py @@ -0,0 +1,84 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests http server end-to-end.""" + +import subprocess +import sys +import time +import unittest + + +import requests + + +class HTTPServerTest(unittest.IsolatedAsyncioTestCase): + + @classmethod + def setUpClass(cls): + """Sets up a JetStream http server for unit tests.""" + cls.base_url = "http://localhost:8080" + cls.server = subprocess.Popen( + [ + "python", + "-m", + "jetstream.entrypoints.http.api_server", + "--config=InterleavedCPUTestServer", + ], + stdout=sys.stdout, + stderr=sys.stderr, + ) + time.sleep(10) + + @classmethod + def tearDownClass(cls): + """Stop the server gracefully.""" + cls.server.terminate() + + async def test_root_endpoint(self): + response = requests.get(self.base_url + "/", timeout=5) + assert response.status_code == 200 + expected_data = {"message": "JetStream HTTP Server"} + assert response.json() == expected_data + + async def test_health_endpoint(self): + response = requests.get(self.base_url + "/v1/health", timeout=5) + assert response.status_code == 200 + data = response.json() + assert "is_live" in data + assert data["is_live"] == "True" + + async def test_generate_endpoint(self): + # Prepare a sample request (replace with actual data) + sample_request_data = { + "max_tokens": 10, + "text_content": {"text": "translate this to french: hello world"}, + } + + response = requests.post( + self.base_url + "/v1/generate", + json=sample_request_data, + stream=True, + timeout=5, + ) + assert response.status_code == 200 + full_response = [] + for chunk in response.iter_content( + chunk_size=None + ): # chunk_size=None for complete lines + if chunk: + stream_response = chunk.decode("utf-8") + print(f"{stream_response=}") + full_response.append(stream_response) + assert len(full_response) == 11 # 10 tokens + eos token diff --git a/jetstream/tools/load_tester.py b/jetstream/tools/load_tester.py index 4d6445be..5f791efd 100644 --- a/jetstream/tools/load_tester.py +++ b/jetstream/tools/load_tester.py @@ -50,14 +50,11 @@ def api_call( stub: jetstream_pb2_grpc.OrchestratorStub, text: str, max_tokens: int, - session_cache: str = "", print_interim: bool = True, ) -> str: """Sends a request to server and returns text.""" request = jetstream_pb2.DecodeRequest( - session_cache=session_cache, text_content=jetstream_pb2.DecodeRequest.TextContent(text=text), - priority=1, max_tokens=max_tokens, ) response = stub.Decode(request) diff --git a/jetstream/tools/maxtext/model_ckpt_conversion.sh b/jetstream/tools/maxtext/model_ckpt_conversion.sh index 19a62b74..0340dbfe 100644 --- a/jetstream/tools/maxtext/model_ckpt_conversion.sh +++ b/jetstream/tools/maxtext/model_ckpt_conversion.sh @@ -28,27 +28,23 @@ export MODEL=$1 export MODEL_VARIATION=$2 export MODEL_NAME=${MODEL}-${MODEL_VARIATION} -# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ +# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET # Please use separate GCS paths for uploading open source model weights ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). # Point these variables to a GCS bucket that you created. # An example of CHKPT_BUCKET could be: gs://${USER}-maxtext/chkpt/${MODEL}/${MODEL_VARIATION} export CHKPT_BUCKET=$3 -export MODEL_BUCKET=gs://${USER}-maxtext +export MODEL_BUCKET=$4 -# Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run. -export BASE_OUTPUT_DIRECTORY=gs://${USER}-runner-maxtext-logs - -# Point `DATASET_PATH` to the GCS bucket where you have your training data. -export DATASET_PATH=gs://${USER}-maxtext-dataset +# Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run, specifically the unscanned checkpoint. +export BASE_OUTPUT_DIRECTORY=$5 export BUCKET_LOCATION=US # Create three GCS buckets for the demo. gcloud storage buckets create ${MODEL_BUCKET} --location=${BUCKET_LOCATION} || true gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} --location=${BUCKET_LOCATION} || true -gcloud storage buckets create ${DATASET_PATH} --location=${BUCKET_LOCATION} || true -# Covert model checkpoints to MaxText compatible checkpoints. +# Convert model checkpoints to MaxText compatible checkpoints. if [ "$MODEL" == "gemma" ]; then CONVERT_CKPT_SCRIPT="convert_gemma_chkpt.py" JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \ @@ -74,7 +70,7 @@ echo "Written MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL}/${MODEL_ # We define `SCANNED_CKPT_PATH` to refer to the checkpoint subdirectory. export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items -# Covert MaxText compatible checkpoints to unscanned checkpoints. +# Convert MaxText compatible checkpoints to unscanned checkpoints. # Note that the `SCANNED_CKPT_PATH` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx} diff --git a/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh b/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh index 7e6ff1f5..a348bebd 100644 --- a/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh +++ b/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh @@ -66,7 +66,7 @@ checkpoint_period=100 # We will convert the `AQT_CKPT` to unscanned checkpoint in the next step. export AQT_CKPT=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/100/items -# Covert MaxText compatible AQT-fine-tuned checkpoints to unscanned checkpoints. +# Convert MaxText compatible AQT-fine-tuned checkpoints to unscanned checkpoints. # Note that the `AQT_CKPT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx} diff --git a/jetstream/tools/proxy_dev/base.Dockerfile b/jetstream/tools/proxy_dev/base.Dockerfile new file mode 100644 index 00000000..9162bcf0 --- /dev/null +++ b/jetstream/tools/proxy_dev/base.Dockerfile @@ -0,0 +1,35 @@ +# Ubuntu:22.04 +# Use Ubuntu 22.04 from Docker Hub. +# https://hub.docker.com/_/ubuntu/tags\?page\=1\&name\=22.04 +FROM ubuntu:22.04 + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt -y update && apt install -y --no-install-recommends apt-transport-https ca-certificates gnupg git python3.10 python3-pip curl nano vim + +RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 +RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && apt-get update -y && apt-get install google-cloud-sdk -y + + +# Copy all files from local workspace into docker container +COPY JetStream ./JetStream +COPY maxtext ./maxtext + +RUN cd maxtext/ && \ +pip install -r requirements.txt + +RUN pip install setuptools==58 fastapi==0.103.2 uvicorn + +RUN pip install ./JetStream + +RUN apt -y update && apt-get -y install python3-dev && apt-get -y install build-essential +RUN pip install \ + transformers==4.31.0 \ + nltk==3.8.1 \ + evaluate==0.4.0 \ + absl-py==1.4.0 \ + rouge-score==0.1.2 \ + sentencepiece==0.1.99 \ + accelerate==0.21.0 + +ENTRYPOINT ["bash"] diff --git a/jetstream/tools/proxy_dev/dev.Dockerfile b/jetstream/tools/proxy_dev/dev.Dockerfile new file mode 100644 index 00000000..25bf382e --- /dev/null +++ b/jetstream/tools/proxy_dev/dev.Dockerfile @@ -0,0 +1,18 @@ +# Ubuntu:22.04 +# Use Ubuntu 22.04 from Docker Hub. +# https://hub.docker.com/_/ubuntu/tags\?page\=1\&name\=22.04 +FROM base_image + +ENV DEBIAN_FRONTEND=noninteractive + +ENV JAX_PLATFORMS=proxy +ENV JAX_BACKEND_TARGET=grpc://localhost:38681 + +# Copy all files from local workspace into docker container +COPY JetStream ./JetStream +COPY maxtext ./maxtext + +RUN pip install ./JetStream +RUN pip install -r ./maxtext/requirements.txt + +ENTRYPOINT ["bash"] diff --git a/jetstream/tools/requester.py b/jetstream/tools/requester.py index 8fcde556..7ac0d55a 100644 --- a/jetstream/tools/requester.py +++ b/jetstream/tools/requester.py @@ -26,11 +26,7 @@ _SERVER = flags.DEFINE_string("server", "0.0.0.0", "server address") _PORT = flags.DEFINE_string("port", "9000", "port to ping") -_SESSION_CACHE = flags.DEFINE_string( - "session_cache", "", "Location of any pre-cached results" -) -_TEXT = flags.DEFINE_string("text", "Today is a good day", "The message") -_PRIORITY = flags.DEFINE_integer("priority", 0, "Message priority") +_TEXT = flags.DEFINE_string("text", "My dog is cute", "The message") _MAX_TOKENS = flags.DEFINE_integer( "max_tokens", 3, "Maximum number of output/decode tokens of a sequence" ) @@ -82,20 +78,16 @@ def main(argv: Sequence[str]) -> None: vocab = load_vocab(_TOKENIZER.value) token_ids = vocab.tokenizer.encode(_TEXT.value) request = jetstream_pb2.DecodeRequest( - session_cache=_SESSION_CACHE.value, token_content=jetstream_pb2.DecodeRequest.TokenContent( token_ids=token_ids ), - priority=_PRIORITY.value, max_tokens=_MAX_TOKENS.value, ) else: request = jetstream_pb2.DecodeRequest( - session_cache=_SESSION_CACHE.value, text_content=jetstream_pb2.DecodeRequest.TextContent( text=_TEXT.value ), - priority=_PRIORITY.value, max_tokens=_MAX_TOKENS.value, ) return _GetResponseAsync(stub, request) diff --git a/license_preamble.txt b/license_preamble.txt new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/license_preamble.txt @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/pylintrc b/pylintrc index 0df2cb47..52bc64ba 100644 --- a/pylintrc +++ b/pylintrc @@ -9,7 +9,7 @@ [MAIN] # Files or directories to be skipped. They should be base names, not paths. -ignore=third_party +ignore=external_tokenizers # Files or directories matching the regex patterns are skipped. The regex # matches against base names, not paths. diff --git a/requirements.in b/requirements.in deleted file mode 100644 index eba423d4..00000000 --- a/requirements.in +++ /dev/null @@ -1,15 +0,0 @@ -absl-py -coverage -flax -grpcio -jax -jaxlib -numpy -portpicker -prometheus-client -pytest -seqio -tiktoken -blobfile -parameterized -shortuuid \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 057b4f8b..86841a57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,317 +1,19 @@ -# -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: -# -# pip-compile requirements.in -# -absl-py==1.4.0 - # via - # -r requirements.in - # array-record - # chex - # clu - # etils - # ml-collections - # optax - # orbax-checkpoint - # seqio - # tensorboard - # tensorflow - # tensorflow-metadata - # tfds-nightly -array-record==0.5.0 - # via tfds-nightly -astunparse==1.6.3 - # via tensorflow -blobfile==2.1.1 - # via -r requirements.in -cachetools==5.3.2 - # via google-auth -certifi==2024.2.2 - # via requests -charset-normalizer==3.3.2 - # via requests -chex==0.1.7 - # via optax -click==8.1.7 - # via tfds-nightly -clu==0.0.10 - # via seqio -contextlib2==21.6.0 - # via ml-collections -coverage==7.4.4 - # via -r requirements.in -dm-tree==0.1.8 - # via - # chex - # tfds-nightly -docstring-parser==0.15 - # via pyglove -editdistance==0.6.2 - # via seqio -etils[array-types,enp,epath,epy,etqdm,etree]==1.6.0 - # via - # array-record - # clu - # orbax-checkpoint - # tfds-nightly -exceptiongroup==1.2.0 - # via pytest -filelock==3.14.0 - # via blobfile -flatbuffers==23.5.26 - # via tensorflow -flax==0.8.0 - # via - # -r requirements.in - # clu -fsspec==2023.12.2 - # via etils -gast==0.4.0 - # via tensorflow -google-auth==2.27.0 - # via - # google-auth-oauthlib - # tensorboard -google-auth-oauthlib==1.0.0 - # via tensorboard -google-pasta==0.2.0 - # via tensorflow -googleapis-common-protos==1.62.0 - # via tensorflow-metadata -grpcio==1.60.1 - # via - # -r requirements.in - # tensorboard - # tensorflow -h5py==3.10.0 - # via tensorflow -idna==3.7 - # via requests -importlib-resources==6.1.1 - # via etils -iniconfig==2.0.0 - # via pytest -jax==0.4.23 - # via - # -r requirements.in - # chex - # clu - # flax - # optax - # orbax-checkpoint - # seqio -jaxlib==0.4.23 - # via - # -r requirements.in - # chex - # clu - # optax - # orbax-checkpoint - # seqio -keras==2.13.1 - # via tensorflow -libclang==16.0.6 - # via tensorflow -lxml==4.9.4 - # via blobfile -markdown==3.5.2 - # via tensorboard -markdown-it-py==3.0.0 - # via rich -markupsafe==2.1.5 - # via werkzeug -mdurl==0.1.2 - # via markdown-it-py -ml-collections==0.1.1 - # via clu -ml-dtypes==0.3.2 - # via - # jax - # jaxlib - # tensorstore -msgpack==1.0.7 - # via - # flax - # orbax-checkpoint -nest-asyncio==1.6.0 - # via orbax-checkpoint -numpy==1.23.1 - # via - # -r requirements.in - # chex - # clu - # etils - # flax - # h5py - # jax - # jaxlib - # ml-dtypes - # opt-einsum - # optax - # orbax-checkpoint - # scipy - # seqio - # tensorboard - # tensorflow - # tensorflow-hub - # tensorstore - # tfds-nightly -oauthlib==3.2.2 - # via requests-oauthlib -opt-einsum==3.3.0 - # via - # jax - # tensorflow -optax==0.1.8 - # via flax -orbax-checkpoint==0.5.2 - # via flax -packaging==23.2 - # via - # clu - # pytest - # seqio - # tensorflow -parameterized==0.9.0 - # via -r requirements.in -pluggy==1.4.0 - # via pytest -portpicker==1.6.0 - # via -r requirements.in -prometheus-client==0.20.0 - # via -r requirements.in -promise==2.3 - # via tfds-nightly -protobuf==3.20.3 - # via - # googleapis-common-protos - # orbax-checkpoint - # seqio - # tensorboard - # tensorflow - # tensorflow-hub - # tensorflow-metadata - # tfds-nightly -psutil==5.9.8 - # via - # portpicker - # tfds-nightly -pyasn1==0.5.1 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 - # via google-auth -pycryptodomex==3.20.0 - # via blobfile -pyglove==0.4.4 - # via seqio -pygments==2.17.2 - # via rich -pytest==8.1.1 - # via -r requirements.in -pyyaml==6.0.1 - # via - # flax - # ml-collections - # orbax-checkpoint -regex==2024.4.28 - # via tiktoken -requests==2.32.0 - # via - # requests-oauthlib - # tensorboard - # tfds-nightly - # tiktoken -requests-oauthlib==1.3.1 - # via google-auth-oauthlib -rich==13.7.0 - # via flax -rsa==4.9 - # via google-auth -scipy==1.12.0 - # via - # jax - # jaxlib -sentencepiece==0.1.99 - # via seqio -seqio==0.0.19 - # via -r requirements.in -shortuuid==1.0.13 - # via -r requirements.in -six==1.16.0 - # via - # astunparse - # google-pasta - # ml-collections - # promise - # tensorflow -tensorboard==2.13.0 - # via tensorflow -tensorboard-data-server==0.7.2 - # via tensorboard -tensorflow==2.13.1 - # via tensorflow-text -tensorflow-estimator==2.13.0 - # via tensorflow -tensorflow-hub==0.16.1 - # via tensorflow-text -tensorflow-io-gcs-filesystem==0.35.0 - # via tensorflow -tensorflow-metadata==1.14.0 - # via tfds-nightly -tensorflow-text==2.13.0 - # via seqio -tensorstore==0.1.52 - # via - # flax - # orbax-checkpoint -termcolor==2.4.0 - # via - # tensorflow - # tfds-nightly -tf-keras==2.15.0 - # via tensorflow-hub -tfds-nightly==4.9.2.dev202308090034 - # via seqio -tiktoken==0.6.0 - # via -r requirements.in -toml==0.10.2 - # via tfds-nightly -tomli==2.0.1 - # via pytest -toolz==0.12.1 - # via chex -tqdm==4.66.3 - # via - # etils - # tfds-nightly -typing-extensions==4.5.0 - # via - # chex - # clu - # etils - # flax - # orbax-checkpoint - # tensorflow -urllib3==2.2.0 - # via - # blobfile - # requests -werkzeug==3.0.1 - # via tensorboard -wheel==0.42.0 - # via - # astunparse - # tensorboard -wrapt==1.16.0 - # via - # clu - # tensorflow - # tfds-nightly -zipp==3.17.0 - # via etils - -# The following packages are considered to be unsafe in a requirements file: -# setuptools +absl-py +coverage +flax +grpcio +jax +jaxlib +numpy +portpicker +prometheus-client +pytest +seqio +tiktoken +blobfile +parameterized +shortuuid +fastapi +uvicorn +# For profiling +tensorboard-plugin-profile \ No newline at end of file diff --git a/setup.py b/setup.py index c4efd21e..55b91c3b 100644 --- a/setup.py +++ b/setup.py @@ -39,5 +39,5 @@ def parse_requirements(filename): "Operating System :: OS Independent", ], python_requires=">=3.10", - install_requires=parse_requirements("requirements.in"), + install_requires=parse_requirements("requirements.txt"), )