diff --git a/firebase_admin/_db_utils.py b/firebase_admin/_db_utils.py new file mode 100644 index 000000000..8e352238c --- /dev/null +++ b/firebase_admin/_db_utils.py @@ -0,0 +1,88 @@ +# Copyright 2023 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal utilities for Firebase Realtime Database module""" + +import time +import random +import math + +_PUSH_CHARS = '-0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz' + +def time_now(): + return int(time.time()*1000) + +def _generate_next_push_id(): + """Creates a unique push id generator. + + Creates 20-character string identifiers with the following properties: + 1. They're based on timestamps so that they sort after any existing ids. + + 2. They contain 96-bits of random data after the timestamp so that IDs won't + collide with other clients' IDs. + + 3. They sort lexicographically*(so the timestamp is converted to characters + that will sort properly). + + 4. They're monotonically increasing. Even if you generate more than one in + the same timestamp, the latter ones will sort after the former ones. We do + this by using the previous random bits but "incrementing" them by 1 (only + in the case of a timestamp collision). + """ + + # Timestamp of last push, used to prevent local collisions if you push twice + # in one ms. + last_push_time = 0 + + # We generate 96-bits of randomness which get turned into 12 characters and + # appended to the timestamp to prevent collisions with other clients. We + # store the last characters we generated because in the event of a collision, + # we'll use those same characters except "incremented" by one. + last_rand_chars_indexes = [] + + def next_push_id(now): + nonlocal last_push_time + nonlocal last_rand_chars_indexes + is_duplicate_time = now == last_push_time + last_push_time = now + + push_id = '' + for _ in range(8): + push_id = _PUSH_CHARS[now % 64] + push_id + now = math.floor(now / 64) + + if not is_duplicate_time: + last_rand_chars_indexes = [] + for _ in range(12): + last_rand_chars_indexes.append(random.randrange(64)) + else: + for index in range(11, -1, -1): + if last_rand_chars_indexes[index] == 63: + last_rand_chars_indexes[index] = 0 + else: + break + if index != 0: + last_rand_chars_indexes[index] += 1 + elif index == 0 and last_rand_chars_indexes[index] != 0: + last_rand_chars_indexes[index] += 1 + + for index in range(12): + push_id += _PUSH_CHARS[last_rand_chars_indexes[index]] + + if len(push_id) != 20: + raise ValueError("push_id length should be 20") + return push_id + return next_push_id + +get_next_push_id = _generate_next_push_id() diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 890968796..51b29116d 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -34,7 +34,7 @@ from firebase_admin import _http_client from firebase_admin import _sseclient from firebase_admin import _utils - +from firebase_admin import _db_utils _DB_ATTRIBUTE = '_database' _INVALID_PATH_CHARACTERS = '[].?#$' @@ -301,12 +301,13 @@ def set_if_unchanged(self, expected_etag, value): raise error - def push(self, value=''): + def push(self, value=None): """Creates a new child node. - The optional value argument can be used to provide an initial value for the child node. If - no value is provided, child node will have empty string as the default value. - + The optional value argument can be used to provide an initial value for the child node. + If you provide a value, a child node is created and the value written to that location. + If you don't provide a value, the child node is created but nothing is written to the + database and the child remains empty (but you can use the Reference elsewhere). Args: value: JSON-serializable initial value for the child node (optional). @@ -314,14 +315,15 @@ def push(self, value=''): Reference: A Reference representing the newly created child node. Raises: - ValueError: If the value is None. TypeError: If the value is not JSON-serializable. FirebaseError: If an error occurs while communicating with the remote database server. """ - if value is None: - raise ValueError('Value must not be None.') - output = self._client.body('post', self._add_suffix(), json=value) - push_id = output.get('name') + now = _db_utils.time_now() + push_id = _db_utils.get_next_push_id(now) + push_ref = self.child(push_id) + + if value is not None: + push_ref.set(value) return self.child(push_id) def update(self, value): diff --git a/integration/test_db.py b/integration/test_db.py index c448436d6..9fcca8bde 100644 --- a/integration/test_db.py +++ b/integration/test_db.py @@ -149,7 +149,7 @@ def test_push(self, testref): python = testref.parent ref = python.child('users').push() assert ref.path == '/_adminsdk/python/users/' + ref.key - assert ref.get() == '' + assert ref.get() is None def test_push_with_value(self, testref): python = testref.parent @@ -158,6 +158,25 @@ def test_push_with_value(self, testref): assert ref.path == '/_adminsdk/python/users/' + ref.key assert ref.get() == value + def test_push_to_local_ref(self, testref): + python = testref.parent + ref1 = python.child('games').push() + assert ref1.get() is None + ref2 = ref1.push("card") + assert ref2.parent.key == ref1.key + assert ref1.get() == {ref2.key: 'card'} + assert ref2.get() == 'card' + + def test_push_set_local_ref(self, testref): + python = testref.parent + ref1 = python.child('games').push().child('card') + ref2 = ref1.push() + assert ref2.get() is None + ref3 = ref1.push('heart') + ref2.set('spade') + assert ref2.get() == 'spade' + assert ref1.parent.get() == {'card': {ref2.key: 'spade', ref3.key: 'heart'}} + def test_set_primitive_value(self, testref): python = testref.parent ref = python.child('users').push() diff --git a/tests/test_db.py b/tests/test_db.py index aa2c83bd9..11ed3cd6b 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -19,6 +19,7 @@ import sys import time +from unittest import mock import pytest import firebase_admin @@ -145,7 +146,7 @@ def get(cls, ref): @classmethod def push(cls, ref): - ref.push() + ref.push({'foo': 'bar'}) @classmethod def set(cls, ref): @@ -179,6 +180,8 @@ class TestReference: 500: exceptions.InternalError, } + duplicate_timestamp = time.time() + @classmethod def setup_class(cls): firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : cls.test_url}) @@ -392,33 +395,46 @@ def test_set_invalid_update(self, update): @pytest.mark.parametrize('data', valid_values) def test_push(self, data): ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps({'name' : 'testkey'})) + recorder = self.instrument(ref, json.dumps({})) child = ref.push(data) assert isinstance(child, db.Reference) - assert child.key == 'testkey' + assert len(child.key) == 20 assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' + assert recorder[0].method == 'PUT' + assert recorder[0].url == f'https://test.firebaseio.com/test/{child.key}.json?print=silent' assert json.loads(recorder[0].body.decode()) == data assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['User-Agent'] == db._USER_AGENT def test_push_default(self): ref = db.reference('/test') - recorder = self.instrument(ref, json.dumps({'name' : 'testkey'})) - assert ref.push().key == 'testkey' - assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert json.loads(recorder[0].body.decode()) == '' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + recorder = self.instrument(ref, json.dumps({})) + child = ref.push() + assert isinstance(child, db.Reference) + assert len(child.key) == 20 + assert len(recorder) == 0 - def test_push_none_value(self): + @pytest.mark.parametrize('data', valid_values) + @mock.patch('time.time', mock.MagicMock(return_value=duplicate_timestamp)) + def test_push_duplicate_timestamp(self, data): ref = db.reference('/test') - self.instrument(ref, '') - with pytest.raises(ValueError): - ref.push(None) + recorder = self.instrument(ref, json.dumps({})) + child = [] + child.append(ref.push(data)) + child.append(ref.push(data)) + key1 = child[0].key + key2 = child[1].key + # First 8 digits are the encoded timestamp + assert key1[:8] == key2[:8] + assert key2 > key1 + assert len(recorder) == 2 + for index, record in enumerate(recorder): + assert record.method == 'PUT' + assert record.url == \ + f'https://test.firebaseio.com/test/{child[index].key}.json?print=silent' + assert json.loads(record.body.decode()) == data + assert record.headers['Authorization'] == 'Bearer mock-token' + assert record.headers['User-Agent'] == db._USER_AGENT def test_delete(self): ref = db.reference('/test')