Skip to content

Commit a3ecb04

Browse files
Craigacpkarllessard
authored andcommitted
Adding a TF-text custom op test.
1 parent 41da07f commit a3ecb04

3 files changed

Lines changed: 68 additions & 23 deletions

File tree

tensorflow-core/tensorflow-core-api/pom.xml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
<java.module.name>org.tensorflow.core.api</java.module.name>
1919
<ndarray.version>0.4.0</ndarray.version>
2020
<truth.version>1.1.5</truth.version>
21+
<test.download.skip>false</test.download.skip>
22+
<test.download.folder>${project.build.directory}/tf-text-download/</test.download.folder>
2123
</properties>
2224

2325
<dependencies>
@@ -260,6 +262,35 @@
260262
</descriptorRefs>
261263
</configuration>
262264
</plugin>
265+
<plugin>
266+
<groupId>org.codehaus.mojo</groupId>
267+
<artifactId>exec-maven-plugin</artifactId>
268+
<version>3.1.0</version>
269+
<executions>
270+
<execution>
271+
<!--
272+
Download TF-Text for the custom op loading test.
273+
-->
274+
<id>dist-download</id>
275+
<phase>test-compile</phase>
276+
<goals>
277+
<goal>exec</goal>
278+
</goals>
279+
<configuration>
280+
<skip>${test.download.skip}</skip>
281+
<executable>bash</executable>
282+
<arguments>
283+
<argument>scripts/test_download.sh</argument>
284+
<argument>${test.download.folder}</argument>
285+
</arguments>
286+
<environmentVariables>
287+
<PLATFORM>${native.classifier}</PLATFORM>
288+
</environmentVariables>
289+
<workingDirectory>${project.basedir}</workingDirectory>
290+
</configuration>
291+
</execution>
292+
</executions>
293+
</plugin>
263294
</plugins>
264295
</build>
265296
</project>
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/bin/bash
2+
set -e
3+
4+
DOWNLOAD_FOLDER="$1"
5+
6+
case ${PLATFORM:-} in
7+
'linux-x86_64')
8+
TEXT_WHEEL_URL='https://files.pythonhosted.org/packages/20/a0/bdbf2a11141f1c93e572364d13c42537cfe811b747a0bbb58fdd904f3960/tensorflow_text-2.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl'
9+
;;
10+
'macosx-x86_64')
11+
TEXT_WHEEL_URL='https://files.pythonhosted.org/packages/8a/fe/a2f19d3d3ab834c3fa1007c970b0b86573beb929c86ca6c85cd13e86e4b2/tensorflow_text-2.15.0-cp311-cp311-macosx_10_9_x86_64.whl'
12+
;;
13+
*)
14+
echo "TensorFlow Text distribution for ${PLATFORM} is not supported for download"
15+
exit 0;
16+
esac
17+
18+
mkdir -p "$DOWNLOAD_FOLDER"
19+
cd "$DOWNLOAD_FOLDER"
20+
21+
if [[ -n "$TEXT_WHEEL_URL" ]]; then
22+
echo "Downloading $TEXT_WHEEL_URL"
23+
if [ ! -f 'tensorflow-text.whl' ]; then
24+
curl -L $TEXT_WHEEL_URL --output 'tensorflow-text.whl'
25+
fi
26+
yes | unzip -q -u 'tensorflow-text.whl' # use 'yes' because for some reasons -u does not work on Windows
27+
fi
28+
29+
ls -l .

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import static org.junit.jupiter.api.Assertions.assertEquals;
1919
import static org.junit.jupiter.api.Assertions.assertNotNull;
2020
import static org.junit.jupiter.api.Assertions.assertTrue;
21-
import static org.junit.jupiter.api.Assertions.fail;
2221
import static org.junit.jupiter.api.Assumptions.assumeTrue;
2322

2423
import java.io.File;
@@ -40,30 +39,16 @@ public void registeredOpList() {
4039
}
4140

4241
@Test
43-
public void loadLibrary() {
44-
File customOpLibrary = Paths.get("").resolve("bazel-bin/libcustom_ops_test.so").toFile();
42+
public void loadTFTextLibrary() {
43+
String libname = System.mapLibraryName("_sentence_breaking_ops").substring(3); // strips off the lib on macOS & Linux, don't care about Windows.
44+
File customOpLibrary = Paths.get("", "target","tf-text-download","tensorflow_text","python","ops",libname).toFile();
4545

46-
// Disable this test if the custom op library is not available. This may happen on some
47-
// platforms (e.g. Windows) or when using a development profile that skips the native build
46+
// Disable this test if the tf-text library is not available. This may happen on some platforms (e.g. Windows)
4847
assumeTrue(customOpLibrary.exists());
4948

50-
try (Graph g = new Graph()) {
51-
// Build a graph with an unrecognized operation.
52-
try {
53-
g.baseScope().opBuilder("MyTest", "MyTest").build();
54-
fail("should not be able to construct graphs with unregistered ops");
55-
} catch (IllegalArgumentException e) {
56-
// expected exception
57-
}
58-
59-
// Load the library containing the operation.
60-
OpList opList = TensorFlow.loadLibrary(customOpLibrary.getAbsolutePath());
61-
assertNotNull(opList);
62-
assertEquals(1, opList.getOpCount());
63-
assertEquals(opList.getOpList().get(0).getName(), "MyTest");
64-
65-
// Now graph building should succeed.
66-
g.baseScope().opBuilder("MyTest", "MyTest").build();
67-
}
49+
OpList opList = TensorFlow.loadLibrary(customOpLibrary.getAbsolutePath());
50+
assertNotNull(opList);
51+
assertEquals(1, opList.getOpCount());
52+
assertEquals(opList.getOpList().get(0).getName(), "SentenceFragments");
6853
}
6954
}

0 commit comments

Comments
 (0)