Coverage for src / puzzletree / utils / progress_bar.py: 98.67%

59 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-12 20:35 +0000

1"""Progress bar utilities for training visualization.""" 

2 

3from __future__ import annotations 

4 

5import time 

6from typing import TYPE_CHECKING, Self 

7 

8from rich.progress import BarColumn, Progress, SpinnerColumn, TaskID, TextColumn, TimeElapsedColumn, TimeRemainingColumn 

9 

10if TYPE_CHECKING: 

11 from types import TracebackType 

12 

13 from rich.console import Console 

14 

15 

16class ProgressBar: 

17 """A flexible progress bar with dynamic control for efficient. 

18 

19 Supports both: 

20 - **Time-Based Updates:** Refreshes every `target_update_interval` seconds. 

21 - **Step-Based Updates:** Refreshes every `update_interval` steps. 

22 

23 Optimized for performance using Rich's `.refresh()` and `.transient` for minimal I/O overhead. 

24 """ 

25 

26 def __init__( 

27 self, 

28 use_progress_bar: bool = True, # noqa: FBT001, FBT002 - Boolean positional args acceptable for class initialization 

29 update_mode: str = "time", 

30 target_update_interval: float = 1.0, 

31 ) -> None: 

32 """Initializes the ProgressBar instance. 

33 

34 Args: 

35 use_progress_bar (bool): Enables or disables the progress bar. 

36 update_mode (str): Mode of progress bar updates: 'time' or 'step'. 

37 target_update_interval (float): Time interval (in seconds) for adaptive updates. 

38 """ 

39 self.use_progress_bar = use_progress_bar 

40 self.progress: Progress | None = None 

41 self.task: TaskID | None = None 

42 self.update_mode = update_mode 

43 self.target_update_interval = target_update_interval 

44 self.last_update_time: float = time.perf_counter() 

45 self.update_interval: int = 1 # Initial step size 

46 

47 def __enter__(self) -> Self: 

48 """Allows `with ProgressBar(...)` usage.""" 

49 return self 

50 

51 def __exit__( 

52 self, 

53 exc_type: type[BaseException] | None, 

54 exc_value: BaseException | None, 

55 traceback: object | None, # noqa: PYI036 - object | None is acceptable for traceback parameter 

56 ) -> None: 

57 """Ensures `.stop()` is automatically called in `with` blocks.""" 

58 self.stop() 

59 

60 def start(self, total_batches: int, epoch: int) -> None: 

61 """Starts the progress bar for the given epoch. 

62 

63 Args: 

64 total_batches (int): Total number of batches in the dataset. 

65 epoch (int): Current epoch number. 

66 """ 

67 if self.use_progress_bar: 

68 self.progress = Progress( 

69 TextColumn("[bold blue]Epoch {task.fields[epoch]}[/]"), 

70 BarColumn(), 

71 TextColumn("• Loss: [red]{task.fields[loss]:.4f}[/]"), 

72 TextColumn("• Acc: [green]{task.fields[acc]:.2f}%[/]"), 

73 TimeRemainingColumn(), 

74 transient=True, 

75 ) 

76 self.task = self.progress.add_task("Training", total=total_batches, loss=0.0, acc=0.0, epoch=epoch) 

77 self.progress.start() 

78 

79 def update(self, current_batch: int, loss: float, acc: float, epoch: int, batch_time: float) -> None: 

80 """Update the progress bar with current metrics. 

81 

82 Dynamically adjusts the update interval based on batch time (for time-based updates). 

83 

84 Args: 

85 current_batch (int): Current batch number. 

86 loss (float): Average loss value for the current epoch. 

87 acc (float): Accuracy value for the current epoch. 

88 epoch (int): Current epoch number. 

89 batch_time (float): Time (in seconds) taken to process the current batch. 

90 """ 

91 now = time.perf_counter() 

92 

93 # 🔥 Adaptive Interval Adjustment 

94 if self.update_mode == "time": 

95 self.update_interval = max(1, int(self.target_update_interval / batch_time)) 

96 

97 # 🔥 Adaptive Progress Bar Logic 

98 if self.use_progress_bar and ( 

99 (self.update_mode == "step" and current_batch % self.update_interval == 0) 

100 or (self.update_mode == "time" and now - self.last_update_time >= self.target_update_interval) 

101 ): 

102 if self.progress is not None and self.task is not None: 102 ↛ 105line 102 didn't jump to line 105 because the condition on line 102 was always true

103 self.progress.update(self.task, completed=current_batch, loss=loss, acc=acc, epoch=epoch) 

104 self.progress.refresh() 

105 self.last_update_time = now 

106 

107 def stop(self) -> None: 

108 """Stops the progress bar and finalizes its output.""" 

109 if self.use_progress_bar and self.progress is not None: 

110 self.progress.stop() 

111 

112 

113class StageProgressBar: 

114 """Simple stage-based progress bar for CLI workflows.""" 

115 

116 def __init__( 

117 self, 

118 console: Console | None = None, 

119 use_progress_bar: bool = True, # noqa: FBT001, FBT002 - Boolean positional args acceptable for utility class 

120 transient: bool = True, # noqa: FBT001, FBT002 - Boolean positional args acceptable for utility class 

121 ) -> None: 

122 """Initialize the stage-based progress helper.""" 

123 self.console = console 

124 self.use_progress_bar = use_progress_bar 

125 self.transient = transient 

126 self.progress: Progress | None = None 

127 self.task: TaskID | None = None 

128 

129 def __enter__(self) -> Self: 

130 """Return the progress bar for context-manager usage.""" 

131 return self 

132 

133 def __exit__( 

134 self, 

135 exc_type: type[BaseException] | None, 

136 exc_value: BaseException | None, 

137 traceback: TracebackType | None, 

138 ) -> None: 

139 """Stop the progress bar when leaving a context manager.""" 

140 self.stop() 

141 

142 def start(self, total_steps: int, description: str = "Starting") -> None: 

143 """Start a progress bar that advances once per workflow stage.""" 

144 if not self.use_progress_bar: 

145 return 

146 

147 self.progress = Progress( 

148 SpinnerColumn(), 

149 TextColumn("[bold blue]{task.description}[/]"), 

150 BarColumn(), 

151 TextColumn("{task.completed}/{task.total}"), 

152 TimeElapsedColumn(), 

153 console=self.console, 

154 transient=self.transient, 

155 ) 

156 self.task = self.progress.add_task(description, total=total_steps) 

157 self.progress.start() 

158 

159 def advance(self, description: str) -> None: 

160 """Advance the stage counter and update the current description.""" 

161 if not self.use_progress_bar or self.progress is None or self.task is None: 

162 return 

163 

164 self.progress.update(self.task, advance=1, description=description) 

165 self.progress.refresh() 

166 

167 def stop(self) -> None: 

168 """Stop the progress bar.""" 

169 if self.use_progress_bar and self.progress is not None: 

170 self.progress.stop()