高级操作函数
功能函数代码
根据索引号抽样tf.gather(x,index,axis)根据索引号采集多个样本tf.gather_nd(x,index)掩码采样tf.boolean_mask(x, mask, axis)条件取样tf.where(cond,a,b)刷新张量tf.scatter_nd(indices, updates, shape)生成二维网格的采样点坐标tf.meshgrid (x,y)
根据索引号收集数据
tf.gather ,tf.gather(x,index,axis)index使用列表传入
import tensorflow
as tf
x
= tf
.random
.normal
([4,3])
print(x
)
a
= tf
.gather
(x
,[0,2],axis
=0)
print(a
)
out
:
tf
.Tensor
(
[[ 0.6171281 0.04664925 -1.240589 ]
[ 0.3884052 -0.13463594 -0.93812966]
[ 0.21023595 0.02191296 -0.60263956]
[-0.59707534 -0.9989014 0.93808985]], shape
=(4, 3), dtype
=float32
)
tf
.Tensor
(
[[ 0.6171281 0.04664925 -1.240589 ]
[ 0.21023595 0.02191296 -0.60263956]], shape
=(2, 3), dtype
=float32
)
tf.gather_nd(x,index)index 多个列表,指明元素位置
import tensorflow
as tf
x
= tf
.random
.normal
([4,3,5])
print(x
)
a
= tf
.gather_nd
(x
,[[1,2,4],[2,0,4],[2,2,4]])
print(a
)
out
:
tf
.Tensor
(
[[[ 0.48533428 -0.4602297 0.10718699 1.4634823 1.0696448 ]
[-0.10646328 1.4204148 0.49611762 -0.7021836 -0.78922015]
[ 0.27958906 0.77579606 1.1943241 0.13299868 1.8589126 ]]
[[-1.2397368 -0.8667516 1.2485082 -0.36232617 -0.01963608]
[-0.3774731 -0.9508353 0.17664587 0.8245683 -1.4487312 ]
[-0.40236002 1.1351906 0.281336 -1.2188344 -0.5182614 ]]
[[-1.3938109 1.2935148 -0.8586935 -1.6478568 -0.7925164 ]
[-0.5961709 -0.14946866 -0.34493566 0.70288414 -0.32704452]
[ 1.1950675 -0.22109848 0.21102418 -0.3180953 -0.4069654 ]]
[[-2.664851 -0.6875616 -0.75624275 1.4966092 -0.70387363]
[ 0.19424789 -0.05890714 0.6462316 0.36867452 -1.6294032 ]
[-0.2791568 -1.4880462 0.28855824 0.18338643 1.1877102 ]]], shape
=(4, 3, 5), dtype
=float32
)
tf
.Tensor
([-0.5182614 -0.7925164 -0.4069654], shape
=(3,), dtype
=float32
)
掩码方式采样
tf.boolean_mask(x, mask, axis)mask为布尔值列表
import tensorflow
as tf
x
= tf
.random
.normal
([4,3,5])
b
= tf
.boolean_mask
(x
,[False,False,True,True],axis
=0)
print(b
)
out
:
tf
.Tensor
(
[[[ 0.9954934 0.2346662 -1.5152481 -0.23908533 1.0049196 ]
[-1.9355854 2.198667 0.7091041 0.8390751 -0.539098 ]
[-0.8521507 -0.5687699 0.5074792 0.7154239 0.14445351]]
[[-2.0306711 -0.02072303 -1.4181378 -0.01865017 -0.26235464]
[ 1.1109879 -1.454516 -0.4335605 -1.37627 -0.7934608 ]
[-0.22541425 -1.0131035 -0.4386002 0.47543052 0.71290684]]], shape
=(2, 3, 5), dtype
=float32
)
条件取样
tf.where(cond,a,b)cond与a与b的shape相同cond为布尔值张量True选a对应位置元素,false选b对应位置元素
import tensorflow
as tf
a
= tf
.random
.normal
([4,4])
b
= tf
.random
.normal
([4,4])
x
= tf
.where
([[True,False,False,False],[True,False,True,False],[True,False,True,False],[True,False,True,False]],a
,b
)
print(a
,'\n',b
,'\n',x
)
out:
tf
.Tensor
(
[[ 0.528277 -0.934028 -1.82246 0.6171794 ]
[ 1.3643358 -0.10889477 -0.42628863 1.3473538 ]
[-0.8530917 -1.1393764 1.1789635 -0.47311786]
[-0.06923807 -0.7026398 -0.5189239 0.12308616]], shape
=(4, 4), dtype
=float32
)
tf
.Tensor
(
[[-0.11811212 -0.8660801 1.0821908 -0.80681664]
[ 0.7331498 -0.31200057 -0.7272013 -0.04565765]
[ 0.05147053 -0.39878348 1.404608 -0.50364506]
[-0.48536187 -0.80131745 -0.72818786 1.5861374 ]], shape
=(4, 4), dtype
=float32
)
tf
.Tensor
(
[[ 0.528277 -0.8660801 1.0821908 -0.80681664]
[ 1.3643358 -0.31200057 -0.42628863 -0.04565765]
[-0.8530917 -0.39878348 1.1789635 -0.50364506]
[-0.06923807 -0.80131745 -0.5189239 1.5861374 ]], shape
=(4, 4), dtype
=float32
)
刷新张量的部分数据
tf.scatter_nd(indices, updates, shape)白板的形状通过 shape 参数表示,需要刷新的数据索引号通过 indices 表示,新数据为 updates。 根据 indices 给出的索引位置将 updates 中新的数据依次写入白板中,并返回更新后的结果张量。
import tensorflow
as tf
indices
= tf
.constant
([[2],[3]])
updata
= tf
.constant
([5,4])
a
= tf
.scatter_nd
(indices
,updata
,[10])
print(a
)
out
:
tf
.Tensor
([0 0 5 4 0 0 0 0 0 0], shape
=(10,), dtype
=int32
)
生成二维网格的采样点坐标
tf.meshgrid(x,y)将x,y的列表元素枚举法配对
import tensorflow
as tf
x
= tf
.linspace
(-7.,8,3)
y
= tf
.linspace
(-8.,8,3)
x
,y
= tf
.meshgrid
(x
,y
)
print(x
)
print(y
)
out
:
tf
.Tensor
(
[[-7. 0.5 8. ]
[-7. 0.5 8. ]
[-7. 0.5 8. ]], shape
=(3, 3), dtype
=float32
)
tf
.Tensor
(
[[-8. -8. -8.]
[ 0. 0. 0.]
[ 8. 8. 8.]], shape
=(3, 3), dtype
=float32
)