Categories: Pytorch

PyTorchのTensor配列の結合方法(catとstack)

本記事では、PyTorchの配列を結合させるcatとstackという二つの機能を紹介します。これらは異なる結合方法を提供しているため、用途によって使い分けることが必要です。

既存の次元軸(dim)に沿って結合するtorch.cat

torch.catは既存の次元軸に沿って配列を結合させます。以下の例を見てみましょう。

import torch
testmatrix1 = torch.tensor([[1,2,3],[4,5,6]])
testmatrix2 = torch.tensor([[7,8,9],[10,11,12]])
testmatrix_cat_dim0 = torch.cat((testmatrix1,testmatrix2), dim=0)
testmatrix_cat_dim1 = torch.cat((testmatrix1,testmatrix2), dim=1)
print(testmatrix_cat_dim0)
print(testmatrix_cat_dim0.size())
print(testmatrix_cat_dim1)
print(testmatrix_cat_dim1.size())

testmatrix1とtestmatrix2という二つのSize([2,3])の配列を結合します。

このとき、cat関数の中でdim=0を指定すると、1つ目の次元で配列を結合します。すなわち、2要素×3要素の行列を二つ結合すると、最終的な配列サイズは2要素側が結合するため4要素×3要素になります。

cat関数の中でdim=1を指定すると、2つ目の次元で配列を結合します。すなわち、2要素×3要素の行列を二つ結合すると、最終的な配列サイズは3要素側が結合するため2要素×6要素になります。

dimを指定しない場合にはdim=0がデフォルトになるため、dim=0であれば必ずしも記述する必要はありません。

新しい次元軸(dim)を生成して結合するtorch.stack

torch.stackは新しい次元を作成して配列を結合させます。以下の例を見てみましょう。

import torch
testmatrix1 = torch.tensor([[1,2,3],[4,5,6]])
testmatrix2 = torch.tensor([[7,8,9],[10,11,12]])
testmatrix_stack = torch.stack((testmatrix1,testmatrix2))
print(testmatrix_stack)
print(testmatrix_stack.size())

testmatrix1とtestmatrix2という二つのSize([2,3])の配列を結合します。

このときに新しい次元が生成されるため、最終的な配列のサイズは2×2×3になります。2次元配列をStackする場合、Stack後の配列は3次元配列になります。

Haruoka

Share
Published by
Haruoka