CNTK的MNIST例子中,它是一个py文件,首先是将图像文件转换为文本来表达,以方便CNTK中读取。
文件位置:\cntk\Examples\Image\DataSets\MNIST\mnist_utils.py
import sys
import urllib
import gzip
import shutil
import os
import struct
import numpy as np
//定义一个读数据的函数
def loadData(src, cimg):
print ('Downloading ' + src)
gzfname, h = urllib.urlretrieve(src, './delete.me') //从URL下载该数据文件
print ('Done.')
try:
with gzip.open(gzfname) as gz: //使用gzip打开它
n = struct.unpack('I', gz.read(4)) //读4个字节解包成无符号整型
# Read magic number.
if n[0] != 0x3080000: //如果文件开头不对,则认为文件不对
raise Exception('Invalid file: unexpected magic number.')
# Read number of entries.
n = struct.unpack('>I', gz.read(4))[0] //再读4个字节解包成无符号整型
if n != cimg: //如果不属于图像文件,则抛出异常
raise Exception('Invalid file: expected {0} entries.'.format(cimg))
crow = struct.unpack('>I', gz.read(4))[0] //读取4个数据为行
ccol = struct.unpack('>I', gz.read(4))[0] //读取4个字节为列
if crow != 28 or ccol != 28: //如果行与列不等于28,说明图像文件有误
raise Exception('Invalid file: expected 28 rows/cols per image.')
# Read data. //读取数据,后面需要读取的大小为 行乘以列再乘以表达每个象素需要多少字节便是总长度
res = np.fromstring(gz.read(cimg * crow * ccol), dtype = np.uint8) //
finally:
os.remove(gzfname)
return res.reshape((cimg, crow * ccol)) //返回数据时,排成2维数组表达
//读取标签文件
def loadLabels(src, cimg):
print 'Downloading ' + src
gzfname, h = urllib.urlretrieve(src, './delete.me')
print 'Done.'
try:
with gzip.open(gzfname) as gz:
n = struct.unpack('I', gz.read(4))
# Read magic number.
if n[0] != 0x1080000:
raise Exception('Invalid file: unexpected magic number.')
# Read number of entries.
n = struct.unpack('>I', gz.read(4))
if n[0] != cimg:
raise Exception('Invalid file: expected {0} rows.'.format(cimg))
# Read labels.
res = np.fromstring(gz.read(cimg), dtype = np.uint8)
finally:
os.remove(gzfname)
return res.reshape((cimg, 1))//同样返回2维数组表达
if __name__ == "__main__":
trnData = loadData('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 60000) //下载图像文件
trnLbl = loadLabels('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 60000) //下载标签文件
trn = np.hstack((trnLbl, trnData)) //将数组进行合并
print 'Writing train text file…'
np.savetxt(r'./../Data/Train-28×28.txt', trn, fmt = '%u', delimiter='\t') //将数组用文本形式保存
print 'Done.'
testData = loadData('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 10000) //下载测试数据
testLbl = loadLabels('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 10000) //下载测试标签
test = np.hstack((testLbl, testData)) //将数组进行合并
print 'Writing test text file…'
np.savetxt(r'./../Data/Test-28×28.txt', test, fmt = '%u', delimiter='\t') //转写为 txt 文件
print 'Done.'
然而这有个问题,在国内下得太慢,所以修改了下文件,将:
def loadData(src, cimg):
print ('Downloading ' + src)
gzfname, h = urllib.urlretrieve(src, './delete.me') //从URL下载该数据文件
均修改为:
def loadLabels(gzfname, cimg):
#print ('Downloading ' + src)
#gzfname, h = urlretrieve(src, './delete.me')
然后将install_mnist.py文件修改为:
from __future__ import print_function
import mnist_utils as ut
if __name__ == "__main__":
train = ut.load('./train-images-idx3-ubyte.gz',
'./train-labels-idx1-ubyte.gz', 60000)
print ('Writing train text file…')
ut.savetxt(r'./Train-28x28_cntk_text.txt', train)
print ('Done.')
test = ut.load('./t10k-images-idx3-ubyte.gz',
'./t10k-labels-idx1-ubyte.gz', 10000)
print ('Writing test text file…')
ut.savetxt(r'./Test-28x28_cntk_text.txt', test)
print ('Done.')
找个方法把MNIST的数据下载放在该目录下,这样就避免下了载速度上的问题。
4条评论