Is there any way I can download the pre-trained models available in PyTorch to a specific path?
As, @dennlinger mentioned in his answer : torch.utils.model_zoo
, is being internally called when you load a pre-trained model.
More specifically, the method: torch.utils.model_zoo.load_url()
is being called every time a pre-trained model is loaded. The documentation for the same, mentions:
The default value of
model_dir
is$TORCH_HOME/models
where$TORCH_HOME
defaults to~/.torch
.The default directory can be overridden with the
$TORCH_HOME
environment variable.
This can be done as follows:
import torch
import torchvision
import os
# Suppose you are trying to load pre-trained resnet model in directory- models\resnet
os.environ['TORCH_HOME'] = 'models\\resnet' #setting the environment variable
resnet = torchvision.models.resnet18(pretrained=True)
I came across the above solution by raising an issue in the PyTorch's GitHub repository: https://github.com/pytorch/vision/issues/616
This led to an improvement in the documentation i.e. the solution mentioned above.
Yes, you can simply copy the urls and use wget
to download it to the desired path. Here's an illustration:
For AlexNet:
$ wget -c https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth
For Google Inception (v3):
$ wget -c https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth
For SqueezeNet:
$ wget -c https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth
For MobileNetV2:
$ wget -c https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
For DenseNet201:
$ wget -c https://download.pytorch.org/models/densenet201-c1103571.pth
For MNASNet1_0:
$ wget -c https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth
For ShuffleNetv2_x1.0:
$ wget -c https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth
If you want to do it in Python, then use something like:
In [11]: from six.moves import urllib
# resnet 101 host url
In [12]: url = "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth"
# download and rename the file to `resnet_101.pth`
In [13]: urllib.request.urlretrieve(url, "resnet_101.pth")
Out[13]: ('resnet_101.pth', <http.client.HTTPMessage at 0x7f7fd7f53438>)
P.S: You can find the download URLs in the respective python modules of torchvision.models