2

PyTorch里eval和no_grad的关系

 3 years ago
source link: https://blog.csdn.net/yanxiangtianji/article/details/113798888
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

PyTorch里eval和no_grad的关系

首先这两者有着本质上区别。

model.eval()是用来告知model内的各个layer采取eval模式工作。这个操作主要是应对诸如dropoutbatchnorm这些在训练模式下需要采取不同操作的特殊layer。训练和测试的时候都可以开启。
torch.no_grad()则是告知自动求导引擎不要进行求导操作。这个操作的意义在于加速计算、节约内存。但是由于没有gradient,也就没有办法进行backward。所以只能在测试的时候开启。

所以在evaluate的时候,需要同时使用两者。

model = ...
dataset = ...
loss_fun = ...

# training
model.train()
for x,y in dataset:
	model.zero_grad()
	p = model(x)
	l = loss_fun(p, y)
	l.backward()

# evaluating
sum_loss = 0.0
model.eval()
with torch.no_grad():
	for x,y in dataset:
		p = model(x)
		l = loss_fun(p, y)
		sum_loss += l
print('total loss:', sum_loss)

另外no_grad还可以作为函数是修饰符来用,从而简化代码。

def train(model, dataset, loss_fun):
	model.train()
	for x,y in dataset:
		model.zero_grad()
		p = model(x)
		l = loss_fun(p, y)
		l.backward()
	
@torch.no_grad()
def test(model, dataset, loss_fun):
	sum_loss = 0.0
	model.eval()
	for x,y in dataset:
		p = model(x)
		l = loss_fun(p, y)
		sum_loss += l
	return sum_loss

# main block:
model = ...
dataset = ...
loss_fun = ...

# training
train()
# test
sum_loss = test()
print('total loss:', sum_loss)

参考:
https://pytorch.org/docs/stable/generated/torch.no_grad.html
https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK