Python/numpy & Pytorch

[Pytorch] tensor 합치기는 방법 cat(), stack()

파송송 2023. 3. 14. 18:47
728x90

'+' 연산자

list

list에서는 '+' 연산자를 쓰면 list가 합쳐진다.

x = [1,2]
x2 = [3,4]

x+x2
[1, 2, 3, 4]

 

Tensor

tensor는 합쳐지지 않고 각 원소마다 더해진다. 이는 같은 차원끼리 더하거나 한 차원이 1일 때만 가능함

x = torch.randint(0, 10,(3,1))
x2 = torch.randint(0, 10,(3,1))
x3 = torch.randint(0, 10,(1,1))

x, x2, x+x2, x+x3
tensor([[4],
         [2],
         [1]])
         
 tensor([[2],
         [1],
         [2]])
         
 tensor([[6],
         [3],
         [3]])
         
 tensor([[5],
         [3],
         [2]])

x = torch.randint(0, 10,(3,1))
x2 = torch.randint(0, 10,(2,1))

x+x2
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0

Cat(seq, dim)

concatenate 함수이며 list의 append와 같이 차원을 증가시킴

합치려는 차원을 제외한 다은 차원의 shape은 같아야 함

x = torch.randint(0, 10,(2,5))
x2 = torch.randint(0, 10,(2,7))
x, x2
x3 = torch.cat((x,x2),dim=1)
x3
tensor([[7, 0, 3, 3, 2],
         [9, 9, 8, 9, 3]])
tensor([[0, 1, 2, 9, 7, 9, 8],
    	 [0, 3, 6, 7, 2, 0, 2]])
         
tensor([[7, 0, 3, 3, 2, 0, 1, 2, 9, 7, 9, 8],
	[9, 9, 8, 9, 3, 0, 3, 6, 7, 2, 0, 2]])

x = torch.randint(0, 10,(10,2))
x2 = torch.randint(0, 10,(8,2))
x, x2
x3 = torch.cat((x,x2),dim=0)
x3
(tensor([[0, 2],
         [6, 2],
         [6, 1],
         [5, 1],
         [3, 3],
         [1, 6],
         [3, 2],
         [6, 5],
         [5, 3],
         [7, 2]]),
 tensor([[3, 4],
         [7, 9],
         [2, 7],
         [3, 4],
         [3, 3],
         [0, 8],
         [2, 6],
         [0, 4]]))
         

tensor([[0, 2],
        [6, 2],
        [6, 1],
        [5, 1],
        [3, 3],
        [1, 6],
        [3, 2],
        [6, 5],
        [5, 3],
        [7, 2],
        [3, 4],
        [7, 9],
        [2, 7],
        [3, 4],
        [3, 3],
        [0, 8],
        [2, 6],
        [0, 4]])

Stack(seq, dim)

지정하는 차원으로 확장하여 tensor를 쌓음, 2개의 tensor의 shape이 똑같아야 할 수 있음

x = torch.randint(0, 10,(2,10))
x2 = torch.randint(0, 10,(2,10))
x, x2
x3 = torch.stack((x,x2),dim=0)
x3
(tensor([[4, 3, 0, 7, 3, 6, 1, 4, 7, 4],
         [6, 8, 3, 7, 2, 2, 7, 6, 2, 4]]),
 tensor([[9, 4, 6, 9, 0, 8, 0, 1, 0, 7],
         [5, 4, 8, 0, 2, 1, 6, 3, 1, 3]]))
         
tensor([[[4, 3, 0, 7, 3, 6, 1, 4, 7, 4],
         [6, 8, 3, 7, 2, 2, 7, 6, 2, 4]],

        [[9, 4, 6, 9, 0, 8, 0, 1, 0, 7],
         [5, 4, 8, 0, 2, 1, 6, 3, 1, 3]]])
728x90