diff --git a/.pylintrc b/.pylintrc index 0c6ba556..b2abd0f9 100644 --- a/.pylintrc +++ b/.pylintrc @@ -169,7 +169,7 @@ contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. -generated-members=zipkin_pb2.* +generated-members=zipkin_pb2.*,profile_pb2.* # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). @@ -439,7 +439,8 @@ exclude-protected=_asdict, _replace, _source, _make, - _Span + _Span, + _current_frames # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls @@ -463,7 +464,7 @@ max-bool-expr=5 max-branches=12 # Maximum number of locals for function / method body. -max-locals=15 +max-locals=32 # Maximum number of parents for a class (see R0901). max-parents=7 @@ -475,7 +476,7 @@ max-public-methods=20 max-returns=6 # Maximum number of statements in function / method body. -max-statements=50 +max-statements=100 # Minimum number of public methods for a class (see R0903). min-public-methods=2 diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 89625dad..6a5d4312 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -26,7 +26,7 @@ the test project's environment. Assuming the test project environment lives at version of package in the test project. ``` -make develop DEV_ENV=/path/to/test/project/venv +make develop DEV_VENV=/path/to/test/project/venv ``` This will install an editable version of the package in the test project. Any diff --git a/pyproject.toml b/pyproject.toml index 00e9f4c0..b7e4028d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ splunk_distro = "splunk_otel.distro:_SplunkDistro" [tool.poetry.dependencies] cryptography=">=2.0,<=41.0.4" python = "^3.7.2" +protobuf = "^4.23" opentelemetry-api = "1.20.0" opentelemetry-sdk = "1.20.0" opentelemetry-instrumentation = "0.41b0" diff --git a/splunk_otel/__init__.py b/splunk_otel/__init__.py index efac50d8..e4e23384 100644 --- a/splunk_otel/__init__.py +++ b/splunk_otel/__init__.py @@ -20,10 +20,14 @@ """ from .defaults import _set_otel_defaults +from .util import _init_logger + +_init_logger("splunk_otel") _set_otel_defaults() # pylint: disable=wrong-import-position from .metrics import start_metrics # noqa: E402 +from .profiling import start_profiling # noqa: E402 from .tracing import start_tracing # noqa: E402 from .version import __version__ # noqa: E402 diff --git a/splunk_otel/distro.py b/splunk_otel/distro.py index fd3f6702..bb4444e0 100644 --- a/splunk_otel/distro.py +++ b/splunk_otel/distro.py @@ -21,19 +21,20 @@ from splunk_otel.metrics import _configure_metrics from splunk_otel.options import _Options +from splunk_otel.profiling import _start_profiling +from splunk_otel.profiling.options import _Options as ProfilingOptions from splunk_otel.tracing import _configure_tracing from splunk_otel.util import _is_truthy -otel_log_level = os.environ.get("OTEL_LOG_LEVEL", logging.INFO) - -logger = logging.getLogger(__file__) -logger.setLevel(otel_log_level) +logger = logging.getLogger(__name__) class _SplunkDistro(BaseDistro): def __init__(self): tracing_enabled = os.environ.get("OTEL_TRACE_ENABLED", True) + profiling_enabled = os.environ.get("SPLUNK_PROFILER_ENABLED", False) self._tracing_enabled = _is_truthy(tracing_enabled) + self._profiling_enabled = _is_truthy(profiling_enabled) if not self._tracing_enabled: logger.info( "tracing has been disabled with OTEL_TRACE_ENABLED=%s", tracing_enabled @@ -47,8 +48,14 @@ def __init__(self): ) def _configure(self, **kwargs: Dict[str, Any]) -> None: + options = _Options() + if self._tracing_enabled: - _configure_tracing(_Options()) + _configure_tracing(options) + + if self._profiling_enabled: + _start_profiling(ProfilingOptions(options.resource)) + if self._metrics_enabled: _configure_metrics() diff --git a/splunk_otel/metrics.py b/splunk_otel/metrics.py index 4b19b3db..c1b3eeac 100644 --- a/splunk_otel/metrics.py +++ b/splunk_otel/metrics.py @@ -23,10 +23,7 @@ from splunk_otel.util import _is_truthy -otel_log_level = os.environ.get("OTEL_LOG_LEVEL", logging.INFO) - -logger = logging.getLogger(__file__) -logger.setLevel(otel_log_level) +logger = logging.getLogger(__name__) def start_metrics() -> MeterProvider: diff --git a/splunk_otel/profiling/__init__.py b/splunk_otel/profiling/__init__.py new file mode 100644 index 00000000..37366be4 --- /dev/null +++ b/splunk_otel/profiling/__init__.py @@ -0,0 +1,432 @@ +# Copyright Splunk Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import base64 +import gzip +import logging +import os +import sys +import threading +import time +import traceback +from collections import OrderedDict +from traceback import StackSummary +from typing import Dict, Optional, Union + +import opentelemetry.context +import wrapt +from opentelemetry._logs import SeverityNumber +from opentelemetry.context import Context +from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter +from opentelemetry.instrumentation.utils import unwrap +from opentelemetry.sdk._logs import LoggerProvider, LogRecord +from opentelemetry.sdk._logs.export import BatchLogRecordProcessor +from opentelemetry.trace import TraceFlags +from opentelemetry.trace.propagation import _SPAN_KEY + +import splunk_otel +from splunk_otel.profiling import profile_pb2 +from splunk_otel.profiling.options import _Options + +logger = logging.getLogger(__name__) + + +class Profiler: + def __init__(self): + self.condition = threading.Condition(threading.Lock()) + self.thread_states = {} + + self.resource = None + self.call_stack_interval_millis = None + self.include_internal_stacks = None + self.running = False + + self.logger_provider = None + self.exporter = None + self.batch_processor = None + + def setup( + self, resource, endpoint, call_stack_interval_millis, include_internal_stacks + ): + self.resource = resource + self.call_stack_interval_millis = call_stack_interval_millis + self.include_internal_stacks = include_internal_stacks + self.running = True + + self.exporter = OTLPLogExporter(endpoint) + self.batch_processor = BatchLogRecordProcessor(self.exporter) + self.logger_provider = LoggerProvider(resource=resource) + self.logger_provider.add_log_record_processor(self.batch_processor) + + def loop(self): + ignored_thread_ids = self.get_ignored_thread_ids() + + profiling_logger = self.logger_provider.get_logger("otel.profiling", "0.1.0") + + call_stack_interval_seconds = self.call_stack_interval_millis / 1e3 + # In case the processing takes more than the given interval, default to the smallest allowed interval + min_call_stack_interval_seconds = 1 / 1e3 + while self.is_running(): + time_begin = time.time() + timestamp = int(time_begin * 1e9) + + stacktraces = _collect_stacktraces(ignored_thread_ids) + + log_record = self.to_log_record( + stacktraces, + timestamp, + ) + + profiling_logger.emit(log_record) + + processing_time = time.time() - time_begin + wait_for = max( + call_stack_interval_seconds - processing_time, + min_call_stack_interval_seconds, + ) + + with self.condition: + self.condition.wait(wait_for) + + # Hack around batch log processor getting stuck on application exits when the profiling endpoint is not reachable + # Could be related: https://github.com/open-telemetry/opentelemetry-python/issues/2284 + # pylint: disable-next=protected-access + self.exporter._shutdown = True + self.logger_provider.shutdown() + with self.condition: + self.condition.notify_all() + + def get_ignored_thread_ids(self): + if self.include_internal_stacks: + return [] + + ignored_ids = [threading.get_ident()] + + log_processor_thread_id = self.extract_log_processor_thread_id() + + if log_processor_thread_id: + ignored_ids.append(log_processor_thread_id) + + return ignored_ids + + def extract_log_processor_thread_id(self): + if hasattr(self.batch_processor, "_worker_thread"): + # pylint: disable-next=protected-access + worker_thread = self.batch_processor._worker_thread + + if isinstance(worker_thread, threading.Thread): + return worker_thread.ident + + return None + + def to_log_record(self, stacktraces, timestamp_unix_nanos): + encoded = self.get_cpu_profile(stacktraces, timestamp_unix_nanos) + encoded_profile = base64.b64encode(encoded).decode() + + return LogRecord( + timestamp=timestamp_unix_nanos, + trace_id=0, + span_id=0, + trace_flags=TraceFlags(0x01), + severity_number=SeverityNumber.UNSPECIFIED, + body=encoded_profile, + resource=self.resource, + attributes={ + "profiling.data.format": "pprof-gzip-base64", + "profiling.data.type": "cpu", + "com.splunk.sourcetype": "otel.profiling", + "profiling.data.total.frame.count": len(stacktraces), + }, + ) + + def get_cpu_profile(self, stacktraces, timestamp_unix_nanos): + str_table = StringTable() + locations_table = OrderedDict() + functions_table = OrderedDict() + + def get_function(file_name, function_name): + key = f"{file_name}:{function_name}" + fun = functions_table.get(key) + + if fun is None: + name_id = str_table.index(function_name) + fun = profile_pb2.Function() + fun.id = len(functions_table) + 1 + fun.name = name_id + fun.system_name = name_id + fun.filename = str_table.index(file_name) + functions_table[key] = fun + + return fun + + def get_line(file_name, function_name, line_no): + line = profile_pb2.Line() + line.function_id = get_function(file_name, function_name).id + line.line = line_no if line_no != 0 else -1 + return line + + def get_location(frame): + (file_name, function_name, line_no) = frame + key = f"{file_name}:{function_name}:{line_no}" + location = locations_table.get(key) + + if location is None: + location = profile_pb2.Location() + location.id = len(locations_table) + 1 + location.line.append(get_line(file_name, function_name, line_no)) + locations_table[key] = location + + return location + + timestamp_unix_millis = int(timestamp_unix_nanos / 1e6) + + timestamp_key = str_table.index("source.event.time") + trace_id_key = str_table.index("trace_id") + span_id_key = str_table.index("span_id") + thread_id_key = str_table.index("thread.id") + event_period_key = str_table.index("source.event.period") + + pb_profile = profile_pb2.Profile() + + event_period_label = profile_pb2.Label() + event_period_label.key = event_period_key + event_period_label.num = self.call_stack_interval_millis + + samples = [] + for stacktrace in stacktraces: + thread_id = stacktrace["tid"] + + timestamp_label = profile_pb2.Label() + timestamp_label.key = timestamp_key + timestamp_label.num = timestamp_unix_millis + + thread_id_label = profile_pb2.Label() + thread_id_label.key = thread_id_key + thread_id_label.num = thread_id + + labels = [timestamp_label, event_period_label, thread_id_label] + + trace_context = self.thread_states.get(thread_id) + if trace_context: + (trace_id, span_id) = trace_context + + trace_id_label = profile_pb2.Label() + trace_id_label.key = trace_id_key + trace_id_label.str = str_table.index(f"{trace_id:#016x}") + labels.append(trace_id_label) + + span_id_label = profile_pb2.Label() + span_id_label.key = span_id_key + span_id_label.str = str_table.index(f"{span_id:#08x}") + labels.append(span_id_label) + + sample = profile_pb2.Sample() + + location_ids = [] + + for frame in reversed(stacktrace["stacktrace"]): + location_ids.append(get_location(frame).id) + + sample.location_id.extend(location_ids) + sample.label.extend(labels) + + samples.append(sample) + + pb_profile.sample.extend(samples) + pb_profile.string_table.extend(str_table.keys()) + pb_profile.function.extend(list(functions_table.values())) + pb_profile.location.extend(list(locations_table.values())) + + return gzip.compress(pb_profile.SerializeToString()) + + def is_running(self): + return self.running + + def force_flush(self): + if self.logger_provider: + self.logger_provider.force_flush() + + def make_wrapped_context_attach(self): + def _wrapped_context_attach(wrapped, _instance, args, kwargs): + token = wrapped(*args, **kwargs) + + maybe_context = args[0] if args else None + + if maybe_context: + span = maybe_context.get(_SPAN_KEY) + + if span: + thread_id = threading.get_ident() + self.thread_states[thread_id] = ( + span.context.trace_id, + span.context.span_id, + ) + + return token + + return _wrapped_context_attach + + def make_wrapped_context_detach(self): + def _wrapped_context_detach(wrapped, _instance, args, kwargs): + token = args[0] if args else None + + if token: + prev = token.old_value + thread_id = threading.get_ident() + if isinstance(prev, Context): + span = prev.get(_SPAN_KEY) + + if span: + self.thread_states[thread_id] = ( + span.context.trace_id, + span.context.span_id, + ) + else: + self.thread_states[thread_id] = None + else: + self.thread_states[thread_id] = None + return wrapped(*args, **kwargs) + + return _wrapped_context_detach + + +_profiler = Profiler() + + +def get_profiler(): + return _profiler + + +class StringTable: + def __init__(self): + self.strings = OrderedDict() + + def index(self, token): + idx = self.strings.get(token) + + if idx: + return idx + + idx = len(self.strings) + self.strings[token] = idx + return idx + + def keys(self): + return list(self.strings.keys()) + + +def extract_stack(frame): + stack = StackSummary.extract( + traceback.walk_stack(frame), limit=None, lookup_lines=False + ) + stack.reverse() + return stack + + +def _collect_stacktraces(ignored_thread_ids): + stacktraces = [] + + frames = sys._current_frames() + + for thread_id, frame in frames.items(): + if thread_id in ignored_thread_ids: + continue + + stacktrace_frames = [] + stack = extract_stack(frame) + for sf in stack: + stacktrace_frames.append((sf.filename, sf.name, sf.lineno)) + stacktrace = { + "stacktrace": stacktrace_frames, + "tid": thread_id, + } + stacktraces.append(stacktrace) + return stacktraces + + +def _start_profiler_thread(profiler): + profiler_thread = threading.Thread( + name="splunk-otel-profiler", target=profiler.loop, daemon=True + ) + profiler_thread.start() + + +def _force_flush(): + profiler = get_profiler() + profiler.force_flush() + + +def _start_profiling(opts): + profiler = get_profiler() + + if profiler.is_running(): + logger.warning("profiler already running") + return + + profiler.setup( + opts.resource, + opts.endpoint, + opts.call_stack_interval_millis, + opts.include_internal_stacks, + ) + + logger.debug( + "starting profiling call_stack_interval_millis=%s endpoint=%s", + opts.call_stack_interval_millis, + opts.endpoint, + ) + wrapt.wrap_function_wrapper( + opentelemetry.context, "attach", profiler.make_wrapped_context_attach() + ) + wrapt.wrap_function_wrapper( + opentelemetry.context, "detach", profiler.make_wrapped_context_detach() + ) + + # Windows does not have register_at_fork + if hasattr(os, "register_at_fork"): + os.register_at_fork(after_in_child=lambda: _start_profiler_thread(profiler)) + atexit.register(stop_profiling) + + _start_profiler_thread(profiler) + + +def start_profiling( + service_name: Optional[str] = None, + resource_attributes: Optional[Dict[str, Union[str, bool, int, float]]] = None, + endpoint: Optional[str] = None, + call_stack_interval_millis: Optional[int] = None, +): + # pylint: disable-next=protected-access + resource = splunk_otel.options._Options._get_resource( + service_name, resource_attributes + ) + opts = _Options(resource, endpoint, call_stack_interval_millis) + _start_profiling(opts) + + +def stop_profiling(): + profiler = get_profiler() + if not profiler.is_running(): + return + + profiler.running = False + with profiler.condition: + # Wake up the profiler thread + profiler.condition.notify_all() + # Wait for the profiler thread to exit + profiler.condition.wait() + + unwrap(opentelemetry.context, "attach") + unwrap(opentelemetry.context, "detach") diff --git a/splunk_otel/profiling/options.py b/splunk_otel/profiling/options.py new file mode 100644 index 00000000..e70669dc --- /dev/null +++ b/splunk_otel/profiling/options.py @@ -0,0 +1,94 @@ +# Copyright Splunk Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from os import environ +from typing import Optional + +from opentelemetry.sdk.resources import Resource + +from splunk_otel.util import _is_truthy + +logger = logging.getLogger(__name__) +_DEFAULT_CALL_STACK_INTERVAL_MILLIS = 1_000 + + +def _sanitize_interval(interval): + if isinstance(interval, int): + if interval < 1: + logger.warning( + "call stack interval has to be positive, got %s, defaulting to %s", + interval, + _DEFAULT_CALL_STACK_INTERVAL_MILLIS, + ) + return _DEFAULT_CALL_STACK_INTERVAL_MILLIS + + return interval + + logger.warning( + "call stack interval not an integer, defaulting to %s", + _DEFAULT_CALL_STACK_INTERVAL_MILLIS, + ) + return _DEFAULT_CALL_STACK_INTERVAL_MILLIS + + +class _Options: + resource: Resource + endpoint: str + call_stack_interval_millis: int + + def __init__( + self, + resource: Resource, + endpoint: Optional[str] = None, + call_stack_interval_millis: Optional[int] = None, + include_internal_stacks: Optional[bool] = None, + ): + self.resource = resource + self.endpoint = _Options._get_endpoint(endpoint) + self.call_stack_interval_millis = _Options._get_call_stack_interval( + call_stack_interval_millis + ) + self.include_internal_stacks = _Options._include_internal_stacks( + include_internal_stacks + ) + + @staticmethod + def _get_endpoint(endpoint: Optional[str]) -> str: + if not endpoint: + endpoint = environ.get("SPLUNK_PROFILER_LOGS_ENDPOINT") + + if not endpoint: + endpoint = environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") + + return endpoint or "http://localhost:4317" + + @staticmethod + def _get_call_stack_interval(interval: Optional[int]) -> int: + if not interval: + interval = environ.get("SPLUNK_PROFILER_CALL_STACK_INTERVAL") + + if interval: + return _sanitize_interval(int(interval)) + + return _DEFAULT_CALL_STACK_INTERVAL_MILLIS + + return _sanitize_interval(interval) + + @staticmethod + def _include_internal_stacks(include: Optional[bool]) -> bool: + if include is None: + return _is_truthy(environ.get("SPLUNK_PROFILER_INCLUDE_INTERNAL_STACKS")) + + return include diff --git a/splunk_otel/profiling/profile_pb2.py b/splunk_otel/profiling/profile_pb2.py new file mode 100644 index 00000000..a597ba98 --- /dev/null +++ b/splunk_otel/profiling/profile_pb2.py @@ -0,0 +1,44 @@ +# pylint: skip-file +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: profile.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\rprofile.proto\x12\x12perftools.profiles"\xd5\x03\n\x07Profile\x12\x32\n\x0bsample_type\x18\x01 \x03(\x0b\x32\x1d.perftools.profiles.ValueType\x12*\n\x06sample\x18\x02 \x03(\x0b\x32\x1a.perftools.profiles.Sample\x12,\n\x07mapping\x18\x03 \x03(\x0b\x32\x1b.perftools.profiles.Mapping\x12.\n\x08location\x18\x04 \x03(\x0b\x32\x1c.perftools.profiles.Location\x12.\n\x08\x66unction\x18\x05 \x03(\x0b\x32\x1c.perftools.profiles.Function\x12\x14\n\x0cstring_table\x18\x06 \x03(\t\x12\x13\n\x0b\x64rop_frames\x18\x07 \x01(\x03\x12\x13\n\x0bkeep_frames\x18\x08 \x01(\x03\x12\x12\n\ntime_nanos\x18\t \x01(\x03\x12\x16\n\x0e\x64uration_nanos\x18\n \x01(\x03\x12\x32\n\x0bperiod_type\x18\x0b \x01(\x0b\x32\x1d.perftools.profiles.ValueType\x12\x0e\n\x06period\x18\x0c \x01(\x03\x12\x0f\n\x07\x63omment\x18\r \x03(\x03\x12\x1b\n\x13\x64\x65\x66\x61ult_sample_type\x18\x0e \x01(\x03"\'\n\tValueType\x12\x0c\n\x04type\x18\x01 \x01(\x03\x12\x0c\n\x04unit\x18\x02 \x01(\x03"V\n\x06Sample\x12\x13\n\x0blocation_id\x18\x01 \x03(\x04\x12\r\n\x05value\x18\x02 \x03(\x03\x12(\n\x05label\x18\x03 \x03(\x0b\x32\x19.perftools.profiles.Label"@\n\x05Label\x12\x0b\n\x03key\x18\x01 \x01(\x03\x12\x0b\n\x03str\x18\x02 \x01(\x03\x12\x0b\n\x03num\x18\x03 \x01(\x03\x12\x10\n\x08num_unit\x18\x04 \x01(\x03"\xdd\x01\n\x07Mapping\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x14\n\x0cmemory_start\x18\x02 \x01(\x04\x12\x14\n\x0cmemory_limit\x18\x03 \x01(\x04\x12\x13\n\x0b\x66ile_offset\x18\x04 \x01(\x04\x12\x10\n\x08\x66ilename\x18\x05 \x01(\x03\x12\x10\n\x08\x62uild_id\x18\x06 \x01(\x03\x12\x15\n\rhas_functions\x18\x07 \x01(\x08\x12\x15\n\rhas_filenames\x18\x08 \x01(\x08\x12\x18\n\x10has_line_numbers\x18\t \x01(\x08\x12\x19\n\x11has_inline_frames\x18\n \x01(\x08"v\n\x08Location\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x12\n\nmapping_id\x18\x02 \x01(\x04\x12\x0f\n\x07\x61\x64\x64ress\x18\x03 \x01(\x04\x12&\n\x04line\x18\x04 \x03(\x0b\x32\x18.perftools.profiles.Line\x12\x11\n\tis_folded\x18\x05 \x01(\x08")\n\x04Line\x12\x13\n\x0b\x66unction_id\x18\x01 \x01(\x04\x12\x0c\n\x04line\x18\x02 \x01(\x03"_\n\x08\x46unction\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\x03\x12\x13\n\x0bsystem_name\x18\x03 \x01(\x03\x12\x10\n\x08\x66ilename\x18\x04 \x01(\x03\x12\x12\n\nstart_line\x18\x05 \x01(\x03\x42-\n\x1d\x63om.google.perftools.profilesB\x0cProfileProtob\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "profile_pb2", _globals) +if _descriptor._USE_C_DESCRIPTORS is False: + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = ( + b"\n\035com.google.perftools.profilesB\014ProfileProto" + ) + _globals["_PROFILE"]._serialized_start = 38 + _globals["_PROFILE"]._serialized_end = 507 + _globals["_VALUETYPE"]._serialized_start = 509 + _globals["_VALUETYPE"]._serialized_end = 548 + _globals["_SAMPLE"]._serialized_start = 550 + _globals["_SAMPLE"]._serialized_end = 636 + _globals["_LABEL"]._serialized_start = 638 + _globals["_LABEL"]._serialized_end = 702 + _globals["_MAPPING"]._serialized_start = 705 + _globals["_MAPPING"]._serialized_end = 926 + _globals["_LOCATION"]._serialized_start = 928 + _globals["_LOCATION"]._serialized_end = 1046 + _globals["_LINE"]._serialized_start = 1048 + _globals["_LINE"]._serialized_end = 1089 + _globals["_FUNCTION"]._serialized_start = 1091 + _globals["_FUNCTION"]._serialized_end = 1186 +# @@protoc_insertion_point(module_scope) diff --git a/splunk_otel/tracing.py b/splunk_otel/tracing.py index 392063e3..b27bee76 100644 --- a/splunk_otel/tracing.py +++ b/splunk_otel/tracing.py @@ -29,10 +29,7 @@ from splunk_otel.options import _Options, _SpanExporterFactory from splunk_otel.util import _is_truthy -otel_log_level = os.environ.get("OTEL_LOG_LEVEL", logging.INFO) - -logger = logging.getLogger(__file__) -logger.setLevel(otel_log_level) +logger = logging.getLogger(__name__) def start_tracing( diff --git a/splunk_otel/util.py b/splunk_otel/util.py index 4ebc6ec2..b6f3c2bf 100644 --- a/splunk_otel/util.py +++ b/splunk_otel/util.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import os from typing import Any @@ -19,3 +21,23 @@ def _is_truthy(value: Any) -> bool: if isinstance(value, str): value = value.lower().strip() return value in [True, 1, "true", "yes"] + + +def _get_log_level(level): + levels = { + "none": logging.NOTSET, + "debug": logging.DEBUG, + "info": logging.INFO, + "warn": logging.WARNING, + "error": logging.ERROR, + "fatal": logging.CRITICAL, + } + + return levels[level.lower()] + + +def _init_logger(name): + level = _get_log_level(os.environ.get("OTEL_LOG_LEVEL", "info")) + logger = logging.getLogger(name) + logger.setLevel(level) + logger.addHandler(logging.StreamHandler()) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 122aba0b..e5916c14 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -36,7 +36,7 @@ def integration(pytestconfig, docker_ip, docker_services): timeout=10, pause=0.1, check=lambda: is_responsive(url) ) return IntegrationSession( - poll_url=f"http://{docker_ip}:8378", + poll_url=f"http://{docker_ip}:8378/spans", rootdir=path.join(str(pytestconfig.rootdir), "tests", "integration"), ) @@ -45,7 +45,7 @@ def integration(pytestconfig, docker_ip, docker_services): def integration_local(pytestconfig): """Local non-docer based replacement for `integration` fixture""" return IntegrationSession( - poll_url="http://localhost:8378", + poll_url="http://localhost:8378/spans", rootdir=path.join(str(pytestconfig.rootdir), "tests", "integration"), ) @@ -55,6 +55,6 @@ def is_responsive(url) -> bool: response = requests.get(url, timeout=10) if response.status_code == 200: return True - except ConnectionError: + except requests.exceptions.ConnectionError: pass return False diff --git a/tests/integration/docker-compose.yml b/tests/integration/docker-compose.yml index 5119c056..3290d1fc 100644 --- a/tests/integration/docker-compose.yml +++ b/tests/integration/docker-compose.yml @@ -1,7 +1,7 @@ version: '3' services: collector: - image: quay.io/signalfx/splunk-otel-collector:0.48.0 + image: quay.io/signalfx/splunk-otel-collector:0.85.0 environment: - SPLUNK_CONFIG=/etc/otel/config.yml volumes: diff --git a/tests/unit/test_profiling.py b/tests/unit/test_profiling.py new file mode 100644 index 00000000..f6c69594 --- /dev/null +++ b/tests/unit/test_profiling.py @@ -0,0 +1,228 @@ +# Copyright Splunk Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import gzip +import os +import random +import threading +import time +import unittest +from unittest import mock + +from opentelemetry import trace as trace_api +from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter +from opentelemetry.sdk._logs.export import LogExportResult +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor + +from splunk_otel.profiling import ( + _force_flush, + profile_pb2, + start_profiling, + stop_profiling, +) + + +def do_work(time_ms): + now = time.time() + target = now + time_ms / 1000.0 + + total = 0.0 + while now < target: + value = random.random() + for _ in range(0, 10000): + value = value + random.random() + + total = total + value + + now = time.time() + time.sleep(0.01) + + return total + + +def find_label(sample, name, strings): + for label in sample.label: + if strings[label.key] == name: + return label + return None + + +def log_record_to_profile(log_record): + body = log_record.body + pprof_gzipped = base64.b64decode(body) + pprof = gzip.decompress(pprof_gzipped) + profile = profile_pb2.Profile() + profile.ParseFromString(pprof) + return profile + + +class TestProfiling(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Tracing is set up to correlate active spans with stacktraces. + # When the profiler takes a sample while a span is active, it stores the trace context into the stacktrace. + provider = TracerProvider() + processor = SimpleSpanProcessor(ConsoleSpanExporter()) + provider.add_span_processor(processor) + trace_api.set_tracer_provider(provider) + + def setUp(self): + self.span_id = None + self.trace_id = None + self.export_patcher = mock.patch.object(OTLPLogExporter, "export") + self.export_mock = self.export_patcher.start() + self.export_mock.return_value = LogExportResult.SUCCESS + + def tearDown(self): + self.export_patcher.stop() + stop_profiling() + + def profile_capture_thread_ids(self): + start_profiling( + service_name="prof-thread-filter", + call_stack_interval_millis=10, + ) + + do_work(100) + stop_profiling() + + args = self.export_mock.call_args + (log_datas,) = args[0] + self.assertGreaterEqual(len(log_datas), 5) + + thread_ids = set() + for log_data in log_datas: + profile = log_record_to_profile(log_data.log_record) + + strings = profile.string_table + + for sample in profile.sample: + thread_ids.add(find_label(sample, "thread.id", strings).num) + + return thread_ids + + def _assert_scope(self, scope): + self.assertEqual(scope.name, "otel.profiling") + self.assertEqual(scope.version, "0.1.0") + + def _assert_log_record(self, log_record): + self.assertTrue(int(time.time() * 1e9) - log_record.timestamp <= 2e9) + + self.assertEqual(log_record.trace_id, 0) + self.assertEqual(log_record.span_id, 0) + # Attributes are of type BoundedAttributes, get the underlying dict + attributes = log_record.attributes.copy() + + self.assertEqual(attributes["profiling.data.format"], "pprof-gzip-base64") + self.assertEqual(attributes["profiling.data.type"], "cpu") + self.assertEqual(attributes["com.splunk.sourcetype"], "otel.profiling") + # We should at least have the main thread + self.assertGreaterEqual(attributes["profiling.data.total.frame.count"], 1) + + resource = log_record.resource.attributes + + self.assertEqual(resource["foo"], "bar") + self.assertEqual(resource["telemetry.sdk.language"], "python") + self.assertEqual(resource["service.name"], "prof-export-test") + + profile = log_record_to_profile(log_record) + + strings = profile.string_table + locations = profile.location + functions = profile.function + + main_thread_samples = [] + + for sample in profile.sample: + thread_id = find_label(sample, "thread.id", strings).num + period = find_label(sample, "source.event.period", strings).num + self.assertEqual(period, 100) + + # Only care for the main thread which is running our busy function + if thread_id == threading.get_ident(): + main_thread_samples.append(sample) + + for sample in main_thread_samples: + for location_id in sample.location_id: + location = locations[location_id - 1] + function = functions[location.line[0].function_id - 1] + function_name = strings[function.name] + file_name = strings[function.filename] + self.assertGreater(len(file_name), 0) + + if function_name == "do_work": + span_id = int(strings[find_label(sample, "span_id", strings).str], 16) + trace_id = int( + strings[find_label(sample, "trace_id", strings).str], 16 + ) + self.assertEqual(span_id, self.span_id) + self.assertEqual(trace_id, self.trace_id) + return True + + return False + + def test_profiling_export(self): + tracer = trace_api.get_tracer("tests.tracer") + + start_profiling( + service_name="prof-export-test", + resource_attributes={"foo": "bar"}, + call_stack_interval_millis=100, + ) + + with tracer.start_as_current_span("add-some-numbers") as span: + span_ctx = span.get_span_context() + self.span_id = span_ctx.span_id + self.trace_id = span_ctx.trace_id + do_work(550) + + _force_flush() + + self.export_mock.assert_called() + + args = self.export_mock.call_args + (log_datas,) = args[0] + + # We expect at least 5 profiling cycles during a 550ms run with call_stack_interval=100 + self.assertGreaterEqual(len(log_datas), 5) + + busy_function_profiled = False + for log_data in log_datas: + log_record = log_data.log_record + self._assert_scope(log_data.instrumentation_scope) + if self._assert_log_record(log_record): + busy_function_profiled = True + + self.assertTrue(busy_function_profiled) + + def test_include_internal_threads(self): + with_internal_stacks_thread_ids = set() + with mock.patch.dict( + os.environ, {"SPLUNK_PROFILER_INCLUDE_INTERNAL_STACKS": "true"}, clear=True + ): + with_internal_stacks_thread_ids = self.profile_capture_thread_ids() + + without_internal_stacks_thread_ids = set() + with mock.patch.dict( + os.environ, {"SPLUNK_PROFILER_INCLUDE_INTERNAL_STACKS": "false"}, clear=True + ): + without_internal_stacks_thread_ids = self.profile_capture_thread_ids() + + # Internally the profiler should use 2 threads: the sampler thread and export thread. + count_diff = len(with_internal_stacks_thread_ids) - len( + without_internal_stacks_thread_ids + ) + self.assertEqual(count_diff, 2)