Skip to content

Commit 2aaa700

Browse files
authored
add react native platform prototype (tensorflow#1789)
INTERNAL This adds an initial prototype of a platform for react native support. Initially it supports CPU backend execution and provides two IOHandlers for use with react-native. It also changes adds headers to (http) weight requests in core so that we can better distinguish the desired return type.
1 parent 9199355 commit 2aaa700

23 files changed

Lines changed: 7643 additions & 4 deletions

src/device_util.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ export function isMobile(): boolean {
2727
}
2828

2929
export function isBrowser(): boolean {
30-
return (typeof window !== 'undefined') ||
30+
return (typeof window !== 'undefined' && window.document != null) ||
3131
//@ts-ignore
3232
(typeof WorkerGlobalScope !== 'undefined');
3333
}

src/io/io.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import {browserHTTPRequest, http, isHTTPScheme} from './http';
2525
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsInfoForJSON} from './io_utils';
2626
import {fromMemory, withSaveHandler} from './passthrough';
2727
import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry';
28-
import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types';
28+
import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types';
2929
import {loadWeights, weightsLoaderFactory} from './weights_loader';
3030

3131
export {copyModel, listModels, moveModel, removeModel} from './model_management';
@@ -46,6 +46,7 @@ export {
4646
LoadOptions,
4747
loadWeights,
4848
ModelArtifacts,
49+
ModelArtifactsInfo,
4950
ModelJSON,
5051
ModelStoreManager,
5152
OnProgressCallback,

src/io/types.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ export declare interface ModelArtifactsInfo {
154154
dateSaved: Date;
155155

156156
/**
157+
* TODO (cais,yassogba) consider removing GraphDef as GraphDefs now
158+
* come in a JSON format and none of our IOHandlers support a non json
159+
* format. We could conder replacing this with 'Binary' if we want to
160+
* allow future handlers to save to non json formats (though they will
161+
* probably want more information than 'Binary').
162+
* Type of the model topology
163+
*
157164
* Type of the model topology
158165
*
159166
* Possible values:

src/io/weights_loader.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,16 @@ export async function loadWeightsAsArrayBuffer(
4242
const fetchFunc =
4343
loadOptions.fetchFunc == null ? util.fetch : loadOptions.fetchFunc;
4444

45+
// Update the request headers without modifying the passed in
46+
// loadOptions object.
47+
const requestOptions = Object.assign({}, loadOptions.requestInit);
48+
requestOptions.headers = Object.assign({}, requestOptions.headers, {
49+
responseType: 'arraybuffer',
50+
});
51+
4552
// Create the requests for all of the weights in parallel.
4653
const requests =
47-
fetchURLs.map(fetchURL => fetchFunc(fetchURL, loadOptions.requestInit));
54+
fetchURLs.map(fetchURL => fetchFunc(fetchURL, requestOptions));
4855

4956
const fetchStartFraction = 0;
5057
const fetchEndFraction = 0.5;

src/io/weights_loader_test.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,10 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => {
436436
manifest, './', weightsNamesToFetch, {credentials: 'include'});
437437
expect((tf.util.fetch as jasmine.Spy).calls.count()).toBe(1);
438438
expect(tf.util.fetch).toHaveBeenCalledWith('./weightfile0', {
439-
credentials: 'include'
439+
credentials: 'include',
440+
headers: {
441+
responseType: 'arraybuffer',
442+
}
440443
});
441444
});
442445

tfjs-react-native/.npmignore

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
.babelrc
2+
.DS_Store
3+
.idea/
4+
.rpt2_cache
5+
.travis.yml
6+
.vscode
7+
*.tgz
8+
*.txt
9+
**.yalc
10+
**yalc.lock
11+
cloudbuild.yml
12+
coverage/
13+
demo/
14+
dist/**/*_test.d.ts
15+
dist/**/*_test.js
16+
karma.conf.js
17+
node_modules/
18+
npm-debug.log
19+
package-lock.json
20+
package/
21+
rollup.config.js
22+
scripts/
23+
src/**/*_test.ts
24+
tsconfig.json
25+
tslint.json
26+
yarn-error.log
27+
yarn.lock
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"search.exclude": {
3+
"**/node_modules": true,
4+
"coverage/": true,
5+
"dist/": true,
6+
"**/yarn.lock": true,
7+
".rpt2_cache/": true,
8+
".yalc/": true
9+
},
10+
"tslint.enable": true,
11+
"tslint.run": "onType",
12+
"tslint.configFile": "tslint.json",
13+
"files.trimTrailingWhitespace": true,
14+
"editor.tabSize": 2,
15+
"editor.insertSpaces": true,
16+
"[typescript]": {
17+
"editor.formatOnSave": true
18+
},
19+
"[javascript]": {
20+
"editor.formatOnSave": true
21+
},
22+
"editor.rulers": [80],
23+
"clang-format.style": "Google",
24+
"files.insertFinalNewline": true,
25+
"editor.detectIndentation": false,
26+
"editor.wrappingIndent": "none",
27+
"typescript.tsdk": "./node_modules/typescript/lib",
28+
"clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format"
29+
}

tfjs-react-native/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Platform Adapter for React Native
2+
3+
Status: Early development.

tfjs-react-native/karma.conf.js

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
const karmaTypescriptConfig = {
19+
tsconfig: 'tsconfig.json',
20+
reports: {},
21+
bundlerOptions: {
22+
// Start from test files to control what karma typescript loads
23+
// and ensure that environment setup happens appropriately.
24+
entrypoints: /_test\.(ts)$/,
25+
// Mock react native functionality to enable unit tests in the browser.
26+
resolve: {
27+
alias: {
28+
'react-native': './test/utils/react_native_mock.ts',
29+
'@react-native-community/async-storage':
30+
'./test/utils/async_storage_mock.ts',
31+
}
32+
}
33+
}
34+
};
35+
36+
const baseConfig = {
37+
frameworks: ['jasmine', 'karma-typescript'],
38+
files: [
39+
'./src/**/*.ts',
40+
'./test/**/*.ts',
41+
],
42+
preprocessors: {
43+
'src/**/*.ts': ['karma-typescript'],
44+
'test/**/*.ts': ['karma-typescript'],
45+
},
46+
karmaTypescriptConfig,
47+
reporters: ['verbose', 'karma-typescript'],
48+
};
49+
50+
module.exports = function(config) {
51+
const args = [];
52+
if (config.grep) {
53+
args.push('--grep', config.grep);
54+
}
55+
56+
config.set({
57+
...baseConfig,
58+
basePath: '',
59+
frameworks: ['jasmine', 'karma-typescript'],
60+
preprocessors: {'**/*.ts': ['karma-typescript']},
61+
karmaTypescriptConfig,
62+
reporters: ['progress', 'karma-typescript'],
63+
port: 9876,
64+
colors: true,
65+
autoWatch: false,
66+
browsers: ['Chrome'],
67+
singleRun: true,
68+
client: {
69+
jasmine: {random: false},
70+
args: args,
71+
},
72+
customLaunchers: {
73+
// For browserstack configs see:
74+
// https://www.browserstack.com/automate/node
75+
bs_chrome_mac: {
76+
base: 'BrowserStack',
77+
browser: 'chrome',
78+
browser_version: 'latest',
79+
os: 'OS X',
80+
os_version: 'High Sierra'
81+
},
82+
}
83+
})
84+
}

tfjs-react-native/package.json

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
{
2+
"name": "@tensorflow/tfjs-platform-react-native",
3+
"version": "0.1.0",
4+
"description": "TensorFlow.js platform implementation for React Native",
5+
"main": "dist/react_native/src/index.js",
6+
"types": "dist/react_native/src/index.d.ts",
7+
"jsnext:main": "dist/tf-react-native.esm.js",
8+
"module": "dist/tf-react-native.esm.js",
9+
"unpkg": "dist/tf-react-native.min.js",
10+
"jsdelivr": "dist/tf-react-native.min.js",
11+
"license": "Apache-2.0",
12+
"private": true,
13+
"scripts": {
14+
"publish-local": "rimraf dist/ && yarn build && rollup -c && yalc push",
15+
"build": "tsc",
16+
"link-local": "yalc link",
17+
"unlink-local": "yalc remove",
18+
"lint": "tslint -p . -t verbose",
19+
"test": "karma start"
20+
},
21+
"devDependencies": {
22+
"@react-native-community/async-storage": "^1.4.2",
23+
"@tensorflow/tfjs-core": "1.2.0",
24+
"@types/base64-js": "^1.2.5",
25+
"@types/react-native": "^0.57.60",
26+
"clang-format": "~1.2.2",
27+
"jasmine": "~3.1.0",
28+
"jasmine-core": "~3.1.0",
29+
"karma": "~4.0.0",
30+
"karma-browserify": "~6.0.0",
31+
"karma-browserstack-launcher": "~1.4.0",
32+
"karma-chrome-launcher": "~2.2.0",
33+
"karma-jasmine": "~1.1.0",
34+
"karma-typescript": "~4.0.0",
35+
"karma-verbose-reporter": "^0.0.6",
36+
"react-native": "^0.59.9",
37+
"rimraf": "~2.6.2",
38+
"rollup": "^0.58.2",
39+
"rollup-plugin-commonjs": "9.1.3",
40+
"rollup-plugin-node-resolve": "3.3.0",
41+
"rollup-plugin-typescript2": "0.13.0",
42+
"rollup-plugin-uglify": "~3.0.0",
43+
"tslint": "~5.11.0",
44+
"tslint-no-circular-imports": "^0.5.0",
45+
"typescript": "3.3.3333",
46+
"yalc": "~1.0.0-pre.21"
47+
},
48+
"dependencies": {
49+
"base64-js": "^1.3.0"
50+
},
51+
"peerDependencies": {
52+
"@react-native-community/async-storage": "^1.4.2",
53+
"@tensorflow/tfjs-core": ">=1.0.0",
54+
"react-native": ">=0.59.0"
55+
}
56+
}

0 commit comments

Comments
 (0)