CS231n Assignment 3(4) : Self-Supervised Learning for Image Classification

2024. 12. 20. 20:19·Stanford CS231n
반응형

Self-Supervised Learning

Self-Supervised Learning 이란 레이블이 없는 데이터셋을 바탕으로 모델이 데이터로부터 좋은 representation을 만드는 학습 방법을 의미합니다. 데이터로부터 추출한 representation vector는 데이터의 특징을 잘 담고 있어야 합니다. 예를 들어, Self-Supervised Learning으로 잘 학습된 Encoder가 있다고 가정해봅니다. 데이터의 특징을 잘 추출하였다면, 바나나 이미지들로부터 추출한 representation vector들은 cosine similarity가 높고, 바나나 이미지와 강아지 이미지로부터 추출한 representation vector들의 similarity는 낮을 것 입니다.

최근에는 Contrastive Learning 방법 중 하나인 SimCLR을 통해 Self-Supervised Learning을 수행합니다. Contrastive Learning은 유사한 이미지는 유사한 representation을, 다른 이미지에는 다른 representation을 갖도록 학습하는 것이 목표입니다.

 

SimCLR은 이미지 $x$가 주어지면, 두 가지 다른 데이터 증강 방식 $t$와 $t`$을 사용해 $x_i$와 $x_j$ 를 생성합니다. 생성된 이미지($x_i, x_j$)는 Positive Pair가 됩니다. 이렇게 $N$개의 데이터로부터 $2N$개의 데이터를 생성하며, $N$ 쌍의 Positive Pair가 존재하게 됩니다.

$f$는 증강된 데이터 샘플에서 representation vector를 추출하는 기본 인코더 네트워크로, 각각 $h_i, h_j$를 생성합니다. ($f$는 보통 ResNet-50이 사용됩니다.) 마지막으로 Projection head라고 불리는 small neural network인 $g$를 통해 Contrastive loss를 적용하기 위한 space 위로 $h_i, h_j$를 맵핑합니다.

Contrastive loss의 목표는 $z_i = g(h_i)$와 $z_j = g(h_j)$ 의 유사도를 최대화하는 것입니다. 학습이 완료되면 $g$를 버리고 $f$와 $h$만을 사용하여 여러 downstream task를 수행합니다.

 

SimCLR Loss 는 아래와 같습니다. N개의 학습 데이터로부터 데이터 증강을 통해 2N개의 증강된 샘플을 얻습니다. 각 positive pair($i,j$) 에 대해 loss는 $z_i$와 $z_j$의 유사도를 최대화 하는 것을 목표로 합니다. 

sim$(z_i, z_j) = \frac{z_i \cdot z_j}{|| z_i || || z_j ||}$ 은 벡터 $z_i$와 $z_j$ 간의 (정규화된) 내적 값 입니다. 벡터 $z_i$와 $z_j$의 유사도가 높을 수록 내적 값이 커지게 되고, 분자가 커지게 됩니다. 분모는 배치 내의 $z_i$와 다른 모든 증강 예제 $k$에 대한 유사도를 합산하여 (0, 1) 범위로 정규화 합니다. 따라서 음의 로그는 (-inf, 0) 의 범위를 갖습니다.

 

$$
l \; (i, j) = -\log \frac{\exp (\;\text{sim}(z_i, z_j)\; / \;\tau) }{\sum_{k=1}^{2N} \mathbb{1}_{k \neq i} \exp (\;text{sim} (z_i, z_k) \;/ \;\tau) }
$$

 

$l(k,k+N)$은 loss의 분모가 $k$ 번째 이미지와 전체 이미지들의 유사도의 합이고, $l(k+N, k)$는 $k+N$ 번째 이미지와 전체 이미지들의 유사도의 합 입니다. 이 두 값은 다르기 때문에, 두 loss을 평균내어 total loss를 계산합니다.

 

$$
L = \frac{1}{2N} \sum_{k=1}^N [ \; l(k, \;k+N) + l(k+N, \;k)\;]
$$

 

Code

먼저, vector $z_i$와 $z_j$ 의 similarity를 구하는 함수를 만들어 보겠습니다. cosine similarity의 공식은 다음과 같습니다. torch.linalg.norm 함수는 입력 벡터의 유클리드 노름을 구해줍니다.

$$sim(z_i, z_j) = \frac{A \cdot B}{|A| |B|}$$

def sim(z_i, z_j):
    """Normalized dot product between two vectors.

    Inputs:
    - z_i: 1xD tensor.
    - z_j: 1xD tensor.
    
    Returns:
    - A scalar value that is the normalized dot product between z_i and z_j.
    """
    norm_dot_product = None
    ##############################################################################
    # TODO: Start of your code.                                                  #
    #                                                                            #
    # HINT: torch.linalg.norm might be helpful.                                  #
    ##############################################################################
    
    norm_dot_product = torch.dot(z_i, z_j) / (torch.linalg.norm(z_i) * torch.linalg.norm(z_j))
    
    ##############################################################################
    #                               END OF YOUR CODE                             #
    ##############################################################################
    
    return norm_dot_product

 

 

다음으로, simclr loss를 naive하게 구하는 함수 입니다. 공식을 그대로 코드로 구현했습니다.

def simclr_loss_naive(out_left, out_right, tau):
    """Compute the contrastive loss L over a batch (naive loop version).
    
    Input:
    - out_left: NxD tensor; output of the projection head g(), left branch in SimCLR model.
    - out_right: NxD tensor; output of the projection head g(), right branch in SimCLR model.
    Each row is a z-vector for an augmented sample in the batch. The same row in out_left and out_right form a positive pair. 
    In other words, (out_left[k], out_right[k]) form a positive pair for all k=0...N-1.
    - tau: scalar value, temperature parameter that determines how fast the exponential increases.
    
    Returns:
    - A scalar value; the total loss across all positive pairs in the batch. See notebook for definition.
    """
    N = out_left.shape[0]  # total number of training examples
    
     # Concatenate out_left and out_right into a 2*N x D tensor.
    out = torch.cat([out_left, out_right], dim=0)  # [2*N, D]
    
    total_loss = 0
    for k in range(N):  # loop through each positive pair (k, k+N)
        z_k, z_k_N = out[k], out[k+N]
        
        ##############################################################################
        # TODO: Start of your code.                                                  #
        #                                                                            #
        # Hint: Compute l(k, k+N) and l(k+N, k).                                     #
        ##############################################################################
        # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

        num_1_up = torch.exp(sim(z_k, z_k_N)/tau)
        num_1_bo = 0
        for l in range(2*N):
          if l == k:
            continue
          num_1_bo += torch.exp(sim(z_k, out[l])/tau)
        loss1 = -torch.log(num_1_up/num_1_bo)

        num_2_up = torch.exp(sim(z_k_N, z_k)/tau)
        num_2_bo = 0
        for l in range(2*N):
          if l==k:
            continue
          num_2_bo += torch.exp(sim(z_k_N, out[l])/tau)
        loss2 = -torch.log(num_2_up/num_2_bo)

        total_loss += (loss1 + loss2)
        # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
         ##############################################################################
        #                               END OF YOUR CODE                             #
        ##############################################################################
    
    # In the end, we need to divide the total loss by 2N, the number of samples in the batch.
    total_loss = total_loss / (2*N)
    return total_loss

 

 

아래는 simclr_loss를 vectorize하게 구하는 함수를 구현하기 위해 필요한 함수들입니다.

 

sim_positive_pairs 함수는 입력 $l$ 과 $r$을 받아 두 벡터간의 similarity를 구해줍니다. 이는 simclr loss의 분자를 계산할 때 사용됩니다.

compute_sim_matrix 함수는 입력 (2N x D) 를 받아, 모든 데이터 쌍 similarity를 구하는 함수입니다. 따라서 output은 (2N x 2N) 행렬이 됩니다. 이는 simclr loss의 분모를 계산할 때 사용됩니다.

def sim_positive_pairs(out_left, out_right):
    """Normalized dot product between positive pairs.

    Inputs:
    - out_left: NxD tensor; output of the projection head g(), left branch in SimCLR model.
    - out_right: NxD tensor; output of the projection head g(), right branch in SimCLR model.
    Each row is a z-vector for an augmented sample in the batch.
    The same row in out_left and out_right form a positive pair.
    
    Returns:
    - A Nx1 tensor; each row k is the normalized dot product between out_left[k] and out_right[k].
    """
    pos_pairs = None
    
    ##############################################################################
    # TODO: Start of your code.                                                  #
    #                                                                            #
    # HINT: torch.linalg.norm might be helpful.                                  #
    ##############################################################################
    
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    pos_pairs = torch.sum(out_left*out_right, dim = 1, keepdim=True)
    pos_pairs /= torch.linalg.norm(out_left, dim=1, keepdim=True)
    pos_pairs /= torch.linalg.norm(out_right, dim=1, keepdim=True)


    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    
    ##############################################################################
    #                               END OF YOUR CODE                             #
    ##############################################################################
    return pos_pairs


def compute_sim_matrix(out):
    """Compute a 2N x 2N matrix of normalized dot products between all pairs of augmented examples in a batch.

    Inputs:
    - out: 2N x D tensor; each row is the z-vector (output of projection head) of a single augmented example.
    There are a total of 2N augmented examples in the batch.
    
    Returns:
    - sim_matrix: 2N x 2N tensor; each element i, j in the matrix is the normalized dot product between out[i] and out[j].
    """
    sim_matrix = None
    
    ##############################################################################
    # TODO: Start of your code.                                                  #
    ##############################################################################
    
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    norm = torch.linalg.norm(out, dim=1, keepdim=True)
    # sim_matrix = torch.matmul(out, out.T)
    sim_matrix = out @ out.T
    sim_matrix /= norm
    sim_matrix /= norm.T

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    
    ##############################################################################
    #                               END OF YOUR CODE                             #
    ##############################################################################
    return sim_matrix

 

 

simclr_loss를 vectorize하게 계산하는 함수입니다.

 

pos_pair_sim 을 계산한 값은 loss의 분자로 사용됩니다.

 

compute_sim_matrix를 통해 계산된 (2N x 2N) 행렬은 분모에 사용됩니다. 이 때, 행렬의 대각성분인 $(i, i)$ 값들은 자기 자신과의 similarity 이므로 mask 행렬을 통해 계산하지 않습니다.

def simclr_loss_vectorized(out_left, out_right, tau, device='cuda'):
    """Compute the contrastive loss L over a batch (vectorized version). No loops are allowed.
    
    Inputs and output are the same as in simclr_loss_naive.
    """
    N = out_left.shape[0]
    
    # Concatenate out_left and out_right into a 2*N x D tensor.
    out = torch.cat([out_left, out_right], dim=0)  # [2*N, D]
    
    # Compute similarity matrix between all pairs of augmented examples in the batch.
    sim_matrix = compute_sim_matrix(out)  # [2*N, 2*N]
    
    ##############################################################################
    # TODO: Start of your code. Follow the hints.                                #
    ##############################################################################
    
    # Step 1: Use sim_matrix to compute the denominator value for all augmented samples.
    # Hint: Compute e^{sim / tau} and store into exponential, which should have shape 2N x 2N.
    exponential = (sim_matrix / tau).exp()
    
    # This binary mask zeros out terms where k=i.
    mask = (torch.ones_like(exponential, device=device) - torch.eye(2 * N, device=device)).to(device).bool()
    
    # We apply the binary mask.
    exponential = exponential.masked_select(mask).view(2 * N, -1)  # [2*N, 2*N-1]
    
    # Hint: Compute the denominator values for all augmented samples. This should be a 2N x 1 vector.
    denom = exponential.sum(dim=1, keepdim=True)

    # Step 2: Compute similarity between positive pairs.
    # You can do this in two ways: 
    # Option 1: Extract the corresponding indices from sim_matrix. 
    # Option 2: Use sim_positive_pairs().
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    pos_pair_sim = sim_matrix[torch.arange(0,N), torch.arange(N, 2*N)][:, None]
    
    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    
    # Step 3: Compute the numerator value for all augmented samples.
    numerator = None
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    numerator = (pos_pair_sim / tau).exp().repeat(2,1)

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    
    # Step 4: Now that you have the numerator and denominator for all augmented samples, compute the total loss.
    loss = None
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    loss = -(numerator / denom).log().mean()

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    
    ##############################################################################
    #                               END OF YOUR CODE                             #
    ##############################################################################
    
    return loss

 

 

 

먼저, [Encoder - Projection head] 로 구성된 Simclr을 n회 학습합니다. 이후, Projection head는 버리고, 사전 학습된 Encoder와 새로운 Classifier를 결합하여 분류기를 만듭니다.

결과적으로 [(Pretrained) Encoder - (New) Linear classifier] 구성을 통해 CIFAR-10 데이터를 Classification 해보았습니다. 

 

먼저, Self-Supervised learning 방법을 적용하지 않은 [(No Pretrained) Encoder - (New) Linear classifier] 조합으로는 test Accuracy가 10% 정도 됩니다.

 

Self-Supervised learning 방법을 적용한 [(Pretrained) Encoder - (New) Linear classifier] 조합으로는 Accuracy가 80% 이상입니다.

반응형

'Stanford CS231n' 카테고리의 다른 글

CS231n Assignment 3(3) : Generative Adversarial Networks  (0) 2024.12.19
CS231n Assignment 3(2) : Image Captioning with Transformers  (0) 2024.12.19
CS231n Assignment 3(1) : RNN_Captioning  (1) 2024.12.18
CS231n Assignment2(4) : Convolutional Neural Networks  (0) 2024.12.18
CS231n Assignment 2(3) : Dropout  (0) 2024.12.18
'Stanford CS231n' 카테고리의 다른 글
  • CS231n Assignment 3(3) : Generative Adversarial Networks
  • CS231n Assignment 3(2) : Image Captioning with Transformers
  • CS231n Assignment 3(1) : RNN_Captioning
  • CS231n Assignment2(4) : Convolutional Neural Networks
hangyuwon
hangyuwon
  • hangyuwon
    191
    hangyuwon
  • 전체
    오늘
    어제
  • 글쓰기 관리
    • 분류 전체보기 (38)
      • 기타 (1)
      • Stanford CS231n (19)
      • 논문 리뷰 (5)
      • Error (4)
      • 알고리즘 (2)
      • Linux (1)
      • 잡동사니 (2)
      • 딥러닝 (4)
  • 인기 글

  • 태그

    알고리즘
    error
    논문 리뷰
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.0
hangyuwon
CS231n Assignment 3(4) : Self-Supervised Learning for Image Classification
상단으로

티스토리툴바