Documentation Index
Fetch the complete documentation index at: https://cerebrium.ai/docs/llms.txt
Use this file to discover all available pages before exploring further.
This example is only compatible with CLI v1.20 and later. Should you be making
use of an older version of the CLI, please run pip install --upgrade cerebrium to upgrade it to the latest version.
This tutorial covers deploying Mistral 7B using the popular vLLM inference framework.
To see the final implementation, you can view it here
Basic Setup
Developing on Cerebrium is similar to a virtual machine or Google Colab. Install the Cerebrium package and log in before proceeding. See the installation docs for details.
First, create your project:
cerebrium init 1-faster-inference-with-vllm
Add these Python packages to the [cerebrium.dependencies.pip] section in your cerebrium.toml file:
[cerebrium.dependencies.pip]
sentencepiece = "latest"
torch = ">=2.0.0"
vllm = "latest"
transformers = ">=4.35.0"
accelerate = "latest"
xformers = "latest"
Create a main.py file. This implementation fits in a single file. Start by defining the request object:
from pydantic import BaseModel
class Item(BaseModel):
prompt: str
temperature: float
top_p: float
top_k: float
max_tokens: int
frequency_penalty: float
Pydantic handles data validation. The prompt parameter is required; others are optional with default values. A missing prompt triggers an automatic error message.
vLLM Implementation
Model Setup
import os
from vllm import LLM, SamplingParams
from huggingface_hub import login
# Your huggingface token (HF_AUTH_TOKEN) should be stored in your project secrets on your dashboard
login(token=os.environ.get("HF_AUTH_TOKEN"))
# Initialize model with optimized settings
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.1", dtype="bfloat16", max_model_len=20000, gpu_memory_utilization=0.9)
def predict(prompt, temperature=0.8, top_p=0.75, top_k=40, max_tokens=256, frequency_penalty=1):
item = Item(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty
)
sampling_params = SamplingParams(
temperature=item.temperature,
top_p=item.top_p,
top_k=item.top_k,
max_tokens=item.max_tokens,
frequency_penalty=item.frequency_penalty
)
outputs = llm.generate([item.prompt], sampling_params)
generated_text = []
for output in outputs:
generated_text.append(output.outputs[0].text)
return {"result": generated_text}
The model loads outside the predict function since it only needs to load once at startup. The predict function passes input parameters from the request to the model and returns generated outputs.
Deploy
Configure your compute and environment settings in cerebrium.toml:
[cerebrium.build]
predict_data = "{\"prompt\": \"Here is some example predict data for your config.yaml which will be used to test your predict function on build.\"}"
hide_public_endpoint = false
disable_animation = false
disable_build_logs = false
disable_syntax_check = false
disable_predict = false
log_level = "INFO"
disable_confirmation = false
[cerebrium.deployment]
name = "1-faster-inference-with-vllm"
python_version = "3.11"
include = ["./*", "main.py", "cerebrium.toml"]
exclude = ["./example_exclude"]
docker_base_image_url = "nvidia/cuda:12.1.1-runtime-ubuntu22.04"
[cerebrium.hardware]
region = "us-east-1"
provider = "aws"
compute = "AMPERE_A10"
cpu = 2
memory = 16.0
gpu_count = 1
[cerebrium.scaling]
min_replicas = 0
max_replicas = 5
cooldown = 60
[cerebrium.dependencies.pip]
huggingface-hub = "latest"
sentencepiece = "latest"
torch = ">=2.0.0"
vllm = "latest"
transformers = ">=4.35.0"
accelerate = "latest"
xformers = "latest"
[cerebrium.dependencies.conda]
[cerebrium.dependencies.apt]
ffmpeg = "latest"
Deploy the model using this command:
After deployment, make this request:
curl --location --request POST 'https://api.aws.us-east-1.cerebrium.ai/v4/p-<YOUR PROJECT ID>/1-faster-inference-with-vllm/predict' \
--header 'Authorization: Bearer <YOUR TOKEN HERE>' \
--header 'Content-Type: application/json' \
--data-raw '{
"prompt": "What is the capital city of France?"
}'
The endpoint returns results in this format:
{
"run_id": "nZL6mD8q66u4lHTXcqmPCc6pxxFwn95IfqQvEix0gHaOH4gkHUdz1w==",
"message": "Finished inference request with run_id: `nZL6mD8q66u4lHTXcqmPCc6pxxFwn95IfqQvEix0gHaOH4gkHUdz1w==`",
"result": {
"result": ["\nA: Paris"]
},
"status_code": 200,
"run_time_ms": 151.24988555908203
}