Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make red teamer surface errors #1309

Merged
merged 4 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 67 additions & 173 deletions deepeval/red_teaming/attack_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from deepeval.red_teaming.types import (
AttackEnhancement,
NonRemoteVulnerability,
VulnerabilityType,
CallbackType,
)
Expand Down Expand Up @@ -72,6 +71,7 @@ def generate_attacks(
attacks_per_vulnerability_type: int,
vulnerabilities: List[BaseVulnerability],
attack_enhancements: Dict[AttackEnhancement, float],
ignore_errors: bool,
) -> List[Attack]:
# Generate unenhanced attacks for each vulnerability
base_attacks: List[Attack] = []
Expand All @@ -85,8 +85,9 @@ def generate_attacks(
for vulnerability in pbar:
base_attacks.extend(
self.generate_base_attacks(
attacks_per_vulnerability_type,
vulnerability,
attacks_per_vulnerability_type=attacks_per_vulnerability_type,
vulnerability=vulnerability,
ignore_errors=ignore_errors,
)
)

Expand All @@ -111,6 +112,7 @@ def generate_attacks(
target_model_callback=target_model_callback,
base_attack=base_attack,
attack_enhancement=sampled_enhancement,
ignore_errors=ignore_errors,
)
enhanced_attacks.append(enhanced_attack)

Expand All @@ -123,10 +125,11 @@ async def a_generate_attacks(
attacks_per_vulnerability_type: int,
vulnerabilities: List[BaseVulnerability],
attack_enhancements: Dict[AttackEnhancement, float],
max_concurrent_tasks: int = 10,
ignore_errors: bool,
max_concurrent: int = 10,
) -> List[Attack]:
# Create a semaphore to control the number of concurrent tasks
semaphore = asyncio.Semaphore(max_concurrent_tasks)
semaphore = asyncio.Semaphore(max_concurrent)

# Generate unenhanced attacks for each vulnerability
base_attacks: List[Attack] = []
Expand All @@ -141,7 +144,9 @@ async def a_generate_attacks(
async def throttled_generate_base_attack(vulnerability):
async with semaphore: # Throttling applied here
result = await self.a_generate_base_attacks(
attacks_per_vulnerability_type, vulnerability
attacks_per_vulnerability_type=attacks_per_vulnerability_type,
vulnerability=vulnerability,
ignore_errors=ignore_errors,
)
pbar.update(1)
return result
Expand Down Expand Up @@ -175,6 +180,7 @@ async def throttled_attack_enhancement(base_attack):
target_model_callback=target_model_callback,
base_attack=base_attack,
attack_enhancement=sampled_enhancement,
ignore_errors=ignore_errors,
)
pbar.update(1)
return result
Expand Down Expand Up @@ -204,34 +210,29 @@ def generate_base_attacks(
self,
attacks_per_vulnerability_type: int,
vulnerability: BaseVulnerability,
max_retries: int = 5,
ignore_errors: bool,
) -> List[Attack]:
base_attacks: List[Attack] = []
# Remote vulnerabilities
if not isinstance(BaseVulnerability, NonRemoteVulnerability):
if not is_confident():
raise Exception(
f"To generate attacks for '{vulnerability.get_name()}', login to Confident AI by running `deepeval login`"
)

for vulnerability_type in vulnerability.get_types():
try:
remote_attacks = self.generate_remote_attack(
self.purpose,
vulnerability_type,
attacks_per_vulnerability_type,
)
base_attacks.extend(
[
Attack(
vulnerability=vulnerability.get_name(),
vulnerability_type=vulnerability_type,
input=remote_attack,
)
for remote_attack in remote_attacks
]
)
except:
for vulnerability_type in vulnerability.get_types():
try:
remote_attacks = self.generate_remote_attack(
self.purpose,
vulnerability_type,
attacks_per_vulnerability_type,
)
base_attacks.extend(
[
Attack(
vulnerability=vulnerability.get_name(),
vulnerability_type=vulnerability_type,
input=remote_attack,
)
for remote_attack in remote_attacks
]
)
except:
if ignore_errors:
for _ in range(attacks_per_vulnerability_type):
base_attacks.append(
Attack(
Expand All @@ -240,96 +241,36 @@ def generate_base_attacks(
error="Error generating aligned attacks.",
)
)

# Aligned vulnerabilities: LLMs can generate
else:
for vulnerability_type in vulnerability.get_types():
prompt = RedTeamSynthesizerTemplate.generate_attacks(
attacks_per_vulnerability_type,
vulnerability_type,
self.purpose,
)

# Generate attacks with retries
for i in range(max_retries):
try:
res: SyntheticDataList = self._generate_schema(
prompt, SyntheticDataList
)
compliance_prompt = (
RedTeamSynthesizerTemplate.non_compliant(
res.model_dump()
)
)
compliance_res: ComplianceData = self._generate_schema(
compliance_prompt, ComplianceData
)

if not compliance_res.non_compliant:
base_attacks.extend(
Attack(
input=attack.input,
vulnerability=vulnerability.get_name(),
vulnerability_type=vulnerability_type,
)
for attack in res.data
)
break

if i == max_retries - 1:
base_attacks = [
Attack(
vulnerability=vulnerability.get_name(),
vulnerability_type=vulnerability_type,
error="Error generating compliant attacks.",
)
for _ in range(attacks_per_vulnerability_type)
]
except:
if i == max_retries - 1:
base_attacks = [
Attack(
vulnerability=vulnerability.get_name(),
vulnerability_type=vulnerability_type,
error="Error generating aligned attacks.",
)
for _ in range(attacks_per_vulnerability_type)
]
else:
raise
return base_attacks

async def a_generate_base_attacks(
self,
attacks_per_vulnerability_type: int,
vulnerability: BaseVulnerability,
max_retries: int = 5,
ignore_errors: bool,
) -> List[Attack]:
base_attacks: List[Attack] = []

# Remote vulnerabilities
if not isinstance(vulnerability, NonRemoteVulnerability):
if not is_confident():
raise Exception(
f"To generate attacks for '{vulnerability.get_name()}', login to Confident AI by running `deepeval login`"
for vulnerability_type in vulnerability.get_types():
try:
remote_attacks = self.generate_remote_attack(
self.purpose,
vulnerability_type,
attacks_per_vulnerability_type,
)

for vulnerability_type in vulnerability.get_types():
try:
remote_attacks = self.generate_remote_attack(
self.purpose,
vulnerability_type,
attacks_per_vulnerability_type,
)
base_attacks.extend(
[
Attack(
vulnerability=vulnerability.get_name(),
vulnerability_type=vulnerability_type,
input=remote_attack,
)
for remote_attack in remote_attacks
]
)
except:
base_attacks.extend(
[
Attack(
vulnerability=vulnerability.get_name(),
vulnerability_type=vulnerability_type,
input=remote_attack,
)
for remote_attack in remote_attacks
]
)
except:
if ignore_errors:
for _ in range(attacks_per_vulnerability_type):
base_attacks.append(
Attack(
Expand All @@ -338,63 +279,8 @@ async def a_generate_base_attacks(
error="Error generating aligned attacks.",
)
)

# Aligned vulnerabilities: LLMs can generate
else:
for vulnerability_type in vulnerability.get_types():
prompt = RedTeamSynthesizerTemplate.generate_attacks(
attacks_per_vulnerability_type,
vulnerability_type,
self.purpose,
)

# Generate attacks with retries
for i in range(max_retries):
try:
res: SyntheticDataList = await self._a_generate_schema(
prompt, SyntheticDataList
)
compliance_prompt = (
RedTeamSynthesizerTemplate.non_compliant(
res.model_dump()
)
)
compliance_res: ComplianceData = (
await self._a_generate_schema(
compliance_prompt, ComplianceData
)
)

if not compliance_res.non_compliant:
base_attacks.extend(
Attack(
input=attack.input,
vulnerability=vulnerability.get_name(),
vulnerability_type=vulnerability_type,
)
for attack in res.data
)
break

if i == max_retries - 1:
base_attacks = [
Attack(
vulnerability=vulnerability.get_name(),
vulnerability_type=vulnerability_type,
error="Error generating compliant attacks.",
)
for _ in range(attacks_per_vulnerability_type)
]
except:
if i == max_retries - 1:
base_attacks = [
Attack(
vulnerability=vulnerability.get_name(),
vulnerability_type=vulnerability_type,
error="Error generating aligned attacks.",
)
for _ in range(attacks_per_vulnerability_type)
]
else:
raise
return base_attacks

##################################################
Expand All @@ -406,6 +292,7 @@ def enhance_attack(
target_model_callback: CallbackType,
base_attack: Attack,
attack_enhancement: AttackEnhancement,
ignore_errors: bool,
jailbreaking_iterations: int = 5,
):
attack_input = base_attack.input
Expand Down Expand Up @@ -478,8 +365,11 @@ def enhance_attack(
).enhance(attack_input)
base_attack.input = enhanced_attack
except:
base_attack.error = "Error enhancing attack"
return base_attack
if ignore_errors:
base_attack.error = "Error enhancing attack"
return base_attack
else:
raise

return base_attack

Expand All @@ -488,6 +378,7 @@ async def a_enhance_attack(
target_model_callback: CallbackType,
base_attack: Attack,
attack_enhancement: AttackEnhancement,
ignore_errors: bool,
jailbreaking_iterations: int = 5,
):
attack_input = base_attack.input
Expand Down Expand Up @@ -562,8 +453,11 @@ async def a_enhance_attack(
).a_enhance(attack_input)
base_attack.input = enhanced_attack
except:
base_attack.error = "Error enhancing attack"
return base_attack
if ignore_errors:
base_attack.error = "Error enhancing attack"
return base_attack
else:
raise

return base_attack

Expand Down
Loading
Loading