(12)tensorflow高级操作函数

tech2022-08-13  121

高级操作函数

功能函数代码根据索引号抽样tf.gather(x,index,axis)根据索引号采集多个样本tf.gather_nd(x,index)掩码采样tf.boolean_mask(x, mask, axis)条件取样tf.where(cond,a,b)刷新张量tf.scatter_nd(indices, updates, shape)生成二维网格的采样点坐标tf.meshgrid (x,y)

根据索引号收集数据

tf.gather ,tf.gather(x,index,axis)index使用列表传入 import tensorflow as tf x = tf.random.normal([4,3]) print(x) a = tf.gather(x,[0,2],axis=0) print(a) out: tf.Tensor( [[ 0.6171281 0.04664925 -1.240589 ] [ 0.3884052 -0.13463594 -0.93812966] [ 0.21023595 0.02191296 -0.60263956] [-0.59707534 -0.9989014 0.93808985]], shape=(4, 3), dtype=float32) tf.Tensor( [[ 0.6171281 0.04664925 -1.240589 ] [ 0.21023595 0.02191296 -0.60263956]], shape=(2, 3), dtype=float32) tf.gather_nd(x,index)index 多个列表,指明元素位置 import tensorflow as tf x = tf.random.normal([4,3,5]) print(x) a = tf.gather_nd(x,[[1,2,4],[2,0,4],[2,2,4]]) print(a) out: tf.Tensor( [[[ 0.48533428 -0.4602297 0.10718699 1.4634823 1.0696448 ] [-0.10646328 1.4204148 0.49611762 -0.7021836 -0.78922015] [ 0.27958906 0.77579606 1.1943241 0.13299868 1.8589126 ]] [[-1.2397368 -0.8667516 1.2485082 -0.36232617 -0.01963608] [-0.3774731 -0.9508353 0.17664587 0.8245683 -1.4487312 ] [-0.40236002 1.1351906 0.281336 -1.2188344 -0.5182614 ]] [[-1.3938109 1.2935148 -0.8586935 -1.6478568 -0.7925164 ] [-0.5961709 -0.14946866 -0.34493566 0.70288414 -0.32704452] [ 1.1950675 -0.22109848 0.21102418 -0.3180953 -0.4069654 ]] [[-2.664851 -0.6875616 -0.75624275 1.4966092 -0.70387363] [ 0.19424789 -0.05890714 0.6462316 0.36867452 -1.6294032 ] [-0.2791568 -1.4880462 0.28855824 0.18338643 1.1877102 ]]], shape=(4, 3, 5), dtype=float32) tf.Tensor([-0.5182614 -0.7925164 -0.4069654], shape=(3,), dtype=float32)

掩码方式采样

tf.boolean_mask(x, mask, axis)mask为布尔值列表 import tensorflow as tf x = tf.random.normal([4,3,5]) b = tf.boolean_mask(x,[False,False,True,True],axis=0) print(b) out: tf.Tensor( [[[ 0.9954934 0.2346662 -1.5152481 -0.23908533 1.0049196 ] [-1.9355854 2.198667 0.7091041 0.8390751 -0.539098 ] [-0.8521507 -0.5687699 0.5074792 0.7154239 0.14445351]] [[-2.0306711 -0.02072303 -1.4181378 -0.01865017 -0.26235464] [ 1.1109879 -1.454516 -0.4335605 -1.37627 -0.7934608 ] [-0.22541425 -1.0131035 -0.4386002 0.47543052 0.71290684]]], shape=(2, 3, 5), dtype=float32)

条件取样

tf.where(cond,a,b)cond与a与b的shape相同cond为布尔值张量True选a对应位置元素,false选b对应位置元素 import tensorflow as tf a = tf.random.normal([4,4]) b = tf.random.normal([4,4]) x = tf.where([[True,False,False,False],[True,False,True,False],[True,False,True,False],[True,False,True,False]],a,b) print(a,'\n',b,'\n',x) out: tf.Tensor( [[ 0.528277 -0.934028 -1.82246 0.6171794 ] [ 1.3643358 -0.10889477 -0.42628863 1.3473538 ] [-0.8530917 -1.1393764 1.1789635 -0.47311786] [-0.06923807 -0.7026398 -0.5189239 0.12308616]], shape=(4, 4), dtype=float32) tf.Tensor( [[-0.11811212 -0.8660801 1.0821908 -0.80681664] [ 0.7331498 -0.31200057 -0.7272013 -0.04565765] [ 0.05147053 -0.39878348 1.404608 -0.50364506] [-0.48536187 -0.80131745 -0.72818786 1.5861374 ]], shape=(4, 4), dtype=float32) tf.Tensor( [[ 0.528277 -0.8660801 1.0821908 -0.80681664] [ 1.3643358 -0.31200057 -0.42628863 -0.04565765] [-0.8530917 -0.39878348 1.1789635 -0.50364506] [-0.06923807 -0.80131745 -0.5189239 1.5861374 ]], shape=(4, 4), dtype=float32)

刷新张量的部分数据

tf.scatter_nd(indices, updates, shape)白板的形状通过 shape 参数表示,需要刷新的数据索引号通过 indices 表示,新数据为 updates。 根据 indices 给出的索引位置将 updates 中新的数据依次写入白板中,并返回更新后的结果张量。 import tensorflow as tf indices = tf.constant([[2],[3]]) updata = tf.constant([5,4]) a = tf.scatter_nd(indices,updata,[10]) print(a) out: tf.Tensor([0 0 5 4 0 0 0 0 0 0], shape=(10,), dtype=int32)

生成二维网格的采样点坐标

tf.meshgrid(x,y)将x,y的列表元素枚举法配对 import tensorflow as tf x = tf.linspace(-7.,8,3) y = tf.linspace(-8.,8,3) x,y = tf.meshgrid(x,y) print(x) print(y) out: tf.Tensor( [[-7. 0.5 8. ] [-7. 0.5 8. ] [-7. 0.5 8. ]], shape=(3, 3), dtype=float32) tf.Tensor( [[-8. -8. -8.] [ 0. 0. 0.] [ 8. 8. 8.]], shape=(3, 3), dtype=float32)
最新回复(0)