|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Firebase Admin SDK ML quickstart example.""" |
| 3 | + |
| 4 | +import argparse |
| 5 | + |
| 6 | +import firebase_admin |
| 7 | +from firebase_admin import ml |
| 8 | + |
| 9 | + |
| 10 | +# TODO(user): Configure for your project. (See README.md.) |
| 11 | +SERVICE_ACCOUNT_KEY = '/path/to/your/service_account_key.json' |
| 12 | +STORAGE_BUCKET = 'your-storage-bucket' |
| 13 | + |
| 14 | +credentials = firebase_admin.credentials.Certificate(SERVICE_ACCOUNT_KEY) |
| 15 | +firebase_admin.initialize_app(credentials, options={ |
| 16 | + 'storageBucket': STORAGE_BUCKET |
| 17 | +}) |
| 18 | + |
| 19 | + |
| 20 | +def upload_model(model_file, name, tags=None): |
| 21 | + """Upload a tflite model file to the project and publish it.""" |
| 22 | + # Load a tflite file and upload it to Cloud Storage |
| 23 | + print('Uploading to Cloud Storage...') |
| 24 | + model_source = ml.TFLiteGCSModelSource.from_tflite_model_file(model_file) |
| 25 | + |
| 26 | + # Create the model object |
| 27 | + tflite_format = ml.TFLiteFormat(model_source=model_source) |
| 28 | + model = ml.Model( |
| 29 | + display_name=name, |
| 30 | + model_format=tflite_format) |
| 31 | + if tags is not None: |
| 32 | + model.tags = tags |
| 33 | + |
| 34 | + # Add the model to your Firebase project and publish it |
| 35 | + new_model = ml.create_model(model) |
| 36 | + ml.publish_model(new_model.model_id) |
| 37 | + |
| 38 | + print('Model uploaded and published:') |
| 39 | + tags = ', '.join(new_model.tags) if new_model.tags is not None else '' |
| 40 | + print('{:<20}{:<10} {}'.format(new_model.display_name, new_model.model_id, |
| 41 | + tags)) |
| 42 | + |
| 43 | + |
| 44 | +def list_models(filter_exp=''): |
| 45 | + """List the models in the project.""" |
| 46 | + models = ml.list_models(list_filter=filter_exp).iterate_all() |
| 47 | + for model in models: |
| 48 | + tags = ', '.join(model.tags) if model.tags is not None else '' |
| 49 | + print('{:<20}{:<10} {}'.format(model.display_name, model.model_id, tags)) |
| 50 | + |
| 51 | + |
| 52 | +def update_model(model_id, model_file=None, name=None, |
| 53 | + new_tags=None, remove_tags=None): |
| 54 | + """Update one of the project's models.""" |
| 55 | + model = ml.get_model(model_id) |
| 56 | + |
| 57 | + if model_file is not None: |
| 58 | + # Load a tflite file and upload it to Cloud Storage |
| 59 | + print('Uploading to Cloud Storage...') |
| 60 | + model_source = ml.TFLiteGCSModelSource.from_tflite_model_file(model_file) |
| 61 | + tflite_format = ml.TFLiteFormat(model_source=model_source) |
| 62 | + model.model_format = tflite_format |
| 63 | + |
| 64 | + if name is not None: |
| 65 | + model.display_name = name |
| 66 | + |
| 67 | + if new_tags is not None: |
| 68 | + model.tags = new_tags if model.tags is None else model.tags + new_tags |
| 69 | + |
| 70 | + if remove_tags is not None and model.tags is not None: |
| 71 | + model.tags = list(set(model.tags).difference(set(remove_tags))) |
| 72 | + |
| 73 | + updated_model = ml.update_model(model) |
| 74 | + ml.publish_model(updated_model.model_id) |
| 75 | + |
| 76 | + |
| 77 | +def delete_model(model_id): |
| 78 | + """Delete a model from the project.""" |
| 79 | + ml.delete_model(model_id) |
| 80 | + |
| 81 | + |
| 82 | +# The rest of the file just parses the command line and dispatches one of the |
| 83 | +# functions above. |
| 84 | + |
| 85 | + |
| 86 | +def main(): |
| 87 | + main_parser = argparse.ArgumentParser() |
| 88 | + subparsers = main_parser.add_subparsers( |
| 89 | + dest='command', required=True, metavar='command') |
| 90 | + |
| 91 | + new_parser = subparsers.add_parser( |
| 92 | + 'new', help='upload a tflite model to your project') |
| 93 | + new_parser.add_argument( |
| 94 | + 'model_file', type=str, help='path to the tflite file') |
| 95 | + new_parser.add_argument( |
| 96 | + 'name', type=str, help='display name for the new model') |
| 97 | + new_parser.add_argument( |
| 98 | + '-t', '--tags', type=str, help='comma-separated list of tags') |
| 99 | + |
| 100 | + list_parser = subparsers.add_parser( |
| 101 | + 'list', help='list your project\'s models') |
| 102 | + list_parser.add_argument( |
| 103 | + '-f', '--filter', type=str, default='', |
| 104 | + help='''filter expression to limit results (see: |
| 105 | + https://firebase.google.com/docs/ml-kit/manage-hosted-models#list_your_projects_models)''') |
| 106 | + |
| 107 | + update_parser = subparsers.add_parser( |
| 108 | + 'update', help='update one of your project\'s models') |
| 109 | + update_parser.add_argument( |
| 110 | + 'model_id', type=valid_id, help='the ID of the model you want to update') |
| 111 | + update_parser.add_argument( |
| 112 | + '-m', '--model_file', type=str, help='path to a new tflite file') |
| 113 | + update_parser.add_argument( |
| 114 | + '-n', '--name', type=str, help='display name for the model') |
| 115 | + update_parser.add_argument( |
| 116 | + '-t', '--new_tags', type=str, |
| 117 | + help='comma-separated list of tags to add') |
| 118 | + update_parser.add_argument( |
| 119 | + '-d', '--remove_tags', type=str, |
| 120 | + help='comma-separated list of tags to remove') |
| 121 | + |
| 122 | + delete_parser = subparsers.add_parser( |
| 123 | + 'delete', help='delete a model from your project') |
| 124 | + delete_parser.add_argument( |
| 125 | + 'model_id', type=valid_id, help='the ID of the model you want to delete') |
| 126 | + |
| 127 | + args = main_parser.parse_args() |
| 128 | + try: |
| 129 | + if args.command == 'new': |
| 130 | + tags = args.tags.split(',') if args.tags is not None else None |
| 131 | + upload_model(args.model_file, args.name, tags) |
| 132 | + elif args.command == 'list': |
| 133 | + list_models(args.filter) |
| 134 | + elif args.command == 'update': |
| 135 | + new_tags = args.new_tags.split(',') if args.new_tags is not None else None |
| 136 | + remove_tags = ( |
| 137 | + args.remove_tags.split(',') if args.remove_tags is not None else None) |
| 138 | + update_model(args.model_id, args.model_file, args.name, |
| 139 | + new_tags, remove_tags) |
| 140 | + elif args.command == 'delete': |
| 141 | + delete_model(args.model_id) |
| 142 | + except firebase_admin.exceptions.NotFoundError: |
| 143 | + print('ERROR: Model not found. Make sure you\'re specifying a valid' |
| 144 | + ' numerical model ID.') |
| 145 | + |
| 146 | + |
| 147 | +def valid_id(model_id): |
| 148 | + try: |
| 149 | + val = int(model_id) |
| 150 | + return str(val) |
| 151 | + except ValueError: |
| 152 | + raise argparse.ArgumentTypeError('must be a numerical model ID.') |
| 153 | + |
| 154 | + |
| 155 | +if __name__ == '__main__': |
| 156 | + main() |
0 commit comments