from __future__ import annotations import hashlib import re import time from datetime import datetime, timezone from typing import Any import httpx from app.adapters.iris import IrisAdapter from app.adapters.pagerduty import PagerDutyAdapter from app.adapters.shuffle import ShuffleAdapter from app.adapters.wazuh import WazuhAdapter from app.config import settings from app.repositories.mvp_repo import MvpRepository class MvpService: def __init__( self, repo: MvpRepository, wazuh_adapter: WazuhAdapter, shuffle_adapter: ShuffleAdapter, iris_adapter: IrisAdapter, pagerduty_adapter: PagerDutyAdapter, ) -> None: self.repo = repo self.wazuh_adapter = wazuh_adapter self.shuffle_adapter = shuffle_adapter self.iris_adapter = iris_adapter self.pagerduty_adapter = pagerduty_adapter def _is_off_hours(self, ts: datetime) -> bool: hour = ts.astimezone(timezone.utc).hour return hour < 6 or hour >= 20 def _safe_excerpt(self, payload: Any) -> str: text = str(payload) return text[:300] def _primary_subject(self, event: dict[str, Any]) -> str: asset = event.get("asset", {}) return str(asset.get("user") or asset.get("hostname") or "unknown") def _primary_observable(self, event: dict[str, Any]) -> str: network = event.get("network", {}) return str(network.get("domain") or network.get("src_ip") or network.get("dst_ip") or "unknown") def _incident_key(self, event: dict[str, Any]) -> str: ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00")).astimezone(timezone.utc) day_bucket = ts.strftime("%Y-%m-%d") raw = "|".join( [ str(event.get("event_type", "generic")), self._primary_subject(event), self._primary_observable(event), day_bucket, ] ) return hashlib.sha256(raw.encode("utf-8")).hexdigest() def _effective_severity(self, event: dict[str, Any], policy: dict[str, Any]) -> tuple[str, int, list[str]]: severity = str(event.get("severity", "medium")).lower() risk_context = event.get("risk_context", {}) network = event.get("network", {}) weights = policy.get("risk", {}).get("weights", {}) score = 0 factors: list[str] = [] allowed_country = policy.get("vpn", {}).get("allowed_country", "TH") country = str(network.get("country", "")).upper() if country and country != allowed_country: score += int(weights.get("outside_thailand", 50)) factors.append("outside_country") if risk_context.get("admin_account"): score += int(weights.get("admin", 20)) factors.append("admin_account") ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00")) if risk_context.get("off_hours") or self._is_off_hours(ts): score += int(weights.get("off_hours", 15)) factors.append("off_hours") if risk_context.get("first_seen_country"): score += int(weights.get("first_seen_country", 15)) factors.append("first_seen_country") thresholds = policy.get("risk", {}).get("thresholds", {}) if score >= int(thresholds.get("high", 70)): severity = "high" if severity in {"low", "medium"} else severity elif score >= int(thresholds.get("medium", 40)) and severity == "low": severity = "medium" return severity, score, factors def _is_exception(self, event: dict[str, Any], policy: dict[str, Any]) -> bool: if event.get("event_type") != "vpn_geo_anomaly": return False asset = event.get("asset", {}) user = str(asset.get("user", "")) allowed_users = set(policy.get("vpn", {}).get("exception_users", [])) return user in allowed_users def _extract_iris_case_id(self, iris_response: dict[str, Any]) -> str | None: if "case_id" in iris_response: return str(iris_response.get("case_id")) data = iris_response.get("data") if isinstance(data, dict) and "case_id" in data: return str(data.get("case_id")) return None def _parse_kv_pairs(self, text: str) -> dict[str, str]: pattern = r"([A-Za-z0-9_]+)=('(?:[^']*)'|\"(?:[^\"]*)\"|[^\\s]+)" out: dict[str, str] = {} for key, raw in re.findall(pattern, text): value = raw.strip().strip("'").strip('"') out[key] = value return out def _severity_from_rule_level(self, rule_level: Any) -> str: try: level = int(rule_level) except (TypeError, ValueError): return "medium" if level >= 12: return "critical" if level >= 8: return "high" if level >= 4: return "medium" return "low" def _event_type_from_text(self, text: str, parsed: dict[str, str]) -> str: explicit = parsed.get("event_type") if explicit: return explicit lowered = text.lower() if "vpn" in lowered and ("geo" in lowered or "country" in lowered): return "vpn_geo_anomaly" if "domain" in lowered or "dns" in lowered: return "ioc_dns" if "c2" in lowered or "ips" in lowered or "ip " in lowered: return "ioc_ips" if "auth" in lowered and "fail" in lowered: return "auth_anomaly" return "generic" def _normalize_wazuh_hit(self, hit: dict[str, Any]) -> dict[str, Any]: src = hit.get("_source", {}) full_log = str(src.get("full_log", "")) parsed = self._parse_kv_pairs(full_log) event_id = str(parsed.get("event_id") or src.get("id") or hit.get("_id") or f"wazuh-{int(time.time())}") timestamp = ( src.get("@timestamp") or src.get("timestamp") or datetime.now(timezone.utc).isoformat() ) rule = src.get("rule", {}) if isinstance(src.get("rule"), dict) else {} rule_desc = str(rule.get("description") or "") event_type = self._event_type_from_text(full_log, parsed) severity = str(parsed.get("severity", "")).lower() or self._severity_from_rule_level(rule.get("level")) src_ip = parsed.get("src_ip") dst_ip = parsed.get("dst_ip") domain = parsed.get("query") or parsed.get("domain") country = parsed.get("country") user = parsed.get("user") or (src.get("agent", {}) or {}).get("name") title = rule_desc or f"Wazuh alert {rule.get('id', '')}".strip() description = full_log or rule_desc or "Wazuh alert" return { "source": "wazuh", "event_type": event_type, "event_id": event_id, "timestamp": timestamp, "severity": severity if severity in {"low", "medium", "high", "critical"} else "medium", "title": title, "description": description, "asset": { "user": user, "hostname": (src.get("agent", {}) or {}).get("name"), "agent_id": (src.get("agent", {}) or {}).get("id"), }, "network": { "src_ip": src_ip, "dst_ip": dst_ip, "domain": domain, "country": country, }, "tags": ["wazuh", event_type, f"rule_{rule.get('id', 'unknown')}"], "risk_context": { "outside_thailand": bool(country and str(country).upper() != "TH"), }, "raw": src, "payload": {}, } async def ingest_incident(self, event: dict[str, Any]) -> dict[str, Any]: policy = self.repo.get_policy() incident_key = self._incident_key(event) if self._is_exception(event, policy): decision_trace = { "incident_key": incident_key, "policy_exception": True, "reason": "vpn_exception_user", } self.repo.upsert_incident(incident_key, severity="low", status="ignored_exception", iris_case_id=None) self.repo.add_event( incident_key=incident_key, event_id=event.get("event_id"), source=event.get("source", "unknown"), event_type=event.get("event_type", "generic"), raw_payload=event, decision_trace=decision_trace, ) return { "incident_key": incident_key, "action_taken": "ignored_exception", "escalation_stub_sent": False, "decision_trace": decision_trace, } effective_severity, risk_score, risk_factors = self._effective_severity(event, policy) current = self.repo.get_incident(incident_key) action_taken = "updated_case" if current else "created_case" iris_case_id = current.get("iris_case_id") if current else None if not iris_case_id: case_payload = { "case_name": event.get("title", "SOC Incident"), "case_description": event.get("description", "Generated by soc-integrator MVP"), "case_customer": event.get("payload", {}).get("case_customer", settings.iris_default_customer_id), "case_soc_id": event.get("payload", {}).get("case_soc_id", settings.iris_default_soc_id), } iris_result = await self.iris_adapter.create_case(case_payload) iris_case_id = self._extract_iris_case_id(iris_result) else: update_payload = { "case_description": f"{event.get('description', 'Updated by soc-integrator MVP')} [event_id={event.get('event_id', '')}]" } try: await self.iris_adapter.update_case(iris_case_id, update_payload) except Exception: # Keep pipeline progressing for MVP even if update path is unsupported. pass stored = self.repo.upsert_incident( incident_key=incident_key, severity=effective_severity, status="open", iris_case_id=iris_case_id, ) decision_trace = { "incident_key": incident_key, "risk_score": risk_score, "risk_factors": risk_factors, "effective_severity": effective_severity, "action_taken": action_taken, } self.repo.add_event( incident_key=incident_key, event_id=event.get("event_id"), source=event.get("source", "unknown"), event_type=event.get("event_type", "generic"), raw_payload=event, decision_trace=decision_trace, ) escalate_severities = set(policy.get("escalate_severities", ["high", "critical"])) escalation_stub_sent = False stub_response: dict[str, Any] | None = None if effective_severity in escalate_severities: escalation_payload = { "incident_key": incident_key, "title": event.get("title", "SOC Incident"), "severity": effective_severity, "source": event.get("source", "soc-integrator"), "iris_case_id": iris_case_id, "event_summary": event.get("description", ""), "timestamp": event.get("timestamp"), } try: pd_result = await self.pagerduty_adapter.create_incident(escalation_payload) escalation_stub_sent = True stub_response = {"ok": True, "data": pd_result} self.repo.add_escalation_audit( incident_key=incident_key, status_code=200, success=True, response_excerpt=self._safe_excerpt(pd_result), ) except Exception as exc: stub_response = {"ok": False, "error": str(exc)} self.repo.add_escalation_audit( incident_key=incident_key, status_code=502, success=False, response_excerpt=self._safe_excerpt(exc), ) return { "incident_key": stored["incident_key"], "action_taken": action_taken, "iris_case_id": stored.get("iris_case_id"), "escalation_stub_sent": escalation_stub_sent, "stub_response": stub_response, "decision_trace": decision_trace, } async def evaluate_ioc(self, payload: dict[str, Any]) -> dict[str, Any]: policy = self.repo.get_policy() workflow_id = str(policy.get("shuffle", {}).get("ioc_workflow_id", "")).strip() matched = True confidence = 0.7 shuffle_result: dict[str, Any] | None = None if workflow_id: shuffle_result = await self.shuffle_adapter.trigger_workflow(workflow_id, payload) if matched: event = { "source": "shuffle", "event_type": "ioc_dns" if payload.get("ioc_type") == "domain" else "ioc_ips", "event_id": payload.get("source_event", {}).get("event_id") or f"ioc-{int(time.time())}", "timestamp": datetime.now(timezone.utc).isoformat(), "severity": "medium", "title": f"IOC match: {payload.get('ioc_value', 'unknown')}", "description": "IOC evaluation result", "asset": payload.get("source_event", {}).get("asset", {}), "network": payload.get("source_event", {}).get("network", {}), "tags": ["ioc", str(payload.get("ioc_type", "unknown"))], "risk_context": {}, "raw": payload, "payload": {}, } ingest_result = await self.ingest_incident(event) else: ingest_result = {"action_taken": "rejected"} return { "matched": matched, "confidence": confidence, "shuffle": shuffle_result, "result": ingest_result, } async def evaluate_vpn(self, payload: dict[str, Any]) -> dict[str, Any]: if not payload.get("success", False): return { "risk_score": 0, "risk_factors": [], "exception_applied": False, "action_taken": "rejected", } event = { "source": "wazuh", "event_type": "vpn_geo_anomaly", "event_id": payload.get("event_id") or f"vpn-{int(time.time())}", "timestamp": payload.get("event_time") or datetime.now(timezone.utc).isoformat(), "severity": "high", "title": f"VPN login anomaly: {payload.get('user', 'unknown')}", "description": f"VPN success from {payload.get('country_code', 'unknown')} for user {payload.get('user', 'unknown')}", "asset": {"user": payload.get("user")}, "network": {"src_ip": payload.get("src_ip"), "country": payload.get("country_code")}, "tags": ["vpn", "geo-anomaly"], "risk_context": { "outside_thailand": payload.get("country_code", "").upper() != "TH", "admin_account": bool(payload.get("is_admin", False)), "off_hours": bool(payload.get("off_hours", False)), "first_seen_country": bool(payload.get("first_seen_country", False)), }, "raw": payload, "payload": {}, } ingest_result = await self.ingest_incident(event) decision_trace = ingest_result.get("decision_trace", {}) return { "risk_score": decision_trace.get("risk_score", 0), "risk_factors": decision_trace.get("risk_factors", []), "exception_applied": ingest_result.get("action_taken") == "ignored_exception", "action_taken": ingest_result.get("action_taken"), "incident_key": ingest_result.get("incident_key"), "iris_case_id": ingest_result.get("iris_case_id"), "escalation_stub_sent": ingest_result.get("escalation_stub_sent", False), } async def sync_wazuh_alerts( self, query: str = "soc_mvp_test=true OR event_type:*", limit: int = 50, minutes: int = 120, ) -> dict[str, Any]: raw = await self.wazuh_adapter.search_alerts(query=query, limit=limit, minutes=minutes) hits = (raw.get("hits", {}) or {}).get("hits", []) if isinstance(raw, dict) else [] processed = 0 ingested = 0 skipped_existing = 0 failed = 0 errors: list[str] = [] created_incidents: list[str] = [] for hit in hits: processed += 1 event = self._normalize_wazuh_hit(hit) event_id = str(event.get("event_id", "")).strip() if event_id and self.repo.has_event("wazuh", event_id): skipped_existing += 1 continue try: result = await self.ingest_incident(event) ingested += 1 incident_key = str(result.get("incident_key", "")) if incident_key: created_incidents.append(incident_key) except Exception as exc: failed += 1 errors.append(f"{event_id or 'unknown_event'}: {exc}") return { "query": query, "window_minutes": minutes, "limit": limit, "processed": processed, "ingested": ingested, "skipped_existing": skipped_existing, "failed": failed, "incident_keys": created_incidents, "errors": errors[:10], "total_hits": (raw.get("hits", {}).get("total", {}) if isinstance(raw, dict) else {}), } async def dependency_health(self) -> dict[str, Any]: out: dict[str, Any] = {} async def timed(name: str, fn): start = time.time() try: result = await fn() out[name] = { "status": "up", "latency_ms": round((time.time() - start) * 1000, 2), "details": result, } except Exception as exc: out[name] = { "status": "down", "latency_ms": round((time.time() - start) * 1000, 2), "error": str(exc), } await timed("wazuh", self.wazuh_adapter.auth_test) await timed("shuffle", self.shuffle_adapter.health) await timed("iris", self.iris_adapter.whoami) async def pagerduty_stub_health(): async with httpx.AsyncClient(timeout=10.0) as client: r = await client.get(settings.pagerduty_base_url) r.raise_for_status() return {"status_code": r.status_code} await timed("pagerduty_stub", pagerduty_stub_health) return out