Brak opisu

mvp_service.py 34KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851
  1. from __future__ import annotations
  2. import hashlib
  3. import json
  4. import logging
  5. import re
  6. import time
  7. from datetime import datetime, timezone
  8. from typing import Any
  9. import httpx
  10. from app.adapters.abuseipdb import AbuseIpdbAdapter
  11. from app.adapters.iris import IrisAdapter
  12. from app.adapters.pagerduty import PagerDutyAdapter
  13. from app.adapters.shuffle import ShuffleAdapter
  14. from app.adapters.virustotal import VirusTotalAdapter
  15. from app.adapters.wazuh import WazuhAdapter
  16. from app.config import settings
  17. from app.repositories.mvp_repo import MvpRepository
  18. logger = logging.getLogger(__name__)
  19. _IRIS_SEVERITY_ID: dict[str, int] = {
  20. "critical": 5,
  21. "high": 4,
  22. "medium": 3,
  23. "low": 2,
  24. "informational": 1,
  25. }
  26. _SEVERITY_ORDER: dict[str, int] = {
  27. "informational": 0,
  28. "low": 1,
  29. "medium": 2,
  30. "high": 3,
  31. "critical": 4,
  32. }
  33. def _build_vt_ioc_result(
  34. vt: dict[str, object],
  35. ioc_type: str,
  36. ioc_value: str,
  37. malicious_threshold: int,
  38. suspicious_threshold: int,
  39. ) -> tuple[dict[str, object], bool, str, float]:
  40. stats = (
  41. (((vt.get("data") or {}).get("attributes") or {}).get("last_analysis_stats"))
  42. if isinstance(vt, dict)
  43. else None
  44. ) or {}
  45. malicious = int(stats.get("malicious", 0) or 0)
  46. suspicious = int(stats.get("suspicious", 0) or 0)
  47. harmless = int(stats.get("harmless", 0) or 0)
  48. undetected = int(stats.get("undetected", 0) or 0)
  49. total = malicious + suspicious + harmless + undetected
  50. confidence = 0.0 if total == 0 else round(((malicious + (0.5 * suspicious)) / total), 4)
  51. matched = (malicious >= malicious_threshold) or (suspicious >= suspicious_threshold)
  52. severity = "low"
  53. if malicious >= 5 or suspicious >= 10:
  54. severity = "critical"
  55. elif malicious >= 2 or suspicious >= 5:
  56. severity = "high"
  57. elif malicious >= 1 or suspicious >= 1:
  58. severity = "medium"
  59. reason = (
  60. f"virustotal_stats malicious={malicious} suspicious={suspicious} "
  61. f"thresholds(malicious>={malicious_threshold}, suspicious>={suspicious_threshold})"
  62. )
  63. result: dict[str, object] = {
  64. "ioc_type": ioc_type,
  65. "ioc_value": ioc_value,
  66. "matched": matched,
  67. "severity": severity,
  68. "confidence": confidence,
  69. "reason": reason,
  70. "providers": {"virustotal": {"stats": stats}},
  71. "raw": {"virustotal": vt},
  72. }
  73. return result, matched, severity, confidence
  74. def _build_abuseipdb_ioc_result(
  75. abuse: dict[str, object],
  76. ioc_value: str,
  77. confidence_threshold: int = 50,
  78. ) -> tuple[dict[str, object], bool, str, float]:
  79. data = ((abuse.get("data") if isinstance(abuse, dict) else None) or {}) if isinstance(abuse, dict) else {}
  80. score = int(data.get("abuseConfidenceScore", 0) or 0)
  81. total_reports = int(data.get("totalReports", 0) or 0)
  82. matched = score >= confidence_threshold
  83. severity = "low"
  84. if score >= 90:
  85. severity = "critical"
  86. elif score >= 70:
  87. severity = "high"
  88. elif score >= 30:
  89. severity = "medium"
  90. confidence = round(score / 100.0, 4)
  91. reason = f"abuseipdb score={score} totalReports={total_reports} threshold>={confidence_threshold}"
  92. result: dict[str, object] = {
  93. "ioc_type": "ip",
  94. "ioc_value": ioc_value,
  95. "matched": matched,
  96. "severity": severity,
  97. "confidence": confidence,
  98. "reason": reason,
  99. "providers": {"abuseipdb": {"score": score, "totalReports": total_reports, "raw": abuse}},
  100. }
  101. return result, matched, severity, confidence
  102. class MvpService:
  103. def __init__(
  104. self,
  105. repo: MvpRepository,
  106. wazuh_adapter: WazuhAdapter,
  107. shuffle_adapter: ShuffleAdapter,
  108. iris_adapter: IrisAdapter,
  109. pagerduty_adapter: PagerDutyAdapter,
  110. virustotal_adapter: VirusTotalAdapter | None = None,
  111. abuseipdb_adapter: AbuseIpdbAdapter | None = None,
  112. ) -> None:
  113. self.repo = repo
  114. self.wazuh_adapter = wazuh_adapter
  115. self.shuffle_adapter = shuffle_adapter
  116. self.iris_adapter = iris_adapter
  117. self.pagerduty_adapter = pagerduty_adapter
  118. self.virustotal_adapter = virustotal_adapter
  119. self.abuseipdb_adapter = abuseipdb_adapter
  120. def _is_off_hours(self, ts: datetime) -> bool:
  121. hour = ts.astimezone(timezone.utc).hour
  122. return hour < 6 or hour >= 20
  123. def _safe_excerpt(self, payload: Any) -> str:
  124. text = str(payload)
  125. return text[:300]
  126. def _primary_subject(self, event: dict[str, Any]) -> str:
  127. asset = event.get("asset", {})
  128. return str(asset.get("user") or asset.get("hostname") or "unknown")
  129. def _primary_observable(self, event: dict[str, Any]) -> str:
  130. network = event.get("network", {})
  131. return str(network.get("domain") or network.get("src_ip") or network.get("dst_ip") or "unknown")
  132. def _incident_key(self, event: dict[str, Any]) -> str:
  133. ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00")).astimezone(timezone.utc)
  134. day_bucket = ts.strftime("%Y-%m-%d")
  135. raw = "|".join(
  136. [
  137. str(event.get("event_type", "generic")),
  138. self._primary_subject(event),
  139. self._primary_observable(event),
  140. day_bucket,
  141. ]
  142. )
  143. return hashlib.sha256(raw.encode("utf-8")).hexdigest()
  144. def _effective_severity(self, event: dict[str, Any], policy: dict[str, Any]) -> tuple[str, int, list[str]]:
  145. severity = str(event.get("severity", "medium")).lower()
  146. risk_context = event.get("risk_context", {})
  147. network = event.get("network", {})
  148. weights = policy.get("risk", {}).get("weights", {})
  149. score = 0
  150. factors: list[str] = []
  151. allowed_country = policy.get("vpn", {}).get("allowed_country", "TH")
  152. country = str(network.get("country", "")).upper()
  153. if country and country != allowed_country:
  154. score += int(weights.get("outside_thailand", 50))
  155. factors.append("outside_country")
  156. if risk_context.get("admin_account"):
  157. score += int(weights.get("admin", 20))
  158. factors.append("admin_account")
  159. ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00"))
  160. if risk_context.get("off_hours") or self._is_off_hours(ts):
  161. score += int(weights.get("off_hours", 15))
  162. factors.append("off_hours")
  163. if risk_context.get("first_seen_country"):
  164. score += int(weights.get("first_seen_country", 15))
  165. factors.append("first_seen_country")
  166. thresholds = policy.get("risk", {}).get("thresholds", {})
  167. if score >= int(thresholds.get("high", 70)):
  168. severity = "high" if severity in {"low", "medium"} else severity
  169. elif score >= int(thresholds.get("medium", 40)) and severity == "low":
  170. severity = "medium"
  171. return severity, score, factors
  172. def _is_exception(self, event: dict[str, Any], policy: dict[str, Any]) -> bool:
  173. if event.get("event_type") != "vpn_geo_anomaly":
  174. return False
  175. asset = event.get("asset", {})
  176. user = str(asset.get("user", ""))
  177. allowed_users = set(policy.get("vpn", {}).get("exception_users", []))
  178. return user in allowed_users
  179. def _extract_iris_case_id(self, iris_response: dict[str, Any]) -> str | None:
  180. if "case_id" in iris_response:
  181. return str(iris_response.get("case_id"))
  182. data = iris_response.get("data")
  183. if isinstance(data, dict) and "case_id" in data:
  184. return str(data.get("case_id"))
  185. return None
  186. def _parse_kv_pairs(self, text: str) -> dict[str, str]:
  187. pattern = r"([A-Za-z0-9_]+)=('(?:[^']*)'|\"(?:[^\"]*)\"|[^\s]+)"
  188. out: dict[str, str] = {}
  189. for key, raw in re.findall(pattern, text):
  190. value = raw.strip().strip("'").strip('"')
  191. out[key] = value
  192. return out
  193. def _severity_from_rule_level(self, rule_level: Any) -> str:
  194. try:
  195. level = int(rule_level)
  196. except (TypeError, ValueError):
  197. return "medium"
  198. if level >= 12:
  199. return "critical"
  200. if level >= 8:
  201. return "high"
  202. if level >= 4:
  203. return "medium"
  204. return "low"
  205. def _event_type_from_text(self, text: str, parsed: dict[str, str]) -> str:
  206. explicit = str(parsed.get("event_type") or "").strip().lower()
  207. usecase_id = str(parsed.get("usecase_id") or "").strip().upper()
  208. section = str(parsed.get("section") or "").strip().upper()
  209. source = str(parsed.get("source") or "").strip().lower()
  210. success = str(parsed.get("success") or "").strip().lower()
  211. has_geo = bool(parsed.get("country") or parsed.get("src_lat") or parsed.get("src_lon"))
  212. has_user = bool(parsed.get("user"))
  213. has_src_ip = bool(parsed.get("src_ip") or parsed.get("srcip"))
  214. explicit_success_login = explicit in {
  215. "vpn_login_success",
  216. "windows_auth_success",
  217. "auth_success",
  218. }
  219. # Production-first C1 detection:
  220. # successful auth/login + geo context on vpn/windows identity streams.
  221. if (
  222. (source in {"vpn", "fortigate", "windows", "identity"} or "vpn" in source)
  223. and has_geo
  224. and has_user
  225. and has_src_ip
  226. and (success == "true" or explicit_success_login)
  227. ):
  228. return "c1_impossible_travel"
  229. # Legacy simulator markers remain supported as fallback.
  230. if usecase_id.startswith("C1") or section == "C1":
  231. return "c1_impossible_travel"
  232. if explicit in {"c1_impossible_travel", "impossible_travel"}:
  233. return "c1_impossible_travel"
  234. if explicit == "vpn_geo_anomaly":
  235. return "vpn_geo_anomaly"
  236. if explicit:
  237. return explicit
  238. lowered = text.lower()
  239. if "impossible travel" in lowered:
  240. return "c1_impossible_travel"
  241. if "vpn" in lowered and ("geo" in lowered or "country" in lowered):
  242. return "vpn_geo_anomaly"
  243. if "domain" in lowered or "dns" in lowered:
  244. return "ioc_dns"
  245. if "c2" in lowered or "ips" in lowered or "ip " in lowered:
  246. return "ioc_ips"
  247. if "auth" in lowered and "fail" in lowered:
  248. return "auth_anomaly"
  249. return "generic"
  250. def _normalize_wazuh_hit(self, hit: dict[str, Any]) -> dict[str, Any]:
  251. src = hit.get("_source", {})
  252. full_log = str(src.get("full_log", ""))
  253. parsed = self._parse_kv_pairs(full_log)
  254. event_id = str(parsed.get("event_id") or src.get("id") or hit.get("_id") or f"wazuh-{int(time.time())}")
  255. timestamp = (
  256. src.get("@timestamp")
  257. or src.get("timestamp")
  258. or datetime.now(timezone.utc).isoformat()
  259. )
  260. rule = src.get("rule", {}) if isinstance(src.get("rule"), dict) else {}
  261. rule_desc = str(rule.get("description") or "")
  262. event_type = self._event_type_from_text(full_log, parsed)
  263. severity = str(parsed.get("severity", "")).lower() or self._severity_from_rule_level(rule.get("level"))
  264. src_ip = parsed.get("src_ip")
  265. if not src_ip:
  266. src_ip = parsed.get("srcip")
  267. dst_ip = parsed.get("dst_ip")
  268. if not dst_ip:
  269. dst_ip = parsed.get("dstip")
  270. domain = parsed.get("query") or parsed.get("domain")
  271. country = parsed.get("country")
  272. user = parsed.get("user") or (src.get("agent", {}) or {}).get("name")
  273. dst_port = parsed.get("dst_port") or parsed.get("dstport")
  274. event_action = parsed.get("event_action") or parsed.get("action")
  275. title = rule_desc or f"Wazuh alert {rule.get('id', '')}".strip()
  276. description = full_log or rule_desc or "Wazuh alert"
  277. src_lat_raw = parsed.get("src_lat")
  278. src_lon_raw = parsed.get("src_lon")
  279. try:
  280. src_lat = float(src_lat_raw) if src_lat_raw not in {None, ""} else None
  281. except (TypeError, ValueError):
  282. src_lat = None
  283. try:
  284. src_lon = float(src_lon_raw) if src_lon_raw not in {None, ""} else None
  285. except (TypeError, ValueError):
  286. src_lon = None
  287. return {
  288. "source": "wazuh",
  289. "event_type": event_type,
  290. "event_id": event_id,
  291. "timestamp": timestamp,
  292. "severity": severity if severity in {"low", "medium", "high", "critical"} else "medium",
  293. "title": title,
  294. "description": description,
  295. "asset": {
  296. "user": user,
  297. "hostname": (src.get("agent", {}) or {}).get("name"),
  298. "agent_id": (src.get("agent", {}) or {}).get("id"),
  299. },
  300. "network": {
  301. "src_ip": src_ip,
  302. "dst_ip": dst_ip,
  303. "dst_host": parsed.get("dst_host") or parsed.get("host"),
  304. "dst_port": int(dst_port) if str(dst_port or "").isdigit() else None,
  305. "domain": domain,
  306. "country": country,
  307. "src_lat": src_lat,
  308. "src_lon": src_lon,
  309. },
  310. "tags": ["wazuh", event_type, f"rule_{rule.get('id', 'unknown')}"],
  311. "risk_context": {
  312. "outside_thailand": bool(country and str(country).upper() != "TH"),
  313. },
  314. "raw": src,
  315. "payload": {
  316. **parsed,
  317. "event_action": event_action,
  318. "event_id": parsed.get("event_id"),
  319. "event_type": event_type,
  320. "success": parsed.get("success"),
  321. "logon_type": parsed.get("logon_type"),
  322. "account_type": parsed.get("account_type"),
  323. "is_admin": parsed.get("is_admin"),
  324. "is_service": parsed.get("is_service"),
  325. },
  326. }
  327. def normalize_wazuh_hit(self, hit: dict[str, Any]) -> dict[str, Any]:
  328. return self._normalize_wazuh_hit(hit)
  329. def _to_float(self, value: Any, default: float = 0.0) -> float:
  330. try:
  331. return float(value)
  332. except (TypeError, ValueError):
  333. return default
  334. def _severity_from_confidence(self, confidence: float) -> str:
  335. if confidence >= 0.9:
  336. return "high"
  337. if confidence >= 0.7:
  338. return "medium"
  339. return "low"
  340. async def ingest_incident(self, event: dict[str, Any]) -> dict[str, Any]:
  341. policy = self.repo.get_policy()
  342. incident_key = self._incident_key(event)
  343. if self._is_exception(event, policy):
  344. decision_trace = {
  345. "incident_key": incident_key,
  346. "policy_exception": True,
  347. "reason": "vpn_exception_user",
  348. }
  349. self.repo.upsert_incident(incident_key, severity="low", status="ignored_exception", iris_case_id=None)
  350. self.repo.add_event(
  351. incident_key=incident_key,
  352. event_id=event.get("event_id"),
  353. source=event.get("source", "unknown"),
  354. event_type=event.get("event_type", "generic"),
  355. raw_payload=event,
  356. decision_trace=decision_trace,
  357. )
  358. return {
  359. "incident_key": incident_key,
  360. "action_taken": "ignored_exception",
  361. "escalation_stub_sent": False,
  362. "decision_trace": decision_trace,
  363. }
  364. effective_severity, risk_score, risk_factors = self._effective_severity(event, policy)
  365. current = self.repo.get_incident(incident_key)
  366. action_taken = "updated_case" if current else "created_case"
  367. iris_case_id = current.get("iris_case_id") if current else None
  368. if not iris_case_id:
  369. case_payload = {
  370. "case_name": event.get("title", "SOC Incident"),
  371. "case_description": event.get("description", "Generated by soc-integrator MVP"),
  372. "case_customer": event.get("payload", {}).get("case_customer", settings.iris_default_customer_id),
  373. "case_soc_id": event.get("payload", {}).get("case_soc_id", settings.iris_default_soc_id),
  374. }
  375. iris_result = await self.iris_adapter.create_case(case_payload)
  376. iris_case_id = self._extract_iris_case_id(iris_result)
  377. else:
  378. update_payload = {
  379. "case_description": f"{event.get('description', 'Updated by soc-integrator MVP')} [event_id={event.get('event_id', '')}]"
  380. }
  381. try:
  382. await self.iris_adapter.update_case(iris_case_id, update_payload)
  383. except Exception:
  384. # Keep pipeline progressing for MVP even if update path is unsupported.
  385. pass
  386. stored = self.repo.upsert_incident(
  387. incident_key=incident_key,
  388. severity=effective_severity,
  389. status="open",
  390. iris_case_id=iris_case_id,
  391. )
  392. decision_trace = {
  393. "incident_key": incident_key,
  394. "risk_score": risk_score,
  395. "risk_factors": risk_factors,
  396. "effective_severity": effective_severity,
  397. "action_taken": action_taken,
  398. }
  399. self.repo.add_event(
  400. incident_key=incident_key,
  401. event_id=event.get("event_id"),
  402. source=event.get("source", "unknown"),
  403. event_type=event.get("event_type", "generic"),
  404. raw_payload=event,
  405. decision_trace=decision_trace,
  406. )
  407. escalate_severities = set(policy.get("escalate_severities", ["high", "critical"]))
  408. escalation_stub_sent = False
  409. stub_response: dict[str, Any] | None = None
  410. if effective_severity in escalate_severities:
  411. escalation_payload = {
  412. "incident_key": incident_key,
  413. "title": event.get("title", "SOC Incident"),
  414. "severity": effective_severity,
  415. "source": event.get("source", "soc-integrator"),
  416. "iris_case_id": iris_case_id,
  417. "event_summary": event.get("description", ""),
  418. "timestamp": event.get("timestamp"),
  419. }
  420. try:
  421. pd_result = await self.pagerduty_adapter.create_incident(escalation_payload)
  422. escalation_stub_sent = True
  423. stub_response = {"ok": True, "data": pd_result}
  424. self.repo.add_escalation_audit(
  425. incident_key=incident_key,
  426. status_code=200,
  427. success=True,
  428. response_excerpt=self._safe_excerpt(pd_result),
  429. )
  430. except Exception as exc:
  431. stub_response = {"ok": False, "error": str(exc)}
  432. self.repo.add_escalation_audit(
  433. incident_key=incident_key,
  434. status_code=502,
  435. success=False,
  436. response_excerpt=self._safe_excerpt(exc),
  437. )
  438. return {
  439. "incident_key": stored["incident_key"],
  440. "action_taken": action_taken,
  441. "iris_case_id": stored.get("iris_case_id"),
  442. "escalation_stub_sent": escalation_stub_sent,
  443. "stub_response": stub_response,
  444. "decision_trace": decision_trace,
  445. }
  446. async def evaluate_ioc(self, payload: dict[str, Any]) -> dict[str, Any]:
  447. ioc_type = str(payload.get("ioc_type") or "ip").strip()
  448. ioc_value = str(payload.get("ioc_value") or "").strip()
  449. if not ioc_value or ioc_value == "unknown":
  450. return {
  451. "matched": False,
  452. "confidence": 0.0,
  453. "severity": "low",
  454. "evidence": "no_ioc_value",
  455. "iocs": [],
  456. "decision_source": "skipped",
  457. "result": {"action_taken": "rejected"},
  458. }
  459. verdicts: list[dict[str, Any]] = []
  460. if self.virustotal_adapter is not None:
  461. try:
  462. vt_raw = await self.virustotal_adapter.enrich_ioc(ioc_type, ioc_value)
  463. vt_result, vt_matched, vt_severity, vt_conf = _build_vt_ioc_result(
  464. vt_raw, ioc_type, ioc_value,
  465. malicious_threshold=1, suspicious_threshold=3,
  466. )
  467. verdicts.append({
  468. "matched": vt_matched, "severity": vt_severity,
  469. "confidence": vt_conf, "result": vt_result, "provider": "virustotal",
  470. })
  471. except Exception as exc:
  472. logger.warning("VT IOC eval failed for %s: %s", ioc_value, exc)
  473. if ioc_type == "ip" and self.abuseipdb_adapter is not None:
  474. try:
  475. abuse_raw = await self.abuseipdb_adapter.check_ip(ioc_value)
  476. abuse_result, abuse_matched, abuse_severity, abuse_conf = _build_abuseipdb_ioc_result(
  477. abuse_raw, ioc_value, confidence_threshold=50,
  478. )
  479. verdicts.append({
  480. "matched": abuse_matched, "severity": abuse_severity,
  481. "confidence": abuse_conf, "result": abuse_result, "provider": "abuseipdb",
  482. })
  483. except Exception as exc:
  484. logger.warning("AbuseIPDB IOC eval failed for %s: %s", ioc_value, exc)
  485. matched = any(v["matched"] for v in verdicts)
  486. confidence = max((v["confidence"] for v in verdicts), default=0.0)
  487. severity = (
  488. max(
  489. (v["severity"] for v in verdicts if v["matched"]),
  490. key=lambda s: _SEVERITY_ORDER.get(s, 1),
  491. default="low",
  492. )
  493. if matched
  494. else "low"
  495. )
  496. evidence = "; ".join(v["result"]["reason"] for v in verdicts if v["matched"])
  497. providers_used = [v["provider"] for v in verdicts]
  498. logger.info(
  499. "ioc evaluation ioc=%s type=%s matched=%s confidence=%.2f providers=%s",
  500. ioc_value, ioc_type, matched, confidence, providers_used,
  501. )
  502. self.repo.add_ioc_trace(
  503. action="evaluate",
  504. ioc_type=ioc_type,
  505. ioc_value=ioc_value,
  506. providers=providers_used,
  507. request_payload=payload,
  508. response_payload={"verdicts": verdicts, "matched": matched},
  509. matched=matched,
  510. severity=severity if matched else None,
  511. confidence=confidence,
  512. )
  513. ingest_result: dict[str, Any] = {"action_taken": "rejected"}
  514. if matched:
  515. src_event = payload.get("source_event", {}) or {}
  516. event_id = str(src_event.get("event_id") or f"ioc-{int(time.time())}")
  517. event = {
  518. "source": providers_used[0] if providers_used else "ioc",
  519. "event_type": "ioc_dns" if ioc_type == "domain" else "ioc_ips",
  520. "event_id": event_id,
  521. "timestamp": datetime.now(timezone.utc).isoformat(),
  522. "severity": severity,
  523. "title": f"IOC match: {ioc_value}",
  524. "description": f"IOC evaluation confidence={confidence:.2f} evidence={evidence[:180]}",
  525. "asset": src_event.get("asset", {}),
  526. "network": src_event.get("network", {}),
  527. "tags": ["ioc", ioc_type],
  528. "risk_context": {},
  529. "raw": {"payload": payload, "verdicts": verdicts},
  530. "payload": {},
  531. }
  532. ingest_result = await self.ingest_wazuh_alert_to_iris(event)
  533. return {
  534. "matched": matched,
  535. "confidence": confidence,
  536. "severity": severity,
  537. "evidence": evidence,
  538. "iocs": [v["result"] for v in verdicts if v["matched"]],
  539. "decision_source": "direct_api",
  540. "result": ingest_result,
  541. }
  542. async def evaluate_vpn(self, payload: dict[str, Any]) -> dict[str, Any]:
  543. if not payload.get("success", False):
  544. return {
  545. "risk_score": 0,
  546. "risk_factors": [],
  547. "exception_applied": False,
  548. "action_taken": "rejected",
  549. }
  550. event = {
  551. "source": "wazuh",
  552. "event_type": "vpn_geo_anomaly",
  553. "event_id": payload.get("event_id") or f"vpn-{int(time.time())}",
  554. "timestamp": payload.get("event_time") or datetime.now(timezone.utc).isoformat(),
  555. "severity": "high",
  556. "title": f"VPN login anomaly: {payload.get('user', 'unknown')}",
  557. "description": f"VPN success from {payload.get('country_code', 'unknown')} for user {payload.get('user', 'unknown')}",
  558. "asset": {"user": payload.get("user")},
  559. "network": {"src_ip": payload.get("src_ip"), "country": payload.get("country_code")},
  560. "tags": ["vpn", "geo-anomaly"],
  561. "risk_context": {
  562. "outside_thailand": payload.get("country_code", "").upper() != "TH",
  563. "admin_account": bool(payload.get("is_admin", False)),
  564. "off_hours": bool(payload.get("off_hours", False)),
  565. "first_seen_country": bool(payload.get("first_seen_country", False)),
  566. },
  567. "raw": payload,
  568. "payload": {},
  569. }
  570. ingest_result = await self.ingest_incident(event)
  571. decision_trace = ingest_result.get("decision_trace", {})
  572. return {
  573. "risk_score": decision_trace.get("risk_score", 0),
  574. "risk_factors": decision_trace.get("risk_factors", []),
  575. "exception_applied": ingest_result.get("action_taken") == "ignored_exception",
  576. "action_taken": ingest_result.get("action_taken"),
  577. "incident_key": ingest_result.get("incident_key"),
  578. "iris_case_id": ingest_result.get("iris_case_id"),
  579. "escalation_stub_sent": ingest_result.get("escalation_stub_sent", False),
  580. }
  581. async def ingest_wazuh_alert_to_iris(self, event: dict[str, Any]) -> dict[str, Any]:
  582. """Create an IRIS Alert from a normalised Wazuh event and record it for dedup."""
  583. event_id = str(event.get("event_id", "")).strip()
  584. severity_str = (event.get("severity") or "medium").lower()
  585. severity_id = _IRIS_SEVERITY_ID.get(severity_str, 3)
  586. def _strip_nulls(obj: Any) -> Any:
  587. """Recursively remove None values to keep the JSON compact."""
  588. if isinstance(obj, dict):
  589. return {k: _strip_nulls(v) for k, v in obj.items() if v is not None}
  590. if isinstance(obj, list):
  591. return [_strip_nulls(i) for i in obj]
  592. return obj
  593. raw = event.get("raw") or {}
  594. note_data: dict[str, Any] = {
  595. "asset": event.get("asset") or {},
  596. "network": event.get("network") or {},
  597. "tags": event.get("tags") or [],
  598. }
  599. if raw.get("verdicts"):
  600. ioc_payload = raw.get("payload") or {}
  601. note_data["ioc"] = {
  602. "type": ioc_payload.get("ioc_type"),
  603. "value": ioc_payload.get("ioc_value"),
  604. "verdicts": [
  605. {
  606. "provider": v.get("provider"),
  607. "matched": v.get("matched"),
  608. "confidence": round(float(v.get("confidence") or 0), 4),
  609. "severity": v.get("severity"),
  610. "reason": (v.get("result") or {}).get("reason"),
  611. "vt_stats": ((v.get("result") or {}).get("providers") or {}).get("virustotal", {}).get("stats"),
  612. "abuseipdb_score": ((v.get("result") or {}).get("providers") or {}).get("abuseipdb", {}).get("score"),
  613. "abuseipdb_reports": ((v.get("result") or {}).get("providers") or {}).get("abuseipdb", {}).get("totalReports"),
  614. }
  615. for v in raw["verdicts"]
  616. ],
  617. }
  618. alert_note = json.dumps(_strip_nulls(note_data), indent=2, ensure_ascii=False)
  619. payload: dict[str, Any] = {
  620. "alert_title": event.get("title") or f"Wazuh alert {event_id}",
  621. "alert_description": event.get("description") or "",
  622. "alert_severity_id": severity_id,
  623. "alert_status_id": 1, # Unassigned
  624. "alert_source": "wazuh",
  625. "alert_source_ref": event_id,
  626. "alert_source_event_time": event.get("timestamp") or datetime.now(timezone.utc).isoformat(),
  627. "alert_customer_id": settings.iris_default_customer_id or 1,
  628. "alert_note": alert_note,
  629. }
  630. result = await self.iris_adapter.create_alert(payload)
  631. iris_alert_id = (result.get("data") or {}).get("alert_id")
  632. if event_id:
  633. synthetic_key = f"wazuh_alert_{event_id}"
  634. self.repo.upsert_incident(
  635. incident_key=synthetic_key,
  636. severity=event.get("severity") or "medium",
  637. status="open",
  638. iris_case_id=str(iris_alert_id) if iris_alert_id else None,
  639. )
  640. self.repo.add_event(
  641. incident_key=synthetic_key,
  642. event_id=event_id,
  643. source="wazuh",
  644. event_type=event.get("event_type") or "wazuh",
  645. raw_payload=event,
  646. decision_trace={"iris_alert_id": iris_alert_id, "action": "created_iris_alert"},
  647. )
  648. return {"iris_alert_id": iris_alert_id, "event_id": event_id}
  649. async def sync_wazuh_alerts(
  650. self,
  651. query: str = "soc_mvp_test=true OR event_type:*",
  652. limit: int = 50,
  653. minutes: int = 120,
  654. min_severity: str | None = None,
  655. ) -> dict[str, Any]:
  656. raw = await self.wazuh_adapter.search_alerts(query=query, limit=limit, minutes=minutes)
  657. hits = (raw.get("hits", {}) or {}).get("hits", []) if isinstance(raw, dict) else []
  658. # Resolve minimum severity: param > policy > default "medium"
  659. policy = self.repo.get_policy()
  660. effective_min = (min_severity or policy.get("sync", {}).get("min_severity", "medium")).lower()
  661. min_order = _SEVERITY_ORDER.get(effective_min, 2)
  662. processed = 0
  663. ingested = 0
  664. skipped_existing = 0
  665. skipped_filtered = 0
  666. failed = 0
  667. errors: list[str] = []
  668. created_incidents: list[str] = []
  669. ioc_evaluated = 0
  670. ioc_matched = 0
  671. ioc_rejected = 0
  672. for hit in hits:
  673. processed += 1
  674. event = self._normalize_wazuh_hit(hit)
  675. event_id = str(event.get("event_id", "")).strip()
  676. if event_id and self.repo.has_event("wazuh", event_id):
  677. skipped_existing += 1
  678. continue
  679. # Severity filter — skip alerts below minimum threshold
  680. event_order = _SEVERITY_ORDER.get((event.get("severity") or "low").lower(), 1)
  681. if event_order < min_order:
  682. skipped_filtered += 1
  683. continue
  684. try:
  685. if event.get("event_type") in {"ioc_dns", "ioc_ips"}:
  686. ioc_evaluated += 1
  687. payload = {
  688. "ioc_type": "domain" if event.get("event_type") == "ioc_dns" else "ip",
  689. "ioc_value": (event.get("network", {}) or {}).get("domain")
  690. or (event.get("network", {}) or {}).get("dst_ip")
  691. or (event.get("network", {}) or {}).get("src_ip")
  692. or "unknown",
  693. "source_event": {
  694. "event_id": event.get("event_id"),
  695. "asset": event.get("asset", {}),
  696. "network": event.get("network", {}),
  697. "raw": event.get("raw", {}),
  698. },
  699. }
  700. ioc_result = await self.evaluate_ioc(payload)
  701. if ioc_result.get("matched"):
  702. ioc_matched += 1
  703. ingested += 1
  704. iris_alert_id = (ioc_result.get("result", {}) or {}).get("iris_alert_id")
  705. if iris_alert_id:
  706. created_incidents.append(str(iris_alert_id))
  707. else:
  708. ioc_rejected += 1
  709. else:
  710. result = await self.ingest_wazuh_alert_to_iris(event)
  711. ingested += 1
  712. iris_alert_id = result.get("iris_alert_id")
  713. if iris_alert_id:
  714. created_incidents.append(str(iris_alert_id))
  715. except Exception as exc:
  716. failed += 1
  717. errors.append(f"{event_id or 'unknown_event'}: {exc}")
  718. return {
  719. "query": query,
  720. "window_minutes": minutes,
  721. "limit": limit,
  722. "min_severity_applied": effective_min,
  723. "processed": processed,
  724. "ingested": ingested,
  725. "skipped_existing": skipped_existing,
  726. "skipped_filtered": skipped_filtered,
  727. "failed": failed,
  728. "ioc_evaluated": ioc_evaluated,
  729. "ioc_matched": ioc_matched,
  730. "ioc_rejected": ioc_rejected,
  731. "iris_alert_ids": created_incidents,
  732. "errors": errors[:10],
  733. "total_hits": (raw.get("hits", {}).get("total", {}) if isinstance(raw, dict) else {}),
  734. }
  735. async def dependency_health(self) -> dict[str, Any]:
  736. out: dict[str, Any] = {}
  737. async def timed(name: str, fn):
  738. start = time.time()
  739. try:
  740. result = await fn()
  741. out[name] = {
  742. "status": "up",
  743. "latency_ms": round((time.time() - start) * 1000, 2),
  744. "details": result,
  745. }
  746. except Exception as exc:
  747. out[name] = {
  748. "status": "down",
  749. "latency_ms": round((time.time() - start) * 1000, 2),
  750. "error": str(exc),
  751. }
  752. await timed("wazuh", self.wazuh_adapter.auth_test)
  753. await timed("shuffle", self.shuffle_adapter.health)
  754. await timed("iris", self.iris_adapter.whoami)
  755. async def pagerduty_stub_health():
  756. async with httpx.AsyncClient(timeout=10.0) as client:
  757. r = await client.get(settings.pagerduty_base_url)
  758. r.raise_for_status()
  759. return {"status_code": r.status_code}
  760. await timed("pagerduty_stub", pagerduty_stub_health)
  761. return out