| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591 |
- from __future__ import annotations
- import hashlib
- import logging
- import re
- import time
- from datetime import datetime, timezone
- from typing import Any
- import httpx
- from app.adapters.iris import IrisAdapter
- from app.adapters.pagerduty import PagerDutyAdapter
- from app.adapters.shuffle import ShuffleAdapter
- from app.adapters.wazuh import WazuhAdapter
- from app.config import settings
- from app.repositories.mvp_repo import MvpRepository
- logger = logging.getLogger(__name__)
- class MvpService:
- def __init__(
- self,
- repo: MvpRepository,
- wazuh_adapter: WazuhAdapter,
- shuffle_adapter: ShuffleAdapter,
- iris_adapter: IrisAdapter,
- pagerduty_adapter: PagerDutyAdapter,
- ) -> None:
- self.repo = repo
- self.wazuh_adapter = wazuh_adapter
- self.shuffle_adapter = shuffle_adapter
- self.iris_adapter = iris_adapter
- self.pagerduty_adapter = pagerduty_adapter
- def _is_off_hours(self, ts: datetime) -> bool:
- hour = ts.astimezone(timezone.utc).hour
- return hour < 6 or hour >= 20
- def _safe_excerpt(self, payload: Any) -> str:
- text = str(payload)
- return text[:300]
- def _primary_subject(self, event: dict[str, Any]) -> str:
- asset = event.get("asset", {})
- return str(asset.get("user") or asset.get("hostname") or "unknown")
- def _primary_observable(self, event: dict[str, Any]) -> str:
- network = event.get("network", {})
- return str(network.get("domain") or network.get("src_ip") or network.get("dst_ip") or "unknown")
- def _incident_key(self, event: dict[str, Any]) -> str:
- ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00")).astimezone(timezone.utc)
- day_bucket = ts.strftime("%Y-%m-%d")
- raw = "|".join(
- [
- str(event.get("event_type", "generic")),
- self._primary_subject(event),
- self._primary_observable(event),
- day_bucket,
- ]
- )
- return hashlib.sha256(raw.encode("utf-8")).hexdigest()
- def _effective_severity(self, event: dict[str, Any], policy: dict[str, Any]) -> tuple[str, int, list[str]]:
- severity = str(event.get("severity", "medium")).lower()
- risk_context = event.get("risk_context", {})
- network = event.get("network", {})
- weights = policy.get("risk", {}).get("weights", {})
- score = 0
- factors: list[str] = []
- allowed_country = policy.get("vpn", {}).get("allowed_country", "TH")
- country = str(network.get("country", "")).upper()
- if country and country != allowed_country:
- score += int(weights.get("outside_thailand", 50))
- factors.append("outside_country")
- if risk_context.get("admin_account"):
- score += int(weights.get("admin", 20))
- factors.append("admin_account")
- ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00"))
- if risk_context.get("off_hours") or self._is_off_hours(ts):
- score += int(weights.get("off_hours", 15))
- factors.append("off_hours")
- if risk_context.get("first_seen_country"):
- score += int(weights.get("first_seen_country", 15))
- factors.append("first_seen_country")
- thresholds = policy.get("risk", {}).get("thresholds", {})
- if score >= int(thresholds.get("high", 70)):
- severity = "high" if severity in {"low", "medium"} else severity
- elif score >= int(thresholds.get("medium", 40)) and severity == "low":
- severity = "medium"
- return severity, score, factors
- def _is_exception(self, event: dict[str, Any], policy: dict[str, Any]) -> bool:
- if event.get("event_type") != "vpn_geo_anomaly":
- return False
- asset = event.get("asset", {})
- user = str(asset.get("user", ""))
- allowed_users = set(policy.get("vpn", {}).get("exception_users", []))
- return user in allowed_users
- def _extract_iris_case_id(self, iris_response: dict[str, Any]) -> str | None:
- if "case_id" in iris_response:
- return str(iris_response.get("case_id"))
- data = iris_response.get("data")
- if isinstance(data, dict) and "case_id" in data:
- return str(data.get("case_id"))
- return None
- def _parse_kv_pairs(self, text: str) -> dict[str, str]:
- pattern = r"([A-Za-z0-9_]+)=('(?:[^']*)'|\"(?:[^\"]*)\"|[^\s]+)"
- out: dict[str, str] = {}
- for key, raw in re.findall(pattern, text):
- value = raw.strip().strip("'").strip('"')
- out[key] = value
- return out
- def _severity_from_rule_level(self, rule_level: Any) -> str:
- try:
- level = int(rule_level)
- except (TypeError, ValueError):
- return "medium"
- if level >= 12:
- return "critical"
- if level >= 8:
- return "high"
- if level >= 4:
- return "medium"
- return "low"
- def _event_type_from_text(self, text: str, parsed: dict[str, str]) -> str:
- explicit = parsed.get("event_type")
- if explicit:
- return explicit
- lowered = text.lower()
- if "vpn" in lowered and ("geo" in lowered or "country" in lowered):
- return "vpn_geo_anomaly"
- if "domain" in lowered or "dns" in lowered:
- return "ioc_dns"
- if "c2" in lowered or "ips" in lowered or "ip " in lowered:
- return "ioc_ips"
- if "auth" in lowered and "fail" in lowered:
- return "auth_anomaly"
- return "generic"
- def _normalize_wazuh_hit(self, hit: dict[str, Any]) -> dict[str, Any]:
- src = hit.get("_source", {})
- full_log = str(src.get("full_log", ""))
- parsed = self._parse_kv_pairs(full_log)
- event_id = str(parsed.get("event_id") or src.get("id") or hit.get("_id") or f"wazuh-{int(time.time())}")
- timestamp = (
- src.get("@timestamp")
- or src.get("timestamp")
- or datetime.now(timezone.utc).isoformat()
- )
- rule = src.get("rule", {}) if isinstance(src.get("rule"), dict) else {}
- rule_desc = str(rule.get("description") or "")
- event_type = self._event_type_from_text(full_log, parsed)
- severity = str(parsed.get("severity", "")).lower() or self._severity_from_rule_level(rule.get("level"))
- src_ip = parsed.get("src_ip")
- dst_ip = parsed.get("dst_ip")
- domain = parsed.get("query") or parsed.get("domain")
- country = parsed.get("country")
- user = parsed.get("user") or (src.get("agent", {}) or {}).get("name")
- title = rule_desc or f"Wazuh alert {rule.get('id', '')}".strip()
- description = full_log or rule_desc or "Wazuh alert"
- return {
- "source": "wazuh",
- "event_type": event_type,
- "event_id": event_id,
- "timestamp": timestamp,
- "severity": severity if severity in {"low", "medium", "high", "critical"} else "medium",
- "title": title,
- "description": description,
- "asset": {
- "user": user,
- "hostname": (src.get("agent", {}) or {}).get("name"),
- "agent_id": (src.get("agent", {}) or {}).get("id"),
- },
- "network": {
- "src_ip": src_ip,
- "dst_ip": dst_ip,
- "domain": domain,
- "country": country,
- },
- "tags": ["wazuh", event_type, f"rule_{rule.get('id', 'unknown')}"],
- "risk_context": {
- "outside_thailand": bool(country and str(country).upper() != "TH"),
- },
- "raw": src,
- "payload": {},
- }
- def _to_float(self, value: Any, default: float = 0.0) -> float:
- try:
- return float(value)
- except (TypeError, ValueError):
- return default
- def _severity_from_confidence(self, confidence: float) -> str:
- if confidence >= 0.9:
- return "high"
- if confidence >= 0.7:
- return "medium"
- return "low"
- def _extract_shuffle_verdict(self, shuffle_result: dict[str, Any] | None) -> dict[str, Any]:
- if not isinstance(shuffle_result, dict):
- return {
- "matched": False,
- "confidence": 0.0,
- "severity": "low",
- "evidence": "",
- "iocs": [],
- "reason": "no_shuffle_result",
- }
- flat = dict(shuffle_result)
- nested = shuffle_result.get("result")
- if isinstance(nested, dict):
- merged = dict(nested)
- merged.update(flat)
- flat = merged
- confidence = self._to_float(flat.get("confidence"), 0.0)
- matched_raw = flat.get("matched")
- if isinstance(matched_raw, bool):
- matched = matched_raw
- reason = "shuffle_explicit"
- else:
- matched = confidence >= 0.7
- reason = "confidence_threshold_fallback"
- severity_raw = str(flat.get("severity", "")).lower()
- severity = severity_raw if severity_raw in {"low", "medium", "high", "critical"} else self._severity_from_confidence(confidence)
- return {
- "matched": matched,
- "confidence": confidence,
- "severity": severity,
- "evidence": str(flat.get("evidence", "")),
- "iocs": flat.get("iocs", []),
- "reason": reason,
- "raw": shuffle_result,
- }
- async def ingest_incident(self, event: dict[str, Any]) -> dict[str, Any]:
- policy = self.repo.get_policy()
- incident_key = self._incident_key(event)
- if self._is_exception(event, policy):
- decision_trace = {
- "incident_key": incident_key,
- "policy_exception": True,
- "reason": "vpn_exception_user",
- }
- self.repo.upsert_incident(incident_key, severity="low", status="ignored_exception", iris_case_id=None)
- self.repo.add_event(
- incident_key=incident_key,
- event_id=event.get("event_id"),
- source=event.get("source", "unknown"),
- event_type=event.get("event_type", "generic"),
- raw_payload=event,
- decision_trace=decision_trace,
- )
- return {
- "incident_key": incident_key,
- "action_taken": "ignored_exception",
- "escalation_stub_sent": False,
- "decision_trace": decision_trace,
- }
- effective_severity, risk_score, risk_factors = self._effective_severity(event, policy)
- current = self.repo.get_incident(incident_key)
- action_taken = "updated_case" if current else "created_case"
- iris_case_id = current.get("iris_case_id") if current else None
- if not iris_case_id:
- case_payload = {
- "case_name": event.get("title", "SOC Incident"),
- "case_description": event.get("description", "Generated by soc-integrator MVP"),
- "case_customer": event.get("payload", {}).get("case_customer", settings.iris_default_customer_id),
- "case_soc_id": event.get("payload", {}).get("case_soc_id", settings.iris_default_soc_id),
- }
- iris_result = await self.iris_adapter.create_case(case_payload)
- iris_case_id = self._extract_iris_case_id(iris_result)
- else:
- update_payload = {
- "case_description": f"{event.get('description', 'Updated by soc-integrator MVP')} [event_id={event.get('event_id', '')}]"
- }
- try:
- await self.iris_adapter.update_case(iris_case_id, update_payload)
- except Exception:
- # Keep pipeline progressing for MVP even if update path is unsupported.
- pass
- stored = self.repo.upsert_incident(
- incident_key=incident_key,
- severity=effective_severity,
- status="open",
- iris_case_id=iris_case_id,
- )
- decision_trace = {
- "incident_key": incident_key,
- "risk_score": risk_score,
- "risk_factors": risk_factors,
- "effective_severity": effective_severity,
- "action_taken": action_taken,
- }
- self.repo.add_event(
- incident_key=incident_key,
- event_id=event.get("event_id"),
- source=event.get("source", "unknown"),
- event_type=event.get("event_type", "generic"),
- raw_payload=event,
- decision_trace=decision_trace,
- )
- escalate_severities = set(policy.get("escalate_severities", ["high", "critical"]))
- escalation_stub_sent = False
- stub_response: dict[str, Any] | None = None
- if effective_severity in escalate_severities:
- escalation_payload = {
- "incident_key": incident_key,
- "title": event.get("title", "SOC Incident"),
- "severity": effective_severity,
- "source": event.get("source", "soc-integrator"),
- "iris_case_id": iris_case_id,
- "event_summary": event.get("description", ""),
- "timestamp": event.get("timestamp"),
- }
- try:
- pd_result = await self.pagerduty_adapter.create_incident(escalation_payload)
- escalation_stub_sent = True
- stub_response = {"ok": True, "data": pd_result}
- self.repo.add_escalation_audit(
- incident_key=incident_key,
- status_code=200,
- success=True,
- response_excerpt=self._safe_excerpt(pd_result),
- )
- except Exception as exc:
- stub_response = {"ok": False, "error": str(exc)}
- self.repo.add_escalation_audit(
- incident_key=incident_key,
- status_code=502,
- success=False,
- response_excerpt=self._safe_excerpt(exc),
- )
- return {
- "incident_key": stored["incident_key"],
- "action_taken": action_taken,
- "iris_case_id": stored.get("iris_case_id"),
- "escalation_stub_sent": escalation_stub_sent,
- "stub_response": stub_response,
- "decision_trace": decision_trace,
- }
- async def evaluate_ioc(self, payload: dict[str, Any]) -> dict[str, Any]:
- policy = self.repo.get_policy()
- workflow_id = str(policy.get("shuffle", {}).get("ioc_workflow_id", "")).strip()
- shuffle_result: dict[str, Any] | None = None
- if workflow_id:
- shuffle_result = await self.shuffle_adapter.trigger_workflow(workflow_id, payload)
- verdict = self._extract_shuffle_verdict(shuffle_result)
- matched = bool(verdict["matched"])
- confidence = self._to_float(verdict["confidence"], 0.0)
- logger.info(
- "ioc evaluation workflow_id=%s matched=%s confidence=%.2f",
- workflow_id or "<none>",
- matched,
- confidence,
- )
- if matched:
- src_event = payload.get("source_event", {})
- event_id = src_event.get("event_id") or f"ioc-{int(time.time())}"
- if not isinstance(event_id, str):
- event_id = str(event_id)
- description = f"IOC evaluation result confidence={confidence:.2f}"
- evidence = str(verdict.get("evidence", "")).strip()
- if evidence:
- description = f"{description} evidence={evidence[:180]}"
- event = {
- "source": "shuffle",
- "event_type": "ioc_dns" if payload.get("ioc_type") == "domain" else "ioc_ips",
- "event_id": event_id,
- "timestamp": datetime.now(timezone.utc).isoformat(),
- "severity": verdict["severity"],
- "title": f"IOC match: {payload.get('ioc_value', 'unknown')}",
- "description": description,
- "asset": src_event.get("asset", {}),
- "network": src_event.get("network", {}),
- "tags": ["ioc", str(payload.get("ioc_type", "unknown"))],
- "risk_context": {},
- "raw": {
- "payload": payload,
- "shuffle": verdict.get("raw"),
- },
- "payload": {},
- }
- ingest_result = await self.ingest_incident(event)
- else:
- ingest_result = {"action_taken": "rejected"}
- return {
- "matched": matched,
- "confidence": confidence,
- "severity": verdict["severity"],
- "evidence": verdict["evidence"],
- "iocs": verdict["iocs"],
- "decision_source": verdict["reason"],
- "shuffle": shuffle_result,
- "result": ingest_result,
- }
- async def evaluate_vpn(self, payload: dict[str, Any]) -> dict[str, Any]:
- if not payload.get("success", False):
- return {
- "risk_score": 0,
- "risk_factors": [],
- "exception_applied": False,
- "action_taken": "rejected",
- }
- event = {
- "source": "wazuh",
- "event_type": "vpn_geo_anomaly",
- "event_id": payload.get("event_id") or f"vpn-{int(time.time())}",
- "timestamp": payload.get("event_time") or datetime.now(timezone.utc).isoformat(),
- "severity": "high",
- "title": f"VPN login anomaly: {payload.get('user', 'unknown')}",
- "description": f"VPN success from {payload.get('country_code', 'unknown')} for user {payload.get('user', 'unknown')}",
- "asset": {"user": payload.get("user")},
- "network": {"src_ip": payload.get("src_ip"), "country": payload.get("country_code")},
- "tags": ["vpn", "geo-anomaly"],
- "risk_context": {
- "outside_thailand": payload.get("country_code", "").upper() != "TH",
- "admin_account": bool(payload.get("is_admin", False)),
- "off_hours": bool(payload.get("off_hours", False)),
- "first_seen_country": bool(payload.get("first_seen_country", False)),
- },
- "raw": payload,
- "payload": {},
- }
- ingest_result = await self.ingest_incident(event)
- decision_trace = ingest_result.get("decision_trace", {})
- return {
- "risk_score": decision_trace.get("risk_score", 0),
- "risk_factors": decision_trace.get("risk_factors", []),
- "exception_applied": ingest_result.get("action_taken") == "ignored_exception",
- "action_taken": ingest_result.get("action_taken"),
- "incident_key": ingest_result.get("incident_key"),
- "iris_case_id": ingest_result.get("iris_case_id"),
- "escalation_stub_sent": ingest_result.get("escalation_stub_sent", False),
- }
- async def sync_wazuh_alerts(
- self,
- query: str = "soc_mvp_test=true OR event_type:*",
- limit: int = 50,
- minutes: int = 120,
- ) -> dict[str, Any]:
- raw = await self.wazuh_adapter.search_alerts(query=query, limit=limit, minutes=minutes)
- hits = (raw.get("hits", {}) or {}).get("hits", []) if isinstance(raw, dict) else []
- processed = 0
- ingested = 0
- skipped_existing = 0
- failed = 0
- errors: list[str] = []
- created_incidents: list[str] = []
- ioc_evaluated = 0
- ioc_matched = 0
- ioc_rejected = 0
- for hit in hits:
- processed += 1
- event = self._normalize_wazuh_hit(hit)
- event_id = str(event.get("event_id", "")).strip()
- if event_id and self.repo.has_event("wazuh", event_id):
- skipped_existing += 1
- continue
- try:
- if event.get("event_type") in {"ioc_dns", "ioc_ips"}:
- ioc_evaluated += 1
- payload = {
- "ioc_type": "domain" if event.get("event_type") == "ioc_dns" else "ip",
- "ioc_value": (event.get("network", {}) or {}).get("domain")
- or (event.get("network", {}) or {}).get("dst_ip")
- or (event.get("network", {}) or {}).get("src_ip")
- or "unknown",
- "source_event": {
- "event_id": event.get("event_id"),
- "asset": event.get("asset", {}),
- "network": event.get("network", {}),
- "raw": event.get("raw", {}),
- },
- }
- ioc_result = await self.evaluate_ioc(payload)
- if ioc_result.get("matched"):
- ioc_matched += 1
- ingested += 1
- incident_key = str((ioc_result.get("result", {}) or {}).get("incident_key", ""))
- if incident_key:
- created_incidents.append(incident_key)
- else:
- ioc_rejected += 1
- else:
- result = await self.ingest_incident(event)
- ingested += 1
- incident_key = str(result.get("incident_key", ""))
- if incident_key:
- created_incidents.append(incident_key)
- except Exception as exc:
- failed += 1
- errors.append(f"{event_id or 'unknown_event'}: {exc}")
- return {
- "query": query,
- "window_minutes": minutes,
- "limit": limit,
- "processed": processed,
- "ingested": ingested,
- "skipped_existing": skipped_existing,
- "failed": failed,
- "ioc_evaluated": ioc_evaluated,
- "ioc_matched": ioc_matched,
- "ioc_rejected": ioc_rejected,
- "incident_keys": created_incidents,
- "errors": errors[:10],
- "total_hits": (raw.get("hits", {}).get("total", {}) if isinstance(raw, dict) else {}),
- }
- async def dependency_health(self) -> dict[str, Any]:
- out: dict[str, Any] = {}
- async def timed(name: str, fn):
- start = time.time()
- try:
- result = await fn()
- out[name] = {
- "status": "up",
- "latency_ms": round((time.time() - start) * 1000, 2),
- "details": result,
- }
- except Exception as exc:
- out[name] = {
- "status": "down",
- "latency_ms": round((time.time() - start) * 1000, 2),
- "error": str(exc),
- }
- await timed("wazuh", self.wazuh_adapter.auth_test)
- await timed("shuffle", self.shuffle_adapter.health)
- await timed("iris", self.iris_adapter.whoami)
- async def pagerduty_stub_health():
- async with httpx.AsyncClient(timeout=10.0) as client:
- r = await client.get(settings.pagerduty_base_url)
- r.raise_for_status()
- return {"status_code": r.status_code}
- await timed("pagerduty_stub", pagerduty_stub_health)
- return out
|