Skip to content

Commit e975f69

Browse files
author
Jan Paw
committed
SCL-4 remove junit
1 parent f4c1b1e commit e975f69

File tree

10 files changed

+135
-116
lines changed

10 files changed

+135
-116
lines changed

project/ScalaCLBuild.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,8 @@ object ScalaCLBuild extends Build {
7979
libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-reflect" % _),
8080
libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-compiler" % _),
8181
libraryDependencies ++= Seq(
82-
"junit" % "junit" % "4.10" % "test",
8382
"org.scalatest" % "scalatest_2.11" % "2.2.1" % "test",
84-
"org.scalamock" % "scalamock-scalatest-support_2.11" % "3.1.2",
85-
"com.novocode" % "junit-interface" % "0.8" % "test"
83+
"org.scalamock" % "scalamock-scalatest-support_2.11" % "3.1.2"
8684
)
8785
)
8886

src/test/scala/scalacl/BaseTest.scala

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ package scalacl
22

33
import org.scalamock.scalatest.MockFactory
44
import org.scalatest.{ FlatSpecLike, Matchers }
5-
import scalacl.impl.{ OpenCLCodeFlattening, CodeConversion }
6-
import scalaxy.components.{ FlatCode, WithRuntimeUniverse }
5+
6+
import scala.reflect.runtime.{ currentMirror => cm, universe => ru }
7+
import scala.tools.reflect.ToolBox
8+
import scalacl.impl.{ Vectorization, CodeConversion, OpenCLCodeFlattening }
9+
import scalaxy.components.FlatCode
710

811
trait BaseTest extends FlatSpecLike with Matchers with MockFactory {
912
def context[T](f: Context => T): T = {
@@ -14,30 +17,77 @@ trait BaseTest extends FlatSpecLike with Matchers with MockFactory {
1417
}
1518
}
1619

17-
trait RuntimeUniverseTest extends WithRuntimeUniverse {
20+
trait RuntimeUniverseTest {
21+
lazy val global = ru
22+
import global._
23+
24+
def verbose = false
25+
1826
private var nextId = 0L
1927

2028
def fresh(s: String) = synchronized {
2129
val v = nextId
2230
nextId += 1
2331
s + v
2432
}
33+
34+
def warning(pos: Position, msg: String) =
35+
println(msg + " (" + pos + ")")
36+
37+
def withSymbol[T <: Tree](sym: Symbol, tpe: Type = NoType)(tree: T): T = tree
38+
39+
def typed[T <: Tree](tree: T): T = {
40+
// if (tree.tpe == null && tree.tpe == NoType)
41+
// toolbox.typeCheck(tree.asInstanceOf[toolbox.u.Tree]).asInstanceOf[T]
42+
// else
43+
tree
44+
}
45+
46+
def inferImplicitValue(pt: Type): Tree =
47+
toolbox.inferImplicitValue(pt.asInstanceOf[toolbox.u.Type]).asInstanceOf[global.Tree]
48+
49+
lazy val toolbox = cm.mkToolBox()
50+
51+
def typeCheck(x: Expr[_]): Tree =
52+
typeCheck(x.tree)
53+
54+
def typeCheck(tree: Tree, pt: Type = WildcardType): Tree = {
55+
val ttree = tree.asInstanceOf[toolbox.u.Tree]
56+
if (ttree.tpe != null && ttree.tpe != NoType)
57+
tree
58+
else {
59+
try {
60+
toolbox.typecheck(
61+
ttree,
62+
pt = pt.asInstanceOf[toolbox.u.Type])
63+
} catch {
64+
case ex: Throwable =>
65+
throw new RuntimeException(s"Failed to typeCheck($tree, $pt): $ex", ex)
66+
}
67+
}.asInstanceOf[Tree]
68+
}
69+
70+
def resetLocalAttrs(tree: Tree): Tree = {
71+
toolbox.untypecheck(tree.asInstanceOf[toolbox.u.Tree]).asInstanceOf[Tree]
72+
}
73+
2574
}
2675

2776
trait CodeConversionTest extends CodeConversion with RuntimeUniverseTest {
2877
val global: reflect.api.Universe
78+
2979
import global._
3080

3181
def convertExpression(block: Expr[Unit], explicitParamDescs: Seq[ParamDesc] = Seq()) = {
3282
convertCode(typeCheck(block.tree, WildcardType), explicitParamDescs)
3383
}
3484

85+
def flatStatement(statements: Seq[String], values: Seq[String]): FlatCode[String] =
86+
FlatCode[String](statements = statements, values = values)
87+
3588
def flatAndConvertExpression(x: Expr[_]): FlatCode[String] = {
3689
flattenAndConvert(typeCheck(x))
3790
}
38-
39-
def flatCode(statements: Seq[String], values: Seq[String]): FlatCode[String] =
40-
FlatCode[String](statements = statements, values = values)
4191
}
4292

4393
trait CodeFlatteningTest extends OpenCLCodeFlattening with RuntimeUniverseTest {
@@ -67,3 +117,5 @@ trait CodeFlatteningTest extends OpenCLCodeFlattening with RuntimeUniverseTest {
67117
flatten(typeCheck(x.tree, WildcardType), inputSymbols, owner)
68118
}
69119
}
120+
121+
trait CodeVectorizationTest extends Vectorization with RuntimeUniverseTest

src/test/scala/scalacl/CLFunctionTest.scala

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,42 +29,21 @@
2929
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3030
*/
3131
package scalacl
32-
import impl._
3332

34-
import org.junit._
35-
import Assert._
33+
class CLFunctionTest extends BaseTest {
3634

37-
class CLFunctionTest {
38-
@Ignore
39-
@Test
40-
def simple() {
41-
implicit val context = Context.best
42-
try {
35+
behavior of "CLFunction"
36+
37+
ignore should "wrap scalar function" in context {
38+
implicit context =>
4339
val a = CLArray[Int](1, 2, 3)
4440
val v = 10
45-
// task {
46-
// a(1) = 10 * v
47-
// }
48-
// println(a.toSeq)
49-
// assertEquals(Seq(0, 100, 0), a.toSeq)
5041

51-
val f: CLFunction[Int, Float] = (x: Int) => {
42+
val clFunction: CLFunction[Int, Float] = (x: Int) => {
5243
x * 2.0f * v
5344
}
5445

55-
println(f)
56-
println(f.value)
57-
println(f.functionKernel)
58-
assertArrayEquals(
59-
Array(20.0f, 40.0f, 60.0f),
60-
a.map(f).toArray,
61-
0)
62-
} catch {
63-
case ex: Throwable =>
64-
ex.printStackTrace()
65-
throw ex
66-
} finally {
67-
context.release()
68-
}
46+
val clResult = a.map(clFunction).toArray
47+
clResult should equal(Array(20.0f, 40.0f, 60.0f))
6948
}
7049
}
Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
package scalacl
2-
import org.junit._
3-
import Assert._
42

5-
class MatrixTest {
3+
class MatrixTest extends BaseTest {
4+
behavior of "Matrix"
65

76
case class Matrix(data: CLArray[Float], rows: Int, columns: Int)(implicit context: Context) {
87
def this(rows: Int, columns: Int)(implicit context: Context) =
@@ -11,7 +10,7 @@ class MatrixTest {
1110
this(n, n)
1211
}
1312

14-
def mult(a: Matrix, b: Matrix, out: Matrix)(implicit context: Context) = {
13+
def mul(a: Matrix, b: Matrix, out: Matrix)(implicit context: Context) = {
1514
assert(a.columns == b.rows)
1615
assert(a.rows == out.rows)
1716
assert(b.columns == out.columns)
@@ -43,45 +42,40 @@ class MatrixTest {
4342
}
4443
}
4544

46-
@Ignore
47-
@Test
48-
def testMatrix2() {
49-
implicit val context = Context.best
45+
ignore should "perform multiplication of two matrix" in context {
46+
implicit context =>
47+
val n = 10
48+
val a = new Matrix(n)
49+
val b = new Matrix(n)
50+
val out = new Matrix(n)
5051

51-
val n = 10
52-
val out = new Matrix(n)
53-
val outData = out.data
54-
kernel {
55-
// This block will either be converted to an OpenCL kernel or cause compilation error
56-
// It captures out.data, a.data and b.data
57-
for (i <- 0 until 10; j <- 0 until 20) {
58-
// TODO chain map and sum (to avoid creating a builder here !)
59-
// outData(i * 30 + j) =
60-
// (0 until 30).map(k => {
61-
// aData(i * 30 + k) * bData(k * 30 + j)
62-
// }).sum
63-
var tot = 0f
64-
for (k <- 0 until 30) {
65-
//tot = tot + aData(i * aColumns + k) * bData(k * bColumns + j)
66-
tot = 10000
52+
//TODO add some verification
53+
mul(a, b, out)
54+
}
55+
56+
ignore should "generate kernel with matrix type" in context {
57+
implicit context =>
58+
val n = 10
59+
val out = new Matrix(n)
60+
val outData = out.data
61+
kernel {
62+
// This block will either be converted to an OpenCL kernel or cause compilation error
63+
// It captures out.data, a.data and b.data
64+
for (i <- 0 until 10; j <- 0 until 20) {
65+
// TODO chain map and sum (to avoid creating a builder here !)
66+
// outData(i * 30 + j) =
67+
// (0 until 30).map(k => {
68+
// aData(i * 30 + k) * bData(k * 30 + j)
69+
// }).sum
70+
var tot = 0f
71+
for (k <- 0 until 30) {
72+
//tot = tot + aData(i * aColumns + k) * bData(k * bColumns + j)
73+
tot = 10000
74+
}
75+
outData(i * 10 + j) = tot
6776
}
68-
outData(i * 10 + j) = tot
6977
}
70-
}
78+
//TODO add some verification
7179
}
7280

73-
@Ignore
74-
@Test
75-
def testMatrix() {
76-
implicit val context = Context.best
77-
78-
val n = 10
79-
val a = new Matrix(n)
80-
val b = new Matrix(n)
81-
val out = new Matrix(n)
82-
83-
mult(a, b, out)
84-
85-
println(out.data)
86-
}
8781
}

src/test/scala/scalacl/impl/DefaultScheduledDataTest.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,11 @@ import collection.mutable.ArrayBuffer
3636
import com.nativelibs4java.opencl.CLEvent
3737
import com.nativelibs4java.opencl.MockEvent
3838

39+
import scalaxy.components.WithRuntimeUniverse
40+
3941
class DefaultScheduledDataTest
40-
extends BaseTest {
42+
extends BaseTest
43+
with WithRuntimeUniverse {
4144

4245
behavior of "DefaultScheduledData"
4346

src/test/scala/scalacl/impl/InlinedCollectionsTest.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ package scalacl
3232
import impl._
3333

3434
import InlinedCollections._
35+
import scalaxy.components.WithRuntimeUniverse
3536

3637
class InlinedCollectionsTest
37-
extends BaseTest {
38+
extends BaseTest
39+
with WithRuntimeUniverse {
3840

3941
behavior of "InlinedCollections"
4042

src/test/scala/scalacl/impl/OpenCLConverterTest.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,12 @@ package impl
3333

3434
class OpenCLConverterTest
3535
extends BaseTest
36-
with OpenCLConverter
3736
with CodeConversionTest {
3837

3938
behavior of "OpenClConverter"
4039

4140
ignore should "convert touple" in {
42-
val flattenCode = flatCode(
41+
val flattenCode = flatStatement(
4342
Seq("const int x = 10;"),
4443
Seq("x", "(x * 2)")
4544
)
@@ -56,7 +55,7 @@ class OpenCLConverterTest
5655
ignore should "convert simple function: cos" in {
5756
import scala.math._
5857

59-
val flattenCode = flatCode(
58+
val flattenCode = flatStatement(
6059
Seq(),
6160
Seq("cos((float)10.0)")
6261
)

src/test/scala/scalacl/impl/ScheduledDataTest.scala

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,18 @@
3131
package scalacl
3232
package impl
3333

34-
import org.junit._
35-
import Assert._
36-
import org.hamcrest.CoreMatchers._
37-
3834
import com.nativelibs4java.opencl.CLEvent
3935
import com.nativelibs4java.opencl.MockEvent
4036
import com.nativelibs4java.opencl.library.OpenCLLibrary._
4137
import com.nativelibs4java.opencl.library.IOpenCLLibrary._
4238

4339
import scala.collection.mutable.ArrayBuffer
4440

45-
class ScheduledDataTest {
46-
@Test
47-
def simpleOpWithoutEvent() {
41+
class ScheduledDataTest extends BaseTest {
42+
behavior of "ScheduledDate"
43+
44+
//TODO create higher granularization
45+
ignore should "perform some reads and writes" in {
4846
val inEvt = new MockEvent(1)
4947
val outEvt = new MockEvent(2)
5048
val opEvt = new MockEvent(3)
@@ -63,29 +61,25 @@ class ScheduledDataTest {
6361
}
6462

6563
ScheduledData.schedule(Array(in), Array(out), events => {
66-
assertEquals(Seq(inEvt, outEvt), events.toSeq)
64+
Seq(inEvt, outEvt) should equal(events.toSeq)
6765
opEvt
6866
})
69-
assertNotNull(opEvt.completionCallback)
7067

71-
assertEquals(
72-
"in calls don't match",
73-
Seq(
74-
'startRead -> List(Nil),
75-
'endRead -> List(opEvt)),
76-
in.calls)
77-
assertEquals(
78-
"out calls don't match",
79-
Seq(
80-
'startWrite -> List(List(inEvt)),
81-
'endWrite -> List(opEvt)),
82-
out.calls)
68+
opEvt.completionCallback should not be null
69+
70+
Seq(
71+
'startRead -> List(Nil),
72+
'endRead -> List(opEvt)
73+
) should equal(in.calls)
74+
75+
Seq(
76+
'startWrite -> List(List(inEvt)),
77+
'endWrite -> List(opEvt)
78+
) should equal(out.calls)
8379

8480
opEvt.completionCallback.callback(CL_COMPLETE)
85-
for (d <- Seq(in, out))
86-
assertEquals(
87-
Seq(
88-
'eventCompleted -> List(opEvt)),
89-
d.calls)
81+
Seq(in, out).foreach {
82+
d => Seq('eventCompleted -> List(opEvt)) should equal(d.calls)
83+
}
9084
}
9185
}

src/test/scala/scalacl/impl/SymbolKindsTest.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,12 @@ package impl
3333

3434
class SymbolKindsTest
3535
extends BaseTest
36-
with RuntimeUniverseTest
37-
with SymbolKinds {
36+
with SymbolKinds
37+
with RuntimeUniverseTest {
38+
import global._
3839

3940
behavior of "Symbols kind resolving"
4041

41-
import global._
42-
4342
class EmptyClass()
4443
case class EmptyCaseClass()
4544
class ImmutableClass(a: Int, b: Int)

0 commit comments

Comments
 (0)