Deploy an ML App
What you'll build
You will take the MNIST CNN from Lesson 8, wrap it in a FastAPI endpoint that accepts an image and returns a predicted digit, and add basic logging and health-check routes, a production-ready pattern for any ML model.
Concepts
Packaging a model
Before serving, you need to serialise the model so it can be loaded without the training code.
import torch
import torch.nn as nn
# Assume SmallCNN is defined as in Lesson 8
# Save weights
torch.save(model.state_dict(), "model/mnist_cnn.pth")
# Better: save the whole traced model for deployment
model.eval()
example_input = torch.randn(1, 1, 28, 28)
traced = torch.jit.trace(model, example_input)
traced.save("model/mnist_cnn_traced.pt")
torch.jit.trace produces a standalone file that does not require you to define the model class at load time, ideal for deployment.
For scikit-learn models, use joblib:
import joblib
from sklearn.ensemble import RandomForestClassifier
clf = RandomForestClassifier()
clf.fit(X_train, y_train)
joblib.dump(clf, "model/classifier.joblib")
# Load
loaded_clf = joblib.load("model/classifier.joblib")
preds = loaded_clf.predict(X_test)
FastAPI serving
FastAPI is a Python web framework that turns Python functions into HTTP endpoints with automatic input validation and documentation.
# app.py
from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel
import torch
import numpy as np
from PIL import Image
import io
app = FastAPI(title="Digit Recogniser API")
# --- Load model at startup ---
model = torch.jit.load("model/mnist_cnn_traced.pt")
model.eval()
class PredictionResponse(BaseModel):
digit: int
confidence: float
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/predict", response_model=PredictionResponse)
async def predict(file: UploadFile = File(...)):
if file.content_type not in ("image/png", "image/jpeg"):
raise HTTPException(400, "Upload a PNG or JPEG image.")
contents = await file.read()
img = Image.open(io.BytesIO(contents)).convert("L").resize((28, 28))
arr = np.array(img, dtype=np.float32) / 255.0
arr = (arr - 0.1307) / 0.3081 # MNIST normalisation
tensor = torch.tensor(arr).unsqueeze(0).unsqueeze(0) # (1, 1, 28, 28)
with torch.no_grad():
logits = model(tensor)
probs = torch.softmax(logits, dim=1)
digit = probs.argmax(dim=1).item()
conf = probs.max().item()
return PredictionResponse(digit=digit, confidence=round(conf, 4))
Run with: uvicorn app:app --host 0.0.0.0 --port 8000
FastAPI generates interactive docs automatically at http://localhost:8000/docs. You can test the endpoint directly from the browser.
A minimal web demo with HTML and fetch
A simple frontend calls the API and displays the result.
# Add this route to app.py to serve a one-page demo
from fastapi.responses import HTMLResponse
@app.get("/demo", response_class=HTMLResponse)
def demo():
return """
<!DOCTYPE html>
<html>
<head><title>Digit Recogniser</title></head>
<body>
<h2>Digit Recogniser Demo</h2>
<input type="file" id="imageInput" accept="image/*">
<button onclick="predict()">Predict</button>
<p id="result"></p>
<script>
async function predict() {
const file = document.getElementById('imageInput').files[0];
if (!file) { alert('Choose an image first'); return; }
const form = new FormData();
form.append('file', file);
const resp = await fetch('/predict', {method: 'POST', body: form});
const data = await resp.json();
document.getElementById('result').textContent =
'Digit: ' + data.digit + ' Confidence: ' + (data.confidence * 100).toFixed(1) + '%';
}
</script>
</body>
</html>
"""
This gives you a browser-testable demo at http://localhost:8000/demo with no frontend build tool needed.
What to watch in production
Cost
Track how many inferences you are making per day. For LLM APIs, each token costs money. For self-hosted models, track compute hours. Set budget alerts early.
# Simple request counter with Python's built-in counter
# In production you'd use Prometheus / Grafana
import threading
counters = {"requests": 0, "errors": 0}
lock = threading.Lock()
@app.middleware("http")
async def count_requests(request, call_next):
with lock:
counters["requests"] += 1
response = await call_next(request)
if response.status_code >= 400:
with lock:
counters["errors"] += 1
return response
@app.get("/metrics")
def metrics():
return counters
Latency
Measure inference time and expose it. Users expect < 200ms for interactive applications.
import time
from fastapi import Request
@app.middleware("http")
async def add_latency_header(request: Request, call_next):
start = time.perf_counter()
response = await call_next(request)
duration_ms = (time.perf_counter() - start) * 1000
response.headers["X-Inference-Time-ms"] = f"{duration_ms:.1f}"
return response
Data drift
A model trained on data from 2023 may degrade on 2025 data if the real-world distribution has shifted. Monitor the distribution of your inputs over time.
import numpy as np
class DriftMonitor:
def __init__(self, reference_mean, reference_std):
self.ref_mean = reference_mean
self.ref_std = reference_std
def check(self, recent_predictions, threshold_z=3.0):
mean_shift = abs(recent_predictions.mean() - self.ref_mean) / self.ref_std
if mean_shift > threshold_z:
print(f"ALERT: Prediction mean shifted by {mean_shift:.1f} std devs!")
return mean_shift
# Usage: compare recent prediction distribution to training distribution
monitor = DriftMonitor(reference_mean=4.5, reference_std=2.8)
recent = np.array([0, 0, 1, 0, 0, 0, 1, 0]) # lots of zeros, suspicious
monitor.check(recent)
Hands-on
Full project structure for a production-ready ML API.
ml-api/
model/
mnist_cnn_traced.pt
app.py
requirements.txt
Dockerfile
requirements.txt:
fastapi==0.111.0
uvicorn==0.29.0
torch==2.3.0
torchvision==0.18.0
Pillow==10.3.0
numpy==1.26.4
Dockerfile:
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY model/ model/
COPY app.py .
EXPOSE 8000
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
Build and run:
docker build -t digit-api .
docker run -p 8000:8000 digit-api
Test with curl:
curl -X POST http://localhost:8000/predict \
-F "file=@test_digit.png" \
-H "accept: application/json"
Full app.py with middleware for latency and request counting:
import io
import time
import threading
import numpy as np
import torch
from fastapi import FastAPI, File, HTTPException, Request, UploadFile
from fastapi.responses import HTMLResponse
from PIL import Image
from pydantic import BaseModel
app = FastAPI(title="Digit Recogniser API", version="1.0")
model = torch.jit.load("model/mnist_cnn_traced.pt")
model.eval()
counters = {"requests": 0, "errors": 0}
lock = threading.Lock()
class PredictionResponse(BaseModel):
digit: int
confidence: float
@app.middleware("http")
async def track_requests(request: Request, call_next):
with lock:
counters["requests"] += 1
start = time.perf_counter()
response = await call_next(request)
ms = (time.perf_counter() - start) * 1000
response.headers["X-Inference-Time-ms"] = f"{ms:.1f}"
if response.status_code >= 400:
with lock:
counters["errors"] += 1
return response
@app.get("/health")
def health():
return {"status": "ok"}
@app.get("/metrics")
def get_metrics():
return counters
@app.post("/predict", response_model=PredictionResponse)
async def predict(file: UploadFile = File(...)):
if file.content_type not in ("image/png", "image/jpeg"):
raise HTTPException(400, "Upload PNG or JPEG.")
data = await file.read()
img = Image.open(io.BytesIO(data)).convert("L").resize((28, 28))
arr = np.array(img, dtype=np.float32) / 255.0
arr = (arr - 0.1307) / 0.3081
tensor = torch.tensor(arr).unsqueeze(0).unsqueeze(0)
with torch.no_grad():
logits = model(tensor)
probs = torch.softmax(logits, dim=1)
digit = probs.argmax(dim=1).item()
conf = probs.max().item()
return PredictionResponse(digit=digit, confidence=round(conf, 4))
Common pitfalls
Loading the model on every request. Load the model once at startup (module-level or in a lifespan context manager), not inside the endpoint function. Loading takes seconds; inference takes milliseconds.
Not using torch.no_grad() during inference. Without it, PyTorch builds a computation graph for every forward pass, wasting memory and time.
Blocking the event loop with heavy computation. FastAPI is async. If your model inference takes > 100ms, run it in a thread pool: await asyncio.to_thread(model, tensor).
Serving raw exception messages to users. Catch exceptions, log them server-side, and return a clean error message. Stack traces expose implementation details and are confusing to end users.
Deploying without a health check. Load balancers and container orchestrators (Kubernetes) use /health to know if your service is alive. Without it, traffic can be routed to crashed instances.
What to try next
- Add a
/versionendpoint that returns the model version, training date, and dataset it was trained on, useful for debugging production issues., Explore Gradio or Streamlit for building a demo UI in pure Python without writing any HTML or JavaScript., Read about model serving platforms: Hugging Face Inference Endpoints, Google Vertex AI, AWS SageMaker., Look into concept drift detection libraries likeevidentlyfor automated monitoring of production ML systems.
You have now completed the AI / ML path. The next frontier: pick a project, get a real dataset, and ship something.
Prefer watching over reading?
Subscribe for free.