FastAPI
Middleware Pattern
Scan all incoming requests automatically:Copy
from fastapi import FastAPI, Request, HTTPException
from raxe import Raxe
app = FastAPI()
raxe = Raxe()
@app.middleware("http")
async def raxe_middleware(request: Request, call_next):
# Only scan POST/PUT requests with JSON body
if request.method in ("POST", "PUT"):
try:
body = await request.json()
# Scan relevant fields
if "prompt" in body:
result = raxe.scan(body["prompt"])
if result.has_threats:
raise HTTPException(
status_code=400,
detail={
"error": "Security threat detected",
"severity": result.severity,
"blocked": True
}
)
except ValueError:
pass # Not JSON, skip
return await call_next(request)
@app.post("/chat")
async def chat(prompt: str):
# Already scanned by middleware
return {"response": generate_response(prompt)}
Dependency Injection
Use FastAPI dependencies for cleaner code:Copy
from fastapi import Depends, HTTPException
from raxe import Raxe
raxe = Raxe()
async def scan_prompt(prompt: str) -> str:
"""Dependency that scans and returns the prompt."""
result = raxe.scan(prompt)
if result.has_threats:
raise HTTPException(
status_code=400,
detail=f"Blocked: {result.severity} threat detected"
)
return prompt
@app.post("/generate")
async def generate(prompt: str = Depends(scan_prompt)):
# prompt is already validated
return {"response": llm.generate(prompt)}
Async with AsyncRaxe
For high-throughput APIs:Copy
from fastapi import FastAPI
from contextlib import asynccontextmanager
from raxe import AsyncRaxe
raxe: AsyncRaxe = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global raxe
raxe = AsyncRaxe()
yield
await raxe.close()
app = FastAPI(lifespan=lifespan)
@app.post("/chat")
async def chat(prompt: str):
result = await raxe.scan(prompt)
if result.has_threats:
return {"error": "Threat detected", "severity": result.severity}
return {"response": await generate_async(prompt)}
Flask
Before Request Hook
Copy
from flask import Flask, request, jsonify
from raxe import Raxe
app = Flask(__name__)
raxe = Raxe()
@app.before_request
def scan_request():
if request.method in ("POST", "PUT") and request.is_json:
data = request.get_json()
# Scan prompt field if present
if "prompt" in data:
result = raxe.scan(data["prompt"])
if result.has_threats:
return jsonify({
"error": "Security threat detected",
"severity": result.severity
}), 400
@app.route("/chat", methods=["POST"])
def chat():
data = request.get_json()
response = generate_response(data["prompt"])
return jsonify({"response": response})
Decorator Pattern
Copy
from functools import wraps
from flask import request, jsonify
from raxe import Raxe
raxe = Raxe()
def require_safe_prompt(f):
@wraps(f)
def decorated(*args, **kwargs):
data = request.get_json()
prompt = data.get("prompt", "")
result = raxe.scan(prompt)
if result.has_threats:
return jsonify({
"error": "Blocked",
"severity": result.severity,
"detections": result.total_detections
}), 400
return f(*args, **kwargs)
return decorated
@app.route("/generate", methods=["POST"])
@require_safe_prompt
def generate():
data = request.get_json()
return jsonify({"response": llm.generate(data["prompt"])})
Django
Middleware
Copy
# myapp/middleware.py
import json
from django.http import JsonResponse
from raxe import Raxe
class RaxeMiddleware:
def __init__(self, get_response):
self.get_response = get_response
self.raxe = Raxe()
def __call__(self, request):
if request.method in ("POST", "PUT"):
try:
body = json.loads(request.body)
if "prompt" in body:
result = self.raxe.scan(body["prompt"])
if result.has_threats:
return JsonResponse({
"error": "Security threat detected",
"severity": result.severity
}, status=400)
except (json.JSONDecodeError, UnicodeDecodeError):
pass
return self.get_response(request)
settings.py:
Copy
MIDDLEWARE = [
# ... other middleware
'myapp.middleware.RaxeMiddleware',
]
View Decorator
Copy
# myapp/decorators.py
from functools import wraps
from django.http import JsonResponse
import json
from raxe import Raxe
raxe = Raxe()
def raxe_protected(view_func):
@wraps(view_func)
def wrapper(request, *args, **kwargs):
if request.method in ("POST", "PUT"):
try:
body = json.loads(request.body)
prompt = body.get("prompt", "")
result = raxe.scan(prompt)
if result.has_threats:
return JsonResponse({
"error": "Threat detected",
"severity": result.severity
}, status=400)
except json.JSONDecodeError:
pass
return view_func(request, *args, **kwargs)
return wrapper
# Usage in views.py
@raxe_protected
def chat_view(request):
body = json.loads(request.body)
response = generate_response(body["prompt"])
return JsonResponse({"response": response})
Django REST Framework
Copy
# myapp/views.py
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from raxe import Raxe
raxe = Raxe()
class ChatView(APIView):
def post(self, request):
prompt = request.data.get("prompt", "")
# Scan the prompt
result = raxe.scan(prompt)
if result.has_threats:
return Response({
"error": "Security threat detected",
"severity": result.severity,
"detections": result.total_detections
}, status=status.HTTP_400_BAD_REQUEST)
# Safe to proceed
response = generate_response(prompt)
return Response({"response": response})
Async Queue Processing
For background job processing:Copy
import asyncio
from raxe import AsyncRaxe
async def process_queue(queue: asyncio.Queue):
async with AsyncRaxe() as raxe:
while True:
job = await queue.get()
# Scan before processing
result = await raxe.scan(job["prompt"])
if result.has_threats:
await mark_job_failed(job, reason=f"Threat: {result.severity}")
else:
await process_job(job)
queue.task_done()
# Start workers
async def main():
queue = asyncio.Queue()
# Start 5 workers
workers = [
asyncio.create_task(process_queue(queue))
for _ in range(5)
]
# Add jobs to queue
for job in jobs:
await queue.put(job)
await queue.join()
Batch Processing
For processing large datasets:Copy
from raxe import AsyncRaxe
async def scan_dataset(prompts: list[str]) -> dict:
async with AsyncRaxe() as raxe:
results = await raxe.scan_batch(
prompts,
max_concurrency=20
)
safe = []
threats = []
for prompt, result in zip(prompts, results):
if result.has_threats:
threats.append({
"prompt": prompt,
"severity": result.severity,
"detections": result.total_detections
})
else:
safe.append(prompt)
return {
"safe_count": len(safe),
"threat_count": len(threats),
"threats": threats
}
Error Handling Pattern
Consistent error handling across your application:Copy
from raxe import Raxe
from raxe.sdk.exceptions import SecurityException, RaxeException
raxe = Raxe()
def safe_scan(prompt: str) -> dict:
"""Scan with comprehensive error handling."""
try:
result = raxe.scan(prompt, block_on_threat=True)
return {
"safe": True,
"duration_ms": result.duration_ms
}
except SecurityException as e:
# Threat was detected and blocked
return {
"safe": False,
"severity": e.result.severity,
"detections": e.result.total_detections,
"message": str(e)
}
except RaxeException as e:
# Other RAXE errors (config, validation, etc.)
return {
"error": True,
"message": str(e)
}
Logging Integration
Structured logging for monitoring:Copy
import logging
import json
from raxe import Raxe
logger = logging.getLogger("raxe.security")
raxe = Raxe()
def scan_with_logging(prompt: str, user_id: str = None) -> bool:
result = raxe.scan(prompt)
if result.has_threats:
logger.warning(
json.dumps({
"event": "threat_detected",
"severity": result.severity,
"total_detections": result.total_detections,
"duration_ms": result.duration_ms,
"user_id": user_id,
"rules": [d.rule_id for d in result.detections]
})
)
return False
logger.debug(
json.dumps({
"event": "scan_safe",
"duration_ms": result.duration_ms,
"user_id": user_id
})
)
return True
