Skip to main content
This tutorial covers creating and deploying a Gradio chat interface connected to a Llama 8B language model using Cerebrium’s custom ASGI runtime. The architecture runs the frontend on CPU instances while the model runs separately on GPU instances for optimal resource utilization. You can find the full codebase for deploying your Gradio frontend here.

Architecture Overview

The application consists of two main components:
  1. A frontend interface running on CPU instances using FastAPI and Gradio.
  2. A separate Llama model endpoint running on GPU instances. (For a comprehensive example for deploying Llama 8B with TensorRT here.)
This separation enables:
  • Keep the frontend always available while minimizing costs (CPU-only).
  • Scale the GPU-intensive model independently based on demand.
  • Optimize resource allocation for different components.

Prerequisites

Before starting, you’ll need:
  • A Cerebrium account (sign up here).
  • The Cerebrium CLI installed: pip install --upgrade cerebrium.
  • A Llama model endpoint (or other LLM API endpoint).

Basic Setup

First, create a new directory for your project and initialize it:
cerebrium init 2-gradio-interface
Add the following configuration to cerebrium.toml:
[cerebrium.deployment]
name = "2-gradio-interface"
python_version = "3.12"
disable_auth = true
include = ['./*', 'main.py', 'cerebrium.toml']
exclude = ['.*']

[cerebrium.runtime.custom]
entrypoint = ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]
port = 8080
healthcheck_endpoint = "/health"

[cerebrium.hardware]
cpu = 2
memory = 4.0
compute = "CPU"

[cerebrium.scaling]
min_replicas = 0
max_replicas = 2
cooldown = 30
replica_concurrency = 10

[cerebrium.dependencies.pip]
gradio = "latest"
fastapi = "latest"
requests = "latest"
httpx = "latest"
uvicorn = "latest"
starlette = "latest"
This configuration:
  • Disables default JWT authentication, making the Gradio interface publicly accessible.
  • Sets the ASGI server entrypoint to Uvicorn.
  • Sets the default port to 8080.
  • Sets the health endpoint to /health for availability checks.
  • Configures hardware settings for the CPU instance.
  • Defines scaling with min/max replicas, cooldown, and concurrency (10 requests per replica).
  • Specifies required dependencies: Gradio, FastAPI, Requests, HTTPX, Uvicorn, and Starlette.
Set up the main entrypoint file (main.py). Start by creating the FastAPI application:
# at the top of your main.py file
import requests
from typing import Optional, List

import httpx
from fastapi import FastAPI, Request
from starlette.responses import Response as StarletteResponse

app = FastAPI()

# Get the Gradio app URL (when running on Cerebrium)
GRADIO_HOST = os.getenv("GRADIO_HOST", "127.0.0.1")
GRADIO_PORT = int(os.getenv("GRADIO_PORT", "7860"))
GRADIO_URL = os.getenv("GRADIO_SERVER_URL", f"http://{GRADIO_HOST}:{GRADIO_PORT}")

# Health check endpoint
@app.get("/health")
async def health_check():
    return {"status": "healthy"}


@app.route("/{path:path}", include_in_schema=False, methods=["GET", "POST"])
async def gradio(request: Request):
    print(f"Forwarding request path: {request.url.path}")

    headers = dict(request.headers)

    # Construct the full URL to Gradio, preserving the original path
    target_url = f"{GRADIO_URL}{request.url.path}"

    async with httpx.AsyncClient() as client:
        response = await client.request(
            request.method,
            target_url,
            headers=headers,
            data=await request.body(),
            params=request.query_params,
        )

        content = await response.aread()
        response_headers = dict(response.headers)
        return StarletteResponse(
            content=content,
            status_code=response.status_code,
            headers=response_headers,
        )
The above code:
  • Initializes a FastAPI application that forwards requests to the Gradio app running as a subprocess on a different port.
  • Sets up a health check endpoint at /health.
  • Creates a catchall proxy that routes all requests to Gradio, including headers.
Next, set up the Gradio application. Add the following code to main.py:
import multiprocessing
import os
import sys
import time

import gradio as gr

# Global variable for the Gradio server
gradio_server = None

# Configure the Llama endpoint URL
LLAMA_ENDPOINT = os.getenv("LLAMA_ENDPOINT", "<YOUR_MODEL_API_ENDPOINT>")  # Update with your endpoint


class GradioServer:
    def __init__(self):
        self.host = GRADIO_HOST
        self.port = GRADIO_PORT
        self.process: Optional[multiprocessing.Process] = None
        self.url = GRADIO_URL

    async def chat_with_llama(self, message: str, history: List[List[str]]) -> str:
        """Make a request to the Llama endpoint"""
        # Convert history and new message into OpenAI chat format
        messages = []
        for h in history:
            messages.extend([
                {"role": "user", "content": h[0]},
                {"role": "assistant", "content": h[1]}
            ])
        messages.append({"role": "user", "content": message})

        async with httpx.AsyncClient() as client:
            try:
                response = await client.post(
                    f"{LLAMA_ENDPOINT}/v1/chat/completions",
                    json={
                        "messages": messages,
                        "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
                        "stream": False,
                        "temperature": 0.7,
                        "top_p": 0.95
                    },
                    timeout=30.0
                )

                if response.status_code == 200:
                    response_data = response.json()
                    return response_data['choices'][0]['text']
                else:
                    return f"Error: Received status code {response.status_code} from Llama endpoint"
            except Exception as e:
                return f"Error communicating with Llama endpoint: {str(e)}"

    def run_server(self):
        interface = gr.ChatInterface(
            fn=self.chat_with_llama,
            type="messages",
            title="Chat with Llama",
            description="This is a chat interface powered by Llama 3.1 8B Instruct",
            examples=[
                ["What is the capital of France?"],
                ["Explain quantum computing in simple terms"],
                ["Write a short poem about technology"]
            ],
        )
        interface.launch(
            server_name=self.host,
            server_port=self.port,
            root_path=f"https://api.aws.us-east-1.cerebrium.ai/v4/{os.getenv('PROJECT_ID')}/{os.getenv('APP_NAME')}/",
            quiet=True
        )

    def start(self):
        print(f"Starting Gradio server at {self.url} port {self.port}")

        # Start Gradio in a separate process
        self.process = multiprocessing.Process(target=self.run_server)
        self.process.start()

        # Wait for Gradio to become ready
        max_retries = 30
        retry_delay = 1.0

        for _ in range(max_retries):
            try:
                response = requests.get(f"{self.url}/")
                if response.status_code == 200:
                    print(f"Gradio server is ready at {self.url}")
                    return True
            except requests.exceptions.ConnectionError:
                time.sleep(retry_delay)

        print("Failed to start Gradio server")
        self.stop()
        return False

    def stop(self):
        if self.process:
            self.process.terminate()
            self.process.join()
            self.process = None

@app.on_event("startup")
async def startup_event():
    global gradio_server
    if not os.getenv("GRADIO_SERVER_URL"):  # Only start local server if no external URL provided
        gradio_server = GradioServer()
        if not gradio_server.start():
            sys.exit(1)


@app.on_event("shutdown")
async def shutdown_event():
    global gradio_server
    if gradio_server:
        gradio_server.stop()
The code above defines:
  • GradioServer: handles communication with the Llama model endpoint
  • chat_with_llama: sends a message to the Llama model and returns the response
  • run_server: creates a Gradio chat interface
  • start: starts the Gradio server in a separate process
  • stop: stops the Gradio server
  • on_event startup/shutdown: starts and stops the Gradio server respectively
The final main.py file:
import multiprocessing
import os
import sys
import time
from typing import Optional, List

import gradio as gr
import httpx
import requests
from fastapi import FastAPI, Request
from starlette.responses import Response as StarletteResponse

# Initialize FastAPI
app = FastAPI()

# Server configuration
gradio_server = None
GRADIO_HOST = os.getenv("GRADIO_HOST", "127.0.0.1")
GRADIO_PORT = int(os.getenv("GRADIO_PORT", "7860"))
GRADIO_URL = os.getenv("GRADIO_SERVER_URL", f"http://{GRADIO_HOST}:{GRADIO_PORT}")
LLAMA_ENDPOINT = os.getenv("LLAMA_ENDPOINT", "<YOUR_MODEL_API_ENDPOINT>")


class GradioServer:
    def __init__(self):
        self.host = GRADIO_HOST
        self.port = GRADIO_PORT
        self.process: Optional[multiprocessing.Process] = None
        self.url = GRADIO_URL

    async def chat_with_llama(self, message: str, history: List[List[str]]) -> str:
        """Make a request to the Llama endpoint"""
        # Convert history and new message into OpenAI chat format
        messages = []
        for h in history:
            messages.extend([
                {"role": "user", "content": h[0]},
                {"role": "assistant", "content": h[1]}
            ])
        messages.append({"role": "user", "content": message})

        async with httpx.AsyncClient() as client:
            try:
                response = await client.post(
                    f"{LLAMA_ENDPOINT}/v1/chat/completions",
                    json={
                        "messages": messages,
                        "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
                        "stream": False,
                        "temperature": 0.7,
                        "top_p": 0.95
                    },
                    timeout=30.0
                )

                if response.status_code == 200:
                    response_data = response.json()
                    return response_data['choices'][0]['text']
                else:
                    return f"Error: Received status code {response.status_code} from Llama endpoint"
            except Exception as e:
                return f"Error communicating with Llama endpoint: {str(e)}"

    def run_server(self):
        interface = gr.ChatInterface(
            fn=self.chat_with_llama,
            type="messages",
            title="Chat with Llama",
            description="This is a chat interface powered by Llama 3.1 8B Instruct",
            examples=[
                ["What is the capital of France?"],
                ["Explain quantum computing in simple terms"],
                ["Write a short poem about technology"]
            ],
        )
        interface.launch(
            server_name=self.host,
            server_port=self.port,
            root_path=f"https://api.aws.us-east-1.cerebrium.ai/v4/{os.getenv('PROJECT_ID')}/{os.getenv('APP_NAME')}/",
            quiet=True
        )

    def start(self):
        print(f"Starting Gradio server at {self.url} port {self.port}")

        # Start Gradio in a separate process
        self.process = multiprocessing.Process(target=self.run_server)
        self.process.start()

        # Wait for Gradio to become ready
        max_retries = 30
        retry_delay = 1.0

        for _ in range(max_retries):
            try:
                response = requests.get(f"{self.url}/")
                if response.status_code == 200:
                    print(f"Gradio server is ready at {self.url}")
                    return True
            except requests.exceptions.ConnectionError:
                time.sleep(retry_delay)

        print("Failed to start Gradio server")
        self.stop()
        return False

    def stop(self):
        if self.process:
            self.process.terminate()
            self.process.join()
            self.process = None


@app.get("/health")
async def health_check():
    return {"status": "healthy"}


# Catchall proxy endpoint for Gradio
@app.route("/{path:path}", include_in_schema=False, methods=["GET", "POST"])
async def gradio(request: Request):
    print(f"Forwarding request path: {request.url.path}")

    headers = dict(request.headers)

    # Construct the full URL to Gradio, preserving the original path
    target_url = f"{GRADIO_URL}{request.url.path}"

    async with httpx.AsyncClient() as client:
        response = await client.request(
            request.method,
            target_url,
            headers=headers,
            data=await request.body(),
            params=request.query_params,
        )

        content = await response.aread()
        response_headers = dict(response.headers)
        return StarletteResponse(
            content=content,
            status_code=response.status_code,
            headers=response_headers,
        )


@app.on_event("startup")
async def startup_event():
    global gradio_server
    if not os.getenv("GRADIO_SERVER_URL"):  # Only start local server if no external URL provided
        gradio_server = GradioServer()
        if not gradio_server.start():
            sys.exit(1)


@app.on_event("shutdown")
async def shutdown_event():
    global gradio_server
    if gradio_server:
        gradio_server.stop()

Deploy

Deploy the app:
cerebrium deploy -y
Once deployed, navigate to the following URL in your browser:
https://api.aws.us-east-1.cerebrium.ai/v4/p-<YOUR_PROJECT_ID>/2-gradio-interface/
The Gradio chat interface appears at this URL.

Conclusion

This architecture provides a scalable chat app using the ASGI custom runtime. The frontend/backend separation improves performance and cost management while maintaining scaling flexibility. Share feedback, challenges, or Gradio apps in the Discord community.