tf.cast用法

tech2025-04-22  47

tf.cast:将张量转换为新类型

tf.cast(     x, dtype, name=None )

该操作将x(如果是Tensor)或x.values (如果是SparseTensor或IndexedSlices)强制转换为dtype。

 例子:

import tensorflow as tf with tf.Session() as sess: x = tf.constant([1.8, 2.2], dtype=tf.float32) print(x) b = tf.dtypes.cast(x, tf.int32) print(b) # 输出结果: # Tensor("Const:0", shape=(2,), dtype=float32) # Tensor("Cast:0", shape=(2,), dtype=int32)

操作支持的数据类型(x和dtype的)为: uint8,uint16,uint32,uint64,int8,int16,int32,int64, float16,float32,float64,complex64,complex128,bfloat16。如果从复杂类型(complex64,complex128)转换为实类型,则仅返回x的实部。如果将实类型转换为复杂类型(complex64,complex128),则将返回值的虚部设置为0。这里对复杂类型的处理与numpy的行为相匹配。 

Args

x一个Tensor或者SparseTensor或IndexedSlices数字型的。这可能是uint8,uint16,uint32,uint64,int8,int16,int32, int64,float16,float32,float64,complex64,complex128, bfloat16。dtype目标类型。支持的dtypes列表与x相同。name操作的名称(可选)。

Returns

一个Tensor或SparseTensor或IndexedSlices具有与x相同的形状和dtype相同的类型。

Raises

TypeError如果x无法转换为dtype。
最新回复(0)