Self-Supervised Learning
Self-Supervised Learning 이란 레이블이 없는 데이터셋을 바탕으로 모델이 데이터로부터 좋은 representation을 만드는 학습 방법을 의미합니다. 데이터로부터 추출한 representation vector는 데이터의 특징을 잘 담고 있어야 합니다. 예를 들어, Self-Supervised Learning으로 잘 학습된 Encoder가 있다고 가정해봅니다. 데이터의 특징을 잘 추출하였다면, 바나나 이미지들로부터 추출한 representation vector들은 cosine similarity가 높고, 바나나 이미지와 강아지 이미지로부터 추출한 representation vector들의 similarity는 낮을 것 입니다.
최근에는 Contrastive Learning 방법 중 하나인 SimCLR을 통해 Self-Supervised Learning을 수행합니다. Contrastive Learning은 유사한 이미지는 유사한 representation을, 다른 이미지에는 다른 representation을 갖도록 학습하는 것이 목표입니다.
SimCLR은 이미지 $x$가 주어지면, 두 가지 다른 데이터 증강 방식 $t$와 $t`$을 사용해 $x_i$와 $x_j$ 를 생성합니다. 생성된 이미지($x_i, x_j$)는 Positive Pair가 됩니다. 이렇게 $N$개의 데이터로부터 $2N$개의 데이터를 생성하며, $N$ 쌍의 Positive Pair가 존재하게 됩니다.
$f$는 증강된 데이터 샘플에서 representation vector를 추출하는 기본 인코더 네트워크로, 각각 $h_i, h_j$를 생성합니다. ($f$는 보통 ResNet-50이 사용됩니다.) 마지막으로 Projection head라고 불리는 small neural network인 $g$를 통해 Contrastive loss를 적용하기 위한 space 위로 $h_i, h_j$를 맵핑합니다.
Contrastive loss의 목표는 $z_i = g(h_i)$와 $z_j = g(h_j)$ 의 유사도를 최대화하는 것입니다. 학습이 완료되면 $g$를 버리고 $f$와 $h$만을 사용하여 여러 downstream task를 수행합니다.

SimCLR Loss 는 아래와 같습니다. N개의 학습 데이터로부터 데이터 증강을 통해 2N개의 증강된 샘플을 얻습니다. 각 positive pair($i,j$) 에 대해 loss는 $z_i$와 $z_j$의 유사도를 최대화 하는 것을 목표로 합니다.
sim$(z_i, z_j) = \frac{z_i \cdot z_j}{|| z_i || || z_j ||}$ 은 벡터 $z_i$와 $z_j$ 간의 (정규화된) 내적 값 입니다. 벡터 $z_i$와 $z_j$의 유사도가 높을 수록 내적 값이 커지게 되고, 분자가 커지게 됩니다. 분모는 배치 내의 $z_i$와 다른 모든 증강 예제 $k$에 대한 유사도를 합산하여 (0, 1) 범위로 정규화 합니다. 따라서 음의 로그는 (-inf, 0) 의 범위를 갖습니다.
$$
l \; (i, j) = -\log \frac{\exp (\;\text{sim}(z_i, z_j)\; / \;\tau) }{\sum_{k=1}^{2N} \mathbb{1}_{k \neq i} \exp (\;text{sim} (z_i, z_k) \;/ \;\tau) }
$$
$l(k,k+N)$은 loss의 분모가 $k$ 번째 이미지와 전체 이미지들의 유사도의 합이고, $l(k+N, k)$는 $k+N$ 번째 이미지와 전체 이미지들의 유사도의 합 입니다. 이 두 값은 다르기 때문에, 두 loss을 평균내어 total loss를 계산합니다.
$$
L = \frac{1}{2N} \sum_{k=1}^N [ \; l(k, \;k+N) + l(k+N, \;k)\;]
$$
Code
먼저, vector $z_i$와 $z_j$ 의 similarity를 구하는 함수를 만들어 보겠습니다. cosine similarity의 공식은 다음과 같습니다. torch.linalg.norm 함수는 입력 벡터의 유클리드 노름을 구해줍니다.
$$sim(z_i, z_j) = \frac{A \cdot B}{|A| |B|}$$
def sim(z_i, z_j):
"""Normalized dot product between two vectors.
Inputs:
- z_i: 1xD tensor.
- z_j: 1xD tensor.
Returns:
- A scalar value that is the normalized dot product between z_i and z_j.
"""
norm_dot_product = None
##############################################################################
# TODO: Start of your code. #
# #
# HINT: torch.linalg.norm might be helpful. #
##############################################################################
norm_dot_product = torch.dot(z_i, z_j) / (torch.linalg.norm(z_i) * torch.linalg.norm(z_j))
##############################################################################
# END OF YOUR CODE #
##############################################################################
return norm_dot_product
다음으로, simclr loss를 naive하게 구하는 함수 입니다. 공식을 그대로 코드로 구현했습니다.
def simclr_loss_naive(out_left, out_right, tau):
"""Compute the contrastive loss L over a batch (naive loop version).
Input:
- out_left: NxD tensor; output of the projection head g(), left branch in SimCLR model.
- out_right: NxD tensor; output of the projection head g(), right branch in SimCLR model.
Each row is a z-vector for an augmented sample in the batch. The same row in out_left and out_right form a positive pair.
In other words, (out_left[k], out_right[k]) form a positive pair for all k=0...N-1.
- tau: scalar value, temperature parameter that determines how fast the exponential increases.
Returns:
- A scalar value; the total loss across all positive pairs in the batch. See notebook for definition.
"""
N = out_left.shape[0] # total number of training examples
# Concatenate out_left and out_right into a 2*N x D tensor.
out = torch.cat([out_left, out_right], dim=0) # [2*N, D]
total_loss = 0
for k in range(N): # loop through each positive pair (k, k+N)
z_k, z_k_N = out[k], out[k+N]
##############################################################################
# TODO: Start of your code. #
# #
# Hint: Compute l(k, k+N) and l(k+N, k). #
##############################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
num_1_up = torch.exp(sim(z_k, z_k_N)/tau)
num_1_bo = 0
for l in range(2*N):
if l == k:
continue
num_1_bo += torch.exp(sim(z_k, out[l])/tau)
loss1 = -torch.log(num_1_up/num_1_bo)
num_2_up = torch.exp(sim(z_k_N, z_k)/tau)
num_2_bo = 0
for l in range(2*N):
if l==k:
continue
num_2_bo += torch.exp(sim(z_k_N, out[l])/tau)
loss2 = -torch.log(num_2_up/num_2_bo)
total_loss += (loss1 + loss2)
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
##############################################################################
# END OF YOUR CODE #
##############################################################################
# In the end, we need to divide the total loss by 2N, the number of samples in the batch.
total_loss = total_loss / (2*N)
return total_loss
아래는 simclr_loss를 vectorize하게 구하는 함수를 구현하기 위해 필요한 함수들입니다.
sim_positive_pairs 함수는 입력 $l$ 과 $r$을 받아 두 벡터간의 similarity를 구해줍니다. 이는 simclr loss의 분자를 계산할 때 사용됩니다.
compute_sim_matrix 함수는 입력 (2N x D) 를 받아, 모든 데이터 쌍 similarity를 구하는 함수입니다. 따라서 output은 (2N x 2N) 행렬이 됩니다. 이는 simclr loss의 분모를 계산할 때 사용됩니다.
def sim_positive_pairs(out_left, out_right):
"""Normalized dot product between positive pairs.
Inputs:
- out_left: NxD tensor; output of the projection head g(), left branch in SimCLR model.
- out_right: NxD tensor; output of the projection head g(), right branch in SimCLR model.
Each row is a z-vector for an augmented sample in the batch.
The same row in out_left and out_right form a positive pair.
Returns:
- A Nx1 tensor; each row k is the normalized dot product between out_left[k] and out_right[k].
"""
pos_pairs = None
##############################################################################
# TODO: Start of your code. #
# #
# HINT: torch.linalg.norm might be helpful. #
##############################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
pos_pairs = torch.sum(out_left*out_right, dim = 1, keepdim=True)
pos_pairs /= torch.linalg.norm(out_left, dim=1, keepdim=True)
pos_pairs /= torch.linalg.norm(out_right, dim=1, keepdim=True)
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
##############################################################################
# END OF YOUR CODE #
##############################################################################
return pos_pairs
def compute_sim_matrix(out):
"""Compute a 2N x 2N matrix of normalized dot products between all pairs of augmented examples in a batch.
Inputs:
- out: 2N x D tensor; each row is the z-vector (output of projection head) of a single augmented example.
There are a total of 2N augmented examples in the batch.
Returns:
- sim_matrix: 2N x 2N tensor; each element i, j in the matrix is the normalized dot product between out[i] and out[j].
"""
sim_matrix = None
##############################################################################
# TODO: Start of your code. #
##############################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
norm = torch.linalg.norm(out, dim=1, keepdim=True)
# sim_matrix = torch.matmul(out, out.T)
sim_matrix = out @ out.T
sim_matrix /= norm
sim_matrix /= norm.T
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
##############################################################################
# END OF YOUR CODE #
##############################################################################
return sim_matrix
simclr_loss를 vectorize하게 계산하는 함수입니다.
pos_pair_sim 을 계산한 값은 loss의 분자로 사용됩니다.
compute_sim_matrix를 통해 계산된 (2N x 2N) 행렬은 분모에 사용됩니다. 이 때, 행렬의 대각성분인 $(i, i)$ 값들은 자기 자신과의 similarity 이므로 mask 행렬을 통해 계산하지 않습니다.
def simclr_loss_vectorized(out_left, out_right, tau, device='cuda'):
"""Compute the contrastive loss L over a batch (vectorized version). No loops are allowed.
Inputs and output are the same as in simclr_loss_naive.
"""
N = out_left.shape[0]
# Concatenate out_left and out_right into a 2*N x D tensor.
out = torch.cat([out_left, out_right], dim=0) # [2*N, D]
# Compute similarity matrix between all pairs of augmented examples in the batch.
sim_matrix = compute_sim_matrix(out) # [2*N, 2*N]
##############################################################################
# TODO: Start of your code. Follow the hints. #
##############################################################################
# Step 1: Use sim_matrix to compute the denominator value for all augmented samples.
# Hint: Compute e^{sim / tau} and store into exponential, which should have shape 2N x 2N.
exponential = (sim_matrix / tau).exp()
# This binary mask zeros out terms where k=i.
mask = (torch.ones_like(exponential, device=device) - torch.eye(2 * N, device=device)).to(device).bool()
# We apply the binary mask.
exponential = exponential.masked_select(mask).view(2 * N, -1) # [2*N, 2*N-1]
# Hint: Compute the denominator values for all augmented samples. This should be a 2N x 1 vector.
denom = exponential.sum(dim=1, keepdim=True)
# Step 2: Compute similarity between positive pairs.
# You can do this in two ways:
# Option 1: Extract the corresponding indices from sim_matrix.
# Option 2: Use sim_positive_pairs().
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
pos_pair_sim = sim_matrix[torch.arange(0,N), torch.arange(N, 2*N)][:, None]
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
# Step 3: Compute the numerator value for all augmented samples.
numerator = None
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
numerator = (pos_pair_sim / tau).exp().repeat(2,1)
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
# Step 4: Now that you have the numerator and denominator for all augmented samples, compute the total loss.
loss = None
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
loss = -(numerator / denom).log().mean()
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
##############################################################################
# END OF YOUR CODE #
##############################################################################
return loss
먼저, [Encoder - Projection head] 로 구성된 Simclr을 n회 학습합니다. 이후, Projection head는 버리고, 사전 학습된 Encoder와 새로운 Classifier를 결합하여 분류기를 만듭니다.
결과적으로 [(Pretrained) Encoder - (New) Linear classifier] 구성을 통해 CIFAR-10 데이터를 Classification 해보았습니다.
먼저, Self-Supervised learning 방법을 적용하지 않은 [(No Pretrained) Encoder - (New) Linear classifier] 조합으로는 test Accuracy가 10% 정도 됩니다.
Self-Supervised learning 방법을 적용한 [(Pretrained) Encoder - (New) Linear classifier] 조합으로는 Accuracy가 80% 이상입니다.

'Stanford CS231n' 카테고리의 다른 글
| CS231n Assignment 3(3) : Generative Adversarial Networks (0) | 2024.12.19 |
|---|---|
| CS231n Assignment 3(2) : Image Captioning with Transformers (0) | 2024.12.19 |
| CS231n Assignment 3(1) : RNN_Captioning (1) | 2024.12.18 |
| CS231n Assignment2(4) : Convolutional Neural Networks (0) | 2024.12.18 |
| CS231n Assignment 2(3) : Dropout (0) | 2024.12.18 |