Skip to main content

Overview

RAXE integrates with LangChain to protect your chains and agents from prompt injection and other threats.

Installation

pip install raxe langchain

Callback Handler

Use the RAXE callback handler to scan inputs and outputs:
from langchain_openai import ChatOpenAI
from langchain.schema import HumanMessage
from raxe.sdk.integrations.langchain import RaxeCallbackHandler

# Create callback handler
handler = RaxeCallbackHandler(
    scan_inputs=True,
    scan_outputs=True,
    block_on_threat=True
)

# Use with LangChain
llm = ChatOpenAI(
    model="gpt-4",
    callbacks=[handler]
)

# Scans are automatic
response = llm.invoke([HumanMessage(content="Hello, how are you?")])

Configuration Options

handler = RaxeCallbackHandler(
    # What to scan
    scan_inputs=True,       # Scan user inputs
    scan_outputs=True,      # Scan LLM outputs

    # Threat response
    block_on_threat=True,   # Block if threat detected
    min_severity="HIGH",    # Minimum severity to block

    # Logging
    log_scans=True,         # Log all scans
)

Chain Integration

from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from raxe.sdk.integrations.langchain import RaxeCallbackHandler

# Setup
handler = RaxeCallbackHandler(block_on_threat=True)
llm = ChatOpenAI(model="gpt-4")

prompt = PromptTemplate(
    input_variables=["question"],
    template="Answer this question: {question}"
)

chain = LLMChain(
    llm=llm,
    prompt=prompt,
    callbacks=[handler]
)

# All inputs are scanned
result = chain.run(question="What is machine learning?")

Agent Integration

from langchain.agents import create_react_agent, AgentExecutor
from langchain_openai import ChatOpenAI
from langchain import hub
from raxe.sdk.integrations.langchain import RaxeCallbackHandler

# Create handler
handler = RaxeCallbackHandler(
    block_on_threat=True,
    scan_inputs=True,
    scan_outputs=True
)

# Setup agent
llm = ChatOpenAI(model="gpt-4")
prompt = hub.pull("hwchase17/react")
tools = []  # Your tools here

agent = create_react_agent(llm, tools, prompt)
agent_executor = AgentExecutor(
    agent=agent,
    tools=tools,
    callbacks=[handler]
)

# Execute with protection
result = agent_executor.invoke({"input": "Hello"})

RAG Protection

Protect RAG pipelines from data exfiltration:
from langchain.chains import RetrievalQA
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from raxe.sdk.integrations.langchain import RaxeCallbackHandler

# Create handler with RAG focus
handler = RaxeCallbackHandler(
    scan_inputs=True,
    scan_outputs=True,
    block_on_threat=True
)

# Setup RAG chain
llm = ChatOpenAI(model="gpt-4")
embeddings = OpenAIEmbeddings()
vectorstore = Chroma(embedding_function=embeddings)

qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=vectorstore.as_retriever(),
    callbacks=[handler]
)

# Protected queries
result = qa_chain.invoke({"query": "What are our policies?"})

Error Handling

from raxe.exceptions import RaxeBlockedError

handler = RaxeCallbackHandler(block_on_threat=True)

try:
    result = chain.run(question=user_input)
except RaxeBlockedError as e:
    print(f"Blocked: {e.scan_result.highest_severity}")
    print(f"Detections: {len(e.scan_result.detections)}")

Async Support

from raxe.sdk.integrations.langchain import AsyncRaxeCallbackHandler

handler = AsyncRaxeCallbackHandler(
    block_on_threat=True
)

# Use with async chains
result = await chain.ainvoke({"question": "Hello"})

Best Practices

Enable both scan_inputs and scan_outputs for comprehensive protection:
handler = RaxeCallbackHandler(
    scan_inputs=True,   # Catch injection attempts
    scan_outputs=True   # Catch jailbreak responses
)
Adjust min_severity based on your risk tolerance:
# Strict: Block on any threat
handler = RaxeCallbackHandler(min_severity="LOW")

# Lenient: Only block critical threats
handler = RaxeCallbackHandler(min_severity="CRITICAL")
Always catch RaxeBlockedError for user-friendly responses:
try:
    result = chain.run(user_input)
except RaxeBlockedError:
    return "I can't process that request."