Skip to content

Commit e36e986

Browse files
authored
Spanner: add support for session / pool labels (googleapis#5734)
1 parent 06e860a commit e36e986

7 files changed

Lines changed: 609 additions & 342 deletions

File tree

spanner/google/cloud/spanner_v1/database.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,16 @@ def drop(self):
272272
metadata = _metadata_with_prefix(self.name)
273273
api.drop_database(self.name, metadata=metadata)
274274

275-
def session(self):
275+
def session(self, labels=None):
276276
"""Factory to create a session for this database.
277277
278+
:type labels: dict (str -> str) or None
279+
:param labels: (Optional) user-assigned labels for the session.
280+
278281
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
279282
:returns: a session bound to this database.
280283
"""
281-
return Session(self)
284+
return Session(self, labels=labels)
282285

283286
def snapshot(self, **kw):
284287
"""Return an object which wraps a snapshot.

spanner/google/cloud/spanner_v1/pool.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,28 @@
2626

2727

2828
class AbstractSessionPool(object):
29-
"""Specifies required API for concrete session pool implementations."""
29+
"""Specifies required API for concrete session pool implementations.
3030
31+
:type labels: dict (str -> str) or None
32+
:param labels: (Optional) user-assigned labels for sessions created
33+
by the pool.
34+
"""
3135
_database = None
3236

37+
def __init__(self, labels=None):
38+
if labels is None:
39+
labels = {}
40+
self._labels = labels
41+
42+
@property
43+
def labels(self):
44+
"""User-assigned labels for sesions created by the pool.
45+
46+
:rtype: dict (str -> str)
47+
:returns: labels assigned by the user
48+
"""
49+
return self._labels
50+
3351
def bind(self, database):
3452
"""Associate the pool with a database.
3553
@@ -80,6 +98,16 @@ def clear(self):
8098
"""
8199
raise NotImplementedError()
82100

101+
def _new_session(self):
102+
"""Helper for concrete methods creating session instances.
103+
104+
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
105+
:returns: new session instance.
106+
"""
107+
if self.labels:
108+
return self._database.session(labels=self.labels)
109+
return self._database.session()
110+
83111
def session(self, **kwargs):
84112
"""Check out a session from the pool.
85113
@@ -115,11 +143,17 @@ class FixedSizePool(AbstractSessionPool):
115143
:type default_timeout: int
116144
:param default_timeout: default timeout, in seconds, to wait for
117145
a returned session.
146+
147+
:type labels: dict (str -> str) or None
148+
:param labels: (Optional) user-assigned labels for sessions created
149+
by the pool.
118150
"""
119151
DEFAULT_SIZE = 10
120152
DEFAULT_TIMEOUT = 10
121153

122-
def __init__(self, size=DEFAULT_SIZE, default_timeout=DEFAULT_TIMEOUT):
154+
def __init__(self, size=DEFAULT_SIZE, default_timeout=DEFAULT_TIMEOUT,
155+
labels=None):
156+
super(FixedSizePool, self).__init__(labels=labels)
123157
self.size = size
124158
self.default_timeout = default_timeout
125159
self._sessions = queue.Queue(size)
@@ -134,7 +168,7 @@ def bind(self, database):
134168
self._database = database
135169

136170
while not self._sessions.full():
137-
session = database.session()
171+
session = self._new_session()
138172
session.create()
139173
self._sessions.put(session)
140174

@@ -198,9 +232,14 @@ class BurstyPool(AbstractSessionPool):
198232
199233
:type target_size: int
200234
:param target_size: max pool size
235+
236+
:type labels: dict (str -> str) or None
237+
:param labels: (Optional) user-assigned labels for sessions created
238+
by the pool.
201239
"""
202240

203-
def __init__(self, target_size=10):
241+
def __init__(self, target_size=10, labels=None):
242+
super(BurstyPool, self).__init__(labels=labels)
204243
self.target_size = target_size
205244
self._database = None
206245
self._sessions = queue.Queue(target_size)
@@ -224,11 +263,11 @@ def get(self):
224263
try:
225264
session = self._sessions.get_nowait()
226265
except queue.Empty:
227-
session = self._database.session()
266+
session = self._new_session()
228267
session.create()
229268
else:
230269
if not session.exists():
231-
session = self._database.session()
270+
session = self._new_session()
232271
session.create()
233272
return session
234273

@@ -290,9 +329,15 @@ class PingingPool(AbstractSessionPool):
290329
291330
:type ping_interval: int
292331
:param ping_interval: interval at which to ping sessions.
332+
333+
:type labels: dict (str -> str) or None
334+
:param labels: (Optional) user-assigned labels for sessions created
335+
by the pool.
293336
"""
294337

295-
def __init__(self, size=10, default_timeout=10, ping_interval=3000):
338+
def __init__(self, size=10, default_timeout=10, ping_interval=3000,
339+
labels=None):
340+
super(PingingPool, self).__init__(labels=labels)
296341
self.size = size
297342
self.default_timeout = default_timeout
298343
self._delta = datetime.timedelta(seconds=ping_interval)
@@ -308,7 +353,7 @@ def bind(self, database):
308353
self._database = database
309354

310355
for _ in xrange(self.size):
311-
session = database.session()
356+
session = self._new_session()
312357
session.create()
313358
self.put(session)
314359

@@ -330,7 +375,7 @@ def get(self, timeout=None): # pylint: disable=arguments-differ
330375

331376
if _NOW() > ping_after:
332377
if not session.exists():
333-
session = self._database.session()
378+
session = self._new_session()
334379
session.create()
335380

336381
return session
@@ -373,7 +418,7 @@ def ping(self):
373418
self._sessions.put((ping_after, session))
374419
break
375420
if not session.exists(): # stale
376-
session = self._database.session()
421+
session = self._new_session()
377422
session.create()
378423
# Re-add to queue with new expiration
379424
self.put(session)
@@ -400,13 +445,18 @@ class TransactionPingingPool(PingingPool):
400445
401446
:type ping_interval: int
402447
:param ping_interval: interval at which to ping sessions.
448+
449+
:type labels: dict (str -> str) or None
450+
:param labels: (Optional) user-assigned labels for sessions created
451+
by the pool.
403452
"""
404453

405-
def __init__(self, size=10, default_timeout=10, ping_interval=3000):
454+
def __init__(self, size=10, default_timeout=10, ping_interval=3000,
455+
labels=None):
406456
self._pending_sessions = queue.Queue()
407457

408458
super(TransactionPingingPool, self).__init__(
409-
size, default_timeout, ping_interval)
459+
size, default_timeout, ping_interval, labels=labels)
410460

411461
self.begin_pending_transactions()
412462

spanner/google/cloud/spanner_v1/session.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,19 @@ class Session(object):
4444
4545
:type database: :class:`~google.cloud.spanner_v1.database.Database`
4646
:param database: The database to which the session is bound.
47+
48+
:type labels: dict (str -> str)
49+
:param labels: (Optional) User-assigned labels for the session.
4750
"""
4851

4952
_session_id = None
5053
_transaction = None
5154

52-
def __init__(self, database):
55+
def __init__(self, database, labels=None):
5356
self._database = database
57+
if labels is None:
58+
labels = {}
59+
self._labels = labels
5460

5561
def __lt__(self, other):
5662
return self._session_id < other._session_id
@@ -60,6 +66,15 @@ def session_id(self):
6066
"""Read-only ID, set by the back-end during :meth:`create`."""
6167
return self._session_id
6268

69+
@property
70+
def labels(self):
71+
"""User-assigned labels for the session.
72+
73+
:rtype: dict (str -> str)
74+
:returns: the labels dict (empty if no labels were assigned.
75+
"""
76+
return self._labels
77+
6378
@property
6479
def name(self):
6580
"""Session name used in requests.
@@ -93,7 +108,14 @@ def create(self):
93108
raise ValueError('Session ID already set by back-end')
94109
api = self._database.spanner_api
95110
metadata = _metadata_with_prefix(self._database.name)
96-
session_pb = api.create_session(self._database.name, metadata=metadata)
111+
kw = {}
112+
if self._labels:
113+
kw = {'session': {'labels': self._labels}}
114+
session_pb = api.create_session(
115+
self._database.name,
116+
metadata=metadata,
117+
**kw
118+
)
97119
self._session_id = session_pb.name.split('/')[-1]
98120

99121
def exists(self):

spanner/tests/system/test_system.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ class TestDatabaseAPI(unittest.TestCase, _TestData):
239239

240240
@classmethod
241241
def setUpClass(cls):
242-
pool = BurstyPool()
242+
pool = BurstyPool(labels={'testcase': 'database_api'})
243243
cls._db = Config.INSTANCE.database(
244244
cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool)
245245
operation = cls._db.create()
@@ -264,7 +264,7 @@ def test_list_databases(self):
264264
self.assertTrue(self._db.name in database_names)
265265

266266
def test_create_database(self):
267-
pool = BurstyPool()
267+
pool = BurstyPool(labels={'testcase': 'create_database'})
268268
temp_db_id = 'temp_db' + unique_resource_id('_')
269269
temp_db = Config.INSTANCE.database(temp_db_id, pool=pool)
270270
operation = temp_db.create()
@@ -311,7 +311,7 @@ def test_table_not_found(self):
311311
'https://github.com/GoogleCloudPlatform/google-cloud-python/issues/'
312312
'5629'))
313313
def test_update_database_ddl(self):
314-
pool = BurstyPool()
314+
pool = BurstyPool(labels={'testcase': 'update_database_ddl'})
315315
temp_db_id = 'temp_db' + unique_resource_id('_')
316316
temp_db = Config.INSTANCE.database(temp_db_id, pool=pool)
317317
create_op = temp_db.create()
@@ -434,7 +434,7 @@ class TestSessionAPI(unittest.TestCase, _TestData):
434434

435435
@classmethod
436436
def setUpClass(cls):
437-
pool = BurstyPool()
437+
pool = BurstyPool(labels={'testcase': 'session_api'})
438438
cls._db = Config.INSTANCE.database(
439439
cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool)
440440
operation = cls._db.create()
@@ -902,7 +902,7 @@ def test_read_w_index(self):
902902
EXTRA_DDL = [
903903
'CREATE INDEX contacts_by_last_name ON contacts(last_name)',
904904
]
905-
pool = BurstyPool()
905+
pool = BurstyPool(labels={'testcase': 'read_w_index'})
906906
temp_db = Config.INSTANCE.database(
907907
'test_read' + unique_resource_id('_'),
908908
ddl_statements=DDL_STATEMENTS + EXTRA_DDL,

spanner/tests/unit/test_database.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ def test_drop_success(self):
596596
self.assertEqual(
597597
metadata, [('google-cloud-resource-prefix', database.name)])
598598

599-
def test_session_factory(self):
599+
def test_session_factory_defaults(self):
600600
from google.cloud.spanner_v1.session import Session
601601

602602
client = _Client()
@@ -609,6 +609,23 @@ def test_session_factory(self):
609609
self.assertTrue(isinstance(session, Session))
610610
self.assertIs(session.session_id, None)
611611
self.assertIs(session._database, database)
612+
self.assertEqual(session.labels, {})
613+
614+
def test_session_factory_w_labels(self):
615+
from google.cloud.spanner_v1.session import Session
616+
617+
client = _Client()
618+
instance = _Instance(self.INSTANCE_NAME, client=client)
619+
pool = _Pool()
620+
labels = {'foo': 'bar'}
621+
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
622+
623+
session = database.session(labels=labels)
624+
625+
self.assertTrue(isinstance(session, Session))
626+
self.assertIs(session.session_id, None)
627+
self.assertIs(session._database, database)
628+
self.assertEqual(session.labels, labels)
612629

613630
def test_snapshot_defaults(self):
614631
from google.cloud.spanner_v1.database import SnapshotCheckout

0 commit comments

Comments
 (0)