diff --git a/docker_tasks/vector_ingest/handler.py b/docker_tasks/vector_ingest/handler.py index 28c6ac1a..6c61e8a3 100644 --- a/docker_tasks/vector_ingest/handler.py +++ b/docker_tasks/vector_ingest/handler.py @@ -8,6 +8,13 @@ import smart_open from urllib.parse import urlparse import psycopg2 +import geopandas as gpd +from shapely import wkb +from geoalchemy2 import Geometry +import sqlalchemy +from sqlalchemy import create_engine, MetaData, Table, Column, inspect +import concurrent.futures +from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, INTEGER, VARCHAR, TIMESTAMP def download_file(file_uri: str): @@ -45,6 +52,130 @@ def get_connection_string(secret: dict, as_uri: bool = False) -> str: return f"PG:host={secret['host']} dbname={secret['dbname']} user={secret['username']} password={secret['password']}" +def get_gdf_schema(gdf, target_projection): + """map GeoDataFrame columns into a table schema + + :param gdf: GeoDataFrame from geopandas + :return: + """ + # map geodatafrome dtypes to sqlalchemy types + dtype_map = { + "int64": INTEGER, + "float64": DOUBLE_PRECISION, + "object": VARCHAR, + "datetime64": TIMESTAMP, + } + schema = [] + for column, dtype in zip(gdf.columns, gdf.dtypes): + if str(dtype) == "geometry": + # do not inpsect to retrieve geom type, just use generic GEOMETRY + # geom_type = str(gdf[column].geom_type.unique()[0]).upper() + geom_type = str(dtype).upper() + # do not taKe SRID from existing file for target table + # we always want to transform from file EPSG to Table EPSG() + column_type = Geometry(geometry_type=geom_type, srid=target_projection) + else: + dtype_str = str(dtype) + column_type = dtype_map.get(dtype_str.split("[")[0], VARCHAR) + + if column == "primarykey": + schema.append(Column(column.lower(), column_type, unique=True)) + else: + schema.append(Column(column.lower(), column_type)) + return schema + + +def ensure_table_exists(connection_string, gpkg_file, target_projection, table_name): + """create a table if it doesn't exist or just + validate GeoDataFrame columns against existing table + + :param connection_string: + :param gpkg_file: geopackage file location + :param table_name: name of table to create + :return: None + """ + engine = create_engine(connection_string) + metadata = MetaData() + metadata.bind = engine + + gdf = gpd.read_file(gpkg_file) + gdf_schema = get_gdf_schema(gdf, target_projection) + try: + Table(table_name, metadata, autoload_with=engine) + except sqlalchemy.exc.NoSuchTableError: + Table(table_name, metadata, *gdf_schema) + metadata.create_all(engine) + + # validate gdf schema against existing table schema + insp = inspect(engine) + existing_columns = insp.get_columns(table_name) + existing_column_names = [col["name"] for col in existing_columns] + for column in gdf_schema: + if column.name not in existing_column_names: + raise ValueError( + f"your .gpkg seems to have a column={column.name} that does not exist in the existing table columns={existing_column_names}" + ) + + +def upsert_to_postgis( + connection_string, gpkg_path, target_projection, table_name, batch_size=10000 +): + """batch the GPKG file and upsert via threads + + :param connection_string: + :param gpkg_path: + :param table_name: + :param batch_size: + :return: + """ + engine = create_engine(connection_string) + metadata = MetaData() + metadata.bind = engine + + gdf = gpd.read_file(gpkg_path) + source_epsg_code = gdf.crs.to_epsg() + if not source_epsg_code: + # assume NAD27 Equal Area for now :shrug: + # since that's what the default is for Fire Atlas team exports + # that's what PROJ4 does under the hood for 9311 :wethinksmirk: + source_epsg_code = 2163 + + # convert the `t` column to something suitable for sql insertion otherwise we get 'Timestamp()' + gdf["t"] = gdf["t"].dt.strftime("%Y-%m-%d %H:%M:%S") + # convert to WKB + gdf["geometry"] = gdf["geometry"].apply(lambda geom: wkb.dumps(geom, hex=True)) + + batches = [gdf.iloc[i : i + batch_size] for i in range(0, len(gdf), batch_size)] + + def upsert_batch(batch): + with engine.connect() as conn: + with conn.begin(): + for row in batch.to_dict(orient="records"): + # make sure all column names are lower case + row = {k.lower(): v for k, v in row.items()} + columns = [col.lower() for col in batch.columns] + + non_geom_placeholders = ", ".join( + [f":{col}" for col in columns[:-1]] + ) + # NOTE: we need to escape `::geometry` so parameterized statements don't try to replace it + geom_placeholder = f"ST_Transform(ST_SetSRID(ST_GeomFromWKB(:geometry\:\:geometry), {source_epsg_code}), {target_projection})" # noqa: W605 + upsert_sql = sqlalchemy.text( + f""" + INSERT INTO {table_name} ({', '.join([col for col in columns])}) + VALUES ({non_geom_placeholders},{geom_placeholder}) + ON CONFLICT (primarykey) + DO UPDATE SET {', '.join(f"{col}=EXCLUDED.{col}" for col in columns if col != 'primarykey')} + """ + ) + + # logging.debug(f"[ UPSERT SQL ]:\n{str(upsert_sql)}") + conn.execute(upsert_sql, row) + + with concurrent.futures.ThreadPoolExecutor() as executor: + executor.map(upsert_batch, batches) + + def get_secret(secret_name: str) -> None: """Retrieve secrets from AWS Secrets Manager @@ -117,65 +248,53 @@ def load_to_featuresdb( def load_to_featuresdb_eis( filename: str, collection: str, - extra_flags: list = None, - target_projection: str = "EPSG:4326", + target_projection: int = 4326, ): - """ - EIS Fire team naming convention for outputs - Snapshots: "snapshot_{layer_name}_nrt_{region_name}.fgb" - Lf_archive: "lf_{layer_name}_archive_{region_name}.fgb" - Lf_nrt: "lf_{layer_name}_nrt_{region_name}.fgb" - - Insert on table call everything except the region name: - e.g. `snapshot_perimeter_nrt_conus` this gets inserted into the table `eis_fire_snapshot_perimeter_nrt` - """ + # NOTE: about `collection.rsplit` below: + # + # EIS Fire team naming convention for outputs + # Snapshots: "snapshot_{layer_name}_nrt_{region_name}.gpkg" + # Lf_archive: "lf_{layer_name}_archive_{region_name}.gpkg" + # Lf_nrt: "lf_{layer_name}_nrt_{region_name}.gpkg" + # + # Insert/Alter on table call everything except the region name: + # e.g. `snapshot_perimeter_nrt_conus` this gets inserted into the table `eis_fire_snapshot_perimeter_nrt` collection = collection.rsplit("_", 1)[0] - - if extra_flags is None: - extra_flags = ["-append", "-progress"] + target_table_name = f"eis_fire_{collection}" secret_name = os.environ.get("VECTOR_SECRET_NAME") + conn_secrets = get_secret(secret_name) + connection_string = get_connection_string(conn_secrets, as_uri=True) - con_secrets = get_secret(secret_name) - connection = get_connection_string(con_secrets) - - print(f"running ogr2ogr import for collection: {collection}") - - out = subprocess.run( - [ - "ogr2ogr", - "-f", - "PostgreSQL", - connection, - "-t_srs", - target_projection, - filename, - "-nln", - f"eis_fire_{collection}", - *extra_flags, - ], - check=False, - capture_output=True, + ensure_table_exists( + connection_string, filename, target_projection, table_name=target_table_name + ) + upsert_to_postgis( + connection_string, filename, target_projection, table_name=target_table_name ) - - if out.stderr: - error_description = f"Error: {out.stderr}" - print(error_description) - return {"status": "failure", "reason": error_description} return {"status": "success"} def alter_datetime_add_indexes_eis(collection: str): - secret_name = os.environ.get("VECTOR_SECRET_NAME") - - con_secrets = get_secret(secret_name) + # NOTE: about `collection.rsplit` below: + # + # EIS Fire team naming convention for outputs + # Snapshots: "snapshot_{layer_name}_nrt_{region_name}.gpkg" + # Lf_archive: "lf_{layer_name}_archive_{region_name}.gpkg" + # Lf_nrt: "lf_{layer_name}_nrt_{region_name}.gpkg" + # + # Insert/Alter on table call everything except the region name: + # e.g. `snapshot_perimeter_nrt_conus` this gets inserted into the table `eis_fire_snapshot_perimeter_nrt` + collection = collection.rsplit("_", 1)[0] + secret_name = os.environ.get("VECTOR_SECRET_NAME") + conn_secrets = get_secret(secret_name) conn = psycopg2.connect( - host=con_secrets["host"], - dbname=con_secrets["dbname"], - user=con_secrets["username"], - password=con_secrets["password"], + host=conn_secrets["host"], + dbname=conn_secrets["dbname"], + user=conn_secrets["username"], + password=conn_secrets["password"], ) cur = conn.cursor() diff --git a/docker_tasks/vector_ingest/requirements.txt b/docker_tasks/vector_ingest/requirements.txt index 50090b14..38263eed 100644 --- a/docker_tasks/vector_ingest/requirements.txt +++ b/docker_tasks/vector_ingest/requirements.txt @@ -1,4 +1,7 @@ smart-open==6.3.0 -psycopg2-binary==2.9.6 +psycopg2-binary==2.9.9 requests==2.30.0 -boto3==1.26.129 \ No newline at end of file +boto3==1.26.129 +GeoAlchemy2==0.14.2 +geopandas==0.14.0 +SQLAlchemy==2.0.23 \ No newline at end of file