各位用户为了找寻关于tensorflow 动态获取 BatchSzie 的大小实例的资料费劲了很多周折。这里教程网为您整理了关于tensorflow 动态获取 BatchSzie 的大小实例的相关资料,仅供查阅,以下为您介绍关于tensorflow 动态获取 BatchSzie 的大小实例的详细内容
我就废话不多说了,大家还是直接看代码吧~
? 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30import
tensorflow as tf
import
sys
with tf.variable_scope(
'ha'
):
a1
=
tf.get_variable(
'a'
, shape
=
[], dtype
=
tf.int32)
with tf.variable_scope(
'haha'
):
a2
=
tf.get_variable(
'a'
, shape
=
[], dtype
=
tf.int32)
with tf.variable_scope(
'hahaha'
):
a3
=
tf.get_variable(
'a'
, shape
=
[], dtype
=
tf.int32)
with tf.variable_scope(
'ha'
, reuse
=
True
):
# 不会创建新的变量
a4
=
tf.get_variable(
'a'
, shape
=
[], dtype
=
tf.int32)
sum
=
a1
+
a2
+
a3
+
a4
fts_s
=
tf.placeholder(tf.float32, shape
=
(
None
,
100
), name
=
'fts_s'
)
b
=
tf.zeros(shape
=
(tf.shape(fts_s)[
0
], tf.shape(fts_s)[
1
]))
concat
=
tf.concat(axis
=
1
, values
=
[fts_s, b])
init_op
=
tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
for
var
in
tf.global_variables():
print
var.name
import
numpy as np
ft_sample
=
np.ones((
10
,
100
))
con_value
=
sess.run([concat], feed_dict
=
{fts_s: ft_sample})
print
con_value[
0
].shape
results:
ha/a:0 ha/haha/a:0 ha/haha/hahaha/a:0 (10, 200)
小总结:
1: 对于未知的shape, 最常用的就是batch-size 通常是 None 代替, 那么在代码中需要用到实际数据的batch size的时候应该怎么做呢?
可以传一个tensor类型, tf.shape(Name) 返回一个tensor 类型的数据, 然后取batchsize 所在的维度即可. 这样就能根据具体的数据去获取batch size的大小
2: 对于变量命名, 要善于用 variable_scope 来规范化命名, 以及 reuse 参数可以控制共享变量
补充知识:tensorflow RNN 使用动态的batch_size
在使用tensorflow实现RNN模型时,需要初始化隐藏状态 如下:
lstm_cell_1
=
[tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE),output_keep_prob
=
dropout_keep_prob)
for
_
in
range
(NUM_LAYERS)]
cell_1
=
tf.nn.rnn_cell.MultiRNNCell(lstm_cell_1)
self
.init_state_1
=
cell_1.zero_state(
self
.batch_size,tf.float32)
如果我们直接使用超参数batch_size初始化 在使用模型预测的结果时会很麻烦。我们可以使用动态的batch_size,就是将batch_size作为一个placeholder,在运行时,将batch_size作为输入输入就可以实现根据数据量的大小使用不同的batch_size。
代码实现如下:
self.batch_size = tf.placeholder(tf.int32,[],name='batch_size')
self.state = cell.zero_state(self.batch_size,tf.float32)
以上这篇tensorflow 动态获取 BatchSzie 的大小实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/zjm750617105/article/details/82959175