Verified Commit 86cfd19f authored by Sebastian Endres's avatar Sebastian Endres
Browse files

Add offset plot, add --debug, move helpers to utils, refactoring

parent cc5f3e17
......@@ -23,6 +23,7 @@ class Side(Enum):
class PlotModeAll(Enum):
OFFSET_NUMBER = "offset-number"
PACKET_NUMBER = "packet-number"
FILE_SIZE = "file-size"
ALL = "all"
......@@ -31,6 +32,7 @@ class PlotModeAll(Enum):
"""Get PlotModes (resolve 'all')."""
return {
PlotModeAll.OFFSET_NUMBER: [PlotMode.OFFSET_NUMBER],
PlotModeAll.PACKET_NUMBER: [PlotMode.PACKET_NUMBER],
PlotModeAll.FILE_SIZE: [PlotMode.FILE_SIZE],
PlotModeAll.ALL: [mode for mode in PlotMode],
......@@ -96,9 +98,14 @@ def parse_args():
action="store",
choices=PlotModeAll,
type=PlotModeAll,
default=PlotModeAll.PACKET_NUMBER,
default=PlotModeAll.OFFSET_NUMBER,
help="The mode of plotting (time vs. packet-number or time vs. file-size or both)",
)
parser.add_argument(
"--debug",
action="store_true",
help="Debug mode.",
)
return parser.parse_args()
......@@ -115,7 +122,8 @@ class PlotAllCli:
force=False,
only_sat=False,
include_failed=False,
mode: PlotModeAll = PlotModeAll.PACKET_NUMBER,
mode: PlotModeAll = PlotModeAll.OFFSET_NUMBER,
debug=False,
):
# self.log_dirs = log_dirs
self.result_files = result_files
......@@ -128,6 +136,7 @@ class PlotAllCli:
self.include_failed = include_failed
self.modes = mode.get_plot_modes()
self._current_log_dir: Optional[Path] = None
self.debug = debug
def plot_in_testcase_dir(
self,
......@@ -209,6 +218,7 @@ class PlotAllCli:
annotate=self.annotate,
mode=mode,
cache=True,
debug=self.debug,
)
try:
cli.run()
......@@ -308,6 +318,7 @@ def main():
force=args.force,
include_failed=args.include_failed,
mode=args.mode,
debug=args.debug,
)
try:
cli.run()
......
......@@ -3,7 +3,6 @@
"""Plot time packet-number plots and more."""
# TODO
# - lost packets?
# - plot packet sizes
import argparse
......@@ -11,14 +10,28 @@ import sys
import typing
from enum import Enum
from pathlib import Path
from typing import Callable, Literal, Optional, TypeVar, Union
from termcolor import cprint
from typing import Optional, Union
import numpy as np
from matplotlib import pyplot as plt
from tracer import ParsingError, Trace, get_quic_payload_size
from utils import YaspinWrapper, existing_file_path
from termcolor import colored, cprint
from tracer import (
ParsingError,
Trace,
get_quic_payload_size,
get_stream_offset,
iter_stream_frames,
)
from utils import (
Subplot,
YaspinWrapper,
create_relpath,
existing_file_path,
format_file_size,
map2d,
map3d,
)
if typing.TYPE_CHECKING:
from collections.abc import Iterable
......@@ -27,71 +40,17 @@ if typing.TYPE_CHECKING:
class PlotMode(Enum):
OFFSET_NUMBER = "offset-number"
PACKET_NUMBER = "packet-number"
FILE_SIZE = "file-size"
DEFAULT_TITLES = {
PlotMode.OFFSET_NUMBER: "Time vs. Offset-Number",
PlotMode.PACKET_NUMBER: "Time vs. Packet-Number",
PlotMode.FILE_SIZE: "Time vs. Transmitted File Size",
}
T = TypeVar("T")
def map2d(func: Callable[["Iterable[T]"], T], arrays: "Iterable[Iterable[T]]") -> T:
"""Map func to arrays and to each entry of arrays."""
return func(map(func, arrays))
def map3d(
func: Callable[["Iterable[T]"], T],
arrays: "Iterable[Iterable[Iterable[T]]]",
) -> T:
def inner_func(arr):
return map2d(func, arr)
return func(map(inner_func, arrays))
def format_file_size(val: Union[int, float]) -> str:
"""Format bytes."""
if val < 0:
return f"-{format_file_size(-val)}"
val = int(val)
UNITS = ["B", "kB", "MB", "GB", "TB"]
for exp, unit in reversed(list(enumerate(UNITS))):
fac = 2 ** (exp * 10)
cur, val = divmod(val, fac)
if cur:
digits = int(val / fac * 10)
if unit == "B":
return f"{val:.1} B"
else:
return f"{cur}.{digits} {UNITS[exp]}"
return f"{val} B"
class Subplot:
fig: plt.Figure
ax: plt.Axes
def __init__(self, *args, **kwargs):
self.fig, self.ax = plt.subplots(*args, **kwargs)
def __enter__(self):
return self.fig, self.ax
def __exit__(self, *args, **kwargs):
plt.close(fig=self.fig)
def parse_args():
"""Parse command line args."""
......@@ -139,7 +98,7 @@ def parse_args():
action="store",
choices=PlotMode,
type=PlotMode,
default=PlotMode.PACKET_NUMBER,
default=PlotMode.OFFSET_NUMBER,
help="The mode of plotting (time vs. packet-number or time vs. file-size",
)
parser.add_argument(
......@@ -175,10 +134,16 @@ class PlotCli:
keylog_files: list[Optional[Path]] = [],
output_file: Optional[Path] = None,
annotate=True,
mode: PlotMode = PlotMode.PACKET_NUMBER,
mode: PlotMode = PlotMode.OFFSET_NUMBER,
cache=False,
debug=False,
):
self.title = title or DEFAULT_TITLES[mode]
self.output_file = output_file
self.annotate = annotate
self.mode = mode
self.debug = debug
if not keylog_files:
keylog_files = [None] * len(pcap_files)
......@@ -189,14 +154,10 @@ class PlotCli:
keylog_file=keylog_file,
display_filter=self.display_filter,
cache=cache,
debug=self.debug,
)
for pcap_file, keylog_file in zip(pcap_files, keylog_files)
]
self.title = title or DEFAULT_TITLES[mode]
self.output_file = output_file
self.annotate = annotate
self.mode = mode
self.debug = debug
def vline_annotate(
self,
......@@ -252,81 +213,176 @@ class PlotCli:
label_side=label_side,
)
def plot_packet_number(self):
"""Plot the packet number diagram."""
def plot_offset_number(self):
"""Plot the offset number diagram."""
with Subplot(nrows=1, ncols=1) as (_fig, ax):
ax.grid(True)
ax.set_xlabel("Time (s)")
ax.set_ylabel("Packet Number")
ax.set_ylabel("Offset")
ax.set_title(self.title)
ax.yaxis.set_major_formatter(lambda val, _pos: format_file_size(val))
# avoid lazy result parsing:
with YaspinWrapper(
debug=self.debug, text="processing...", color="cyan"
) as spinner:
request_offsets = list[list[int]]()
response_first_offsets = list[list[int]]()
response_retrans_offsets = list[list[int]]()
request_timestamps = list[list[float]]()
response_first_timestamps = list[list[float]]()
response_retrans_timestamps = list[list[float]]()
for trace in self.traces:
request_offsets.append(list[int]())
response_first_offsets.append(list[int]())
response_retrans_offsets.append(list[int]())
request_timestamps.append(list[float]())
response_first_timestamps.append(list[float]())
response_retrans_timestamps.append(list[float]())
for layer in trace.request_stream_frames:
offset = get_stream_offset(layer)
if offset is not None:
request_offsets[-1].append(offset)
request_timestamps[-1].append(layer.sniff_timestamp)
for layer in trace.response_stream_frames_first_tx:
offset = get_stream_offset(layer)
if offset is not None:
response_first_offsets[-1].append(offset)
response_first_timestamps[-1].append(layer.sniff_timestamp)
for layer in trace.response_stream_frames_retrans:
offset = get_stream_offset(layer)
if offset is not None:
response_retrans_offsets[-1].append(offset)
response_retrans_timestamps[-1].append(layer.sniff_timestamp)
all_offsets = (
*request_offsets,
*response_first_offsets,
*(lst for lst in response_retrans_offsets if lst),
)
min_offset: int = map2d(min, all_offsets)
max_offset: int = map2d(max, all_offsets)
all_timestamps = (
*request_timestamps,
*response_first_timestamps,
*(lst for lst in response_retrans_timestamps if lst),
)
min_timestamp: float = map2d(min, all_timestamps)
max_timestamp: float = map2d(max, all_timestamps)
for trace in self.traces:
assert (
trace.packets
), "Trace {trace} contains no filtered packets! Are the secrets injected?"
ax.set_xlim(left=min(0, min_timestamp), right=max_timestamp)
ax.set_ylim(bottom=min(0, min_offset), top=max_offset)
if self.annotate:
# raise errors early
trace.parse()
spinner.ok("✔")
with YaspinWrapper(
debug=self.debug, text="plotting...", color="cyan"
) as spinner:
# plot shadow traces (request and response separated)
for trace_timestamps, trace_offsets in zip(
request_timestamps[1:], request_offsets[1:]
):
ax.plot(
trace_timestamps,
trace_offsets,
marker="o",
linestyle="",
color="#CCC",
)
for (
trace_first_timestamps,
trace_first_offsets,
trace_retrans_timestamps,
trace_retrans_offsets,
) in zip(
response_first_timestamps[1:],
response_first_offsets[1:],
response_retrans_timestamps[1:],
response_retrans_offsets[1:],
):
ax.plot(
(*trace_first_timestamps, *trace_retrans_timestamps),
(*trace_first_offsets, *trace_retrans_offsets),
marker="o",
linestyle="",
color="#CCC",
)
# plot main trace (request and response separated)
ax.plot(
request_timestamps[0],
request_offsets[0],
marker="o",
linestyle="",
color="#73d216",
)
ax.plot(
response_first_timestamps[0],
response_first_offsets[0],
marker="o",
linestyle="",
color="#3465A4",
)
ax.plot(
response_retrans_timestamps[0],
response_retrans_offsets[0],
marker="o",
linestyle="",
color="#b60000",
)
self._annotate_time_plot(ax, max_offset / 2)
spinner.ok("✔")
self._save()
def plot_packet_number(self):
"""Plot the packet number diagram."""
with Subplot(nrows=1, ncols=1) as (_fig, ax):
ax.grid(True)
ax.set_xlabel("Time (s)")
ax.set_ylabel("Packet Number")
ax.set_title(self.title)
with YaspinWrapper(
debug=self.debug, text="processing...", color="cyan"
) as spinner:
request_timestamps = [
np.array(
[
float(packet.sniff_timestamp)
for packet in trace.request_packets
]
)
[layer.sniff_timestamp for layer in trace.request_stream_frames]
for trace in self.traces
]
response_timestamps = [
np.array(
[
float(packet.sniff_timestamp)
for packet in trace.response_packets
]
)
[layer.sniff_timestamp for layer in trace.response_stream_frames]
for trace in self.traces
]
request_packet_numbers = [
np.array(
[
int(packet.quic.packet_number)
for packet in trace.request_packets
]
)
[int(layer.packet_number) for layer in trace.request_stream_frames]
for trace in self.traces
]
response_packet_numbers = [
np.array(
[
int(packet.quic.packet_number)
for packet in trace.response_packets
]
)
[int(layer.packet_number) for layer in trace.response_stream_frames]
for trace in self.traces
]
min_packet_number: int = map3d(
min, [request_packet_numbers, response_packet_numbers]
)
max_packet_number: int = map3d(
max, [request_packet_numbers, response_packet_numbers]
)
min_timestamp: np.float64 = map3d(
min, [request_timestamps, response_timestamps]
)
max_timestamp: np.float64 = map3d(
max, [request_timestamps, response_timestamps]
)
all_packet_numbers = [request_packet_numbers, response_packet_numbers]
min_packet_number: int = map3d(min, all_packet_numbers)
max_packet_number: int = map3d(max, all_packet_numbers)
all_timestamps = [request_timestamps, response_timestamps]
min_timestamp: float = map3d(min, all_timestamps)
max_timestamp: float = map3d(max, all_timestamps)
ax.set_xlim(left=min(0, min_timestamp), right=max_timestamp)
ax.set_ylim(bottom=min(0, min_packet_number), top=max_packet_number)
spinner.ok("")
spinner.ok("")
with YaspinWrapper(
debug=self.debug, text="plotting...", color="cyan"
......@@ -334,7 +390,7 @@ class PlotCli:
# plot shadow traces (request and response separated)
for trace_timestamps, trace_packet_numbers in zip(
request_timestamps, request_packet_numbers
request_timestamps[1:], request_packet_numbers[1:]
):
ax.plot(
trace_timestamps,
......@@ -345,7 +401,7 @@ class PlotCli:
)
for trace_timestamps, trace_packet_numbers in zip(
response_timestamps, response_packet_numbers
response_timestamps[1:], response_packet_numbers[1:]
):
ax.plot(
trace_timestamps,
......@@ -373,7 +429,7 @@ class PlotCli:
)
self._annotate_time_plot(ax, max_packet_number / 2)
spinner.ok("")
spinner.ok("")
self._save()
def plot_file_size(self):
......@@ -385,17 +441,6 @@ class PlotCli:
ax.set_title(self.title)
ax.yaxis.set_major_formatter(lambda val, _pos: format_file_size(val))
# avoid lazy result parsing:
for trace in self.traces:
assert (
trace.packets
), "Trace {trace} contains no filtered packets! Are the secrets injected?"
if self.annotate:
# raise errors early
trace.parse()
with YaspinWrapper(
debug=self.debug, text="processing...", color="cyan"
) as spinner:
......@@ -430,7 +475,7 @@ class PlotCli:
ax.set_ylim(bottom=min(0, min_file_size), top=max_file_size)
ax.set_yticks(np.arange(0, max_file_size * 1.1, 1024 * 1024))
spinner.ok("")
spinner.ok("")
with YaspinWrapper(
debug=self.debug, text="plotting...", color="cyan"
......@@ -459,7 +504,7 @@ class PlotCli:
)
self._annotate_time_plot(ax, max_file_size / 2)
spinner.ok("")
spinner.ok("")
self._save()
def _save(self):
......@@ -467,12 +512,16 @@ class PlotCli:
if self.output_file:
plt.savefig(self.output_file, dpi=300, transparent=True)
cprint(f"{self.output_file} written.", color="green")
cprint(f"{create_relpath(self.output_file)} written.", color="green")
else:
plt.show()
def run(self):
mapping = {
PlotMode.OFFSET_NUMBER: {
"callback": self.plot_offset_number,
"desc": "time vs. offset number",
},
PlotMode.PACKET_NUMBER: {
"callback": self.plot_packet_number,
"desc": "time vs. packet number",
......@@ -483,6 +532,17 @@ class PlotCli:
},
}
# avoid lazy result parsing:
for trace in self.traces:
assert (
trace.packets
), "Trace {trace} contains no filtered packets! Are the secrets injected?"
if self.annotate:
# raise errors early
trace.parse()
cfg = mapping[self.mode]
callback = cfg["callback"]
desc = cfg["desc"]
......@@ -513,7 +573,7 @@ def main():
try:
cli.run()
except ParsingError as err:
sys.exit(err)
sys.exit(colored(err, color="red"))
if __name__ == "__main__":
......
......@@ -18,6 +18,9 @@ if typing.TYPE_CHECKING:
from pyshark.packet.packet import Packet
QuicStreamLayer = Any
class ParsingError(Exception):
"""Exception that will be thrown when we can't parse the trace."""
......@@ -57,7 +60,7 @@ def get_quic_payload_size(packet: "Packet") -> int:
def follow_stream(stream_frames: list[Any]) -> bytes:
buf = list[int]()
buf = list[Optional[int]]()
for frame in stream_frames:
offset = get_stream_offset(frame)
......@@ -65,11 +68,17 @@ def follow_stream(stream_frames: list[Any]) -> bytes:
extend_buf = offset - len(buf)
if extend_buf > 0:
buf += [0] * extend_buf
buf += [None] * extend_buf
buf[offset:] = frame.stream_data.binary_value
return bytes(buf)
if not all(byte is not None for byte in buf):
cprint(
"Warning! Did not receive all bytes in follow_stream.",
color="yellow",
)
return bytes([byte or 0 for byte in buf])
def get_frame_prop_from_all_frames(
......@@ -141,9 +150,7 @@ def get_stream_fin_packet_number(packets: list["Packet"], trace: "Trace") -> lis
for packet in packets:
for quic_layer in iter_stream_frames(packet):
layer_offset = get_stream_offset(quic_layer)
if layer_offset is None:
breakpoint()
assert layer_offset is not None
if layer_offset is not None and layer_offset > max_offset:
max_offset = layer_offset
......@@ -239,6 +246,8 @@ class Trace:
self._facts = dict[str, Any]()
self._request_packets = list["Packet"]()
self._response_packets = list["Packet"]()
self._response_stream_frames_first_tx = list[QuicStreamLayer]()
self._response_stream_frames_retrans = list[QuicStreamLayer]()
def __str__(self):
trace_file_name = Path(self._cap.input_filename).name
......@@ -274,7 +283,7 @@ class Trace:
color="green",
) as spinner:
cached_packets = pickle.load(cache_file)
spinner.ok("")
spinner.ok("")
return cached_packets
......@@ -298,7 +307,7 @@ class Trace:
color="green",
) as spinner:
pickle.dump(obj=packets, file=cache_file)
spinner.ok("")
spinner.ok("")
return packets
......@@ -323,6 +332,77 @@ class Trace:
return self._facts
@cached_property
def request_stream_frames(self) -> list[QuicStreamLayer]:
stream_layers = list[QuicStreamLayer]()
for packet in self.request_packets:
layers_in_packet = list(iter_stream_frames(packet))