kemono2/src/internals/tracing/custom_rb_instrumentor.py
2024-07-04 22:08:17 +02:00

69 lines
2.2 KiB
Python

import time
from opentelemetry import trace
from opentelemetry.semconv.trace import SpanAttributes
from wrapt import wrap_function_wrapper
redis_instrumentor_kwargs = {}
redis_tracer = trace.get_tracer("rb", "0")
def _traced_execute_command(func, instance, args, kwargs):
name = args[0]
if len(args) > 1:
name += f" {args[1][:30]}"
with redis_tracer.start_as_current_span(name, kind=trace.SpanKind.CLIENT) as span:
if span.is_recording():
span.set_attribute(
SpanAttributes.DB_STATEMENT,
" ".join(x[:40] if isinstance(x, str) else str(x)[:40] for x in args[1:5])
+ ("..." if len(args) > 5 else ""),
)
span.set_attribute("db.redis.args_length", len(args))
response = func(*args, **kwargs)
return response
def _traced_send_command(func, instance, args, kwargs):
with redis_tracer.start_as_current_span("connection_send_command", kind=trace.SpanKind.CLIENT) as span:
response = func(*args, **kwargs)
return response
def _traced_get_connection(func, instance, args, kwargs):
with redis_tracer.start_as_current_span(func.__name__, kind=trace.SpanKind.CLIENT) as span:
response = func(*args, **kwargs)
return response
def _traced_generic(func, instance, args, kwargs):
with redis_tracer.start_as_current_span(func.__name__, kind=trace.SpanKind.CLIENT) as span:
response = func(*args, **kwargs)
return response
def _traced_time_generic(func, instance, args, kwargs):
current_span = trace.get_current_span()
if current_span:
start = time.perf_counter()
response = func(*args, **kwargs)
current_span.set_attribute("time_" + func.__name__, (time.perf_counter() - start) * 1e6)
else:
response = func(*args, **kwargs)
return response
def rb_instrument():
wrap_function_wrapper("rb", f"clients.RoutingClient.execute_command", _traced_execute_command)
wrap_function_wrapper("rb", f"clients.RoutingClient.parse_response", _traced_time_generic)
wrap_function_wrapper("redis", f"connection.Connection.send_command", _traced_time_generic)
wrap_function_wrapper("redis", f"connection.ConnectionPool.get_connection", _traced_time_generic)