Skip to content

Commit 33f3dd5

Browse files
committed
Add gax support for entity annotations.
1 parent 60f1ada commit 33f3dd5

File tree

11 files changed

+323
-5
lines changed

11 files changed

+323
-5
lines changed

system_tests/vision.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def _assert_face(self, face):
190190

191191
def test_detect_faces_content(self):
192192
client = Config.CLIENT
193+
client._use_gax = False
193194
with open(FACE_FILE, 'rb') as image_file:
194195
image = client.image(content=image_file.read())
195196
faces = image.detect_faces()
@@ -208,6 +209,7 @@ def test_detect_faces_gcs(self):
208209
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)
209210

210211
client = Config.CLIENT
212+
client._use_gax = False
211213
image = client.image(source_uri=source_uri)
212214
faces = image.detect_faces()
213215
self.assertEqual(len(faces), 5)
@@ -216,6 +218,7 @@ def test_detect_faces_gcs(self):
216218

217219
def test_detect_faces_filename(self):
218220
client = Config.CLIENT
221+
client._use_gax = False
219222
image = client.image(filename=FACE_FILE)
220223
faces = image.detect_faces()
221224
self.assertEqual(len(faces), 5)
@@ -310,6 +313,7 @@ def _assert_landmark(self, landmark):
310313

311314
def test_detect_landmark_content(self):
312315
client = Config.CLIENT
316+
client._use_gax = True
313317
with open(LANDMARK_FILE, 'rb') as image_file:
314318
image = client.image(content=image_file.read())
315319
landmarks = image.detect_landmarks()
@@ -328,6 +332,7 @@ def test_detect_landmark_gcs(self):
328332
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)
329333

330334
client = Config.CLIENT
335+
client._use_gax = True
331336
image = client.image(source_uri=source_uri)
332337
landmarks = image.detect_landmarks()
333338
self.assertEqual(len(landmarks), 1)
@@ -336,6 +341,7 @@ def test_detect_landmark_gcs(self):
336341

337342
def test_detect_landmark_filename(self):
338343
client = Config.CLIENT
344+
client._use_gax = True
339345
image = client.image(filename=LANDMARK_FILE)
340346
landmarks = image.detect_landmarks()
341347
self.assertEqual(len(landmarks), 1)
@@ -362,6 +368,7 @@ def _assert_safe_search(self, safe_search):
362368

363369
def test_detect_safe_search_content(self):
364370
client = Config.CLIENT
371+
client._use_gax = False
365372
with open(FACE_FILE, 'rb') as image_file:
366373
image = client.image(content=image_file.read())
367374
safe_searches = image.detect_safe_search()
@@ -380,6 +387,7 @@ def test_detect_safe_search_gcs(self):
380387
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)
381388

382389
client = Config.CLIENT
390+
client._use_gax = False
383391
image = client.image(source_uri=source_uri)
384392
safe_searches = image.detect_safe_search()
385393
self.assertEqual(len(safe_searches), 1)
@@ -388,6 +396,7 @@ def test_detect_safe_search_gcs(self):
388396

389397
def test_detect_safe_search_filename(self):
390398
client = Config.CLIENT
399+
client._use_gax = False
391400
image = client.image(filename=FACE_FILE)
392401
safe_searches = image.detect_safe_search()
393402
self.assertEqual(len(safe_searches), 1)
@@ -423,6 +432,7 @@ def _assert_text(self, text):
423432

424433
def test_detect_text_content(self):
425434
client = Config.CLIENT
435+
client._use_gax = True
426436
with open(TEXT_FILE, 'rb') as image_file:
427437
image = client.image(content=image_file.read())
428438
texts = image.detect_text()
@@ -441,6 +451,7 @@ def test_detect_text_gcs(self):
441451
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)
442452

443453
client = Config.CLIENT
454+
client._use_gax = True
444455
image = client.image(source_uri=source_uri)
445456
texts = image.detect_text()
446457
self.assertEqual(len(texts), 9)
@@ -449,6 +460,7 @@ def test_detect_text_gcs(self):
449460

450461
def test_detect_text_filename(self):
451462
client = Config.CLIENT
463+
client._use_gax = True
452464
image = client.image(filename=TEXT_FILE)
453465
texts = image.detect_text()
454466
self.assertEqual(len(texts), 9)
@@ -485,6 +497,7 @@ def _assert_properties(self, image_property):
485497

486498
def test_detect_properties_content(self):
487499
client = Config.CLIENT
500+
client._use_gax = False
488501
with open(FACE_FILE, 'rb') as image_file:
489502
image = client.image(content=image_file.read())
490503
properties = image.detect_properties()
@@ -503,6 +516,7 @@ def test_detect_properties_gcs(self):
503516
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)
504517

505518
client = Config.CLIENT
519+
client._use_gax = False
506520
image = client.image(source_uri=source_uri)
507521
properties = image.detect_properties()
508522
self.assertEqual(len(properties), 1)
@@ -511,6 +525,7 @@ def test_detect_properties_gcs(self):
511525

512526
def test_detect_properties_filename(self):
513527
client = Config.CLIENT
528+
client._use_gax = False
514529
image = client.image(filename=FACE_FILE)
515530
properties = image.detect_properties()
516531
self.assertEqual(len(properties), 1)

vision/google/cloud/vision/_gax.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from google.cloud._helpers import _to_bytes
2121

22+
from google.cloud.vision.annotations import Annotations
23+
2224

2325
class _GAPICVisionAPI(object):
2426
"""Vision API for interacting with the gRPC version of Vision.
@@ -30,6 +32,27 @@ def __init__(self, client=None):
3032
self._client = client
3133
self._api = image_annotator_client.ImageAnnotatorClient()
3234

35+
def annotate(self, image, features):
36+
"""Annotate images through GAX.
37+
38+
:type image: :class:`~google.cloud.vision.image.Image`
39+
:param image: Instance of ``Image``.
40+
41+
:type features: list
42+
:param features: List of :class:`~google.cloud.vision.feature.Feature`.
43+
44+
:rtype: :class:`~google.cloud.vision.annotations.Annotations`
45+
:returns: Instance of ``Annotations`` with results.
46+
"""
47+
gapic_features = [_to_gapic_feature(feature) for feature in features]
48+
gapic_image = _to_gapic_image(image)
49+
request = image_annotator_pb2.AnnotateImageRequest(
50+
image=gapic_image, features=gapic_features)
51+
requests = [request]
52+
api = self._api
53+
responses = api.batch_annotate_images(requests)
54+
return Annotations.from_pb(responses.responses[0])
55+
3356

3457
def _to_gapic_feature(feature):
3558
"""Helper function to convert a ``Feature`` to a gRPC ``Feature``.

vision/google/cloud/vision/_http.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""HTTP Client for interacting with the Google Cloud Vision API."""
1616

17+
from google.cloud.vision.annotations import Annotations
1718
from google.cloud.vision.feature import Feature
1819

1920

@@ -49,7 +50,7 @@ def annotate(self, image, features):
4950
api_response = self._connection.api_request(
5051
method='POST', path='/images:annotate', data=data)
5152
responses = api_response.get('responses')
52-
return responses[0]
53+
return Annotations.from_api_repr(responses[0])
5354

5455

5556
def _make_request(image, features):

vision/google/cloud/vision/annotations.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,55 @@ def from_api_repr(cls, response):
9393
_entity_from_response_type(feature_type, annotation))
9494
return cls(**annotations)
9595

96+
@classmethod
97+
def from_pb(cls, response):
98+
"""Factory: construct an instance of ``Annotations`` from gRPC response.
99+
100+
:type response: :class:`~google.cloud.grpc.vision.v1.\
101+
image_annotator_pb2.AnnotateImageResponse`
102+
:param response: ``AnnotateImageResponse`` from gRPC call.
103+
104+
:rtype: :class:`~google.cloud.vision.annotations.Annotations`
105+
:returns: ``Annotations`` instance populated from gRPC response.
106+
"""
107+
annotations = _process_image_annotations(response)
108+
return cls(**annotations)
109+
110+
111+
def _process_image_annotations(image):
112+
"""Helper for processing annotation types from gRPC responses.
113+
114+
:type image: :class:`~google.cloud.grpc.vision.v1.image_annotator_pb2.\
115+
AnnotateImageResponse`
116+
:param image: ``AnnotateImageResponse`` from gRPC response.
117+
118+
:rtype: dict
119+
:returns: Dictionary populated with entities from response.
120+
"""
121+
annotations = {}
122+
annotations['labels'] = _make_entity_from_pb(image.label_annotations)
123+
annotations['landmarks'] = _make_entity_from_pb(image.landmark_annotations)
124+
annotations['logos'] = _make_entity_from_pb(image.logo_annotations)
125+
annotations['texts'] = _make_entity_from_pb(image.text_annotations)
126+
return annotations
127+
128+
129+
def _make_entity_from_pb(annotations):
130+
"""Create an entity from a gRPC response.
131+
132+
:type annotations:
133+
:class:`~google.cloud.grpc.vision.v1.image_annotator_pb2.EntityAnnotation`
134+
:param annotations: gRPC instance of ``EntityAnnotation``.
135+
136+
:rtype: list
137+
:returns: List of ``EntityAnnotation``.
138+
"""
139+
140+
entities = []
141+
for annotation in annotations:
142+
entities.append(EntityAnnotation.from_pb(annotation))
143+
return entities
144+
96145

97146
def _entity_from_response_type(feature_type, results):
98147
"""Convert a JSON result to an entity type based on the feature.

vision/google/cloud/vision/entity.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,26 @@ def from_api_repr(cls, response):
7070

7171
return cls(bounds, description, locale, locations, mid, score)
7272

73+
@classmethod
74+
def from_pb(cls, response):
75+
"""Factory: construct entity from Vision gRPC response.
76+
77+
:type response: :class:`~google.cloud.grpc.vision.v1.\
78+
image_annotator_pb2.AnnotateImageResponse`
79+
:param response: gRPC response from Vision API with entity data.
80+
81+
:rtype: :class:`~google.cloud.vision.entity.EntityAnnotation`
82+
:returns: Instance of ``EntityAnnotation``.
83+
"""
84+
bounds = Bounds.from_pb(response.bounding_poly)
85+
description = response.description
86+
locale = response.locale
87+
locations = [LocationInformation.from_pb(location)
88+
for location in response.locations]
89+
mid = response.mid
90+
score = response.score
91+
return cls(bounds, description, locale, locations, mid, score)
92+
7393
@property
7494
def bounds(self):
7595
"""Bounding polygon of detected image feature.

vision/google/cloud/vision/geometry.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ def from_api_repr(cls, response_vertices):
4141
vertex in response_vertices.get('vertices', [])]
4242
return cls(vertices)
4343

44+
@classmethod
45+
def from_pb(cls, response_vertices):
46+
"""Factory: construct BoundsBase instance from Vision gRPC response.
47+
48+
:type response_vertices: :class:`~google.cloud.grpc.vision.v1.\
49+
geometry_pb2.BoundingPoly`
50+
:param response_vertices: List of vertices.
51+
52+
:rtype: :class:`~google.cloud.vision.geometry.BoundsBase` or None
53+
:returns: Instance of ``BoundsBase`` with populated verticies.
54+
"""
55+
return cls([Vertex(vertex.x, vertex.y)
56+
for vertex in response_vertices.vertices])
57+
4458
@property
4559
def vertices(self):
4660
"""List of vertices.
@@ -87,6 +101,19 @@ def from_api_repr(cls, response):
87101
longitude = response['latLng']['longitude']
88102
return cls(latitude, longitude)
89103

104+
@classmethod
105+
def from_pb(cls, response):
106+
"""Factory: construct location information from Vision gRPC response.
107+
108+
:type response: :class:`~google.cloud.vision.v1.LocationInfo`
109+
:param response: gRPC response of ``LocationInfo``.
110+
111+
:rtype: :class:`~google.cloud.vision.geometry.LocationInformation`
112+
:returns: ``LocationInformation`` with populated latitude and
113+
longitude.
114+
"""
115+
return cls(response.lat_lng.latitude, response.lat_lng.longitude)
116+
90117
@property
91118
def latitude(self):
92119
"""Latitude coordinate.

vision/google/cloud/vision/image.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
from google.cloud._helpers import _to_bytes
2121
from google.cloud._helpers import _bytes_to_unicode
22-
from google.cloud.vision.annotations import Annotations
2322
from google.cloud.vision.feature import Feature
2423
from google.cloud.vision.feature import FeatureTypes
2524

@@ -109,8 +108,7 @@ def _detect_annotation(self, features):
109108
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`,
110109
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`,
111110
"""
112-
results = self.client._vision_api.annotate(self, features)
113-
return Annotations.from_api_repr(results)
111+
return self.client._vision_api.annotate(self, features)
114112

115113
def detect(self, features):
116114
"""Detect multiple feature types.

vision/unit_tests/test__gax.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,28 @@ def test_ctor(self):
3232
api = self._make_one(client)
3333
self.assertIs(api._client, client)
3434

35+
def test_annotation(self):
36+
from google.cloud.vision.feature import Feature
37+
from google.cloud.vision.feature import FeatureTypes
38+
from google.cloud.vision.image import Image
39+
40+
client = mock.Mock()
41+
feature = Feature(FeatureTypes.LABEL_DETECTION, 5)
42+
image_content = b'abc 1 2 3'
43+
image = Image(client, content=image_content)
44+
with mock.patch('google.cloud.vision._gax.image_annotator_client.'
45+
'ImageAnnotatorClient'):
46+
api = self._make_one(client)
47+
48+
api._api = mock.Mock()
49+
mock_response = mock.Mock(responses=['mock response data'])
50+
api._api.batch_annotate_images.return_value = mock_response
51+
52+
with mock.patch('google.cloud.vision._gax.Annotations') as mock_anno:
53+
api.annotate(image, [feature])
54+
mock_anno.from_pb.assert_called_with('mock response data')
55+
api._api.batch_annotate_images.assert_called()
56+
3557

3658
class TestToGAPICFeature(unittest.TestCase):
3759
def _call_fut(self, feature):

0 commit comments

Comments
 (0)