Skip to content

Commit 57e4681

Browse files
ML model management quickstart
1 parent 688fcfa commit 57e4681

2 files changed

Lines changed: 225 additions & 0 deletions

File tree

machine-learning/README.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Firebase Admin Python SDK ML quickstart
2+
3+
This sample script shows how you can use the Firebase Admin SDK to manage your
4+
Firebase-hosted ML models.
5+
6+
## Setup
7+
8+
1. Install the Admin SDK (probably in a virtual environment):
9+
10+
```
11+
$ pip install -U pip setuptools
12+
$ pip install firebase_admin
13+
```
14+
15+
2. Clone the quickstart repository and change to the `machine-learning`
16+
  directory:
17+
18+
```
19+
$ git clone https://github.com/firebase/quickstart-python.git
20+
$ cd quickstart-python/machine-learning
21+
$ chmod u+x manage-ml.py # Optional
22+
```
23+
24+
3. If you don't already have a Firebase project, create a new project in the
25+
[Firebase console](https://console.firebase.google.com/). Then, open your
26+
project in the Firebase console and do the following:
27+
28+
1. On the [Settings][service-account] page, create a service account and
29+
download the service account key file. Keep this file safe, since it
30+
grants administrator access to your project.
31+
2. On the Storage page, enable Cloud Storage. Take note of your default
32+
bucket name (or create a new bucket for ML models.)
33+
3. On the ML Kit page, click **Get started** if you haven't yet enabled ML
34+
Kit.
35+
36+
4. In the [Google APIs console][enable-api], open your Firebase project and
37+
enable the Firebase ML API.
38+
39+
[enable-api]: https://console.developers.google.com/apis/library/firebaseml.googleapis.com?project=_
40+
41+
5. At the top of `manage-ml.py`, set the `SERVICE_ACCOUNT_KEY` and
42+
`STORAGE_BUCKET`:
43+
44+
```
45+
SERVICE_ACCOUNT_KEY = '/path/to/your/service_account_key.json'
46+
STORAGE_BUCKET = 'your-storage-bucket'
47+
```
48+
49+
[service-account]: https://firebase.google.com/project/_/settings/serviceaccounts/adminsdk
50+
51+
## Example session
52+
53+
```
54+
$ ./manage-ml.py list
55+
fish_detector 8716935 vision
56+
barcode_scanner 8716959 vision
57+
smart_reply 8716981 natural_language
58+
$ ./manage-ml.py new ~/yak.tflite yak_detector --tags vision,experimental
59+
Uploading model to Cloud Storage...
60+
Model uploaded and published:
61+
yak_detector 8717019 experimental, vision
62+
$ ./manage-ml.py update 8717019 --remove_tags experimental
63+
$ ./manage-ml.py delete 8716959
64+
$ ./manage-ml.py list
65+
fish_detector 8716935 vision
66+
smart_reply 8716981 natural_language
67+
yak_detector 8717019 vision
68+
$
69+
```

machine-learning/manage-ml.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#!/usr/bin/env python3
2+
"""Firebase Admin SDK ML quickstart example."""
3+
4+
import argparse
5+
6+
import firebase_admin
7+
from firebase_admin import ml
8+
9+
10+
# TODO(user): Configure for your project. (See README.md.)
11+
SERVICE_ACCOUNT_KEY = '/path/to/your/service_account_key.json'
12+
STORAGE_BUCKET = 'your-storage-bucket'
13+
14+
credentials = firebase_admin.credentials.Certificate(SERVICE_ACCOUNT_KEY)
15+
firebase_admin.initialize_app(credentials, options={
16+
'storageBucket': STORAGE_BUCKET
17+
})
18+
19+
20+
def upload_model(model_file, name, tags=None):
21+
"""Upload a tflite model file to the project and publish it."""
22+
# Load a tflite file and upload it to Cloud Storage
23+
print('Uploading to Cloud Storage...')
24+
model_source = ml.TFLiteGCSModelSource.from_tflite_model_file(model_file)
25+
26+
# Create the model object
27+
tflite_format = ml.TFLiteFormat(model_source=model_source)
28+
model = ml.Model(
29+
display_name=name,
30+
model_format=tflite_format)
31+
if tags is not None:
32+
model.tags = tags
33+
34+
# Add the model to your Firebase project and publish it
35+
new_model = ml.create_model(model)
36+
ml.publish_model(new_model.model_id)
37+
38+
print('Model uploaded and published:')
39+
tags = ', '.join(new_model.tags) if new_model.tags is not None else ''
40+
print('{:<20}{:<10} {}'.format(new_model.display_name, new_model.model_id,
41+
tags))
42+
43+
44+
def list_models(filter_exp=''):
45+
"""List the models in the project."""
46+
models = ml.list_models(list_filter=filter_exp).iterate_all()
47+
for model in models:
48+
tags = ', '.join(model.tags) if model.tags is not None else ''
49+
print('{:<20}{:<10} {}'.format(model.display_name, model.model_id, tags))
50+
51+
52+
def update_model(model_id, model_file=None, name=None,
53+
new_tags=None, remove_tags=None):
54+
"""Update one of the project's models."""
55+
model = ml.get_model(model_id)
56+
57+
if model_file is not None:
58+
# Load a tflite file and upload it to Cloud Storage
59+
print('Uploading to Cloud Storage...')
60+
model_source = ml.TFLiteGCSModelSource.from_tflite_model_file(model_file)
61+
tflite_format = ml.TFLiteFormat(model_source=model_source)
62+
model.model_format = tflite_format
63+
64+
if name is not None:
65+
model.display_name = name
66+
67+
if new_tags is not None:
68+
model.tags = new_tags if model.tags is None else model.tags + new_tags
69+
70+
if remove_tags is not None and model.tags is not None:
71+
model.tags = list(set(model.tags).difference(set(remove_tags)))
72+
73+
updated_model = ml.update_model(model)
74+
ml.publish_model(updated_model.model_id)
75+
76+
77+
def delete_model(model_id):
78+
"""Delete a model from the project."""
79+
ml.delete_model(model_id)
80+
81+
82+
# The rest of the file just parses the command line and dispatches one of the
83+
# functions above.
84+
85+
86+
def main():
87+
main_parser = argparse.ArgumentParser()
88+
subparsers = main_parser.add_subparsers(
89+
dest='command', required=True, metavar='command')
90+
91+
new_parser = subparsers.add_parser(
92+
'new', help='upload a tflite model to your project')
93+
new_parser.add_argument(
94+
'model_file', type=str, help='path to the tflite file')
95+
new_parser.add_argument(
96+
'name', type=str, help='display name for the new model')
97+
new_parser.add_argument(
98+
'-t', '--tags', type=str, help='comma-separated list of tags')
99+
100+
list_parser = subparsers.add_parser(
101+
'list', help='list your project\'s models')
102+
list_parser.add_argument(
103+
'-f', '--filter', type=str, default='',
104+
help='''filter expression to limit results (see:
105+
https://firebase.google.com/docs/ml-kit/manage-hosted-models#list_your_projects_models)''')
106+
107+
update_parser = subparsers.add_parser(
108+
'update', help='update one of your project\'s models')
109+
update_parser.add_argument(
110+
'model_id', type=valid_id, help='the ID of the model you want to update')
111+
update_parser.add_argument(
112+
'-m', '--model_file', type=str, help='path to a new tflite file')
113+
update_parser.add_argument(
114+
'-n', '--name', type=str, help='display name for the model')
115+
update_parser.add_argument(
116+
'-t', '--new_tags', type=str,
117+
help='comma-separated list of tags to add')
118+
update_parser.add_argument(
119+
'-d', '--remove_tags', type=str,
120+
help='comma-separated list of tags to remove')
121+
122+
delete_parser = subparsers.add_parser(
123+
'delete', help='delete a model from your project')
124+
delete_parser.add_argument(
125+
'model_id', type=valid_id, help='the ID of the model you want to delete')
126+
127+
args = main_parser.parse_args()
128+
try:
129+
if args.command == 'new':
130+
tags = args.tags.split(',') if args.tags is not None else None
131+
upload_model(args.model_file, args.name, tags)
132+
elif args.command == 'list':
133+
list_models(args.filter)
134+
elif args.command == 'update':
135+
new_tags = args.new_tags.split(',') if args.new_tags is not None else None
136+
remove_tags = (
137+
args.remove_tags.split(',') if args.remove_tags is not None else None)
138+
update_model(args.model_id, args.model_file, args.name,
139+
new_tags, remove_tags)
140+
elif args.command == 'delete':
141+
delete_model(args.model_id)
142+
except firebase_admin.exceptions.NotFoundError:
143+
print('ERROR: Model not found. Make sure you\'re specifying a valid'
144+
' numerical model ID.')
145+
146+
147+
def valid_id(model_id):
148+
try:
149+
val = int(model_id)
150+
return str(val)
151+
except ValueError:
152+
raise argparse.ArgumentTypeError('must be a numerical model ID.')
153+
154+
155+
if __name__ == '__main__':
156+
main()

0 commit comments

Comments
 (0)