Skip to main content

FastAPI

Middleware Pattern

Scan all incoming requests automatically:
from fastapi import FastAPI, Request, HTTPException
from raxe import Raxe

app = FastAPI()
raxe = Raxe()

@app.middleware("http")
async def raxe_middleware(request: Request, call_next):
    # Only scan POST/PUT requests with JSON body
    if request.method in ("POST", "PUT"):
        try:
            body = await request.json()

            # Scan relevant fields
            if "prompt" in body:
                result = raxe.scan(body["prompt"])
                if result.has_threats:
                    raise HTTPException(
                        status_code=400,
                        detail={
                            "error": "Security threat detected",
                            "severity": result.severity,
                            "blocked": True
                        }
                    )
        except ValueError:
            pass  # Not JSON, skip

    return await call_next(request)

@app.post("/chat")
async def chat(prompt: str):
    # Already scanned by middleware
    return {"response": generate_response(prompt)}

Dependency Injection

Use FastAPI dependencies for cleaner code:
from fastapi import Depends, HTTPException
from raxe import Raxe

raxe = Raxe()

async def scan_prompt(prompt: str) -> str:
    """Dependency that scans and returns the prompt."""
    result = raxe.scan(prompt)
    if result.has_threats:
        raise HTTPException(
            status_code=400,
            detail=f"Blocked: {result.severity} threat detected"
        )
    return prompt

@app.post("/generate")
async def generate(prompt: str = Depends(scan_prompt)):
    # prompt is already validated
    return {"response": llm.generate(prompt)}

Async with AsyncRaxe

For high-throughput APIs:
from fastapi import FastAPI
from contextlib import asynccontextmanager
from raxe import AsyncRaxe

raxe: AsyncRaxe = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    global raxe
    raxe = AsyncRaxe()
    yield
    await raxe.close()

app = FastAPI(lifespan=lifespan)

@app.post("/chat")
async def chat(prompt: str):
    result = await raxe.scan(prompt)
    if result.has_threats:
        return {"error": "Threat detected", "severity": result.severity}

    return {"response": await generate_async(prompt)}

Flask

Before Request Hook

from flask import Flask, request, jsonify
from raxe import Raxe

app = Flask(__name__)
raxe = Raxe()

@app.before_request
def scan_request():
    if request.method in ("POST", "PUT") and request.is_json:
        data = request.get_json()

        # Scan prompt field if present
        if "prompt" in data:
            result = raxe.scan(data["prompt"])
            if result.has_threats:
                return jsonify({
                    "error": "Security threat detected",
                    "severity": result.severity
                }), 400

@app.route("/chat", methods=["POST"])
def chat():
    data = request.get_json()
    response = generate_response(data["prompt"])
    return jsonify({"response": response})

Decorator Pattern

from functools import wraps
from flask import request, jsonify
from raxe import Raxe

raxe = Raxe()

def require_safe_prompt(f):
    @wraps(f)
    def decorated(*args, **kwargs):
        data = request.get_json()
        prompt = data.get("prompt", "")

        result = raxe.scan(prompt)
        if result.has_threats:
            return jsonify({
                "error": "Blocked",
                "severity": result.severity,
                "detections": result.total_detections
            }), 400

        return f(*args, **kwargs)
    return decorated

@app.route("/generate", methods=["POST"])
@require_safe_prompt
def generate():
    data = request.get_json()
    return jsonify({"response": llm.generate(data["prompt"])})

Django

Middleware

# myapp/middleware.py
import json
from django.http import JsonResponse
from raxe import Raxe

class RaxeMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        self.raxe = Raxe()

    def __call__(self, request):
        if request.method in ("POST", "PUT"):
            try:
                body = json.loads(request.body)
                if "prompt" in body:
                    result = self.raxe.scan(body["prompt"])
                    if result.has_threats:
                        return JsonResponse({
                            "error": "Security threat detected",
                            "severity": result.severity
                        }, status=400)
            except (json.JSONDecodeError, UnicodeDecodeError):
                pass

        return self.get_response(request)
Add to settings.py:
MIDDLEWARE = [
    # ... other middleware
    'myapp.middleware.RaxeMiddleware',
]

View Decorator

# myapp/decorators.py
from functools import wraps
from django.http import JsonResponse
import json
from raxe import Raxe

raxe = Raxe()

def raxe_protected(view_func):
    @wraps(view_func)
    def wrapper(request, *args, **kwargs):
        if request.method in ("POST", "PUT"):
            try:
                body = json.loads(request.body)
                prompt = body.get("prompt", "")

                result = raxe.scan(prompt)
                if result.has_threats:
                    return JsonResponse({
                        "error": "Threat detected",
                        "severity": result.severity
                    }, status=400)
            except json.JSONDecodeError:
                pass

        return view_func(request, *args, **kwargs)
    return wrapper

# Usage in views.py
@raxe_protected
def chat_view(request):
    body = json.loads(request.body)
    response = generate_response(body["prompt"])
    return JsonResponse({"response": response})

Django REST Framework

# myapp/views.py
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from raxe import Raxe

raxe = Raxe()

class ChatView(APIView):
    def post(self, request):
        prompt = request.data.get("prompt", "")

        # Scan the prompt
        result = raxe.scan(prompt)
        if result.has_threats:
            return Response({
                "error": "Security threat detected",
                "severity": result.severity,
                "detections": result.total_detections
            }, status=status.HTTP_400_BAD_REQUEST)

        # Safe to proceed
        response = generate_response(prompt)
        return Response({"response": response})

Async Queue Processing

For background job processing:
import asyncio
from raxe import AsyncRaxe

async def process_queue(queue: asyncio.Queue):
    async with AsyncRaxe() as raxe:
        while True:
            job = await queue.get()

            # Scan before processing
            result = await raxe.scan(job["prompt"])

            if result.has_threats:
                await mark_job_failed(job, reason=f"Threat: {result.severity}")
            else:
                await process_job(job)

            queue.task_done()

# Start workers
async def main():
    queue = asyncio.Queue()

    # Start 5 workers
    workers = [
        asyncio.create_task(process_queue(queue))
        for _ in range(5)
    ]

    # Add jobs to queue
    for job in jobs:
        await queue.put(job)

    await queue.join()

Batch Processing

For processing large datasets:
from raxe import AsyncRaxe

async def scan_dataset(prompts: list[str]) -> dict:
    async with AsyncRaxe() as raxe:
        results = await raxe.scan_batch(
            prompts,
            max_concurrency=20
        )

        safe = []
        threats = []

        for prompt, result in zip(prompts, results):
            if result.has_threats:
                threats.append({
                    "prompt": prompt,
                    "severity": result.severity,
                    "detections": result.total_detections
                })
            else:
                safe.append(prompt)

        return {
            "safe_count": len(safe),
            "threat_count": len(threats),
            "threats": threats
        }

Error Handling Pattern

Consistent error handling across your application:
from raxe import Raxe
from raxe.sdk.exceptions import SecurityException, RaxeException

raxe = Raxe()

def safe_scan(prompt: str) -> dict:
    """Scan with comprehensive error handling."""
    try:
        result = raxe.scan(prompt, block_on_threat=True)
        return {
            "safe": True,
            "duration_ms": result.duration_ms
        }

    except SecurityException as e:
        # Threat was detected and blocked
        return {
            "safe": False,
            "severity": e.result.severity,
            "detections": e.result.total_detections,
            "message": str(e)
        }

    except RaxeException as e:
        # Other RAXE errors (config, validation, etc.)
        return {
            "error": True,
            "message": str(e)
        }

Logging Integration

Structured logging for monitoring:
import logging
import json
from raxe import Raxe

logger = logging.getLogger("raxe.security")
raxe = Raxe()

def scan_with_logging(prompt: str, user_id: str = None) -> bool:
    result = raxe.scan(prompt)

    if result.has_threats:
        logger.warning(
            json.dumps({
                "event": "threat_detected",
                "severity": result.severity,
                "total_detections": result.total_detections,
                "duration_ms": result.duration_ms,
                "user_id": user_id,
                "rules": [d.rule_id for d in result.detections]
            })
        )
        return False

    logger.debug(
        json.dumps({
            "event": "scan_safe",
            "duration_ms": result.duration_ms,
            "user_id": user_id
        })
    )
    return True