Нет описания

fcc375ed37d1_access_control_migration.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. """Access control migration
  2. Revision ID: fcc375ed37d1
  3. Revises: 7cc588444b79
  4. Create Date: 2022-06-14 17:01:29.205520
  5. """
  6. import uuid
  7. import sqlalchemy as sa
  8. from alembic import op
  9. from sqlalchemy import text
  10. from sqlalchemy.dialects.postgresql import UUID
  11. from app.alembic.alembic_utils import _has_table
  12. # revision identifiers, used by Alembic.
  13. from app.alembic.alembic_utils import _table_has_column
  14. from app.iris_engine.access_control.utils import ac_get_mask_analyst
  15. from app.iris_engine.access_control.utils import ac_get_mask_case_access_level_full
  16. from app.iris_engine.access_control.utils import ac_get_mask_full_permissions
  17. revision = 'fcc375ed37d1'
  18. down_revision = '7cc588444b79'
  19. branch_labels = None
  20. depends_on = None
  21. def upgrade():
  22. # Ensure the DB is not in a locked state and commit any pending transactions
  23. op.execute(text("COMMIT;"))
  24. conn = None
  25. # Add UUID to users
  26. if not _table_has_column('user', 'uuid'):
  27. conn = op.get_bind()
  28. op.add_column('user',
  29. sa.Column('uuid', UUID(as_uuid=True), default=uuid.uuid4, nullable=False,
  30. server_default=sa.text('gen_random_uuid()'))
  31. )
  32. # Add UUID to existing users
  33. t_users = sa.Table(
  34. 'user',
  35. sa.MetaData(),
  36. sa.Column('id', sa.BigInteger(), primary_key=True),
  37. sa.Column('uuid', UUID(as_uuid=True), default=uuid.uuid4, nullable=False)
  38. )
  39. res = conn.execute(text("select id from \"user\";"))
  40. results = res.fetchall()
  41. for user in results:
  42. conn.execute(t_users.update().where(t_users.c.id == user[0]).values(
  43. uuid=uuid.uuid4()
  44. ))
  45. # Add all the new access control tables if they don't exist
  46. if not _has_table('user_case_access'):
  47. op.create_table('user_case_access',
  48. sa.Column('id', sa.BigInteger(), primary_key=True, nullable=False),
  49. sa.Column('user_id', sa.BigInteger(), sa.ForeignKey('user.id'), nullable=False),
  50. sa.Column('case_id', sa.BigInteger(), sa.ForeignKey('cases.case_id'), nullable=False),
  51. sa.Column('access_level', sa.BigInteger()),
  52. keep_existing=True
  53. )
  54. op.create_foreign_key('fk_user_case_access_user_id', 'user_case_access', 'user', ['user_id'], ['id'])
  55. op.create_foreign_key('fk_user_case_access_case_id', 'user_case_access', 'cases', ['case_id'], ['case_id'])
  56. op.create_unique_constraint('uq_user_case_access_user_id_case_id', 'user_case_access', ['user_id', 'case_id'])
  57. if not _has_table('user_case_effective_access'):
  58. op.create_table('user_case_effective_access',
  59. sa.Column('id', sa.BigInteger(), primary_key=True, nullable=False),
  60. sa.Column('user_id', sa.BigInteger(), sa.ForeignKey('user.id'), nullable=False),
  61. sa.Column('case_id', sa.BigInteger(), sa.ForeignKey('cases.case_id'), nullable=False),
  62. sa.Column('access_level', sa.BigInteger()),
  63. keep_existing=True
  64. )
  65. op.create_foreign_key('fk_user_case_effective_access_user_id', 'user_case_effective_access',
  66. 'user', ['user_id'], ['id'])
  67. op.create_foreign_key('fk_user_case_effective_access_case_id', 'user_case_effective_access',
  68. 'cases', ['case_id'], ['case_id'])
  69. op.create_unique_constraint('uq_user_case_effective_access_user_id_case_id',
  70. 'user_case_access', ['user_id', 'case_id'])
  71. if not _has_table('group_case_access'):
  72. op.create_table('group_case_access',
  73. sa.Column('id', sa.BigInteger(), primary_key=True, nullable=False),
  74. sa.Column('group_id', sa.BigInteger(), sa.ForeignKey('groups.group_id'), nullable=False),
  75. sa.Column('case_id', sa.BigInteger(), sa.ForeignKey('cases.case_id'), nullable=False),
  76. sa.Column('access_level', sa.BigInteger(), nullable=False),
  77. keep_existing=True
  78. )
  79. op.create_foreign_key('group_case_access_group_id_fkey', 'group_case_access', 'groups',
  80. ['group_id'], ['group_id'])
  81. op.create_foreign_key('group_case_access_case_id_fkey', 'group_case_access', 'cases',
  82. ['case_id'], ['case_id'])
  83. op.create_unique_constraint('group_case_access_unique', 'group_case_access', ['group_id', 'case_id'])
  84. if not _has_table('groups'):
  85. op.create_table('groups',
  86. sa.Column('group_id', sa.BigInteger(), primary_key=True, nullable=False),
  87. sa.Column('group_uuid', UUID(as_uuid=True), default=uuid.uuid4, nullable=False,
  88. server_default=sa.text('gen_random_uuid()'), unique=True),
  89. sa.Column('group_name', sa.Text(), nullable=False),
  90. sa.Column('group_description', sa.Text(), nullable=False),
  91. sa.Column('group_permissions', sa.BigInteger(), nullable=False),
  92. sa.Column('group_auto_follow', sa.Boolean(), nullable=False, default=False),
  93. sa.Column('group_auto_follow_access_level', sa.BigInteger(), nullable=True),
  94. keep_existing=True
  95. )
  96. op.create_unique_constraint('groups_group_name_unique', 'groups', ['group_name'])
  97. if not _has_table('organisations'):
  98. op.create_table('organisations',
  99. sa.Column('org_id', sa.BigInteger(), primary_key=True, nullable=False),
  100. sa.Column('org_uuid', UUID(as_uuid=True), default=uuid.uuid4(), nullable=False,
  101. server_default=sa.text('gen_random_uuid()'), unique=True),
  102. sa.Column('org_name', sa.Text(), nullable=False),
  103. sa.Column('org_description', sa.Text(), nullable=False),
  104. sa.Column('org_url', sa.Text(), nullable=False),
  105. sa.Column('org_email', sa.Text(), nullable=False),
  106. sa.Column('org_logo', sa.Text(), nullable=False),
  107. sa.Column('org_type', sa.Text(), nullable=False),
  108. sa.Column('org_sector', sa.Text(), nullable=False),
  109. sa.Column('org_nationality', sa.Text(), nullable=False),
  110. keep_existing=True
  111. )
  112. op.create_unique_constraint('organisation_name_unique', 'organisations', ['org_name'])
  113. if not _has_table('organisation_case_access'):
  114. op.create_table('organisation_case_access',
  115. sa.Column('id', sa.BigInteger(), primary_key=True, nullable=False),
  116. sa.Column('org_id', sa.BigInteger(), sa.ForeignKey('organisations.org_id'), nullable=False),
  117. sa.Column('case_id', sa.BigInteger(), sa.ForeignKey('cases.case_id'), nullable=False),
  118. sa.Column('access_level', sa.BigInteger(), nullable=False),
  119. keep_existing=True
  120. )
  121. op.create_foreign_key('organisation_case_access_org_id_fkey', 'organisation_case_access',
  122. 'organisations', ['org_id'], ['org_id'])
  123. op.create_foreign_key('organisation_case_access_case_id_fkey', 'organisation_case_access', 'cases',
  124. ['case_id'], ['case_id'])
  125. op.create_unique_constraint('organisation_case_access_unique', 'organisation_case_access',
  126. ['org_id', 'case_id'])
  127. if not _has_table('user_organisation'):
  128. op.create_table('user_organisation',
  129. sa.Column('id', sa.BigInteger(), primary_key=True, nullable=False),
  130. sa.Column('user_id', sa.BigInteger(), sa.ForeignKey('user.id'), nullable=False),
  131. sa.Column('org_id', sa.BigInteger(), sa.ForeignKey('organisations.org_id'), nullable=False),
  132. sa.Column('is_primary_org', sa.Boolean(), nullable=False),
  133. keep_existing=True
  134. )
  135. op.create_foreign_key('user_organisation_user_id_fkey', 'user_organisation', 'user', ['user_id'], ['id'])
  136. op.create_foreign_key('user_organisation_org_id_fkey', 'user_organisation', 'organisations',
  137. ['org_id'], ['org_id'])
  138. op.create_unique_constraint('user_organisation_unique', 'user_organisation', ['user_id', 'org_id'])
  139. if not _has_table('user_group'):
  140. op.create_table('user_group',
  141. sa.Column('id', sa.BigInteger(), primary_key=True, nullable=False),
  142. sa.Column('user_id', sa.BigInteger(), sa.ForeignKey('user.id'), nullable=False),
  143. sa.Column('group_id', sa.BigInteger(), sa.ForeignKey('groups.group_id'), nullable=False),
  144. keep_existing=True
  145. )
  146. op.create_foreign_key('user_group_user_id_fkey', 'user_group', 'user', ['user_id'], ['id'])
  147. op.create_foreign_key('user_group_group_id_fkey', 'user_group', 'groups', ['group_id'], ['group_id'])
  148. op.create_unique_constraint('user_group_unique', 'user_group', ['user_id', 'group_id'])
  149. if not conn:
  150. conn = op.get_bind()
  151. # Create the groups if they don't exist
  152. res = conn.execute(text("select group_id from groups where group_name = 'Administrators';"))
  153. if res.rowcount == 0:
  154. conn.execute(text(f"insert into groups (group_name, group_description, group_permissions, group_uuid, "
  155. f"group_auto_follow, group_auto_follow_access_level) "
  156. f"values ('Administrators', 'Administrators', '{ac_get_mask_full_permissions()}', '{uuid.uuid4()}',"
  157. f" true, 4);"))
  158. res = conn.execute(text("select group_id from groups where group_name = 'Administrators';"))
  159. admin_group_id = res.fetchone()[0]
  160. res = conn.execute(text("select group_id from groups where group_name = 'Analysts';"))
  161. if res.rowcount == 0:
  162. conn.execute(text(f"insert into groups (group_name, group_description, group_permissions, group_uuid, "
  163. f"group_auto_follow, group_auto_follow_access_level) "
  164. f"values ('Analysts', 'Standard Analysts', '{ac_get_mask_analyst()}', '{uuid.uuid4()}', true, 4);"))
  165. res = conn.execute(text("select group_id from groups where group_name = 'Analysts';"))
  166. analyst_group_id = res.fetchone()[0]
  167. # Create the organisations if they don't exist
  168. res = conn.execute(text("select org_id from organisations where org_name = 'Default Org';"))
  169. if res.rowcount == 0:
  170. conn.execute(text(f"insert into organisations (org_name, org_description, org_url, org_email, org_logo, "
  171. f"org_type, org_sector, org_nationality, org_uuid) values ('Default Org', 'Default Organisation', "
  172. f"'', '', "
  173. f"'','', '', '', '{uuid.uuid4()}');"))
  174. res = conn.execute(text("select org_id from organisations where org_name = 'Default Org';"))
  175. default_org_id = res.fetchone()[0]
  176. # Give the organisation access to all the cases
  177. res = conn.execute(text("select case_id from cases;"))
  178. result_cases = [case[0] for case in res.fetchall()]
  179. access_level = ac_get_mask_case_access_level_full()
  180. # Migrate the users to the new access control system
  181. conn = op.get_bind()
  182. # Get all users with their roles
  183. if _has_table("user_roles"):
  184. res = conn.execute(text("select distinct roles.name, \"user\".id from user_roles INNER JOIN \"roles\" ON "
  185. "\"roles\".id = user_roles.role_id INNER JOIN \"user\" ON \"user\".id = user_roles.user_id;"))
  186. results_users = res.fetchall()
  187. for user_id in results_users:
  188. role_name = user_id[0]
  189. user_id = user_id[1]
  190. # Migrate user to groups
  191. if role_name == 'administrator':
  192. conn.execute(text(f"insert into user_group (user_id, group_id) values ({user_id}, {admin_group_id}) "
  193. f"on conflict do nothing;"))
  194. elif role_name == 'investigator':
  195. conn.execute(text(f"insert into user_group (user_id, group_id) values ({user_id}, {analyst_group_id}) "
  196. f"on conflict do nothing;"))
  197. # Add user to default organisation
  198. conn.execute(text(f"insert into user_organisation (user_id, org_id, is_primary_org) values ({user_id}, "
  199. f"{default_org_id}, true) on conflict do nothing;"))
  200. # Add default cases effective permissions
  201. for case_id in result_cases:
  202. conn.execute(text(f"insert into user_case_effective_access (case_id, user_id, access_level) values "
  203. f"({case_id}, {user_id}, {access_level}) on conflict do nothing;"))
  204. op.drop_table('user_roles')
  205. pass
  206. def downgrade():
  207. pass