02 Softmax And Classification

获取Fashion-MNIST训练集和读取数据

这是一个多类图像分类数据集。

torchvision包是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:

  1. torchvision.datasets: 一些加载数据的函数及常用的数据集接口;

  2. torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;

  3. torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;

  4. torchvision.utils: 其他的一些有用的方法。

获取训练数据集的方法:

import torch
import torchvision
import torchvision.transforms as transforms

mnist_train = torchvision.datasets.FashionMNIST(root='/home/kesci/input/FashionMNIST2065', train=True, download=True, transform=transforms.ToTensor())

class torchvision.datasets.FashionMNIST(root, train=True, transform=None, target_transform=None, download=False)

  • root(string)– 数据集的根目录,其中存放processed/training.pt和processed/test.pt文件。

  • train(bool, 可选)– 如果设置为True,从training.pt创建数据集,否则从test.pt创建。

  • download(bool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。

  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop。

  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。

然后再读取数据

对多维Tensor按维度操作

gather的用法

交叉熵损失

Last updated

Was this helpful?