class DDPM:
def __init__(self, model, image_size=64, channels=1, batch_size=512, epochs=500, timesteps=200, beta_schedule="linear", loss_type="huber", device=None) -> None:
self.model = model
self.image_size = image_size
self.channels = channels
self.batch_size = batch_size # 256 * max(1, torch.cuda.device_count())
self.epochs = epochs
self.timesteps = timesteps
self.loss_type=loss_type
self.device = device
if self.device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# define beta schedule
if beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps=self.timesteps)
elif beta_schedule == "cosine":
self.betas = cosine_beta_schedule(timesteps=self.timesteps)
elif beta_schedule == "quadratic":
self.betas = quadratic_beta_schedule(timesteps=self.timesteps)
elif beta_schedule == "sigmoid":
self.betas = sigmoid_beta_schedule(timesteps=self.timesteps)
else:
raise NotImplementedError(f"Beta schedule {beta_schedule} not implemented")
# define alphas
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.results_folder = Path("./results")
self.results_folder.mkdir(exist_ok=True)
self.save_and_sample_every = 1000
self.optimizer = Adam(self.model.parameters(), lr=1e-3)
def q_sample(self, x_start, t, noise=None):
""" Forward diffusion """
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = myutils.extract(
self.sqrt_alphas_cumprod, t, x_start.shape
)
sqrt_one_minus_alphas_cumprod_t = myutils.extract(
self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
@torch.no_grad()
def p_sample(self, x, t, t_index):
betas_t = myutils.extract(self.betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = myutils.extract(
self.sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = myutils.extract(self.sqrt_recip_alphas, t, x.shape)
# Equation 11 in the paper
# Use our model (noise predictor) to predict the mean
model_mean = sqrt_recip_alphas_t * (
x - betas_t * self.model(x, t) / sqrt_one_minus_alphas_cumprod_t
)
if t_index == 0:
return model_mean
else:
posterior_variance_t = myutils.extract(self.posterior_variance, t, x.shape)
noise = torch.randn_like(x)
# Algorithm 2 line 4:
return model_mean + torch.sqrt(posterior_variance_t) * noise
# Algorithm 2 but save all images:
@torch.no_grad()
def p_sample_loop(self, shape):
device = next(self.model.parameters()).device
if device == "cuda":
print(myutils.cuda_mem_info())
b = shape[0]
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device)
imgs = []
for i in tqdm(
reversed(range(0, self.timesteps)),
desc="sampling loop time step",
total=self.timesteps,
):
img = self.p_sample(
img, torch.full((b,), i, device=device, dtype=torch.long), i
)
imgs.append(img.cpu().numpy())
if device == "cuda":
print(myutils.cuda_mem_info())
return imgs
@torch.no_grad()
def sample(self):
return self.p_sample_loop(
shape=(self.batch_size, self.channels, self.image_size, self.image_size),
)
def train(self, dataloader, mlflow_logging=False):
print(f"/!\ training on {self.device}")
if torch.cuda.device_count() > 1:
print(" -- using dataparallel")
self.model = nn.DataParallel(self.model)
self.model.to(self.device)
if self.device == "cuda":
print(myutils.cuda_mem_info())
scheduler1 = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.99)
losses = []
step_count = 0
for epoch in range(self.epochs):
epoch_losses = []
for step, batch in enumerate(dataloader):
self.optimizer.zero_grad()
# self.batch_size = batch["pixel_values"].shape[0]
# batch = batch["pixel_values"].to(self.device)
self.batch_size = batch.shape[0]
batch = batch.to(self.device)
# Algorithm 1 line 3: sample t uniformally for every example in the batch
# Important to have a number of epoch sufficiently large to see all the self.timesteps
t = torch.randint(
0, self.timesteps, (self.batch_size,), device=self.device
).long()
loss = self.p_losses(
denoise_model=self.model, x_start=batch, t=t, loss_type=self.loss_type
)
loss_value = loss.item()
losses.append(loss_value)
epoch_losses.append(loss_value)
if mlflow_logging:
import mlflow
if step_count % 10 == 0:
mlflow.log_metric("step_loss", loss_value, step=step_count)
if step % 100 == 0:
print(f"Epoch: {epoch:<4} -- Step: {step:<4} -- Loss: {loss_value:.6f}")
# if self.device == "cuda":
# print(f" -- {helpers.cuda_mem_info()}")
loss.backward()
self.optimizer.step()
step_count += 1
if epoch % 50 == 0 and epoch > 0 and step == 0:
fig, axes = plt.subplots(1, 3, figsize=(6, 2), dpi=100)
axes.flat[0].plot(np.arange(len(losses)), np.array(losses))
axes.flat[1].plot(
np.arange(len(losses[-100:])), np.array(losses[-100:])
)
axes.flat[2].plot(
np.arange(len(losses[-50:])), np.array(losses[-50:])
)
if mlflow_logging:
pass
# import mlflow
# mlflow.log_figure(fig, f"training_progress_epoch_{epoch}.png")
else:
plt.show()
# save generated images
if step != 0 and step % self.save_and_sample_every == 0:
milestone = step // self.save_and_sample_every
batches = myutils.num_to_groups(4, self.batch_size)
all_images_list = list(
map(
lambda n: self.sample(
self.model, batch_size=n, channels=self.channels
),
batches,
)
)
all_images = torch.cat(all_images_list, dim=0)
all_images = (all_images + 1) * 0.5
save_image(
all_images,
str(self.results_folder / f"sample-{milestone}.png"),
nrow=6,
)
epoch_avg_loss = sum(epoch_losses) / len(epoch_losses)
print(f"Epoch {epoch} completed - Avg Loss: {epoch_avg_loss:.6f}")
if mlflow_logging:
import mlflow
mlflow.log_metric("epoch_loss", epoch_avg_loss, step=epoch)
# scheduler1.step()
if self.device == "cuda":
print(myutils.cuda_mem_info())
return losses
def write_train_checkpoint(self, save_path=None):
if save_path is None:
save_path = f"checkpoint-{self.epoch}.pth"
torch.save(
{
"epochs": self.epochs,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"loss_type": self.loss_type,
},
save_path,
)
return save_path
def load_model_checkpoint(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.epochs = checkpoint["epochs"]
self.loss_type = checkpoint["loss_type"]
self.model.eval()
# necessary ????
self.model.train()
self.model.to(self.device)
def p_losses(self, denoise_model, x_start, t, noise=None, loss_type="l1"):
if noise is None:
noise = torch.randn_like(x_start)
x_noisy = self.q_sample(
x_start=x_start, t=t, noise=noise
) # forward diffusion of the dataset image
predicted_noise = denoise_model(x_noisy, t)
if loss_type == "l1":
loss = F.l1_loss(noise, predicted_noise)
elif loss_type == "l2":
loss = F.mse_loss(noise, predicted_noise)
elif loss_type == "huber":
loss = F.smooth_l1_loss(noise, predicted_noise)
else:
raise NotImplementedError()
return loss