PyTorch Cheat Sheet¶
1.Imports¶
# General import torch # root package from torch.utils.data import Dataset, Dataloader # dataset representation and loading # Neural Network API from torch import Tensor import torch.nn as nn import torch.nn.functional as F import torch.autograd as autograd import torch.optim as optim from torch.jit import script, trace # Torchscript and JIT torch.jit.trace() @script # ONNX torch.onnx.export(model, dummy data, xxxx.proto) model = onnx.load("alexnet.proto) onnx.checker.check_model(model) onnx.helper.printable_graph(model.graph) # Vision from torchvision import datasets, models, transforms import torchvision.transforms as transforms # Distributed Training import torch.distributed as dist from torch.multiprocessing import Process
2.Tensors¶
import torch # Creation x = torch.randn(*size) x = torch.[ones|zeros](*size) x = torch.tensor(L) y = x.clone() with torch.no_grad(): requires_grad = True # Dimensionality x.size() x = torch.cat(tensor_seq, dim = 0) y = x.view(a, b, ...) y - x.view(-1,a) y = x.transpose(a, b) y = x.permute(*dims) y = x.unsqueeze(dim) y = x.unsqueeze(dim = 2) y = x.squeeze() y = x.squeeze(dim = 1) # Algebra ret = A.mm(B) ret = A.mv(x) x = x.t() # GUP Usage torch.cuda.is_available x = x.cuda() x = x.cpu() if not args.disable_cuda and torch.cuda.is_available(): args.device = torch.device("cuda") else: args.device = torch.device("cpu") net.to(device) x = x.to(device)
3.Deep Learning¶
import torch.nn as nn nn.Linear(m, n) nn.ConvXd(m, n, s) nn.MaxPoolXd(s) nn.BatchNormXd nn.RNN nn.LSTM nn.GRU nn.Dropout(p = 0.5, inplace = False) nn.Dropout2d(p = 0.5, inplace = False) nn.Embedding(num_embeddings, embedding_dim) # Loss Function nn.X # Activation Function nn.X # Optimizers opt = optim.x(model.parameters(), ...) opt.step() optim.X # Learning rate scheduling scheduler = optim.X(optimizer, ...) scheduler.step() optim.lr_scheduler.X
4.Data Utilities¶
# Dataset Dataset TensorDataset Concat Dataset # Dataloaders and DataSamplers DataLoader(dataset, batch_size = 1, ...) sampler.Sampler(dataset, ...) sampler.XSampler where ...