from __future__ import annotations import hashlib import time from datetime import datetime, timezone from typing import Any import httpx from app.adapters.iris import IrisAdapter from app.adapters.pagerduty import PagerDutyAdapter from app.adapters.shuffle import ShuffleAdapter from app.adapters.wazuh import WazuhAdapter from app.config import settings from app.repositories.mvp_repo import MvpRepository class MvpService: def __init__( self, repo: MvpRepository, wazuh_adapter: WazuhAdapter, shuffle_adapter: ShuffleAdapter, iris_adapter: IrisAdapter, pagerduty_adapter: PagerDutyAdapter, ) -> None: self.repo = repo self.wazuh_adapter = wazuh_adapter self.shuffle_adapter = shuffle_adapter self.iris_adapter = iris_adapter self.pagerduty_adapter = pagerduty_adapter def _is_off_hours(self, ts: datetime) -> bool: hour = ts.astimezone(timezone.utc).hour return hour < 6 or hour >= 20 def _safe_excerpt(self, payload: Any) -> str: text = str(payload) return text[:300] def _primary_subject(self, event: dict[str, Any]) -> str: asset = event.get("asset", {}) return str(asset.get("user") or asset.get("hostname") or "unknown") def _primary_observable(self, event: dict[str, Any]) -> str: network = event.get("network", {}) return str(network.get("domain") or network.get("src_ip") or network.get("dst_ip") or "unknown") def _incident_key(self, event: dict[str, Any]) -> str: ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00")).astimezone(timezone.utc) day_bucket = ts.strftime("%Y-%m-%d") raw = "|".join( [ str(event.get("event_type", "generic")), self._primary_subject(event), self._primary_observable(event), day_bucket, ] ) return hashlib.sha256(raw.encode("utf-8")).hexdigest() def _effective_severity(self, event: dict[str, Any], policy: dict[str, Any]) -> tuple[str, int, list[str]]: severity = str(event.get("severity", "medium")).lower() risk_context = event.get("risk_context", {}) network = event.get("network", {}) weights = policy.get("risk", {}).get("weights", {}) score = 0 factors: list[str] = [] allowed_country = policy.get("vpn", {}).get("allowed_country", "TH") country = str(network.get("country", "")).upper() if country and country != allowed_country: score += int(weights.get("outside_thailand", 50)) factors.append("outside_country") if risk_context.get("admin_account"): score += int(weights.get("admin", 20)) factors.append("admin_account") ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00")) if risk_context.get("off_hours") or self._is_off_hours(ts): score += int(weights.get("off_hours", 15)) factors.append("off_hours") if risk_context.get("first_seen_country"): score += int(weights.get("first_seen_country", 15)) factors.append("first_seen_country") thresholds = policy.get("risk", {}).get("thresholds", {}) if score >= int(thresholds.get("high", 70)): severity = "high" if severity in {"low", "medium"} else severity elif score >= int(thresholds.get("medium", 40)) and severity == "low": severity = "medium" return severity, score, factors def _is_exception(self, event: dict[str, Any], policy: dict[str, Any]) -> bool: if event.get("event_type") != "vpn_geo_anomaly": return False asset = event.get("asset", {}) user = str(asset.get("user", "")) allowed_users = set(policy.get("vpn", {}).get("exception_users", [])) return user in allowed_users def _extract_iris_case_id(self, iris_response: dict[str, Any]) -> str | None: if "case_id" in iris_response: return str(iris_response.get("case_id")) data = iris_response.get("data") if isinstance(data, dict) and "case_id" in data: return str(data.get("case_id")) return None async def ingest_incident(self, event: dict[str, Any]) -> dict[str, Any]: policy = self.repo.get_policy() incident_key = self._incident_key(event) if self._is_exception(event, policy): decision_trace = { "incident_key": incident_key, "policy_exception": True, "reason": "vpn_exception_user", } self.repo.upsert_incident(incident_key, severity="low", status="ignored_exception", iris_case_id=None) self.repo.add_event( incident_key=incident_key, event_id=event.get("event_id"), source=event.get("source", "unknown"), event_type=event.get("event_type", "generic"), raw_payload=event, decision_trace=decision_trace, ) return { "incident_key": incident_key, "action_taken": "ignored_exception", "escalation_stub_sent": False, "decision_trace": decision_trace, } effective_severity, risk_score, risk_factors = self._effective_severity(event, policy) current = self.repo.get_incident(incident_key) action_taken = "updated_case" if current else "created_case" iris_case_id = current.get("iris_case_id") if current else None if not iris_case_id: case_payload = { "case_name": event.get("title", "SOC Incident"), "case_description": event.get("description", "Generated by soc-integrator MVP"), "case_customer": event.get("payload", {}).get("case_customer", settings.iris_default_customer_id), "case_soc_id": event.get("payload", {}).get("case_soc_id", settings.iris_default_soc_id), } iris_result = await self.iris_adapter.create_case(case_payload) iris_case_id = self._extract_iris_case_id(iris_result) else: update_payload = { "case_description": f"{event.get('description', 'Updated by soc-integrator MVP')} [event_id={event.get('event_id', '')}]" } try: await self.iris_adapter.update_case(iris_case_id, update_payload) except Exception: # Keep pipeline progressing for MVP even if update path is unsupported. pass stored = self.repo.upsert_incident( incident_key=incident_key, severity=effective_severity, status="open", iris_case_id=iris_case_id, ) decision_trace = { "incident_key": incident_key, "risk_score": risk_score, "risk_factors": risk_factors, "effective_severity": effective_severity, "action_taken": action_taken, } self.repo.add_event( incident_key=incident_key, event_id=event.get("event_id"), source=event.get("source", "unknown"), event_type=event.get("event_type", "generic"), raw_payload=event, decision_trace=decision_trace, ) escalate_severities = set(policy.get("escalate_severities", ["high", "critical"])) escalation_stub_sent = False stub_response: dict[str, Any] | None = None if effective_severity in escalate_severities: escalation_payload = { "incident_key": incident_key, "title": event.get("title", "SOC Incident"), "severity": effective_severity, "source": event.get("source", "soc-integrator"), "iris_case_id": iris_case_id, "event_summary": event.get("description", ""), "timestamp": event.get("timestamp"), } try: pd_result = await self.pagerduty_adapter.create_incident(escalation_payload) escalation_stub_sent = True stub_response = {"ok": True, "data": pd_result} self.repo.add_escalation_audit( incident_key=incident_key, status_code=200, success=True, response_excerpt=self._safe_excerpt(pd_result), ) except Exception as exc: stub_response = {"ok": False, "error": str(exc)} self.repo.add_escalation_audit( incident_key=incident_key, status_code=502, success=False, response_excerpt=self._safe_excerpt(exc), ) return { "incident_key": stored["incident_key"], "action_taken": action_taken, "iris_case_id": stored.get("iris_case_id"), "escalation_stub_sent": escalation_stub_sent, "stub_response": stub_response, "decision_trace": decision_trace, } async def evaluate_ioc(self, payload: dict[str, Any]) -> dict[str, Any]: policy = self.repo.get_policy() workflow_id = str(policy.get("shuffle", {}).get("ioc_workflow_id", "")).strip() matched = True confidence = 0.7 shuffle_result: dict[str, Any] | None = None if workflow_id: shuffle_result = await self.shuffle_adapter.trigger_workflow(workflow_id, payload) if matched: event = { "source": "shuffle", "event_type": "ioc_dns" if payload.get("ioc_type") == "domain" else "ioc_ips", "event_id": payload.get("source_event", {}).get("event_id") or f"ioc-{int(time.time())}", "timestamp": datetime.now(timezone.utc).isoformat(), "severity": "medium", "title": f"IOC match: {payload.get('ioc_value', 'unknown')}", "description": "IOC evaluation result", "asset": payload.get("source_event", {}).get("asset", {}), "network": payload.get("source_event", {}).get("network", {}), "tags": ["ioc", str(payload.get("ioc_type", "unknown"))], "risk_context": {}, "raw": payload, "payload": {}, } ingest_result = await self.ingest_incident(event) else: ingest_result = {"action_taken": "rejected"} return { "matched": matched, "confidence": confidence, "shuffle": shuffle_result, "result": ingest_result, } async def evaluate_vpn(self, payload: dict[str, Any]) -> dict[str, Any]: if not payload.get("success", False): return { "risk_score": 0, "risk_factors": [], "exception_applied": False, "action_taken": "rejected", } event = { "source": "wazuh", "event_type": "vpn_geo_anomaly", "event_id": payload.get("event_id") or f"vpn-{int(time.time())}", "timestamp": payload.get("event_time") or datetime.now(timezone.utc).isoformat(), "severity": "high", "title": f"VPN login anomaly: {payload.get('user', 'unknown')}", "description": f"VPN success from {payload.get('country_code', 'unknown')} for user {payload.get('user', 'unknown')}", "asset": {"user": payload.get("user")}, "network": {"src_ip": payload.get("src_ip"), "country": payload.get("country_code")}, "tags": ["vpn", "geo-anomaly"], "risk_context": { "outside_thailand": payload.get("country_code", "").upper() != "TH", "admin_account": bool(payload.get("is_admin", False)), "off_hours": bool(payload.get("off_hours", False)), "first_seen_country": bool(payload.get("first_seen_country", False)), }, "raw": payload, "payload": {}, } ingest_result = await self.ingest_incident(event) decision_trace = ingest_result.get("decision_trace", {}) return { "risk_score": decision_trace.get("risk_score", 0), "risk_factors": decision_trace.get("risk_factors", []), "exception_applied": ingest_result.get("action_taken") == "ignored_exception", "action_taken": ingest_result.get("action_taken"), "incident_key": ingest_result.get("incident_key"), "iris_case_id": ingest_result.get("iris_case_id"), "escalation_stub_sent": ingest_result.get("escalation_stub_sent", False), } async def dependency_health(self) -> dict[str, Any]: out: dict[str, Any] = {} async def timed(name: str, fn): start = time.time() try: result = await fn() out[name] = { "status": "up", "latency_ms": round((time.time() - start) * 1000, 2), "details": result, } except Exception as exc: out[name] = { "status": "down", "latency_ms": round((time.time() - start) * 1000, 2), "error": str(exc), } await timed("wazuh", self.wazuh_adapter.auth_test) await timed("shuffle", self.shuffle_adapter.health) await timed("iris", self.iris_adapter.whoami) async def pagerduty_stub_health(): async with httpx.AsyncClient(timeout=10.0) as client: r = await client.get(settings.pagerduty_base_url) r.raise_for_status() return {"status_code": r.status_code} await timed("pagerduty_stub", pagerduty_stub_health) return out