Skip to content

Helpers

influpaint.utils.helpers

extract(a, t, x_shape)

define an extract function, which will allow us to extract the appropriate (t) index for a batch of indices.

Source code in influpaint/utils/helpers.py
def extract(a, t, x_shape):
    """
    define an `extract` function, which will allow us to extract the appropriate \(t\) index for a batch of indices.
    """
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)