Coverage for src / puzzletree / utils / progress_bar.py: 98.67%
59 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-12 20:35 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-12 20:35 +0000
1"""Progress bar utilities for training visualization."""
3from __future__ import annotations
5import time
6from typing import TYPE_CHECKING, Self
8from rich.progress import BarColumn, Progress, SpinnerColumn, TaskID, TextColumn, TimeElapsedColumn, TimeRemainingColumn
10if TYPE_CHECKING:
11 from types import TracebackType
13 from rich.console import Console
16class ProgressBar:
17 """A flexible progress bar with dynamic control for efficient.
19 Supports both:
20 - **Time-Based Updates:** Refreshes every `target_update_interval` seconds.
21 - **Step-Based Updates:** Refreshes every `update_interval` steps.
23 Optimized for performance using Rich's `.refresh()` and `.transient` for minimal I/O overhead.
24 """
26 def __init__(
27 self,
28 use_progress_bar: bool = True, # noqa: FBT001, FBT002 - Boolean positional args acceptable for class initialization
29 update_mode: str = "time",
30 target_update_interval: float = 1.0,
31 ) -> None:
32 """Initializes the ProgressBar instance.
34 Args:
35 use_progress_bar (bool): Enables or disables the progress bar.
36 update_mode (str): Mode of progress bar updates: 'time' or 'step'.
37 target_update_interval (float): Time interval (in seconds) for adaptive updates.
38 """
39 self.use_progress_bar = use_progress_bar
40 self.progress: Progress | None = None
41 self.task: TaskID | None = None
42 self.update_mode = update_mode
43 self.target_update_interval = target_update_interval
44 self.last_update_time: float = time.perf_counter()
45 self.update_interval: int = 1 # Initial step size
47 def __enter__(self) -> Self:
48 """Allows `with ProgressBar(...)` usage."""
49 return self
51 def __exit__(
52 self,
53 exc_type: type[BaseException] | None,
54 exc_value: BaseException | None,
55 traceback: object | None, # noqa: PYI036 - object | None is acceptable for traceback parameter
56 ) -> None:
57 """Ensures `.stop()` is automatically called in `with` blocks."""
58 self.stop()
60 def start(self, total_batches: int, epoch: int) -> None:
61 """Starts the progress bar for the given epoch.
63 Args:
64 total_batches (int): Total number of batches in the dataset.
65 epoch (int): Current epoch number.
66 """
67 if self.use_progress_bar:
68 self.progress = Progress(
69 TextColumn("[bold blue]Epoch {task.fields[epoch]}[/]"),
70 BarColumn(),
71 TextColumn("• Loss: [red]{task.fields[loss]:.4f}[/]"),
72 TextColumn("• Acc: [green]{task.fields[acc]:.2f}%[/]"),
73 TimeRemainingColumn(),
74 transient=True,
75 )
76 self.task = self.progress.add_task("Training", total=total_batches, loss=0.0, acc=0.0, epoch=epoch)
77 self.progress.start()
79 def update(self, current_batch: int, loss: float, acc: float, epoch: int, batch_time: float) -> None:
80 """Update the progress bar with current metrics.
82 Dynamically adjusts the update interval based on batch time (for time-based updates).
84 Args:
85 current_batch (int): Current batch number.
86 loss (float): Average loss value for the current epoch.
87 acc (float): Accuracy value for the current epoch.
88 epoch (int): Current epoch number.
89 batch_time (float): Time (in seconds) taken to process the current batch.
90 """
91 now = time.perf_counter()
93 # 🔥 Adaptive Interval Adjustment
94 if self.update_mode == "time":
95 self.update_interval = max(1, int(self.target_update_interval / batch_time))
97 # 🔥 Adaptive Progress Bar Logic
98 if self.use_progress_bar and (
99 (self.update_mode == "step" and current_batch % self.update_interval == 0)
100 or (self.update_mode == "time" and now - self.last_update_time >= self.target_update_interval)
101 ):
102 if self.progress is not None and self.task is not None: 102 ↛ 105line 102 didn't jump to line 105 because the condition on line 102 was always true
103 self.progress.update(self.task, completed=current_batch, loss=loss, acc=acc, epoch=epoch)
104 self.progress.refresh()
105 self.last_update_time = now
107 def stop(self) -> None:
108 """Stops the progress bar and finalizes its output."""
109 if self.use_progress_bar and self.progress is not None:
110 self.progress.stop()
113class StageProgressBar:
114 """Simple stage-based progress bar for CLI workflows."""
116 def __init__(
117 self,
118 console: Console | None = None,
119 use_progress_bar: bool = True, # noqa: FBT001, FBT002 - Boolean positional args acceptable for utility class
120 transient: bool = True, # noqa: FBT001, FBT002 - Boolean positional args acceptable for utility class
121 ) -> None:
122 """Initialize the stage-based progress helper."""
123 self.console = console
124 self.use_progress_bar = use_progress_bar
125 self.transient = transient
126 self.progress: Progress | None = None
127 self.task: TaskID | None = None
129 def __enter__(self) -> Self:
130 """Return the progress bar for context-manager usage."""
131 return self
133 def __exit__(
134 self,
135 exc_type: type[BaseException] | None,
136 exc_value: BaseException | None,
137 traceback: TracebackType | None,
138 ) -> None:
139 """Stop the progress bar when leaving a context manager."""
140 self.stop()
142 def start(self, total_steps: int, description: str = "Starting") -> None:
143 """Start a progress bar that advances once per workflow stage."""
144 if not self.use_progress_bar:
145 return
147 self.progress = Progress(
148 SpinnerColumn(),
149 TextColumn("[bold blue]{task.description}[/]"),
150 BarColumn(),
151 TextColumn("{task.completed}/{task.total}"),
152 TimeElapsedColumn(),
153 console=self.console,
154 transient=self.transient,
155 )
156 self.task = self.progress.add_task(description, total=total_steps)
157 self.progress.start()
159 def advance(self, description: str) -> None:
160 """Advance the stage counter and update the current description."""
161 if not self.use_progress_bar or self.progress is None or self.task is None:
162 return
164 self.progress.update(self.task, advance=1, description=description)
165 self.progress.refresh()
167 def stop(self) -> None:
168 """Stop the progress bar."""
169 if self.use_progress_bar and self.progress is not None:
170 self.progress.stop()