Nenhuma Descrição

mvp_service.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. from __future__ import annotations
  2. import hashlib
  3. import re
  4. import time
  5. from datetime import datetime, timezone
  6. from typing import Any
  7. import httpx
  8. from app.adapters.iris import IrisAdapter
  9. from app.adapters.pagerduty import PagerDutyAdapter
  10. from app.adapters.shuffle import ShuffleAdapter
  11. from app.adapters.wazuh import WazuhAdapter
  12. from app.config import settings
  13. from app.repositories.mvp_repo import MvpRepository
  14. class MvpService:
  15. def __init__(
  16. self,
  17. repo: MvpRepository,
  18. wazuh_adapter: WazuhAdapter,
  19. shuffle_adapter: ShuffleAdapter,
  20. iris_adapter: IrisAdapter,
  21. pagerduty_adapter: PagerDutyAdapter,
  22. ) -> None:
  23. self.repo = repo
  24. self.wazuh_adapter = wazuh_adapter
  25. self.shuffle_adapter = shuffle_adapter
  26. self.iris_adapter = iris_adapter
  27. self.pagerduty_adapter = pagerduty_adapter
  28. def _is_off_hours(self, ts: datetime) -> bool:
  29. hour = ts.astimezone(timezone.utc).hour
  30. return hour < 6 or hour >= 20
  31. def _safe_excerpt(self, payload: Any) -> str:
  32. text = str(payload)
  33. return text[:300]
  34. def _primary_subject(self, event: dict[str, Any]) -> str:
  35. asset = event.get("asset", {})
  36. return str(asset.get("user") or asset.get("hostname") or "unknown")
  37. def _primary_observable(self, event: dict[str, Any]) -> str:
  38. network = event.get("network", {})
  39. return str(network.get("domain") or network.get("src_ip") or network.get("dst_ip") or "unknown")
  40. def _incident_key(self, event: dict[str, Any]) -> str:
  41. ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00")).astimezone(timezone.utc)
  42. day_bucket = ts.strftime("%Y-%m-%d")
  43. raw = "|".join(
  44. [
  45. str(event.get("event_type", "generic")),
  46. self._primary_subject(event),
  47. self._primary_observable(event),
  48. day_bucket,
  49. ]
  50. )
  51. return hashlib.sha256(raw.encode("utf-8")).hexdigest()
  52. def _effective_severity(self, event: dict[str, Any], policy: dict[str, Any]) -> tuple[str, int, list[str]]:
  53. severity = str(event.get("severity", "medium")).lower()
  54. risk_context = event.get("risk_context", {})
  55. network = event.get("network", {})
  56. weights = policy.get("risk", {}).get("weights", {})
  57. score = 0
  58. factors: list[str] = []
  59. allowed_country = policy.get("vpn", {}).get("allowed_country", "TH")
  60. country = str(network.get("country", "")).upper()
  61. if country and country != allowed_country:
  62. score += int(weights.get("outside_thailand", 50))
  63. factors.append("outside_country")
  64. if risk_context.get("admin_account"):
  65. score += int(weights.get("admin", 20))
  66. factors.append("admin_account")
  67. ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00"))
  68. if risk_context.get("off_hours") or self._is_off_hours(ts):
  69. score += int(weights.get("off_hours", 15))
  70. factors.append("off_hours")
  71. if risk_context.get("first_seen_country"):
  72. score += int(weights.get("first_seen_country", 15))
  73. factors.append("first_seen_country")
  74. thresholds = policy.get("risk", {}).get("thresholds", {})
  75. if score >= int(thresholds.get("high", 70)):
  76. severity = "high" if severity in {"low", "medium"} else severity
  77. elif score >= int(thresholds.get("medium", 40)) and severity == "low":
  78. severity = "medium"
  79. return severity, score, factors
  80. def _is_exception(self, event: dict[str, Any], policy: dict[str, Any]) -> bool:
  81. if event.get("event_type") != "vpn_geo_anomaly":
  82. return False
  83. asset = event.get("asset", {})
  84. user = str(asset.get("user", ""))
  85. allowed_users = set(policy.get("vpn", {}).get("exception_users", []))
  86. return user in allowed_users
  87. def _extract_iris_case_id(self, iris_response: dict[str, Any]) -> str | None:
  88. if "case_id" in iris_response:
  89. return str(iris_response.get("case_id"))
  90. data = iris_response.get("data")
  91. if isinstance(data, dict) and "case_id" in data:
  92. return str(data.get("case_id"))
  93. return None
  94. def _parse_kv_pairs(self, text: str) -> dict[str, str]:
  95. pattern = r"([A-Za-z0-9_]+)=('(?:[^']*)'|\"(?:[^\"]*)\"|[^\\s]+)"
  96. out: dict[str, str] = {}
  97. for key, raw in re.findall(pattern, text):
  98. value = raw.strip().strip("'").strip('"')
  99. out[key] = value
  100. return out
  101. def _severity_from_rule_level(self, rule_level: Any) -> str:
  102. try:
  103. level = int(rule_level)
  104. except (TypeError, ValueError):
  105. return "medium"
  106. if level >= 12:
  107. return "critical"
  108. if level >= 8:
  109. return "high"
  110. if level >= 4:
  111. return "medium"
  112. return "low"
  113. def _event_type_from_text(self, text: str, parsed: dict[str, str]) -> str:
  114. explicit = parsed.get("event_type")
  115. if explicit:
  116. return explicit
  117. lowered = text.lower()
  118. if "vpn" in lowered and ("geo" in lowered or "country" in lowered):
  119. return "vpn_geo_anomaly"
  120. if "domain" in lowered or "dns" in lowered:
  121. return "ioc_dns"
  122. if "c2" in lowered or "ips" in lowered or "ip " in lowered:
  123. return "ioc_ips"
  124. if "auth" in lowered and "fail" in lowered:
  125. return "auth_anomaly"
  126. return "generic"
  127. def _normalize_wazuh_hit(self, hit: dict[str, Any]) -> dict[str, Any]:
  128. src = hit.get("_source", {})
  129. full_log = str(src.get("full_log", ""))
  130. parsed = self._parse_kv_pairs(full_log)
  131. event_id = str(parsed.get("event_id") or src.get("id") or hit.get("_id") or f"wazuh-{int(time.time())}")
  132. timestamp = (
  133. src.get("@timestamp")
  134. or src.get("timestamp")
  135. or datetime.now(timezone.utc).isoformat()
  136. )
  137. rule = src.get("rule", {}) if isinstance(src.get("rule"), dict) else {}
  138. rule_desc = str(rule.get("description") or "")
  139. event_type = self._event_type_from_text(full_log, parsed)
  140. severity = str(parsed.get("severity", "")).lower() or self._severity_from_rule_level(rule.get("level"))
  141. src_ip = parsed.get("src_ip")
  142. dst_ip = parsed.get("dst_ip")
  143. domain = parsed.get("query") or parsed.get("domain")
  144. country = parsed.get("country")
  145. user = parsed.get("user") or (src.get("agent", {}) or {}).get("name")
  146. title = rule_desc or f"Wazuh alert {rule.get('id', '')}".strip()
  147. description = full_log or rule_desc or "Wazuh alert"
  148. return {
  149. "source": "wazuh",
  150. "event_type": event_type,
  151. "event_id": event_id,
  152. "timestamp": timestamp,
  153. "severity": severity if severity in {"low", "medium", "high", "critical"} else "medium",
  154. "title": title,
  155. "description": description,
  156. "asset": {
  157. "user": user,
  158. "hostname": (src.get("agent", {}) or {}).get("name"),
  159. "agent_id": (src.get("agent", {}) or {}).get("id"),
  160. },
  161. "network": {
  162. "src_ip": src_ip,
  163. "dst_ip": dst_ip,
  164. "domain": domain,
  165. "country": country,
  166. },
  167. "tags": ["wazuh", event_type, f"rule_{rule.get('id', 'unknown')}"],
  168. "risk_context": {
  169. "outside_thailand": bool(country and str(country).upper() != "TH"),
  170. },
  171. "raw": src,
  172. "payload": {},
  173. }
  174. async def ingest_incident(self, event: dict[str, Any]) -> dict[str, Any]:
  175. policy = self.repo.get_policy()
  176. incident_key = self._incident_key(event)
  177. if self._is_exception(event, policy):
  178. decision_trace = {
  179. "incident_key": incident_key,
  180. "policy_exception": True,
  181. "reason": "vpn_exception_user",
  182. }
  183. self.repo.upsert_incident(incident_key, severity="low", status="ignored_exception", iris_case_id=None)
  184. self.repo.add_event(
  185. incident_key=incident_key,
  186. event_id=event.get("event_id"),
  187. source=event.get("source", "unknown"),
  188. event_type=event.get("event_type", "generic"),
  189. raw_payload=event,
  190. decision_trace=decision_trace,
  191. )
  192. return {
  193. "incident_key": incident_key,
  194. "action_taken": "ignored_exception",
  195. "escalation_stub_sent": False,
  196. "decision_trace": decision_trace,
  197. }
  198. effective_severity, risk_score, risk_factors = self._effective_severity(event, policy)
  199. current = self.repo.get_incident(incident_key)
  200. action_taken = "updated_case" if current else "created_case"
  201. iris_case_id = current.get("iris_case_id") if current else None
  202. if not iris_case_id:
  203. case_payload = {
  204. "case_name": event.get("title", "SOC Incident"),
  205. "case_description": event.get("description", "Generated by soc-integrator MVP"),
  206. "case_customer": event.get("payload", {}).get("case_customer", settings.iris_default_customer_id),
  207. "case_soc_id": event.get("payload", {}).get("case_soc_id", settings.iris_default_soc_id),
  208. }
  209. iris_result = await self.iris_adapter.create_case(case_payload)
  210. iris_case_id = self._extract_iris_case_id(iris_result)
  211. else:
  212. update_payload = {
  213. "case_description": f"{event.get('description', 'Updated by soc-integrator MVP')} [event_id={event.get('event_id', '')}]"
  214. }
  215. try:
  216. await self.iris_adapter.update_case(iris_case_id, update_payload)
  217. except Exception:
  218. # Keep pipeline progressing for MVP even if update path is unsupported.
  219. pass
  220. stored = self.repo.upsert_incident(
  221. incident_key=incident_key,
  222. severity=effective_severity,
  223. status="open",
  224. iris_case_id=iris_case_id,
  225. )
  226. decision_trace = {
  227. "incident_key": incident_key,
  228. "risk_score": risk_score,
  229. "risk_factors": risk_factors,
  230. "effective_severity": effective_severity,
  231. "action_taken": action_taken,
  232. }
  233. self.repo.add_event(
  234. incident_key=incident_key,
  235. event_id=event.get("event_id"),
  236. source=event.get("source", "unknown"),
  237. event_type=event.get("event_type", "generic"),
  238. raw_payload=event,
  239. decision_trace=decision_trace,
  240. )
  241. escalate_severities = set(policy.get("escalate_severities", ["high", "critical"]))
  242. escalation_stub_sent = False
  243. stub_response: dict[str, Any] | None = None
  244. if effective_severity in escalate_severities:
  245. escalation_payload = {
  246. "incident_key": incident_key,
  247. "title": event.get("title", "SOC Incident"),
  248. "severity": effective_severity,
  249. "source": event.get("source", "soc-integrator"),
  250. "iris_case_id": iris_case_id,
  251. "event_summary": event.get("description", ""),
  252. "timestamp": event.get("timestamp"),
  253. }
  254. try:
  255. pd_result = await self.pagerduty_adapter.create_incident(escalation_payload)
  256. escalation_stub_sent = True
  257. stub_response = {"ok": True, "data": pd_result}
  258. self.repo.add_escalation_audit(
  259. incident_key=incident_key,
  260. status_code=200,
  261. success=True,
  262. response_excerpt=self._safe_excerpt(pd_result),
  263. )
  264. except Exception as exc:
  265. stub_response = {"ok": False, "error": str(exc)}
  266. self.repo.add_escalation_audit(
  267. incident_key=incident_key,
  268. status_code=502,
  269. success=False,
  270. response_excerpt=self._safe_excerpt(exc),
  271. )
  272. return {
  273. "incident_key": stored["incident_key"],
  274. "action_taken": action_taken,
  275. "iris_case_id": stored.get("iris_case_id"),
  276. "escalation_stub_sent": escalation_stub_sent,
  277. "stub_response": stub_response,
  278. "decision_trace": decision_trace,
  279. }
  280. async def evaluate_ioc(self, payload: dict[str, Any]) -> dict[str, Any]:
  281. policy = self.repo.get_policy()
  282. workflow_id = str(policy.get("shuffle", {}).get("ioc_workflow_id", "")).strip()
  283. matched = True
  284. confidence = 0.7
  285. shuffle_result: dict[str, Any] | None = None
  286. if workflow_id:
  287. shuffle_result = await self.shuffle_adapter.trigger_workflow(workflow_id, payload)
  288. if matched:
  289. event = {
  290. "source": "shuffle",
  291. "event_type": "ioc_dns" if payload.get("ioc_type") == "domain" else "ioc_ips",
  292. "event_id": payload.get("source_event", {}).get("event_id") or f"ioc-{int(time.time())}",
  293. "timestamp": datetime.now(timezone.utc).isoformat(),
  294. "severity": "medium",
  295. "title": f"IOC match: {payload.get('ioc_value', 'unknown')}",
  296. "description": "IOC evaluation result",
  297. "asset": payload.get("source_event", {}).get("asset", {}),
  298. "network": payload.get("source_event", {}).get("network", {}),
  299. "tags": ["ioc", str(payload.get("ioc_type", "unknown"))],
  300. "risk_context": {},
  301. "raw": payload,
  302. "payload": {},
  303. }
  304. ingest_result = await self.ingest_incident(event)
  305. else:
  306. ingest_result = {"action_taken": "rejected"}
  307. return {
  308. "matched": matched,
  309. "confidence": confidence,
  310. "shuffle": shuffle_result,
  311. "result": ingest_result,
  312. }
  313. async def evaluate_vpn(self, payload: dict[str, Any]) -> dict[str, Any]:
  314. if not payload.get("success", False):
  315. return {
  316. "risk_score": 0,
  317. "risk_factors": [],
  318. "exception_applied": False,
  319. "action_taken": "rejected",
  320. }
  321. event = {
  322. "source": "wazuh",
  323. "event_type": "vpn_geo_anomaly",
  324. "event_id": payload.get("event_id") or f"vpn-{int(time.time())}",
  325. "timestamp": payload.get("event_time") or datetime.now(timezone.utc).isoformat(),
  326. "severity": "high",
  327. "title": f"VPN login anomaly: {payload.get('user', 'unknown')}",
  328. "description": f"VPN success from {payload.get('country_code', 'unknown')} for user {payload.get('user', 'unknown')}",
  329. "asset": {"user": payload.get("user")},
  330. "network": {"src_ip": payload.get("src_ip"), "country": payload.get("country_code")},
  331. "tags": ["vpn", "geo-anomaly"],
  332. "risk_context": {
  333. "outside_thailand": payload.get("country_code", "").upper() != "TH",
  334. "admin_account": bool(payload.get("is_admin", False)),
  335. "off_hours": bool(payload.get("off_hours", False)),
  336. "first_seen_country": bool(payload.get("first_seen_country", False)),
  337. },
  338. "raw": payload,
  339. "payload": {},
  340. }
  341. ingest_result = await self.ingest_incident(event)
  342. decision_trace = ingest_result.get("decision_trace", {})
  343. return {
  344. "risk_score": decision_trace.get("risk_score", 0),
  345. "risk_factors": decision_trace.get("risk_factors", []),
  346. "exception_applied": ingest_result.get("action_taken") == "ignored_exception",
  347. "action_taken": ingest_result.get("action_taken"),
  348. "incident_key": ingest_result.get("incident_key"),
  349. "iris_case_id": ingest_result.get("iris_case_id"),
  350. "escalation_stub_sent": ingest_result.get("escalation_stub_sent", False),
  351. }
  352. async def sync_wazuh_alerts(
  353. self,
  354. query: str = "soc_mvp_test=true OR event_type:*",
  355. limit: int = 50,
  356. minutes: int = 120,
  357. ) -> dict[str, Any]:
  358. raw = await self.wazuh_adapter.search_alerts(query=query, limit=limit, minutes=minutes)
  359. hits = (raw.get("hits", {}) or {}).get("hits", []) if isinstance(raw, dict) else []
  360. processed = 0
  361. ingested = 0
  362. skipped_existing = 0
  363. failed = 0
  364. errors: list[str] = []
  365. created_incidents: list[str] = []
  366. for hit in hits:
  367. processed += 1
  368. event = self._normalize_wazuh_hit(hit)
  369. event_id = str(event.get("event_id", "")).strip()
  370. if event_id and self.repo.has_event("wazuh", event_id):
  371. skipped_existing += 1
  372. continue
  373. try:
  374. result = await self.ingest_incident(event)
  375. ingested += 1
  376. incident_key = str(result.get("incident_key", ""))
  377. if incident_key:
  378. created_incidents.append(incident_key)
  379. except Exception as exc:
  380. failed += 1
  381. errors.append(f"{event_id or 'unknown_event'}: {exc}")
  382. return {
  383. "query": query,
  384. "window_minutes": minutes,
  385. "limit": limit,
  386. "processed": processed,
  387. "ingested": ingested,
  388. "skipped_existing": skipped_existing,
  389. "failed": failed,
  390. "incident_keys": created_incidents,
  391. "errors": errors[:10],
  392. "total_hits": (raw.get("hits", {}).get("total", {}) if isinstance(raw, dict) else {}),
  393. }
  394. async def dependency_health(self) -> dict[str, Any]:
  395. out: dict[str, Any] = {}
  396. async def timed(name: str, fn):
  397. start = time.time()
  398. try:
  399. result = await fn()
  400. out[name] = {
  401. "status": "up",
  402. "latency_ms": round((time.time() - start) * 1000, 2),
  403. "details": result,
  404. }
  405. except Exception as exc:
  406. out[name] = {
  407. "status": "down",
  408. "latency_ms": round((time.time() - start) * 1000, 2),
  409. "error": str(exc),
  410. }
  411. await timed("wazuh", self.wazuh_adapter.auth_test)
  412. await timed("shuffle", self.shuffle_adapter.health)
  413. await timed("iris", self.iris_adapter.whoami)
  414. async def pagerduty_stub_health():
  415. async with httpx.AsyncClient(timeout=10.0) as client:
  416. r = await client.get(settings.pagerduty_base_url)
  417. r.raise_for_status()
  418. return {"status_code": r.status_code}
  419. await timed("pagerduty_stub", pagerduty_stub_health)
  420. return out