02 Softmax And Classification
获取Fashion-MNIST训练集和读取数据
这是一个多类图像分类数据集。
torchvision包是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:
torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
torchvision.utils: 其他的一些有用的方法。
获取训练数据集的方法:
class torchvision.datasets.FashionMNIST(root, train=True, transform=None, target_transform=None, download=False)
root(string)– 数据集的根目录,其中存放processed/training.pt和processed/test.pt文件。
train(bool, 可选)– 如果设置为True,从training.pt创建数据集,否则从test.pt创建。
download(bool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。
transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop。
target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。
然后再读取数据
对多维Tensor按维度操作
gather的用法
交叉熵损失
Last updated
Was this helpful?