Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ShapeFile DataSource prototype #585

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ lazy val datasource = project
spark("core").value % Provided,
spark("mllib").value % Provided,
spark("sql").value % Provided,
`better-files`
`better-files`,
geotrellis("shapefile").value,
geotoolsMain,
geotoolsOpengis,
geotoolsShapefile
),
Compile / console / scalacOptions ~= { _.filterNot(Set("-Ywarn-unused-import", "-Ywarn-unused:imports")) },
Test / console / scalacOptions ~= { _.filterNot(Set("-Ywarn-unused-import", "-Ywarn-unused:imports")) },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ org.locationtech.rasterframes.datasource.raster.RasterSourceDataSource
org.locationtech.rasterframes.datasource.geojson.GeoJsonDataSource
org.locationtech.rasterframes.datasource.stac.api.StacApiDataSource
org.locationtech.rasterframes.datasource.tiles.TilesDataSource
org.locationtech.rasterframes.datasource.slippy.SlippyDataSource
org.locationtech.rasterframes.datasource.slippy.SlippyDataSource
org.locationtech.rasterframes.datasource.shapefile.ShapeFileDataSource
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import sttp.model.Uri

import java.net.URI
import java.io.File
import java.net.{URI, URL}
import scala.util.Try

/**
Expand Down Expand Up @@ -65,6 +66,25 @@ package object datasource {
if(parameters.containsKey(key)) Uri.parse(parameters.get(key)).toOption
else None

private[rasterframes]
def urlParam(key: String, parameters: Map[String, String]): Option[URL] =
parameters.get(key).flatMap { p =>
Try {
if (p.contains("://")) new URL(p)
else new URL(s"file://${new File(p).getAbsolutePath}")
}.toOption
}

private[rasterframes]
def urlParam(key: String, parameters: CaseInsensitiveStringMap): Option[URL] =
if(parameters.containsKey(key)) {
val p = parameters.get(key)
Try {
if (p.contains("://")) new URL(p)
else new URL(s"file://${new File(p).getAbsolutePath}")
}.toOption
} else None

private[rasterframes]
def jsonParam(key: String, parameters: Map[String, String]): Option[Json] =
parameters.get(key).flatMap(p => parser.parse(p).toOption)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.locationtech.rasterframes.datasource.shapefile

import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

import java.util

class ShapeFileDataSource extends TableProvider with DataSourceRegister {

def inferSchema(caseInsensitiveStringMap: CaseInsensitiveStringMap): StructType =
getTable(null, Array.empty[Transform], caseInsensitiveStringMap.asCaseSensitiveMap()).schema()

def getTable(structType: StructType, transforms: Array[Transform], map: util.Map[String, String]): Table =
new ShapeFileTable()

def shortName(): String = ShapeFileDataSource.SHORT_NAME
}

object ShapeFileDataSource {
final val SHORT_NAME = "shapefile"
final val URL_PARAM = "url"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.locationtech.rasterframes.datasource.shapefile

import org.locationtech.rasterframes.encoders.syntax._

import geotrellis.vector.Geometry
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.geotools.data.shapefile.ShapefileDataStore
import org.geotools.data.simple.SimpleFeatureIterator

import java.net.URL

case class ShapeFilePartition(url: URL) extends InputPartition

class ShapeFilePartitionReaderFactory extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = partition match {
case p: ShapeFilePartition => new ShapeFilePartitionReader(p)
case _ => throw new UnsupportedOperationException("Partition processing is unsupported by the reader.")
}
}

class ShapeFilePartitionReader(partition: ShapeFilePartition) extends PartitionReader[InternalRow] {
import geotrellis.shapefile.ShapeFileReader._

@transient lazy val ds = new ShapefileDataStore(partition.url)
@transient lazy val partitionValues: SimpleFeatureIterator = ds.getFeatureSource.getFeatures.features

def next: Boolean = partitionValues.hasNext

def get: InternalRow = partitionValues.next.geom[Geometry].toInternalRow

def close(): Unit = { partitionValues.close(); ds.dispose() }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package org.locationtech.rasterframes.datasource.shapefile

import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.types.StructType

import java.net.URL

class ShapeFileScanBuilder(url: URL) extends ScanBuilder {
def build(): Scan = new ShapeFileBatchScan(url)
}

/** Batch Reading Support. The schema is repeated here as it can change after column pruning, etc. */
class ShapeFileBatchScan(url: URL) extends Scan with Batch {
def readSchema(): StructType = geometryExpressionEncoder.schema

override def toBatch: Batch = this

/** Unfortunately, we can only load one file into a single partition only.*/
def planInputPartitions(): Array[InputPartition] = Array(ShapeFilePartition(url))
def createReaderFactory(): PartitionReaderFactory = new ShapeFilePartitionReaderFactory()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package org.locationtech.rasterframes.datasource.shapefile

import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.locationtech.rasterframes.datasource.shapefile.ShapeFileDataSource.URL_PARAM
import org.locationtech.rasterframes.datasource.urlParam
import java.net.URL

import scala.collection.JavaConverters._
import java.util

class ShapeFileTable extends Table with SupportsRead {
import ShapeFileTable._

def name(): String = this.getClass.toString

def schema(): StructType = geometryExpressionEncoder.schema

def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava

def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
new ShapeFileScanBuilder(options.url)
}

object ShapeFileTable {
implicit class CaseInsensitiveStringMapOps(val options: CaseInsensitiveStringMap) extends AnyVal {
def url: URL = urlParam(URL_PARAM, options).getOrElse(throw new IllegalArgumentException("Missing URL."))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.locationtech.rasterframes.datasource

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.locationtech.jts.geom.Geometry

package object shapefile extends Serializable {
// see org.locationtech.geomesa.spark.jts.encoders.SpatialEncoders
// GeometryUDT should be registered before the encoder below is used
// TODO: use TypedEncoders derived from UDT instances?
@transient implicit lazy val geometryExpressionEncoder: ExpressionEncoder[Option[Geometry]] = ExpressionEncoder()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.locationtech.rasterframes.datasource.shapefile

import geotrellis.shapefile.ShapeFileReader
import org.locationtech.jts.geom.Geometry
import org.locationtech.rasterframes.TestEnvironment

import java.net.URL

class ShapeFileDataSourceTest extends TestEnvironment { self =>
import spark.implicits._

describe("ShapeFile Spark reader") {
it("should read a shapefile") {
val url = "https://github.com/locationtech/geotrellis/raw/master/shapefile/data/shapefiles/demographics/demographics.shp"
import ShapeFileReader._

val expected = ShapeFileReader
.readSimpleFeatures(new URL(url))
.map(_.geom[Geometry])
.take(2)

val results =
spark
.read
.format("shapefile")
.option("url", url)
.load()
.limit(2)

results.printSchema()

results.as[Option[Geometry]].collect() shouldBe expected
}
}
}
8 changes: 7 additions & 1 deletion project/RFDependenciesPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ object RFDependenciesPlugin extends AutoPlugin {
val frameless = "org.typelevel" %% "frameless-dataset-spark31" % "0.11.1"
val framelessRefined = "org.typelevel" %% "frameless-refined-spark31" % "0.11.1"
val `better-files` = "com.github.pathikrit" %% "better-files" % "3.9.1" % Test

val geotoolsVersion = "25.0"
val geotoolsMain = "org.geotools" % "gt-main" % geotoolsVersion
val geotoolsShapefile = "org.geotools" % "gt-shapefile" % geotoolsVersion
val geotoolsOpengis = "org.geotools" % "gt-opengis" % geotoolsVersion
}
import autoImport._

Expand All @@ -67,7 +72,8 @@ object RFDependenciesPlugin extends AutoPlugin {
"boundless-releases" at "https://repo.boundlessgeo.com/main/",
"Open Source Geospatial Foundation Repository" at "https://download.osgeo.org/webdav/geotools/",
"oss-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots",
"jitpack" at "https://jitpack.io"
"jitpack" at "https://jitpack.io",
"osgeo-releases" at "https://repo.osgeo.org/repository/release/"
),
// dependencyOverrides += "com.azavea.gdal" % "gdal-warp-bindings" % "33.f746890",
// NB: Make sure to update the Spark version in pyrasterframes/python/setup.py
Expand Down