>>103339435
for those curious new swg_pred_calc (pag_utils.py)
def swg_pred_calc(x: Tensor, crop_count: int, crop_width: int, crop_height : int, calc_func: Callable[..., tuple[Tensor]]):
steps_per_dim = int(math.sqrt(crop_count))
b, c, h, w = x.shape
swg_pred = torch.zeros_like(x)
overlap = torch.zeros_like(x)
stride_x = (w - crop_width) // (steps_per_dim - 1)
stride_y = (h - crop_height) // (steps_per_dim - 1)
for i in range(steps_per_dim):
for j in range(steps_per_dim):
left, right = stride_x * i, stride_x * i + crop_width
top, bottom = stride_y * j, stride_y * j + crop_height
x_window = x[:, :, top:bottom, left:right]
swg_pred_window = calc_func(x_in=x_window)[0]
swg_pred[:, :, top:bottom, left:right] += swg_pred_window
overlap_window = torch.ones_like(swg_pred_window)
overlap[:, :, top:bottom, left:right] += overlap_window
swg_pred = swg_pred / overlap
return swg_pred