Нет описания

mvp_service.py 29KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732
  1. from __future__ import annotations
  2. import hashlib
  3. import json
  4. import logging
  5. import re
  6. import time
  7. from datetime import datetime, timezone
  8. from typing import Any
  9. import httpx
  10. from app.adapters.iris import IrisAdapter
  11. from app.adapters.pagerduty import PagerDutyAdapter
  12. from app.adapters.shuffle import ShuffleAdapter
  13. from app.adapters.wazuh import WazuhAdapter
  14. from app.config import settings
  15. from app.repositories.mvp_repo import MvpRepository
  16. logger = logging.getLogger(__name__)
  17. _IRIS_SEVERITY_ID: dict[str, int] = {
  18. "critical": 5,
  19. "high": 4,
  20. "medium": 3,
  21. "low": 2,
  22. "informational": 1,
  23. }
  24. _SEVERITY_ORDER: dict[str, int] = {
  25. "informational": 0,
  26. "low": 1,
  27. "medium": 2,
  28. "high": 3,
  29. "critical": 4,
  30. }
  31. class MvpService:
  32. def __init__(
  33. self,
  34. repo: MvpRepository,
  35. wazuh_adapter: WazuhAdapter,
  36. shuffle_adapter: ShuffleAdapter,
  37. iris_adapter: IrisAdapter,
  38. pagerduty_adapter: PagerDutyAdapter,
  39. ) -> None:
  40. self.repo = repo
  41. self.wazuh_adapter = wazuh_adapter
  42. self.shuffle_adapter = shuffle_adapter
  43. self.iris_adapter = iris_adapter
  44. self.pagerduty_adapter = pagerduty_adapter
  45. def _is_off_hours(self, ts: datetime) -> bool:
  46. hour = ts.astimezone(timezone.utc).hour
  47. return hour < 6 or hour >= 20
  48. def _safe_excerpt(self, payload: Any) -> str:
  49. text = str(payload)
  50. return text[:300]
  51. def _primary_subject(self, event: dict[str, Any]) -> str:
  52. asset = event.get("asset", {})
  53. return str(asset.get("user") or asset.get("hostname") or "unknown")
  54. def _primary_observable(self, event: dict[str, Any]) -> str:
  55. network = event.get("network", {})
  56. return str(network.get("domain") or network.get("src_ip") or network.get("dst_ip") or "unknown")
  57. def _incident_key(self, event: dict[str, Any]) -> str:
  58. ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00")).astimezone(timezone.utc)
  59. day_bucket = ts.strftime("%Y-%m-%d")
  60. raw = "|".join(
  61. [
  62. str(event.get("event_type", "generic")),
  63. self._primary_subject(event),
  64. self._primary_observable(event),
  65. day_bucket,
  66. ]
  67. )
  68. return hashlib.sha256(raw.encode("utf-8")).hexdigest()
  69. def _effective_severity(self, event: dict[str, Any], policy: dict[str, Any]) -> tuple[str, int, list[str]]:
  70. severity = str(event.get("severity", "medium")).lower()
  71. risk_context = event.get("risk_context", {})
  72. network = event.get("network", {})
  73. weights = policy.get("risk", {}).get("weights", {})
  74. score = 0
  75. factors: list[str] = []
  76. allowed_country = policy.get("vpn", {}).get("allowed_country", "TH")
  77. country = str(network.get("country", "")).upper()
  78. if country and country != allowed_country:
  79. score += int(weights.get("outside_thailand", 50))
  80. factors.append("outside_country")
  81. if risk_context.get("admin_account"):
  82. score += int(weights.get("admin", 20))
  83. factors.append("admin_account")
  84. ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00"))
  85. if risk_context.get("off_hours") or self._is_off_hours(ts):
  86. score += int(weights.get("off_hours", 15))
  87. factors.append("off_hours")
  88. if risk_context.get("first_seen_country"):
  89. score += int(weights.get("first_seen_country", 15))
  90. factors.append("first_seen_country")
  91. thresholds = policy.get("risk", {}).get("thresholds", {})
  92. if score >= int(thresholds.get("high", 70)):
  93. severity = "high" if severity in {"low", "medium"} else severity
  94. elif score >= int(thresholds.get("medium", 40)) and severity == "low":
  95. severity = "medium"
  96. return severity, score, factors
  97. def _is_exception(self, event: dict[str, Any], policy: dict[str, Any]) -> bool:
  98. if event.get("event_type") != "vpn_geo_anomaly":
  99. return False
  100. asset = event.get("asset", {})
  101. user = str(asset.get("user", ""))
  102. allowed_users = set(policy.get("vpn", {}).get("exception_users", []))
  103. return user in allowed_users
  104. def _extract_iris_case_id(self, iris_response: dict[str, Any]) -> str | None:
  105. if "case_id" in iris_response:
  106. return str(iris_response.get("case_id"))
  107. data = iris_response.get("data")
  108. if isinstance(data, dict) and "case_id" in data:
  109. return str(data.get("case_id"))
  110. return None
  111. def _parse_kv_pairs(self, text: str) -> dict[str, str]:
  112. pattern = r"([A-Za-z0-9_]+)=('(?:[^']*)'|\"(?:[^\"]*)\"|[^\s]+)"
  113. out: dict[str, str] = {}
  114. for key, raw in re.findall(pattern, text):
  115. value = raw.strip().strip("'").strip('"')
  116. out[key] = value
  117. return out
  118. def _severity_from_rule_level(self, rule_level: Any) -> str:
  119. try:
  120. level = int(rule_level)
  121. except (TypeError, ValueError):
  122. return "medium"
  123. if level >= 12:
  124. return "critical"
  125. if level >= 8:
  126. return "high"
  127. if level >= 4:
  128. return "medium"
  129. return "low"
  130. def _event_type_from_text(self, text: str, parsed: dict[str, str]) -> str:
  131. explicit = str(parsed.get("event_type") or "").strip().lower()
  132. usecase_id = str(parsed.get("usecase_id") or "").strip().upper()
  133. section = str(parsed.get("section") or "").strip().upper()
  134. source = str(parsed.get("source") or "").strip().lower()
  135. success = str(parsed.get("success") or "").strip().lower()
  136. has_geo = bool(parsed.get("country") or parsed.get("src_lat") or parsed.get("src_lon"))
  137. has_user = bool(parsed.get("user"))
  138. has_src_ip = bool(parsed.get("src_ip") or parsed.get("srcip"))
  139. explicit_success_login = explicit in {
  140. "vpn_login_success",
  141. "windows_auth_success",
  142. "auth_success",
  143. }
  144. # Production-first C1 detection:
  145. # successful auth/login + geo context on vpn/windows identity streams.
  146. if (
  147. (source in {"vpn", "fortigate", "windows", "identity"} or "vpn" in source)
  148. and has_geo
  149. and has_user
  150. and has_src_ip
  151. and (success == "true" or explicit_success_login)
  152. ):
  153. return "c1_impossible_travel"
  154. # Legacy simulator markers remain supported as fallback.
  155. if usecase_id.startswith("C1") or section == "C1":
  156. return "c1_impossible_travel"
  157. if explicit in {"c1_impossible_travel", "impossible_travel"}:
  158. return "c1_impossible_travel"
  159. if explicit == "vpn_geo_anomaly":
  160. return "vpn_geo_anomaly"
  161. if explicit:
  162. return explicit
  163. lowered = text.lower()
  164. if "impossible travel" in lowered:
  165. return "c1_impossible_travel"
  166. if "vpn" in lowered and ("geo" in lowered or "country" in lowered):
  167. return "vpn_geo_anomaly"
  168. if "domain" in lowered or "dns" in lowered:
  169. return "ioc_dns"
  170. if "c2" in lowered or "ips" in lowered or "ip " in lowered:
  171. return "ioc_ips"
  172. if "auth" in lowered and "fail" in lowered:
  173. return "auth_anomaly"
  174. return "generic"
  175. def _normalize_wazuh_hit(self, hit: dict[str, Any]) -> dict[str, Any]:
  176. src = hit.get("_source", {})
  177. full_log = str(src.get("full_log", ""))
  178. parsed = self._parse_kv_pairs(full_log)
  179. event_id = str(parsed.get("event_id") or src.get("id") or hit.get("_id") or f"wazuh-{int(time.time())}")
  180. timestamp = (
  181. src.get("@timestamp")
  182. or src.get("timestamp")
  183. or datetime.now(timezone.utc).isoformat()
  184. )
  185. rule = src.get("rule", {}) if isinstance(src.get("rule"), dict) else {}
  186. rule_desc = str(rule.get("description") or "")
  187. event_type = self._event_type_from_text(full_log, parsed)
  188. severity = str(parsed.get("severity", "")).lower() or self._severity_from_rule_level(rule.get("level"))
  189. src_ip = parsed.get("src_ip")
  190. if not src_ip:
  191. src_ip = parsed.get("srcip")
  192. dst_ip = parsed.get("dst_ip")
  193. if not dst_ip:
  194. dst_ip = parsed.get("dstip")
  195. domain = parsed.get("query") or parsed.get("domain")
  196. country = parsed.get("country")
  197. user = parsed.get("user") or (src.get("agent", {}) or {}).get("name")
  198. dst_port = parsed.get("dst_port") or parsed.get("dstport")
  199. event_action = parsed.get("event_action") or parsed.get("action")
  200. title = rule_desc or f"Wazuh alert {rule.get('id', '')}".strip()
  201. description = full_log or rule_desc or "Wazuh alert"
  202. src_lat_raw = parsed.get("src_lat")
  203. src_lon_raw = parsed.get("src_lon")
  204. try:
  205. src_lat = float(src_lat_raw) if src_lat_raw not in {None, ""} else None
  206. except (TypeError, ValueError):
  207. src_lat = None
  208. try:
  209. src_lon = float(src_lon_raw) if src_lon_raw not in {None, ""} else None
  210. except (TypeError, ValueError):
  211. src_lon = None
  212. return {
  213. "source": "wazuh",
  214. "event_type": event_type,
  215. "event_id": event_id,
  216. "timestamp": timestamp,
  217. "severity": severity if severity in {"low", "medium", "high", "critical"} else "medium",
  218. "title": title,
  219. "description": description,
  220. "asset": {
  221. "user": user,
  222. "hostname": (src.get("agent", {}) or {}).get("name"),
  223. "agent_id": (src.get("agent", {}) or {}).get("id"),
  224. },
  225. "network": {
  226. "src_ip": src_ip,
  227. "dst_ip": dst_ip,
  228. "dst_host": parsed.get("dst_host") or parsed.get("host"),
  229. "dst_port": int(dst_port) if str(dst_port or "").isdigit() else None,
  230. "domain": domain,
  231. "country": country,
  232. "src_lat": src_lat,
  233. "src_lon": src_lon,
  234. },
  235. "tags": ["wazuh", event_type, f"rule_{rule.get('id', 'unknown')}"],
  236. "risk_context": {
  237. "outside_thailand": bool(country and str(country).upper() != "TH"),
  238. },
  239. "raw": src,
  240. "payload": {
  241. **parsed,
  242. "event_action": event_action,
  243. "event_id": parsed.get("event_id"),
  244. "event_type": event_type,
  245. "success": parsed.get("success"),
  246. "logon_type": parsed.get("logon_type"),
  247. "account_type": parsed.get("account_type"),
  248. "is_admin": parsed.get("is_admin"),
  249. "is_service": parsed.get("is_service"),
  250. },
  251. }
  252. def normalize_wazuh_hit(self, hit: dict[str, Any]) -> dict[str, Any]:
  253. return self._normalize_wazuh_hit(hit)
  254. def _to_float(self, value: Any, default: float = 0.0) -> float:
  255. try:
  256. return float(value)
  257. except (TypeError, ValueError):
  258. return default
  259. def _severity_from_confidence(self, confidence: float) -> str:
  260. if confidence >= 0.9:
  261. return "high"
  262. if confidence >= 0.7:
  263. return "medium"
  264. return "low"
  265. def _extract_shuffle_verdict(self, shuffle_result: dict[str, Any] | None) -> dict[str, Any]:
  266. if not isinstance(shuffle_result, dict):
  267. return {
  268. "matched": False,
  269. "confidence": 0.0,
  270. "severity": "low",
  271. "evidence": "",
  272. "iocs": [],
  273. "reason": "no_shuffle_result",
  274. }
  275. flat = dict(shuffle_result)
  276. nested = shuffle_result.get("result")
  277. if isinstance(nested, dict):
  278. merged = dict(nested)
  279. merged.update(flat)
  280. flat = merged
  281. confidence = self._to_float(flat.get("confidence"), 0.0)
  282. matched_raw = flat.get("matched")
  283. if isinstance(matched_raw, bool):
  284. matched = matched_raw
  285. reason = "shuffle_explicit"
  286. else:
  287. matched = confidence >= 0.7
  288. reason = "confidence_threshold_fallback"
  289. severity_raw = str(flat.get("severity", "")).lower()
  290. severity = severity_raw if severity_raw in {"low", "medium", "high", "critical"} else self._severity_from_confidence(confidence)
  291. return {
  292. "matched": matched,
  293. "confidence": confidence,
  294. "severity": severity,
  295. "evidence": str(flat.get("evidence", "")),
  296. "iocs": flat.get("iocs", []),
  297. "reason": reason,
  298. "raw": shuffle_result,
  299. }
  300. async def ingest_incident(self, event: dict[str, Any]) -> dict[str, Any]:
  301. policy = self.repo.get_policy()
  302. incident_key = self._incident_key(event)
  303. if self._is_exception(event, policy):
  304. decision_trace = {
  305. "incident_key": incident_key,
  306. "policy_exception": True,
  307. "reason": "vpn_exception_user",
  308. }
  309. self.repo.upsert_incident(incident_key, severity="low", status="ignored_exception", iris_case_id=None)
  310. self.repo.add_event(
  311. incident_key=incident_key,
  312. event_id=event.get("event_id"),
  313. source=event.get("source", "unknown"),
  314. event_type=event.get("event_type", "generic"),
  315. raw_payload=event,
  316. decision_trace=decision_trace,
  317. )
  318. return {
  319. "incident_key": incident_key,
  320. "action_taken": "ignored_exception",
  321. "escalation_stub_sent": False,
  322. "decision_trace": decision_trace,
  323. }
  324. effective_severity, risk_score, risk_factors = self._effective_severity(event, policy)
  325. current = self.repo.get_incident(incident_key)
  326. action_taken = "updated_case" if current else "created_case"
  327. iris_case_id = current.get("iris_case_id") if current else None
  328. if not iris_case_id:
  329. case_payload = {
  330. "case_name": event.get("title", "SOC Incident"),
  331. "case_description": event.get("description", "Generated by soc-integrator MVP"),
  332. "case_customer": event.get("payload", {}).get("case_customer", settings.iris_default_customer_id),
  333. "case_soc_id": event.get("payload", {}).get("case_soc_id", settings.iris_default_soc_id),
  334. }
  335. iris_result = await self.iris_adapter.create_case(case_payload)
  336. iris_case_id = self._extract_iris_case_id(iris_result)
  337. else:
  338. update_payload = {
  339. "case_description": f"{event.get('description', 'Updated by soc-integrator MVP')} [event_id={event.get('event_id', '')}]"
  340. }
  341. try:
  342. await self.iris_adapter.update_case(iris_case_id, update_payload)
  343. except Exception:
  344. # Keep pipeline progressing for MVP even if update path is unsupported.
  345. pass
  346. stored = self.repo.upsert_incident(
  347. incident_key=incident_key,
  348. severity=effective_severity,
  349. status="open",
  350. iris_case_id=iris_case_id,
  351. )
  352. decision_trace = {
  353. "incident_key": incident_key,
  354. "risk_score": risk_score,
  355. "risk_factors": risk_factors,
  356. "effective_severity": effective_severity,
  357. "action_taken": action_taken,
  358. }
  359. self.repo.add_event(
  360. incident_key=incident_key,
  361. event_id=event.get("event_id"),
  362. source=event.get("source", "unknown"),
  363. event_type=event.get("event_type", "generic"),
  364. raw_payload=event,
  365. decision_trace=decision_trace,
  366. )
  367. escalate_severities = set(policy.get("escalate_severities", ["high", "critical"]))
  368. escalation_stub_sent = False
  369. stub_response: dict[str, Any] | None = None
  370. if effective_severity in escalate_severities:
  371. escalation_payload = {
  372. "incident_key": incident_key,
  373. "title": event.get("title", "SOC Incident"),
  374. "severity": effective_severity,
  375. "source": event.get("source", "soc-integrator"),
  376. "iris_case_id": iris_case_id,
  377. "event_summary": event.get("description", ""),
  378. "timestamp": event.get("timestamp"),
  379. }
  380. try:
  381. pd_result = await self.pagerduty_adapter.create_incident(escalation_payload)
  382. escalation_stub_sent = True
  383. stub_response = {"ok": True, "data": pd_result}
  384. self.repo.add_escalation_audit(
  385. incident_key=incident_key,
  386. status_code=200,
  387. success=True,
  388. response_excerpt=self._safe_excerpt(pd_result),
  389. )
  390. except Exception as exc:
  391. stub_response = {"ok": False, "error": str(exc)}
  392. self.repo.add_escalation_audit(
  393. incident_key=incident_key,
  394. status_code=502,
  395. success=False,
  396. response_excerpt=self._safe_excerpt(exc),
  397. )
  398. return {
  399. "incident_key": stored["incident_key"],
  400. "action_taken": action_taken,
  401. "iris_case_id": stored.get("iris_case_id"),
  402. "escalation_stub_sent": escalation_stub_sent,
  403. "stub_response": stub_response,
  404. "decision_trace": decision_trace,
  405. }
  406. async def evaluate_ioc(self, payload: dict[str, Any]) -> dict[str, Any]:
  407. policy = self.repo.get_policy()
  408. workflow_id = str(policy.get("shuffle", {}).get("ioc_workflow_id", "")).strip()
  409. shuffle_result: dict[str, Any] | None = None
  410. if workflow_id:
  411. shuffle_result = await self.shuffle_adapter.trigger_workflow(workflow_id, payload)
  412. verdict = self._extract_shuffle_verdict(shuffle_result)
  413. matched = bool(verdict["matched"])
  414. confidence = self._to_float(verdict["confidence"], 0.0)
  415. logger.info(
  416. "ioc evaluation workflow_id=%s matched=%s confidence=%.2f",
  417. workflow_id or "<none>",
  418. matched,
  419. confidence,
  420. )
  421. if matched:
  422. src_event = payload.get("source_event", {})
  423. event_id = src_event.get("event_id") or f"ioc-{int(time.time())}"
  424. if not isinstance(event_id, str):
  425. event_id = str(event_id)
  426. description = f"IOC evaluation result confidence={confidence:.2f}"
  427. evidence = str(verdict.get("evidence", "")).strip()
  428. if evidence:
  429. description = f"{description} evidence={evidence[:180]}"
  430. event = {
  431. "source": "shuffle",
  432. "event_type": "ioc_dns" if payload.get("ioc_type") == "domain" else "ioc_ips",
  433. "event_id": event_id,
  434. "timestamp": datetime.now(timezone.utc).isoformat(),
  435. "severity": verdict["severity"],
  436. "title": f"IOC match: {payload.get('ioc_value', 'unknown')}",
  437. "description": description,
  438. "asset": src_event.get("asset", {}),
  439. "network": src_event.get("network", {}),
  440. "tags": ["ioc", str(payload.get("ioc_type", "unknown"))],
  441. "risk_context": {},
  442. "raw": {
  443. "payload": payload,
  444. "shuffle": verdict.get("raw"),
  445. },
  446. "payload": {},
  447. }
  448. ingest_result = await self.ingest_incident(event)
  449. else:
  450. ingest_result = {"action_taken": "rejected"}
  451. return {
  452. "matched": matched,
  453. "confidence": confidence,
  454. "severity": verdict["severity"],
  455. "evidence": verdict["evidence"],
  456. "iocs": verdict["iocs"],
  457. "decision_source": verdict["reason"],
  458. "shuffle": shuffle_result,
  459. "result": ingest_result,
  460. }
  461. async def evaluate_vpn(self, payload: dict[str, Any]) -> dict[str, Any]:
  462. if not payload.get("success", False):
  463. return {
  464. "risk_score": 0,
  465. "risk_factors": [],
  466. "exception_applied": False,
  467. "action_taken": "rejected",
  468. }
  469. event = {
  470. "source": "wazuh",
  471. "event_type": "vpn_geo_anomaly",
  472. "event_id": payload.get("event_id") or f"vpn-{int(time.time())}",
  473. "timestamp": payload.get("event_time") or datetime.now(timezone.utc).isoformat(),
  474. "severity": "high",
  475. "title": f"VPN login anomaly: {payload.get('user', 'unknown')}",
  476. "description": f"VPN success from {payload.get('country_code', 'unknown')} for user {payload.get('user', 'unknown')}",
  477. "asset": {"user": payload.get("user")},
  478. "network": {"src_ip": payload.get("src_ip"), "country": payload.get("country_code")},
  479. "tags": ["vpn", "geo-anomaly"],
  480. "risk_context": {
  481. "outside_thailand": payload.get("country_code", "").upper() != "TH",
  482. "admin_account": bool(payload.get("is_admin", False)),
  483. "off_hours": bool(payload.get("off_hours", False)),
  484. "first_seen_country": bool(payload.get("first_seen_country", False)),
  485. },
  486. "raw": payload,
  487. "payload": {},
  488. }
  489. ingest_result = await self.ingest_incident(event)
  490. decision_trace = ingest_result.get("decision_trace", {})
  491. return {
  492. "risk_score": decision_trace.get("risk_score", 0),
  493. "risk_factors": decision_trace.get("risk_factors", []),
  494. "exception_applied": ingest_result.get("action_taken") == "ignored_exception",
  495. "action_taken": ingest_result.get("action_taken"),
  496. "incident_key": ingest_result.get("incident_key"),
  497. "iris_case_id": ingest_result.get("iris_case_id"),
  498. "escalation_stub_sent": ingest_result.get("escalation_stub_sent", False),
  499. }
  500. async def ingest_wazuh_alert_to_iris(self, event: dict[str, Any]) -> dict[str, Any]:
  501. """Create an IRIS Alert from a normalised Wazuh event and record it for dedup."""
  502. event_id = str(event.get("event_id", "")).strip()
  503. severity_str = (event.get("severity") or "medium").lower()
  504. severity_id = _IRIS_SEVERITY_ID.get(severity_str, 3)
  505. payload: dict[str, Any] = {
  506. "alert_title": event.get("title") or f"Wazuh alert {event_id}",
  507. "alert_description": event.get("description") or "",
  508. "alert_severity_id": severity_id,
  509. "alert_status_id": 1, # Unassigned
  510. "alert_source": "wazuh",
  511. "alert_source_ref": event_id,
  512. "alert_source_event_time": event.get("timestamp") or datetime.now(timezone.utc).isoformat(),
  513. "alert_customer_id": settings.iris_default_customer_id or 1,
  514. "alert_note": json.dumps({
  515. "asset": event.get("asset", {}),
  516. "network": event.get("network", {}),
  517. "tags": event.get("tags", []),
  518. }),
  519. }
  520. result = await self.iris_adapter.create_alert(payload)
  521. iris_alert_id = (result.get("data") or {}).get("alert_id")
  522. if event_id:
  523. synthetic_key = f"wazuh_alert_{event_id}"
  524. self.repo.upsert_incident(
  525. incident_key=synthetic_key,
  526. severity=event.get("severity") or "medium",
  527. status="open",
  528. iris_case_id=str(iris_alert_id) if iris_alert_id else None,
  529. )
  530. self.repo.add_event(
  531. incident_key=synthetic_key,
  532. event_id=event_id,
  533. source="wazuh",
  534. event_type=event.get("event_type") or "wazuh",
  535. raw_payload=event,
  536. decision_trace={"iris_alert_id": iris_alert_id, "action": "created_iris_alert"},
  537. )
  538. return {"iris_alert_id": iris_alert_id, "event_id": event_id}
  539. async def sync_wazuh_alerts(
  540. self,
  541. query: str = "soc_mvp_test=true OR event_type:*",
  542. limit: int = 50,
  543. minutes: int = 120,
  544. min_severity: str | None = None,
  545. ) -> dict[str, Any]:
  546. raw = await self.wazuh_adapter.search_alerts(query=query, limit=limit, minutes=minutes)
  547. hits = (raw.get("hits", {}) or {}).get("hits", []) if isinstance(raw, dict) else []
  548. # Resolve minimum severity: param > policy > default "medium"
  549. policy = self.repo.get_policy()
  550. effective_min = (min_severity or policy.get("sync", {}).get("min_severity", "medium")).lower()
  551. min_order = _SEVERITY_ORDER.get(effective_min, 2)
  552. processed = 0
  553. ingested = 0
  554. skipped_existing = 0
  555. skipped_filtered = 0
  556. failed = 0
  557. errors: list[str] = []
  558. created_incidents: list[str] = []
  559. ioc_evaluated = 0
  560. ioc_matched = 0
  561. ioc_rejected = 0
  562. for hit in hits:
  563. processed += 1
  564. event = self._normalize_wazuh_hit(hit)
  565. event_id = str(event.get("event_id", "")).strip()
  566. if event_id and self.repo.has_event("wazuh", event_id):
  567. skipped_existing += 1
  568. continue
  569. # Severity filter — skip alerts below minimum threshold
  570. event_order = _SEVERITY_ORDER.get((event.get("severity") or "low").lower(), 1)
  571. if event_order < min_order:
  572. skipped_filtered += 1
  573. continue
  574. try:
  575. if event.get("event_type") in {"ioc_dns", "ioc_ips"}:
  576. ioc_evaluated += 1
  577. payload = {
  578. "ioc_type": "domain" if event.get("event_type") == "ioc_dns" else "ip",
  579. "ioc_value": (event.get("network", {}) or {}).get("domain")
  580. or (event.get("network", {}) or {}).get("dst_ip")
  581. or (event.get("network", {}) or {}).get("src_ip")
  582. or "unknown",
  583. "source_event": {
  584. "event_id": event.get("event_id"),
  585. "asset": event.get("asset", {}),
  586. "network": event.get("network", {}),
  587. "raw": event.get("raw", {}),
  588. },
  589. }
  590. ioc_result = await self.evaluate_ioc(payload)
  591. if ioc_result.get("matched"):
  592. ioc_matched += 1
  593. ingested += 1
  594. incident_key = str((ioc_result.get("result", {}) or {}).get("incident_key", ""))
  595. if incident_key:
  596. created_incidents.append(incident_key)
  597. else:
  598. ioc_rejected += 1
  599. else:
  600. result = await self.ingest_wazuh_alert_to_iris(event)
  601. ingested += 1
  602. iris_alert_id = result.get("iris_alert_id")
  603. if iris_alert_id:
  604. created_incidents.append(str(iris_alert_id))
  605. except Exception as exc:
  606. failed += 1
  607. errors.append(f"{event_id or 'unknown_event'}: {exc}")
  608. return {
  609. "query": query,
  610. "window_minutes": minutes,
  611. "limit": limit,
  612. "min_severity_applied": effective_min,
  613. "processed": processed,
  614. "ingested": ingested,
  615. "skipped_existing": skipped_existing,
  616. "skipped_filtered": skipped_filtered,
  617. "failed": failed,
  618. "ioc_evaluated": ioc_evaluated,
  619. "ioc_matched": ioc_matched,
  620. "ioc_rejected": ioc_rejected,
  621. "iris_alert_ids": created_incidents,
  622. "errors": errors[:10],
  623. "total_hits": (raw.get("hits", {}).get("total", {}) if isinstance(raw, dict) else {}),
  624. }
  625. async def dependency_health(self) -> dict[str, Any]:
  626. out: dict[str, Any] = {}
  627. async def timed(name: str, fn):
  628. start = time.time()
  629. try:
  630. result = await fn()
  631. out[name] = {
  632. "status": "up",
  633. "latency_ms": round((time.time() - start) * 1000, 2),
  634. "details": result,
  635. }
  636. except Exception as exc:
  637. out[name] = {
  638. "status": "down",
  639. "latency_ms": round((time.time() - start) * 1000, 2),
  640. "error": str(exc),
  641. }
  642. await timed("wazuh", self.wazuh_adapter.auth_test)
  643. await timed("shuffle", self.shuffle_adapter.health)
  644. await timed("iris", self.iris_adapter.whoami)
  645. async def pagerduty_stub_health():
  646. async with httpx.AsyncClient(timeout=10.0) as client:
  647. r = await client.get(settings.pagerduty_base_url)
  648. r.raise_for_status()
  649. return {"status_code": r.status_code}
  650. await timed("pagerduty_stub", pagerduty_stub_health)
  651. return out