该操作将x(如果是Tensor)或x.values (如果是SparseTensor或IndexedSlices)强制转换为dtype。
例子:
import tensorflow as tf with tf.Session() as sess: x = tf.constant([1.8, 2.2], dtype=tf.float32) print(x) b = tf.dtypes.cast(x, tf.int32) print(b) # 输出结果: # Tensor("Const:0", shape=(2,), dtype=float32) # Tensor("Cast:0", shape=(2,), dtype=int32)操作支持的数据类型(x和dtype的)为: uint8,uint16,uint32,uint64,int8,int16,int32,int64, float16,float32,float64,complex64,complex128,bfloat16。如果从复杂类型(complex64,complex128)转换为实类型,则仅返回x的实部。如果将实类型转换为复杂类型(complex64,complex128),则将返回值的虚部设置为0。这里对复杂类型的处理与numpy的行为相匹配。