|
29 | 29 | GRPC_ONLY_FILE = os.path.join(ROOT_DIR, 'gcloud', 'datastore', |
30 | 30 | '_generated', 'datastore_grpc_pb2.py') |
31 | 31 | GRPCIO_VIRTUALENV = os.environ.get('GRPCIO_VIRTUALENV', 'protoc') |
| 32 | +MESSAGE_SNIPPET = ' = _reflection.GeneratedProtocolMessageType(' |
| 33 | +IMPORT_TEMPLATE = 'from gcloud.datastore._generated.datastore_pb2 import %s\n' |
32 | 34 |
|
33 | 35 |
|
34 | 36 | def get_pb2_contents_with_grpc(): |
@@ -110,10 +112,32 @@ def get_pb2_grpc_only(): |
110 | 112 | return grpc_only_lines |
111 | 113 |
|
112 | 114 |
|
| 115 | +def get_pb2_message_types(): |
| 116 | + """Get message types defined in datastore pb2 file. |
| 117 | +
|
| 118 | + :rtype: list |
| 119 | + :returns: A list of names that are defined as message types. |
| 120 | + """ |
| 121 | + non_grpc_contents = get_pb2_contents_without_grpc() |
| 122 | + result = [] |
| 123 | + for line in non_grpc_contents: |
| 124 | + if MESSAGE_SNIPPET in line: |
| 125 | + name, _ = line.split(MESSAGE_SNIPPET) |
| 126 | + result.append(name) |
| 127 | + |
| 128 | + return sorted(result) |
| 129 | + |
| 130 | + |
113 | 131 | def main(): |
114 | 132 | """Write gRPC-only lines to custom module.""" |
115 | 133 | grpc_only_lines = get_pb2_grpc_only() |
116 | 134 | with open(GRPC_ONLY_FILE, 'wb') as file_obj: |
| 135 | + # First add imports for public objects in the original. |
| 136 | + file_obj.write('# BEGIN: Imports from datastore_pb2\n') |
| 137 | + for name in get_pb2_message_types(): |
| 138 | + import_line = IMPORT_TEMPLATE % (name,) |
| 139 | + file_obj.write(import_line) |
| 140 | + file_obj.write('# END: Imports from datastore_pb2\n') |
117 | 141 | file_obj.write(''.join(grpc_only_lines)) |
118 | 142 |
|
119 | 143 |
|
|
0 commit comments