반응형
[ Pytorch ] 파이토치 텐서 합치는 방법 : cat(), stack() ( + dim의 의미와, 병합 방식의 차이)
torch.cat(seq, dim)
- 새로운 차원으로 확장하지 않고 주어진 텐서 shape 내에서 병합합니다.
Parameters
- seq : 합치고 싶은 텐서의 시퀀스
- dim : 텐서를 합칠 차원을 정수로 표현함
주의
- 합치려는 차원을 제외한 나머지 차원의 shape은 모두 동일해야합니다
- dim 계산시 shape을 기준으로 생각하면 편합니다. (아래 예시 코드 참고)
- dim 양수일 경우 : 차원 shape 앞에서 부터 계산
- dim 음수일 경우 : 차원 shape 뒤어서 부터 계산
코드예시 : 1차원 텐서간 병합 (cat)
- 1차원 텐서는 차원이 하나 밖에 없기 때문에 dim = 0 만 설정 가능하다.
- 병합하는 차원으로는 size가 달라도 병합 가능 (a.size = [3], b.size = [5] → a_b.size = [8])
import torch
print(torch.__version__) # 1.11.0
# 병합할 텐서
a = torch.arange(0,3) # tensor([0, 1, 2])
b = torch.arange(3,8) # tensor([3, 4, 5, 6, 7])
print(a.size(), b.size()) # torch.Size([3]) torch.Size([5])
# 1차원 텐서간 병합
a_b = torch.cat((a,b),dim=0) # tensor([0, 1, 2, 3, 4, 5, 6, 7])
print(a_b.size()) # torch.Size([8])
코드예시 : 2차원 텐서간 병합 (cat)
- dim = 0 → torch.size() 에서 첫번째 차원으로 병합 (a.size = [2,3], b.size = [2,3] → a_b.size = [4,3])
- dim = 1 → torch.size() 에서 두번째 차원으로 병합 (a.size = [2,3], b.size = [2,3] → a_b.size = [2,6])
import torch
print(torch.__version__) # 1.11.0
# 병합할 텐서
a = torch.arange(0,6).reshape(2,3)
print(a.size()) # torch.Size([2, 3])
print(a)
"""
tensor([[0, 1, 2],
[3, 4, 5]])
"""
b = torch.arange(6,12).reshape(2,3)
print(b.size()) # torch.Size([2, 3])
print(b)
"""
tensor([[ 6, 7, 8],
[ 9, 10, 11]])
"""
# 2차원 텐서간 병합
# dim = 0
a_b_dim_0 = torch.cat((a,b),dim=0)
print(a_b_dim_0.size()) # torch.Size([4, 3])
print(a_b_dim_0)
"""
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
"""
# dim = 1
a_b_dim_1 = torch.cat((a,b),dim=1)
print(a_b_dim_1.size()) # torch.Size([2, 6])
print(a_b_dim_1)
"""
tensor([[ 0, 1, 2, 6, 7, 8],
[ 3, 4, 5, 9, 10, 11]])
"""
코드예시 : 2차원 텐서 3개 병합 (cat)
- 여러개의 텐서를 seq로 만들어서 한번에 병합이 가능합니다.
import torch
print(torch.__version__) # 1.11.0
# 병합할 텐서
a = torch.arange(0,6).reshape(2,3)
print(a.size()) # torch.Size([2, 3])
print(a)
"""
tensor([[0, 1, 2],
[3, 4, 5]])
"""
b = torch.arange(6,12).reshape(2,3)
print(b.size()) # torch.Size([2, 3])
print(b)
"""
tensor([[ 6, 7, 8],
[ 9, 10, 11]])
"""
c = torch.arange(12,18).reshape(2,3)
print(c.size()) # torch.Size([2, 3])
print(c)
"""
tensor([[12, 13, 14],
[15, 16, 17]])
"""
# 2차원 텐서 3개 병합
# dim = 0
a_b_dim_0 = torch.cat((a,b,c),dim=0)
print(a_b_dim_0.size()) # torch.Size([6, 3])
print(a_b_dim_0)
"""
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]])
"""
# dim = 1
a_b_dim_1 = torch.cat((a,b,c),dim=1)
print(a_b_dim_1.size()) # torch.Size([2, 9])
print(a_b_dim_1)
"""
tensor([[ 0, 1, 2, 6, 7, 8, 12, 13, 14],
[ 3, 4, 5, 9, 10, 11, 15, 16, 17]])
"""
코드예시 : 3차원 텐서간 병합 (cat)
- 텐서가 3차원이 되어도 원리는 유사하다
- dim = 0 → torch.size() 에서 첫번째 차원으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [2,2,3])
- dim = 1 → torch.size() 에서 두번째 차원으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [1,4,3])
- dim = 2 → torch.size() 에서 세번째 차원으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [1,2,6])
import torch
print(torch.__version__) # 1.11.0
# 병합할 텐서
a = torch.arange(0,6).reshape(1,2,3)
print(a.size()) # torch.Size([1, 2, 3])
print(a)
"""
tensor([[[0, 1, 2],
[3, 4, 5]]])
"""
b = torch.arange(6,12).reshape(1,2,3)
print(b.size()) # torch.Size([1, 2, 3])
print(b)
"""
tensor([[[ 6, 7, 8],
[ 9, 10, 11]]])
"""
# 2차원 텐서간 병합
# dim = 0
a_b_dim_0 = torch.cat((a,b),dim=0)
print(a_b_dim_0.size()) # torch.Size([2, 2, 3])
print(a_b_dim_0)
"""
tensor([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]])
"""
# dim = 1
a_b_dim_1 = torch.cat((a,b),dim=1)
print(a_b_dim_1.size()) #torch.Size([1, 4, 3])
print(a_b_dim_1)
"""
tensor([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]]])
"""
# dim = 2
a_b_dim_2 = torch.cat((a,b),dim=2)
print(a_b_dim_2.size()) # torch.Size([1, 2, 6])
print(a_b_dim_2)
"""
tensor([[[ 0, 1, 2, 6, 7, 8],
[ 3, 4, 5, 9, 10, 11]]])
"""
코드예시 : 3차원 텐서간 병합 (cat) (dim 음수 일때)
- dim이 음수일때는 뒤에서 부터 계산한다.
- dim = 0 → torch.size() 앞에서 첫번째 차원으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [2,2,3])
- dim = 1 → torch.size() 앞에서 두번째 차원으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [1,4,3])
- dim = 2 → torch.size() 앞에서 세번째 차원(마지막차원)으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [1,2,6])
- dim = -1 → torch.size() 뒤에서 첫번째 차원(앞에서 마지막차원)으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [1,2,6])
- dim = -2 → torch.size() 뒤에서 두번째 차원으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [1,4,3])
- dim = -3 → torch.size() 뒤에서 세번째 차원으로 병합 (a.size = [1,2,3], b.size = [1,2,3] → a_b.size = [2,2,3])
# dim = -1
a_b_dim_1_ = torch.cat((a,b),dim=-1)
print(a_b_dim_1_.size()) # torch.Size([1, 2, 6])
print(a_b_dim_1_)
"""
tensor([[[ 0, 1, 2, 6, 7, 8],
[ 3, 4, 5, 9, 10, 11]]])
"""
# dim = -2
a_b_dim_2_ = torch.cat((a,b),dim=-2)
print(a_b_dim_2_.size()) #torch.Size([1, 4, 3])
print(a_b_dim_2_)
"""
tensor([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]]])
"""
# dim = -3
a_b_dim_3_ = torch.cat((a,b),dim=-3)
print(a_b_dim_3_.size()) # torch.Size([2, 2, 3])
print(a_b_dim_3_)
"""
ensor([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]])
"""
dim 비교 직접 해보시면 이해가 빠르게 되실 겁니다!!
print(a_b_dim_0) # dim = 0
print(a_b_dim_1) # dim = 1
print(a_b_dim_2) # dim = 2
print(a_b_dim_1_) # dim = -1
print(a_b_dim_2_) # dim = -2
print(a_b_dim_3_) # dim = -3
torch.cat(seq, dim)
- 앞에서 다룬 cat은 주어진 텐서의 고유 shape에 맞추어서 병합 했습니다.
- 그러나 stack은 새로운 차원으로 확장하여 텐서 시퀀스를 병합 합니다.
- 즉, stack은 말 그대로 주어진 차원으로 쌓아 올린다라고 생각하시면 됩니다.
Parameters
- seq : 합치고 싶은 텐서의 시퀀스
- dim : 텐서를 확장할 차원을 정수로 표현함
주의
- 확장하고자 하는 차원의 shape이 동일해야 합니다.
- 텐서를 dim 차원으로 확장한 뒤 cat을 진행한다는 느낌으로 생각하시면 됩니다.
코드예시 : 1차원 텐서간 병합 (stack)
- dim = 0 → unsqueeze(0)으로 텐서를 확장한 뒤 dim = 0 으로 병합한 것과 동일 (a.size = [3] ▶ 확장 : [1,3], b.size = [3] ▶ 확장 : [1,3] → a_b.size = [2,3])
- dim = 1 → unsqueeze(1)로 텐서로 확장한 뒤 dim = 1 으로 병합한 것과 동일 (a.size = [3] ▶ 확장 : [3,1], b.size = [3] ▶ 확장 : [3,1] → a_b.size = [3,2])
import torch
print(torch.__version__) # 1.11.0
# 병합할 텐서
a = torch.arange(0,3) # tensor([0, 1, 2])
b = torch.arange(3,6) # tensor([3, 4, 5])
print(a.size(), b.size()) # torch.Size([3]) torch.Size([3])
# 1차원 텐서간 병합
# dim = 0
a_b = torch.stack((a,b),dim=0)
print(a_b.size()) # torch.Size([2,3])
print(a_b)
"""
tensor([[0, 1, 2],
[3, 4, 5]])
"""
print(a_b == torch.cat((a.unsqueeze(0),b.unsqueeze(0)),dim=0))
"""
tensor([[True, True, True],
[True, True, True]])
"""
# dim = 1
a_b = torch.stack((a,b),dim=1) # tensor([0, 1, 2, 3, 4, 5, 6, 7])
print(a_b.size()) # torch.Size([3,2])
print(a_b)
"""
tensor([[0, 3],
[1, 4],
[2, 5]])
"""
print(a_b == torch.cat((a.unsqueeze(1),b.unsqueeze(1)),dim=1))
"""
tensor([[True, True],
[True, True],
[True, True]])
"""
코드예시 : 2차원 텐서간 병합 (stack)
- dim = 0 → unsqueeze(0)으로 텐서를 확장한 뒤 dim = 0 으로 병합한 것과 동일 (a.size = [2,3] ▶ 확장 : [1,2,3], b.size = [2,3] ▶ 확장 : [1,2,3] → a_b.size = [2,2,3])
- dim = 1 → unsqueeze(1)로 텐서로 확장한 뒤 dim = 1 으로 병합한 것과 동일 (a.size = [2,3] ▶ 확장 : [2,1,3], b.size = [2,3] ▶ 확장 : [2,1,3] → a_b.size = [2,2,3])
- dim = 2 → unsqueeze(2)로 텐서로 확장한 뒤 dim = 2 으로 병합한 것과 동일 (a.size = [2, 3] ▶ 확장 : [2,3,1], b.size = [2,3] ▶ 확장 : [2,3,1] → a_b.size = [2,3,2])
주의 : 같은 텐서 구조를 가지더라도 텐서의 내용을 자세히 보시면 다릅니다!!!
import torch
print(torch.__version__) # 1.11.0
# 병합할 텐서
a = torch.arange(0,6).reshape(2,3)
print(a.size()) # torch.Size([2, 3])
print(a)
"""
tensor([[0, 1, 2],
[3, 4, 5]])
"""
b = torch.arange(6,12).reshape(2,3)
print(b.size()) # torch.Size([2, 3])
print(b)
"""
tensor([[ 6, 7, 8],
[ 9, 10, 11]])
"""
# 2차원 텐서간 병합
# dim = 0
a_b_dim_0 = torch.stack((a,b),dim=0)
print(a_b_dim_0.size()) # torch.Size([2, 2, 3])
print(a_b_dim_0)
"""
tensor([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]])
"""
# dim = 1
a_b_dim_1 = torch.stack((a,b),dim=1)
print(a_b_dim_1.size()) # torch.Size([2, 2, 3])
print(a_b_dim_1)
"""
tensor([[[ 0, 1, 2],
[ 6, 7, 8]],
[[ 3, 4, 5],
[ 9, 10, 11]]])
"""
# dim = 2
a_b_dim_2 = torch.stack((a,b),dim=2)
print(a_b_dim_2.size()) # torch.Size([2, 3, 2])
print(a_b_dim_2)
"""
tensor([[[ 0, 6],
[ 1, 7],
[ 2, 8]],
[[ 3, 9],
[ 4, 10],
[ 5, 11]]])
"""
직접 unsqueeze를 해보며 차원의 변화를 알아보시는 것이 이해에 도움이 됩니다!
정리
- cat : 주어진 텐서들을 병합할때 dim 을 기준으로 붙여넣음. 그래서 전체 차원의 크기가 변하지 않음
- stack : 주어진 텐서들을 병합할때 dim 을 확장 시킨 후 붙여넣음. 그래서 전체 차원이 1개 늘어남
차원 변환 과정에서 텐서의 배열이 바뀌게 됨. 병합시 주의할 것!
[ 다음 글 ]
[ Pytorch ] nn.BCELoss, nn.BCEWithLogitsLoss, nn.CrossEntropyLoss, nn.NLLLoss 총정리
[ PyTorch / torchvision ] make_grid() 사용하기
[ PyTorch / torchvision ] draw_bounding_boxes() 사용하기
[ 참고 ]
https://pytorch.org/docs/stable/generated/torch.cat.html
https://pytorch.org/docs/stable/generated/torch.stack.html
반응형
'컴퓨터 언어 > Python_Pytorch' 카테고리의 다른 글
[ PyTorch / torchvision ] draw_segmentation_masks() 사용하기 (0) | 2022.08.31 |
---|---|
[ PyTorch / torchvision ] draw_bounding_boxes() 사용하기 (0) | 2022.08.30 |
[ PyTorch / torchvision ] make_grid() 사용하기 (0) | 2022.08.30 |
[ Pytorch ] nn.BCELoss, nn.BCEWithLogitsLoss, nn.CrossEntropyLoss, nn.NLLLoss 총정리 (0) | 2022.08.21 |
[ Pytorch ] 파이토치 설치하기 (0) | 2021.08.10 |