Python深度學習筆記(三):使用Keras載入常見的電腦視覺資料集
5 min readMay 7, 2019
用Keras載入四種常見的資料:
- MNIST: 分類手寫數字 (10 classes)
- Fashion-MNIST:分類服裝 (10 classes)
- CIFAR-10: 分類小圖片物件 (10 classes)
- CIFAR-100: 分類常見物件 (100 classes)
MNIST
from keras.datasets import mnistfrom matplotlib import pyplot# load dataset(trainX, trainy), (testX, testy) = mnist.load_data()# summarize loaded datasetprint('Train: X=%s, y=%s' % (trainX.shape, trainy.shape))print('Test: X=%s, y=%s' % (testX.shape, testy.shape))# plot first few imagesfor i in range(9): #define subplot
pyplot.subplot(330 + 1 + i)
#plot raw pixel data
pyplot.imshow(trainX[i], cmap=pyplot.get_cmap('gray'))# show the figurepyplot.show()
Fashion-MNIST
# example of loading the fashion mnist datasetfrom matplotlib import pyplotfrom keras.datasets import fashion_mnist# load dataset(trainX, trainy), (testX, testy) = fashion_mnist.load_data()# summarize loaded datasetprint('Train: X=%s, y=%s' % (trainX.shape, trainy.shape))print('Test: X=%s, y=%s' % (testX.shape, testy.shape))# plot first few imagesfor i in range(9):
#define subplot
pyplot.subplot(330 + 1 + i)
#plot raw pixel data
pyplot.imshow(trainX[i], cmap=pyplot.get_cmap('gray'))# show the figurepyplot.show()
CIFAR-10 Dataset
# example of loading the cifar10 dataset
from matplotlib import pyplot
from keras.datasets import cifar10
# load dataset
(trainX, trainy), (testX, testy) = cifar10.load_data()
# summarize loaded dataset
print(‘Train: X=%s, y=%s’ % (trainX.shape, trainy.shape))
print(‘Test: X=%s, y=%s’ % (testX.shape, testy.shape))
# plot first few images
for i in range(9):
#define subplot
pyplot.subplot(330 + 1 + i)
#plot raw pixel data
pyplot.imshow(trainX[i])
# show the figure
pyplot.show()
CIFAR-100 Dataset
# example of loading the cifar100 dataset
from matplotlib import pyplot
from keras.datasets import cifar100
# load dataset
(trainX, trainy), (testX, testy) = cifar100.load_data()
# summarize loaded dataset
print('Train: X=%s, y=%s' % (trainX.shape, trainy.shape))
print('Test: X=%s, y=%s' % (testX.shape, testy.shape))
# plot first few images
for i in range(9):
#define subplot
pyplot.subplot(330 + 1 + i)
#plot raw pixel data
pyplot.imshow(trainX[i])
# show the figure
pyplot.show()