본문 바로가기
컴퓨터 언어/Python_Pytorch

[ Pytorch ] 파이토치 텐서 합치는 방법 : cat(), stack() ( + dim의 의미와, 병합 방식의 차이)

by SuperMemi 2022. 8. 18.
반응형

[ Pytorch ] 파이토치 텐서 합치는 방법 : cat(), stack() ( + dim의 의미와, 병합 방식의 차이)

 

https://pytorch.org/

[ Pytorch ] 파이토치 설치하기

 

[ Pytorch ] 파이토치 설치하기

[ Pytorch ] 파이토치 설치하기 머신러닝에서 tensorflow와 pytorch는 양대 산맥이죠 pytorch를 설치해봅시다. https://pytorch.org/get-started/locally/ PyTorch An open source machine learning framewor..

supermemi.tistory.com


torch.cat(seq, dim)

  • 새로운 차원으로 확장하지 않고 주어진 텐서 shape 내에서 병합합니다. 

Parameters 

  • seq : 합치고 싶은 텐서의 시퀀스
  • dim : 텐서를 합칠 차원을 정수로 표현함

주의

  • 합치려는 차원을 제외한 나머지 차원의 shape은 모두 동일해야합니다
  • dim 계산시 shape을 기준으로 생각하면 편합니다. (아래 예시 코드 참고)
  • dim 양수일 경우 : 차원 shape 앞에서 부터 계산
  • dim 음수일 경우 : 차원 shape 뒤어서 부터 계산

코드예시 : 1차원 텐서간 병합 (cat)

  • 1차원 텐서는 차원이 하나 밖에 없기 때문에 dim = 0 만 설정 가능하다.
  • 병합하는 차원으로는 size가 달라도 병합 가능 (a.size = [3], b.size = [5] → a_b.size = [8])
import torch
print(torch.__version__) # 1.11.0

# 병합할 텐서
a = torch.arange(0,3) # tensor([0, 1, 2])
b = torch.arange(3,8) # tensor([3, 4, 5, 6, 7])

print(a.size(), b.size()) # torch.Size([3]) torch.Size([5])

# 1차원 텐서간 병합
a_b = torch.cat((a,b),dim=0) # tensor([0, 1, 2, 3, 4, 5, 6, 7])

print(a_b.size()) # torch.Size([8])

코드예시 : 2차원 텐서간 병합 (cat)

  • dim = 0 → torch.size() 에서 첫번째 차원으로 병합 (a.size = [2,3], b.size = [2,3] → a_b.size = [4,3])
  • dim = 1 → torch.size() 에서 두번째 차원으로 병합 (a.size = [2,3], b.size = [2,3] → a_b.size = [2,6])
import torch
print(torch.__version__) # 1.11.0

# 병합할 텐서

a = torch.arange(0,6).reshape(2,3)
print(a.size()) # torch.Size([2, 3]) 
print(a)
"""
tensor([[0, 1, 2],
        [3, 4, 5]])
""" 

b = torch.arange(6,12).reshape(2,3) 
print(b.size()) # torch.Size([2, 3])
print(b)
"""
tensor([[ 6,  7,  8],
        [ 9, 10, 11]])
"""

# 2차원 텐서간 병합

#     dim = 0
a_b_dim_0 = torch.cat((a,b),dim=0) 
print(a_b_dim_0.size()) # torch.Size([4, 3])
print(a_b_dim_0)
"""
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
"""

#     dim = 1
a_b_dim_1 = torch.cat((a,b),dim=1)
print(a_b_dim_1.size()) # torch.Size([2, 6])
print(a_b_dim_1)
"""
tensor([[ 0,  1,  2,  6,  7,  8],
        [ 3,  4,  5,  9, 10, 11]])
"""

코드예시 : 2차원 텐서 3개 병합 (cat)

  • 여러개의 텐서를 seq로 만들어서 한번에 병합이 가능합니다.
import torch
print(torch.__version__) # 1.11.0

# 병합할 텐서

a = torch.arange(0,6).reshape(2,3)
print(a.size()) # torch.Size([2, 3]) 
print(a)
"""
tensor([[0, 1, 2],
        [3, 4, 5]])
""" 

b = torch.arange(6,12).reshape(2,3) 
print(b.size()) # torch.Size([2, 3])
print(b)
"""
tensor([[ 6,  7,  8],
        [ 9, 10, 11]])
"""

c = torch.arange(12,18).reshape(2,3) 
print(c.size()) # torch.Size([2, 3])
print(c)
"""
tensor([[12, 13, 14],
        [15, 16, 17]])
"""

# 2차원 텐서 3개 병합

#     dim = 0
a_b_dim_0 = torch.cat((a,b,c),dim=0) 
print(a_b_dim_0.size()) # torch.Size([6, 3])
print(a_b_dim_0)
"""
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]])
"""

#     dim = 1
a_b_dim_1 = torch.cat((a,b,c),dim=1)
print(a_b_dim_1.size()) # torch.Size([2, 9])
print(a_b_dim_1)
"""
tensor([[ 0,  1,  2,  6,  7,  8, 12, 13, 14],
        [ 3,  4,  5,  9, 10, 11, 15, 16, 17]])
"""

코드예시 : 3차원 텐서간 병합 (cat)

  • 텐서가 3차원이 되어도 원리는 유사하다
  • dim = 0 → torch.size() 에서 첫번째 차원으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [2,2,3])
  • dim = 1 → torch.size() 에서 두번째 차원으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [1,4,3])
  • dim = 2 → torch.size() 에서 세번째 차원으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [1,2,6])
import torch
print(torch.__version__) # 1.11.0

# 병합할 텐서

a = torch.arange(0,6).reshape(1,2,3)
print(a.size()) # torch.Size([1, 2, 3])
print(a)
"""
tensor([[[0, 1, 2],
         [3, 4, 5]]])
""" 

b = torch.arange(6,12).reshape(1,2,3) 
print(b.size()) # torch.Size([1, 2, 3])
print(b)
"""
tensor([[[ 6,  7,  8],
         [ 9, 10, 11]]])
"""

# 2차원 텐서간 병합

#     dim = 0
a_b_dim_0 = torch.cat((a,b),dim=0) 
print(a_b_dim_0.size()) # torch.Size([2, 2, 3])
print(a_b_dim_0)
"""
tensor([[[ 0,  1,  2],
         [ 3,  4,  5]],

        [[ 6,  7,  8],
         [ 9, 10, 11]]])
"""

#     dim = 1
a_b_dim_1 = torch.cat((a,b),dim=1)
print(a_b_dim_1.size()) #torch.Size([1, 4, 3])
print(a_b_dim_1)
"""
tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]]])
"""

#     dim = 2
a_b_dim_2 = torch.cat((a,b),dim=2)
print(a_b_dim_2.size()) # torch.Size([1, 2, 6])
print(a_b_dim_2)
"""
tensor([[[ 0,  1,  2,  6,  7,  8],
         [ 3,  4,  5,  9, 10, 11]]])
"""

코드예시 : 3차원 텐서간 병합 (cat) (dim 음수 일때)

  • dim이 음수일때는 뒤에서 부터 계산한다.
  • dim = 0 → torch.size() 앞에서 첫번째 차원으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [2,2,3])
  • dim = 1 → torch.size() 앞에서 두번째 차원으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [1,4,3])
  • dim = 2 → torch.size() 앞에서 세번째 차원(마지막차원)으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [1,2,6])
  • dim = -1 → torch.size() 뒤에서 첫번째 차원(앞에서 마지막차원)으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [1,2,6])
  • dim = -2 → torch.size() 뒤에서 두번째 차원으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [1,4,3])
  • dim = -3 → torch.size() 뒤에서 세번째 차원으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [2,2,3])
#     dim = -1
a_b_dim_1_ = torch.cat((a,b),dim=-1)
print(a_b_dim_1_.size()) # torch.Size([1, 2, 6])
print(a_b_dim_1_)
"""
tensor([[[ 0,  1,  2,  6,  7,  8],
         [ 3,  4,  5,  9, 10, 11]]])
"""

#     dim = -2
a_b_dim_2_ = torch.cat((a,b),dim=-2)
print(a_b_dim_2_.size()) #torch.Size([1, 4, 3])
print(a_b_dim_2_)
"""
tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]]])
"""


#     dim = -3
a_b_dim_3_ = torch.cat((a,b),dim=-3) 
print(a_b_dim_3_.size()) # torch.Size([2, 2, 3])
print(a_b_dim_3_)
"""
ensor([[[ 0,  1,  2],
         [ 3,  4,  5]],

        [[ 6,  7,  8],
         [ 9, 10, 11]]])
"""

 

dim 비교 직접 해보시면 이해가 빠르게 되실 겁니다!!

 

print(a_b_dim_0) # dim = 0
print(a_b_dim_1) # dim = 1
print(a_b_dim_2) # dim = 2
print(a_b_dim_1_) # dim = -1
print(a_b_dim_2_) # dim = -2
print(a_b_dim_3_) # dim = -3

torch.cat(seq, dim)

  • 앞에서 다룬 cat은 주어진 텐서의 고유 shape에 맞추어서 병합 했습니다. 
  • 그러나 stack은 새로운 차원으로 확장하여 텐서 시퀀스를 병합 합니다.
  • 즉, stack은 말 그대로 주어진 차원으로 쌓아 올린다라고 생각하시면 됩니다.

Parameters 

  • seq : 합치고 싶은 텐서의 시퀀스
  • dim : 텐서를 확장할 차원을 정수로 표현함

주의 

  • 확장하고자 하는 차원의 shape이 동일해야 합니다.
  • 텐서를 dim 차원으로 확장한 뒤 cat을 진행한다는 느낌으로 생각하시면 됩니다.

 


코드예시 : 1차원 텐서간 병합 (stack)

  • dim = 0 → unsqueeze(0)으로 텐서를 확장한 뒤 dim = 0 으로 병합한 것과 동일 (a.size = [3] ▶ 확장 : [1,3], b.size = [3] ▶ 확장 : [1,3] → a_b.size = [2,3])
  • dim = 1  → unsqueeze(1)로 텐서로 확장한 뒤 dim = 1 으로 병합한 것과 동일 (a.size = [3] ▶ 확장 : [3,1], b.size = [3] ▶ 확장 : [3,1] → a_b.size = [3,2])
import torch
print(torch.__version__) # 1.11.0

# 병합할 텐서
a = torch.arange(0,3) # tensor([0, 1, 2])
b = torch.arange(3,6) # tensor([3, 4, 5])

print(a.size(), b.size()) # torch.Size([3]) torch.Size([3])

# 1차원 텐서간 병합
#     dim = 0
a_b = torch.stack((a,b),dim=0) 

print(a_b.size()) # torch.Size([2,3])
print(a_b)
"""
tensor([[0, 1, 2],
        [3, 4, 5]])
"""
print(a_b == torch.cat((a.unsqueeze(0),b.unsqueeze(0)),dim=0)) 
"""
tensor([[True, True, True],
        [True, True, True]])
"""


#     dim = 1
a_b = torch.stack((a,b),dim=1) # tensor([0, 1, 2, 3, 4, 5, 6, 7])

print(a_b.size()) # torch.Size([3,2])
print(a_b)
"""
tensor([[0, 3],
        [1, 4],
        [2, 5]])
"""
print(a_b == torch.cat((a.unsqueeze(1),b.unsqueeze(1)),dim=1)) 
"""
tensor([[True, True],
        [True, True],
        [True, True]])
"""

 

 

코드예시 : 2차원 텐서간 병합 (stack)

  • dim = 0 → unsqueeze(0)으로 텐서를 확장한 뒤 dim = 0 으로 병합한 것과 동일 (a.size = [2,3] ▶ 확장 : [1,2,3], b.size = [2,3] ▶ 확장 : [1,2,3] → a_b.size = [2,2,3])
  • dim = 1  → unsqueeze(1)로 텐서로 확장한 뒤 dim = 1 으로 병합한 것과 동일 (a.size = [2,3]  확장 : [2,1,3], b.size = [2,3] ▶ 확장 : [2,1,3] → a_b.size = [2,2,3])
  • dim = 2  → unsqueeze(2)로 텐서로 확장한 뒤 dim = 2 으로 병합한 것과 동일 (a.size = [2, 3] ▶ 확장 : [2,3,1], b.size = [2,3] ▶ 확장 : [2,3,1] → a_b.size = [2,3,2])

 

주의 : 같은 텐서 구조를 가지더라도 텐서의 내용을 자세히 보시면 다릅니다!!!

 

import torch
print(torch.__version__) # 1.11.0

# 병합할 텐서

a = torch.arange(0,6).reshape(2,3)
print(a.size()) # torch.Size([2, 3]) 
print(a)
"""
tensor([[0, 1, 2],
        [3, 4, 5]])
""" 

b = torch.arange(6,12).reshape(2,3) 
print(b.size()) # torch.Size([2, 3])
print(b)
"""
tensor([[ 6,  7,  8],
        [ 9, 10, 11]])
"""

# 2차원 텐서간 병합

#     dim = 0
a_b_dim_0 = torch.stack((a,b),dim=0) 
print(a_b_dim_0.size()) # torch.Size([2, 2, 3])
print(a_b_dim_0)
"""
tensor([[[ 0,  1,  2],
         [ 3,  4,  5]],

        [[ 6,  7,  8],
         [ 9, 10, 11]]])
"""

#     dim = 1
a_b_dim_1 = torch.stack((a,b),dim=1)
print(a_b_dim_1.size()) # torch.Size([2, 2, 3])
print(a_b_dim_1)
"""
tensor([[[ 0,  1,  2],
         [ 6,  7,  8]],

        [[ 3,  4,  5],
         [ 9, 10, 11]]])
"""

#     dim = 2
a_b_dim_2 = torch.stack((a,b),dim=2)
print(a_b_dim_2.size()) # torch.Size([2, 3, 2])
print(a_b_dim_2)
"""
tensor([[[ 0,  6],
         [ 1,  7],
         [ 2,  8]],

        [[ 3,  9],
         [ 4, 10],
         [ 5, 11]]])
"""

 

직접 unsqueeze를 해보며 차원의 변화를 알아보시는 것이 이해에 도움이 됩니다!

 


정리

  • cat : 주어진 텐서들을 병합할때 dim 을 기준으로 붙여넣음. 그래서 전체 차원의 크기가 변하지 않음
  • stack : 주어진 텐서들을 병합할때 dim 을 확장 시킨 후 붙여넣음. 그래서 전체 차원이 1개 늘어남

차원 변환 과정에서 텐서의 배열이 바뀌게 됨. 병합시 주의할 것!


[ 다음 글 ]

[ Pytorch ] nn.BCELoss, nn.BCEWithLogitsLoss, nn.CrossEntropyLoss, nn.NLLLoss 총정리

[ PyTorch / torchvision ] make_grid() 사용하기

[ PyTorch / torchvision ] draw_bounding_boxes() 사용하기

 


[ 참고 ]

https://pytorch.org/docs/stable/generated/torch.cat.html

https://pytorch.org/docs/stable/generated/torch.stack.html


 

반응형