【技術記事】scikit-learnのdatasets.load_digitsしたデータの中身

みなさん、こんにちは、
みむすたーです。

python関連の書籍を読んでいて、一番よくわからないのが、データの構造です。
scikit-learnなどであらかじめ用意された機械学習用のデータは、余計にわからないと思います。

C言語など、配列の宣言時に明確な次元数の指定を必ず行なっているプログラミング言語ならいいのですが、
pythonは宣言時に何の指定もせずに配列を参照できてしまうので、少しややこしいです。

そこで、本記事では、

scikit-learnのload_digits関数の戻り値のデータの中身はどうなっているの?

という疑問にお答えします。

それではいきましょう。

datasets.load_digits()

load_digitsという名前の通り、これは10進数のデータをロードするための関数です。
ただ、ただ単の10進数のデータではなく、8×8のピクセルに手書きされた0〜9の数字です。

datasets.load_digits().dataの中身

以下のコードでdatasets.load_digits().dataのデータの中身をプロットしてみました。

%matplotlib inline
from sklearn import datasets
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import math
import numpy as np

# データを取り出す
datas = datasets.load_digits().data

# 中身データの詳細
print("★datasets.load_digits().dataの中身★")
print("サイズ:" + str(np.shape(datas)))
print("タイプ(リスト):" + str(type(datas)))
print("タイプ(1次元目要素):" + str(type(datas[0])))
print("タイプ(2次元目要素):" + str(type(datas[0][0])))
print("min値:" + str(np.min(datas)))
print("max値:" + str(np.max(datas)))
print("データ:")
print(datas)

# あまりにもデータの数が多いのでデータを間引く
datas = datas[0:36]

# ここからは画像をプロットする
plt.gray()
fig = plt.figure()
row_and_col = math.ceil(math.sqrt(len(datas)))
for i,data in enumerate(datas):
    ax = fig.add_subplot(row_and_col,row_and_col,i + 1)
    ax.imshow(data.reshape(8,8))

結果は以下の通りでした。

ここで上の出力結果について、少し解説しておきます。

datasets.load_digits().dataには、1797個の手書きデータが存在することがわかります。
おそらくscikit-learnのバージョンによっても変わると思うので、必ずこの数ではないかと思います。

1797個の手書きの数字データのそれぞれに、
8×8ピクセル、つまり、64個の要素を持つ配列が格納されていました。
その64個の要素の中にFloat型で0〜16の値が格納されていました。

また、上のソースコードで画像をプロットする際は、
1797個だと多すぎるので画像データとしてプロットしたのは、最初の36個だけにしました。
※私のPC環境では、36個の出力結果を得るだけでも処理が重かった。 (補足)

内容は以上です。

コメント