diff --git a/Loops/src/main/scala/scalaxy/loops.scala b/Loops/src/main/scala/scalaxy/loops.scala index 6a208ce5..5742ca69 100644 --- a/Loops/src/main/scala/scalaxy/loops.scala +++ b/Loops/src/main/scala/scalaxy/loops.scala @@ -114,21 +114,62 @@ package loops c.typeCheck(f.tree) match { case Function(List(param), body) => - def newIntVal(name: TermName, rhs: Tree) = - ValDef(NoMods, name, TypeTree(IntTpe), rhs) + def tc[T <: Tree](t: Tree, tpe: Type): T = + c.typeCheck(Block(t, Literal(Constant(()))), tpe) match { + case Block(List(tt), _) => tt.asInstanceOf[T] + } + + def newIntVal(name: TermName, rhs: Tree): ValDef = + tc(ValDef(NoMods, name, TypeTree(IntTpe), rhs), UnitTpe) - def newIntVar(name: TermName, rhs: Tree) = - ValDef(Modifiers(MUTABLE), name, TypeTree(IntTpe), rhs) + def newIntVar(name: TermName, rhs: Tree): ValDef = + tc(ValDef(Modifiers(MUTABLE), name, TypeTree(IntTpe), rhs), UnitTpe) + implicit def valRef(vd: ValDef) = new { + def apply() = c.typeCheck(Ident(vd.symbol), IntTpe) + } // Body expects a local constant: create a var outside the loop + a val inside it. val iVar = newIntVar(c.fresh("i"), start) - val iVal = newIntVal(param.name, Ident(iVar.name)) + val iVal = newIntVal(param.name, iVar()) val stepVal = newIntVal(c.fresh("step"), Literal(Constant(step))) val endVal = newIntVal(c.fresh("end"), end) + + println(s"i.tpe = ${iVal.tpe}") + /* + // Type-check a fake (ordered) block, to force creation of ValDef symbols: + val Block( + List( + iVar @ ValDef(_, _, _, _), + iVal @ ValDef(_, _, _, _), + stepVal @ ValDef(_, _, _, _), + endVal @ ValDef(_, _, _, _) + ), + _ + ) = c.typeCheck { + val iVarRaw = newIntVar(c.fresh("i"), start) + Block( + iVarRaw, + newIntVal(param.name, Ident(iVarRaw.name)), + newIntVal(c.fresh("step"), Literal(Constant(step))), + newIntVal(c.fresh("end"), end), + Literal(Constant(())) + ) + }*/ + + println("TYPECHECKED") + + // Replace any mention of the lambda parameter by a reference to iVal: + val replacedBody = new Transformer { override def transform(tree: Tree) = { + if (tree.symbol == param.symbol) + iVal() + else + super.transform(tree) + }}.transform(body) + val condition = Apply( Select( - Ident(iVar.name), + iVar(), newTermName( encode( if (step > 0) { @@ -139,7 +180,7 @@ package loops ) ) ), - List(Ident(endVal.name)) + List(endVal()) ) val iVarExpr = c.Expr[Unit](iVar) @@ -149,33 +190,39 @@ package loops val conditionExpr = c.Expr[Boolean](condition) // Body still refers to old function param symbol (which has same name as iVal). // We must wipe it out (alas, it's not local, so we must reset all symbols). - val bodyExpr = c.Expr[Unit](c.resetAllAttrs(body)) + //val bodyExpr = c.Expr[Unit](replacedBody)//c.resetAllAttrs(body)) - val incrExpr = c.Expr[Unit]( + val incr = //c.Expr[Unit]( Assign( - Ident(iVar.name), + iVar(), Apply( Select( - Ident(iVar.name), + iVar(), encode("+") ), - List(Ident(stepVal.name)) + List(stepVal()) ) ) + //) + val loopBody = c.Expr[Unit]( + Block( + iVal, + replacedBody, + incr, + Literal(Constant(())) + ) ) - val iVarRef = c.Expr[Int](Ident(iVar.name)) - val stepValRef = c.Expr[Int](Ident(stepVal.name)) - - reify { + + val res = reify { iVarExpr.splice endValExpr.splice stepValExpr.splice while (conditionExpr.splice) { - iValExpr.splice - bodyExpr.splice - incrExpr.splice + loopBody.splice } } + println(s"res = $res") + c.Expr[Unit](c.typeCheck(res.tree)) case _ => c.error(f.tree.pos, s"Unsupported function: $f") null