[NumPy] 10. グラフ上のランダムデータそれぞれの最近傍点を結ぶ

NumPy

はじめに

np.arrayのブロードキャストにより各点の2乗距離を求め、np.argsortを使って最近傍点を見出す方法について説明する。

コード

解説

データの生成

一様な乱数x2と、平均0.5, 標準偏差0.2の正規分布に従う乱数x3を生成する。sizeはそれぞれ(100,2)とした。

生成した配列の可視化

2乗距離の計算

ブロードキャストと集約関数を組み合わせて2乗距離を求める。

x2[:,np.newaxis, :] と x2[np.newaxis,:,:]の形はそれぞれ(100, 1, 2), (1, 100, 2)となる。これをブロードキャストで計算することで、各点のすべての組み合わせの、座標ごとの差を計算できる。

差を**2で2乗して、合計することで2乗距離を求める。

np.sum()でaxis=-1とすると、3次元配列(形状=(100,100,2))の[:,:,0]のデータと[:,:,1]のデータの合計を得ることができる。

近傍点の座標

np.argsort()を2乗距離の配列に適用することで各点に最も近い座標のインデックスを取得できる。

最近傍の点の可視化

上記コードで各点における最も近い点との間に線をプロットできる。

コードをダウンロード(.pyファイル)

コードをダウンロード(.ipynbファイル)

参考

コメント