From 01cf42b7813de0fde92bbf77b5021b59b941153b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 24 May 2023 11:07:33 -0500 Subject: [PATCH 1/2] feat: speed up the service registry Every lookup was doing .lower() on the input. Since we only feed this data in via handlers and core, we should lower the name before feeding it into the registry to avoid doing it 1000s of times --- src/zeroconf/_core.py | 2 +- src/zeroconf/_handlers.py | 17 ++++---- src/zeroconf/_services/registry.py | 4 +- tests/services/test_registry.py | 19 --------- tests/test_handlers.py | 68 ++++++++++++++++++++++++++++++ 5 files changed, 80 insertions(+), 30 deletions(-) diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 18823ef2..a55f55e8 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -745,7 +745,7 @@ async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: # goodbye packets for the address records assert info.server is not None - entries = self.registry.async_get_infos_server(info.server) + entries = self.registry.async_get_infos_server(info.server.lower()) broadcast_addresses = not bool(entries) return asyncio.ensure_future( self._async_broadcast_service(info, _UNREGISTER_TIME, 0, broadcast_addresses) diff --git a/src/zeroconf/_handlers.py b/src/zeroconf/_handlers.py index 240deb47..159fd0d5 100644 --- a/src/zeroconf/_handlers.py +++ b/src/zeroconf/_handlers.py @@ -255,10 +255,10 @@ def _add_service_type_enumeration_query_answers( answer_set[dns_pointer] = set() def _add_pointer_answers( - self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float + self, lower_name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float ) -> None: """Answer PTR/ANY question.""" - for service in self.registry.async_get_infos_type(name): + for service in self.registry.async_get_infos_type(lower_name): # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. dns_pointer = service.dns_pointer(created=now) @@ -270,14 +270,14 @@ def _add_pointer_answers( def _add_address_answers( self, - name: str, + lower_name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float, type_: int, ) -> None: """Answer A/AAAA/ANY question.""" - for service in self.registry.async_get_infos_server(name): + for service in self.registry.async_get_infos_server(lower_name): answers: List[DNSAddress] = [] additionals: Set[DNSRecord] = set() seen_types: Set[int] = set() @@ -305,21 +305,22 @@ def _answer_question( now: float, ) -> _AnswerWithAdditionalsType: answer_set: _AnswerWithAdditionalsType = {} + question_lower_name = question.name.lower() - if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: + if question.type == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME: self._add_service_type_enumeration_query_answers(answer_set, known_answers, now) return answer_set type_ = question.type if type_ in (_TYPE_PTR, _TYPE_ANY): - self._add_pointer_answers(question.name, answer_set, known_answers, now) + self._add_pointer_answers(question_lower_name, answer_set, known_answers, now) if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY): - self._add_address_answers(question.name, answer_set, known_answers, now, type_) + self._add_address_answers(question_lower_name, answer_set, known_answers, now, type_) if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY): - service = self.registry.async_get_info_name(question.name) + service = self.registry.async_get_info_name(question_lower_name) if service is not None: if type_ in (_TYPE_SRV, _TYPE_ANY): # Add recommended additional answers according to diff --git a/src/zeroconf/_services/registry.py b/src/zeroconf/_services/registry.py index b3dba674..1c4ad085 100644 --- a/src/zeroconf/_services/registry.py +++ b/src/zeroconf/_services/registry.py @@ -60,7 +60,7 @@ def async_get_service_infos(self) -> List[ServiceInfo]: def async_get_info_name(self, name: str) -> Optional[ServiceInfo]: """Return all ServiceInfo for the name.""" - return self._services.get(name.lower()) + return self._services.get(name) def async_get_types(self) -> List[str]: """Return all types.""" @@ -76,7 +76,7 @@ def async_get_infos_server(self, server: str) -> List[ServiceInfo]: def _async_get_by_index(self, records: Dict[str, List], key: str) -> List[ServiceInfo]: """Return all ServiceInfo matching the index.""" - return [self._services[name] for name in records.get(key.lower(), [])] + return [self._services[name] for name in records.get(key, [])] def _add(self, info: ServiceInfo) -> None: """Add a new service under the lock.""" diff --git a/tests/services/test_registry.py b/tests/services/test_registry.py index 3207b14e..f8656e2f 100644 --- a/tests/services/test_registry.py +++ b/tests/services/test_registry.py @@ -110,22 +110,3 @@ def test_lookups_upper_case_by_lower_case(self): assert registry.async_get_infos_type(type_.lower()) == [info] assert registry.async_get_infos_server("ash-2.local.") == [info] assert registry.async_get_types() == [type_.lower()] - - def test_lookups_lower_case_by_upper_case(self): - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = f"{name}.{type_}" - - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] - ) - - registry = r.ServiceRegistry() - registry.async_add(info) - - assert registry.async_get_service_infos() == [info] - assert registry.async_get_info_name(registration_name.upper()) == info - assert registry.async_get_infos_type(type_.upper()) == [info] - assert registry.async_get_infos_server("ASH-2.local.") == [info] - assert registry.async_get_types() == [type_] diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 0a976d3d..e3262a85 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -340,6 +340,32 @@ def test_aaaa_query(): zc.close() +@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') +@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') +def test_aaaa_query_upper_case(): + """Test that queries for AAAA records work and should respond right away with an upper case name.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_knownaaaservice._tcp.local." + name = "knownname" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address]) + zc.registry.async_add(info) + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(server_name.upper(), const._TYPE_AAAA, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + mcast_answers = list(question_answers.mcast_now) + assert mcast_answers[0].address == ipv6_address # type: ignore[attr-defined] + # unregister + zc.registry.async_remove(info) + zc.close() + + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_a_and_aaaa_record_fate_sharing(): @@ -481,6 +507,48 @@ async def test_probe_answered_immediately(): zc.close() +@pytest.mark.asyncio +async def test_probe_answered_immediately_with_uppercase_name(): + """Verify probes are responded to immediately with an uppercase name.""" + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # service definition + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info) + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type.upper(), const._TYPE_PTR, const._CLASS_IN) + query.add_question(question) + query.add_authorative_answer(info.dns_pointer()) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert not question_answers.ucast + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + assert question_answers.mcast_now + + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unicast = True + query.add_question(question) + query.add_authorative_answer(info.dns_pointer()) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert question_answers.ucast + assert question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + zc.close() + + def test_qu_response(): """Handle multicast incoming with the QU bit set.""" # instantiate a zeroconf instance From 8a00ae975bcaa511ba04b4694e363928c816f64b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 24 May 2023 11:10:50 -0500 Subject: [PATCH 2/2] fix: add more tests --- tests/test_handlers.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index e3262a85..c1c0a9a7 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -910,6 +910,45 @@ def test_known_answer_supression_service_type_enumeration_query(): zc.close() +def test_upper_case_enumeration_query(): + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_otherknown._tcp.local." + name = "knownname" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info) + + type_2 = "_otherknown2._tcp.local." + name = "knownname" + registration_name2 = f"{name}.{type_2}" + desc = {'path': '/~paulsm/'} + server_name2 = "ash-3.local." + info2 = ServiceInfo( + type_2, registration_name2, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info2) + _clear_cache(zc) + + # Test PTR supression + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME.upper(), const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + # unregister + zc.registry.async_remove(info) + zc.registry.async_remove(info2) + zc.close() + + # This test uses asyncio because it needs to access the cache directly # which is not threadsafe @pytest.mark.asyncio