diff --git a/build.sbt b/build.sbt index d54d0cac8..7835a0a09 100644 --- a/build.sbt +++ b/build.sbt @@ -128,7 +128,11 @@ lazy val datasource = project spark("core").value % Provided, spark("mllib").value % Provided, spark("sql").value % Provided, - `better-files` + `better-files`, + geotrellis("shapefile").value, + geotoolsMain, + geotoolsOpengis, + geotoolsShapefile ), Compile / console / scalacOptions ~= { _.filterNot(Set("-Ywarn-unused-import", "-Ywarn-unused:imports")) }, Test / console / scalacOptions ~= { _.filterNot(Set("-Ywarn-unused-import", "-Ywarn-unused:imports")) }, diff --git a/datasource/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/datasource/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index e5e28792e..e3cf920e0 100644 --- a/datasource/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/datasource/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,4 +5,5 @@ org.locationtech.rasterframes.datasource.raster.RasterSourceDataSource org.locationtech.rasterframes.datasource.geojson.GeoJsonDataSource org.locationtech.rasterframes.datasource.stac.api.StacApiDataSource org.locationtech.rasterframes.datasource.tiles.TilesDataSource -org.locationtech.rasterframes.datasource.slippy.SlippyDataSource \ No newline at end of file +org.locationtech.rasterframes.datasource.slippy.SlippyDataSource +org.locationtech.rasterframes.datasource.shapefile.ShapeFileDataSource diff --git a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/package.scala b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/package.scala index 71e925cc7..21395ac86 100644 --- a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/package.scala +++ b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/package.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.util.CaseInsensitiveStringMap import sttp.model.Uri -import java.net.URI +import java.io.File +import java.net.{URI, URL} import scala.util.Try /** @@ -65,6 +66,25 @@ package object datasource { if(parameters.containsKey(key)) Uri.parse(parameters.get(key)).toOption else None + private[rasterframes] + def urlParam(key: String, parameters: Map[String, String]): Option[URL] = + parameters.get(key).flatMap { p => + Try { + if (p.contains("://")) new URL(p) + else new URL(s"file://${new File(p).getAbsolutePath}") + }.toOption + } + + private[rasterframes] + def urlParam(key: String, parameters: CaseInsensitiveStringMap): Option[URL] = + if(parameters.containsKey(key)) { + val p = parameters.get(key) + Try { + if (p.contains("://")) new URL(p) + else new URL(s"file://${new File(p).getAbsolutePath}") + }.toOption + } else None + private[rasterframes] def jsonParam(key: String, parameters: Map[String, String]): Option[Json] = parameters.get(key).flatMap(p => parser.parse(p).toOption) diff --git a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/shapefile/ShapeFileDataSource.scala b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/shapefile/ShapeFileDataSource.scala new file mode 100644 index 000000000..b2f709674 --- /dev/null +++ b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/shapefile/ShapeFileDataSource.scala @@ -0,0 +1,25 @@ +package org.locationtech.rasterframes.datasource.shapefile + +import org.apache.spark.sql.connector.catalog.{Table, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util + +class ShapeFileDataSource extends TableProvider with DataSourceRegister { + + def inferSchema(caseInsensitiveStringMap: CaseInsensitiveStringMap): StructType = + getTable(null, Array.empty[Transform], caseInsensitiveStringMap.asCaseSensitiveMap()).schema() + + def getTable(structType: StructType, transforms: Array[Transform], map: util.Map[String, String]): Table = + new ShapeFileTable() + + def shortName(): String = ShapeFileDataSource.SHORT_NAME +} + +object ShapeFileDataSource { + final val SHORT_NAME = "shapefile" + final val URL_PARAM = "url" +} diff --git a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/shapefile/ShapeFilePartition.scala b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/shapefile/ShapeFilePartition.scala new file mode 100644 index 000000000..1e7459722 --- /dev/null +++ b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/shapefile/ShapeFilePartition.scala @@ -0,0 +1,33 @@ +package org.locationtech.rasterframes.datasource.shapefile + +import org.locationtech.rasterframes.encoders.syntax._ + +import geotrellis.vector.Geometry +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.geotools.data.shapefile.ShapefileDataStore +import org.geotools.data.simple.SimpleFeatureIterator + +import java.net.URL + +case class ShapeFilePartition(url: URL) extends InputPartition + +class ShapeFilePartitionReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = partition match { + case p: ShapeFilePartition => new ShapeFilePartitionReader(p) + case _ => throw new UnsupportedOperationException("Partition processing is unsupported by the reader.") + } +} + +class ShapeFilePartitionReader(partition: ShapeFilePartition) extends PartitionReader[InternalRow] { + import geotrellis.shapefile.ShapeFileReader._ + + @transient lazy val ds = new ShapefileDataStore(partition.url) + @transient lazy val partitionValues: SimpleFeatureIterator = ds.getFeatureSource.getFeatures.features + + def next: Boolean = partitionValues.hasNext + + def get: InternalRow = partitionValues.next.geom[Geometry].toInternalRow + + def close(): Unit = { partitionValues.close(); ds.dispose() } +} diff --git a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/shapefile/ShapeFileScanBuilder.scala b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/shapefile/ShapeFileScanBuilder.scala new file mode 100644 index 000000000..032f80059 --- /dev/null +++ b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/shapefile/ShapeFileScanBuilder.scala @@ -0,0 +1,21 @@ +package org.locationtech.rasterframes.datasource.shapefile + +import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder} +import org.apache.spark.sql.types.StructType + +import java.net.URL + +class ShapeFileScanBuilder(url: URL) extends ScanBuilder { + def build(): Scan = new ShapeFileBatchScan(url) +} + +/** Batch Reading Support. The schema is repeated here as it can change after column pruning, etc. */ +class ShapeFileBatchScan(url: URL) extends Scan with Batch { + def readSchema(): StructType = geometryExpressionEncoder.schema + + override def toBatch: Batch = this + + /** Unfortunately, we can only load one file into a single partition only.*/ + def planInputPartitions(): Array[InputPartition] = Array(ShapeFilePartition(url)) + def createReaderFactory(): PartitionReaderFactory = new ShapeFilePartitionReaderFactory() +} diff --git a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/shapefile/ShapeFileTable.scala b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/shapefile/ShapeFileTable.scala new file mode 100644 index 000000000..20ea223a3 --- /dev/null +++ b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/shapefile/ShapeFileTable.scala @@ -0,0 +1,31 @@ +package org.locationtech.rasterframes.datasource.shapefile + +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.locationtech.rasterframes.datasource.shapefile.ShapeFileDataSource.URL_PARAM +import org.locationtech.rasterframes.datasource.urlParam +import java.net.URL + +import scala.collection.JavaConverters._ +import java.util + +class ShapeFileTable extends Table with SupportsRead { + import ShapeFileTable._ + + def name(): String = this.getClass.toString + + def schema(): StructType = geometryExpressionEncoder.schema + + def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava + + def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = + new ShapeFileScanBuilder(options.url) +} + +object ShapeFileTable { + implicit class CaseInsensitiveStringMapOps(val options: CaseInsensitiveStringMap) extends AnyVal { + def url: URL = urlParam(URL_PARAM, options).getOrElse(throw new IllegalArgumentException("Missing URL.")) + } +} diff --git a/datasource/src/main/scala/org/locationtech/rasterframes/datasource/shapefile/package.scala b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/shapefile/package.scala new file mode 100644 index 000000000..37ee3febd --- /dev/null +++ b/datasource/src/main/scala/org/locationtech/rasterframes/datasource/shapefile/package.scala @@ -0,0 +1,11 @@ +package org.locationtech.rasterframes.datasource + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.locationtech.jts.geom.Geometry + +package object shapefile extends Serializable { + // see org.locationtech.geomesa.spark.jts.encoders.SpatialEncoders + // GeometryUDT should be registered before the encoder below is used + // TODO: use TypedEncoders derived from UDT instances? + @transient implicit lazy val geometryExpressionEncoder: ExpressionEncoder[Option[Geometry]] = ExpressionEncoder() +} diff --git a/datasource/src/test/scala/org/locationtech/rasterframes/datasource/shapefile/ShapeFileDataSourceTest.scala b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/shapefile/ShapeFileDataSourceTest.scala new file mode 100644 index 000000000..55839de20 --- /dev/null +++ b/datasource/src/test/scala/org/locationtech/rasterframes/datasource/shapefile/ShapeFileDataSourceTest.scala @@ -0,0 +1,35 @@ +package org.locationtech.rasterframes.datasource.shapefile + +import geotrellis.shapefile.ShapeFileReader +import org.locationtech.jts.geom.Geometry +import org.locationtech.rasterframes.TestEnvironment + +import java.net.URL + +class ShapeFileDataSourceTest extends TestEnvironment { self => + import spark.implicits._ + + describe("ShapeFile Spark reader") { + it("should read a shapefile") { + val url = "https://github.com/locationtech/geotrellis/raw/master/shapefile/data/shapefiles/demographics/demographics.shp" + import ShapeFileReader._ + + val expected = ShapeFileReader + .readSimpleFeatures(new URL(url)) + .map(_.geom[Geometry]) + .take(2) + + val results = + spark + .read + .format("shapefile") + .option("url", url) + .load() + .limit(2) + + results.printSchema() + + results.as[Option[Geometry]].collect() shouldBe expected + } + } +} diff --git a/project/RFDependenciesPlugin.scala b/project/RFDependenciesPlugin.scala index 8ac74a84f..2729090e2 100644 --- a/project/RFDependenciesPlugin.scala +++ b/project/RFDependenciesPlugin.scala @@ -57,6 +57,11 @@ object RFDependenciesPlugin extends AutoPlugin { val frameless = "org.typelevel" %% "frameless-dataset-spark31" % "0.11.1" val framelessRefined = "org.typelevel" %% "frameless-refined-spark31" % "0.11.1" val `better-files` = "com.github.pathikrit" %% "better-files" % "3.9.1" % Test + + val geotoolsVersion = "25.0" + val geotoolsMain = "org.geotools" % "gt-main" % geotoolsVersion + val geotoolsShapefile = "org.geotools" % "gt-shapefile" % geotoolsVersion + val geotoolsOpengis = "org.geotools" % "gt-opengis" % geotoolsVersion } import autoImport._ @@ -67,7 +72,8 @@ object RFDependenciesPlugin extends AutoPlugin { "boundless-releases" at "https://repo.boundlessgeo.com/main/", "Open Source Geospatial Foundation Repository" at "https://download.osgeo.org/webdav/geotools/", "oss-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", - "jitpack" at "https://jitpack.io" + "jitpack" at "https://jitpack.io", + "osgeo-releases" at "https://repo.osgeo.org/repository/release/" ), // dependencyOverrides += "com.azavea.gdal" % "gdal-warp-bindings" % "33.f746890", // NB: Make sure to update the Spark version in pyrasterframes/python/setup.py