diff --git a/docs/conf.py b/docs/conf.py index badf67f8a..39c5452a7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -54,6 +54,7 @@ 'faust.cli', 'faust.models', 'faust.serializers', + 'faust.transport.drivers.confluent', 'faust.types', 'faust.types._env', 'faust.utils', diff --git a/docs/includes/settingref.txt b/docs/includes/settingref.txt index 91bec6c02..669894cff 100644 --- a/docs/includes/settingref.txt +++ b/docs/includes/settingref.txt @@ -391,6 +391,15 @@ You can also pass a list of URLs: Limitations: None +- ``confluent://`` + + Experimental transport using the :pypi:`confluent-kafka` client. + + Limitations: Does not do sticky partition assignment (not + suitable for tables), and do not create any necessary internal + topics (you have to create them manually). + + .. setting:: broker_credentials ``broker_credentials`` diff --git a/examples/word_count.py b/examples/word_count.py index dafb1b98a..cd3f55ce3 100755 --- a/examples/word_count.py +++ b/examples/word_count.py @@ -7,7 +7,7 @@ app = faust.App( 'word-counts', - broker='kafka://localhost:9092', + broker='confluent://localhost:9092', store='rocksdb://', version=1, topic_partitions=8, diff --git a/extra/bandit/baseline.json b/extra/bandit/baseline.json index e56de74bb..ea1967d2a 100644 --- a/extra/bandit/baseline.json +++ b/extra/bandit/baseline.json @@ -1070,6 +1070,18 @@ "loc": 890, "nosec": 0 }, + "faust/transport/drivers/confluent.py": { + "CONFIDENCE.HIGH": 0.0, + "CONFIDENCE.LOW": 0.0, + "CONFIDENCE.MEDIUM": 0.0, + "CONFIDENCE.UNDEFINED": 0.0, + "SEVERITY.HIGH": 0.0, + "SEVERITY.LOW": 0.0, + "SEVERITY.MEDIUM": 0.0, + "SEVERITY.UNDEFINED": 0.0, + "loc": 486, + "nosec": 0 + }, "faust/transport/producer.py": { "CONFIDENCE.HIGH": 0.0, "CONFIDENCE.LOW": 0.0, @@ -2091,4 +2103,4 @@ "test_name": "blacklist" } ] -} \ No newline at end of file +} diff --git a/faust/transport/drivers/__init__.py b/faust/transport/drivers/__init__.py index 203a8907c..e79cbf9f4 100644 --- a/faust/transport/drivers/__init__.py +++ b/faust/transport/drivers/__init__.py @@ -1,21 +1,17 @@ """Transport registry.""" -from yarl import URL +from typing import Type -from .aiokafka import Transport as AIOKafkaTransport +from mode.utils.imports import FactoryMapping -__all__ = ["by_name", "by_url"] - - -DRIVERS = { - "aiokafka": AIOKafkaTransport, - "kafka": AIOKafkaTransport, -} - - -def by_name(driver_name: str): - return DRIVERS[driver_name] +from faust.types import TransportT +__all__ = ["by_name", "by_url"] -def by_url(url: URL): - scheme = url.scheme - return DRIVERS[scheme] +TRANSPORTS: FactoryMapping[Type[TransportT]] = FactoryMapping( + aiokafka="faust.transport.drivers.aiokafka:Transport", + confluent="faust.transport.drivers.confluent:Transport", + kafka="faust.transport.drivers.aiokafka:Transport", +) +TRANSPORTS.include_setuptools_namespace("faust.transports") +by_name = TRANSPORTS.by_name +by_url = TRANSPORTS.by_url diff --git a/faust/transport/drivers/confluent.py b/faust/transport/drivers/confluent.py new file mode 100644 index 000000000..6c1f2810f --- /dev/null +++ b/faust/transport/drivers/confluent.py @@ -0,0 +1,758 @@ +"""Message transport using :pypi:`confluent_kafka`.""" +import asyncio +import os +import struct +import typing +import weakref +from collections import defaultdict +from time import monotonic +from typing import ( + Any, + Awaitable, + Callable, + ClassVar, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + Set, + Type, + cast, +) + +import confluent_kafka +from confluent_kafka import KafkaException, TopicPartition as _TopicPartition +from confluent_kafka.admin import AdminClient +from mode import Service, get_logger +from mode.threads import QueueServiceThread +from mode.utils.futures import notify +from mode.utils.times import Seconds, want_seconds +from yarl import URL + +from faust.exceptions import ConsumerNotStarted, ProducerSendError +from faust.transport import base +from faust.transport.consumer import ( + ConsumerThread, + RecordMap, + ThreadDelegateConsumer, + ensure_TP, + ensure_TPset, +) +from faust.types import TP, AppT, ConsumerMessage, HeadersArg, RecordMetadata +from faust.types.transports import ConsumerT, ProducerT + +if typing.TYPE_CHECKING: + from confluent_kafka import ( + Consumer as _Consumer, + Message as _Message, + Producer as _Producer, + ) +else: + + class _Consumer: + ... # noqa + + class _Producer: + ... # noqa + + class _Message: + ... # noqa + + +__all__ = ["Consumer", "Producer", "Transport"] + + +logger = get_logger(__name__) + + +def server_list(urls: List[URL], default_port: int) -> str: + default_host = "127.0.0.1" + return ",".join( + [f"{u.host or default_host}:{u.port or default_port}" for u in urls] + ) + + +class Consumer(ThreadDelegateConsumer): + """Kafka consumer using :pypi:`confluent_kafka`.""" + + logger = logger + + def _new_consumer_thread(self) -> ConsumerThread: + return ConfluentConsumerThread(self, loop=self.loop, beacon=self.beacon) + + async def create_topic( + self, + topic: str, + partitions: int, + replication: int, + *, + config: Mapping[str, Any] = None, + timeout: Seconds = 30.0, + retention: Seconds = None, + compacting: bool = None, + deleting: bool = None, + ensure_created: bool = False, + ) -> None: + """Create topic on broker.""" + if self.app.conf.topic_allow_declare: + await self._thread.create_topic( + topic, + partitions, + replication, + config=config, + timeout=int(want_seconds(timeout) * 1000.0), + retention=int(want_seconds(retention) * 1000.0), + compacting=compacting, + deleting=deleting, + ensure_created=ensure_created, + ) + else: + logger.warning(f"Topic creation disabled! Can't create topic {topic}") + + def _to_message(self, tp: TP, record: Any) -> ConsumerMessage: + # convert timestamp to seconds from int milliseconds. + timestamp_type: int + timestamp: Optional[int] + timestamp_type, timestamp = record.timestamp() + timestamp_s: float = cast(float, None) + if timestamp is not None: + timestamp_s = timestamp / 1000.0 + key = record.key() + key_size = len(key) if key is not None else 0 + value = record.value() + value_size = len(value) if value is not None else 0 + return ConsumerMessage( + record.topic(), + record.partition(), + record.offset(), + timestamp_s, + timestamp_type, + [], # headers + key, + value, + None, + key_size, + value_size, + tp, + ) + + def _new_topicpartition(self, topic: str, partition: int) -> TP: + return cast(TP, _TopicPartition(topic, partition)) + + async def on_stop(self) -> None: + """Call when consumer is stopping.""" + await super().on_stop() + transport = cast(Transport, self.transport) + # transport._topic_waiters.clear() + + def verify_event_path(self, now: float, tp: TP) -> None: + return self._thread.verify_event_path(now, tp) + + +class AsyncConsumer: + def __init__( + self, + config, + logger=None, + callback=None, + loop=None, + on_partitions_revoked=None, + on_partitions_assigned=None, + beacon=None, + ): + """Construct a Consumer usable within asyncio. + + :param config: A configuration dict for this Consumer + :param logger: A python logger instance. + + # Taken from https://github.com/stephan-hof/confluent-kafka-python/blob/69b79ad1b53d5e9058710ced63c42ebb1da2d9ec/examples/linux_asyncio_consumer.py + """ + self.consumer = confluent_kafka.Consumer(**config) + self.callback = callback + self.on_partitions_revoked = on_partitions_revoked + self.on_partitions_assigned = on_partitions_assigned + self.beacon = beacon + + self.eventfd = os.eventfd(0, os.EFD_CLOEXEC | os.EFD_NONBLOCK) + + # This is the channel how the consumer notifies asyncio. + self.loop = loop + if loop is None: + self.loop = asyncio.get_running_loop() + self.loop.add_reader(self.eventfd, self.__eventfd_ready) + self.consumer.io_event_enable(self.eventfd, struct.pack("@q", 1)) + + self.waiters = set() + + # Close eventfd and remove it from reader if + # self is not referenced anymore. + self.__close_eventfd = weakref.finalize( + self, AsyncConsumer.close_eventd, self.loop, self.eventfd + ) + + @staticmethod + def close_eventd(loop, eventfd): + """Internal helper method. Not part of the public API.""" + loop.remove_reader(eventfd) + os.close(eventfd) + + def close(self): + self.consumer.close() + + def __eventfd_ready(self): + os.eventfd_read(self.eventfd) + + for future in self.waiters: + if not future.done(): + future.set_result(True) + + def subscribe(self, *args, **kwargs): + self.consumer.subscribe(*args, **kwargs) + + def assign(self, *args, **kwargs): + self.consumer.assign(*args, **kwargs) + + async def poll(self, timeout=0): + """Consumes a single message, calls callbacks and returns events. + + It is defined a 'async def' and returns an awaitable object a + caller needs to deal with to get the result. + See https://docs.python.org/3/library/asyncio-task.html#awaitables + + Which makes it safe (and mandatory) to call it directly in an asyncio + coroutine like this: `msg = await consumer.poll()` + + If timeout > 0: Wait at most X seconds for a message. + Returns `None` if no message arrives in time. + If timeout <= 0: Endless wait for a message. + """ + if timeout > 0: + try: + t = await asyncio.wait_for(self._poll_no_timeout(), timeout) + if self.callback: + await self.callback(t) + except asyncio.TimeoutError: + return None + else: + return self._poll_no_timeout() + + async def _poll_no_timeout(self): + while not (msg := await self._single_poll()): + pass + return msg + + async def _single_poll(self): + if (msg := self.consumer.poll(timeout=0)) is not None: + return msg + + awaitable = self.loop.create_future() + self.waiters.add(awaitable) + try: + # timeout=2 is there for two reasons: + # 1) self.consumer.poll needs to be called reguraly for other + # activities like: log callbacks. + # 2) Ensures progress even if something with eventfd + # notification goes wrong. + await asyncio.wait_for(awaitable, timeout=2) + except asyncio.TimeoutError: + return None + finally: + self.waiters.discard(awaitable) + + def assignment(self) -> Set[TP]: + return self.consumer.assignment() + + +class ConfluentConsumerThread(ConsumerThread): + """Thread managing underlying :pypi:`confluent_kafka` consumer.""" + + _consumer: Optional[AsyncConsumer] = None + _assigned: bool = False + + # _pending_rebalancing_spans: Deque[opentracing.Span] + + tp_last_committed_at: MutableMapping[TP, float] + time_started: float + + tp_fetch_request_timeout_secs: float + tp_fetch_response_timeout_secs: float + tp_stream_timeout_secs: float + tp_commit_timeout_secs: float + + async def on_start(self) -> None: + self._consumer = self._create_consumer(loop=self.thread_loop) + self.time_started = monotonic() + # await self._consumer.start() + + def _create_consumer(self, loop: asyncio.AbstractEventLoop) -> AsyncConsumer: + transport = cast(Transport, self.transport) + if self.app.client_only: + return self._create_client_consumer(transport, loop=loop) + else: + return self._create_worker_consumer(transport, loop=loop) + + def _create_worker_consumer( + self, transport: "Transport", loop: asyncio.AbstractEventLoop + ) -> AsyncConsumer: + conf = self.app.conf + self._assignor = self.app.assignor + + # XXX parition.assignment.strategy is string + # need to write C wrapper for this + # 'partition.assignment.strategy': [self._assignor] + return AsyncConsumer( + { + "bootstrap.servers": server_list(transport.url, transport.default_port), + "group.id": conf.id, + "client.id": conf.broker_client_id, + "default.topic.config": { + "auto.offset.reset": "earliest", + }, + "enable.auto.commit": False, + "fetch.max.bytes": conf.consumer_max_fetch_size, + "request.timeout.ms": int(conf.broker_request_timeout * 1000.0), + "check.crcs": conf.broker_check_crcs, + "session.timeout.ms": int(conf.broker_session_timeout * 1000.0), + "heartbeat.interval.ms": int(conf.broker_heartbeat_interval * 1000.0), + }, + self.logger, + ) + + def _create_client_consumer( + self, transport: "Transport", loop: asyncio.AbstractEventLoop + ) -> AsyncConsumer: + conf = self.app.conf + return AsyncConsumer( + { + "bootstrap.servers": server_list(transport.url, transport.default_port), + "client.id": conf.broker_client_id, + "enable.auto.commit": True, + "default.topic.config": { + "auto.offset.reset": "earliest", + }, + }, + self.logger, + ) + + def close(self) -> None: + ... + + async def subscribe(self, topics: Iterable[str]) -> None: + # XXX pattern does not work :/ + await self.cast_thread( + self._ensure_consumer().subscribe, + topics=list(topics), + on_assign=self._on_assign, + on_revoke=self._on_revoke, + # listener=self._rebalance_listener, + ) + + def _on_assign(self, consumer: _Consumer, assigned: List[_TopicPartition]) -> None: + self._assigned = True + self.thread_loop.create_task( + self.on_partitions_assigned({TP(tp.topic, tp.partition) for tp in assigned}) + ) + + def _on_revoke(self, consumer: _Consumer, revoked: List[_TopicPartition]) -> None: + self.thread_loop.create_task( + self.on_partitions_revoked({TP(tp.topic, tp.partition) for tp in revoked}) + ) + + async def seek_to_committed(self) -> Mapping[TP, int]: + return await self.call_thread(self._seek_to_committed) + + async def _seek_to_committed(self) -> Mapping[TP, int]: + consumer = self._ensure_consumer() + assignment = consumer.assignment() + committed = consumer.consumer.committed(assignment) + for tp in committed: + await consumer.consumer.seek(tp) + return {ensure_TP(tp): tp.offset for tp in committed} + + async def _committed_offsets(self, partitions: List[TP]) -> MutableMapping[TP, int]: + consumer = self._ensure_consumer() + committed = consumer.consumer.committed( + [_TopicPartition(tp[0], tp[1]) for tp in partitions] + ) + return {TP(tp.topic, tp.partition): tp.offset for tp in committed} + + async def commit(self, tps: Mapping[TP, int]) -> bool: + await self.call_thread( + self._ensure_consumer().consumer.commit, + offsets=[ + _TopicPartition(tp.topic, tp.partition, offset=offset) + for tp, offset in tps.items() + ], + asynchronous=False, + ) + return True + + async def position(self, tp: TP) -> Optional[int]: + return await self.call_thread(self._ensure_consumer().consumer.position, tp) + + async def seek_to_beginning(self, *partitions: _TopicPartition) -> None: + await self.call_thread( + self._ensure_consumer().consumer.seek_to_beginning, *partitions + ) + + async def seek_wait(self, partitions: Mapping[TP, int]) -> None: + consumer = self._ensure_consumer() + await self.call_thread(self._seek_wait, consumer, partitions) + + async def _seek_wait( + self, consumer: Consumer, partitions: Mapping[TP, int] + ) -> None: + for tp, offset in partitions.items(): + self.log.dev("SEEK %r -> %r", tp, offset) + await consumer.seek(tp, offset) + await asyncio.gather(*[consumer.position(tp) for tp in partitions]) + + def seek(self, partition: TP, offset: int) -> None: + self._ensure_consumer().consumer.seek(partition, offset) + + def assignment(self) -> Set[TP]: + return ensure_TPset(self._ensure_consumer().assignment()) + + def highwater(self, tp: TP) -> int: + _, hw = self._ensure_consumer().consumer.get_watermark_offsets( + _TopicPartition(tp.topic, tp.partition), cached=True + ) + return hw + + def topic_partitions(self, topic: str) -> Optional[int]: + # XXX NotImplemented + return None + + async def earliest_offsets(self, *partitions: TP) -> MutableMapping[TP, int]: + if not partitions: + return {} + return await self.call_thread(self._earliest_offsets, partitions) + + async def _earliest_offsets(self, partitions: List[TP]) -> MutableMapping[TP, int]: + consumer = self._ensure_consumer() + return { + tp: consumer.consumer.get_watermark_offsets(_TopicPartition(tp[0], tp[1]))[ + 0 + ] + for tp in partitions + } + + async def highwaters(self, *partitions: TP) -> MutableMapping[TP, int]: + if not partitions: + return {} + return await self.call_thread(self._highwaters, partitions) + + async def _highwaters(self, partitions: List[TP]) -> MutableMapping[TP, int]: + consumer = self._ensure_consumer() + return { + tp: consumer.consumer.get_watermark_offsets(_TopicPartition(tp[0], tp[1]))[ + 1 + ] + for tp in partitions + } + + def _ensure_consumer(self) -> AsyncConsumer: + if self._consumer is None: + raise ConsumerNotStarted("Consumer thread not yet started") + return self._consumer + + async def getmany( + self, active_partitions: Optional[Set[TP]], timeout: float + ) -> RecordMap: + # Implementation for the Fetcher service. + _consumer = self._ensure_consumer() + messages = await self.call_thread( + _consumer.consumer.consume, + num_messages=10000, + timeout=timeout, + ) + records: RecordMap = defaultdict(list) + for message in messages: + tp = TP(message.topic(), message.partition()) + records[tp].append(message) + return records + + async def create_topic( + self, + topic: str, + partitions: int, + replication: int, + *, + config: Mapping[str, Any] = None, + timeout: Seconds = 30.0, + retention: Seconds = None, + compacting: bool = None, + deleting: bool = None, + ensure_created: bool = False, + ) -> None: + return # XXX + + def key_partition( + self, topic: str, key: Optional[bytes], partition: int = None + ) -> Optional[int]: + metadata = self._consumer.consumer.list_topics(topic) + partition_count = len(metadata.topics[topic]["partitions"]) + + # Calculate the partition number based on the key hash + key_bytes = str(key).encode("utf-8") + return abs(hash(key_bytes)) % partition_count + + +class ProducerProduceFuture(asyncio.Future): + def set_from_on_delivery(self, err: Optional[BaseException], msg: _Message) -> None: + if err: + # XXX Not sure what err' is here, hopefully it's an exception + # object and not a string [ask]. + self.set_exception(err) + else: + metadata: RecordMetadata = self.message_to_metadata(msg) + self.set_result(metadata) + + def message_to_metadata(self, message: _Message) -> RecordMetadata: + topic, partition = tp = TP(message.topic(), message.partition()) + return RecordMetadata(topic, partition, tp, message.offset()) + + +class ProducerThread(QueueServiceThread): + """Thread managing underlying :pypi:`confluent_kafka` producer.""" + + app: AppT + producer: "Producer" + transport: "Transport" + _producer: Optional[_Producer] = None + _flush_soon: Optional[asyncio.Future] = None + + def __init__(self, producer: "Producer", **kwargs: Any) -> None: + self.producer = producer + self.transport = cast(Transport, self.producer.transport) + self.app = self.transport.app + super().__init__(**kwargs) + + async def on_start(self) -> None: + self._producer = confluent_kafka.Producer( + { + "bootstrap.servers": server_list( + self.transport.url, self.transport.default_port + ), + "client.id": self.app.conf.broker_client_id, + "max.in.flight.requests.per.connection": 1, + } + ) + + async def flush(self) -> None: + if self._producer is not None: + self._producer.flush() + + async def on_thread_stop(self) -> None: + if self._producer is not None: + self._producer.flush() + + def produce( + self, + topic: str, + key: bytes, + value: bytes, + partition: int, + on_delivery: Callable, + ) -> None: + if self._producer is None: + raise RuntimeError("Producer not started") + if partition is not None: + self._producer.produce( + topic, + key, + value, + partition, + on_delivery=on_delivery, + ) + else: + self._producer.produce( + topic, + key, + value, + on_delivery=on_delivery, + ) + notify(self._flush_soon) + + @Service.task + async def _background_flush(self) -> None: + producer = cast(_Producer, self._producer) + _size = producer.__len__ + _flush = producer.flush + _poll = producer.poll + _sleep = self.sleep + _create_future = self.loop.create_future + while not self.should_stop: + if not _size(): + flush_soon = self._flush_soon + if flush_soon is None: + flush_soon = self._flush_soon = _create_future() + stopped: bool = False + try: + stopped = await self.wait_for_stopped(flush_soon, timeout=1.0) + finally: + self._flush_soon = None + if not stopped: + _flush(timeout=100) + _poll(timeout=1) + await _sleep(0) + + +class Producer(base.Producer): + """Kafka producer using :pypi:`confluent_kafka`.""" + + logger = logger + + _producer_thread: ProducerThread + _admin: AdminClient + _quick_produce: Any = None + + def __post_init__(self) -> None: + self._producer_thread = ProducerThread(self, loop=self.loop, beacon=self.beacon) + self._quick_produce = self._producer_thread.produce + + async def _on_irrecoverable_error(self, exc: BaseException) -> None: + consumer = self.transport.app.consumer + if consumer is not None: + await consumer.crash(exc) + await self.crash(exc) + + async def on_restart(self) -> None: + """Call when producer is restarting.""" + self.on_init() + + async def create_topic( + self, + topic: str, + partitions: int, + replication: int, + *, + config: Mapping[str, Any] = None, + timeout: Seconds = 20.0, + retention: Seconds = None, + compacting: bool = None, + deleting: bool = None, + ensure_created: bool = False, + ) -> None: + """Create topic on broker.""" + return # XXX + _retention = int(want_seconds(retention) * 1000.0) if retention else None + await cast(Transport, self.transport)._create_topic( + self, + self._producer.client, + topic, + partitions, + replication, + config=config, + timeout=int(want_seconds(timeout) * 1000.0), + retention=_retention, + compacting=compacting, + deleting=deleting, + ensure_created=ensure_created, + ) + + async def on_start(self) -> None: + """Call when producer is starting.""" + await self._producer_thread.start() + await self.sleep(0.5) # cannot remember why, necessary? [ask] + + async def on_stop(self) -> None: + """Call when producer is stopping.""" + await self._producer_thread.stop() + + async def send( + self, + topic: str, + key: Optional[bytes], + value: Optional[bytes], + partition: Optional[int], + timestamp: Optional[float], + headers: Optional[HeadersArg], + *, + transactional_id: str = None, + ) -> Awaitable[RecordMetadata]: + """Send message for future delivery.""" + fut = ProducerProduceFuture(loop=self.loop) + self._quick_produce( + topic, + value, + key, + partition, + on_delivery=fut.set_from_on_delivery, + ) + return cast(Awaitable[RecordMetadata], fut) + try: + return cast( + Awaitable[RecordMetadata], + await self._producer.send(topic, value, key=key, partition=partition), + ) + except KafkaException as exc: + raise ProducerSendError(f"Error while sending: {exc!r}") from exc + + async def send_and_wait( + self, + topic: str, + key: Optional[bytes], + value: Optional[bytes], + partition: Optional[int], + timestamp: Optional[float], + headers: Optional[HeadersArg], + *, + transactional_id: str = None, + ) -> RecordMetadata: + """Send message and wait for it to be delivered to broker(s).""" + fut = await self.send( + topic, + key, + value, + partition, + timestamp, + headers, + ) + return await fut + + async def flush(self) -> None: + """Flush producer buffer. + + This will wait until the producer has written + all buffered up messages to any connected brokers. + """ + await self._producer_thread.flush() + + def key_partition(self, topic: str, key: bytes) -> TP: + """Return topic and partition destination for key.""" + # Get the partition count for the topic + metadata = self._producer_thread.producer.list_topics(topic) + partition_count = len(metadata.topics[topic].partitions) + + # Calculate the partition number based on the key hash + key_bytes = str(key).encode("utf-8") + partition = abs(hash(key_bytes)) % partition_count + + return TP(topic, partition) + + +class Transport(base.Transport): + """Kafka transport using :pypi:`confluent_kafka`.""" + + Consumer: ClassVar[Type[ConsumerT]] = Consumer + Producer: ClassVar[Type[ProducerT]] = Producer + + default_port = 9092 + driver_version = f"confluent_kafka={confluent_kafka.__version__}" + + def _topic_config( + self, retention: int = None, compacting: bool = None, deleting: bool = None + ) -> MutableMapping[str, Any]: + config: MutableMapping[str, Any] = {} + cleanup_flags: Set[str] = set() + if compacting: + cleanup_flags |= {"compact"} + if deleting: + cleanup_flags |= {"delete"} + if cleanup_flags: + config["cleanup.policy"] = ",".join(sorted(cleanup_flags)) + if retention: + config["retention.ms"] = retention + return config diff --git a/faust/types/settings/settings.py b/faust/types/settings/settings.py index f2203fff5..d7613b506 100644 --- a/faust/types/settings/settings.py +++ b/faust/types/settings/settings.py @@ -705,6 +705,15 @@ def broker(self) -> List[URL]: The recommended transport using the :pypi:`aiokafka` client. Limitations: None + + + - ``confluent://`` + + Experimental transport using the :pypi:`confluent-kafka` client. + + Limitations: Does not do sticky partition assignment (not + suitable for tables), and do not create any necessary internal + topics (you have to create them manually). """ @broker.on_set_default # type: ignore diff --git a/requirements/extras/ckafka.txt b/requirements/extras/ckafka.txt new file mode 100644 index 000000000..839c6139d --- /dev/null +++ b/requirements/extras/ckafka.txt @@ -0,0 +1,2 @@ +#confluent-kafka # waiting for https://github.com/faust-streaming/faust/pull/418 to be merged and released +confluent-kafka @ git+https://github.com/stephan-hof/confluent-kafka-python.git@features/io_event_enable diff --git a/setup.py b/setup.py index 3c8a1a3ab..4fad9a006 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ "aiomonitor", "cchardet", "ciso8601", + "ckafka", "cython", "datadog", "debug", diff --git a/tests/unit/transport/drivers/test_confluent.py b/tests/unit/transport/drivers/test_confluent.py new file mode 100644 index 000000000..15e8c0042 --- /dev/null +++ b/tests/unit/transport/drivers/test_confluent.py @@ -0,0 +1,1375 @@ +from contextlib import contextmanager +from typing import Optional +from unittest.mock import Mock, call, patch + +import confluent_kafka +import opentracing +import pytest +from confluent_kafka import TopicPartition +from confluent_kafka.error import KafkaError, KafkaException +from mode.utils import text +from mode.utils.futures import done_future +from opentracing.ext import tags + +import faust +from faust import auth +from faust.exceptions import ImproperlyConfigured, NotReady +from faust.sensors.monitor import Monitor +from faust.transport.drivers import confluent as mod +from faust.transport.drivers.confluent import ( + ConfluentConsumerThread, + Consumer, + ConsumerNotStarted, + Producer, + ProducerSendError, + ProducerThread, + Transport, + server_list, +) +from faust.types import TP +from faust.types.tuples import FutureMessage, PendingMessage +from tests.helpers import AsyncMock + +TP1 = TP("topic", 23) +TP2 = TP("topix", 23) + +TESTED_MODULE = "faust.transport.drivers.confluent" + + +@pytest.fixture() +def thread(): + return Mock( + name="thread", + create_topic=AsyncMock(), + ) + + +@pytest.fixture() +def consumer(*, thread, app, callback, on_partitions_revoked, on_partitions_assigned): + consumer = Consumer( + app.transport, + callback=callback, + on_partitions_revoked=on_partitions_revoked, + on_partitions_assigned=on_partitions_assigned, + ) + consumer._thread = thread + return consumer + + +@pytest.fixture() +def callback(): + return Mock(name="callback") + + +@pytest.fixture() +def on_partitions_revoked(): + return Mock(name="on_partitions_revoked") + + +@pytest.fixture() +def on_partitions_assigned(): + return Mock(name="on_partitions_assigned") + + +class TestConsumer: + @pytest.fixture() + def thread(self): + return Mock( + name="thread", + create_topic=AsyncMock(), + ) + + @pytest.fixture() + def consumer( + self, *, thread, app, callback, on_partitions_revoked, on_partitions_assigned + ): + consumer = Consumer( + app.transport, + callback=callback, + on_partitions_revoked=on_partitions_revoked, + on_partitions_assigned=on_partitions_assigned, + ) + consumer._thread = thread + return consumer + + @pytest.fixture() + def callback(self): + return Mock(name="callback") + + @pytest.fixture() + def on_partitions_revoked(self): + return Mock(name="on_partitions_revoked") + + @pytest.fixture() + def on_partitions_assigned(self): + return Mock(name="on_partitions_assigned") + + @pytest.mark.asyncio + async def test_create_topic(self, *, consumer, thread): + await consumer.create_topic( + "topic", + 30, + 3, + timeout=40.0, + retention=50.0, + compacting=True, + deleting=True, + ensure_created=True, + ) + thread.create_topic.assert_called_once_with( + "topic", + 30, + 3, + config=None, + timeout=40.0, + retention=50.0, + compacting=True, + deleting=True, + ensure_created=True, + ) + + def test__new_topicpartition(self, *, consumer): + tp = consumer._new_topicpartition("t", 3) + assert isinstance(tp, TopicPartition) + assert tp.topic == "t" + assert tp.partition == 3 + + def test__to_message(self, *, consumer): + record = self.mock_record( + timestamp=3000, + headers=[("a", b"b")], + ) + m = consumer._to_message(TopicPartition("t", 3), record) + assert m.topic == record.topic + assert m.partition == record.partition + assert m.offset == record.offset + assert m.timestamp == 3.0 + assert m.headers == record.headers + assert m.key == record.key + assert m.value == record.value + assert m.checksum == record.checksum + assert m.serialized_key_size == record.serialized_key_size + assert m.serialized_value_size == record.serialized_value_size + + def test__to_message__no_timestamp(self, *, consumer): + record = self.mock_record(timestamp=None) + m = consumer._to_message(TopicPartition("t", 3), record) + assert m.timestamp is None + + def mock_record( + self, + topic="t", + partition=3, + offset=1001, + timestamp=None, + timestamp_type=1, + headers=None, + key=b"key", + value=b"value", + checksum=312, + serialized_key_size=12, + serialized_value_size=40, + **kwargs, + ): + return Mock( + name="record", + topic=topic, + partition=partition, + offset=offset, + timestamp=timestamp, + timestamp_type=timestamp_type, + headers=headers, + key=key, + value=value, + checksum=checksum, + serialized_key_size=serialized_key_size, + serialized_value_size=serialized_value_size, + ) + + +class ConfluentConsumerThreadFixtures: + @pytest.fixture() + def cthread(self, *, consumer): + return ConfluentConsumerThread(consumer) + + @pytest.fixture() + def tracer(self, *, app): + tracer = app.tracer = Mock(name="tracer") + tobj = tracer.get_tracer.return_value + + def start_span(operation_name=None, **kwargs): + span = opentracing.Span( + tracer=tobj, + context=opentracing.SpanContext(), + ) + + if operation_name is not None: + span.operation_name = operation_name + assert span.operation_name == operation_name + return span + + tobj.start_span = start_span + return tracer + + @pytest.fixture() + def _consumer(self): + return Mock( + name="Consumer", + autospec=confluent_kafka.Consumer, + start=AsyncMock(), + stop=AsyncMock(), + commit=AsyncMock(), + position=AsyncMock(), + end_offsets=AsyncMock(), + _client=Mock(name="Client", close=AsyncMock()), + _coordinator=Mock(name="Coordinator", close=AsyncMock()), + ) + + @pytest.fixture() + def now(self): + return 1201230410 + + @pytest.fixture() + def tp(self): + return TP("foo", 30) + + @pytest.fixture() + def aiotp(self, *, tp): + return TopicPartition(tp.topic, tp.partition) + + @pytest.fixture() + def logger(self, *, cthread): + cthread.log = Mock(name="cthread.log") + return cthread.log + + +class Test_verify_event_path_base(ConfluentConsumerThreadFixtures): + last_request: Optional[float] = None + last_response: Optional[float] = None + highwater: int = 1 + committed_offset: int = 1 + acks_enabled: bool = False + stream_inbound: Optional[float] = None + last_commit: Optional[float] = None + expected_message: Optional[str] = None + has_monitor = True + + def _set_started(self, t): + self._cthread.time_started = t + + def _set_last_request(self, last_request): + self.__consumer.records_last_request[self._aiotp] = last_request + + def _set_last_response(self, last_response): + self.__consumer.records_last_response[self._aiotp] = last_response + + def _set_stream_inbound(self, inbound_time): + self._app.monitor.stream_inbound_time[self._tp] = inbound_time + + def _set_last_commit(self, commit_time): + self._cthread.tp_last_committed_at[self._tp] = commit_time + + @pytest.fixture(autouse=True) + def aaaa_setup_attributes(self, *, app, cthread, _consumer, now, tp, aiotp): + self._app = app + self._tp = tp + self._aiotp = aiotp + self._now = now + self._cthread = cthread + self.__consumer = _consumer + + @pytest.fixture(autouse=True) + def setup_consumer(self, *, app, cthread, _consumer, now, tp, aiotp): + assert self._tp is tp + assert self._aiotp is aiotp + # patch self.acks_enabledc + app.topics.acks_enabled_for = Mock(name="acks_enabled_for") + app.topics.acks_enabled_for.return_value = self.acks_enabled + + # patch consumer.time_started + self._set_started(now) + + # connect underlying AIOKafkaConsumer object. + cthread._consumer = _consumer + + # patch AIOKafkaConsumer.records_last_request to self.last_request + _consumer.records_last_request = {} + if self.last_request is not None: + self._set_last_request(self.last_request) + + # patch AIOKafkaConsumer.records_last_response to self.last_response + _consumer.records_last_response = {} + if self.last_response is not None: + self._set_last_response(self.last_response) + + # patch app.monitor + if self.has_monitor: + cthread.consumer.app.monitor = Mock(name="monitor", spec=Monitor) + app.monitor = cthread.consumer.app.monitor + app.monitor = Mock(name="monitor", spec=Monitor) + # patch monitor.stream_inbound_time + # this is the time when a stream last processed a record + # for tp + app.monitor.stream_inbound_time = {} + self._set_stream_inbound(self.stream_inbound) + else: + app.monitor = None + + # patch highwater + cthread.highwater = Mock(name="highwater") + cthread.highwater.return_value = self.highwater + + # patch committed offset + cthread.consumer._committed_offset = { + tp: self.committed_offset, + } + + cthread.tp_last_committed_at = {} + self._set_last_commit(self.last_commit) + + def test_state(self, *, cthread, now): + # verify that setup_consumer fixture was applied + assert cthread.time_started == now + + +class Test_ConfluentConsumerThread(ConfluentConsumerThreadFixtures): + def test_constructor(self, *, cthread): + assert cthread._partitioner + assert cthread._rebalance_listener + + @pytest.mark.asyncio + async def test_on_start(self, *, cthread, _consumer): + cthread._create_consumer = Mock( + name="_create_consumer", + return_value=_consumer, + ) + await cthread.on_start() + + assert cthread._consumer is cthread._create_consumer.return_value + cthread._create_consumer.assert_called_once_with(loop=cthread.thread_loop) + cthread._consumer.start.assert_called_once_with() + + @pytest.mark.asyncio + async def test_on_thread_stop(self, *, cthread, _consumer): + cthread._consumer = _consumer + await cthread.on_thread_stop() + cthread._consumer.stop.assert_called_once_with() + + @pytest.mark.asyncio + async def test_on_thread_stop__consumer_not_started(self, *, cthread): + cthread._consumer = None + await cthread.on_thread_stop() + + def test__create_consumer__client(self, *, cthread, app): + app.client_only = True + loop = Mock(name="loop") + cthread._create_client_consumer = Mock(name="_create_client_consumer") + c = cthread._create_consumer(loop=loop) + assert c is cthread._create_client_consumer.return_value + cthread._create_client_consumer.assert_called_once_with(cthread.transport) + + def test__create_consumer__worker(self, *, cthread, app): + app.client_only = False + loop = Mock(name="loop") + cthread._create_worker_consumer = Mock(name="_create_worker_consumer") + c = cthread._create_consumer(loop=loop) + assert c is cthread._create_worker_consumer.return_value + cthread._create_worker_consumer.assert_called_once_with(cthread.transport) + + def test_session_gt_request_timeout(self, *, cthread, app): + app.conf.broker_session_timeout = 90 + app.conf.broker_request_timeout = 10 + + with pytest.raises(ImproperlyConfigured): + self.assert_create_worker_consumer(cthread, app, in_transaction=False) + + def test__create_worker_consumer(self, *, cthread, app): + self.assert_create_worker_consumer( + cthread, + app, + in_transaction=False, + isolation_level="read_uncommitted", + ) + + def test__create_worker_consumer__transaction(self, *, cthread, app): + self.assert_create_worker_consumer( + cthread, + app, + in_transaction=True, + isolation_level="read_committed", + ) + + def assert_create_worker_consumer( + self, + cthread, + app, + in_transaction=False, + isolation_level="read_uncommitted", + api_version=None, + ): + transport = cthread.transport + conf = app.conf + cthread.consumer.in_transaction = in_transaction + auth_settings = credentials_to_confluent_kafka_auth( + conf.broker_credentials, conf.ssl_context + ) + with patch("confluent_kafka.AIOKafkaConsumer") as AIOKafkaConsumer: + c = cthread._create_worker_consumer(transport) + assert c is AIOKafkaConsumer.return_value + max_poll_interval = conf.broker_max_poll_interval + AIOKafkaConsumer.assert_called_once_with( + api_version=app.conf.consumer_api_version, + client_id=conf.broker_client_id, + group_id=conf.id, + # group_instance_id=conf.consumer_group_instance_id, + bootstrap_servers=server_list(transport.url, transport.default_port), + partition_assignment_strategy=[cthread._assignor], + enable_auto_commit=False, + auto_offset_reset=conf.consumer_auto_offset_reset, + max_poll_records=conf.broker_max_poll_records, + max_poll_interval_ms=int(max_poll_interval * 1000.0), + max_partition_fetch_bytes=conf.consumer_max_fetch_size, + fetch_max_wait_ms=1500, + request_timeout_ms=int(conf.broker_request_timeout * 1000.0), + rebalance_timeout_ms=int(conf.broker_rebalance_timeout * 1000.0), + check_crcs=conf.broker_check_crcs, + session_timeout_ms=int(conf.broker_session_timeout * 1000.0), + heartbeat_interval_ms=int(conf.broker_heartbeat_interval * 1000.0), + isolation_level=isolation_level, + connections_max_idle_ms=conf.consumer_connections_max_idle_ms, + metadata_max_age_ms=conf.consumer_metadata_max_age_ms, + # traced_from_parent_span=cthread.traced_from_parent_span, + # start_rebalancing_span=cthread.start_rebalancing_span, + # start_coordinator_span=cthread.start_coordinator_span, + # on_generation_id_known=cthread.on_generation_id_known, + # flush_spans=cthread.flush_spans, + **auth_settings, + ) + + def test__create_client_consumer(self, *, cthread, app): + transport = cthread.transport + conf = app.conf + auth_settings = credentials_to_confluent_kafka_auth( + conf.broker_credentials, conf.ssl_context + ) + with patch("confluent_kafka.AIOKafkaConsumer") as AIOKafkaConsumer: + c = cthread._create_client_consumer(transport) + max_poll_interval = conf.broker_max_poll_interval + assert c is AIOKafkaConsumer.return_value + AIOKafkaConsumer.assert_called_once_with( + client_id=conf.broker_client_id, + bootstrap_servers=server_list(transport.url, transport.default_port), + request_timeout_ms=int(conf.broker_request_timeout * 1000.0), + max_poll_interval_ms=int(max_poll_interval * 1000.0), + enable_auto_commit=True, + max_poll_records=conf.broker_max_poll_records, + auto_offset_reset=conf.consumer_auto_offset_reset, + check_crcs=conf.broker_check_crcs, + **auth_settings, + ) + + def test__start_span(self, *, cthread, app): + with patch(TESTED_MODULE + ".set_current_span") as s: + app.tracer = Mock(name="tracer") + span = cthread._start_span("test") + app.tracer.get_tracer.assert_called_once_with( + f"{app.conf.name}-_confluent_kafka" + ) + tracer = app.tracer.get_tracer.return_value + tracer.start_span.assert_called_once_with(operation_name="test") + span.set_tag.assert_has_calls( + [ + call(tags.SAMPLING_PRIORITY, 1), + call("faust_app", app.conf.name), + call("faust_id", app.conf.id), + ] + ) + s.assert_called_once_with(span) + assert span is tracer.start_span.return_value + + def test_trace_category(self, *, cthread, app): + assert cthread.trace_category == f"{app.conf.name}-_confluent_kafka" + + @pytest.mark.skip("Needs fixing") + def test_transform_span_lazy(self, *, cthread, app, tracer): + cthread._consumer = Mock(name="_consumer") + cthread._consumer._coordinator.generation = -1 + self.assert_setup_lazy_spans(cthread, app, tracer) + + cthread._consumer._coordinator.generation = 10 + pending = cthread._pending_rebalancing_spans + assert len(pending) == 3 + + cthread.on_generation_id_known() + assert not pending + + @pytest.mark.skip("Needs fixing") + def test_transform_span_flush_spans(self, *, cthread, app, tracer): + cthread._consumer = Mock(name="_consumer") + cthread._consumer._coordinator.generation = -1 + self.assert_setup_lazy_spans(cthread, app, tracer) + pending = cthread._pending_rebalancing_spans + assert len(pending) == 3 + + cthread.flush_spans() + assert not pending + + def test_span_without_operation_name(self, *, cthread): + span = opentracing.Span( + tracer=Mock("tobj"), + context=opentracing.SpanContext(), + ) + + assert cthread._on_span_cancelled_early(span) is None + + @pytest.mark.skip("Needs fixing") + def test_transform_span_lazy_no_consumer(self, *, cthread, app, tracer): + cthread._consumer = Mock(name="_consumer") + cthread._consumer._coordinator.generation = -1 + self.assert_setup_lazy_spans(cthread, app, tracer) + + cthread._consumer = None + pending = cthread._pending_rebalancing_spans + assert len(pending) == 3 + + while pending: + span = pending.popleft() + cthread._on_span_generation_known(span) + + @pytest.mark.skip("Needs fixing") + def test_transform_span_eager(self, *, cthread, app, tracer): + cthread._consumer = Mock(name="_consumer") + cthread._consumer._coordinator.generation = 10 + self.assert_setup_lazy_spans(cthread, app, tracer, expect_lazy=False) + + def assert_setup_lazy_spans(self, cthread, app, tracer, expect_lazy=True): + got_foo = got_bar = got_baz = False + + def foo(): + nonlocal got_foo + got_foo = True + T = cthread.traced_from_parent_span(None, lazy=True) + T(bar)() + + def bar(): + nonlocal got_bar + got_bar = True + T = cthread.traced_from_parent_span(None, lazy=True) + T(REPLACE_WITH_MEMBER_ID)() + + def REPLACE_WITH_MEMBER_ID(): + nonlocal got_baz + got_baz = True + + with cthread.start_rebalancing_span() as span: + T = cthread.traced_from_parent_span(span) + T(foo)() + if expect_lazy: + assert len(cthread._pending_rebalancing_spans) == 2 + + assert got_foo + assert got_bar + assert got_baz + if expect_lazy: + assert len(cthread._pending_rebalancing_spans) == 3 + else: + assert not cthread._pending_rebalancing_spans + + def test__start_span__no_tracer(self, *, cthread, app): + app.tracer = None + with cthread._start_span("test") as span: + assert span + + def test_traced_from_parent_span(self, *, cthread): + with patch(TESTED_MODULE + ".traced_from_parent_span") as traced: + parent_span = Mock(name="parent_span") + ret = cthread.traced_from_parent_span(parent_span, foo=303) + traced.assert_called_once_with(parent_span, callback=None, foo=303) + assert ret is traced.return_value + + def test_start_rebalancing_span(self, *, cthread): + cthread._start_span = Mock() + ret = cthread.start_rebalancing_span() + assert ret is cthread._start_span.return_value + cthread._start_span.assert_called_once_with("rebalancing", lazy=True) + + def test_start_coordinator_span(self, *, cthread): + cthread._start_span = Mock() + ret = cthread.start_coordinator_span() + assert ret is cthread._start_span.return_value + cthread._start_span.assert_called_once_with("coordinator") + + def test_close(self, *, cthread, _consumer): + cthread._consumer = _consumer + cthread.close() + assert _consumer._closed + _consumer._coordinator.close.assert_called_once_with() + + def test_close__no_consumer(self, *, cthread): + cthread._consumer = None + cthread.close() + + @pytest.mark.asyncio + async def test_subscribe(self, *, cthread, _consumer): + with self.assert_calls_thread( + cthread, + _consumer, + _consumer.subscribe, + topics={"foo", "bar"}, + listener=cthread._rebalance_listener, + ): + await cthread.subscribe(["foo", "bar"]) + + @pytest.mark.asyncio + async def test_seek_to_committed(self, *, cthread, _consumer): + with self.assert_calls_thread(cthread, _consumer, _consumer.seek_to_committed): + await cthread.seek_to_committed() + + @pytest.mark.asyncio + async def test_commit(self, *, cthread, _consumer): + offsets = {TP1: 100} + with self.assert_calls_thread(cthread, _consumer, cthread._commit, offsets): + await cthread.commit(offsets) + + @pytest.mark.skip("Needs fixing") + @pytest.mark.asyncio + async def test__commit(self, *, cthread, _consumer): + offsets = {TP1: 1001} + cthread._consumer = _consumer + await cthread._commit(offsets) + + _consumer.commit.assert_called_once_with( + {TP1: OffsetAndMetadata(1001, "")}, + ) + + @pytest.mark.skip("Needs fixing") + @pytest.mark.asyncio + async def test__commit__already_rebalancing(self, *, cthread, _consumer): + cthread._consumer = _consumer + _consumer.commit.side_effect = CommitFailedError("already rebalanced") + assert not (await cthread._commit({TP1: 1001})) + + @pytest.mark.skip("Needs fixing") + @pytest.mark.asyncio + async def test__commit__CommitFailedError(self, *, cthread, _consumer): + cthread._consumer = _consumer + exc = _consumer.commit.side_effect = CommitFailedError("xx") + cthread.crash = AsyncMock() + cthread.supervisor = Mock(name="supervisor") + assert not (await cthread._commit({TP1: 1001})) + cthread.crash.assert_called_once_with(exc) + cthread.supervisor.wakeup.assert_called_once() + + @pytest.mark.skip("Needs fixing") + @pytest.mark.asyncio + async def test__commit__IllegalStateError(self, *, cthread, _consumer): + cthread._consumer = _consumer + cthread.assignment = Mock() + exc = _consumer.commit.side_effect = IllegalStateError("xx") + cthread.crash = AsyncMock() + cthread.supervisor = Mock(name="supervisor") + assert not (await cthread._commit({TP1: 1001})) + cthread.crash.assert_called_once_with(exc) + cthread.supervisor.wakeup.assert_called_once() + + @pytest.mark.asyncio + async def test_position(self, *, cthread, _consumer): + with self.assert_calls_thread(cthread, _consumer, _consumer.position, TP1): + await cthread.position(TP1) + + @pytest.mark.asyncio + async def test_seek_to_beginning(self, *, cthread, _consumer): + partitions = (TP1,) + with self.assert_calls_thread( + cthread, _consumer, _consumer.seek_to_beginning, *partitions + ): + await cthread.seek_to_beginning(*partitions) + + @pytest.mark.asyncio + async def test_seek_wait(self, *, cthread, _consumer): + partitions = {TP1: 1001} + with self.assert_calls_thread( + cthread, _consumer, cthread._seek_wait, _consumer, partitions + ): + await cthread.seek_wait(partitions) + + @pytest.mark.asyncio + async def test__seek_wait(self, *, cthread, _consumer): + cthread._consumer = _consumer + cthread.consumer._read_offset.clear() + partitions = {TP1: 0, TP2: 3} + + await cthread._seek_wait(_consumer, partitions) + + assert cthread.consumer._read_offset[TP2] == 3 + assert TP1 not in cthread.consumer._read_offset + + _consumer.position.assert_has_calls( + [ + call(TP1), + call(TP2), + ] + ) + + @pytest.mark.asyncio + async def test__seek_wait__empty(self, *, cthread, _consumer): + await cthread._seek_wait(_consumer, {}) + + def test_seek(self, *, cthread, _consumer): + cthread._consumer = _consumer + cthread.seek(TP1, 10) + _consumer.seek.assert_called_once_with(TP1, 10) + + def test_assignment(self, *, cthread, _consumer): + cthread._consumer = _consumer + _consumer.assignment.return_value = { + TopicPartition(TP1.topic, TP1.partition), + } + assignment = cthread.assignment() + assert assignment == {TP1} + assert all(isinstance(x, TP) for x in assignment) + + def test_highwater(self, *, cthread, _consumer): + cthread._consumer = _consumer + cthread.consumer.in_transaction = False + ret = cthread.highwater(TP1) + assert ret is _consumer.highwater.return_value + _consumer.highwater.assert_called_once_with(TP1) + + def test_highwater__in_transaction(self, *, cthread, _consumer): + cthread._consumer = _consumer + cthread.consumer.in_transaction = True + ret = cthread.highwater(TP1) + assert ret is _consumer.last_stable_offset.return_value + _consumer.last_stable_offset.assert_called_once_with(TP1) + + def test_topic_partitions(self, *, cthread, _consumer): + cthread._consumer = None + assert cthread.topic_partitions("foo") is None + cthread._consumer = _consumer + assert cthread.topic_partitions("foo") is ( + _consumer._coordinator._metadata_snapshot.get.return_value + ) + + @pytest.mark.asyncio + async def test_earliest_offsets(self, *, cthread, _consumer): + with self.assert_calls_thread( + cthread, _consumer, _consumer.beginning_offsets, (TP1,) + ): + await cthread.earliest_offsets(TP1) + + @pytest.mark.asyncio + async def test_highwaters(self, *, cthread, _consumer): + with self.assert_calls_thread(cthread, _consumer, cthread._highwaters, (TP1,)): + await cthread.highwaters(TP1) + + @pytest.mark.asyncio + async def test__highwaters(self, *, cthread, _consumer): + cthread.consumer.in_transaction = False + cthread._consumer = _consumer + assert await cthread._highwaters([TP1]) is (_consumer.end_offsets.return_value) + + @pytest.mark.asyncio + async def test__highwaters__in_transaction(self, *, cthread, _consumer): + cthread.consumer.in_transaction = True + cthread._consumer = _consumer + assert await cthread._highwaters([TP1]) == { + TP1: _consumer.last_stable_offset.return_value, + } + + def test__ensure_consumer(self, *, cthread, _consumer): + cthread._consumer = _consumer + assert cthread._ensure_consumer() is _consumer + cthread._consumer = None + with pytest.raises(ConsumerNotStarted): + cthread._ensure_consumer() + + @pytest.mark.asyncio + async def test_getmany(self, *, cthread, _consumer): + timeout = 13.1 + active_partitions = {TP1} + with self.assert_calls_thread( + cthread, + _consumer, + cthread._fetch_records, + _consumer, + active_partitions, + timeout=timeout, + max_records=_consumer._max_poll_records, + ): + await cthread.getmany(active_partitions, timeout) + + def test_key_partition(self, *, cthread, _consumer): + cthread._consumer = _consumer + cthread._partitioner = Mock(name="partitioner") + metadata = _consumer._client.cluster + metadata.partitions_for_topic.return_value = [1, 2, 3] + metadata.available_partitions_for_topic.return_value = [2, 3] + + cthread.key_partition("topic", "k", None) + cthread._partitioner.assert_called_once_with( + "k", + [1, 2, 3], + [2, 3], + ) + + with pytest.raises(AssertionError): + cthread.key_partition("topic", "k", -1) + with pytest.raises(AssertionError): + cthread.key_partition("topic", "k", 4) + + assert cthread.key_partition("topic", "k", 3) == 3 + + def test_key_partition__no_metadata(self, *, cthread, _consumer): + cthread._consumer = _consumer + cthread._partitioner = Mock(name="partitioner") + metadata = _consumer._client.cluster + metadata.partitions_for_topic.return_value = None + + assert cthread.key_partition("topic", "k", None) is None + + @contextmanager + def assert_calls_thread(self, cthread, _consumer, method, *args, **kwargs): + cthread._consumer = _consumer + cthread.call_thread = AsyncMock() + try: + yield + finally: + cthread.call_thread.assert_called_once_with(method, *args, **kwargs) + + +class MyPartitioner: + ... + + +my_partitioner = MyPartitioner() + + +class ProducerBaseTest: + @pytest.fixture() + def producer(self, *, app, _producer): + producer = Producer(app.transport) + producer._new_producer = Mock(return_value=_producer) + producer._producer = _producer + + # I can't figure out what is setting a value for this, + # so we're clearing out the dict after creation + producer._transaction_producers = {} + + return producer + + @pytest.fixture() + def _producer(self, *, _producer_call): + return _producer_call() + + @pytest.fixture() + def _producer_call(self): + def inner(): + return Mock( + name="Producer", + autospec=confluent_kafka.Producer, + start=AsyncMock(), + stop=AsyncMock(), + begin_transaction=AsyncMock(), + commit_transaction=AsyncMock(), + abort_transaction=AsyncMock(), + stop_transaction=AsyncMock(), + maybe_begin_transaction=AsyncMock(), + commit=AsyncMock(), + send=AsyncMock(), + flush=AsyncMock(), + send_offsets_to_transaction=AsyncMock(), + ) + + return inner + + def assert_new_producer( + self, + producer, + acks=-1, + api_version="auto", + bootstrap_servers=["localhost:9092"], # noqa, + client_id=f"faust-{faust.__version__}", + compression_type=None, + linger_ms=0, + max_batch_size=16384, + max_request_size=1000000, + request_timeout_ms=1200000, + security_protocol="PLAINTEXT", + **kwargs, + ): + with patch("confluent_kafka.Producer") as Producer: + p = producer._new_producer() + assert p is Producer.return_value + Producer.assert_called_once_with( + acks=acks, + api_version=api_version, + bootstrap_servers=bootstrap_servers, + client_id=client_id, + compression_type=compression_type, + linger_ms=linger_ms, + max_batch_size=max_batch_size, + max_request_size=max_request_size, + request_timeout_ms=request_timeout_ms, + security_protocol=security_protocol, + loop=producer.loop, + partitioner=producer.partitioner, + transactional_id=None, + **kwargs, + ) + + +class TestProducer(ProducerBaseTest): + @pytest.mark.conf(producer_partitioner=my_partitioner) + def test_producer__uses_custom_partitioner(self, *, producer): + assert producer.partitioner is my_partitioner + + @pytest.mark.asyncio + async def test_begin_transaction(self, *, producer, _producer): + await producer.begin_transaction("tid") + _producer.begin_transaction.assert_called_once() + + @pytest.mark.asyncio + async def test_commit_transaction(self, *, producer, _producer): + await producer.begin_transaction("tid") + await producer.commit_transaction("tid") + _producer.commit_transaction.assert_called_once() + + @pytest.mark.asyncio + async def test_abort_transaction(self, *, producer, _producer): + await producer.begin_transaction("tid") + await producer.abort_transaction("tid") + _producer.abort_transaction.assert_called_once() + + @pytest.mark.asyncio + async def test_stop_transaction(self, *, producer, _producer): + await producer.begin_transaction("tid") + await producer.stop_transaction("tid") + _producer.stop.assert_called_once() + + @pytest.mark.asyncio + async def test_maybe_begin_transaction(self, *, producer, _producer): + await producer.maybe_begin_transaction("tid") + _producer.begin_transaction.assert_called_once() + + @pytest.mark.asyncio + async def test_commit_transactions(self, *, producer, _producer_call): + _producer1 = _producer_call() + _producer2 = _producer_call() + producer._new_producer = Mock(return_value=_producer1) + await producer.begin_transaction("t1") + producer._new_producer = Mock(return_value=_producer2) + await producer.begin_transaction("t2") + tid_to_offset_map = {"t1": {TP1: 1001}, "t2": {TP2: 2002}} + + await producer.commit_transactions( + tid_to_offset_map, "group_id", start_new_transaction=False + ) + + _producer1.send_offsets_to_transaction.assert_called_once_with( + {TP1: 1001}, "group_id" + ) + _producer2.send_offsets_to_transaction.assert_called_once_with( + {TP2: 2002}, "group_id" + ) + _producer1.commit_transaction.assert_called_once() + _producer2.commit_transaction.assert_called_once() + + def test__settings_extra(self, *, producer, app): + app.in_transaction = True + assert producer._settings_extra() == {"acks": "all", "enable_idempotence": True} + app.in_transaction = False + assert producer._settings_extra() == {} + + def test__producer_type(self, *, producer, app): + assert isinstance(producer, Producer) + + @pytest.mark.asyncio + async def test_create_topic(self, *, producer, _producer): + _producer.client = Mock( + force_metadata_update=AsyncMock(), + ) + producer.transport = Mock( + _create_topic=AsyncMock(), + ) + await producer.create_topic( + "foo", + 100, + 3, + config={"x": "y"}, + timeout=30.3, + retention=300.3, + compacting=True, + deleting=True, + ensure_created=True, + ) + producer._create_topic.assert_called_once_with( + producer, + _producer.client, + "foo", + 100, + 3, + config={"x": "y"}, + timeout=int(30.3 * 1000.0), + retention=int(300.3 * 1000.0), + compacting=True, + deleting=True, + ensure_created=True, + ) + + def test__ensure_producer(self, *, producer, _producer): + assert producer._ensure_producer() is _producer + producer._producer = None + with pytest.raises(NotReady): + producer._ensure_producer() + + @pytest.mark.asyncio + async def test_on_start(self, *, producer, loop): + producer._new_producer = Mock( + name="_new_producer", + return_value=Mock( + start=AsyncMock(), + ), + ) + _producer = producer._new_producer.return_value + producer.beacon = Mock() + + await producer.on_start() + assert producer._producer is _producer + producer._new_producer.assert_called_once_with() + producer.beacon.add.assert_called_with(_producer) + _producer.start.assert_called_once_with() + + @pytest.mark.asyncio + async def test_on_stop(self, *, producer, _producer): + await producer.on_stop() + assert producer._producer is None + _producer.stop.assert_called_once_with() + + def test_supports_headers__not_ready(self, *, producer): + producer._producer.client = None + with pytest.raises(NotReady): + producer.supports_headers() + + @pytest.mark.asyncio + async def test_send(self, producer, _producer): + await producer.begin_transaction("tid") + await producer.send( + "topic", + "k", + "v", + 3, + 100, + {"foo": "bar"}, + transactional_id="tid", + ) + _producer.send.assert_called_once_with( + "topic", + "v", + key="k", + partition=3, + timestamp_ms=100 * 1000.0, + headers=[("foo", "bar")], + ) + + @pytest.mark.asyncio + @pytest.mark.conf(producer_api_version="0.10") + async def test_send__request_no_headers(self, producer, _producer): + await producer.begin_transaction("tid") + await producer.send( + "topic", + "k", + "v", + 3, + 100, + {"foo": "bar"}, + transactional_id="tid", + ) + _producer.send.assert_called_once_with( + "topic", + "v", + key="k", + partition=3, + timestamp_ms=100 * 1000.0, + headers=None, + ) + + @pytest.mark.asyncio + @pytest.mark.conf(producer_api_version="0.11") + async def test_send__kafka011_supports_headers(self, producer, _producer): + await producer.begin_transaction("tid") + await producer.send( + "topic", + "k", + "v", + 3, + 100, + {"foo": "bar"}, + transactional_id="tid", + ) + _producer.send.assert_called_once_with( + "topic", + "v", + key="k", + partition=3, + timestamp_ms=100 * 1000.0, + headers=[("foo", "bar")], + ) + + @pytest.mark.asyncio + @pytest.mark.conf(producer_api_version="auto") + async def test_send__auto_passes_headers(self, producer, _producer): + await producer.begin_transaction("tid") + await producer.send( + "topic", + "k", + "v", + 3, + 100, + [("foo", "bar")], + transactional_id="tid", + ) + _producer.send.assert_called_once_with( + "topic", + "v", + key="k", + partition=3, + timestamp_ms=100 * 1000.0, + headers=[("foo", "bar")], + ) + + @pytest.mark.asyncio + async def test_send__no_headers(self, producer, _producer): + await producer.begin_transaction("tid") + await producer.send( + "topic", + "k", + "v", + 3, + 100, + None, + transactional_id="tid", + ) + _producer.send.assert_called_once_with( + "topic", + "v", + key="k", + partition=3, + timestamp_ms=100 * 1000.0, + headers=None, + ) + + @pytest.mark.asyncio + async def test_send__no_timestamp(self, producer, _producer): + await producer.begin_transaction("tid") + await producer.send( + "topic", + "k", + "v", + 3, + None, + None, + transactional_id="tid", + ) + _producer.send.assert_called_once_with( + "topic", + "v", + key="k", + partition=3, + timestamp_ms=None, + headers=None, + ) + + @pytest.mark.asyncio + async def test_send__KafkaError(self, producer, _producer): + _producer.send.side_effect = KafkaError() + with pytest.raises(ProducerSendError): + await producer.send( + "topic", + "k", + "v", + 3, + None, + None, + ) + + @pytest.mark.asyncio + async def test_send__trn_KafkaError(self, producer, _producer): + _producer.send.side_effect = KafkaError() + await producer.begin_transaction("tid") + with pytest.raises(ProducerSendError): + await producer.send( + "topic", + "k", + "v", + 3, + None, + None, + transactional_id="tid", + ) + + @pytest.mark.asyncio + async def test_send_and_wait(self, producer): + producer.send = AsyncMock(return_value=done_future(done_future())) + + await producer.send_and_wait( + "topic", "k", "v", 3, 100, [("a", "b")], transactional_id="tid" + ) + producer.send.assert_called_once_with( + "topic", + key="k", + value="v", + partition=3, + timestamp=100, + headers=[("a", "b")], + transactional_id="tid", + ) + + @pytest.mark.asyncio + async def test_flush(self, *, producer, _producer): + producer._producer = None + await producer.flush() + producer._producer = _producer + await producer.flush() + _producer.flush.assert_called_once_with() + + def test_key_partition(self, *, producer, _producer): + x = producer.key_partition("topic", "k") + assert x == TP("topic", _producer._partition.return_value) + + def test_supports_headers(self, *, producer): + producer._producer.client.api_version = (0, 11) + assert producer.supports_headers() + + +class TestProducerThread(ProducerBaseTest): + @pytest.fixture() + def threaded_producer(self, *, producer: Producer): + return producer.create_threaded_producer() + + @pytest.fixture() + def new_producer_mock(self, *, threaded_producer: ProducerThread): + mock = threaded_producer._new_producer = Mock( + name="_new_producer", + return_value=Mock( + start=AsyncMock(), + stop=AsyncMock(), + flush=AsyncMock(), + send_and_wait=AsyncMock(), + send=AsyncMock(), + ), + ) + return mock + + @pytest.fixture() + def mocked_producer(self, *, new_producer_mock: Mock): + return new_producer_mock.return_value + + @pytest.mark.asyncio + async def test_on_start( + self, *, threaded_producer: ProducerThread, mocked_producer: Mock, loop + ): + await threaded_producer.on_start() + try: + assert threaded_producer._producer is mocked_producer + threaded_producer._new_producer.assert_called_once_with() + mocked_producer.start.assert_called_once_with() + finally: + await threaded_producer.start() + await threaded_producer.stop() + + @pytest.mark.skip("Needs fixing") + @pytest.mark.asyncio + async def test_on_thread_stop( + self, *, threaded_producer: ProducerThread, mocked_producer: Mock, loop + ): + await threaded_producer.start() + await threaded_producer.on_thread_stop() + try: + # Flush and stop currently are called twice + mocked_producer.flush.assert_called_once() + mocked_producer.stop.assert_called_once() + finally: + await threaded_producer.stop() + + @pytest.mark.asyncio + async def test_publish_message( + self, *, threaded_producer: ProducerThread, mocked_producer: Mock, loop + ): + await threaded_producer.start() + try: + await threaded_producer.publish_message( + fut_other=FutureMessage( + PendingMessage( + channel=Mock(), + key="Test", + value="Test", + partition=None, + timestamp=None, + headers=None, + key_serializer=None, + value_serializer=None, + callback=None, + ) + ) + ) + mocked_producer.send.assert_called_once() + finally: + await threaded_producer.stop() + + @pytest.mark.asyncio + async def test_publish_message_with_wait( + self, *, threaded_producer: ProducerThread, mocked_producer: Mock, loop + ): + await threaded_producer.start() + try: + await threaded_producer.publish_message( + wait=True, + fut_other=FutureMessage( + PendingMessage( + channel=Mock(), + key="Test", + value="Test", + partition=None, + timestamp=None, + headers=None, + key_serializer=None, + value_serializer=None, + callback=None, + ) + ), + ) + mocked_producer.send_and_wait.assert_called_once() + finally: + await threaded_producer.stop() + + +class TestTransport: + @pytest.fixture() + def transport(self, *, app): + return Transport(url=["confluent://"], app=app) + + def test__topic_config(self, *, transport): + assert transport._topic_config() == {} + + def test__topic_config__retention(self, *, transport): + assert transport._topic_config(retention=3000.3) == { + "retention.ms": 3000.3, + } + + def test__topic_config__compacting(self, *, transport): + assert transport._topic_config(compacting=True) == { + "cleanup.policy": "compact", + } + + def test__topic_config__deleting(self, *, transport): + assert transport._topic_config(deleting=True) == { + "cleanup.policy": "delete", + } + + def test__topic_config__combined(self, *, transport): + res = transport._topic_config(compacting=True, deleting=True, retention=3000.3) + assert res == { + "retention.ms": 3000.3, + "cleanup.policy": "compact,delete", + }