from __future__ import annotations from datetime import datetime, timezone from typing import Any from psycopg.types.json import Json from app.db import get_conn def utc_now() -> datetime: return datetime.now(timezone.utc) DEFAULT_POLICY: dict[str, Any] = { "escalate_severities": ["high", "critical"], "vpn": { "allowed_country": "TH", "exception_users": [], }, "risk": { "weights": { "outside_thailand": 50, "admin": 20, "off_hours": 15, "first_seen_country": 15, }, "thresholds": { "high": 70, "medium": 40, }, }, "shuffle": { "ioc_workflow_id": "", }, } class MvpRepository: def count_incident_events_since(self, since: datetime, source: str | None = None) -> int: with get_conn() as conn, conn.cursor() as cur: if source: cur.execute( """ SELECT COUNT(*) AS cnt FROM incident_events WHERE created_at >= %s AND source = %s """, (since, source), ) else: cur.execute( """ SELECT COUNT(*) AS cnt FROM incident_events WHERE created_at >= %s """, (since,), ) row = cur.fetchone() or {} return int(row.get("cnt", 0) or 0) def count_incidents_with_iris_since(self, since: datetime) -> int: with get_conn() as conn, conn.cursor() as cur: cur.execute( """ SELECT COUNT(*) AS cnt FROM incident_index WHERE first_seen >= %s AND iris_case_id IS NOT NULL AND iris_case_id <> '' """, (since,), ) row = cur.fetchone() or {} return int(row.get("cnt", 0) or 0) def count_escalations_since(self, since: datetime, success: bool | None = None) -> int: with get_conn() as conn, conn.cursor() as cur: if success is None: cur.execute( """ SELECT COUNT(*) AS cnt FROM escalation_audit WHERE attempted_at >= %s """, (since,), ) else: cur.execute( """ SELECT COUNT(*) AS cnt FROM escalation_audit WHERE attempted_at >= %s AND success = %s """, (since, success), ) row = cur.fetchone() or {} return int(row.get("cnt", 0) or 0) def count_c_detection_events_since(self, since: datetime) -> int: with get_conn() as conn, conn.cursor() as cur: cur.execute( """ SELECT COUNT(*) AS cnt FROM c_detection_events WHERE matched_at >= %s """, (since,), ) row = cur.fetchone() or {} return int(row.get("cnt", 0) or 0) def list_recent_escalations(self, limit: int = 20) -> list[dict[str, Any]]: with get_conn() as conn, conn.cursor() as cur: cur.execute( """ SELECT id, incident_key, attempted_at, status_code, success, response_excerpt FROM escalation_audit ORDER BY attempted_at DESC LIMIT %s """, (max(1, limit),), ) return [dict(row) for row in cur.fetchall()] def has_event(self, source: str, event_id: str) -> bool: with get_conn() as conn, conn.cursor() as cur: cur.execute( "SELECT 1 FROM incident_events WHERE source = %s AND event_id = %s LIMIT 1", (source, event_id), ) return cur.fetchone() is not None def ensure_policy(self) -> None: with get_conn() as conn, conn.cursor() as cur: cur.execute("SELECT id FROM policy_config WHERE id = 1") found = cur.fetchone() if not found: cur.execute( "INSERT INTO policy_config(id, data) VALUES (1, %s)", (Json(DEFAULT_POLICY),), ) def get_policy(self) -> dict[str, Any]: self.ensure_policy() with get_conn() as conn, conn.cursor() as cur: cur.execute("SELECT data FROM policy_config WHERE id = 1") row = cur.fetchone() return dict(row["data"]) if row else dict(DEFAULT_POLICY) def update_policy(self, data: dict[str, Any]) -> dict[str, Any]: with get_conn() as conn, conn.cursor() as cur: cur.execute( """ INSERT INTO policy_config(id, data, updated_at) VALUES (1, %s, NOW()) ON CONFLICT (id) DO UPDATE SET data = EXCLUDED.data, updated_at = NOW() RETURNING data """, (Json(data),), ) row = cur.fetchone() return dict(row["data"]) def get_incident(self, incident_key: str) -> dict[str, Any] | None: with get_conn() as conn, conn.cursor() as cur: cur.execute( "SELECT incident_key, iris_case_id, status, severity, first_seen, last_seen FROM incident_index WHERE incident_key = %s", (incident_key,), ) row = cur.fetchone() return dict(row) if row else None def upsert_incident( self, incident_key: str, severity: str, status: str, iris_case_id: str | None, ) -> dict[str, Any]: now = utc_now() with get_conn() as conn, conn.cursor() as cur: cur.execute( """ INSERT INTO incident_index(incident_key, iris_case_id, status, severity, first_seen, last_seen) VALUES (%s, %s, %s, %s, %s, %s) ON CONFLICT (incident_key) DO UPDATE SET iris_case_id = COALESCE(EXCLUDED.iris_case_id, incident_index.iris_case_id), status = EXCLUDED.status, severity = EXCLUDED.severity, last_seen = EXCLUDED.last_seen RETURNING incident_key, iris_case_id, status, severity, first_seen, last_seen """, (incident_key, iris_case_id, status, severity, now, now), ) return dict(cur.fetchone()) def add_event( self, incident_key: str, event_id: str | None, source: str, event_type: str, raw_payload: dict[str, Any], decision_trace: dict[str, Any], ) -> None: with get_conn() as conn, conn.cursor() as cur: cur.execute( """ INSERT INTO incident_events(incident_key, event_id, source, event_type, raw_payload, decision_trace) VALUES (%s, %s, %s, %s, %s, %s) """, ( incident_key, event_id, source, event_type, Json(raw_payload), Json(decision_trace), ), ) def add_escalation_audit( self, incident_key: str, status_code: int | None, success: bool, response_excerpt: str | None, ) -> None: with get_conn() as conn, conn.cursor() as cur: cur.execute( """ INSERT INTO escalation_audit(incident_key, status_code, success, response_excerpt) VALUES (%s, %s, %s, %s) """, (incident_key, status_code, success, response_excerpt), ) def add_ioc_trace( self, action: str, ioc_type: str, ioc_value: str, providers: list[str], request_payload: dict[str, Any], response_payload: dict[str, Any], matched: bool | None = None, severity: str | None = None, confidence: float | None = None, error: str | None = None, ) -> None: with get_conn() as conn, conn.cursor() as cur: cur.execute( """ INSERT INTO ioc_trace( action, ioc_type, ioc_value, providers, request_payload, response_payload, matched, severity, confidence, error ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, ( action, ioc_type, ioc_value, Json(providers), Json(request_payload), Json(response_payload), matched, severity, confidence, error, ), ) def list_ioc_trace(self, limit: int = 50, offset: int = 0) -> list[dict[str, Any]]: with get_conn() as conn, conn.cursor() as cur: cur.execute( """ SELECT id, action, ioc_type, ioc_value, providers, matched, severity, confidence, error, created_at FROM ioc_trace ORDER BY created_at DESC LIMIT %s OFFSET %s """, (max(1, limit), max(0, offset)), ) return [dict(row) for row in cur.fetchall()] def get_correlation_state(self, entity_key: str) -> dict[str, Any] | None: with get_conn() as conn, conn.cursor() as cur: cur.execute( "SELECT state FROM correlation_state WHERE entity_key = %s", (entity_key,), ) row = cur.fetchone() return dict(row["state"]) if row else None def upsert_correlation_state(self, entity_key: str, state: dict[str, Any]) -> None: with get_conn() as conn, conn.cursor() as cur: cur.execute( """ INSERT INTO correlation_state(entity_key, state, updated_at) VALUES (%s, %s, NOW()) ON CONFLICT (entity_key) DO UPDATE SET state = EXCLUDED.state, updated_at = NOW() """, (entity_key, Json(state)), ) def add_c_detection_event( self, usecase_id: str, entity: str, severity: str, evidence: dict[str, Any], event_ref: dict[str, Any], incident_key: str | None = None, ) -> dict[str, Any]: with get_conn() as conn, conn.cursor() as cur: cur.execute( """ INSERT INTO c_detection_events(usecase_id, entity, severity, evidence, event_ref, incident_key) VALUES (%s, %s, %s, %s, %s, %s) RETURNING id, usecase_id, entity, severity, evidence, event_ref, incident_key, matched_at """, ( usecase_id, entity, severity, Json(evidence), Json(event_ref), incident_key, ), ) return dict(cur.fetchone()) def update_c_detection_incident(self, event_id: int, incident_key: str | None) -> None: with get_conn() as conn, conn.cursor() as cur: cur.execute( "UPDATE c_detection_events SET incident_key = %s WHERE id = %s", (incident_key, event_id), ) def list_c_detection_events( self, limit: int = 50, offset: int = 0, usecase_id: str | None = None, ) -> list[dict[str, Any]]: with get_conn() as conn, conn.cursor() as cur: if usecase_id: cur.execute( """ SELECT id, usecase_id, entity, severity, evidence, event_ref, incident_key, matched_at FROM c_detection_events WHERE usecase_id = %s ORDER BY matched_at DESC LIMIT %s OFFSET %s """, (usecase_id, max(1, limit), max(0, offset)), ) else: cur.execute( """ SELECT id, usecase_id, entity, severity, evidence, event_ref, incident_key, matched_at FROM c_detection_events ORDER BY matched_at DESC LIMIT %s OFFSET %s """, (max(1, limit), max(0, offset)), ) return [dict(row) for row in cur.fetchall()] def is_c_detection_in_cooldown(self, usecase_id: str, entity: str, cooldown_seconds: int) -> bool: if cooldown_seconds <= 0: return False with get_conn() as conn, conn.cursor() as cur: cur.execute( """ SELECT 1 FROM c_detection_events WHERE usecase_id = %s AND entity = %s AND incident_key IS NOT NULL AND matched_at > (NOW() - (%s || ' seconds')::interval) LIMIT 1 """, (usecase_id, entity, int(cooldown_seconds)), ) return cur.fetchone() is not None