Backend: JWT middleware validates Clerk tokens on every request, extracts org ID from claims, enforces org-scoped queries via Supabase RLS. Frontend: ClerkProvider wraps the app, auth gate blocks unauthenticated access, UserButton in header, token injected into every API call. Supabase production wired to trust Clerk JWTs via Third-Party Auth integration. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
340 lines
12 KiB
Python
340 lines
12 KiB
Python
"""Signal API — FastAPI backend for CSV ingestion, scoring, and export."""
|
|
|
|
import csv
|
|
import io
|
|
import os
|
|
import sys
|
|
from datetime import date
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import jwt
|
|
from jwt import PyJWKClient, ExpiredSignatureError, InvalidTokenError
|
|
from fastapi import Depends, FastAPI, File, Header, HTTPException, UploadFile
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
# Ensure python-backend root is on path (works both locally and in Docker)
|
|
_backend_root = Path(__file__).parent.parent
|
|
if str(_backend_root) not in sys.path:
|
|
sys.path.insert(0, str(_backend_root))
|
|
|
|
from core.coverage_calculator import ShipmentRecord, calculate_batch
|
|
from core.audit_logger import AuditAction, log_event
|
|
from core.persistence import persist_export, persist_upload
|
|
from api.normalizer import normalize_csv
|
|
|
|
app = FastAPI(title="Signal API", version="1.0.0", docs_url="/docs")
|
|
|
|
# CORS — locked to Vercel frontend and localhost for dev.
|
|
# Set ALLOWED_ORIGINS in Railway as a comma-separated list for production.
|
|
_origins_env = os.getenv("ALLOWED_ORIGINS", "")
|
|
_allowed_origins: list[str] = (
|
|
[o.strip() for o in _origins_env.split(",") if o.strip()]
|
|
if _origins_env
|
|
else [
|
|
"http://localhost:5173",
|
|
"http://localhost:5174",
|
|
"http://127.0.0.1:5173",
|
|
]
|
|
)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=_allowed_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# API key auth — enforced when SIGNAL_API_KEY env var is set.
|
|
# In dev (no env var), all requests pass. In production, X-API-Key header is required.
|
|
_api_key = os.getenv("SIGNAL_API_KEY", "")
|
|
_clerk_jwks_url = os.getenv("CLERK_JWKS_URL", "")
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _get_jwks_client() -> PyJWKClient:
|
|
if not _clerk_jwks_url:
|
|
raise RuntimeError("CLERK_JWKS_URL is not set")
|
|
return PyJWKClient(_clerk_jwks_url, cache_keys=True, cache_jwk_set=True, lifespan=300)
|
|
|
|
|
|
def verify_clerk_token(authorization: str) -> dict:
|
|
"""Verify a Clerk Bearer JWT. Returns decoded claims or raises HTTP 401."""
|
|
if not authorization.startswith("Bearer "):
|
|
raise HTTPException(status_code=401, detail="Missing or malformed Authorization header")
|
|
token = authorization.removeprefix("Bearer ").strip()
|
|
try:
|
|
client = _get_jwks_client()
|
|
signing_key = client.get_signing_key_from_jwt(token)
|
|
return jwt.decode(token, signing_key.key, algorithms=["RS256"],
|
|
options={"verify_exp": True, "verify_nbf": True})
|
|
except ExpiredSignatureError:
|
|
raise HTTPException(status_code=401, detail="Token has expired")
|
|
except InvalidTokenError as exc:
|
|
raise HTTPException(status_code=401, detail=f"Invalid token: {exc}")
|
|
except RuntimeError:
|
|
raise HTTPException(status_code=503, detail="Auth service not configured")
|
|
except Exception:
|
|
raise HTTPException(status_code=401, detail="Token verification failed")
|
|
|
|
|
|
def require_auth(
|
|
authorization: str = Header(default=""),
|
|
x_api_key: str = Header(default=""),
|
|
) -> dict:
|
|
"""
|
|
Accept either a Clerk JWT (Authorization: Bearer) or a direct API key (X-Api-Key).
|
|
Returns a claims dict. API-key path returns a synthetic service-account dict.
|
|
"""
|
|
# Fast path: API key (for direct access / testing)
|
|
if _api_key and x_api_key == _api_key:
|
|
return {"sub": "service-account", "via": "api_key"}
|
|
# Clerk JWT path (frontend / production)
|
|
if authorization.startswith("Bearer "):
|
|
return verify_clerk_token(authorization)
|
|
# Dev mode: no keys set — allow all
|
|
if not _api_key and not _clerk_jwks_url:
|
|
return {"sub": "dev", "via": "open"}
|
|
raise HTTPException(status_code=401, detail="Authentication required")
|
|
|
|
DEVICE_DISPLAY = {
|
|
"dexcom_g7": "Dexcom G7",
|
|
"dexcom_g6": "Dexcom G6",
|
|
"freestyle_libre_2": "FreeStyle Libre 2",
|
|
"freestyle_libre_3": "FreeStyle Libre 3",
|
|
"omnipod_5": "Omnipod 5",
|
|
}
|
|
|
|
FLAG_LABELS = {
|
|
"OUT_OF_COVERAGE": "Supply Lapsed",
|
|
"VISIT_DUE": "Renewal Due",
|
|
"REFILL_WINDOW": "Resupply Ready",
|
|
"OK": "Active",
|
|
}
|
|
|
|
FLAG_ACTIONS = {
|
|
"OUT_OF_COVERAGE": "Contact Prescriber",
|
|
"VISIT_DUE": "Request Renewal",
|
|
"REFILL_WINDOW": "Initiate Resupply",
|
|
"OK": "No action needed",
|
|
}
|
|
|
|
|
|
class RecordOut(BaseModel):
|
|
patient_id: str
|
|
device_type: str
|
|
device_display: str
|
|
payer: str
|
|
component: str
|
|
days_until_coverage_end: int
|
|
days_until_visit_due: Optional[int] = None
|
|
flag: str
|
|
priority_score: int
|
|
coverage_end_date: str
|
|
next_visit_due_date: Optional[str] = None
|
|
action: str
|
|
status_label: str
|
|
reason: str
|
|
rule_version: str
|
|
|
|
|
|
class UploadResponse(BaseModel):
|
|
records: list[RecordOut]
|
|
total: int
|
|
skipped: int
|
|
skipped_reasons: list[str]
|
|
stats: dict
|
|
mapping_summary: dict
|
|
batch_id: Optional[str] = None
|
|
|
|
|
|
def _build_reason(flag_val: str, days_until_end: int, days_until_visit: Optional[int]) -> str:
|
|
if flag_val == "OUT_OF_COVERAGE":
|
|
ago = abs(days_until_end)
|
|
unit = "day" if ago == 1 else "days"
|
|
return f"Supply lapsed {ago} {unit} ago. Prescriber contact required before next shipment."
|
|
if flag_val == "VISIT_DUE":
|
|
if days_until_visit is not None and days_until_visit <= 0:
|
|
overdue = abs(days_until_visit)
|
|
unit = "day" if overdue == 1 else "days"
|
|
return f"Qualifying visit overdue by {overdue} {unit}. Confirm documentation immediately."
|
|
if days_until_visit is not None:
|
|
unit = "day" if days_until_visit == 1 else "days"
|
|
return f"Qualifying visit due in {days_until_visit} {unit}. Confirm visit documentation before resupply."
|
|
return "Qualifying visit renewal required. Confirm documentation before resupply."
|
|
if flag_val == "REFILL_WINDOW":
|
|
unit = "day" if days_until_end == 1 else "days"
|
|
return f"Coverage ends in {days_until_end} {unit}. Patient is within resupply window — initiate shipment now."
|
|
unit = "day" if days_until_end == 1 else "days"
|
|
return f"Coverage on track. Resupply window opens in approximately {days_until_end} {unit}."
|
|
|
|
|
|
def _to_record_out(r) -> RecordOut:
|
|
flag_val = r.flag.value if hasattr(r.flag, "value") else str(r.flag)
|
|
return RecordOut(
|
|
patient_id=r.patient_id,
|
|
device_type=r.device_type,
|
|
device_display=DEVICE_DISPLAY.get(r.device_type, r.device_type),
|
|
payer=r.payer,
|
|
component=r.component,
|
|
days_until_coverage_end=r.days_until_coverage_end,
|
|
days_until_visit_due=r.days_until_visit_due,
|
|
flag=flag_val,
|
|
priority_score=r.priority_score,
|
|
coverage_end_date=r.coverage_end_date.isoformat(),
|
|
next_visit_due_date=r.next_visit_due_date.isoformat() if r.next_visit_due_date else None,
|
|
action=FLAG_ACTIONS.get(flag_val, "Review"),
|
|
status_label=FLAG_LABELS.get(flag_val, flag_val),
|
|
reason=_build_reason(flag_val, r.days_until_coverage_end, r.days_until_visit_due),
|
|
rule_version=r.rule_version,
|
|
)
|
|
|
|
|
|
def _compute_stats(records: list[RecordOut]) -> dict:
|
|
flags = [r.flag for r in records]
|
|
return {
|
|
"total": len(records),
|
|
"supply_lapsed": flags.count("OUT_OF_COVERAGE"),
|
|
"renewal_due": flags.count("VISIT_DUE"),
|
|
"resupply_ready": flags.count("REFILL_WINDOW"),
|
|
"active": flags.count("OK"),
|
|
"prescriber_action": flags.count("OUT_OF_COVERAGE") + flags.count("VISIT_DUE"),
|
|
}
|
|
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
return {"status": "ok", "service": "signal-api", "version": "1.0.0"}
|
|
|
|
|
|
@app.get("/health/db")
|
|
def health_db():
|
|
from core.supabase_client import get_client
|
|
import os
|
|
client = get_client()
|
|
if not client:
|
|
return {
|
|
"status": "unavailable",
|
|
"supabase_url_set": bool(os.getenv("SUPABASE_URL")),
|
|
"service_key_set": bool(os.getenv("SUPABASE_SERVICE_KEY")),
|
|
}
|
|
try:
|
|
result = client.table("organizations").select("id").limit(1).execute()
|
|
return {"status": "ok", "org_count": len(result.data)}
|
|
except Exception as e:
|
|
return {"status": "error", "detail": str(e)}
|
|
|
|
|
|
@app.post("/api/upload", response_model=UploadResponse)
|
|
async def upload_csv(
|
|
file: UploadFile = File(...),
|
|
claims: dict = Depends(require_auth),
|
|
):
|
|
if not (file.filename or "").endswith(".csv"):
|
|
raise HTTPException(status_code=400, detail="File must be a .csv")
|
|
|
|
content = await file.read()
|
|
try:
|
|
text = content.decode("utf-8")
|
|
except UnicodeDecodeError:
|
|
text = content.decode("latin-1")
|
|
|
|
records, skipped_reasons, mapping_summary = normalize_csv(text)
|
|
|
|
if not records:
|
|
log_event(AuditAction.CSV_INGEST, file.filename or "unknown", "demo_user",
|
|
"failure", "0.0.0.0", detail="No processable rows")
|
|
raise HTTPException(
|
|
status_code=422,
|
|
detail={
|
|
"message": "No processable rows found in the uploaded file.",
|
|
"skipped": skipped_reasons[:10],
|
|
"mapping_summary": mapping_summary,
|
|
},
|
|
)
|
|
|
|
results = calculate_batch(records, as_of=date.today())
|
|
out = [_to_record_out(r) for r in results]
|
|
|
|
log_event(AuditAction.CSV_INGEST, file.filename or "unknown", "demo_user",
|
|
"success", "0.0.0.0", detail=f"{len(out)} records scored")
|
|
|
|
clerk_org_id = claims.get("o", {}).get("id") if isinstance(claims.get("o"), dict) else None
|
|
batch_id = persist_upload(
|
|
filename=file.filename or "unknown",
|
|
content_bytes=content,
|
|
shipment_records=records,
|
|
coverage_results=results,
|
|
skipped_count=len(skipped_reasons),
|
|
mapping_summary=mapping_summary,
|
|
clerk_org_id=clerk_org_id,
|
|
)
|
|
|
|
return UploadResponse(
|
|
records=out,
|
|
total=len(out),
|
|
skipped=len(skipped_reasons),
|
|
skipped_reasons=skipped_reasons[:20],
|
|
stats=_compute_stats(out),
|
|
mapping_summary=mapping_summary,
|
|
batch_id=batch_id,
|
|
)
|
|
|
|
|
|
class ExportRequest(BaseModel):
|
|
records: list[RecordOut]
|
|
batch_id: Optional[str] = None
|
|
|
|
|
|
@app.post("/api/export")
|
|
async def export_work_queue(
|
|
body: ExportRequest,
|
|
claims: dict = Depends(require_auth),
|
|
):
|
|
"""Generate a downloadable work-queue CSV from a list of scored records."""
|
|
records = body.records
|
|
output = io.StringIO()
|
|
writer = csv.writer(output)
|
|
writer.writerow([
|
|
"Patient ID",
|
|
"Device",
|
|
"Payer",
|
|
"Status",
|
|
"Priority Score",
|
|
"Days Until Resupply End",
|
|
"Next Visit Due",
|
|
"Recommended Action",
|
|
"Resupply End Date",
|
|
"Reason",
|
|
])
|
|
for r in records:
|
|
writer.writerow([
|
|
r.patient_id,
|
|
r.device_display,
|
|
r.payer,
|
|
r.status_label,
|
|
r.priority_score,
|
|
r.days_until_coverage_end,
|
|
r.next_visit_due_date or "",
|
|
r.action,
|
|
r.coverage_end_date,
|
|
r.reason,
|
|
])
|
|
|
|
output.seek(0)
|
|
today = date.today().isoformat()
|
|
export_filename = f"signal-work-queue-{today}.csv"
|
|
log_event(AuditAction.WORKLIST_EXPORT, export_filename, "demo_user",
|
|
"success", "0.0.0.0", detail=f"{len(records)} records exported")
|
|
persist_export(batch_id=body.batch_id, filename=export_filename, row_count=len(records))
|
|
return StreamingResponse(
|
|
io.BytesIO(output.getvalue().encode("utf-8")),
|
|
media_type="text/csv",
|
|
headers={
|
|
"Content-Disposition": f"attachment; filename=signal-work-queue-{today}.csv"
|
|
},
|
|
)
|