记录一下自己在tensorflow2.x的call()方法中,打算改变张量的形状,所以使用了:
x = tf.reshape(x,(batch_size,-1))
这样的代码,发生报错。
原因是在构建动态图的时候tensorflow2.0内部还没有创建具体的变量(张量)类型,所以此时如果使用上面的操作,那么batch_size维度和-1自动计算该大小的维度都为None,相当于返回shape(None,None),那么进行后续的张量操作将会报错。
解决办法是通过这里不能通过自动计算维度来进行相关的操作,将:
x = tf.reshape(x,(batch_size,-1))
换成:
x = tf.reshape(x,(batch_size,self.dim))
这里的self.dim是具体的维度,可通过计算得到
但是如果这里变成:
x = tf.reshape(x,(batch_size,-1,self.dim))
self.dim是具体的维度,那么这里也是可以的
总之这里只要变换之后不是全部的None即可。