>>102789965
# Compute CLIP-based loss
if clip_encoder is not None and zs_tilde is not None:
with torch.no_grad():
vae_dtype = next(vae.parameters()).dtype
scaling_factor = torch.tensor(
vae.config.scaling_factor,
dtype=vae_dtype,
device=x_start.device
)
x_start_scaled = (x_start / scaling_factor).to(vae_dtype)
vae_device = next(vae.parameters()).device
x_start_scaled = x_start_scaled.to(vae_device)
decoded = vae.decode(x_start_scaled).sample
decoded_images = (decoded + 1) / 2 # Scale images to [0, 1]
decoded_images = decoded_images.clamp(0, 1)
decoded_images = decoded_images.to(torch.float32)
decoded_images = F.interpolate(
decoded_images,
size=(224, 224),
mode='bicubic',
align_corners=False
)
decoded_images = decoded_images.cpu()
zs = clip_encoder(decoded_images)
zs_tilde = zs_tilde.to("cpu")
zs_tilde = F.normalize(zs_tilde.to(torch.bfloat16), dim=-1)
zs = F.normalize(zs.to(torch.bfloat16), dim=-1)
proj_loss = -torch.sum(zs * zs_tilde, dim=-1).mean()
proj_loss *= clip_loss_weight
proj_loss = proj_loss.to("cuda:0")
terms['proj_loss'] = proj_loss
terms['loss'] += proj_loss