如何確保PyTorch訓練的結果具有再現性?

import torch
import numpy as np
def set_seed(seed=42, loader=None):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
try:
loader.sampler.generator.manual_seed(seed)
except AttributeError:
pass
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
numpy.random.seed(worker_seed)
random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
worker_init_fn=seed_worker
generator=g,
)

--

--

--

Machine Learning | Deep Learning | https://linktr.ee/yanwei

Love podcasts or audiobooks? Learn on the go with our new app.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Yanwei Liu

Yanwei Liu

Machine Learning | Deep Learning | https://linktr.ee/yanwei

More from Medium

Activation functions | Deep Learning

Activation Function | Deep learning

Python Classes and Their Use in Keras

Pytorch for Beginners 💫

Addition and Subtraction using Recurrent Neural Networks.