fovi.utils.losses
- class fovi.utils.losses.SimCLRLoss(batch_size, world_size, gpu, temperature, pairs_per_sample=1)[source]
Bases:
ModuleSimCLR (Simple Framework for Contrastive Learning of Visual Representations) Loss.
This loss function is used for self-supervised learning of visual representations. It encourages the model to learn similar representations for different augmented views of the same image, while pushing apart representations of different images.
The loss is computed using a contrastive learning approach: 1. For each image in a batch, two augmented views are created. 2. These views are passed through an encoder network to get embeddings. 3. The similarity between positive pairs (two views of the same image) is maximized. 4. The similarity between negative pairs (views from different images) is minimized.
The loss uses a temperature-scaled softmax function to compute the probability of identifying the correct positive sample among the negative samples.
Key Components: - Cosine similarity is used as the similarity metric between embeddings. - A mask is used to identify and exclude self-comparisons from the negative samples. - The loss is computed using cross-entropy between the similarity scores and the true labels.
Optimization Notes: When using a batch size of 2048, use LARS as optimizer with a base learning rate of 0.5, weight decay of 1e-6 and a temperature of 0.15. When using a batch size of 256, use LARS as optimizer with base learning rate of 1.0, weight decay of 1e-6 and a temperature of 0.15.
- __init__(batch_size, world_size, gpu, temperature, pairs_per_sample=1)[source]
Initialize the SimCLRLoss module.
- Parameters:
batch_size (int) – The number of samples in each batch.
world_size (int) – The number of distributed processes.
gpu (torch.device) – The GPU device to use.
temperature (float) – A scaling factor for the cosine similarity.
pairs_per_sample (int, optional) – Number of augmented pairs per sample. Defaults to 1.
- mask
A boolean mask to exclude self-comparisons.
- Type:
- criterion
The loss function.
- Type:
nn.CrossEntropyLoss
- similarity_f
The similarity function.
- Type:
nn.CosineSimilarity
Create a mask to identify and exclude self-comparisons from negative samples.
This method generates a boolean mask that is used to exclude self-comparisons and comparisons between augmented views of the same image when computing the SimCLR loss.
- Parameters:
- Returns:
- A boolean mask of shape (N, N), where N = 2 * pairs_per_sample * batch_size * world_size.
True values indicate valid comparisons, while False values indicate self-comparisons or comparisons between augmented views of the same image.
- Return type:
Note
The resulting mask is structured such that it can be used directly in the SimCLR loss computation to select valid negative samples.
- compute_logits(z_i, z_j)[source]
Compute the logits for the SimCLR loss.
This method calculates the similarity matrix between the two sets of feature vectors (z_i and z_j) and prepares the logits for the contrastive loss computation.
- Parameters:
z_i (torch.Tensor) – The first set of feature vectors.
z_j (torch.Tensor) – The second set of feature vectors.
- Returns:
- A tensor of shape (N, N+1) containing the logits for each sample.
The first column contains the similarity with the positive sample, and the remaining columns contain similarities with negative samples.
- Return type:
Note
This method handles distributed training by gathering tensors across processes when world_size > 1.
- forward(z_i, z_j, logits=None)[source]
Compute the SimCLR loss for the given feature representations.
- Parameters:
z_i (torch.Tensor) – The first set of feature representations.
z_j (torch.Tensor) – The second set of feature representations.
logits (torch.Tensor, optional) – Pre-computed logits. If None, they will be computed using the compute_logits method.
- Returns:
- A tuple containing:
- num_sim (float): The numerator term of the loss, representing
the similarity between positive pairs.
- num_entropy (float): The denominator term of the loss,
representing the entropy of the similarity distribution.
- Return type:
Note
This implementation treats all augmented examples within a minibatch, except for the positive pair, as negative examples. This approach is similar to that described in (Chen et al., 2017).
- class fovi.utils.losses.VicRegLoss(sim_coeff, std_coeff, cov_coeff)[source]
Bases:
ModuleImplements the VICReg (Variance-Invariance-Covariance Regularization) loss.
VICReg is a self-supervised learning method that learns representations by enforcing invariance, variance, and covariance constraints on the embeddings.
Note
Recommended hyperparameters: - For batch size 2048: LARS optimizer, base learning rate 0.5, weight decay 1e-4,
sim_coeff and std_coeff 25, cov_coeff 1.
For batch size 256: LARS optimizer, base learning rate 1.5, weight decay 1e-4, sim_coeff and std_coeff 25, cov_coeff 1.
- __init__(sim_coeff, std_coeff, cov_coeff)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(z_i, z_j, return_only_loss=True)[source]
Compute the VICReg loss.
- Parameters:
z_i (torch.Tensor) – First set of embeddings.
z_j (torch.Tensor) – Second set of embeddings.
return_only_loss (bool) – If True, return only the total loss. If False, return individual loss components.
- Returns:
torch.Tensor: The total VICReg loss. If return_only_loss is False:
tuple: (total_loss, repr_loss, std_loss, cov_loss)
- Return type:
If return_only_loss is True
- class fovi.utils.losses.BarlowTwinsLoss(bn, batch_size, world_size, lambd)[source]
Bases:
ModuleImplements the Barlow Twins loss for self-supervised learning.
Barlow Twins aims to learn representations by maximizing the similarity between distorted versions of a sample while reducing the redundancy between the components of the representation vector.
- bn
Batch normalization layer for the embeddings.
- Type:
nn.BatchNorm1d
- __init__(bn, batch_size, world_size, lambd)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(z1, z2)[source]
Compute the Barlow Twins loss.
- Parameters:
z1 (torch.Tensor) – First set of embeddings.
z2 (torch.Tensor) – Second set of embeddings.
- Returns:
The computed Barlow Twins loss.
- Return type: