目录
1.EM算法的实现2.EM算法估算GMM的参数3.GMM调参:covariance_type4.EM算法无监督分类鸢尾花数据5.GMM/DPGMM(贝叶斯高斯分布)比较
1.EM算法的实现
import numpy
as np
from scipy
.stats
import multivariate_normal
from sklearn
.mixture
import GaussianMixture
from mpl_toolkits
.mplot3d
import Axes3D
import matplotlib
as mpl
import matplotlib
.pyplot
as plt
from sklearn
.metrics
.pairwise
import pairwise_distances_argmin
mpl
.rcParams
['font.sans-serif'] = [u
'SimHei']
mpl
.rcParams
['axes.unicode_minus'] = False
if __name__
== '__main__':
style
= 'myself'
np
.random
.seed
(0)
mu1_fact
= (0, 0, 0)
cov1_fact
= np
.diag
((1, 2, 3))
data1
= np
.random
.multivariate_normal
(mu1_fact
, cov1_fact
, 400)
mu2_fact
= (2, 2, 1)
cov2_fact
= np
.array
(((1, 1, 3), (1, 2, 1), (0, 0, 1)))
data2
= np
.random
.multivariate_normal
(mu2_fact
, cov2_fact
, 100)
data
= np
.vstack
((data1
, data2
))
y
= np
.array
([True] * 400 + [False] * 100)
if style
== 'sklearn':
g
= GaussianMixture
(n_components
=2, covariance_type
='full', tol
=1e-6, max_iter
=1000)
g
.fit
(data
)
print('类别概率:\t', g
.weights_
[1])
print('均值:\n', g
.means_
, '\n')
print('方差:\n', g
.covariances_
, '\n')
mu1
, mu2
= g
.means_
sigma1
, sigma2
= g
.covariances_
else:
num_iter
= 100
n
, d
= data
.shape
mu1
= data
.min(axis
=0)
mu2
= data
.max(axis
=0)
sigma1
= np
.identity
(d
)
sigma2
= np
.identity
(d
)
pi
= 0.5
for i
in range(num_iter
):
norm1
= multivariate_normal
(mu1
, sigma1
)
norm2
= multivariate_normal
(mu2
, sigma2
)
tau1
= pi
* norm1
.pdf
(data
)
tau2
= (1 - pi
) * norm2
.pdf
(data
)
gamma
= tau1
/ (tau1
+ tau2
)
mu1
= np
.dot
(gamma
, data
) / np
.sum(gamma
)
mu2
= np
.dot
((1 - gamma
), data
) / np
.sum((1 - gamma
))
sigma1
= np
.dot
(gamma
* (data
- mu1
).T
, data
- mu1
) / np
.sum(gamma
)
sigma2
= np
.dot
((1 - gamma
) * (data
- mu2
).T
, data
- mu2
) / np
.sum(1 - gamma
)
pi
= np
.sum(gamma
) / n
print('类别概率:\t', pi
)
print('均值:\t', mu1
, mu2
)
print('方差:\n', sigma1
, '\n\n', sigma2
, '\n')
norm1
= multivariate_normal
(mu1
, sigma1
)
norm2
= multivariate_normal
(mu2
, sigma2
)
tau1
= norm1
.pdf
(data
)
tau2
= norm2
.pdf
(data
)
fig
= plt
.figure
(figsize
=(13, 7), facecolor
='w')
ax
= fig
.add_subplot
(121, projection
='3d')
ax
.scatter
(data
[:, 0], data
[:, 1], data
[:, 2], c
='b', s
=30, marker
='o', depthshade
=True)
ax
.set_xlabel
('X')
ax
.set_ylabel
('Y')
ax
.set_zlabel
('Z')
ax
.set_title
(u
'原始数据', fontsize
=18)
ax
= fig
.add_subplot
(122, projection
='3d')
order
= pairwise_distances_argmin
([mu1_fact
, mu2_fact
], [mu1
, mu2
], metric
='euclidean')
print(order
)
if order
[0] == 0:
c1
= tau1
> tau2
else:
c1
= tau1
< tau2
c2
= ~c1
acc
= np
.mean
(y
== c1
)
print(u
'准确率:%.2f%%' % (100*acc
))
ax
.scatter
(data
[c1
, 0], data
[c1
, 1], data
[c1
, 2], c
='r', s
=30, marker
='o', depthshade
=True)
ax
.scatter
(data
[c2
, 0], data
[c2
, 1], data
[c2
, 2], c
='g', s
=30, marker
='^', depthshade
=True)
ax
.set_xlabel
('X')
ax
.set_ylabel
('Y')
ax
.set_zlabel
('Z')
ax
.set_title
(u
'EM算法分类', fontsize
=18)
plt
.suptitle
(u
'EM算法的实现', fontsize
=21)
plt
.subplots_adjust
(top
=0.90)
plt
.tight_layout
()
plt
.show
()
类别概率
: 0.7650337783291882
均值
: [-0.123994 -0.02138048 -0.06003756] [1.9076683 1.79622192 1.11752474]
方差
:
[[ 0.82563399 -0.10180706 -0.0414597 ]
[-0.10180706 2.15816316 -0.16360603]
[-0.0414597 -0.16360603 2.79283956]]
[[0.69690051 0.90370392 0.73552321]
[0.90370392 1.8856117 0.76747618]
[0.73552321 0.76747618 2.94819132]]
[0 1]
准确率:
89.80%
2.EM算法估算GMM的参数
import numpy
as np
from sklearn
.mixture
import GaussianMixture
from sklearn
.model_selection
import train_test_split
import matplotlib
as mpl
import matplotlib
.colors
import matplotlib
.pyplot
as plt
mpl
.rcParams
['font.sans-serif'] = [u
'SimHei']
mpl
.rcParams
['axes.unicode_minus'] = False
def expand(a
, b
):
d
= (b
- a
) * 0.05
return a
-d
, b
+d
if __name__
== '__main__':
data
= np
.loadtxt
('HeightWeight.csv', dtype
=np
.float, delimiter
=',', skiprows
=1)
print(data
.shape
)
y
, x
= np
.split
(data
, [1, ], axis
=1)
x
, x_test
, y
, y_test
= train_test_split
(x
, y
, train_size
=0.6, random_state
=0)
gmm
= GaussianMixture
(n_components
=2, covariance_type
='full', random_state
=0)
x_min
= np
.min(x
, axis
=0)
x_max
= np
.max(x
, axis
=0)
gmm
.fit
(x
)
print('均值 = \n', gmm
.means_
)
print('方差 = \n', gmm
.covariances_
)
y_hat
= gmm
.predict
(x
)
y_test_hat
= gmm
.predict
(x_test
)
change
= (gmm
.means_
[0][0] > gmm
.means_
[1][0])
if change
:
z
= y_hat
== 0
y_hat
[z
] = 1
y_hat
[~z
] = 0
z
= y_test_hat
== 0
y_test_hat
[z
] = 1
y_test_hat
[~z
] = 0
acc
= np
.mean
(y_hat
.ravel
() == y
.ravel
())
acc_test
= np
.mean
(y_test_hat
.ravel
() == y_test
.ravel
())
acc_str
= u
'训练集准确率:%.2f%%' % (acc
* 100)
acc_test_str
= u
'测试集准确率:%.2f%%' % (acc_test
* 100)
print(acc_str
)
print(acc_test_str
)
cm_light
= mpl
.colors
.ListedColormap
(['#FF8080', '#77E0A0'])
cm_dark
= mpl
.colors
.ListedColormap
(['r', 'g'])
x1_min
, x1_max
= x
[:, 0].min(), x
[:, 0].max()
x2_min
, x2_max
= x
[:, 1].min(), x
[:, 1].max()
x1_min
, x1_max
= expand
(x1_min
, x1_max
)
x2_min
, x2_max
= expand
(x2_min
, x2_max
)
x1
, x2
= np
.mgrid
[x1_min
:x1_max
:500j, x2_min
:x2_max
:500j]
grid_test
= np
.stack
((x1
.flat
, x2
.flat
), axis
=1)
grid_hat
= gmm
.predict
(grid_test
)
grid_hat
= grid_hat
.reshape
(x1
.shape
)
if change
:
z
= grid_hat
== 0
grid_hat
[z
] = 1
grid_hat
[~z
] = 0
plt
.figure
(figsize
=(9, 7), facecolor
='w')
plt
.pcolormesh
(x1
, x2
, grid_hat
, cmap
=cm_light
)
plt
.scatter
(x
[:, 0], x
[:, 1], s
=50, c
=np
.squeeze
(y
), marker
='o', cmap
=cm_dark
, edgecolors
='k')
plt
.scatter
(x_test
[:, 0], x_test
[:, 1], s
=60, c
=np
.squeeze
(y_test
), marker
='^', cmap
=cm_dark
, edgecolors
='k')
p
= gmm
.predict_proba
(grid_test
)
np
.set_printoptions
(suppress
=True)
print(p
)
p
= p
[:, 0].reshape
(x1
.shape
)
CS
= plt
.contour
(x1
, x2
, p
, levels
=(0.1, 0.5, 0.8), colors
=list('rgb'), linewidths
=2)
plt
.clabel
(CS
, fontsize
=15, fmt
='%.1f', inline
=True)
ax1_min
, ax1_max
, ax2_min
, ax2_max
= plt
.axis
()
xx
= 0.9*ax1_min
+ 0.1*ax1_max
yy
= 0.1*ax2_min
+ 0.9*ax2_max
plt
.text
(xx
, yy
, acc_str
, fontsize
=18)
yy
= 0.15*ax2_min
+ 0.85*ax2_max
plt
.text
(xx
, yy
, acc_test_str
, fontsize
=18)
plt
.xlim
((x1_min
, x1_max
))
plt
.ylim
((x2_min
, x2_max
))
plt
.xlabel
(u
'身高(cm)', fontsize
='large')
plt
.ylabel
(u
'体重(kg)', fontsize
='large')
plt
.title
(u
'EM算法估算GMM的参数', fontsize
=20)
plt
.grid
()
plt
.show
()
(114, 3)
均值
=
[[160.13983374 55.93370575]
[173.50243688 65.03359308]]
方差
=
[[[ 18.82128194 12.30370549]
[ 12.30370549 31.23596113]]
[[ 23.22794989 28.48688647]
[ 28.48688647 105.81824734]]]
训练集准确率:
77.94%
测试集准确率:
82.61%
[[0.99999775 0.00000225]
[0.99999784 0.00000216]
[0.99999792 0.00000208]
...
[0. 1. ]
[0. 1. ]
[0. 1. ]]
3.GMM调参:covariance_type
import numpy
as np
from sklearn
.mixture
import GaussianMixture
import matplotlib
as mpl
import matplotlib
.colors
import matplotlib
.pyplot
as plt
mpl
.rcParams
['font.sans-serif'] = [u
'SimHei']
mpl
.rcParams
['axes.unicode_minus'] = False
def expand(a
, b
, rate
=0.05):
d
= (b
- a
) * rate
return a
-d
, b
+d
def accuracy_rate(y1
, y2
):
acc
= np
.mean
(y1
== y2
)
return acc
if acc
> 0.5 else 1-acc
if __name__
== '__main__':
np
.random
.seed
(0)
cov1
= np
.diag
((1, 2))
print(cov1
)
N1
= 500
N2
= 300
N
= N1
+ N2
x1
= np
.random
.multivariate_normal
(mean
=(1, 2), cov
=cov1
, size
=N1
)
m
= np
.array
(((1, 1), (1, 3)))
x1
= x1
.dot
(m
)
x2
= np
.random
.multivariate_normal
(mean
=(-1, 10), cov
=cov1
, size
=N2
)
x
= np
.vstack
((x1
, x2
))
y
= np
.array
([0]*N1
+ [1]*N2
)
types
= ('spherical', 'diag', 'tied', 'full')
err
= np
.empty
(len(types
))
bic
= np
.empty
(len(types
))
for i
, type in enumerate(types
):
gmm
= GaussianMixture
(n_components
=2, covariance_type
=type, random_state
=0)
gmm
.fit
(x
)
err
[i
] = 1 - accuracy_rate
(gmm
.predict
(x
), y
)
bic
[i
] = gmm
.bic
(x
)
print('错误率:', err
.ravel
())
print('BIC:', bic
.ravel
())
xpos
= np
.arange
(4)
plt
.figure
(facecolor
='w')
ax
= plt
.axes
()
b1
= ax
.bar
(xpos
-0.3, err
, width
=0.3, color
='#77E0A0')
b2
= ax
.twinx
().bar
(xpos
, bic
, width
=0.3, color
='#FF8080')
plt
.grid
(True)
bic_min
, bic_max
= expand
(bic
.min(), bic
.max())
plt
.ylim
((bic_min
, bic_max
))
plt
.xticks
(xpos
, types
)
plt
.legend
([b1
[0], b2
[0]], (u
'错误率', u
'BIC'))
plt
.title
(u
'不同方差类型的误差率和BIC', fontsize
=18)
plt
.show
()
optimal
= bic
.argmin
()
gmm
= GaussianMixture
(n_components
=2, covariance_type
=types
[optimal
], random_state
=0)
gmm
.fit
(x
)
print('均值 = \n', gmm
.means_
)
print('方差 = \n', gmm
.covariances_
)
y_hat
= gmm
.predict
(x
)
cm_light
= mpl
.colors
.ListedColormap
(['#FF8080', '#77E0A0'])
cm_dark
= mpl
.colors
.ListedColormap
(['r', 'g'])
x1_min
, x1_max
= x
[:, 0].min(), x
[:, 0].max()
x2_min
, x2_max
= x
[:, 1].min(), x
[:, 1].max()
x1_min
, x1_max
= expand
(x1_min
, x1_max
)
x2_min
, x2_max
= expand
(x2_min
, x2_max
)
x1
, x2
= np
.mgrid
[x1_min
:x1_max
:500j, x2_min
:x2_max
:500j]
grid_test
= np
.stack
((x1
.flat
, x2
.flat
), axis
=1)
grid_hat
= gmm
.predict
(grid_test
)
grid_hat
= grid_hat
.reshape
(x1
.shape
)
if gmm
.means_
[0][0] > gmm
.means_
[1][0]:
z
= grid_hat
== 0
grid_hat
[z
] = 1
grid_hat
[~z
] = 0
plt
.figure
(figsize
=(9, 7), facecolor
='w')
plt
.pcolormesh
(x1
, x2
, grid_hat
, cmap
=cm_light
)
plt
.scatter
(x
[:, 0], x
[:, 1], s
=30, c
=y
, marker
='o', cmap
=cm_dark
, edgecolors
='k')
ax1_min
, ax1_max
, ax2_min
, ax2_max
= plt
.axis
()
plt
.xlim
((x1_min
, x1_max
))
plt
.ylim
((x2_min
, x2_max
))
plt
.title
(u
'GMM调参:covariance_type=%s' % types
[optimal
], fontsize
=20)
plt
.grid
()
plt
.show
()
[[1 0]
[0 2]]
错误率:
[0.385 0.315 0.3 0.00125]
BIC:
[7990.71460065 7855.56050855 8006.49834359 6845.79374805]
均值
=
[[ 2.88444448 6.69484552]
[-0.97642254 10.06927801]]
方差
=
[[[ 2.87015473 6.64421303]
[ 6.64421303 18.00318872]]
[[ 0.91302546 -0.04298504]
[-0.04298504 1.9603531 ]]]
4.EM算法无监督分类鸢尾花数据
import numpy
as np
import pandas
as pd
from sklearn
.mixture
import GaussianMixture
import matplotlib
as mpl
import matplotlib
.colors
import matplotlib
.pyplot
as plt
from sklearn
.metrics
.pairwise
import pairwise_distances_argmin
mpl
.rcParams
['font.sans-serif'] = [u
'SimHei']
mpl
.rcParams
['axes.unicode_minus'] = False
iris_feature
= u
'花萼长度', u
'花萼宽度', u
'花瓣长度', u
'花瓣宽度'
def expand(a
, b
, rate
=0.05):
d
= (b
- a
) * rate
return a
-d
, b
+d
if __name__
== '__main__':
path
= 'iris.data'
data
= pd
.read_csv
(path
, header
=None)
x_prime
, y
= data
[np
.arange
(4)], data
[4]
y
= pd
.Categorical
(y
).codes
n_components
= 3
feature_pairs
= [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
plt
.figure
(figsize
=(10, 9), facecolor
='#FFFFFF')
for k
, pair
in enumerate(feature_pairs
):
x
= x_prime
[pair
]
m
= np
.array
([np
.mean
(x
[y
== i
], axis
=0) for i
in range(3)])
print('实际均值 = \n', m
)
gmm
= GaussianMixture
(n_components
=n_components
, covariance_type
='full', random_state
=0)
gmm
.fit
(x
)
print('预测均值 = \n', gmm
.means_
)
print('预测方差 = \n', gmm
.covariances_
)
y_hat
= gmm
.predict
(x
)
order
= pairwise_distances_argmin
(m
, gmm
.means_
, axis
=1, metric
='euclidean')
print('顺序:\t', order
)
n_sample
= y
.size
n_types
= 3
change
= np
.empty
((n_types
, n_sample
), dtype
=np
.bool)
for i
in range(n_types
):
change
[i
] = y_hat
== order
[i
]
for i
in range(n_types
):
y_hat
[change
[i
]] = i
acc
= u
'准确率:%.2f%%' % (100*np
.mean
(y_hat
== y
))
print(acc
)
cm_light
= mpl
.colors
.ListedColormap
(['#FF8080', '#77E0A0', '#A0A0FF'])
cm_dark
= mpl
.colors
.ListedColormap
(['r', 'g', '#6060FF'])
x1_min
, x2_min
= x
.min()
x1_max
, x2_max
= x
.max()
x1_min
, x1_max
= expand
(x1_min
, x1_max
)
x2_min
, x2_max
= expand
(x2_min
, x2_max
)
x1
, x2
= np
.mgrid
[x1_min
:x1_max
:500j, x2_min
:x2_max
:500j]
grid_test
= np
.stack
((x1
.flat
, x2
.flat
), axis
=1)
grid_hat
= gmm
.predict
(grid_test
)
change
= np
.empty
((n_types
, grid_hat
.size
), dtype
=np
.bool)
for i
in range(n_types
):
change
[i
] = grid_hat
== order
[i
]
for i
in range(n_types
):
grid_hat
[change
[i
]] = i
grid_hat
= grid_hat
.reshape
(x1
.shape
)
plt
.subplot
(3, 2, k
+1)
plt
.pcolormesh
(x1
, x2
, grid_hat
, cmap
=cm_light
)
plt
.scatter
(x
[pair
[0]], x
[pair
[1]], s
=30, c
=y
, marker
='o', cmap
=cm_dark
, edgecolors
='k')
xx
= 0.95 * x1_min
+ 0.05 * x1_max
yy
= 0.1 * x2_min
+ 0.9 * x2_max
plt
.text
(xx
, yy
, acc
, fontsize
=14)
plt
.xlim
((x1_min
, x1_max
))
plt
.ylim
((x2_min
, x2_max
))
plt
.xlabel
(iris_feature
[pair
[0]], fontsize
=14)
plt
.ylabel
(iris_feature
[pair
[1]], fontsize
=14)
plt
.grid
()
plt
.tight_layout
(2)
plt
.suptitle
(u
'EM算法无监督分类鸢尾花数据', fontsize
=20)
plt
.subplots_adjust
(top
=0.92)
plt
.show
()
实际均值
=
[[5.006 3.418]
[5.936 2.77 ]
[6.588 2.974]]
预测均值
=
[[5.01493896 3.4404862 ]
[6.6814044 3.0285628 ]
[5.90114537 2.74385294]]
预测方差
=
[[[0.1194876 0.08969867]
[0.08969867 0.12147459]]
[[0.36087007 0.05158991]
[0.05158991 0.08923683]]
[[0.27544608 0.08866062]
[0.08866062 0.09382524]]]
顺序:
[0 2 1]
准确率:
79.33%
实际均值
=
[[5.006 1.464]
[5.936 4.26 ]
[6.588 5.552]]
预测均值
=
[[5.0060006 1.46399865]
[6.58888904 5.63329718]
[6.04240777 4.41742864]]
预测方差
=
[[[0.12176525 0.01581631]
[0.01581631 0.0295045 ]]
[[0.48521779 0.36602418]
[0.36602418 0.32601109]]
[[0.28119672 0.23746926]
[0.23746926 0.31503012]]]
顺序:
[0 2 1]
准确率:
91.33%
实际均值
=
[[5.006 0.244]
[5.936 1.326]
[6.588 2.026]]
预测均值
=
[[5.00607264 0.23754806]
[6.56291563 2.02408174]
[5.94928821 1.32089151]]
预测方差
=
[[[0.1239802 0.01055412]
[0.01055412 0.00914172]]
[[0.41146765 0.0558402 ]
[0.0558402 0.07689828]]
[[0.29249903 0.07999787]
[0.07999787 0.0509237 ]]]
顺序:
[0 2 1]
准确率:
96.00%
实际均值
=
[[3.418 1.464]
[2.77 4.26 ]
[2.974 5.552]]
预测均值
=
[[3.41800009 1.46400001]
[2.80062882 4.43004172]
[2.97017899 5.56073357]]
预测方差
=
[[[0.14227691 0.01144799]
[0.01144799 0.029505 ]]
[[0.09376548 0.10702236]
[0.10702236 0.34454954]]
[[0.11477629 0.07760424]
[0.07760424 0.38871245]]]
顺序:
[0 1 2]
准确率:
92.67%
实际均值
=
[[3.418 0.244]
[2.77 1.326]
[2.974 2.026]]
预测均值
=
[[3.41800003 0.244 ]
[2.93629236 1.98607968]
[2.79657869 1.31224583]]
预测方差
=
[[[0.14227697 0.011208 ]
[0.011208 0.011265 ]]
[[0.11263095 0.06192916]
[0.06192916 0.08966439]]
[[0.09554395 0.04869984]
[0.04869984 0.03787478]]]
顺序:
[0 2 1]
准确率:
93.33%
实际均值
=
[[1.464 0.244]
[4.26 1.326]
[5.552 2.026]]
预测均值
=
[[1.46399926 0.24399973]
[4.32760641 1.36230588]
[5.60395923 2.0545418 ]]
预测方差
=
[[[0.02950475 0.00558391]
[0.00558391 0.01126496]]
[[0.25410036 0.09152045]
[0.09152045 0.05088321]]
[[0.29156635 0.03719346]
[0.03719346 0.07073705]]]
顺序:
[0 1 2]
准确率:
97.33%
5.GMM/DPGMM(贝叶斯高斯分布)比较
import numpy
as np
from sklearn
.mixture
import GaussianMixture
, BayesianGaussianMixture
import scipy
as sp
import matplotlib
as mpl
import matplotlib
.colors
import matplotlib
.pyplot
as plt
from matplotlib
.patches
import Ellipse
def expand(a
, b
, rate
=0.05):
d
= (b
- a
) * rate
return a
-d
, b
+d
matplotlib
.rcParams
['font.sans-serif'] = [u
'SimHei']
matplotlib
.rcParams
['axes.unicode_minus'] = False
if __name__
== '__main__':
np
.random
.seed
(0)
cov1
= np
.diag
((1, 2))
N1
= 500
N2
= 300
N
= N1
+ N2
x1
= np
.random
.multivariate_normal
(mean
=(3, 2), cov
=cov1
, size
=N1
)
m
= np
.array
(((1, 1), (1, 3)))
x1
= x1
.dot
(m
)
x2
= np
.random
.multivariate_normal
(mean
=(-1, 10), cov
=cov1
, size
=N2
)
x
= np
.vstack
((x1
, x2
))
y
= np
.array
([0]*N1
+ [1]*N2
)
n_components
= 3
colors
= '#A0FFA0', '#2090E0', '#FF8080'
cm
= mpl
.colors
.ListedColormap
(colors
)
x1_min
, x1_max
= x
[:, 0].min(), x
[:, 0].max()
x2_min
, x2_max
= x
[:, 1].min(), x
[:, 1].max()
x1_min
, x1_max
= expand
(x1_min
, x1_max
)
x2_min
, x2_max
= expand
(x2_min
, x2_max
)
x1
, x2
= np
.mgrid
[x1_min
:x1_max
:500j, x2_min
:x2_max
:500j]
grid_test
= np
.stack
((x1
.flat
, x2
.flat
), axis
=1)
plt
.figure
(figsize
=(9, 9), facecolor
='w')
plt
.suptitle
(u
'GMM/DPGMM比较', fontsize
=23)
ax
= plt
.subplot
(211)
gmm
= GaussianMixture
(n_components
=n_components
, covariance_type
='full', random_state
=0)
gmm
.fit
(x
)
centers
= gmm
.means_
covs
= gmm
.covariances_
print('GMM均值 = \n', centers
)
print('GMM方差 = \n', covs
)
y_hat
= gmm
.predict
(x
)
grid_hat
= gmm
.predict
(grid_test
)
grid_hat
= grid_hat
.reshape
(x1
.shape
)
plt
.pcolormesh
(x1
, x2
, grid_hat
, cmap
=cm
)
plt
.scatter
(x
[:, 0], x
[:, 1], s
=30, c
=y
, cmap
=cm
, marker
='o')
clrs
= list('rgbmy')
for i
, (center
, cov
) in enumerate(zip(centers
, covs
)):
value
, vector
= sp
.linalg
.eigh
(cov
)
width
, height
= value
[0], value
[1]
v
= vector
[0] / sp
.linalg
.norm
(vector
[0])
angle
= 180* np
.arctan
(v
[1] / v
[0]) / np
.pi
e
= Ellipse
(xy
=center
, width
=width
, height
=height
,
angle
=angle
, color
=clrs
[i
], alpha
=0.5, clip_box
= ax
.bbox
)
ax
.add_artist
(e
)
ax1_min
, ax1_max
, ax2_min
, ax2_max
= plt
.axis
()
plt
.xlim
((x1_min
, x1_max
))
plt
.ylim
((x2_min
, x2_max
))
plt
.title
(u
'GMM', fontsize
=20)
plt
.grid
(True)
dpgmm
= BayesianGaussianMixture
(n_components
=n_components
, covariance_type
='full', max_iter
=1000, n_init
=5,
weight_concentration_prior_type
='dirichlet_process', weight_concentration_prior
=0.1)
dpgmm
.fit
(x
)
centers
= dpgmm
.means_
covs
= dpgmm
.covariances_
print('DPGMM均值 = \n', centers
)
print('DPGMM方差 = \n', covs
)
y_hat
= dpgmm
.predict
(x
)
print(y_hat
)
ax
= plt
.subplot
(212)
grid_hat
= dpgmm
.predict
(grid_test
)
grid_hat
= grid_hat
.reshape
(x1
.shape
)
plt
.pcolormesh
(x1
, x2
, grid_hat
, cmap
=cm
)
plt
.scatter
(x
[:, 0], x
[:, 1], s
=30, c
=y
, cmap
=cm
, marker
='o')
for i
, cc
in enumerate(zip(centers
, covs
)):
if i
not in y_hat
:
continue
center
, cov
= cc
value
, vector
= sp
.linalg
.eigh
(cov
)
width
, height
= value
[0], value
[1]
v
= vector
[0] / sp
.linalg
.norm
(vector
[0])
angle
= 180* np
.arctan
(v
[1] / v
[0]) / np
.pi
e
= Ellipse
(xy
=center
, width
=width
, height
=height
,
angle
=angle
, color
='m', alpha
=0.5, clip_box
= ax
.bbox
)
ax
.add_artist
(e
)
plt
.xlim
((x1_min
, x1_max
))
plt
.ylim
((x2_min
, x2_max
))
plt
.title
('DPGMM', fontsize
=20)
plt
.grid
(True)
plt
.tight_layout
()
plt
.subplots_adjust
(top
=0.9)
plt
.show
()
GMM均值
=
[[ 3.77430768 5.86579463]
[ 6.0239399 11.61448122]
[-0.98543679 10.0756839 ]]
GMM方差
=
[[[ 1.5383593 3.21210121]
[ 3.21210121 9.04107582]]
[[ 1.6667472 3.58655076]
[ 3.58655076 10.40673433]]
[[ 0.89079177 -0.02572518]
[-0.02572518 1.95106592]]]
DPGMM均值
=
[[ 4.87807808 8.69858646]
[-0.97320511 10.07279749]
[ 2.68191465 9.21436833]]
DPGMM方差
=
[[[ 2.88507577 6.60477348]
[ 6.60477348 17.92760296]]
[[ 0.9632641 -0.02865042]
[-0.02865042 1.98179578]]
[[ 5.104264 1.12039777]
[ 1.12039777 6.21286898]]]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]