Module pywander.datasets

Functions

def get_datasets_folder(app_name='test')
Expand source code
def get_datasets_folder(app_name='test'):
    """
    获取数据集文件夹路径
    """
    path = normalized_path(os.path.join('~', 'Pywander', app_name, 'datasets'))

    return path

获取数据集文件夹路径

def get_datasets_path(*args, app_name='test')
Expand source code
def get_datasets_path(*args, app_name='test'):
    """
    获取数据集文件路径
    """
    if not args:
        raise Exception('please input the dataset filename.')

    folder_path = get_datasets_folder(app_name=app_name)

    path = os.path.join(folder_path, *args)

    if not os.path.exists(path):
        raise Exception(f'file not exists: {path}')

    if not os.path.isfile(path):
        raise Exception(f'can not find the file: {path}')

    return path

获取数据集文件路径

def load_mnist_csv_data(*args, line_count=-1)
Expand source code
def load_mnist_csv_data(*args, line_count=-1):
    """
    train_data: https://pjreddie.com/media/files/mnist_train.csv
    test_data: https://pjreddie.com/media/files/mnist_test.csv

    灰度图现在定义是0为黑,255为白,从黑到白为从0到255的整数值。
    mnist里面的数据是个反的,为了和现代灰度图标准统一,最好将mnist的图片数据预处理下。
    """
    df = _load_mnist_csv_data(*args, line_count=line_count)

    for index, row in df.iterrows():
        label = row[0]
        value = row[1:].to_numpy('float')
        yield label, value

train_data: https://pjreddie.com/media/files/mnist_train.csv test_data: https://pjreddie.com/media/files/mnist_test.csv

灰度图现在定义是0为黑,255为白,从黑到白为从0到255的整数值。 mnist里面的数据是个反的,为了和现代灰度图标准统一,最好将mnist的图片数据预处理下。

def load_mnist_test_data(line_count=-1)
Expand source code
def load_mnist_test_data(line_count=-1):
    return load_mnist_csv_data('mnist', 'mnist_test.csv', line_count=line_count)
def load_mnist_train_data(line_count=-1)
Expand source code
def load_mnist_train_data(line_count=-1):
    return load_mnist_csv_data('mnist', 'mnist_train.csv', line_count=line_count)
def plot_mnist_image(image_data, label, ax=None, **kwargs)
Expand source code
def plot_mnist_image(image_data, label, ax=None, **kwargs):
    if ax is None:
        import matplotlib.pyplot as plt
        ax = plt.gca()

    image_data = image_data.reshape(28, 28)
    title = f"label = {label}"
    image_plot(ax, image_data, title=title, cmap='gray', interpolation='none', **kwargs)

Classes

class FashionMNIST (train: bool = True)
Expand source code
class FashionMNIST(torchvision_datasets.FashionMNIST):
    def __init__(self, train: bool = True) -> None:
        super().__init__(get_datasets_folder(), train=train, transform=ToTensor())

Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>_ Dataset.

Args

root (str or pathlib.Path): Root directory of dataset where FashionMNIST/raw/train-images-idx3-ubyte
and FashionMNIST/raw/t10k-images-idx3-ubyte exist.
train : bool, optional
If True, creates dataset from train-images-idx3-ubyte, otherwise from t10k-images-idx3-ubyte.
transform : callable, optional
A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.RandomCrop
target_transform : callable, optional
A function/transform that takes in the target and transforms it.
download : bool, optional
If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

Ancestors

  • torchvision.datasets.mnist.FashionMNIST
  • torchvision.datasets.mnist.MNIST
  • torchvision.datasets.vision.VisionDataset
  • torch.utils.data.dataset.Dataset
  • typing.Generic
class MnistDataset (train=True, line_count=-1)
Expand source code
class MnistDataset(Dataset):
    def __init__(self, train=True, line_count=-1):
        if train:
            self.df = _load_mnist_csv_data('mnist', 'mnist_train.csv', line_count=line_count)
        else:
            self.df = _load_mnist_csv_data('mnist', 'mnist_test.csv', line_count=line_count)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        value = self.df.iloc[index, 1:].to_numpy(dtype='float') / 255.0
        sample = torch.FloatTensor(value)

        label = self.df.iloc[index, 0]
        target = torch.zeros(10)
        target[label] = 1.0
        return sample, target

    def plot_image(self, index, ax=None):
        value = self.df.iloc[index, 1:].to_numpy(dtype='float')
        label = self.df.iloc[index, 0]

        plot_mnist_image(value, label, ax=ax)

An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader. Subclasses could also optionally implement :meth:__getitems__, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

Note

:class:~torch.utils.data.DataLoader by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

Ancestors

  • torch.utils.data.dataset.Dataset
  • typing.Generic

Methods

def plot_image(self, index, ax=None)
Expand source code
def plot_image(self, index, ax=None):
    value = self.df.iloc[index, 1:].to_numpy(dtype='float')
    label = self.df.iloc[index, 0]

    plot_mnist_image(value, label, ax=ax)