暫無描述

mvp_service.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. from __future__ import annotations
  2. import hashlib
  3. import time
  4. from datetime import datetime, timezone
  5. from typing import Any
  6. import httpx
  7. from app.adapters.iris import IrisAdapter
  8. from app.adapters.pagerduty import PagerDutyAdapter
  9. from app.adapters.shuffle import ShuffleAdapter
  10. from app.adapters.wazuh import WazuhAdapter
  11. from app.config import settings
  12. from app.repositories.mvp_repo import MvpRepository
  13. class MvpService:
  14. def __init__(
  15. self,
  16. repo: MvpRepository,
  17. wazuh_adapter: WazuhAdapter,
  18. shuffle_adapter: ShuffleAdapter,
  19. iris_adapter: IrisAdapter,
  20. pagerduty_adapter: PagerDutyAdapter,
  21. ) -> None:
  22. self.repo = repo
  23. self.wazuh_adapter = wazuh_adapter
  24. self.shuffle_adapter = shuffle_adapter
  25. self.iris_adapter = iris_adapter
  26. self.pagerduty_adapter = pagerduty_adapter
  27. def _is_off_hours(self, ts: datetime) -> bool:
  28. hour = ts.astimezone(timezone.utc).hour
  29. return hour < 6 or hour >= 20
  30. def _safe_excerpt(self, payload: Any) -> str:
  31. text = str(payload)
  32. return text[:300]
  33. def _primary_subject(self, event: dict[str, Any]) -> str:
  34. asset = event.get("asset", {})
  35. return str(asset.get("user") or asset.get("hostname") or "unknown")
  36. def _primary_observable(self, event: dict[str, Any]) -> str:
  37. network = event.get("network", {})
  38. return str(network.get("domain") or network.get("src_ip") or network.get("dst_ip") or "unknown")
  39. def _incident_key(self, event: dict[str, Any]) -> str:
  40. ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00")).astimezone(timezone.utc)
  41. day_bucket = ts.strftime("%Y-%m-%d")
  42. raw = "|".join(
  43. [
  44. str(event.get("event_type", "generic")),
  45. self._primary_subject(event),
  46. self._primary_observable(event),
  47. day_bucket,
  48. ]
  49. )
  50. return hashlib.sha256(raw.encode("utf-8")).hexdigest()
  51. def _effective_severity(self, event: dict[str, Any], policy: dict[str, Any]) -> tuple[str, int, list[str]]:
  52. severity = str(event.get("severity", "medium")).lower()
  53. risk_context = event.get("risk_context", {})
  54. network = event.get("network", {})
  55. weights = policy.get("risk", {}).get("weights", {})
  56. score = 0
  57. factors: list[str] = []
  58. allowed_country = policy.get("vpn", {}).get("allowed_country", "TH")
  59. country = str(network.get("country", "")).upper()
  60. if country and country != allowed_country:
  61. score += int(weights.get("outside_thailand", 50))
  62. factors.append("outside_country")
  63. if risk_context.get("admin_account"):
  64. score += int(weights.get("admin", 20))
  65. factors.append("admin_account")
  66. ts = datetime.fromisoformat(event["timestamp"].replace("Z", "+00:00"))
  67. if risk_context.get("off_hours") or self._is_off_hours(ts):
  68. score += int(weights.get("off_hours", 15))
  69. factors.append("off_hours")
  70. if risk_context.get("first_seen_country"):
  71. score += int(weights.get("first_seen_country", 15))
  72. factors.append("first_seen_country")
  73. thresholds = policy.get("risk", {}).get("thresholds", {})
  74. if score >= int(thresholds.get("high", 70)):
  75. severity = "high" if severity in {"low", "medium"} else severity
  76. elif score >= int(thresholds.get("medium", 40)) and severity == "low":
  77. severity = "medium"
  78. return severity, score, factors
  79. def _is_exception(self, event: dict[str, Any], policy: dict[str, Any]) -> bool:
  80. if event.get("event_type") != "vpn_geo_anomaly":
  81. return False
  82. asset = event.get("asset", {})
  83. user = str(asset.get("user", ""))
  84. allowed_users = set(policy.get("vpn", {}).get("exception_users", []))
  85. return user in allowed_users
  86. def _extract_iris_case_id(self, iris_response: dict[str, Any]) -> str | None:
  87. if "case_id" in iris_response:
  88. return str(iris_response.get("case_id"))
  89. data = iris_response.get("data")
  90. if isinstance(data, dict) and "case_id" in data:
  91. return str(data.get("case_id"))
  92. return None
  93. async def ingest_incident(self, event: dict[str, Any]) -> dict[str, Any]:
  94. policy = self.repo.get_policy()
  95. incident_key = self._incident_key(event)
  96. if self._is_exception(event, policy):
  97. decision_trace = {
  98. "incident_key": incident_key,
  99. "policy_exception": True,
  100. "reason": "vpn_exception_user",
  101. }
  102. self.repo.upsert_incident(incident_key, severity="low", status="ignored_exception", iris_case_id=None)
  103. self.repo.add_event(
  104. incident_key=incident_key,
  105. event_id=event.get("event_id"),
  106. source=event.get("source", "unknown"),
  107. event_type=event.get("event_type", "generic"),
  108. raw_payload=event,
  109. decision_trace=decision_trace,
  110. )
  111. return {
  112. "incident_key": incident_key,
  113. "action_taken": "ignored_exception",
  114. "escalation_stub_sent": False,
  115. "decision_trace": decision_trace,
  116. }
  117. effective_severity, risk_score, risk_factors = self._effective_severity(event, policy)
  118. current = self.repo.get_incident(incident_key)
  119. action_taken = "updated_case" if current else "created_case"
  120. iris_case_id = current.get("iris_case_id") if current else None
  121. if not iris_case_id:
  122. case_payload = {
  123. "case_name": event.get("title", "SOC Incident"),
  124. "case_description": event.get("description", "Generated by soc-integrator MVP"),
  125. "case_customer": event.get("payload", {}).get("case_customer", settings.iris_default_customer_id),
  126. "case_soc_id": event.get("payload", {}).get("case_soc_id", settings.iris_default_soc_id),
  127. }
  128. iris_result = await self.iris_adapter.create_case(case_payload)
  129. iris_case_id = self._extract_iris_case_id(iris_result)
  130. else:
  131. update_payload = {
  132. "case_description": f"{event.get('description', 'Updated by soc-integrator MVP')} [event_id={event.get('event_id', '')}]"
  133. }
  134. try:
  135. await self.iris_adapter.update_case(iris_case_id, update_payload)
  136. except Exception:
  137. # Keep pipeline progressing for MVP even if update path is unsupported.
  138. pass
  139. stored = self.repo.upsert_incident(
  140. incident_key=incident_key,
  141. severity=effective_severity,
  142. status="open",
  143. iris_case_id=iris_case_id,
  144. )
  145. decision_trace = {
  146. "incident_key": incident_key,
  147. "risk_score": risk_score,
  148. "risk_factors": risk_factors,
  149. "effective_severity": effective_severity,
  150. "action_taken": action_taken,
  151. }
  152. self.repo.add_event(
  153. incident_key=incident_key,
  154. event_id=event.get("event_id"),
  155. source=event.get("source", "unknown"),
  156. event_type=event.get("event_type", "generic"),
  157. raw_payload=event,
  158. decision_trace=decision_trace,
  159. )
  160. escalate_severities = set(policy.get("escalate_severities", ["high", "critical"]))
  161. escalation_stub_sent = False
  162. stub_response: dict[str, Any] | None = None
  163. if effective_severity in escalate_severities:
  164. escalation_payload = {
  165. "incident_key": incident_key,
  166. "title": event.get("title", "SOC Incident"),
  167. "severity": effective_severity,
  168. "source": event.get("source", "soc-integrator"),
  169. "iris_case_id": iris_case_id,
  170. "event_summary": event.get("description", ""),
  171. "timestamp": event.get("timestamp"),
  172. }
  173. try:
  174. pd_result = await self.pagerduty_adapter.create_incident(escalation_payload)
  175. escalation_stub_sent = True
  176. stub_response = {"ok": True, "data": pd_result}
  177. self.repo.add_escalation_audit(
  178. incident_key=incident_key,
  179. status_code=200,
  180. success=True,
  181. response_excerpt=self._safe_excerpt(pd_result),
  182. )
  183. except Exception as exc:
  184. stub_response = {"ok": False, "error": str(exc)}
  185. self.repo.add_escalation_audit(
  186. incident_key=incident_key,
  187. status_code=502,
  188. success=False,
  189. response_excerpt=self._safe_excerpt(exc),
  190. )
  191. return {
  192. "incident_key": stored["incident_key"],
  193. "action_taken": action_taken,
  194. "iris_case_id": stored.get("iris_case_id"),
  195. "escalation_stub_sent": escalation_stub_sent,
  196. "stub_response": stub_response,
  197. "decision_trace": decision_trace,
  198. }
  199. async def evaluate_ioc(self, payload: dict[str, Any]) -> dict[str, Any]:
  200. policy = self.repo.get_policy()
  201. workflow_id = str(policy.get("shuffle", {}).get("ioc_workflow_id", "")).strip()
  202. matched = True
  203. confidence = 0.7
  204. shuffle_result: dict[str, Any] | None = None
  205. if workflow_id:
  206. shuffle_result = await self.shuffle_adapter.trigger_workflow(workflow_id, payload)
  207. if matched:
  208. event = {
  209. "source": "shuffle",
  210. "event_type": "ioc_dns" if payload.get("ioc_type") == "domain" else "ioc_ips",
  211. "event_id": payload.get("source_event", {}).get("event_id") or f"ioc-{int(time.time())}",
  212. "timestamp": datetime.now(timezone.utc).isoformat(),
  213. "severity": "medium",
  214. "title": f"IOC match: {payload.get('ioc_value', 'unknown')}",
  215. "description": "IOC evaluation result",
  216. "asset": payload.get("source_event", {}).get("asset", {}),
  217. "network": payload.get("source_event", {}).get("network", {}),
  218. "tags": ["ioc", str(payload.get("ioc_type", "unknown"))],
  219. "risk_context": {},
  220. "raw": payload,
  221. "payload": {},
  222. }
  223. ingest_result = await self.ingest_incident(event)
  224. else:
  225. ingest_result = {"action_taken": "rejected"}
  226. return {
  227. "matched": matched,
  228. "confidence": confidence,
  229. "shuffle": shuffle_result,
  230. "result": ingest_result,
  231. }
  232. async def evaluate_vpn(self, payload: dict[str, Any]) -> dict[str, Any]:
  233. if not payload.get("success", False):
  234. return {
  235. "risk_score": 0,
  236. "risk_factors": [],
  237. "exception_applied": False,
  238. "action_taken": "rejected",
  239. }
  240. event = {
  241. "source": "wazuh",
  242. "event_type": "vpn_geo_anomaly",
  243. "event_id": payload.get("event_id") or f"vpn-{int(time.time())}",
  244. "timestamp": payload.get("event_time") or datetime.now(timezone.utc).isoformat(),
  245. "severity": "high",
  246. "title": f"VPN login anomaly: {payload.get('user', 'unknown')}",
  247. "description": f"VPN success from {payload.get('country_code', 'unknown')} for user {payload.get('user', 'unknown')}",
  248. "asset": {"user": payload.get("user")},
  249. "network": {"src_ip": payload.get("src_ip"), "country": payload.get("country_code")},
  250. "tags": ["vpn", "geo-anomaly"],
  251. "risk_context": {
  252. "outside_thailand": payload.get("country_code", "").upper() != "TH",
  253. "admin_account": bool(payload.get("is_admin", False)),
  254. "off_hours": bool(payload.get("off_hours", False)),
  255. "first_seen_country": bool(payload.get("first_seen_country", False)),
  256. },
  257. "raw": payload,
  258. "payload": {},
  259. }
  260. ingest_result = await self.ingest_incident(event)
  261. decision_trace = ingest_result.get("decision_trace", {})
  262. return {
  263. "risk_score": decision_trace.get("risk_score", 0),
  264. "risk_factors": decision_trace.get("risk_factors", []),
  265. "exception_applied": ingest_result.get("action_taken") == "ignored_exception",
  266. "action_taken": ingest_result.get("action_taken"),
  267. "incident_key": ingest_result.get("incident_key"),
  268. "iris_case_id": ingest_result.get("iris_case_id"),
  269. "escalation_stub_sent": ingest_result.get("escalation_stub_sent", False),
  270. }
  271. async def dependency_health(self) -> dict[str, Any]:
  272. out: dict[str, Any] = {}
  273. async def timed(name: str, fn):
  274. start = time.time()
  275. try:
  276. result = await fn()
  277. out[name] = {
  278. "status": "up",
  279. "latency_ms": round((time.time() - start) * 1000, 2),
  280. "details": result,
  281. }
  282. except Exception as exc:
  283. out[name] = {
  284. "status": "down",
  285. "latency_ms": round((time.time() - start) * 1000, 2),
  286. "error": str(exc),
  287. }
  288. await timed("wazuh", self.wazuh_adapter.auth_test)
  289. await timed("shuffle", self.shuffle_adapter.health)
  290. await timed("iris", self.iris_adapter.whoami)
  291. async def pagerduty_stub_health():
  292. async with httpx.AsyncClient(timeout=10.0) as client:
  293. r = await client.get(settings.pagerduty_base_url)
  294. r.raise_for_status()
  295. return {"status_code": r.status_code}
  296. await timed("pagerduty_stub", pagerduty_stub_health)
  297. return out