Skip to content
Next Next commit
Add 'Session.labels' property.
Read-only, set via ctor.
  • Loading branch information
tseaver committed Aug 2, 2018
commit d69bfe234fe9977c4427f2a3c0232696b3c763af
17 changes: 16 additions & 1 deletion spanner/google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,19 @@ class Session(object):

:type database: :class:`~google.cloud.spanner_v1.database.Database`
:param database: The database to which the session is bound.

:type labels: dict (str -> str)
:param labels: (Optional) User-assigned labels for the session.
"""

_session_id = None
_transaction = None

def __init__(self, database):
def __init__(self, database, labels=None):
self._database = database
if labels is None:
labels = {}
self._labels = labels

def __lt__(self, other):
return self._session_id < other._session_id
Expand All @@ -60,6 +66,15 @@ def session_id(self):
"""Read-only ID, set by the back-end during :meth:`create`."""
return self._session_id

@property
def labels(self):
"""User-assigned labels for the session.

:rtype: dict (str -> str)
:returns: the labels dict (empty if no labels were assigned.
"""
return self._labels

@property
def name(self):
"""Session name used in requests.
Expand Down
11 changes: 10 additions & 1 deletion spanner/tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,20 @@ def _getTargetClass(self):
def _make_one(self, *args, **kwargs):
return self._getTargetClass()(*args, **kwargs)

def test_constructor(self):
def test_constructor_wo_labels(self):
database = _Database(self.DATABASE_NAME)
session = self._make_one(database)
self.assertIs(session.session_id, None)
self.assertIs(session._database, database)
self.assertEqual(session.labels, {})

def test_constructor_w_labels(self):
database = _Database(self.DATABASE_NAME)
labels = {'foo': 'bar'}
session = self._make_one(database, labels=labels)
self.assertIs(session.session_id, None)
self.assertIs(session._database, database)
self.assertEqual(session.labels, labels)

def test___lt___(self):
database = _Database(self.DATABASE_NAME)
Expand Down