from __future__ import annotations import hashlib import logging import re import time from datetime import datetime, timezone from typing import Any import httpx from app.adapters.iris import IrisAdapter from app.adapters.pagerduty import PagerDutyAdapter from app.adapters.shuffle import ShuffleAdapter from app.adapters.wazuh import WazuhAdapter from app.config import settings from app.repositories.mvp_repo import MvpRepository logger = logging.getLogger(__name__) class MvpService: def __init__( self, repo: MvpRepository, wazuh_adapter: WazuhAdapter, shuffle_adapter: ShuffleAdapter, iris_adapter: IrisAdapter, pagerduty_adapter: PagerDutyAdapter, ) -> None: self.repo = repo self.wazuh_adapter = wazuh_adapter self.shuffle_adapter = shuffle_adapter self.iris_adapter = iris_adapter self.pagerduty_adapter = pagerduty_adapter def _is_off_hours(self, ts: datetime) -> bool: hour = ts.astimezone(timezone.utc).hour return hour < 6 or hour >= 20 def _safe_excerpt(self, payload: Any) -> str: text = str(payload) return text[:300] def _primary_subject(self, event: dict[str, Any]) -> str: asset = event.get("asset", {}) return str(asset.get("user") or asset.get("hostname") or "unknown") def _primary_observable(self, event: dict[str, Any]) -> str: network = event.get("network", {}) return str(network.get("domain") or network.get("src_ip") or network.get("dst_ip") or "unknown") def _incident_key(self, event: dict[str, Any]) -> str: ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00")).astimezone(timezone.utc) day_bucket = ts.strftime("%Y-%m-%d") raw = "|".join( [ str(event.get("event_type", "generic")), self._primary_subject(event), self._primary_observable(event), day_bucket, ] ) return hashlib.sha256(raw.encode("utf-8")).hexdigest() def _effective_severity(self, event: dict[str, Any], policy: dict[str, Any]) -> tuple[str, int, list[str]]: severity = str(event.get("severity", "medium")).lower() risk_context = event.get("risk_context", {}) network = event.get("network", {}) weights = policy.get("risk", {}).get("weights", {}) score = 0 factors: list[str] = [] allowed_country = policy.get("vpn", {}).get("allowed_country", "TH") country = str(network.get("country", "")).upper() if country and country != allowed_country: score += int(weights.get("outside_thailand", 50)) factors.append("outside_country") if risk_context.get("admin_account"): score += int(weights.get("admin", 20)) factors.append("admin_account") ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00")) if risk_context.get("off_hours") or self._is_off_hours(ts): score += int(weights.get("off_hours", 15)) factors.append("off_hours") if risk_context.get("first_seen_country"): score += int(weights.get("first_seen_country", 15)) factors.append("first_seen_country") thresholds = policy.get("risk", {}).get("thresholds", {}) if score >= int(thresholds.get("high", 70)): severity = "high" if severity in {"low", "medium"} else severity elif score >= int(thresholds.get("medium", 40)) and severity == "low": severity = "medium" return severity, score, factors def _is_exception(self, event: dict[str, Any], policy: dict[str, Any]) -> bool: if event.get("event_type") != "vpn_geo_anomaly": return False asset = event.get("asset", {}) user = str(asset.get("user", "")) allowed_users = set(policy.get("vpn", {}).get("exception_users", [])) return user in allowed_users def _extract_iris_case_id(self, iris_response: dict[str, Any]) -> str | None: if "case_id" in iris_response: return str(iris_response.get("case_id")) data = iris_response.get("data") if isinstance(data, dict) and "case_id" in data: return str(data.get("case_id")) return None def _parse_kv_pairs(self, text: str) -> dict[str, str]: pattern = r"([A-Za-z0-9_]+)=('(?:[^']*)'|\"(?:[^\"]*)\"|[^\s]+)" out: dict[str, str] = {} for key, raw in re.findall(pattern, text): value = raw.strip().strip("'").strip('"') out[key] = value return out def _severity_from_rule_level(self, rule_level: Any) -> str: try: level = int(rule_level) except (TypeError, ValueError): return "medium" if level >= 12: return "critical" if level >= 8: return "high" if level >= 4: return "medium" return "low" def _event_type_from_text(self, text: str, parsed: dict[str, str]) -> str: explicit = str(parsed.get("event_type") or "").strip().lower() usecase_id = str(parsed.get("usecase_id") or "").strip().upper() section = str(parsed.get("section") or "").strip().upper() source = str(parsed.get("source") or "").strip().lower() success = str(parsed.get("success") or "").strip().lower() has_geo = bool(parsed.get("country") or parsed.get("src_lat") or parsed.get("src_lon")) has_user = bool(parsed.get("user")) has_src_ip = bool(parsed.get("src_ip") or parsed.get("srcip")) explicit_success_login = explicit in { "vpn_login_success", "windows_auth_success", "auth_success", } # Production-first C1 detection: # successful auth/login + geo context on vpn/windows identity streams. if ( (source in {"vpn", "fortigate", "windows", "identity"} or "vpn" in source) and has_geo and has_user and has_src_ip and (success == "true" or explicit_success_login) ): return "c1_impossible_travel" # Legacy simulator markers remain supported as fallback. if usecase_id.startswith("C1") or section == "C1": return "c1_impossible_travel" if explicit in {"c1_impossible_travel", "impossible_travel"}: return "c1_impossible_travel" if explicit == "vpn_geo_anomaly": return "vpn_geo_anomaly" if explicit: return explicit lowered = text.lower() if "impossible travel" in lowered: return "c1_impossible_travel" if "vpn" in lowered and ("geo" in lowered or "country" in lowered): return "vpn_geo_anomaly" if "domain" in lowered or "dns" in lowered: return "ioc_dns" if "c2" in lowered or "ips" in lowered or "ip " in lowered: return "ioc_ips" if "auth" in lowered and "fail" in lowered: return "auth_anomaly" return "generic" def _normalize_wazuh_hit(self, hit: dict[str, Any]) -> dict[str, Any]: src = hit.get("_source", {}) full_log = str(src.get("full_log", "")) parsed = self._parse_kv_pairs(full_log) event_id = str(parsed.get("event_id") or src.get("id") or hit.get("_id") or f"wazuh-{int(time.time())}") timestamp = ( src.get("@timestamp") or src.get("timestamp") or datetime.now(timezone.utc).isoformat() ) rule = src.get("rule", {}) if isinstance(src.get("rule"), dict) else {} rule_desc = str(rule.get("description") or "") event_type = self._event_type_from_text(full_log, parsed) severity = str(parsed.get("severity", "")).lower() or self._severity_from_rule_level(rule.get("level")) src_ip = parsed.get("src_ip") if not src_ip: src_ip = parsed.get("srcip") dst_ip = parsed.get("dst_ip") if not dst_ip: dst_ip = parsed.get("dstip") domain = parsed.get("query") or parsed.get("domain") country = parsed.get("country") user = parsed.get("user") or (src.get("agent", {}) or {}).get("name") dst_port = parsed.get("dst_port") or parsed.get("dstport") event_action = parsed.get("event_action") or parsed.get("action") title = rule_desc or f"Wazuh alert {rule.get('id', '')}".strip() description = full_log or rule_desc or "Wazuh alert" src_lat_raw = parsed.get("src_lat") src_lon_raw = parsed.get("src_lon") try: src_lat = float(src_lat_raw) if src_lat_raw not in {None, ""} else None except (TypeError, ValueError): src_lat = None try: src_lon = float(src_lon_raw) if src_lon_raw not in {None, ""} else None except (TypeError, ValueError): src_lon = None return { "source": "wazuh", "event_type": event_type, "event_id": event_id, "timestamp": timestamp, "severity": severity if severity in {"low", "medium", "high", "critical"} else "medium", "title": title, "description": description, "asset": { "user": user, "hostname": (src.get("agent", {}) or {}).get("name"), "agent_id": (src.get("agent", {}) or {}).get("id"), }, "network": { "src_ip": src_ip, "dst_ip": dst_ip, "dst_host": parsed.get("dst_host") or parsed.get("host"), "dst_port": int(dst_port) if str(dst_port or "").isdigit() else None, "domain": domain, "country": country, "src_lat": src_lat, "src_lon": src_lon, }, "tags": ["wazuh", event_type, f"rule_{rule.get('id', 'unknown')}"], "risk_context": { "outside_thailand": bool(country and str(country).upper() != "TH"), }, "raw": src, "payload": { **parsed, "event_action": event_action, "event_id": parsed.get("event_id"), "event_type": event_type, "success": parsed.get("success"), "logon_type": parsed.get("logon_type"), "account_type": parsed.get("account_type"), "is_admin": parsed.get("is_admin"), "is_service": parsed.get("is_service"), }, } def normalize_wazuh_hit(self, hit: dict[str, Any]) -> dict[str, Any]: return self._normalize_wazuh_hit(hit) def _to_float(self, value: Any, default: float = 0.0) -> float: try: return float(value) except (TypeError, ValueError): return default def _severity_from_confidence(self, confidence: float) -> str: if confidence >= 0.9: return "high" if confidence >= 0.7: return "medium" return "low" def _extract_shuffle_verdict(self, shuffle_result: dict[str, Any] | None) -> dict[str, Any]: if not isinstance(shuffle_result, dict): return { "matched": False, "confidence": 0.0, "severity": "low", "evidence": "", "iocs": [], "reason": "no_shuffle_result", } flat = dict(shuffle_result) nested = shuffle_result.get("result") if isinstance(nested, dict): merged = dict(nested) merged.update(flat) flat = merged confidence = self._to_float(flat.get("confidence"), 0.0) matched_raw = flat.get("matched") if isinstance(matched_raw, bool): matched = matched_raw reason = "shuffle_explicit" else: matched = confidence >= 0.7 reason = "confidence_threshold_fallback" severity_raw = str(flat.get("severity", "")).lower() severity = severity_raw if severity_raw in {"low", "medium", "high", "critical"} else self._severity_from_confidence(confidence) return { "matched": matched, "confidence": confidence, "severity": severity, "evidence": str(flat.get("evidence", "")), "iocs": flat.get("iocs", []), "reason": reason, "raw": shuffle_result, } async def ingest_incident(self, event: dict[str, Any]) -> dict[str, Any]: policy = self.repo.get_policy() incident_key = self._incident_key(event) if self._is_exception(event, policy): decision_trace = { "incident_key": incident_key, "policy_exception": True, "reason": "vpn_exception_user", } self.repo.upsert_incident(incident_key, severity="low", status="ignored_exception", iris_case_id=None) self.repo.add_event( incident_key=incident_key, event_id=event.get("event_id"), source=event.get("source", "unknown"), event_type=event.get("event_type", "generic"), raw_payload=event, decision_trace=decision_trace, ) return { "incident_key": incident_key, "action_taken": "ignored_exception", "escalation_stub_sent": False, "decision_trace": decision_trace, } effective_severity, risk_score, risk_factors = self._effective_severity(event, policy) current = self.repo.get_incident(incident_key) action_taken = "updated_case" if current else "created_case" iris_case_id = current.get("iris_case_id") if current else None if not iris_case_id: case_payload = { "case_name": event.get("title", "SOC Incident"), "case_description": event.get("description", "Generated by soc-integrator MVP"), "case_customer": event.get("payload", {}).get("case_customer", settings.iris_default_customer_id), "case_soc_id": event.get("payload", {}).get("case_soc_id", settings.iris_default_soc_id), } iris_result = await self.iris_adapter.create_case(case_payload) iris_case_id = self._extract_iris_case_id(iris_result) else: update_payload = { "case_description": f"{event.get('description', 'Updated by soc-integrator MVP')} [event_id={event.get('event_id', '')}]" } try: await self.iris_adapter.update_case(iris_case_id, update_payload) except Exception: # Keep pipeline progressing for MVP even if update path is unsupported. pass stored = self.repo.upsert_incident( incident_key=incident_key, severity=effective_severity, status="open", iris_case_id=iris_case_id, ) decision_trace = { "incident_key": incident_key, "risk_score": risk_score, "risk_factors": risk_factors, "effective_severity": effective_severity, "action_taken": action_taken, } self.repo.add_event( incident_key=incident_key, event_id=event.get("event_id"), source=event.get("source", "unknown"), event_type=event.get("event_type", "generic"), raw_payload=event, decision_trace=decision_trace, ) escalate_severities = set(policy.get("escalate_severities", ["high", "critical"])) escalation_stub_sent = False stub_response: dict[str, Any] | None = None if effective_severity in escalate_severities: escalation_payload = { "incident_key": incident_key, "title": event.get("title", "SOC Incident"), "severity": effective_severity, "source": event.get("source", "soc-integrator"), "iris_case_id": iris_case_id, "event_summary": event.get("description", ""), "timestamp": event.get("timestamp"), } try: pd_result = await self.pagerduty_adapter.create_incident(escalation_payload) escalation_stub_sent = True stub_response = {"ok": True, "data": pd_result} self.repo.add_escalation_audit( incident_key=incident_key, status_code=200, success=True, response_excerpt=self._safe_excerpt(pd_result), ) except Exception as exc: stub_response = {"ok": False, "error": str(exc)} self.repo.add_escalation_audit( incident_key=incident_key, status_code=502, success=False, response_excerpt=self._safe_excerpt(exc), ) return { "incident_key": stored["incident_key"], "action_taken": action_taken, "iris_case_id": stored.get("iris_case_id"), "escalation_stub_sent": escalation_stub_sent, "stub_response": stub_response, "decision_trace": decision_trace, } async def evaluate_ioc(self, payload: dict[str, Any]) -> dict[str, Any]: policy = self.repo.get_policy() workflow_id = str(policy.get("shuffle", {}).get("ioc_workflow_id", "")).strip() shuffle_result: dict[str, Any] | None = None if workflow_id: shuffle_result = await self.shuffle_adapter.trigger_workflow(workflow_id, payload) verdict = self._extract_shuffle_verdict(shuffle_result) matched = bool(verdict["matched"]) confidence = self._to_float(verdict["confidence"], 0.0) logger.info( "ioc evaluation workflow_id=%s matched=%s confidence=%.2f", workflow_id or "", matched, confidence, ) if matched: src_event = payload.get("source_event", {}) event_id = src_event.get("event_id") or f"ioc-{int(time.time())}" if not isinstance(event_id, str): event_id = str(event_id) description = f"IOC evaluation result confidence={confidence:.2f}" evidence = str(verdict.get("evidence", "")).strip() if evidence: description = f"{description} evidence={evidence[:180]}" event = { "source": "shuffle", "event_type": "ioc_dns" if payload.get("ioc_type") == "domain" else "ioc_ips", "event_id": event_id, "timestamp": datetime.now(timezone.utc).isoformat(), "severity": verdict["severity"], "title": f"IOC match: {payload.get('ioc_value', 'unknown')}", "description": description, "asset": src_event.get("asset", {}), "network": src_event.get("network", {}), "tags": ["ioc", str(payload.get("ioc_type", "unknown"))], "risk_context": {}, "raw": { "payload": payload, "shuffle": verdict.get("raw"), }, "payload": {}, } ingest_result = await self.ingest_incident(event) else: ingest_result = {"action_taken": "rejected"} return { "matched": matched, "confidence": confidence, "severity": verdict["severity"], "evidence": verdict["evidence"], "iocs": verdict["iocs"], "decision_source": verdict["reason"], "shuffle": shuffle_result, "result": ingest_result, } async def evaluate_vpn(self, payload: dict[str, Any]) -> dict[str, Any]: if not payload.get("success", False): return { "risk_score": 0, "risk_factors": [], "exception_applied": False, "action_taken": "rejected", } event = { "source": "wazuh", "event_type": "vpn_geo_anomaly", "event_id": payload.get("event_id") or f"vpn-{int(time.time())}", "timestamp": payload.get("event_time") or datetime.now(timezone.utc).isoformat(), "severity": "high", "title": f"VPN login anomaly: {payload.get('user', 'unknown')}", "description": f"VPN success from {payload.get('country_code', 'unknown')} for user {payload.get('user', 'unknown')}", "asset": {"user": payload.get("user")}, "network": {"src_ip": payload.get("src_ip"), "country": payload.get("country_code")}, "tags": ["vpn", "geo-anomaly"], "risk_context": { "outside_thailand": payload.get("country_code", "").upper() != "TH", "admin_account": bool(payload.get("is_admin", False)), "off_hours": bool(payload.get("off_hours", False)), "first_seen_country": bool(payload.get("first_seen_country", False)), }, "raw": payload, "payload": {}, } ingest_result = await self.ingest_incident(event) decision_trace = ingest_result.get("decision_trace", {}) return { "risk_score": decision_trace.get("risk_score", 0), "risk_factors": decision_trace.get("risk_factors", []), "exception_applied": ingest_result.get("action_taken") == "ignored_exception", "action_taken": ingest_result.get("action_taken"), "incident_key": ingest_result.get("incident_key"), "iris_case_id": ingest_result.get("iris_case_id"), "escalation_stub_sent": ingest_result.get("escalation_stub_sent", False), } async def sync_wazuh_alerts( self, query: str = "soc_mvp_test=true OR event_type:*", limit: int = 50, minutes: int = 120, ) -> dict[str, Any]: raw = await self.wazuh_adapter.search_alerts(query=query, limit=limit, minutes=minutes) hits = (raw.get("hits", {}) or {}).get("hits", []) if isinstance(raw, dict) else [] processed = 0 ingested = 0 skipped_existing = 0 failed = 0 errors: list[str] = [] created_incidents: list[str] = [] ioc_evaluated = 0 ioc_matched = 0 ioc_rejected = 0 for hit in hits: processed += 1 event = self._normalize_wazuh_hit(hit) event_id = str(event.get("event_id", "")).strip() if event_id and self.repo.has_event("wazuh", event_id): skipped_existing += 1 continue try: if event.get("event_type") in {"ioc_dns", "ioc_ips"}: ioc_evaluated += 1 payload = { "ioc_type": "domain" if event.get("event_type") == "ioc_dns" else "ip", "ioc_value": (event.get("network", {}) or {}).get("domain") or (event.get("network", {}) or {}).get("dst_ip") or (event.get("network", {}) or {}).get("src_ip") or "unknown", "source_event": { "event_id": event.get("event_id"), "asset": event.get("asset", {}), "network": event.get("network", {}), "raw": event.get("raw", {}), }, } ioc_result = await self.evaluate_ioc(payload) if ioc_result.get("matched"): ioc_matched += 1 ingested += 1 incident_key = str((ioc_result.get("result", {}) or {}).get("incident_key", "")) if incident_key: created_incidents.append(incident_key) else: ioc_rejected += 1 else: result = await self.ingest_incident(event) ingested += 1 incident_key = str(result.get("incident_key", "")) if incident_key: created_incidents.append(incident_key) except Exception as exc: failed += 1 errors.append(f"{event_id or 'unknown_event'}: {exc}") return { "query": query, "window_minutes": minutes, "limit": limit, "processed": processed, "ingested": ingested, "skipped_existing": skipped_existing, "failed": failed, "ioc_evaluated": ioc_evaluated, "ioc_matched": ioc_matched, "ioc_rejected": ioc_rejected, "incident_keys": created_incidents, "errors": errors[:10], "total_hits": (raw.get("hits", {}).get("total", {}) if isinstance(raw, dict) else {}), } async def dependency_health(self) -> dict[str, Any]: out: dict[str, Any] = {} async def timed(name: str, fn): start = time.time() try: result = await fn() out[name] = { "status": "up", "latency_ms": round((time.time() - start) * 1000, 2), "details": result, } except Exception as exc: out[name] = { "status": "down", "latency_ms": round((time.time() - start) * 1000, 2), "error": str(exc), } await timed("wazuh", self.wazuh_adapter.auth_test) await timed("shuffle", self.shuffle_adapter.health) await timed("iris", self.iris_adapter.whoami) async def pagerduty_stub_health(): async with httpx.AsyncClient(timeout=10.0) as client: r = await client.get(settings.pagerduty_base_url) r.raise_for_status() return {"status_code": r.status_code} await timed("pagerduty_stub", pagerduty_stub_health) return out