168 lines
5.6 KiB
TypeScript
168 lines
5.6 KiB
TypeScript
|
|
/**
|
||
|
|
* AI Prompt Injection Guard
|
||
|
|
*
|
||
|
|
* Detects and strips common prompt injection patterns from user-supplied text
|
||
|
|
* before passing to AI services. Called before every AI service that uses
|
||
|
|
* user-supplied criteria, descriptions, or free-text fields.
|
||
|
|
*
|
||
|
|
* Patterns detected:
|
||
|
|
* - ChatML tags (<|im_start|>, <|im_end|>, <|endoftext|>)
|
||
|
|
* - Role impersonation ("system:", "assistant:")
|
||
|
|
* - Instruction override ("ignore previous instructions", "disregard above")
|
||
|
|
* - Encoded injection attempts (base64 encoded instructions, unicode tricks)
|
||
|
|
*/
|
||
|
|
|
||
|
|
// ─── Injection Patterns ─────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
const CHATML_PATTERN = /<\|(?:im_start|im_end|endoftext|system|user|assistant)\|>/gi
|
||
|
|
|
||
|
|
const ROLE_IMPERSONATION_PATTERN =
|
||
|
|
/^\s*(?:system|assistant|user|human|ai|bot)\s*:/gim
|
||
|
|
|
||
|
|
const INSTRUCTION_OVERRIDE_PATTERNS = [
|
||
|
|
/ignore\s+(?:all\s+)?(?:previous|prior|above|earlier)\s+instructions?/gi,
|
||
|
|
/disregard\s+(?:all\s+)?(?:previous|prior|above|earlier)\s+(?:instructions?|prompts?|context)/gi,
|
||
|
|
/forget\s+(?:all\s+)?(?:previous|prior|above|earlier)\s+(?:instructions?|prompts?|context)/gi,
|
||
|
|
/override\s+(?:all\s+)?(?:previous|prior|above|earlier)\s+(?:instructions?|prompts?)/gi,
|
||
|
|
/you\s+are\s+now\s+(?:a|an)\s+/gi,
|
||
|
|
/new\s+instructions?\s*:/gi,
|
||
|
|
/begin\s+(?:new\s+)?(?:prompt|instructions?|session)/gi,
|
||
|
|
/\[INST\]/gi,
|
||
|
|
/\[\/INST\]/gi,
|
||
|
|
/<<SYS>>/gi,
|
||
|
|
/<\/SYS>>/gi,
|
||
|
|
]
|
||
|
|
|
||
|
|
const ENCODED_INJECTION_PATTERNS = [
|
||
|
|
// Base64 encoded common injection phrases
|
||
|
|
/aWdub3JlIHByZXZpb3Vz/gi, // "ignore previous" in base64
|
||
|
|
/ZGlzcmVnYXJkIGFib3Zl/gi, // "disregard above" in base64
|
||
|
|
]
|
||
|
|
|
||
|
|
// ─── Types ──────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
export type SanitizationResult = {
|
||
|
|
sanitized: string
|
||
|
|
wasModified: boolean
|
||
|
|
detectedPatterns: string[]
|
||
|
|
}
|
||
|
|
|
||
|
|
// ─── Core Functions ─────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Sanitize user-supplied text by stripping injection patterns.
|
||
|
|
* Returns the sanitized text and metadata about what was detected/removed.
|
||
|
|
*/
|
||
|
|
export function sanitizeUserInput(text: string): SanitizationResult {
|
||
|
|
if (!text || typeof text !== 'string') {
|
||
|
|
return { sanitized: '', wasModified: false, detectedPatterns: [] }
|
||
|
|
}
|
||
|
|
|
||
|
|
let sanitized = text
|
||
|
|
const detectedPatterns: string[] = []
|
||
|
|
|
||
|
|
// Strip ChatML tags
|
||
|
|
if (CHATML_PATTERN.test(sanitized)) {
|
||
|
|
detectedPatterns.push('ChatML tags')
|
||
|
|
sanitized = sanitized.replace(CHATML_PATTERN, '')
|
||
|
|
}
|
||
|
|
|
||
|
|
// Strip role impersonation
|
||
|
|
if (ROLE_IMPERSONATION_PATTERN.test(sanitized)) {
|
||
|
|
detectedPatterns.push('Role impersonation')
|
||
|
|
sanitized = sanitized.replace(ROLE_IMPERSONATION_PATTERN, '')
|
||
|
|
}
|
||
|
|
|
||
|
|
// Strip instruction overrides
|
||
|
|
for (const pattern of INSTRUCTION_OVERRIDE_PATTERNS) {
|
||
|
|
// Reset lastIndex for global patterns
|
||
|
|
pattern.lastIndex = 0
|
||
|
|
if (pattern.test(sanitized)) {
|
||
|
|
detectedPatterns.push('Instruction override attempt')
|
||
|
|
pattern.lastIndex = 0
|
||
|
|
sanitized = sanitized.replace(pattern, '[FILTERED]')
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Strip encoded injections
|
||
|
|
for (const pattern of ENCODED_INJECTION_PATTERNS) {
|
||
|
|
pattern.lastIndex = 0
|
||
|
|
if (pattern.test(sanitized)) {
|
||
|
|
detectedPatterns.push('Encoded injection')
|
||
|
|
pattern.lastIndex = 0
|
||
|
|
sanitized = sanitized.replace(pattern, '[FILTERED]')
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Trim excessive whitespace left by removals
|
||
|
|
sanitized = sanitized.replace(/\n{3,}/g, '\n\n').trim()
|
||
|
|
|
||
|
|
const wasModified = sanitized !== text.trim()
|
||
|
|
|
||
|
|
if (wasModified && detectedPatterns.length > 0) {
|
||
|
|
console.warn(
|
||
|
|
`[PromptGuard] Detected injection patterns in user input: ${detectedPatterns.join(', ')}`
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
return { sanitized, wasModified, detectedPatterns }
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Quick check: does the text contain any injection patterns?
|
||
|
|
* Faster than full sanitization when you only need a boolean check.
|
||
|
|
*/
|
||
|
|
export function containsInjectionPatterns(text: string): boolean {
|
||
|
|
if (!text) return false
|
||
|
|
|
||
|
|
if (CHATML_PATTERN.test(text)) return true
|
||
|
|
CHATML_PATTERN.lastIndex = 0
|
||
|
|
|
||
|
|
if (ROLE_IMPERSONATION_PATTERN.test(text)) return true
|
||
|
|
ROLE_IMPERSONATION_PATTERN.lastIndex = 0
|
||
|
|
|
||
|
|
for (const pattern of INSTRUCTION_OVERRIDE_PATTERNS) {
|
||
|
|
pattern.lastIndex = 0
|
||
|
|
if (pattern.test(text)) return true
|
||
|
|
}
|
||
|
|
|
||
|
|
for (const pattern of ENCODED_INJECTION_PATTERNS) {
|
||
|
|
pattern.lastIndex = 0
|
||
|
|
if (pattern.test(text)) return true
|
||
|
|
}
|
||
|
|
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Sanitize all string values in a criteria/config object.
|
||
|
|
* Recursively processes nested objects and arrays.
|
||
|
|
*/
|
||
|
|
export function sanitizeCriteriaObject(
|
||
|
|
obj: Record<string, unknown>
|
||
|
|
): { sanitized: Record<string, unknown>; detectedPatterns: string[] } {
|
||
|
|
const allDetected: string[] = []
|
||
|
|
|
||
|
|
function processValue(value: unknown): unknown {
|
||
|
|
if (typeof value === 'string') {
|
||
|
|
const result = sanitizeUserInput(value)
|
||
|
|
allDetected.push(...result.detectedPatterns)
|
||
|
|
return result.sanitized
|
||
|
|
}
|
||
|
|
if (Array.isArray(value)) {
|
||
|
|
return value.map(processValue)
|
||
|
|
}
|
||
|
|
if (value && typeof value === 'object') {
|
||
|
|
const processed: Record<string, unknown> = {}
|
||
|
|
for (const [k, v] of Object.entries(value as Record<string, unknown>)) {
|
||
|
|
processed[k] = processValue(v)
|
||
|
|
}
|
||
|
|
return processed
|
||
|
|
}
|
||
|
|
return value
|
||
|
|
}
|
||
|
|
|
||
|
|
const sanitized = processValue(obj) as Record<string, unknown>
|
||
|
|
return { sanitized, detectedPatterns: [...new Set(allDetected)] }
|
||
|
|
}
|