Copy import boto3
import json
import requests
from datetime import datetime, timezone
ATHENA_API_KEY = "sk_live_xxxxx"
ATHENA_API_URL = "https://api.athenatrust.ai/v1"
def send_clarify_results_to_athena(
clarify_output_s3_uri: str,
model_id: str,
protected_attribute: str
):
"""
Parse Clarify processing job output and send to ATHENA.
Args:
clarify_output_s3_uri: S3 URI of Clarify analysis output
model_id: Your model identifier
protected_attribute: Name of the facet (for example, gender)
"""
# Download Clarify output from S3
s3 = boto3.client("s3")
bucket, key = parse_s3_uri(clarify_output_s3_uri)
response = s3.get_object(Bucket=bucket, Key=f"{key}/analysis.json")
analysis = json.loads(response["Body"].read().decode("utf-8"))
results = []
# Process pre training bias metrics
if "pre_training_bias_metrics" in analysis:
for facet in analysis["pre_training_bias_metrics"]["facets"]:
if facet["name_or_index"] == protected_attribute:
for metric in facet["metrics"]:
result = send_metric_to_athena(
metric_name=clarify_to_athena_metric(metric["name"]),
metric_value=normalize_clarify_metric(metric["name"], metric["value"]),
original_value=metric["value"],
description=metric.get("description", ""),
model_id=model_id,
protected_attribute=protected_attribute,
privileged_group=str(facet.get("value_or_threshold", "1")),
stage="pre_training"
)
results.append(result)
# Process post training bias metrics
if "post_training_bias_metrics" in analysis:
for facet in analysis["post_training_bias_metrics"]["facets"]:
if facet["name_or_index"] == protected_attribute:
for metric in facet["metrics"]:
result = send_metric_to_athena(
metric_name=clarify_to_athena_metric(metric["name"]),
metric_value=normalize_clarify_metric(metric["name"], metric["value"]),
original_value=metric["value"],
description=metric.get("description", ""),
model_id=model_id,
protected_attribute=protected_attribute,
privileged_group=str(facet.get("value_or_threshold", "1")),
stage="post_training"
)
results.append(result)
return results
def send_metric_to_athena(
metric_name: str,
metric_value: float,
original_value: float,
description: str,
model_id: str,
protected_attribute: str,
privileged_group: str,
stage: str
):
"""Send a single Clarify metric to ATHENA."""
# Define thresholds based on Clarify recommendations
thresholds = {
"demographic_parity": 0.1,
"disparate_impact": 0.8,
"conditional_acceptance": 0.1,
"treatment_equality": 0.1,
"flip_rate": 0.1,
"class_imbalance": 0.1,
"dpl": 0.1
}
threshold = thresholds.get(metric_name, 0.1)
if metric_name == "disparate_impact":
passes_threshold = metric_value >= threshold
else:
passes_threshold = abs(metric_value) <= threshold
payload = {
"externalToolId": "aws_clarify",
"modelId": model_id,
"metricName": metric_name,
"metricValue": metric_value,
"threshold": threshold,
"passesThreshold": passes_threshold,
"protectedAttribute": protected_attribute,
"privilegedGroup": privileged_group,
"rawPayload": {
"original_value": original_value,
"description": description,
"stage": stage,
"source": "sagemaker_clarify"
},
"signalTimestamp": datetime.now(timezone.utc).isoformat()
}
response = requests.post(
f"{ATHENA_API_URL}/model-fairness-signals",
headers={
"Authorization": f"Bearer {ATHENA_API_KEY}",
"Content-Type": "application/json"
},
json=payload
)
return {
"metric": metric_name,
"stage": stage,
"original_value": original_value,
"status": "success" if response.status_code == 201 else "failed",
"signalId": response.json().get("signalId") if response.status_code == 201 else None
}
def clarify_to_athena_metric(clarify_name: str) -> str:
"""Map Clarify metric names to ATHENA names."""
mapping = {
"CI": "class_imbalance",
"DPL": "dpl",
"DPPL": "demographic_parity",
"DI": "disparate_impact",
"DCAcc": "conditional_acceptance",
"TE": "treatment_equality",
"FT": "flip_rate",
"AD": "accuracy_difference",
"CDDPL": "conditional_demographic_disparity"
}
return mapping.get(clarify_name, "custom")
def normalize_clarify_metric(name: str, value: float) -> float:
"""Normalize Clarify metrics to 0 to 1 range."""
if name == "DI": # Disparate Impact
return min(1.0, max(0.0, value / 2))
else:
# Most Clarify metrics are in negative 1 to 1 range
return (value + 1) / 2
def parse_s3_uri(uri: str):
"""Parse S3 URI into bucket and key."""
parts = uri.replace("s3://", "").split("/", 1)
return parts[0], parts[1] if len(parts) > 1 else ""