Module pywander.datasets
Functions
def get_datasets_folder(app_name='test')
-
Expand source code
def get_datasets_folder(app_name='test'): """ 获取数据集文件夹路径 """ path = normalized_path(os.path.join('~', 'Pywander', app_name, 'datasets')) return path
获取数据集文件夹路径
def get_datasets_path(*args, app_name='test')
-
Expand source code
def get_datasets_path(*args, app_name='test'): """ 获取数据集文件路径 """ if not args: raise Exception('please input the dataset filename.') folder_path = get_datasets_folder(app_name=app_name) path = os.path.join(folder_path, *args) if not os.path.exists(path): raise Exception(f'file not exists: {path}') if not os.path.isfile(path): raise Exception(f'can not find the file: {path}') return path
获取数据集文件路径
def load_mnist_csv_data(*args, line_count=-1)
-
Expand source code
def load_mnist_csv_data(*args, line_count=-1): """ train_data: https://pjreddie.com/media/files/mnist_train.csv test_data: https://pjreddie.com/media/files/mnist_test.csv 灰度图现在定义是0为黑,255为白,从黑到白为从0到255的整数值。 mnist里面的数据是个反的,为了和现代灰度图标准统一,最好将mnist的图片数据预处理下。 """ df = _load_mnist_csv_data(*args, line_count=line_count) for index, row in df.iterrows(): label = row[0] value = row[1:].to_numpy('float') yield label, value
train_data: https://pjreddie.com/media/files/mnist_train.csv test_data: https://pjreddie.com/media/files/mnist_test.csv
灰度图现在定义是0为黑,255为白,从黑到白为从0到255的整数值。 mnist里面的数据是个反的,为了和现代灰度图标准统一,最好将mnist的图片数据预处理下。
def load_mnist_test_data(line_count=-1)
-
Expand source code
def load_mnist_test_data(line_count=-1): return load_mnist_csv_data('mnist', 'mnist_test.csv', line_count=line_count)
def load_mnist_train_data(line_count=-1)
-
Expand source code
def load_mnist_train_data(line_count=-1): return load_mnist_csv_data('mnist', 'mnist_train.csv', line_count=line_count)
def plot_mnist_image(image_data, label, ax=None, **kwargs)
-
Expand source code
def plot_mnist_image(image_data, label, ax=None, **kwargs): if ax is None: import matplotlib.pyplot as plt ax = plt.gca() image_data = image_data.reshape(28, 28) title = f"label = {label}" image_plot(ax, image_data, title=title, cmap='gray', interpolation='none', **kwargs)
Classes
class FashionMNIST (train: bool = True)
-
Expand source code
class FashionMNIST(torchvision_datasets.FashionMNIST): def __init__(self, train: bool = True) -> None: super().__init__(get_datasets_folder(), train=train, transform=ToTensor())
Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>
_ Dataset.Args
- root (str or
pathlib.Path
): Root directory of dataset whereFashionMNIST/raw/train-images-idx3-ubyte
- and
FashionMNIST/raw/t10k-images-idx3-ubyte
exist. train
:bool
, optional- If True, creates dataset from
train-images-idx3-ubyte
, otherwise fromt10k-images-idx3-ubyte
. transform
:callable
, optional- A function/transform that
takes in a PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
target_transform
:callable
, optional- A function/transform that takes in the target and transforms it.
download
:bool
, optional- If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
Ancestors
- torchvision.datasets.mnist.FashionMNIST
- torchvision.datasets.mnist.MNIST
- torchvision.datasets.vision.VisionDataset
- torch.utils.data.dataset.Dataset
- typing.Generic
- root (str or
class MnistDataset (train=True, line_count=-1)
-
Expand source code
class MnistDataset(Dataset): def __init__(self, train=True, line_count=-1): if train: self.df = _load_mnist_csv_data('mnist', 'mnist_train.csv', line_count=line_count) else: self.df = _load_mnist_csv_data('mnist', 'mnist_test.csv', line_count=line_count) def __len__(self): return len(self.df) def __getitem__(self, index): value = self.df.iloc[index, 1:].to_numpy(dtype='float') / 255.0 sample = torch.FloatTensor(value) label = self.df.iloc[index, 0] target = torch.zeros(10) target[label] = 1.0 return sample, target def plot_image(self, index, ax=None): value = self.df.iloc[index, 1:].to_numpy(dtype='float') label = self.df.iloc[index, 0] plot_mnist_image(value, label, ax=ax)
An abstract class representing a :class:
Dataset
.All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:
__getitem__
, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__
, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler
implementations and the default options of :class:~torch.utils.data.DataLoader
. Subclasses could also optionally implement :meth:__getitems__
, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.Note
:class:
~torch.utils.data.DataLoader
by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.Ancestors
- torch.utils.data.dataset.Dataset
- typing.Generic
Methods
def plot_image(self, index, ax=None)
-
Expand source code
def plot_image(self, index, ax=None): value = self.df.iloc[index, 1:].to_numpy(dtype='float') label = self.df.iloc[index, 0] plot_mnist_image(value, label, ax=ax)