Нет описания

manage_users_db.py 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777
  1. # IRIS Source Code
  2. # Copyright (C) 2021 - Airbus CyberSecurity (SAS)
  3. # ir@cyberactionlab.net
  4. #
  5. # This program is free software; you can redistribute it and/or
  6. # modify it under the terms of the GNU Lesser General Public
  7. # License as published by the Free Software Foundation; either
  8. # version 3 of the License, or (at your option) any later version.
  9. #
  10. # This program is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  13. # Lesser General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU Lesser General Public License
  16. # along with this program; if not, write to the Free Software Foundation,
  17. # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
  18. from typing import List
  19. from functools import reduce
  20. from flask_login import current_user
  21. from sqlalchemy import and_
  22. import app
  23. from app import bc
  24. from app import db
  25. from app.datamgmt.case.case_db import get_case
  26. from app.datamgmt.conversions import convert_sort_direction
  27. from app.iris_engine.access_control.utils import ac_access_level_mask_from_val_list
  28. from app.iris_engine.access_control.utils import ac_ldp_group_removal
  29. from app.iris_engine.access_control.utils import ac_access_level_to_list
  30. from app.iris_engine.access_control.utils import ac_auto_update_user_effective_access
  31. from app.iris_engine.access_control.utils import ac_get_detailed_effective_permissions_from_groups
  32. from app.iris_engine.access_control.utils import ac_remove_case_access_from_user
  33. from app.iris_engine.access_control.utils import ac_set_case_access_for_user
  34. from app.models.cases import Cases
  35. from app.models.models import Client
  36. from app.models.models import UserActivity
  37. from app.models.authorization import CaseAccessLevel
  38. from app.models.authorization import UserClient
  39. from app.models.authorization import Group
  40. from app.models.authorization import Organisation
  41. from app.models.authorization import User
  42. from app.models.authorization import UserCaseAccess
  43. from app.models.authorization import UserCaseEffectiveAccess
  44. from app.models.authorization import UserGroup
  45. from app.models.authorization import UserOrganisation
  46. def get_user(user_id, id_key: str = 'id') -> [User, None]:
  47. user = User.query.filter(getattr(User, id_key) == user_id).first()
  48. return user
  49. def get_active_user(user_id, id_key: str = 'id') -> [User, None]:
  50. user = User.query.filter(
  51. and_(
  52. getattr(User, id_key) == user_id,
  53. User.active == True
  54. )).first()
  55. return user
  56. def get_active_user_by_login(username):
  57. return get_active_user(user_id=username, id_key='user')
  58. def list_users_id():
  59. users = User.query.with_entities(User.user_id).all()
  60. return users
  61. def get_user_effective_permissions(user_id):
  62. groups_perms = UserGroup.query.with_entities(
  63. Group.group_permissions,
  64. Group.group_name
  65. ).filter(
  66. UserGroup.user_id == user_id
  67. ).join(
  68. UserGroup.group
  69. ).all()
  70. effective_permissions = ac_get_detailed_effective_permissions_from_groups(groups_perms)
  71. return effective_permissions
  72. def get_user_groups(user_id):
  73. groups = UserGroup.query.with_entities(
  74. Group.group_name,
  75. Group.group_id,
  76. Group.group_uuid
  77. ).filter(
  78. UserGroup.user_id == user_id
  79. ).join(
  80. UserGroup.group
  81. ).all()
  82. output = []
  83. for group in groups:
  84. output.append(group._asdict())
  85. return output
  86. def update_user_groups(user_id, groups):
  87. cur_groups = UserGroup.query.with_entities(
  88. UserGroup.group_id
  89. ).filter(UserGroup.user_id == user_id).all()
  90. set_cur_groups = set([grp[0] for grp in cur_groups])
  91. set_new_groups = set(int(grp) for grp in groups)
  92. groups_to_add = set_new_groups - set_cur_groups
  93. groups_to_remove = set_cur_groups - set_new_groups
  94. for group_id in groups_to_add:
  95. user_group = UserGroup()
  96. user_group.user_id = user_id
  97. user_group.group_id = group_id
  98. db.session.add(user_group)
  99. for group_id in groups_to_remove:
  100. if current_user.id == user_id and ac_ldp_group_removal(user_id=user_id, group_id=group_id):
  101. continue
  102. UserGroup.query.filter(
  103. UserGroup.user_id == user_id,
  104. UserGroup.group_id == group_id
  105. ).delete()
  106. db.session.commit()
  107. ac_auto_update_user_effective_access(user_id)
  108. def add_user_to_customer(user_id, customer_id):
  109. user_client = UserClient.query.filter(
  110. UserClient.user_id == user_id,
  111. UserClient.client_id == customer_id
  112. ).first()
  113. if user_client:
  114. return True
  115. user_client = UserClient()
  116. user_client.user_id = user_id
  117. user_client.client_id = customer_id
  118. user_client.access_level = CaseAccessLevel.full_access.value
  119. user_client.allow_alerts = True
  120. db.session.add(user_client)
  121. db.session.commit()
  122. ac_auto_update_user_effective_access(user_id)
  123. return True
  124. def update_user_customers(user_id, customers):
  125. # Update the user's customers directly
  126. cur_customers = UserClient.query.with_entities(
  127. UserClient.client_id
  128. ).filter(UserClient.user_id == user_id).all()
  129. set_cur_customers = set([cust[0] for cust in cur_customers])
  130. set_new_customers = set(int(cust) for cust in customers)
  131. customers_to_add = set_new_customers - set_cur_customers
  132. customers_to_remove = set_cur_customers - set_new_customers
  133. for client_id in customers_to_add:
  134. user_client = UserClient()
  135. user_client.user_id = user_id
  136. user_client.client_id = client_id
  137. user_client.access_level = CaseAccessLevel.full_access.value
  138. user_client.allow_alerts = True
  139. db.session.add(user_client)
  140. for client_id in customers_to_remove:
  141. UserClient.query.filter(
  142. UserClient.user_id == user_id,
  143. UserClient.client_id == client_id
  144. ).delete()
  145. ac_auto_update_user_effective_access(user_id)
  146. db.session.commit()
  147. def update_user_orgs(user_id, orgs):
  148. cur_orgs = UserOrganisation.query.with_entities(
  149. UserOrganisation.org_id,
  150. UserOrganisation.is_primary_org
  151. ).filter(UserOrganisation.user_id == user_id).all()
  152. updated = False
  153. primary_org = 0
  154. for org in cur_orgs:
  155. if org.is_primary_org:
  156. primary_org = org.org_id
  157. if primary_org == 0:
  158. return False, 'User does not have primary organisation. Set one before managing its organisations'
  159. set_cur_orgs = set([org.org_id for org in cur_orgs])
  160. set_new_orgs = set(int(org) for org in orgs)
  161. orgs_to_add = set_new_orgs - set_cur_orgs
  162. orgs_to_remove = set_cur_orgs - set_new_orgs
  163. for org in orgs_to_add:
  164. user_org = UserOrganisation()
  165. user_org.user_id = user_id
  166. user_org.org_id = org
  167. db.session.add(user_org)
  168. updated = True
  169. for org in orgs_to_remove:
  170. if org != primary_org:
  171. UserOrganisation.query.filter(
  172. UserOrganisation.user_id == user_id,
  173. UserOrganisation.org_id == org
  174. ).delete()
  175. else:
  176. db.session.rollback()
  177. return False, f'Cannot delete user from primary organisation {org}. Change it before deleting.'
  178. updated = True
  179. db.session.commit()
  180. ac_auto_update_user_effective_access(user_id)
  181. return True, 'Organisations membership updated' if updated else "Nothing changed"
  182. def change_user_primary_org(user_id, old_org_id, new_org_id):
  183. uo_old = UserOrganisation.query.filter(
  184. UserOrganisation.user_id == user_id,
  185. UserOrganisation.org_id == old_org_id
  186. ).first()
  187. uo_new = UserOrganisation.query.filter(
  188. UserOrganisation.user_id == user_id,
  189. UserOrganisation.org_id == new_org_id
  190. ).first()
  191. if uo_old:
  192. uo_old.is_primary_org = False
  193. if not uo_new:
  194. uo = UserOrganisation()
  195. uo.user_id = user_id
  196. uo.org_id = new_org_id
  197. uo.is_primary_org = True
  198. db.session.add(uo)
  199. else:
  200. uo_new.is_primary_org = True
  201. db.session.commit()
  202. return
  203. def add_user_to_organisation(user_id, org_id, make_primary=False):
  204. org_id = Organisation.query.first().org_id
  205. uo_exists = UserOrganisation.query.filter(
  206. UserOrganisation.user_id == user_id,
  207. UserOrganisation.org_id == org_id
  208. ).first()
  209. if uo_exists:
  210. uo_exists.is_primary_org = make_primary
  211. db.session.commit()
  212. return True
  213. # Check if user has a primary org already
  214. prim_org = get_user_primary_org(user_id=user_id)
  215. if make_primary:
  216. prim_org.is_primary_org = False
  217. db.session.commit()
  218. uo = UserOrganisation()
  219. uo.user_id = user_id
  220. uo.org_id = org_id
  221. uo.is_primary_org = prim_org is None
  222. db.session.add(uo)
  223. db.session.commit()
  224. return True
  225. def get_user_primary_org(user_id):
  226. uo = UserOrganisation.query.filter(
  227. and_(UserOrganisation.user_id == user_id,
  228. UserOrganisation.is_primary_org == True)
  229. ).all()
  230. if not uo:
  231. return None
  232. uoe = None
  233. index = 0
  234. if len(uo) > 1:
  235. # Fix potential duplication
  236. for u in uo:
  237. if index == 0:
  238. uoe = u
  239. continue
  240. u.is_primary_org = False
  241. db.session.commit()
  242. else:
  243. uoe = uo[0]
  244. return uoe
  245. def add_user_to_group(user_id, group_id):
  246. exists = UserGroup.query.filter(
  247. UserGroup.user_id == user_id,
  248. UserGroup.group_id == group_id
  249. ).scalar()
  250. if exists:
  251. return True
  252. ug = UserGroup()
  253. ug.user_id = user_id
  254. ug.group_id = group_id
  255. db.session.add(ug)
  256. db.session.commit()
  257. return True
  258. def get_user_organisations(user_id):
  259. user_org = UserOrganisation.query.with_entities(
  260. Organisation.org_name,
  261. Organisation.org_id,
  262. Organisation.org_uuid,
  263. UserOrganisation.is_primary_org
  264. ).filter(
  265. UserOrganisation.user_id == user_id
  266. ).join(
  267. UserOrganisation.org
  268. ).all()
  269. output = []
  270. for org in user_org:
  271. output.append(org._asdict())
  272. return output
  273. def get_user_cases_access(user_id):
  274. user_accesses = UserCaseAccess.query.with_entities(
  275. UserCaseAccess.access_level,
  276. UserCaseAccess.case_id,
  277. Cases.name.label('case_name')
  278. ).join(
  279. UserCaseAccess.case
  280. ).filter(
  281. UserCaseAccess.user_id == user_id
  282. ).all()
  283. user_cases_access = []
  284. for kuser in user_accesses:
  285. user_cases_access.append({
  286. "access_level": kuser.access_level,
  287. "access_level_list": ac_access_level_to_list(kuser.access_level),
  288. "case_id": kuser.case_id,
  289. "case_name": kuser.case_name
  290. })
  291. return user_cases_access
  292. def get_user_clients(user_id: int) -> List[Client]:
  293. clients = UserClient.query.filter(
  294. UserClient.user_id == user_id
  295. ).join(
  296. UserClient.client
  297. ).with_entities(
  298. Client.client_id.label('customer_id'),
  299. Client.client_uuid,
  300. Client.name.label('customer_name')
  301. ).all()
  302. clients_out = [c._asdict() for c in clients]
  303. return clients_out
  304. def get_user_cases_fast(user_id):
  305. user_cases = UserCaseEffectiveAccess.query.with_entities(
  306. UserCaseEffectiveAccess.case_id
  307. ).where(
  308. UserCaseEffectiveAccess.user_id == user_id,
  309. UserCaseEffectiveAccess.access_level != CaseAccessLevel.deny_all.value
  310. ).all()
  311. return [c.case_id for c in user_cases]
  312. def remove_cases_access_from_user(user_id, cases_list):
  313. if not user_id or type(user_id) is not int:
  314. return False, 'Invalid user id'
  315. if not cases_list or type(cases_list[0]) is not int:
  316. return False, "Invalid cases list"
  317. UserCaseAccess.query.filter(
  318. and_(
  319. UserCaseAccess.case_id.in_(cases_list),
  320. UserCaseAccess.user_id == user_id
  321. )).delete()
  322. db.session.commit()
  323. ac_auto_update_user_effective_access(user_id)
  324. return True, 'Cases access removed'
  325. def remove_case_access_from_user(user_id, case_id):
  326. if not user_id or type(user_id) is not int:
  327. return False, 'Invalid user id'
  328. if not case_id or type(case_id) is not int:
  329. return False, "Invalid case id"
  330. UserCaseAccess.query.filter(
  331. and_(
  332. UserCaseAccess.case_id == case_id,
  333. UserCaseAccess.user_id == user_id
  334. )).delete()
  335. db.session.commit()
  336. ac_remove_case_access_from_user(user_id, case_id)
  337. return True, 'Case access removed'
  338. def set_user_case_access(user_id, case_id, access_level):
  339. if user_id is None or type(user_id) is not int:
  340. return False, 'Invalid user id'
  341. if case_id is None or type(case_id) is not int:
  342. return False, "Invalid case id"
  343. if access_level is None or type(access_level) is not int:
  344. return False, "Invalid access level"
  345. if CaseAccessLevel.has_value(access_level) is False:
  346. return False, "Invalid access level"
  347. uca = UserCaseAccess.query.filter(
  348. UserCaseAccess.user_id == user_id,
  349. UserCaseAccess.case_id == case_id
  350. ).all()
  351. if len(uca) > 1:
  352. for u in uca:
  353. db.session.delete(u)
  354. db.session.commit()
  355. uca = None
  356. if not uca:
  357. uca = UserCaseAccess()
  358. uca.user_id = user_id
  359. uca.case_id = case_id
  360. uca.access_level = access_level
  361. db.session.add(uca)
  362. else:
  363. uca[0].access_level = access_level
  364. db.session.commit()
  365. ac_set_case_access_for_user(user_id, case_id, access_level)
  366. return True, 'Case access set to {} for user {}'.format(access_level, user_id)
  367. def get_user_details(user_id, include_api_key=False):
  368. user = User.query.filter(User.id == user_id).first()
  369. if not user:
  370. return None
  371. row = {}
  372. row['user_id'] = user.id
  373. row['user_uuid'] = user.uuid
  374. row['user_name'] = user.name
  375. row['user_login'] = user.user
  376. row['user_email'] = user.email
  377. row['user_active'] = user.active
  378. row['user_is_service_account'] = user.is_service_account
  379. if include_api_key:
  380. row['user_api_key'] = user.api_key
  381. row['user_groups'] = get_user_groups(user_id)
  382. row['user_organisations'] = get_user_organisations(user_id)
  383. row['user_permissions'] = get_user_effective_permissions(user_id)
  384. row['user_cases_access'] = get_user_cases_access(user_id)
  385. row['user_customers'] = get_user_clients(user_id)
  386. upg = get_user_primary_org(user_id)
  387. row['user_primary_organisation_id'] = upg.org_id if upg else 0
  388. return row
  389. def add_case_access_to_user(user, cases_list, access_level):
  390. if not user:
  391. return None, "Invalid user"
  392. for case_id in cases_list:
  393. case = get_case(case_id)
  394. if not case:
  395. return None, "Invalid case ID"
  396. access_level_mask = ac_access_level_mask_from_val_list([access_level])
  397. ocas = UserCaseAccess.query.filter(
  398. and_(
  399. UserCaseAccess.case_id == case_id,
  400. UserCaseAccess.user_id == user.id
  401. )).all()
  402. if ocas:
  403. for oca in ocas:
  404. db.session.delete(oca)
  405. oca = UserCaseAccess()
  406. oca.user_id = user.id
  407. oca.access_level = access_level_mask
  408. oca.case_id = case_id
  409. db.session.add(oca)
  410. db.session.commit()
  411. ac_auto_update_user_effective_access(user.id)
  412. return user, "Updated"
  413. def get_user_by_username(username):
  414. user = User.query.filter(User.user == username).first()
  415. return user
  416. def get_users_list():
  417. users = User.query.all()
  418. output = []
  419. for user in users:
  420. row = {}
  421. row['user_id'] = user.id
  422. row['user_uuid'] = user.uuid
  423. row['user_name'] = user.name
  424. row['user_login'] = user.user
  425. row['user_email'] = user.email
  426. row['user_active'] = user.active
  427. row['user_is_service_account'] = user.is_service_account
  428. output.append(row)
  429. return output
  430. def get_users_list_restricted():
  431. users = User.query.all()
  432. output = []
  433. for user in users:
  434. row = {}
  435. row['user_id'] = user.id
  436. row['user_uuid'] = user.uuid
  437. row['user_name'] = user.name
  438. row['user_login'] = user.user
  439. row['user_active'] = user.active
  440. output.append(row)
  441. return output
  442. def get_users_view_from_user_id(user_id):
  443. organisations = get_user_organisations(user_id)
  444. orgs_id = [uo.get('org_id') for uo in organisations]
  445. users = UserOrganisation.query.with_entities(
  446. User
  447. ).filter(and_(
  448. UserOrganisation.org_id.in_(orgs_id),
  449. UserOrganisation.user_id != user_id
  450. )).join(
  451. UserOrganisation.user
  452. ).all()
  453. return users
  454. def get_users_id_view_from_user_id(user_id):
  455. organisations = get_user_organisations(user_id)
  456. orgs_id = [uo.get('org_id') for uo in organisations]
  457. users = UserOrganisation.query.with_entities(
  458. User.id
  459. ).filter(and_(
  460. UserOrganisation.org_id.in_(orgs_id),
  461. UserOrganisation.user_id != user_id
  462. )).join(
  463. UserOrganisation.user
  464. ).all()
  465. users = [u[0] for u in users]
  466. return users
  467. def get_users_list_user_view(user_id):
  468. users = get_users_view_from_user_id(user_id)
  469. output = []
  470. for user in users:
  471. row = {}
  472. row['user_id'] = user.id
  473. row['user_uuid'] = user.uuid
  474. row['user_name'] = user.name
  475. row['user_login'] = user.user
  476. row['user_email'] = user.email
  477. row['user_active'] = user.active
  478. output.append(row)
  479. return output
  480. def get_users_list_restricted_user_view(user_id):
  481. users = get_users_view_from_user_id(user_id)
  482. output = []
  483. for user in users:
  484. row = {}
  485. row['user_id'] = user.id
  486. row['user_uuid'] = user.uuid
  487. row['user_name'] = user.name
  488. row['user_login'] = user.user
  489. row['user_active'] = user.active
  490. output.append(row)
  491. return output
  492. def get_users_list_restricted_from_case(case_id):
  493. users = UserCaseEffectiveAccess.query.with_entities(
  494. User.id.label('user_id'),
  495. User.uuid.label('user_uuid'),
  496. User.name.label('user_name'),
  497. User.user.label('user_login'),
  498. User.active.label('user_active'),
  499. User.email.label('user_email'),
  500. UserCaseEffectiveAccess.access_level.label('user_access_level')
  501. ).filter(
  502. UserCaseEffectiveAccess.case_id == case_id
  503. ).join(
  504. UserCaseEffectiveAccess.user
  505. ).all()
  506. return [u._asdict() for u in users]
  507. def create_user(user_name: str, user_login: str, user_password: str, user_email: str, user_active: bool,
  508. user_external_id: str = None, user_is_service_account: bool = False):
  509. if user_is_service_account is True and (user_password is None or user_password == ''):
  510. pw_hash = None
  511. else:
  512. pw_hash = bc.generate_password_hash(user_password.encode('utf8')).decode('utf8')
  513. user = User(user=user_login, name=user_name, email=user_email, password=pw_hash, active=user_active,
  514. external_id=user_external_id, is_service_account=user_is_service_account)
  515. user.save()
  516. add_user_to_organisation(user.id, org_id=1)
  517. ac_auto_update_user_effective_access(user_id=user.id)
  518. return user
  519. def update_user(user: User, name: str = None, email: str = None, password: str = None):
  520. if password is not None and password != '':
  521. pw_hash = bc.generate_password_hash(password.encode('utf8')).decode('utf8')
  522. user.password = pw_hash
  523. for key, value in [('name', name,), ('email', email,)]:
  524. if value is not None:
  525. setattr(user, key, value)
  526. db.session.commit()
  527. return user
  528. def delete_user(user_id):
  529. # Migrate the user activity to a shadow user
  530. UserActivity.query.filter(UserActivity.user_id == user_id).update({UserActivity.user_id: None})
  531. UserCaseAccess.query.filter(UserCaseAccess.user_id == user_id).delete()
  532. UserOrganisation.query.filter(UserOrganisation.user_id == user_id).delete()
  533. UserGroup.query.filter(UserGroup.user_id == user_id).delete()
  534. UserCaseEffectiveAccess.query.filter(UserCaseEffectiveAccess.user_id == user_id).delete()
  535. User.query.filter(User.id == user_id).delete()
  536. db.session.commit()
  537. def user_exists(user_name, user_email):
  538. user = User.query.filter_by(user=user_name).first()
  539. user_by_email = User.query.filter_by(email=user_email).first()
  540. return user or user_by_email
  541. def get_filtered_users(user_ids: str = None,
  542. user_name: str = None,
  543. user_login: str = None,
  544. customer_id: int = None,
  545. page: int = None,
  546. per_page: int = None,
  547. sort: str =None):
  548. """
  549. """
  550. conditions = []
  551. if user_ids is not None:
  552. conditions.append(User.id.in_(user_ids))
  553. if user_name is not None:
  554. conditions.append(User.name.ilike(user_name))
  555. if user_login is not None:
  556. conditions.append(User.user.ilike(user_login))
  557. if customer_id is not None:
  558. conditions.append(UserClient.client_id == customer_id)
  559. conditions.append(UserClient.user_id == User.id)
  560. if len(conditions) > 1:
  561. conditions = [reduce(and_, conditions)]
  562. order_func = convert_sort_direction(sort)
  563. try:
  564. filtered_users = db.session.query(
  565. User
  566. ).filter(
  567. *conditions
  568. ).order_by(
  569. order_func(User.id)
  570. ).paginate(
  571. page=page,
  572. per_page=per_page,
  573. error_out=False
  574. )
  575. except Exception as e:
  576. app.logger.exception(f'Error getting users: {str(e)}')
  577. return None
  578. return filtered_users