Skip to content

Commit c001746

Browse files
authored
Vision: Add batch processing (#2978)
* Add Vision batch support to the surface.
1 parent 1524fec commit c001746

File tree

9 files changed

+192
-46
lines changed

9 files changed

+192
-46
lines changed

packages/google-cloud-vision/google/cloud/vision/_gax.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,28 @@ def __init__(self, client=None):
3030
self._client = client
3131
self._annotator_client = image_annotator_client.ImageAnnotatorClient()
3232

33-
def annotate(self, image, features):
33+
def annotate(self, images):
3434
"""Annotate images through GAX.
3535
36-
:type image: :class:`~google.cloud.vision.image.Image`
37-
:param image: Instance of ``Image``.
38-
39-
:type features: list
40-
:param features: List of :class:`~google.cloud.vision.feature.Feature`.
36+
:type images: list
37+
:param images: List containing pairs of
38+
:class:`~google.cloud.vision.image.Image` and
39+
:class:`~google.cloud.vision.feature.Feature`.
40+
e.g. [(image, [feature_one, feature_two]),]
4141
4242
:rtype: list
4343
:returns: List of
4444
:class:`~google.cloud.vision.annotations.Annotations`.
4545
"""
46-
gapic_features = [_to_gapic_feature(feature) for feature in features]
47-
gapic_image = _to_gapic_image(image)
48-
request = image_annotator_pb2.AnnotateImageRequest(
49-
image=gapic_image, features=gapic_features)
50-
requests = [request]
46+
requests = []
47+
for image, features in images:
48+
gapic_features = [_to_gapic_feature(feature)
49+
for feature in features]
50+
gapic_image = _to_gapic_image(image)
51+
request = image_annotator_pb2.AnnotateImageRequest(
52+
image=gapic_image, features=gapic_features)
53+
requests.append(request)
54+
5155
annotator_client = self._annotator_client
5256
responses = annotator_client.batch_annotate_images(requests).responses
5357
return [Annotations.from_pb(response) for response in responses]

packages/google-cloud-vision/google/cloud/vision/_http.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,19 @@ def __init__(self, client):
2929
self._client = client
3030
self._connection = client._connection
3131

32-
def annotate(self, image, features):
32+
def annotate(self, images):
3333
"""Annotate an image to discover it's attributes.
3434
35-
:type image: :class:`~google.cloud.vision.image.Image`
36-
:param image: A instance of ``Image``.
35+
:type images: list of :class:`~google.cloud.vision.image.Image`
36+
:param images: A list of ``Image``.
3737
38-
:type features: list of :class:`~google.cloud.vision.feature.Feature`
39-
:param features: The type of detection that the Vision API should
40-
use to determine image attributes. Pricing is
41-
based on the number of Feature Types.
42-
43-
See: https://cloud.google.com/vision/docs/pricing
4438
:rtype: list
4539
:returns: List of :class:`~googe.cloud.vision.annotations.Annotations`.
4640
"""
47-
request = _make_request(image, features)
48-
49-
data = {'requests': [request]}
41+
requests = []
42+
for image, features in images:
43+
requests.append(_make_request(image, features))
44+
data = {'requests': requests}
5045
api_response = self._connection.api_request(
5146
method='POST', path='/images:annotate', data=data)
5247
responses = api_response.get('responses')
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2017 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Batch multiple images into one request."""
16+
17+
18+
class Batch(object):
19+
"""Batch of images to process.
20+
21+
:type client: :class:`~google.cloud.vision.client.Client`
22+
:param client: Vision client.
23+
"""
24+
def __init__(self, client):
25+
self._client = client
26+
self._images = []
27+
28+
def add_image(self, image, features):
29+
"""Add image to batch request.
30+
31+
:type image: :class:`~google.cloud.vision.image.Image`
32+
:param image: Istance of ``Image``.
33+
34+
:type features: list
35+
:param features: List of :class:`~google.cloud.vision.feature.Feature`.
36+
"""
37+
self._images.append((image, features))
38+
39+
@property
40+
def images(self):
41+
"""List of images to process.
42+
43+
:rtype: list
44+
:returns: List of :class:`~google.cloud.vision.image.Image`.
45+
"""
46+
return self._images
47+
48+
def detect(self):
49+
"""Perform batch detection of images.
50+
51+
:rtype: list
52+
:returns: List of
53+
:class:`~google.cloud.vision.annotations.Annotations`.
54+
"""
55+
results = self._client._vision_api.annotate(self.images)
56+
self._images = []
57+
return results

packages/google-cloud-vision/google/cloud/vision/client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from google.cloud.environment_vars import DISABLE_GRPC
2121

2222
from google.cloud.vision._gax import _GAPICVisionAPI
23+
from google.cloud.vision.batch import Batch
2324
from google.cloud.vision.connection import Connection
2425
from google.cloud.vision.image import Image
2526
from google.cloud.vision._http import _HTTPVisionAPI
@@ -71,6 +72,14 @@ def __init__(self, project=None, credentials=None, http=None,
7172
else:
7273
self._use_gax = use_gax
7374

75+
def batch(self):
76+
"""Batch multiple images into a single API request.
77+
78+
:rtype: :class:`google.cloud.vision.batch.Batch`
79+
:returns: Instance of ``Batch``.
80+
"""
81+
return Batch(self)
82+
7483
def image(self, content=None, filename=None, source_uri=None):
7584
"""Get instance of Image using current client.
7685

packages/google-cloud-vision/google/cloud/vision/image.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -94,21 +94,17 @@ def source(self):
9494
"""
9595
return self._source
9696

97-
def _detect_annotation(self, features):
97+
def _detect_annotation(self, images):
9898
"""Generic method for detecting annotations.
9999
100-
:type features: list
101-
:param features: List of :class:`~google.cloud.vision.feature.Feature`
102-
indicating the type of annotations to perform.
100+
:type images: list
101+
:param images: List of :class:`~google.cloud.vision.image.Image`.
103102
104103
:rtype: list
105104
:returns: List of
106-
:class:`~google.cloud.vision.entity.EntityAnnotation`,
107-
:class:`~google.cloud.vision.face.Face`,
108-
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`,
109-
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`,
105+
:class:`~google.cloud.vision.annotations.Annotations`.
110106
"""
111-
return self.client._vision_api.annotate(self, features)
107+
return self.client._vision_api.annotate(images)
112108

113109
def detect(self, features):
114110
"""Detect multiple feature types.
@@ -121,7 +117,8 @@ def detect(self, features):
121117
:returns: List of
122118
:class:`~google.cloud.vision.entity.EntityAnnotation`.
123119
"""
124-
return self._detect_annotation(features)
120+
images = ((self, features),)
121+
return self._detect_annotation(images)
125122

126123
def detect_faces(self, limit=10):
127124
"""Detect faces in image.
@@ -133,7 +130,7 @@ def detect_faces(self, limit=10):
133130
:returns: List of :class:`~google.cloud.vision.face.Face`.
134131
"""
135132
features = [Feature(FeatureTypes.FACE_DETECTION, limit)]
136-
annotations = self._detect_annotation(features)
133+
annotations = self.detect(features)
137134
return annotations[0].faces
138135

139136
def detect_labels(self, limit=10):
@@ -146,7 +143,7 @@ def detect_labels(self, limit=10):
146143
:returns: List of :class:`~google.cloud.vision.entity.EntityAnnotation`
147144
"""
148145
features = [Feature(FeatureTypes.LABEL_DETECTION, limit)]
149-
annotations = self._detect_annotation(features)
146+
annotations = self.detect(features)
150147
return annotations[0].labels
151148

152149
def detect_landmarks(self, limit=10):
@@ -160,7 +157,7 @@ def detect_landmarks(self, limit=10):
160157
:class:`~google.cloud.vision.entity.EntityAnnotation`.
161158
"""
162159
features = [Feature(FeatureTypes.LANDMARK_DETECTION, limit)]
163-
annotations = self._detect_annotation(features)
160+
annotations = self.detect(features)
164161
return annotations[0].landmarks
165162

166163
def detect_logos(self, limit=10):
@@ -174,7 +171,7 @@ def detect_logos(self, limit=10):
174171
:class:`~google.cloud.vision.entity.EntityAnnotation`.
175172
"""
176173
features = [Feature(FeatureTypes.LOGO_DETECTION, limit)]
177-
annotations = self._detect_annotation(features)
174+
annotations = self.detect(features)
178175
return annotations[0].logos
179176

180177
def detect_properties(self, limit=10):
@@ -188,7 +185,7 @@ def detect_properties(self, limit=10):
188185
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`.
189186
"""
190187
features = [Feature(FeatureTypes.IMAGE_PROPERTIES, limit)]
191-
annotations = self._detect_annotation(features)
188+
annotations = self.detect(features)
192189
return annotations[0].properties
193190

194191
def detect_safe_search(self, limit=10):
@@ -202,7 +199,7 @@ def detect_safe_search(self, limit=10):
202199
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`.
203200
"""
204201
features = [Feature(FeatureTypes.SAFE_SEARCH_DETECTION, limit)]
205-
annotations = self._detect_annotation(features)
202+
annotations = self.detect(features)
206203
return annotations[0].safe_searches
207204

208205
def detect_text(self, limit=10):
@@ -216,5 +213,5 @@ def detect_text(self, limit=10):
216213
:class:`~google.cloud.vision.entity.EntityAnnotation`.
217214
"""
218215
features = [Feature(FeatureTypes.TEXT_DETECTION, limit)]
219-
annotations = self._detect_annotation(features)
216+
annotations = self.detect(features)
220217
return annotations[0].texts

packages/google-cloud-vision/unit_tests/test__gax.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def test_annotation(self):
5454
spec_set=['batch_annotate_images'], **mock_response)
5555

5656
with mock.patch('google.cloud.vision._gax.Annotations') as mock_anno:
57-
gax_api.annotate(image, [feature])
57+
images = ((image, [feature]),)
58+
gax_api.annotate(images)
5859
mock_anno.from_pb.assert_called_with('mock response data')
5960
gax_api._annotator_client.batch_annotate_images.assert_called()
6061

@@ -78,7 +79,8 @@ def test_annotate_no_results(self):
7879
gax_api._annotator_client = mock.Mock(
7980
spec_set=['batch_annotate_images'], **mock_response)
8081
with mock.patch('google.cloud.vision._gax.Annotations'):
81-
response = gax_api.annotate(image, [feature])
82+
images = ((image, [feature]),)
83+
response = gax_api.annotate(images)
8284
self.assertEqual(len(response), 0)
8385
self.assertIsInstance(response, list)
8486

@@ -109,7 +111,8 @@ def test_annotate_multiple_results(self):
109111
gax_api._annotator_client = mock.Mock(
110112
spec_set=['batch_annotate_images'])
111113
gax_api._annotator_client.batch_annotate_images.return_value = response
112-
responses = gax_api.annotate(image, [feature])
114+
images = ((image, [feature]),)
115+
responses = gax_api.annotate(images)
113116

114117
self.assertEqual(len(responses), 2)
115118
self.assertIsInstance(responses[0], Annotations)

packages/google-cloud-vision/unit_tests/test__http.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def test_call_annotate_with_no_results(self):
4444
http_api = self._make_one(client)
4545
http_api._connection = mock.Mock(spec_set=['api_request'])
4646
http_api._connection.api_request.return_value = {'responses': []}
47-
response = http_api.annotate(image, [feature])
47+
images = ((image, [feature]),)
48+
response = http_api.annotate(images)
4849
self.assertEqual(len(response), 0)
4950
self.assertIsInstance(response, list)
5051

@@ -63,7 +64,8 @@ def test_call_annotate_with_more_than_one_result(self):
6364
http_api = self._make_one(client)
6465
http_api._connection = mock.Mock(spec_set=['api_request'])
6566
http_api._connection.api_request.return_value = MULTIPLE_RESPONSE
66-
responses = http_api.annotate(image, [feature])
67+
images = ((image, [feature]),)
68+
responses = http_api.annotate(images)
6769

6870
self.assertEqual(len(responses), 2)
6971
image_one = responses[0]
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2017 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import mock
18+
19+
PROJECT = 'PROJECT'
20+
21+
22+
def _make_credentials():
23+
import google.auth.credentials
24+
return mock.Mock(spec=google.auth.credentials.Credentials)
25+
26+
27+
class TestBatch(unittest.TestCase):
28+
@staticmethod
29+
def _get_target_class():
30+
from google.cloud.vision.batch import Batch
31+
32+
return Batch
33+
34+
def _make_one(self, *args, **kw):
35+
return self._get_target_class()(*args, **kw)
36+
37+
def test_ctor(self):
38+
from google.cloud.vision.feature import Feature
39+
from google.cloud.vision.feature import FeatureTypes
40+
from google.cloud.vision.image import Image
41+
42+
client = mock.Mock()
43+
image = Image(client, source_uri='gs://images/imageone.jpg')
44+
face_feature = Feature(FeatureTypes.FACE_DETECTION, 5)
45+
logo_feature = Feature(FeatureTypes.LOGO_DETECTION, 3)
46+
47+
batch = self._make_one(client)
48+
batch.add_image(image, [logo_feature, face_feature])
49+
self.assertEqual(len(batch.images), 1)
50+
self.assertEqual(len(batch.images[0]), 2)
51+
self.assertIsInstance(batch.images[0][0], Image)
52+
self.assertEqual(len(batch.images[0][1]), 2)
53+
self.assertIsInstance(batch.images[0][1][0], Feature)
54+
self.assertIsInstance(batch.images[0][1][1], Feature)
55+
56+
def test_batch_from_client(self):
57+
from google.cloud.vision.client import Client
58+
from google.cloud.vision.feature import Feature
59+
from google.cloud.vision.feature import FeatureTypes
60+
61+
creds = _make_credentials()
62+
client = Client(project=PROJECT, credentials=creds)
63+
64+
image_one = client.image(source_uri='gs://images/imageone.jpg')
65+
image_two = client.image(source_uri='gs://images/imagtwo.jpg')
66+
face_feature = Feature(FeatureTypes.FACE_DETECTION, 5)
67+
logo_feature = Feature(FeatureTypes.LOGO_DETECTION, 3)
68+
client._vision_api_internal = mock.Mock()
69+
client._vision_api_internal.annotate.return_value = True
70+
batch = client.batch()
71+
batch.add_image(image_one, [face_feature])
72+
batch.add_image(image_two, [logo_feature, face_feature])
73+
images = batch.images
74+
self.assertEqual(len(images), 2)
75+
self.assertTrue(batch.detect())
76+
self.assertEqual(len(batch.images), 0)
77+
client._vision_api_internal.annotate.assert_called_with(images)

packages/google-cloud-vision/unit_tests/test_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class TestClient(unittest.TestCase):
3333
@staticmethod
3434
def _get_target_class():
3535
from google.cloud.vision.client import Client
36+
3637
return Client
3738

3839
def _make_one(self, *args, **kw):
@@ -104,7 +105,8 @@ def test_face_annotation(self):
104105
features = [Feature(feature_type=FeatureTypes.FACE_DETECTION,
105106
max_results=3)]
106107
image = client.image(content=IMAGE_CONTENT)
107-
api_response = client._vision_api.annotate(image, features)
108+
images = ((image, features),)
109+
api_response = client._vision_api.annotate(images)
108110

109111
self.assertEqual(len(api_response), 1)
110112
response = api_response[0]

0 commit comments

Comments
 (0)