本記事では、PyTorchの配列を結合させるcatとstackという二つの機能を紹介します。これらは異なる結合方法を提供しているため、用途によって使い分けることが必要です。
torch.catは既存の次元軸に沿って配列を結合させます。以下の例を見てみましょう。
import torch
testmatrix1 = torch.tensor([[1,2,3],[4,5,6]])
testmatrix2 = torch.tensor([[7,8,9],[10,11,12]])
testmatrix_cat_dim0 = torch.cat((testmatrix1,testmatrix2), dim=0)
testmatrix_cat_dim1 = torch.cat((testmatrix1,testmatrix2), dim=1)
print(testmatrix_cat_dim0)
print(testmatrix_cat_dim0.size())
print(testmatrix_cat_dim1)
print(testmatrix_cat_dim1.size())
testmatrix1とtestmatrix2という二つのSize([2,3])の配列を結合します。
このとき、cat関数の中でdim=0を指定すると、1つ目の次元で配列を結合します。すなわち、2要素×3要素の行列を二つ結合すると、最終的な配列サイズは2要素側が結合するため4要素×3要素になります。
cat関数の中でdim=1を指定すると、2つ目の次元で配列を結合します。すなわち、2要素×3要素の行列を二つ結合すると、最終的な配列サイズは3要素側が結合するため2要素×6要素になります。
dimを指定しない場合にはdim=0がデフォルトになるため、dim=0であれば必ずしも記述する必要はありません。
torch.stackは新しい次元を作成して配列を結合させます。以下の例を見てみましょう。
import torch
testmatrix1 = torch.tensor([[1,2,3],[4,5,6]])
testmatrix2 = torch.tensor([[7,8,9],[10,11,12]])
testmatrix_stack = torch.stack((testmatrix1,testmatrix2))
print(testmatrix_stack)
print(testmatrix_stack.size())
testmatrix1とtestmatrix2という二つのSize([2,3])の配列を結合します。
このときに新しい次元が生成されるため、最終的な配列のサイズは2×2×3になります。2次元配列をStackする場合、Stack後の配列は3次元配列になります。