PyTorchのTensor型の基本 -作成方法、参照方法、演算方法-

PyTorchで開発を行う上で、最初に理解する必要があるのは、PyTorch独自のデータ型であるTensor型です。本記事ではTensor型のよく使う使い方をまとめました。

Tensorとは

PyTorchに用意されている特殊なデータ型で、正確にはtorch.Tensorというデータ型です。配列のような型で、多数のデータを入力でき、それらの演算が可能です。また、Tensor型はGPUを使用して演算が可能という特徴があり、近年の深層学習がGPUベースの演算を多く使用することを鑑みると、非常に深層学習と親和性の高いデータ型と言えます。

Tensorを使用するためには、PyTorchが使える環境を構築した上で「import torch」を宣言します。

import torch

Tensorの作成方法

まずは、Tensorの作成方法について紹介します。

0で初期化されたTensorを作成する方法

0で初期化されたTensorを作成するのが「zeros」です。以下に紹介する文法で、任意の次元のTensorを生成することができます。また「dtype」でデータ型を指定することも可能です。データ型を宣言しない場合にはfloat型(torch.FloatTensor)になります。

import torch

testmatrix1 = torch.zeros([5], dtype=torch.float)
testmatrix2 = torch.zeros([3, 4], dtype=torch.float)
testmatrix3 = torch.zeros([2, 2, 2], dtype=torch.int32)
print(testmatrix1)
print(testmatrix2)
print(testmatrix3)

1で初期化されたTensorを作成する方法

0の場合のzerosと同様にonesを利用することで、1で初期化したTensorを生成できます。

import torch

testmatrix = torch.ones([3, 2], dtype=torch.float)
print(testmatrix)

任意の値が代入されたTensorを作成する方法

任意の値でTensorを作成する場合には、以下のように「torch.tensor()」の中に配列を記述すればOKです。

import torch

testmatrix = torch.tensor([[1,2,3],[4,5,6]])
print(testmatrix)

Tensorの値の参照及び取り出しの方法

Tensor型からTensor型を取り出す方法

リストのように直感的にアクセスすることができます。

import torch
testmatrix = torch.tensor([[1,2,3],[4,5,6]])
print(testmatrix[0])        # tensor([1, 2, 3])
print(testmatrix[0][1])     # tensor(2)
print(testmatrix[0,1])      # tensor(2)
print(testmatrix[0,:])      # tensor([1, 2, 3])
print(testmatrix[:,1])      # tensor([2, 5])

testmatrix[0][1]とtestmatrix[0,1]は同じ結果となり、1行2列目の要素にアクセスできます。

print(testmatrix[0,:])とすると、1行目(インデックスは0からスタートするので0)の要素を全て取り出すことができます。コロン(:)は全ての要素を参照することを示します。

Tensor型から数値を取り出す方法

前述の方法だと、取得した結果もTensor型になってしまいます。Tensor型ではなく数値として取得したい場合にはitem()が使えます。

import torch
testmatrix = torch.tensor([[1,2,3],[4,5,6]])
print(testmatrix[0][1])         # tensor(2)
print(testmatrix[0][1].item())  # 2

Tensor型の基本的な演算

四則演算などの基本的な演算を行うことができます。

加算及び減算

「testmatrix1 + 1」のようにスカラーを足すと全ての要素に1が加算されます。

「testmatrix1 + testmatrix2」と「torch.add(testmatrix1,testmatrix2)」は同一で、各要素が加算されます。サイズが違うもの同士を加減算してしまうとエラーが出ます。

import torch
testmatrix1 = torch.tensor([[1,2],[3,4]])
testmatrix2 = torch.tensor([[4,3],[2,1]])
print(testmatrix1 + 1)                     # tensor([[2, 3],[4, 5]])
print(testmatrix1 + testmatrix2)           # tensor([[5, 5],[5, 5]])
print(torch.add(testmatrix1,testmatrix2))  # tensor([[5, 5],[5, 5]])
print(testmatrix1 - testmatrix2)           # tensor([[-3, -1],[1, 3]])
print(torch.sub(testmatrix1,testmatrix2))  # tensor([[-3, -1],[1, 3]])

乗算

掛け算は、加減算と同様にtestmatrix1 * 2のようにスカラーを掛けると全ての要素が2倍されます。

testmatrix1 * testmatrix2のように二つのTensorを「*」を使って乗算した場合、各要素の積が計算されます。torch.mul(testmatrix1,testmatrix2)も同様です。直感的な行列演算における乗算とは異なるので注意してください。

行列同士の掛け算をする場合には「testmatrix1 @ testmatrix2」のように@で表記します。torch.mm(testmatrix1,testmatrix2)も同様です。

import torch
testmatrix1 = torch.tensor([[1,0],[0,1]])
testmatrix2 = torch.tensor([[4,3],[2,1]])
print(testmatrix1 * 2)                     # tensor([[2, 0],[0, 2]])
print(testmatrix1 * testmatrix2)           # tensor([[4, 0],[0, 1]])
print(torch.mul(testmatrix1,testmatrix2))  # tensor([[4, 0],[0, 1]])
print(testmatrix1 @ testmatrix2)           # tensor([[4, 3],[2, 1]])
print(torch.mm(testmatrix1,testmatrix2))   # tensor([[4, 3],[2, 1]])

除算

基本は乗算と同じように除算も計算できます。

import torch
testmatrix1 = torch.tensor([[4,4],[2,2]])
testmatrix2 = torch.tensor([[4,4],[2,2]])
print(testmatrix1 / 2)                     # tensor([[2, 2],[1, 2]])
print(testmatrix1 / testmatrix2)           # tensor([[1, 1],[1, 1]])
print(torch.div(testmatrix1,testmatrix2))  # tensor([[1, 1],[1, 1]])

応用機能

内積(dot)

1次元のTensorを2つの内積をdot(v1,v2)で計算することができます。以下の例だと14=(1×1)+(2×2)+(3×3)と計算されます。

import torch
testmatrix1 = torch.tensor([1,2,3])
testmatrix2 = torch.tensor([1,2,3])
print(torch.dot(testmatrix1,testmatrix2))  # tensor(14)

平均・標準偏差(mean, std)

平均や標準偏差を出力する機能も用意されています。.mean()で平均を、.std()で標準偏差を出力します。

以下の例ではtestmatrix.mean(dim=0)とした場合には、1と3の平均及び2と4の平均が計算されます([2, 3])。testmatrix.mean(dim=1)とした場合には、1と2の平均及び3と4の平均が計算されます([1.5, 3.5])。

import torch
testmatrix = torch.FloatTensor([[1,2],[3,4]])
print(testmatrix.mean())      # tensor(2.5)
print(testmatrix.mean(dim=0)) # tensor([2,3])
print(testmatrix.mean(dim=1)) # tensor([1.5,3.5])
print(testmatrix.std())       # tensor(1.291)
print(testmatrix.std(dim=0))  # tensor([1.414,1.414])
print(testmatrix.std(dim=1))  # tensor([0.707,0.707])
スポンサーリンク

シェアする

フォローする