PyTorch 数据读取与预处理

PyTorch 数据

PyTorch 数据读取

构造自定义的 Datasets, Dataloaders, Transforms

依赖库

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transform, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion() # interactive mode

读取数据

landmarks_frame = pd.read_csv("../data/faces/face_landmarks.csv")
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype("float").reshape(-1, 2)


print("Image name: {}".format(img_name))
print("Landmarks shape: {}".format(landmarks.shape))
print("First 4 Landmarks: {}".format(landmarks[:4]))
def show_landmarks(image, landmarks):
   """show image with landmarks"""
   plt.imshow(image)
   plt.scatter(landmarks[:, 0], landmarks[:, 1], s = 10, marker = ".", c = "r")
   plt.pause(0.001)

plt.figure()
show_landmarks(io.imread(os.path.join("../data/faces/", img_name)), landmarks)

Dataset class

class FaceLandmarksDataset(Dataset):
   """Face Landmarks dataset."""
   def __init__(self, csv_file, root_dir, transform = None):
      pass