PyTorch 数据读取与预处理¶
PyTorch 数据¶
PyTorch 数据读取¶
构造自定义的 Datasets, Dataloaders, Transforms¶
依赖库¶
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transform, utils
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
plt.ion() # interactive mode
读取数据¶
landmarks_frame = pd.read_csv("../data/faces/face_landmarks.csv")
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype("float").reshape(-1, 2)
print("Image name: {}".format(img_name))
print("Landmarks shape: {}".format(landmarks.shape))
print("First 4 Landmarks: {}".format(landmarks[:4]))
def show_landmarks(image, landmarks):
"""show image with landmarks"""
plt.imshow(image)
plt.scatter(landmarks[:, 0], landmarks[:, 1], s = 10, marker = ".", c = "r")
plt.pause(0.001)
plt.figure()
show_landmarks(io.imread(os.path.join("../data/faces/", img_name)), landmarks)
Dataset class¶
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform = None):
pass