Skip to content

Commit

Permalink
Merge pull request #434 from alexarchambault/fix-with-full-help
Browse files Browse the repository at this point in the history
Fix WithFullHelp runtime issues
  • Loading branch information
alexarchambault authored Nov 8, 2022
2 parents cacc9a0 + 07f37cb commit 8da0825
Show file tree
Hide file tree
Showing 11 changed files with 218 additions and 82 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package caseapp.core.help

import caseapp.core.parser.Parser
import caseapp.{ExtraName, HelpMessage}

abstract class WithFullHelpCompanion {

implicit def parser[T, D](implicit underlying: Parser.Aux[T, D]): Parser[WithFullHelp[T]] =
Parser.nil
.addAll[WithHelp[T]].apply
.add[Boolean](
"helpFull",
default = Some(false),
extraNames = Seq(ExtraName("fullHelp")),
helpMessage = Some(HelpMessage("Print help message, including hidden options, and exit"))
)
.as[WithFullHelp[T]]

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package caseapp.core.help

import caseapp.core.parser.Parser
import caseapp.{ExtraName, HelpMessage}

abstract class WithFullHelpCompanion {

implicit def parser[T: Parser]: Parser[WithFullHelp[T]] =
Parser.nil
.addAll[WithHelp[T]](using WithHelp.parser[T])
.add[Boolean](
"helpFull",
default = Some(false),
extraNames = Seq(ExtraName("fullHelp")),
helpMessage = Some(HelpMessage("Print help message, including hidden options, and exit"))
)
.as[WithFullHelp[T]]

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,6 @@ import scala.quoted.{given, *}

object LowPriorityParserImplicits {

// adapted from https://github.com/com-lihaoyi/upickle/blob/3f9c98c3983a98c5e0716e42314d5f825bafa9e4/implicits/src-3/upickle/implicits/macros.scala#L51-L70
inline private def fieldLabels[T] = ${ fieldLabelsImpl[T] }
private def fieldLabelsImpl[T](using Quotes, Type[T]): Expr[List[String]] =
import quotes.reflect.*
val fields = TypeRepr.of[T].typeSymbol
.primaryConstructor
.paramSymss
.flatten
.filterNot(_.isType)
.map(_.name)

Expr.ofList(fields.map(Expr(_)))
end fieldLabelsImpl

extension (comp: Expr.type) {
def ofOption[T](opt: Option[Expr[T]])(using Quotes, Type[T]): Expr[Option[T]] =
opt match {
Expand All @@ -37,15 +23,136 @@ object LowPriorityParserImplicits {
fullName.takeWhile(_ != '[').split('.').last
}

private def fields[U](using
q: Quotes,
t: Type[U]
): List[(q.reflect.Symbol, q.reflect.TypeRepr)] = {
import quotes.reflect.*
val tpe = TypeRepr.of[U]
val sym = TypeRepr.of[U] match {
case AppliedType(base, params) =>
base.typeSymbol
case _ =>
TypeTree.of[U].symbol
}

// Many things inspired by https://github.com/plokhotnyuk/jsoniter-scala/blob/8f39e1d45fde2a04984498f036cad93286344c30/jsoniter-scala-macros/shared/src/main/scala-3/com/github/plokhotnyuk/jsoniter_scala/macros/JsonCodecMaker.scala#L564-L613
// and around, here

def typeArgs(tpe: TypeRepr): List[TypeRepr] = tpe match
case AppliedType(_, typeArgs) => typeArgs.map(_.dealias)
case _ => Nil

def resolveParentTypeArg(
child: Symbol,
fromNudeChildTarg: TypeRepr,
parentTarg: TypeRepr,
binding: Map[String, TypeRepr]
): Map[String, TypeRepr] =
if (fromNudeChildTarg.typeSymbol.isTypeParam) { // todo: check for paramRef instead ?
val paramName = fromNudeChildTarg.typeSymbol.name
binding.get(paramName) match
case None => binding.updated(paramName, parentTarg)
case Some(oldBinding) =>
if (oldBinding =:= parentTarg) binding
else sys.error(
s"Type parameter $paramName in class ${child.name} appeared in the constructor of " +
s"${tpe.show} two times differently, with ${oldBinding.show} and ${parentTarg.show}"
)
}
else if (fromNudeChildTarg <:< parentTarg)
binding // TODO: assupe parentTag is covariant, get covariance from tycon type parameters.
else
(fromNudeChildTarg, parentTarg) match
case (AppliedType(ctycon, ctargs), AppliedType(ptycon, ptargs)) =>
ctargs.zip(ptargs).foldLeft(resolveParentTypeArg(child, ctycon, ptycon, binding)) {
(b, e) =>
resolveParentTypeArg(child, e._1, e._2, b)
}
case _ =>
sys.error(s"Failed unification of type parameters of ${tpe.show} from child $child - " +
s"${fromNudeChildTarg.show} and ${parentTarg.show}")

def resolveParentTypeArgs(
child: Symbol,
nudeChildParentTags: List[TypeRepr],
parentTags: List[TypeRepr],
binding: Map[String, TypeRepr]
): Map[String, TypeRepr] =
nudeChildParentTags.zip(parentTags).foldLeft(binding)((s, e) =>
resolveParentTypeArg(child, e._1, e._2, s)
)

val nudeSubtype = TypeIdent(sym).tpe
val baseConst = nudeSubtype.memberType(sym.primaryConstructor)
val tpeArgsFromChild = typeArgs(tpe)
val const = baseConst match {
case MethodType(_, _, resTp) => resTp
case PolyType(names, _, resPolyTp) =>
val targs = typeArgs(tpe)
val tpBinding = resolveParentTypeArgs(sym, tpeArgsFromChild, targs, Map.empty)
val ctArgs = names.map { name =>
tpBinding.get(name).getOrElse(sys.error(
s"Type parameter $name of $sym can't be deduced from " +
s"type arguments of ${tpe.show}. Please provide a custom implicitly accessible codec for it."
))
}
val polyRes = resPolyTp match
case MethodType(_, _, resTp) => resTp
case other => other // hope we have no multiple typed param lists yet.
if (ctArgs.isEmpty) polyRes
else polyRes match
case AppliedType(base, _) => base.appliedTo(ctArgs)
case AnnotatedType(AppliedType(base, _), annot) =>
AnnotatedType(base.appliedTo(ctArgs), annot)
case _ => polyRes.appliedTo(ctArgs)
case other =>
sys.error(s"Primary constructior for ${tpe.show} is not MethodType or PolyType but $other")
}
sym.primaryConstructor
.paramSymss
.flatten
.map(f => (f, f.tree))
.collect {
case (sym, v: ValDef) =>
(sym, v.tpt.tpe)
}
}

inline private def checkFieldCount[T, N <: Int]: Unit =
${ checkFieldCountImpl[T, N] }
private def checkFieldCountImpl[T, N <: Int](using Quotes, Type[T], Type[N]): Expr[Unit] = {
import quotes.reflect.*

val viaMirror = TypeRepr.of[N] match {
case ConstantType(c) =>
c.value match {
case n: Int => n
case other => sys.error(
s"Expected literal integer type, got ${Type.show[N]} ($other, ${other.getClass})"
)
}
case other =>
sys.error(s"Expected literal integer type, got ${Type.show[N]} ($other, ${other.getClass})")
}

val viaReflect = fields[T].length

assert(
viaMirror == viaReflect,
s"Got Unexpected number of field via reflection for type ${Type.show[T]} (got $viaReflect, expected $viaMirror)"
)

'{ () }
}

inline private def tupleParser[T]: Parser[_] =
${ tupleParserImpl[T] }
private def tupleParserImpl[T](using q: Quotes, t: Type[T]): Expr[Parser[_]] = {
import quotes.reflect.*
val tSym = TypeTree.of[T].symbol
val origin = shortName[T]
val fields = tSym.primaryConstructor.paramSymss.flatten.map(f => (f, f.tree)).collect {
case (sym, v: ValDef) => (sym, v.tpt.tpe)
}
val tSym = TypeTree.of[T].symbol
val origin = shortName[T]
val fields0 = fields[T]

val defaultMap: Map[String, Expr[Any]] = {
val comp =
Expand All @@ -59,7 +166,7 @@ object LowPriorityParserImplicits {
}
bodyOpt match {
case Some(body) =>
val names = fields
val names = fields0
.map(_._1)
.filter(_.flags.is(Flags.HasDefault))
.map(_.name)
Expand All @@ -73,7 +180,7 @@ object LowPriorityParserImplicits {
}
}

val parserExpr = fields
val parserExpr = fields0
.foldRight[(TypeRepr, Expr[Parser[_]])]((TypeRepr.of[EmptyTuple], '{ NilParser })) {
case ((sym, symTpe), (tailType, tailParserExpr)) =>
val isRecursive = sym.annotations.exists(_.tpe =:= TypeRepr.of[caseapp.Recurse])
Expand Down Expand Up @@ -189,8 +296,9 @@ object LowPriorityParserImplicits {
}

trait LowPriorityParserImplicits {
inline given derive[T](using m: Mirror.ProductOf[T]): Parser[T] =
LowPriorityParserImplicits.tupleParser[T].asInstanceOf[Parser[m.MirroredElemTypes]].map(
m.fromTuple
)
inline given derive[T](using m: Mirror.ProductOf[T]): Parser[T] = {
LowPriorityParserImplicits.checkFieldCount[T, Tuple.Size[m.MirroredElemTypes]]
val parser = LowPriorityParserImplicits.tupleParser[T]
parser.asInstanceOf[Parser[m.MirroredElemTypes]].map(m.fromTuple)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ abstract class Parser[T] extends ParserMethods[T] {

final def withFullHelp: Parser[WithFullHelp[T]] = {
implicit val parser: Parser[T] = this
val p = ParserWithNameFormatter(Parser[WithFullHelp[T]], defaultNameFormatter)
val p0 = WithFullHelp.parser[T]
val p = ParserWithNameFormatter(p0, defaultNameFormatter)
if (defaultIgnoreUnrecognized)
p.ignoreUnrecognized
else if (defaultStopAtFirstUnrecognized)
Expand Down
72 changes: 21 additions & 51 deletions core/shared/src/main/scala-3/caseapp/core/parser/ParserOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,33 @@ class ParserOps[T <: Tuple](val parser: Parser[T]) extends AnyVal {
ConsParser(argument, parser)
}

// def addAll[U]: ParserOps.AddAllHelper[T, D, U] =
// new ParserOps.AddAllHelper[T, D, U](parser)
def addAll[H](using headParser: Parser[H]): Parser[H *: T] =
RecursiveConsParser(headParser, parser)

def as[F](implicit helper: ParserOps.AsHelper[T, F]): Parser[F] =
helper(parser)
def as[F](using
m: Mirror.ProductOf[F],
ev: T =:= ParserOps.Reverse[m.MirroredElemTypes],
ev0: ParserOps.Reverse[ParserOps.Reverse[m.MirroredElemTypes]] =:= m.MirroredElemTypes
): Parser[F] =
parser
.map(ev)
.map(ParserOps.reverse[ParserOps.Reverse[m.MirroredElemTypes]])
.map(ev0)
.map(m.fromTuple)

def tupled: Parser[ParserOps.Reverse[T]] =
parser.map(ParserOps.reverse)

def to[F](implicit helper: ParserOps.ToHelper[T, F]): Parser[F] =
helper(parser)
def to[F](using
m: Mirror.ProductOf[F],
ev: T =:= m.MirroredElemTypes
): Parser[F] =
parser.map(ev).map(m.fromTuple)

def toTuple[P](implicit helper: ParserOps.ToTupleHelper[T, P]): Parser[P] =
helper(parser)
def toTuple[P <: Tuple](using
m: Mirror.ProductOf[T] { type MirroredElemTypes = P }
): Parser[P] =
parser.map(Tuple.fromProductTyped[T])

}

Expand All @@ -61,47 +74,4 @@ object ParserOps {
def reverse[T <: Tuple](t: T): Reverse[T] =
Tuple.fromArray(t.toArray.reverse).asInstanceOf[Reverse[T]]

def reverse0[T <: Tuple](t: Reverse[T]): T =
Tuple.fromArray(t.toArray.reverse).asInstanceOf[T]

// class AddAllHelper[T <: Tuple, D <: Tuple, U](val parser: Parse[T]) extends AnyVal {
// def apply[DU](implicit other: Parser[U]): Parser[U :: T] =
// RecursiveConsParser(other, parser)
// }

abstract class AsHelper[T, F] {
def apply(parser: Parser[T]): Parser[F]
}

inline implicit def defaultAsHelper[T <: Tuple, F](implicit
m: Mirror.ProductOf[F],
ev: T =:= Reverse[m.MirroredElemTypes]
): AsHelper[T, F] = {
parser =>
parser.map(reverse0).map(p => m.fromProduct(p))
}

abstract class ToHelper[T, F] {
def apply(parser: Parser[T]): Parser[F]
}

implicit def defaultToHelper[F, T <: Tuple](implicit
m: Mirror.ProductOf[F]
): ToHelper[T, F] = {
parser =>
parser.map(m.fromProduct)
}

sealed abstract class ToTupleHelper[T, P] {
def apply(parser: Parser[T]): Parser[P]
}

implicit def defaultToTupleHelper[T <: Product](implicit
m: scala.deriving.Mirror.ProductOf[T]
): ToTupleHelper[T, m.MirroredElemTypes] =
new ToTupleHelper[T, m.MirroredElemTypes] {
def apply(parser: Parser[T]): Parser[m.MirroredElemTypes] =
parser.map(Tuple.fromProductTyped[T](_))
}

}
2 changes: 1 addition & 1 deletion core/shared/src/main/scala/caseapp/core/help/Help.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ import caseapp.HelpMessage

def withFullHelp: Help[WithFullHelp[T]] = {
final case class Dummy()
val parser: Parser[WithFullHelp[Dummy]] = Parser.derive
val parser: Parser[WithFullHelp[Dummy]] = WithFullHelp.parser
val helpArgs = parser.args

this.withArgs(helpArgs ++ args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package caseapp.core.help

import caseapp.core.Error
import caseapp.{ExtraName, Group, HelpMessage, Recurse}
import caseapp.core.parser.Parser

final case class WithFullHelp[T](
@Recurse
Expand All @@ -14,3 +15,5 @@ final case class WithFullHelp[T](
def map[U](f: T => U): WithFullHelp[U] =
copy(withHelp = withHelp.map(f))
}

object WithFullHelp extends WithFullHelpCompanion
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package caseapp.core.parser

import caseapp.core.{Arg, Error, Indexed}
import caseapp.core.help.{WithFullHelp, WithHelp}
import caseapp.core.RemainingArgs
import caseapp.core.util.Formatter
import caseapp.Name
Expand Down
3 changes: 2 additions & 1 deletion project/Mima.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ import scala.sys.process._
object Mima {

def binaryCompatibilityVersions: Set[String] =
Seq("git", "tag", "--merged", "HEAD^", "--contains", "34a70d9ff42e0e75847ee7157bf8a81b675d7813")
Seq("git", "tag", "--merged", "HEAD^", "--contains", "cacc9a0340fde584a10a814037db6a8947881931")
.!!
.linesIterator
.map(_.trim)
.filter(_.startsWith("v"))
.map(_.stripPrefix("v"))
.filter(_ != "2.1.0-M19")
.toSet

def settings = Def.settings(
Expand Down
2 changes: 1 addition & 1 deletion project/Settings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ object Settings {
}

lazy val shared = Seq(
scalaVersion := scala213,
scalaVersion := scala3,
crossScalaVersions := Seq(scala212, scala213, scala3),
scalacOptions ++= Seq(
"-target:jvm-1.8",
Expand Down
Loading

0 comments on commit 8da0825

Please sign in to comment.