From a751468997e4bed48a34fbc488dcffae98a493b3 Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Sat, 10 Feb 2024 22:35:20 +0100 Subject: [PATCH] split logic for extraction --- _docs/index.md | 64 ++++++ project.scala | 1 + .../stringmatching/regex/Interpolators.scala | 5 +- .../scala/stringmatching/regex/Macros.scala | 66 ++++-- .../scala/stringmatching/regex/Runtime.scala | 210 +++++++++++------- 5 files changed, 241 insertions(+), 105 deletions(-) diff --git a/_docs/index.md b/_docs/index.md index 66db09e..35c1c27 100644 --- a/_docs/index.md +++ b/_docs/index.md @@ -14,3 +14,67 @@ In one pattern, you can extract a typed value `xs: IndexedSeq[Int]` as follows: "[23, 56, 71]" match case r"[${r"$xs%d"}...(, )]" => xs.sum // 150 ``` + +## Possible Formats + +### String Pattern + +e.g. `$foo`, which extracts `val foo: String`. + +### Int Pattern + +e.g. `$foo%d`, which extracts `val foo: Int`. + +### Long Pattern + +e.g. `$foo%L`, which extracts `val foo: Long`. + +### Float Pattern + +e.g. `$foo%f`, which extracts `val foo: Float`. + +### Double Pattern + +e.g. `$foo%g`, which extracts `val foo: Double`. + +### Split Pattern + +e.g. `$foo...()`, which extracts `val foo: List[String]`. + +This is equivalent to extracting with `$foo` and then performing`foo.split(raw"").toIndexedSeq`. + +This means that inside the `` you may put any valid regex accepted by `scala.util.matching.Regex`. +String escape characters are also not processed within the regex. + +There is also a special case where if the first element of the sequence is expected to be empty you can drop it with the `$foo..!()` pattern. + + +Putting this all together, you could split Windows style strings with the following pattern: + +```scala sc:nocompile +raw"C:\foo\bar\baz.pdf" match + case r"C:$elems..!(\\)" => elems.mkString("/") +// yields "foo/bar/baz.pdf" +``` + +### Nested Patterns + +The `r` interpolator can also match on `Seq` of strings, arbitrarily nested. + +For example + +```scala sc:nocompile +val strings: Seq[String] = ??? + +val foo: Seq[Int] = strings match + case r"$foo%d" => foo +``` + +or even + +```scala sc:nocompile +val stringss: Seq[Seq[String]] = ??? + +val foo: Seq[Seq[Int]] = stringss match + case r"$foo%d" => foo +``` diff --git a/project.scala b/project.scala index 11e3273..ba6becf 100644 --- a/project.scala +++ b/project.scala @@ -1,5 +1,6 @@ // Main //> using scala "3.4.0-RC3" +//> using options -source:future -Yexplicit-nulls //> using options -project enhanced-string-interpolator -siteroot ${.} //> using publish.ci.computeVersion "git:tag" diff --git a/src/main/scala/stringmatching/regex/Interpolators.scala b/src/main/scala/stringmatching/regex/Interpolators.scala index 2e50422..c182427 100644 --- a/src/main/scala/stringmatching/regex/Interpolators.scala +++ b/src/main/scala/stringmatching/regex/Interpolators.scala @@ -21,7 +21,10 @@ object Interpolators: end PatternElement /** Holder for the pattern elements described by a string interpolated with `r`. */ - case class Pattern(elements: Seq[PatternElement]) + enum Pattern: + case Literal(glob: String) + case Single(glob: String, pattern: PatternElement) + case Multiple(glob: String, patterns: Seq[PatternElement]) extension (inline sc: StringContext) /** use in patterns like `case r"$foo...(, )" => println(foo)` */ diff --git a/src/main/scala/stringmatching/regex/Macros.scala b/src/main/scala/stringmatching/regex/Macros.scala index e73e0c7..125ffdb 100644 --- a/src/main/scala/stringmatching/regex/Macros.scala +++ b/src/main/scala/stringmatching/regex/Macros.scala @@ -20,7 +20,8 @@ object Macros: '{ new RSStringContext[t]($patternExpr) } end rsApplyExpr - /** Process a `RSStringContext` into a well-typed call to [[stringmatching.regex.Runtime.extract]] + /** Process a `RSStringContext` into a well-typed call to + * [[stringmatching.regex.Runtime.unsafeExtract]] */ def rsUnapplyExpr[R: Type, Base: Type]( rsSCExpr: Expr[RSStringContext[R]], @@ -35,7 +36,7 @@ object Macros: returnType match case '[t] => '{ - Runtime.extract[Base, t]($patternExpr.elements, levels = $levelsExpr)($scrutinee) + Runtime.unsafeExtract[Base, t]($patternExpr, levels = $levelsExpr)($scrutinee) } end match end rsUnapplyExpr @@ -43,8 +44,16 @@ object Macros: private object Reify: given PatternToExpr: ToExpr[Pattern] with - def apply(pattern: Pattern)(using Quotes): Expr[Pattern] = - '{ Pattern(${ Expr.ofSeq(pattern.elements.map(Expr(_))) }) } + import Pattern.* + + def apply(pattern: Pattern)(using Quotes): Expr[Pattern] = pattern match + case Literal(glob) => + '{ Literal(${ Expr(glob) }) } + case Single(glob, pattern) => + '{ Single(${ Expr(glob) }, ${ Expr(pattern) }) } + case Multiple(glob, patterns) => + '{ Multiple(${ Expr(glob) }, ${ Expr.ofSeq(patterns.map(Expr(_))) }) } + end PatternToExpr given FormatPatternToExpr: ToExpr[FormatPattern] with import FormatPattern.* @@ -145,29 +154,40 @@ object Macros: case _ => report.errorAndAbort(s"unsupported format: `%$format`") case rest => PatternElement.Glob(globPattern(rest)) - Pattern(PatternElement.Glob(globPattern(g)) +: rest0) + + val g0 = globPattern(g) + + if rest0.isEmpty then Pattern.Literal(g0) + else if rest0.sizeIs == 1 then Pattern.Single(g0, rest0.head) + else Pattern.Multiple(g0, rest0) end parsed private def refineResult(pattern: Pattern)(using Quotes): quotes.reflect.TypeRepr = import quotes.reflect.* - val args = pattern.elements - .drop(1) - .map: - case PatternElement.Glob(_) => TypeRepr.of[String] - case PatternElement.Split(_, _) => TypeRepr.of[IndexedSeq[String]] - case PatternElement.SplitEmpty(_, _) => TypeRepr.of[IndexedSeq[String]] - case PatternElement.Format(format, _) => - format match - case FormatPattern.AsInt => TypeRepr.of[Int] - case FormatPattern.AsLong => TypeRepr.of[Long] - case FormatPattern.AsDouble => TypeRepr.of[Double] - case FormatPattern.AsFloat => TypeRepr.of[Float] - if args.size == 0 then TypeRepr.of[EmptyTuple] - else if args.size == 1 then args.head - else if args.size <= 22 then AppliedType(defn.TupleClass(args.size).typeRef, args.toList) - else - report.errorAndAbort(s"too many captures: ${args.size} (implementation restriction: max 22)") - end if + + def typeOfPattern(element: PatternElement) = element match + case PatternElement.Glob(_) => TypeRepr.of[String] + case PatternElement.Split(_, _) => TypeRepr.of[IndexedSeq[String]] + case PatternElement.SplitEmpty(_, _) => TypeRepr.of[IndexedSeq[String]] + case PatternElement.Format(format, _) => + format match + case FormatPattern.AsInt => TypeRepr.of[Int] + case FormatPattern.AsLong => TypeRepr.of[Long] + case FormatPattern.AsDouble => TypeRepr.of[Double] + case FormatPattern.AsFloat => TypeRepr.of[Float] + + pattern match + case Pattern.Literal(_) => TypeRepr.of[EmptyTuple] + case Pattern.Single(_, pattern) => typeOfPattern(pattern) + case Pattern.Multiple(_, elements) => + val args = elements.map(typeOfPattern) + if args.size <= 22 then AppliedType(defn.TupleClass(args.size).typeRef, args.toList) + else + report.errorAndAbort( + s"too many captures: ${args.size} (implementation restriction: max 22)" + ) + end if + end match end refineResult private def wrapping[Base: Type](using Quotes): Int = diff --git a/src/main/scala/stringmatching/regex/Runtime.scala b/src/main/scala/stringmatching/regex/Runtime.scala index 1904f8b..f8d3dd3 100644 --- a/src/main/scala/stringmatching/regex/Runtime.scala +++ b/src/main/scala/stringmatching/regex/Runtime.scala @@ -4,22 +4,97 @@ import collection.immutable.ArraySeq import collection.mutable import scala.util.boundary, boundary.break -import Interpolators.{PatternElement, FormatPattern} +import Interpolators.{Pattern, PatternElement, FormatPattern} object Runtime: - import PatternLive.* - - /** Type-erased extractor method, returning a value corresponding to the elements and levels. The + /** Type-erased extractor method, returning a value corresponding to the pattern and levels. The * caller is responsible for ensuring that `value` corresponds to the number of levels, i.e. A * `String` for 0 levels, a `Seq[String]` for 1 level, a `Seq[Seq[String]]` for 2 levels, etc. - * The caller is also responsible for ensuring that `Out` corresponds to the pattern elements, - * wrapped in Seq as many times as `levels`. + * The caller is also responsible for ensuring that `Out` corresponds to the pattern, wrapped in + * `Seq` as many times as `levels`. */ - def extract[In, Out](elements: Seq[PatternElement], levels: Int)(value: In): Out = - PatternLive(elements, levels).unapply(value).asInstanceOf[Out] + def unsafeExtract[In, Out](pattern: Pattern, levels: Int)(value: In): Out = + pattern match + case Pattern.Literal(glob) => + PatternLive0.extract(glob, value, levels).asInstanceOf[Out] + case Pattern.Single(glob, pattern) => + PatternLive1.extract(glob, pattern, value, levels).asInstanceOf[Out] + case Pattern.Multiple(glob, elements) => + PatternLiveN.extract(glob, elements, value, levels).asInstanceOf[Out] + + private object PatternLive0: + + private def unapply0(glob: String, scrutinee: String): Boolean = + StringContext.glob(Seq(glob), scrutinee).isDefined + + private def unapplyN(glob: String, scrutinee: Any, level: Int): Boolean = + level match + case 0 => unapply0(glob, scrutinee.asInstanceOf[String]) + case i => + val stageN1 = scrutinee.asInstanceOf[Seq[Any]].map(unapplyN(glob, _, i - 1)) + stageN1.asInstanceOf[Seq[Boolean]].forall(identity) + + def extract[Base](glob: String, scrutinee: Base, levels: Int): Boolean = + unapplyN(glob, scrutinee, levels) + end PatternLive0 + + private def foldGlobs(acc: Seq[String], self: PatternElement): Seq[String] = + self match + case PatternElement.Glob(pattern) => acc :+ pattern + case PatternElement.Split(_, pattern) => acc :+ pattern + case PatternElement.SplitEmpty(splitOn, pattern) => acc :+ pattern + case PatternElement.Format(_, pattern) => acc :+ pattern + + private def process(globbed: String, self: PatternElement): String | Seq[String] | Option[Any] = + import scala.language.unsafeNulls + self match + case PatternElement.Glob(_) => globbed + case PatternElement.Split(splitOn, _) => + globbed.split(splitOn).toIndexedSeq + case PatternElement.SplitEmpty(splitOn, _) => + val res0 = globbed.split(splitOn) + if res0.isEmpty then ArraySeq.empty[String] + else if res0(0).isEmpty then res0.toIndexedSeq.tail + else res0.toIndexedSeq + case PatternElement.Format(format, _) => + format match + case FormatPattern.AsInt => globbed.toIntOption + case FormatPattern.AsLong => globbed.toLongOption + case FormatPattern.AsDouble => globbed.toDoubleOption + case FormatPattern.AsFloat => globbed.toFloatOption + end match + end process + + private object PatternLive1: + private def unapply0(glob: String, pattern: PatternElement, scrutinee: String): Option[Any] = + val globs = foldGlobs(Vector(glob), pattern) + StringContext.glob(globs, scrutinee) match + case None => None + case Some(stage1) => + process(stage1.head, pattern) match + case opt: Option[Any] => opt // from format + case other => Some(other) + end match + end unapply0 + + private def unapplyN(glob: String, pattern: PatternElement, scrutinee: Any, level: Int) + : Option[Any] = + level match + case 0 => unapply0(glob, pattern, scrutinee.asInstanceOf[String]) + case i => + val stageN1 = scrutinee.asInstanceOf[Seq[Any]].map(unapplyN(glob, pattern, _, i - 1)) + val stageN1Refined = stageN1.asInstanceOf[Seq[Option[Any]]] + if stageN1Refined.forall(_.isDefined) then Some(stageN1Refined.map(_.get)) + else None + + def extract[Base](glob: String, pattern: PatternElement, scrutinee: Base, levels: Int) + : Option[Any] = + unapplyN(glob, pattern, scrutinee, levels) + end PatternLive1 + + private object PatternLiveN: - private object PatternLive: private class ArraySeqBuilderProduct( arr: Array[mutable.Builder[AnyRef, ArraySeq[AnyRef]]]) extends Product: @@ -28,86 +103,59 @@ object Runtime: def productElement(n: Int): Any = arr(n).result() override def toString: String = arr.mkString("ArraySeqBuilderProduct(...)") end ArraySeqBuilderProduct - end PatternLive - - private case class PatternLive(elements: Seq[PatternElement], levels: Int): - - private def unapply0(scrutinee: String, slots: Int): Option[Any] | Boolean = - def foldGlobs(acc: Seq[String], self: PatternElement): Seq[String] = - self match - case PatternElement.Glob(pattern) => acc :+ pattern - case PatternElement.Split(_, pattern) => acc :+ pattern - case PatternElement.SplitEmpty(splitOn, pattern) => acc :+ pattern - case PatternElement.Format(_, pattern) => acc :+ pattern - val globs = elements.foldLeft(Vector.empty[String]: Seq[String])(foldGlobs) + + private def unapply0(glob: String, elements: Seq[PatternElement], scrutinee: String) + : Option[Any] = + val globs = elements.foldLeft(Vector(glob))(foldGlobs) StringContext.glob(globs, scrutinee) match - case None => - if slots == 0 then false else None + case None => None case Some(stage1) => - if slots == 0 then true - else - def process(globbed: String, self: PatternElement): String | Seq[String] | Option[Any] = - self match - case PatternElement.Glob(_) => globbed - case PatternElement.Split(splitOn, _) => globbed.split(splitOn).toIndexedSeq - case PatternElement.SplitEmpty(splitOn, _) => - val res0 = globbed.split(splitOn) - if res0.isEmpty then ArraySeq.empty[String] - else if res0(0).isEmpty then res0.toIndexedSeq.tail - else res0.toIndexedSeq - case PatternElement.Format(format, _) => - format match - case FormatPattern.AsInt => globbed.toIntOption - case FormatPattern.AsLong => globbed.toLongOption - case FormatPattern.AsDouble => globbed.toDoubleOption - case FormatPattern.AsFloat => globbed.toFloatOption - if slots == 1 then - process(stage1.head, elements.last) match - case opt: Option[Any] => opt // from format - case other => Some(other) - else - boundary: - val state = new Array[AnyRef](slots) - stage1 - .lazyZip(elements.drop(1)) - .lazyZip(0 until slots) - .foreach: (globbed, element, index) => - process(globbed, element) match - case None => break(None) // format failed - case Some(value) => - state(index) = value.asInstanceOf[AnyRef] // format succeeded - case value => state(index) = value.asInstanceOf[AnyRef] // no format - Some(Tuple.fromArray(state)) - end if + boundary: + var _state: Array[AnyRef] | Null = null + def state: Array[AnyRef] = + val read = _state + if read == null then + val init = new Array[AnyRef](elements.size) + try init + finally _state = init + else read + end if + end state + stage1 + .lazyZip(elements) + .lazyZip(0 until elements.size) + .foreach: (globbed, element, index) => + process(globbed, element) match + case None => break(None) // format failed + case Some(value) => + state(index) = value.asInstanceOf[AnyRef] // format succeeded + case value => state(index) = value.asInstanceOf[AnyRef] // no format + Some(Tuple.fromArray(state)) end match end unapply0 - private def unapplyN(scrutinee: Any, slots: Int, level: Int): Option[Any] | Boolean = + private def unapplyN(glob: String, elements: Seq[PatternElement], scrutinee: Any, level: Int) + : Option[Any] = level match - case 0 => unapply0(scrutinee.asInstanceOf[String], slots) + case 0 => unapply0(glob, elements, scrutinee.asInstanceOf[String]) case i => - val stageN1 = scrutinee.asInstanceOf[Seq[Any]].map(unapplyN(_, slots, i - 1)) - if slots == 0 then stageN1.asInstanceOf[Seq[Boolean]].forall(identity) - else - val stageN1Refined = stageN1.asInstanceOf[Seq[Option[Any]]] - if stageN1Refined.forall(_.isDefined) then - if slots == 1 then Some(stageN1Refined.map(_.get)) - else - val state = Array.fill(slots)(ArraySeq.newBuilder[AnyRef]) - stageN1Refined.foreach: res => - val tup = res.get.asInstanceOf[Tuple] - var idx = 0 - while idx < slots do - state(idx) += tup.productElement(idx).asInstanceOf[AnyRef] - idx += 1 - Some(Tuple.fromProduct(new ArraySeqBuilderProduct(state))) - else None - end if + val stageN1 = scrutinee.asInstanceOf[Seq[Any]].map(unapplyN(glob, elements, _, i - 1)) + val stageN1Refined = stageN1.asInstanceOf[Seq[Option[Any]]] + if stageN1Refined.forall(_.isDefined) then + val state = Array.fill(elements.size)(ArraySeq.newBuilder[AnyRef]) + stageN1Refined.foreach: res => + val tup = res.get.asInstanceOf[Tuple] + var idx = 0 + while idx < elements.size do + state(idx) += tup.productElement(idx).asInstanceOf[AnyRef] + idx += 1 + Some(Tuple.fromProduct(new ArraySeqBuilderProduct(state))) + else None end if - /** */ - def unapply[Base](scrutinee: Base): Option[Any] | Boolean = - unapplyN(scrutinee, slots = elements.size - 1, levels) - end PatternLive + def extract[Base](glob: String, elements: Seq[PatternElement], scrutinee: Base, levels: Int) + : Option[Any] = + unapplyN(glob, elements, scrutinee, levels) + end PatternLiveN end Runtime