, CIFAR . , , plt.savefig(fname, format='png', dpi=1000)
import numpy as np
import matplotlib.pyplot as plt
def reshape_and_print(self, cifar_data):
rows = cols = np.sqrt(cifar_data.shape[0]).astype(np.int32)
imh = imw = np.sqrt(cifar_data.shape[1] // 3).astype(np.int32)
timg = cifar_data.reshape(rows * cols, 3, imh * imh).transpose(1, 0, 2)
timg = timg.reshape(3, rows, cols, imh, imw).swapaxes(2, 3)
timg = timg.reshape(3, rows * imh, cols * imw).transpose(1, 2, 0)
plt.imshow(timg)
plt.show()
, , , :
import gzip
import pickle
import numpy as np
import matplotlib.pyplot as plt
class DataSet(object):
def __init__(self, seed=42, setsize=10000):
self.seed = seed
np.random.seed(seed)
train_set, test_set = self.load_data()
self.split_data(train_set, test_set, setsize)
def split_data(self, data_set, test_set, split_size):
permutation = np.random.permutation(data_set.shape[0])
self.train = data_set[permutation[:split_size]]
self.valid = data_set[permutation[split_size:split_size * 2]]
self.test = test_set[:split_size]
def reshape_for_print(self, data):
raise NotImplemented
def load_data(self):
raise NotImplemented
def show_all_imgs(self, data):
raise NotImplemented
class CIFAR(DataSet):
def load_data(self):
with open('./data/cifar-100-python/train', 'rb') as f:
data = pickle.load(f, encoding='latin1')
train_set = data['data'].astype(np.float32) / 255.0
with open('./data/cifar-100-python/test', 'rb') as f:
data = pickle.load(f, encoding='latin1')
test_set = data['data'].astype(np.float32) / 255.0
return train_set, test_set
def reshape_for_print(self, data):
gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
imh = imw = np.sqrt(data.shape[1] // 3).astype(np.int32)
timg = data.reshape(gh * gw, 3, imh * imh).transpose(1, 0, 2)
timg = timg.reshape(3, gh, gw, imh, imw).swapaxes(2, 3)
timg = timg.reshape(3, gh * imh, gw * imw).transpose(1, 2, 0)
return timg
def show_all_imgs(self, data):
timg = self.reshape_for_print(data)
plt.imshow(timg)
plt.show()
class MNIST(DataSet):
def load_data(self):
with gzip.open('./data/mnist.pkl.gz', 'rb') as f:
train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
return train_set[0], test_set[0]
def reshape_for_print(self, data):
gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
imh = imw = np.sqrt(data.shape[1]).astype(np.int32)
timg = data.reshape(gh, gw, imh, imw).swapaxes(1, 2)
timg = timg.reshape(gh * imh, gw * imw)
return timg
def show_all_imgs(self, data):
timg = self.reshape_for_print(data)
plt.imshow(timg, cmap=plt.cm.gray)
plt.show()