diff --git a/scripts/data_scripts/create_postgres_tables.py b/scripts/data_scripts/create_postgres_tables.py index 605d7634c..805fd7621 100644 --- a/scripts/data_scripts/create_postgres_tables.py +++ b/scripts/data_scripts/create_postgres_tables.py @@ -1,5 +1,3 @@ -import json -from azure.keyvault.secrets import SecretClient from azure.identity import DefaultAzureCredential import psycopg2 from psycopg2 import sql @@ -24,12 +22,21 @@ def grant_permissions(cursor, dbname, schema_name, principal_name): - principal_name: Name of the principal (role or user) to grant permissions. """ - add_principal_user_query = sql.SQL("SELECT * FROM pgaadauth_create_principal({principal}, false, false)") + # Check if the principal exists in the database cursor.execute( - add_principal_user_query.format( - principal=sql.Literal(principal_name), + sql.SQL("SELECT 1 FROM pg_roles WHERE rolname = {principal}").format( + principal=sql.Literal(principal_name) ) ) + if cursor.fetchone() is None: + add_principal_user_query = sql.SQL( + "SELECT * FROM pgaadauth_create_principal({principal}, false, false)" + ) + cursor.execute( + add_principal_user_query.format( + principal=sql.Literal(principal_name), + ) + ) # Grant CONNECT on database grant_connect_query = sql.SQL("GRANT CONNECT ON DATABASE {database} TO {principal}") @@ -123,7 +130,9 @@ def grant_permissions(cursor, dbname, schema_name, principal_name): conn.commit() -cursor.execute("CREATE INDEX vector_store_content_vector_idx ON vector_store USING hnsw (content_vector vector_cosine_ops);") +cursor.execute( + "CREATE INDEX vector_store_content_vector_idx ON vector_store USING hnsw (content_vector vector_cosine_ops);" +) conn.commit() grant_permissions(cursor, dbname, "public", principal_name)