Notes: Loss Functions for Self-Supervised Learning & Generative Models

Photo by Jess Bailey on Unsplash

Notes: Loss Functions for Self-Supervised Learning & Generative Models

Not a blog (at least not yet), just published notes with the help of ChatGPT for me to easily find.

Contrastive Loss, Triplet Loss, and Adversarial Losses are commonly used loss functions in self-supervised learning and generative models.

Here's a brief overview of each:

Contrastive Loss

Contrastive learning aims to learn useful representations by contrasting positive and negative pairs. The contrastive loss encourages similar representations for positive pairs and dissimilar representations for negative pairs. It is typically computed using a contrastive loss function such as InfoNCE (Normalized Cross Entropy) loss. Contrastive learning has been successful in self-supervised learning tasks like image representation learning.

import torch
import torch.nn as nn
import torch.nn.functional as F

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, embeddings, labels):
        distances = F.pairwise_distance(embeddings[0::2], embeddings[1::2])  # Compute distances between positive pairs
        positive_loss = torch.pow(distances, 2)

        max_margin = torch.clamp(self.margin - distances, min=0)
        negative_loss = torch.pow(max_margin, 2)

        loss = torch.mean(labels * positive_loss + (1 - labels) * negative_loss)
        return loss

# Usage example
embeddings = torch.randn(10, 128)  # Example embeddings
labels = torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])  # Example labels (1 for positive pair, 0 for negative pair)

criterion = ContrastiveLoss(margin=1.0)
loss = criterion(embeddings, labels)

Triplet Loss

Triplet loss is used in metric learning tasks where the goal is to learn embeddings such that similar samples are closer together in the embedding space. It involves selecting triplets of anchor, positive, and negative samples and optimizing the embedding space such that the distance between the anchor and positive sample is smaller than the distance between the anchor and negative sample by a certain margin. Triplet loss is commonly used in tasks like face recognition, person re-identification, and image retrieval.

import torch
import torch.nn as nn
import torch.nn.functional as F

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = F.pairwise_distance(anchor, positive)
        distance_negative = F.pairwise_distance(anchor, negative)

        loss = torch.mean(torch.clamp(distance_positive - distance_negative + self.margin, min=0))
        return loss

# Usage example
anchor = torch.randn(10, 128)  # Example anchor embeddings
positive = torch.randn(10, 128)  # Example positive embeddings
negative = torch.randn(10, 128)  # Example negative embeddings

criterion = TripletLoss(margin=1.0)
loss = criterion(anchor, positive, negative)

Adversarial Losses

Adversarial losses are used in generative models, such as Generative Adversarial Networks (GANs), to train a generator network to produce realistic samples that can fool a discriminator network. The generator tries to generate samples that are indistinguishable from real samples, while the discriminator aims to correctly classify between real and generated samples. Adversarial losses, such as the original GAN loss or Wasserstein GAN loss, provide a way to optimize the generator and discriminator simultaneously.

import torch
import torch.nn as nn
import torch.optim as optim

class Generator(nn.Module):
    def __init__(self, ...):
        super(Generator, self).__init__()
        # Generator network architecture

    def forward(self, ...):
        # Generator forward pass

class Discriminator(nn.Module):
    def __init__(self, ...):
        super(Discriminator, self).__init__()
        # Discriminator network architecture

    def forward(self, ...):
        # Discriminator forward pass

# Instantiate generator and discriminator
generator = Generator(...)
discriminator = Discriminator(...)

# Loss functions
adversarial_loss = nn.BCELoss()  # Binary Cross Entropy Loss for GANs

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.001)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.001)

# Training loop
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(data_loader):
        # Training the discriminator
        optimizer_D.zero_grad()

        # Real images
        real_labels = torch.ones(real_images.size(0), 1)
        real_output = discriminator(real_images)
        real_loss = adversarial_loss(real_output, real_labels)

        # Generated images
        z = torch.randn(real_images.size(0), latent_dim)

These loss functions can be incorporated into your self-supervised learning model or generative model based on the specific task and architecture you are working with. They require careful implementation and tuning to achieve good results. It's recommended to refer to research papers and existing implementations for detailed guidance and best practices specific to each loss function.