Skip to content

Commit

Permalink
Add the initial shapefile datasource prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
pomadchin committed May 13, 2022
1 parent 44f4072 commit 05f8cb3
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 4 deletions.
6 changes: 5 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ lazy val datasource = project
spark("core").value % Provided,
spark("mllib").value % Provided,
spark("sql").value % Provided,
`better-files`
`better-files`,
geotrellis("shapefile").value,
geotoolsMain,
geotoolsOpengis,
geotoolsShapefile
),
Compile / console / scalacOptions ~= { _.filterNot(Set("-Ywarn-unused-import", "-Ywarn-unused:imports")) },
Test / console / scalacOptions ~= { _.filterNot(Set("-Ywarn-unused-import", "-Ywarn-unused:imports")) },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ org.locationtech.rasterframes.datasource.raster.RasterSourceDataSource
org.locationtech.rasterframes.datasource.geojson.GeoJsonDataSource
org.locationtech.rasterframes.datasource.stac.api.StacApiDataSource
org.locationtech.rasterframes.datasource.tiles.TilesDataSource
org.locationtech.rasterframes.datasource.slippy.SlippyDataSource
org.locationtech.rasterframes.datasource.slippy.SlippyDataSource
org.locationtech.rasterframes.datasource.shapefile.ShapeFileDataSource
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import sttp.model.Uri

import java.net.URI
import java.io.File
import java.net.{URI, URL}
import scala.util.Try

/**
Expand Down Expand Up @@ -65,6 +66,25 @@ package object datasource {
if(parameters.containsKey(key)) Uri.parse(parameters.get(key)).toOption
else None

private[rasterframes]
def urlParam(key: String, parameters: Map[String, String]): Option[URL] =
parameters.get(key).flatMap { p =>
Try {
if (p.contains("://")) new URL(p)
else new URL(s"file://${new File(p).getAbsolutePath}")
}.toOption
}

private[rasterframes]
def urlParam(key: String, parameters: CaseInsensitiveStringMap): Option[URL] =
if(parameters.containsKey(key)) {
val p = parameters.get(key)
Try {
if (p.contains("://")) new URL(p)
else new URL(s"file://${new File(p).getAbsolutePath}")
}.toOption
} else None

private[rasterframes]
def jsonParam(key: String, parameters: Map[String, String]): Option[Json] =
parameters.get(key).flatMap(p => parser.parse(p).toOption)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.locationtech.rasterframes.datasource.shapefile

import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

import java.util

class ShapeFileDataSource extends TableProvider with DataSourceRegister {

def inferSchema(caseInsensitiveStringMap: CaseInsensitiveStringMap): StructType =
getTable(null, Array.empty[Transform], caseInsensitiveStringMap.asCaseSensitiveMap()).schema()

def getTable(structType: StructType, transforms: Array[Transform], map: util.Map[String, String]): Table =
new ShapeFileTable()

def shortName(): String = ShapeFileDataSource.SHORT_NAME
}

object ShapeFileDataSource {
final val SHORT_NAME = "shapefile"
final val URL_PARAM = "url"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.locationtech.rasterframes.datasource.shapefile

import org.locationtech.rasterframes.encoders.syntax._

import geotrellis.vector.Geometry
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.geotools.data.shapefile.ShapefileDataStore
import org.geotools.data.simple.SimpleFeatureIterator

import java.net.URL

case class ShapeFilePartition(url: URL) extends InputPartition

class ShapeFilePartitionReaderFactory extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = partition match {
case p: ShapeFilePartition => new ShapeFilePartitionReader(p)
case _ => throw new UnsupportedOperationException("Partition processing is unsupported by the reader.")
}
}

class ShapeFilePartitionReader(partition: ShapeFilePartition) extends PartitionReader[InternalRow] {
import geotrellis.shapefile.ShapeFileReader._

@transient lazy val ds = new ShapefileDataStore(partition.url)
@transient lazy val partitionValues: SimpleFeatureIterator = ds.getFeatureSource.getFeatures.features

def next: Boolean = partitionValues.hasNext

def get: InternalRow = partitionValues.next.geom[Geometry].toInternalRow

def close(): Unit = { partitionValues.close(); ds.dispose() }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package org.locationtech.rasterframes.datasource.shapefile

import org.locationtech.rasterframes.datasource.stac.api.encoders._
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.types.StructType

import java.net.URL

class ShapeFileScanBuilder(url: URL) extends ScanBuilder {
def build(): Scan = new ShapeFileBatchScan(url)
}

/** Batch Reading Support. The schema is repeated here as it can change after column pruning, etc. */
class ShapeFileBatchScan(url: URL) extends Scan with Batch {
def readSchema(): StructType = geometryExpressionEncoder.schema

override def toBatch: Batch = this

/**
* Unfortunately, we can only load everything into a single partition, due to the nature of STAC API endpoints.
* To perform a distributed load, we'd need to know some internals about how the next page token is computed.
* This can be a good idea for the STAC Spec extension.
* */
def planInputPartitions(): Array[InputPartition] = Array(ShapeFilePartition(url))
def createReaderFactory(): PartitionReaderFactory = new ShapeFilePartitionReaderFactory()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.locationtech.rasterframes.datasource.shapefile

import org.locationtech.rasterframes.datasource.stac.api.encoders._

import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.locationtech.rasterframes.datasource.shapefile.ShapeFileDataSource.URL_PARAM
import org.locationtech.rasterframes.datasource.urlParam
import java.net.URL

import scala.collection.JavaConverters._
import java.util

class ShapeFileTable extends Table with SupportsRead {
import ShapeFileTable._

def name(): String = this.getClass.toString

def schema(): StructType = geometryExpressionEncoder.schema

def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava

def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
new ShapeFileScanBuilder(options.url)
}

object ShapeFileTable {
implicit class CaseInsensitiveStringMapOps(val options: CaseInsensitiveStringMap) extends AnyVal {
def url: URL = urlParam(URL_PARAM, options).getOrElse(throw new IllegalArgumentException("Missing URL."))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.locationtech.rasterframes.datasource

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.locationtech.jts.geom.Geometry

package object shapefile extends Serializable {
// see org.locationtech.geomesa.spark.jts.encoders.SpatialEncoders
// GeometryUDT should be registered before the encoder below is used
// TODO: use TypedEncoders derived from UDT instances?
@transient implicit lazy val geometryExpressionEncoder: ExpressionEncoder[Option[Geometry]] = ExpressionEncoder()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.locationtech.rasterframes.datasource.shapefile

import org.locationtech.rasterframes._

import geotrellis.shapefile.ShapeFileReader
import org.locationtech.jts.geom.Geometry
import org.locationtech.rasterframes.TestEnvironment

import java.net.URL

class ShapeFileDataSourceTest extends TestEnvironment { self =>
import spark.implicits._

describe("ShapeFile Spark reader") {
it("should read a shapefile") {
val url = "https://github.com/locationtech/geotrellis/raw/master/shapefile/data/shapefiles/demographics/demographics.shp"
import ShapeFileReader._

val expected = ShapeFileReader
.readSimpleFeatures(new URL(url))
.map(_.geom[Geometry])
.take(2)

val results =
spark
.read
.format("shapefile")
.option("url", url)
.load()
.limit(2)

// results.printSchema()

results.as[Option[Geometry]].collect() shouldBe expected
}

}
}
8 changes: 7 additions & 1 deletion project/RFDependenciesPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ object RFDependenciesPlugin extends AutoPlugin {
val frameless = "org.typelevel" %% "frameless-dataset-spark31" % "0.11.1"
val framelessRefined = "org.typelevel" %% "frameless-refined-spark31" % "0.11.1"
val `better-files` = "com.github.pathikrit" %% "better-files" % "3.9.1" % Test

val geotoolsVersion = "25.0"
val geotoolsMain = "org.geotools" % "gt-main" % geotoolsVersion
val geotoolsShapefile = "org.geotools" % "gt-shapefile" % geotoolsVersion
val geotoolsOpengis = "org.geotools" % "gt-opengis" % geotoolsVersion
}
import autoImport._

Expand All @@ -67,7 +72,8 @@ object RFDependenciesPlugin extends AutoPlugin {
"boundless-releases" at "https://repo.boundlessgeo.com/main/",
"Open Source Geospatial Foundation Repository" at "https://download.osgeo.org/webdav/geotools/",
"oss-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots",
"jitpack" at "https://jitpack.io"
"jitpack" at "https://jitpack.io",
"osgeo-releases" at "https://repo.osgeo.org/repository/release/"
),
// dependencyOverrides += "com.azavea.gdal" % "gdal-warp-bindings" % "33.f746890",
// NB: Make sure to update the Spark version in pyrasterframes/python/setup.py
Expand Down

0 comments on commit 05f8cb3

Please sign in to comment.