"""Multicast DNS Service Discovery for Python, v0.14-wmcbrine
Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
This module provides a framework for the use of DNS Service Discovery
using IP multicast.
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
USA
"""
from __future__ import annotations
import asyncio
import logging
import random
import sys
import threading
from collections.abc import Awaitable
from types import TracebackType
from ._cache import DNSCache
from ._dns import DNSQuestion, DNSQuestionType
from ._engine import AsyncEngine
from ._exceptions import NonUniqueNameException, NotRunningException
from ._handlers.multicast_outgoing_queue import MulticastOutgoingQueue
from ._handlers.query_handler import QueryHandler
from ._handlers.record_manager import RecordManager
from ._history import QuestionHistory
from ._logger import QuietLogger, log
from ._protocol.outgoing import DNSOutgoing
from ._services import ServiceListener
from ._services.browser import ServiceBrowser
from ._services.info import (
AsyncServiceInfo,
ServiceInfo,
instance_name_from_service_info,
)
from ._services.registry import ServiceRegistry
from ._transport import _WrappedTransport
from ._updates import RecordUpdateListener
from ._utils.asyncio import (
_resolve_all_futures_to_none,
await_awaitable,
get_running_loop,
run_coro_with_timeout,
shutdown_loop,
wait_for_future_set_or_timeout,
wait_future_or_timeout,
)
from ._utils.name import service_type_name
from ._utils.net import (
InterfaceChoice,
InterfacesType,
IPVersion,
autodetect_ip_version,
can_send_to,
create_sockets,
)
from ._utils.time import current_time_millis, millis_to_seconds
from .const import (
_CHECK_TIME,
_CLASS_IN,
_CLASS_UNIQUE,
_FLAGS_AA,
_FLAGS_QR_QUERY,
_FLAGS_QR_RESPONSE,
_MAX_MSG_ABSOLUTE,
_MDNS_ADDR,
_MDNS_ADDR6,
_MDNS_PORT,
_ONE_SECOND,
_REGISTER_TIME,
_STARTUP_TIMEOUT,
_TYPE_PTR,
_UNREGISTER_TIME,
)
# The maximum amount of time to delay a multicast
# response in order to aggregate answers
_AGGREGATION_DELAY = 500 # ms
# The maximum amount of time to delay a multicast
# response in order to aggregate answers after
# it has already been delayed to protect the network
# from excessive traffic. We use a shorter time
# window here as we want to _try_ to answer all
# queries in under 1350ms while protecting
# the network from excessive traffic to ensure
# a service info request with two questions
# can be answered in the default timeout of
# 3000ms
_PROTECTED_AGGREGATION_DELAY = 200 # ms
_REGISTER_BROADCASTS = 3
# RFC 6762 §8.1 thundering-herd avoidance: wait a random
# 0-250ms before the first probe so simultaneously-started
# responders don't collide. We default to 150-250ms to
# preserve existing timing assumptions; tests on loopback
# may patch this lower via the `quick_timing` fixture.
_PROBE_RANDOM_DELAY_INTERVAL = (150, 250) # ms
def async_send_with_transport(
log_debug: bool,
transport: _WrappedTransport,
packet: bytes,
packet_num: int,
out: DNSOutgoing,
addr: str | None,
port: int,
v6_flow_scope: tuple[()] | tuple[int, int] = (),
) -> None:
ipv6_socket = transport.is_ipv6
if addr is None:
real_addr = _MDNS_ADDR6 if ipv6_socket else _MDNS_ADDR
else:
real_addr = addr
if not can_send_to(ipv6_socket, real_addr):
return
if log_debug:
log.debug(
"Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...",
real_addr,
port or _MDNS_PORT,
transport.fileno,
transport.sock_name,
len(packet),
packet_num + 1,
out,
packet,
)
# Get flowinfo and scopeid for the IPV6 socket to create a complete IPv6
# address tuple: https://docs.python.org/3.6/library/socket.html#socket-families
if ipv6_socket and not v6_flow_scope:
_, _, sock_flowinfo, sock_scopeid = transport.sock_name
v6_flow_scope = (sock_flowinfo, sock_scopeid)
transport.transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope))
[docs]
class Zeroconf(QuietLogger):
"""Implementation of Zeroconf Multicast DNS Service Discovery
Supports registration, unregistration, queries and browsing.
"""
def __init__(
self,
interfaces: InterfacesType = InterfaceChoice.All,
unicast: bool = False,
ip_version: IPVersion | None = None,
apple_p2p: bool = False,
use_asyncio: bool | None = None,
) -> None:
"""Creates an instance of the Zeroconf class, establishing
multicast communications, listening and reaping threads.
:param interfaces: :class:`InterfaceChoice` or a list of IP addresses
(IPv4 and IPv6) and interface indexes (IPv6 only).
IPv6 notes for non-POSIX systems:
* `InterfaceChoice.All` is an alias for `InterfaceChoice.Default`
on Python versions before 3.8.
Also listening on loopback (``::1``) doesn't work, use a real address.
:param ip_version: IP versions to support. If `choice` is a list, the default is detected
from it. Otherwise defaults to V4 only for backward compatibility.
:param apple_p2p: use AWDL interface (only macOS)
:param use_asyncio: explicitly control whether to attach to the running
asyncio event loop (``True``) or run an internal thread with its
own loop (``False``). ``None`` (default) keeps the historic
behavior: attach if an event loop is running, otherwise start a
thread. Set to ``False`` when running inside an environment that
already has an event loop (e.g. Jupyter) but you want blocking
semantics. ``True`` raises :class:`RuntimeError` immediately if no
running event loop is found, instead of falling back to the thread.
"""
if ip_version is None:
ip_version = autodetect_ip_version(interfaces)
self.done = False
if apple_p2p and sys.platform != "darwin":
raise RuntimeError("Option `apple_p2p` is not supported on non-Apple platforms.")
if use_asyncio is True and get_running_loop() is None:
raise RuntimeError("use_asyncio=True requires a running asyncio event loop")
self.unicast = unicast
self._use_asyncio = use_asyncio
listen_socket, respond_sockets = create_sockets(interfaces, unicast, ip_version, apple_p2p=apple_p2p)
log.debug("Listen socket %s, respond sockets %s", listen_socket, respond_sockets)
self.engine = AsyncEngine(self, listen_socket, respond_sockets)
self.browsers: dict[ServiceListener, ServiceBrowser] = {}
self.registry = ServiceRegistry()
self.cache = DNSCache()
self.question_history = QuestionHistory()
self.out_queue = MulticastOutgoingQueue(self, 0, _AGGREGATION_DELAY)
self.out_delay_queue = MulticastOutgoingQueue(self, _ONE_SECOND, _PROTECTED_AGGREGATION_DELAY)
self.query_handler = QueryHandler(self)
self.record_manager = RecordManager(self)
self._notify_futures: set[asyncio.Future] = set()
self.loop: asyncio.AbstractEventLoop | None = None
self._loop_thread: threading.Thread | None = None
self.start()
@property
def started(self) -> bool:
"""Check if the instance has started."""
running_future = self.engine.running_future
return bool(
not self.done
and running_future
and running_future.done()
and not running_future.cancelled()
and not running_future.exception()
and running_future.result()
)
[docs]
def start(self) -> None:
"""Start Zeroconf."""
self.loop = None if self._use_asyncio is False else get_running_loop()
if self.loop:
self.engine.setup(self.loop, None)
return
self._start_thread()
def _start_thread(self) -> None:
"""Start a thread with a running event loop."""
loop_thread_ready = threading.Event()
def _run_loop() -> None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.engine.setup(self.loop, loop_thread_ready)
self.loop.run_forever()
self._loop_thread = threading.Thread(target=_run_loop, daemon=True)
self._loop_thread.start()
loop_thread_ready.wait()
[docs]
async def async_wait_for_start(self, timeout: float = _STARTUP_TIMEOUT) -> None:
"""Wait for start up for actions that require a running Zeroconf instance.
Throws NotRunningException if the instance is not running or could
not be started.
"""
if self.done: # If the instance was shutdown from under us, raise immediately
raise NotRunningException
assert self.engine.running_future is not None
await wait_future_or_timeout(self.engine.running_future, timeout=timeout)
if not self.started:
raise NotRunningException
@property
def listeners(self) -> set[RecordUpdateListener]:
return self.record_manager.listeners
[docs]
async def async_wait(self, timeout: float) -> None:
"""Calling task waits for a given number of milliseconds or until notified."""
loop = self.loop
assert loop is not None
await wait_for_future_set_or_timeout(loop, self._notify_futures, timeout)
[docs]
def notify_all(self) -> None:
"""Notifies all waiting threads and notify listeners."""
assert self.loop is not None
self.loop.call_soon_threadsafe(self.async_notify_all)
[docs]
def async_notify_all(self) -> None:
"""Schedule an async_notify_all."""
notify_futures = self._notify_futures
if notify_futures:
_resolve_all_futures_to_none(notify_futures)
[docs]
def get_service_info(
self,
type_: str,
name: str,
timeout: int = 3000,
question_type: DNSQuestionType | None = None,
) -> ServiceInfo | None:
"""Returns network's service information for a particular
name and type, or None if no service matches by the timeout,
which defaults to 3 seconds.
:param type_: fully qualified service type name
:param name: the name of the service
:param timeout: milliseconds to wait for a response
:param question_type: The type of questions to ask (DNSQuestionType.QM or DNSQuestionType.QU)
"""
info = ServiceInfo(type_, name)
if info.request(self, timeout, question_type):
return info
return None
[docs]
def add_service_listener(self, type_: str, listener: ServiceListener) -> None:
"""Adds a listener for a particular service type. This object
will then have its add_service and remove_service methods called when
services of that type become available and unavailable."""
self.remove_service_listener(listener)
self.browsers[listener] = ServiceBrowser(self, type_, listener)
[docs]
def remove_service_listener(self, listener: ServiceListener) -> None:
"""Removes a listener from the set that is currently listening."""
if listener in self.browsers:
self.browsers[listener].cancel()
del self.browsers[listener]
[docs]
def remove_all_service_listeners(self) -> None:
"""Removes a listener from the set that is currently listening."""
for listener in list(self.browsers):
self.remove_service_listener(listener)
[docs]
def register_service(
self,
info: ServiceInfo,
ttl: int | None = None,
allow_name_change: bool = False,
cooperating_responders: bool = False,
strict: bool = True,
) -> None:
"""Registers service information to the network with a default TTL.
Zeroconf will then respond to requests for information for that
service. The name of the service may be changed if needed to make
it unique on the network. Additionally multiple cooperating responders
can register the same service on the network for resilience
(if you want this behavior set `cooperating_responders` to `True`).
While it is not expected during normal operation,
this function may raise EventLoopBlocked if the underlying
call to `register_service` cannot be completed.
"""
assert self.loop is not None
run_coro_with_timeout(
await_awaitable(
self.async_register_service(info, ttl, allow_name_change, cooperating_responders, strict)
),
self.loop,
_REGISTER_TIME * _REGISTER_BROADCASTS,
)
[docs]
async def async_register_service(
self,
info: ServiceInfo,
ttl: int | None = None,
allow_name_change: bool = False,
cooperating_responders: bool = False,
strict: bool = True,
) -> Awaitable:
"""Registers service information to the network with a default TTL.
Zeroconf will then respond to requests for information for that
service. The name of the service may be changed if needed to make
it unique on the network. Additionally multiple cooperating responders
can register the same service on the network for resilience
(if you want this behavior set `cooperating_responders` to `True`)."""
if ttl is not None:
# ttl argument is used to maintain backward compatibility
# Setting TTLs via ServiceInfo is preferred
info.host_ttl = ttl
info.other_ttl = ttl
info.set_server_if_missing()
await self.async_wait_for_start()
await self.async_check_service(info, allow_name_change, cooperating_responders, strict)
self.registry.async_add(info)
return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None))
[docs]
def update_service(self, info: ServiceInfo) -> None:
"""Registers service information to the network with a default TTL.
Zeroconf will then respond to requests for information for that
service.
While it is not expected during normal operation,
this function may raise EventLoopBlocked if the underlying
call to `async_update_service` cannot be completed.
"""
assert self.loop is not None
run_coro_with_timeout(
await_awaitable(self.async_update_service(info)),
self.loop,
_REGISTER_TIME * _REGISTER_BROADCASTS,
)
[docs]
async def async_update_service(self, info: ServiceInfo) -> Awaitable:
"""Registers service information to the network with a default TTL.
Zeroconf will then respond to requests for information for that
service."""
self.registry.async_update(info)
return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None))
[docs]
async def async_get_service_info(
self,
type_: str,
name: str,
timeout: int = 3000,
question_type: DNSQuestionType | None = None,
) -> AsyncServiceInfo | None:
"""Returns network's service information for a particular
name and type, or None if no service matches by the timeout,
which defaults to 3 seconds.
:param type_: fully qualified service type name
:param name: the name of the service
:param timeout: milliseconds to wait for a response
:param question_type: The type of questions to ask (DNSQuestionType.QM or DNSQuestionType.QU)
"""
info = AsyncServiceInfo(type_, name)
if await info.async_request(self, timeout, question_type):
return info
return None
async def _async_broadcast_service(
self,
info: ServiceInfo,
interval: int,
ttl: int | None,
broadcast_addresses: bool = True,
) -> None:
"""Send a broadcasts to announce a service at intervals."""
for i in range(_REGISTER_BROADCASTS):
if i != 0:
await asyncio.sleep(millis_to_seconds(interval))
self.async_send(self.generate_service_broadcast(info, ttl, broadcast_addresses))
[docs]
def generate_service_broadcast(
self,
info: ServiceInfo,
ttl: int | None,
broadcast_addresses: bool = True,
) -> DNSOutgoing:
"""Generate a broadcast to announce a service."""
out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
self._add_broadcast_answer(out, info, ttl, broadcast_addresses)
return out
[docs]
def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: disable=no-self-use
"""Generate a query to lookup a service."""
out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA)
# https://datatracker.ietf.org/doc/html/rfc6762#section-8.1
# Because of the mDNS multicast rate-limiting
# rules, the probes SHOULD be sent as "QU" questions with the unicast-
# response bit set, to allow a defending host to respond immediately
# via unicast, instead of potentially having to wait before replying
# via multicast.
#
# _CLASS_UNIQUE is the "QU" bit
out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN | _CLASS_UNIQUE))
out.add_authorative_answer(info.dns_pointer())
return out
def _add_broadcast_answer( # pylint: disable=no-self-use
self,
out: DNSOutgoing,
info: ServiceInfo,
override_ttl: int | None,
broadcast_addresses: bool = True,
) -> None:
"""Add answers to broadcast a service."""
current_time_millis()
other_ttl = None if override_ttl is None else override_ttl
host_ttl = None if override_ttl is None else override_ttl
out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl), 0)
out.add_answer_at_time(info.dns_service(override_ttl=host_ttl), 0)
out.add_answer_at_time(info.dns_text(override_ttl=other_ttl), 0)
if broadcast_addresses:
for record in info.get_address_and_nsec_records(override_ttl=host_ttl):
out.add_answer_at_time(record, 0)
[docs]
def unregister_service(self, info: ServiceInfo) -> None:
"""Unregister a service.
While it is not expected during normal operation,
this function may raise EventLoopBlocked if the underlying
call to `async_unregister_service` cannot be completed.
"""
assert self.loop is not None
run_coro_with_timeout(
self.async_unregister_service(info),
self.loop,
_UNREGISTER_TIME * _REGISTER_BROADCASTS,
)
[docs]
async def async_unregister_service(self, info: ServiceInfo) -> Awaitable:
"""Unregister a service."""
info.set_server_if_missing()
self.registry.async_remove(info)
# If another server uses the same addresses, we do not want to send
# goodbye packets for the address records
assert info.server_key is not None
entries = self.registry.async_get_infos_server(info.server_key)
broadcast_addresses = not bool(entries)
return asyncio.ensure_future(
self._async_broadcast_service(info, _UNREGISTER_TIME, 0, broadcast_addresses)
)
[docs]
def generate_unregister_all_services(self) -> DNSOutgoing | None:
"""Generate a DNSOutgoing goodbye for all services and remove them from the registry."""
service_infos = self.registry.async_get_service_infos()
if not service_infos:
return None
out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
for info in service_infos:
self._add_broadcast_answer(out, info, 0)
self.registry.async_remove(service_infos)
return out
[docs]
async def async_unregister_all_services(self) -> None:
"""Unregister all registered services.
Unlike async_register_service and async_unregister_service, this
method does not return a future and is always expected to be
awaited since its only called at shutdown.
"""
# Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1
out = self.generate_unregister_all_services()
if not out:
return
for i in range(_REGISTER_BROADCASTS):
if i != 0:
await asyncio.sleep(millis_to_seconds(_UNREGISTER_TIME))
self.async_send(out)
[docs]
def unregister_all_services(self) -> None:
"""Unregister all registered services.
While it is not expected during normal operation,
this function may raise EventLoopBlocked if the underlying
call to `async_unregister_all_services` cannot be completed.
"""
assert self.loop is not None
run_coro_with_timeout(
self.async_unregister_all_services(),
self.loop,
_UNREGISTER_TIME * _REGISTER_BROADCASTS,
)
[docs]
async def async_check_service(
self,
info: ServiceInfo,
allow_name_change: bool,
cooperating_responders: bool = False,
strict: bool = True,
) -> None:
"""Checks the network for a unique service name, modifying the
ServiceInfo passed in if it is not unique."""
instance_name = instance_name_from_service_info(info, strict=strict)
if cooperating_responders:
return
# Wait a random amount of time up avoid collisions and avoid
# a thundering herd when multiple services are started on the network
await self.async_wait(random.randint(*_PROBE_RANDOM_DELAY_INTERVAL)) # noqa: S311
next_instance_number = 2
next_time = now = current_time_millis()
i = 0
while i < _REGISTER_BROADCASTS:
# check for a name conflict
while self.cache.current_entry_with_name_and_alias(info.type, info.name):
if not allow_name_change:
raise NonUniqueNameException
# change the name and look for a conflict
info.name = f"{instance_name}-{next_instance_number}.{info.type}"
next_instance_number += 1
service_type_name(info.name, strict=strict)
next_time = now
i = 0
if now < next_time:
await self.async_wait(next_time - now)
now = current_time_millis()
continue
self.async_send(self.generate_service_query(info))
i += 1
next_time += _CHECK_TIME
[docs]
def add_listener(
self,
listener: RecordUpdateListener,
question: DNSQuestion | list[DNSQuestion] | None,
) -> None:
"""Adds a listener for a given question. The listener will have
its update_record method called when information is available to
answer the question(s).
This function is threadsafe
"""
assert self.loop is not None
self.loop.call_soon_threadsafe(self.record_manager.async_add_listener, listener, question)
[docs]
def remove_listener(self, listener: RecordUpdateListener) -> None:
"""Removes a listener.
This function is threadsafe
"""
assert self.loop is not None
self.loop.call_soon_threadsafe(self.record_manager.async_remove_listener, listener)
[docs]
def async_add_listener(
self,
listener: RecordUpdateListener,
question: DNSQuestion | list[DNSQuestion] | None,
) -> None:
"""Adds a listener for a given question. The listener will have
its update_record method called when information is available to
answer the question(s).
This function is not threadsafe and must be called in the eventloop.
"""
self.record_manager.async_add_listener(listener, question)
[docs]
def async_remove_listener(self, listener: RecordUpdateListener) -> None:
"""Removes a listener.
This function is not threadsafe and must be called in the eventloop.
"""
self.record_manager.async_remove_listener(listener)
[docs]
def send(
self,
out: DNSOutgoing,
addr: str | None = None,
port: int = _MDNS_PORT,
v6_flow_scope: tuple[()] | tuple[int, int] = (),
transport: _WrappedTransport | None = None,
) -> None:
"""Sends an outgoing packet threadsafe."""
assert self.loop is not None
self.loop.call_soon_threadsafe(self.async_send, out, addr, port, v6_flow_scope, transport)
[docs]
def async_send(
self,
out: DNSOutgoing,
addr: str | None = None,
port: int = _MDNS_PORT,
v6_flow_scope: tuple[()] | tuple[int, int] = (),
transport: _WrappedTransport | None = None,
) -> None:
"""Sends an outgoing packet."""
if self.done:
return
# If no transport is specified, we send to all the ones
# with the same address family
transports = [transport] if transport else self.engine.senders
log_debug = log.isEnabledFor(logging.DEBUG)
for packet_num, packet in enumerate(out.packets()):
if len(packet) > _MAX_MSG_ABSOLUTE:
self.log_warning_once(
"Dropping %r over-sized packet (%d bytes) %r",
out,
len(packet),
packet,
)
return
for send_transport in transports:
async_send_with_transport(
log_debug,
send_transport,
packet,
packet_num,
out,
addr,
port,
v6_flow_scope,
)
def _close(self) -> None:
"""Set global done and remove all service listeners."""
if self.done:
return
self.remove_all_service_listeners()
self.done = True
def _shutdown_threads(self) -> None:
"""Shutdown any threads."""
assert self.loop is not None
if self.loop.is_closed():
# close() is documented as idempotent — a second call after the
# loop has been torn down must be a no-op rather than raising.
return
self.notify_all()
if not self._loop_thread:
return
shutdown_loop(self.loop)
self._loop_thread.join()
self._loop_thread = None
# The loop's selector (epoll FD on Linux) and self-pipe sockets stay
# open until loop.close() is called. We own this loop because
# _start_thread() created it, so close it here to avoid leaking
# those file descriptors across Zeroconf() construct/close cycles.
self.loop.close()
[docs]
def close(self) -> None:
"""Ends the background threads, and prevent this instance from
servicing further queries.
This method is idempotent and irreversible.
"""
assert self.loop is not None
if self.loop.is_running():
if self.loop == get_running_loop():
log.warning(
"unregister_all_services skipped as it does blocking i/o; use AsyncZeroconf with asyncio"
)
else:
self.unregister_all_services()
self._close()
self.engine.close()
self._shutdown_threads()
async def _async_close(self) -> None:
"""Ends the background threads, and prevent this instance from
servicing further queries.
This method is idempotent and irreversible.
This call only intended to be used by AsyncZeroconf
Callers are responsible for unregistering all services
before calling this function
"""
self._close()
await self.engine._async_close() # pylint: disable=protected-access
self._shutdown_threads()
def __enter__(self) -> Zeroconf:
return self
def __exit__( # pylint: disable=useless-return
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
self.close()
return None