暂无描述

mvp_service.py 23KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. from __future__ import annotations
  2. import hashlib
  3. import logging
  4. import re
  5. import time
  6. from datetime import datetime, timezone
  7. from typing import Any
  8. import httpx
  9. from app.adapters.iris import IrisAdapter
  10. from app.adapters.pagerduty import PagerDutyAdapter
  11. from app.adapters.shuffle import ShuffleAdapter
  12. from app.adapters.wazuh import WazuhAdapter
  13. from app.config import settings
  14. from app.repositories.mvp_repo import MvpRepository
  15. logger = logging.getLogger(__name__)
  16. class MvpService:
  17. def __init__(
  18. self,
  19. repo: MvpRepository,
  20. wazuh_adapter: WazuhAdapter,
  21. shuffle_adapter: ShuffleAdapter,
  22. iris_adapter: IrisAdapter,
  23. pagerduty_adapter: PagerDutyAdapter,
  24. ) -> None:
  25. self.repo = repo
  26. self.wazuh_adapter = wazuh_adapter
  27. self.shuffle_adapter = shuffle_adapter
  28. self.iris_adapter = iris_adapter
  29. self.pagerduty_adapter = pagerduty_adapter
  30. def _is_off_hours(self, ts: datetime) -> bool:
  31. hour = ts.astimezone(timezone.utc).hour
  32. return hour < 6 or hour >= 20
  33. def _safe_excerpt(self, payload: Any) -> str:
  34. text = str(payload)
  35. return text[:300]
  36. def _primary_subject(self, event: dict[str, Any]) -> str:
  37. asset = event.get("asset", {})
  38. return str(asset.get("user") or asset.get("hostname") or "unknown")
  39. def _primary_observable(self, event: dict[str, Any]) -> str:
  40. network = event.get("network", {})
  41. return str(network.get("domain") or network.get("src_ip") or network.get("dst_ip") or "unknown")
  42. def _incident_key(self, event: dict[str, Any]) -> str:
  43. ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00")).astimezone(timezone.utc)
  44. day_bucket = ts.strftime("%Y-%m-%d")
  45. raw = "|".join(
  46. [
  47. str(event.get("event_type", "generic")),
  48. self._primary_subject(event),
  49. self._primary_observable(event),
  50. day_bucket,
  51. ]
  52. )
  53. return hashlib.sha256(raw.encode("utf-8")).hexdigest()
  54. def _effective_severity(self, event: dict[str, Any], policy: dict[str, Any]) -> tuple[str, int, list[str]]:
  55. severity = str(event.get("severity", "medium")).lower()
  56. risk_context = event.get("risk_context", {})
  57. network = event.get("network", {})
  58. weights = policy.get("risk", {}).get("weights", {})
  59. score = 0
  60. factors: list[str] = []
  61. allowed_country = policy.get("vpn", {}).get("allowed_country", "TH")
  62. country = str(network.get("country", "")).upper()
  63. if country and country != allowed_country:
  64. score += int(weights.get("outside_thailand", 50))
  65. factors.append("outside_country")
  66. if risk_context.get("admin_account"):
  67. score += int(weights.get("admin", 20))
  68. factors.append("admin_account")
  69. ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00"))
  70. if risk_context.get("off_hours") or self._is_off_hours(ts):
  71. score += int(weights.get("off_hours", 15))
  72. factors.append("off_hours")
  73. if risk_context.get("first_seen_country"):
  74. score += int(weights.get("first_seen_country", 15))
  75. factors.append("first_seen_country")
  76. thresholds = policy.get("risk", {}).get("thresholds", {})
  77. if score >= int(thresholds.get("high", 70)):
  78. severity = "high" if severity in {"low", "medium"} else severity
  79. elif score >= int(thresholds.get("medium", 40)) and severity == "low":
  80. severity = "medium"
  81. return severity, score, factors
  82. def _is_exception(self, event: dict[str, Any], policy: dict[str, Any]) -> bool:
  83. if event.get("event_type") != "vpn_geo_anomaly":
  84. return False
  85. asset = event.get("asset", {})
  86. user = str(asset.get("user", ""))
  87. allowed_users = set(policy.get("vpn", {}).get("exception_users", []))
  88. return user in allowed_users
  89. def _extract_iris_case_id(self, iris_response: dict[str, Any]) -> str | None:
  90. if "case_id" in iris_response:
  91. return str(iris_response.get("case_id"))
  92. data = iris_response.get("data")
  93. if isinstance(data, dict) and "case_id" in data:
  94. return str(data.get("case_id"))
  95. return None
  96. def _parse_kv_pairs(self, text: str) -> dict[str, str]:
  97. pattern = r"([A-Za-z0-9_]+)=('(?:[^']*)'|\"(?:[^\"]*)\"|[^\s]+)"
  98. out: dict[str, str] = {}
  99. for key, raw in re.findall(pattern, text):
  100. value = raw.strip().strip("'").strip('"')
  101. out[key] = value
  102. return out
  103. def _severity_from_rule_level(self, rule_level: Any) -> str:
  104. try:
  105. level = int(rule_level)
  106. except (TypeError, ValueError):
  107. return "medium"
  108. if level >= 12:
  109. return "critical"
  110. if level >= 8:
  111. return "high"
  112. if level >= 4:
  113. return "medium"
  114. return "low"
  115. def _event_type_from_text(self, text: str, parsed: dict[str, str]) -> str:
  116. explicit = parsed.get("event_type")
  117. if explicit:
  118. return explicit
  119. lowered = text.lower()
  120. if "vpn" in lowered and ("geo" in lowered or "country" in lowered):
  121. return "vpn_geo_anomaly"
  122. if "domain" in lowered or "dns" in lowered:
  123. return "ioc_dns"
  124. if "c2" in lowered or "ips" in lowered or "ip " in lowered:
  125. return "ioc_ips"
  126. if "auth" in lowered and "fail" in lowered:
  127. return "auth_anomaly"
  128. return "generic"
  129. def _normalize_wazuh_hit(self, hit: dict[str, Any]) -> dict[str, Any]:
  130. src = hit.get("_source", {})
  131. full_log = str(src.get("full_log", ""))
  132. parsed = self._parse_kv_pairs(full_log)
  133. event_id = str(parsed.get("event_id") or src.get("id") or hit.get("_id") or f"wazuh-{int(time.time())}")
  134. timestamp = (
  135. src.get("@timestamp")
  136. or src.get("timestamp")
  137. or datetime.now(timezone.utc).isoformat()
  138. )
  139. rule = src.get("rule", {}) if isinstance(src.get("rule"), dict) else {}
  140. rule_desc = str(rule.get("description") or "")
  141. event_type = self._event_type_from_text(full_log, parsed)
  142. severity = str(parsed.get("severity", "")).lower() or self._severity_from_rule_level(rule.get("level"))
  143. src_ip = parsed.get("src_ip")
  144. dst_ip = parsed.get("dst_ip")
  145. domain = parsed.get("query") or parsed.get("domain")
  146. country = parsed.get("country")
  147. user = parsed.get("user") or (src.get("agent", {}) or {}).get("name")
  148. title = rule_desc or f"Wazuh alert {rule.get('id', '')}".strip()
  149. description = full_log or rule_desc or "Wazuh alert"
  150. return {
  151. "source": "wazuh",
  152. "event_type": event_type,
  153. "event_id": event_id,
  154. "timestamp": timestamp,
  155. "severity": severity if severity in {"low", "medium", "high", "critical"} else "medium",
  156. "title": title,
  157. "description": description,
  158. "asset": {
  159. "user": user,
  160. "hostname": (src.get("agent", {}) or {}).get("name"),
  161. "agent_id": (src.get("agent", {}) or {}).get("id"),
  162. },
  163. "network": {
  164. "src_ip": src_ip,
  165. "dst_ip": dst_ip,
  166. "domain": domain,
  167. "country": country,
  168. },
  169. "tags": ["wazuh", event_type, f"rule_{rule.get('id', 'unknown')}"],
  170. "risk_context": {
  171. "outside_thailand": bool(country and str(country).upper() != "TH"),
  172. },
  173. "raw": src,
  174. "payload": {},
  175. }
  176. def _to_float(self, value: Any, default: float = 0.0) -> float:
  177. try:
  178. return float(value)
  179. except (TypeError, ValueError):
  180. return default
  181. def _severity_from_confidence(self, confidence: float) -> str:
  182. if confidence >= 0.9:
  183. return "high"
  184. if confidence >= 0.7:
  185. return "medium"
  186. return "low"
  187. def _extract_shuffle_verdict(self, shuffle_result: dict[str, Any] | None) -> dict[str, Any]:
  188. if not isinstance(shuffle_result, dict):
  189. return {
  190. "matched": False,
  191. "confidence": 0.0,
  192. "severity": "low",
  193. "evidence": "",
  194. "iocs": [],
  195. "reason": "no_shuffle_result",
  196. }
  197. flat = dict(shuffle_result)
  198. nested = shuffle_result.get("result")
  199. if isinstance(nested, dict):
  200. merged = dict(nested)
  201. merged.update(flat)
  202. flat = merged
  203. confidence = self._to_float(flat.get("confidence"), 0.0)
  204. matched_raw = flat.get("matched")
  205. if isinstance(matched_raw, bool):
  206. matched = matched_raw
  207. reason = "shuffle_explicit"
  208. else:
  209. matched = confidence >= 0.7
  210. reason = "confidence_threshold_fallback"
  211. severity_raw = str(flat.get("severity", "")).lower()
  212. severity = severity_raw if severity_raw in {"low", "medium", "high", "critical"} else self._severity_from_confidence(confidence)
  213. return {
  214. "matched": matched,
  215. "confidence": confidence,
  216. "severity": severity,
  217. "evidence": str(flat.get("evidence", "")),
  218. "iocs": flat.get("iocs", []),
  219. "reason": reason,
  220. "raw": shuffle_result,
  221. }
  222. async def ingest_incident(self, event: dict[str, Any]) -> dict[str, Any]:
  223. policy = self.repo.get_policy()
  224. incident_key = self._incident_key(event)
  225. if self._is_exception(event, policy):
  226. decision_trace = {
  227. "incident_key": incident_key,
  228. "policy_exception": True,
  229. "reason": "vpn_exception_user",
  230. }
  231. self.repo.upsert_incident(incident_key, severity="low", status="ignored_exception", iris_case_id=None)
  232. self.repo.add_event(
  233. incident_key=incident_key,
  234. event_id=event.get("event_id"),
  235. source=event.get("source", "unknown"),
  236. event_type=event.get("event_type", "generic"),
  237. raw_payload=event,
  238. decision_trace=decision_trace,
  239. )
  240. return {
  241. "incident_key": incident_key,
  242. "action_taken": "ignored_exception",
  243. "escalation_stub_sent": False,
  244. "decision_trace": decision_trace,
  245. }
  246. effective_severity, risk_score, risk_factors = self._effective_severity(event, policy)
  247. current = self.repo.get_incident(incident_key)
  248. action_taken = "updated_case" if current else "created_case"
  249. iris_case_id = current.get("iris_case_id") if current else None
  250. if not iris_case_id:
  251. case_payload = {
  252. "case_name": event.get("title", "SOC Incident"),
  253. "case_description": event.get("description", "Generated by soc-integrator MVP"),
  254. "case_customer": event.get("payload", {}).get("case_customer", settings.iris_default_customer_id),
  255. "case_soc_id": event.get("payload", {}).get("case_soc_id", settings.iris_default_soc_id),
  256. }
  257. iris_result = await self.iris_adapter.create_case(case_payload)
  258. iris_case_id = self._extract_iris_case_id(iris_result)
  259. else:
  260. update_payload = {
  261. "case_description": f"{event.get('description', 'Updated by soc-integrator MVP')} [event_id={event.get('event_id', '')}]"
  262. }
  263. try:
  264. await self.iris_adapter.update_case(iris_case_id, update_payload)
  265. except Exception:
  266. # Keep pipeline progressing for MVP even if update path is unsupported.
  267. pass
  268. stored = self.repo.upsert_incident(
  269. incident_key=incident_key,
  270. severity=effective_severity,
  271. status="open",
  272. iris_case_id=iris_case_id,
  273. )
  274. decision_trace = {
  275. "incident_key": incident_key,
  276. "risk_score": risk_score,
  277. "risk_factors": risk_factors,
  278. "effective_severity": effective_severity,
  279. "action_taken": action_taken,
  280. }
  281. self.repo.add_event(
  282. incident_key=incident_key,
  283. event_id=event.get("event_id"),
  284. source=event.get("source", "unknown"),
  285. event_type=event.get("event_type", "generic"),
  286. raw_payload=event,
  287. decision_trace=decision_trace,
  288. )
  289. escalate_severities = set(policy.get("escalate_severities", ["high", "critical"]))
  290. escalation_stub_sent = False
  291. stub_response: dict[str, Any] | None = None
  292. if effective_severity in escalate_severities:
  293. escalation_payload = {
  294. "incident_key": incident_key,
  295. "title": event.get("title", "SOC Incident"),
  296. "severity": effective_severity,
  297. "source": event.get("source", "soc-integrator"),
  298. "iris_case_id": iris_case_id,
  299. "event_summary": event.get("description", ""),
  300. "timestamp": event.get("timestamp"),
  301. }
  302. try:
  303. pd_result = await self.pagerduty_adapter.create_incident(escalation_payload)
  304. escalation_stub_sent = True
  305. stub_response = {"ok": True, "data": pd_result}
  306. self.repo.add_escalation_audit(
  307. incident_key=incident_key,
  308. status_code=200,
  309. success=True,
  310. response_excerpt=self._safe_excerpt(pd_result),
  311. )
  312. except Exception as exc:
  313. stub_response = {"ok": False, "error": str(exc)}
  314. self.repo.add_escalation_audit(
  315. incident_key=incident_key,
  316. status_code=502,
  317. success=False,
  318. response_excerpt=self._safe_excerpt(exc),
  319. )
  320. return {
  321. "incident_key": stored["incident_key"],
  322. "action_taken": action_taken,
  323. "iris_case_id": stored.get("iris_case_id"),
  324. "escalation_stub_sent": escalation_stub_sent,
  325. "stub_response": stub_response,
  326. "decision_trace": decision_trace,
  327. }
  328. async def evaluate_ioc(self, payload: dict[str, Any]) -> dict[str, Any]:
  329. policy = self.repo.get_policy()
  330. workflow_id = str(policy.get("shuffle", {}).get("ioc_workflow_id", "")).strip()
  331. shuffle_result: dict[str, Any] | None = None
  332. if workflow_id:
  333. shuffle_result = await self.shuffle_adapter.trigger_workflow(workflow_id, payload)
  334. verdict = self._extract_shuffle_verdict(shuffle_result)
  335. matched = bool(verdict["matched"])
  336. confidence = self._to_float(verdict["confidence"], 0.0)
  337. logger.info(
  338. "ioc evaluation workflow_id=%s matched=%s confidence=%.2f",
  339. workflow_id or "<none>",
  340. matched,
  341. confidence,
  342. )
  343. if matched:
  344. src_event = payload.get("source_event", {})
  345. event_id = src_event.get("event_id") or f"ioc-{int(time.time())}"
  346. if not isinstance(event_id, str):
  347. event_id = str(event_id)
  348. description = f"IOC evaluation result confidence={confidence:.2f}"
  349. evidence = str(verdict.get("evidence", "")).strip()
  350. if evidence:
  351. description = f"{description} evidence={evidence[:180]}"
  352. event = {
  353. "source": "shuffle",
  354. "event_type": "ioc_dns" if payload.get("ioc_type") == "domain" else "ioc_ips",
  355. "event_id": event_id,
  356. "timestamp": datetime.now(timezone.utc).isoformat(),
  357. "severity": verdict["severity"],
  358. "title": f"IOC match: {payload.get('ioc_value', 'unknown')}",
  359. "description": description,
  360. "asset": src_event.get("asset", {}),
  361. "network": src_event.get("network", {}),
  362. "tags": ["ioc", str(payload.get("ioc_type", "unknown"))],
  363. "risk_context": {},
  364. "raw": {
  365. "payload": payload,
  366. "shuffle": verdict.get("raw"),
  367. },
  368. "payload": {},
  369. }
  370. ingest_result = await self.ingest_incident(event)
  371. else:
  372. ingest_result = {"action_taken": "rejected"}
  373. return {
  374. "matched": matched,
  375. "confidence": confidence,
  376. "severity": verdict["severity"],
  377. "evidence": verdict["evidence"],
  378. "iocs": verdict["iocs"],
  379. "decision_source": verdict["reason"],
  380. "shuffle": shuffle_result,
  381. "result": ingest_result,
  382. }
  383. async def evaluate_vpn(self, payload: dict[str, Any]) -> dict[str, Any]:
  384. if not payload.get("success", False):
  385. return {
  386. "risk_score": 0,
  387. "risk_factors": [],
  388. "exception_applied": False,
  389. "action_taken": "rejected",
  390. }
  391. event = {
  392. "source": "wazuh",
  393. "event_type": "vpn_geo_anomaly",
  394. "event_id": payload.get("event_id") or f"vpn-{int(time.time())}",
  395. "timestamp": payload.get("event_time") or datetime.now(timezone.utc).isoformat(),
  396. "severity": "high",
  397. "title": f"VPN login anomaly: {payload.get('user', 'unknown')}",
  398. "description": f"VPN success from {payload.get('country_code', 'unknown')} for user {payload.get('user', 'unknown')}",
  399. "asset": {"user": payload.get("user")},
  400. "network": {"src_ip": payload.get("src_ip"), "country": payload.get("country_code")},
  401. "tags": ["vpn", "geo-anomaly"],
  402. "risk_context": {
  403. "outside_thailand": payload.get("country_code", "").upper() != "TH",
  404. "admin_account": bool(payload.get("is_admin", False)),
  405. "off_hours": bool(payload.get("off_hours", False)),
  406. "first_seen_country": bool(payload.get("first_seen_country", False)),
  407. },
  408. "raw": payload,
  409. "payload": {},
  410. }
  411. ingest_result = await self.ingest_incident(event)
  412. decision_trace = ingest_result.get("decision_trace", {})
  413. return {
  414. "risk_score": decision_trace.get("risk_score", 0),
  415. "risk_factors": decision_trace.get("risk_factors", []),
  416. "exception_applied": ingest_result.get("action_taken") == "ignored_exception",
  417. "action_taken": ingest_result.get("action_taken"),
  418. "incident_key": ingest_result.get("incident_key"),
  419. "iris_case_id": ingest_result.get("iris_case_id"),
  420. "escalation_stub_sent": ingest_result.get("escalation_stub_sent", False),
  421. }
  422. async def sync_wazuh_alerts(
  423. self,
  424. query: str = "soc_mvp_test=true OR event_type:*",
  425. limit: int = 50,
  426. minutes: int = 120,
  427. ) -> dict[str, Any]:
  428. raw = await self.wazuh_adapter.search_alerts(query=query, limit=limit, minutes=minutes)
  429. hits = (raw.get("hits", {}) or {}).get("hits", []) if isinstance(raw, dict) else []
  430. processed = 0
  431. ingested = 0
  432. skipped_existing = 0
  433. failed = 0
  434. errors: list[str] = []
  435. created_incidents: list[str] = []
  436. ioc_evaluated = 0
  437. ioc_matched = 0
  438. ioc_rejected = 0
  439. for hit in hits:
  440. processed += 1
  441. event = self._normalize_wazuh_hit(hit)
  442. event_id = str(event.get("event_id", "")).strip()
  443. if event_id and self.repo.has_event("wazuh", event_id):
  444. skipped_existing += 1
  445. continue
  446. try:
  447. if event.get("event_type") in {"ioc_dns", "ioc_ips"}:
  448. ioc_evaluated += 1
  449. payload = {
  450. "ioc_type": "domain" if event.get("event_type") == "ioc_dns" else "ip",
  451. "ioc_value": (event.get("network", {}) or {}).get("domain")
  452. or (event.get("network", {}) or {}).get("dst_ip")
  453. or (event.get("network", {}) or {}).get("src_ip")
  454. or "unknown",
  455. "source_event": {
  456. "event_id": event.get("event_id"),
  457. "asset": event.get("asset", {}),
  458. "network": event.get("network", {}),
  459. "raw": event.get("raw", {}),
  460. },
  461. }
  462. ioc_result = await self.evaluate_ioc(payload)
  463. if ioc_result.get("matched"):
  464. ioc_matched += 1
  465. ingested += 1
  466. incident_key = str((ioc_result.get("result", {}) or {}).get("incident_key", ""))
  467. if incident_key:
  468. created_incidents.append(incident_key)
  469. else:
  470. ioc_rejected += 1
  471. else:
  472. result = await self.ingest_incident(event)
  473. ingested += 1
  474. incident_key = str(result.get("incident_key", ""))
  475. if incident_key:
  476. created_incidents.append(incident_key)
  477. except Exception as exc:
  478. failed += 1
  479. errors.append(f"{event_id or 'unknown_event'}: {exc}")
  480. return {
  481. "query": query,
  482. "window_minutes": minutes,
  483. "limit": limit,
  484. "processed": processed,
  485. "ingested": ingested,
  486. "skipped_existing": skipped_existing,
  487. "failed": failed,
  488. "ioc_evaluated": ioc_evaluated,
  489. "ioc_matched": ioc_matched,
  490. "ioc_rejected": ioc_rejected,
  491. "incident_keys": created_incidents,
  492. "errors": errors[:10],
  493. "total_hits": (raw.get("hits", {}).get("total", {}) if isinstance(raw, dict) else {}),
  494. }
  495. async def dependency_health(self) -> dict[str, Any]:
  496. out: dict[str, Any] = {}
  497. async def timed(name: str, fn):
  498. start = time.time()
  499. try:
  500. result = await fn()
  501. out[name] = {
  502. "status": "up",
  503. "latency_ms": round((time.time() - start) * 1000, 2),
  504. "details": result,
  505. }
  506. except Exception as exc:
  507. out[name] = {
  508. "status": "down",
  509. "latency_ms": round((time.time() - start) * 1000, 2),
  510. "error": str(exc),
  511. }
  512. await timed("wazuh", self.wazuh_adapter.auth_test)
  513. await timed("shuffle", self.shuffle_adapter.health)
  514. await timed("iris", self.iris_adapter.whoami)
  515. async def pagerduty_stub_health():
  516. async with httpx.AsyncClient(timeout=10.0) as client:
  517. r = await client.get(settings.pagerduty_base_url)
  518. r.raise_for_status()
  519. return {"status_code": r.status_code}
  520. await timed("pagerduty_stub", pagerduty_stub_health)
  521. return out