CNTK研究(一):MNIST的文件转换

CNTK的MNIST例子中,它是一个py文件,首先是将图像文件转换为文本来表达,以方便CNTK中读取。

文件位置:\cntk\Examples\Image\DataSets\MNIST\mnist_utils.py

import sys
import urllib
import gzip
import shutil
import os
import struct
import numpy as np

//定义一个读数据的函数

def loadData(src, cimg):
    print ('Downloading ' + src)
    gzfname, h = urllib.urlretrieve(src, './delete.me')  //从URL下载该数据文件
    print ('Done.')
    try:
        with gzip.open(gzfname) as gz:         //使用gzip打开它
            n = struct.unpack('I', gz.read(4))    //读4个字节解包成无符号整型
            # Read magic number.
            if n[0] != 0x3080000:     //如果文件开头不对,则认为文件不对
                raise Exception('Invalid file: unexpected magic number.')
            # Read number of entries.
            n = struct.unpack('>I', gz.read(4))[0]  //再读4个字节解包成无符号整型
            if n != cimg:       //如果不属于图像文件,则抛出异常
                raise Exception('Invalid file: expected {0} entries.'.format(cimg)) 
            crow = struct.unpack('>I', gz.read(4))[0] //读取4个数据为行
            ccol = struct.unpack('>I', gz.read(4))[0]  //读取4个字节为列
            if crow != 28 or ccol != 28:    //如果行与列不等于28,说明图像文件有误
                raise Exception('Invalid file: expected 28 rows/cols per image.')
            # Read data.  //读取数据,后面需要读取的大小为 行乘以列再乘以表达每个象素需要多少字节便是总长度
            res = np.fromstring(gz.read(cimg * crow * ccol), dtype = np.uint8) //
    finally:
        os.remove(gzfname)
    return res.reshape((cimg, crow * ccol))  //返回数据时,排成2维数组表达

//读取标签文件

def loadLabels(src, cimg):
    print 'Downloading ' + src
    gzfname, h = urllib.urlretrieve(src, './delete.me')
    print 'Done.'
    try:
        with gzip.open(gzfname) as gz:
            n = struct.unpack('I', gz.read(4))
            # Read magic number.
            if n[0] != 0x1080000:
                raise Exception('Invalid file: unexpected magic number.')
            # Read number of entries.
            n = struct.unpack('>I', gz.read(4))
            if n[0] != cimg:
                raise Exception('Invalid file: expected {0} rows.'.format(cimg))
            # Read labels.
            res = np.fromstring(gz.read(cimg), dtype = np.uint8)
    finally:
        os.remove(gzfname)
    return res.reshape((cimg, 1))//同样返回2维数组表达


if __name__ == “__main__”:
    trnData = loadData('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 60000) //下载图像文件
    trnLbl = loadLabels('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 60000) //下载标签文件
    trn = np.hstack((trnLbl, trnData)) //将数组进行合并
    print 'Writing train text file…'
    np.savetxt(r'./../Data/Train-28×28.txt', trn, fmt = '%u', delimiter='\t') //将数组用文本形式保存
    print 'Done.'
    testData = loadData('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 10000) //下载测试数据
    testLbl = loadLabels('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 10000) //下载测试标签
    test = np.hstack((testLbl, testData)) //将数组进行合并
    print 'Writing test text file…'
    np.savetxt(r'./../Data/Test-28×28.txt', test, fmt = '%u', delimiter='\t') //转写为 txt 文件
    print 'Done.'

 

然而这有个问题,在国内下得太慢,所以修改了下文件,将:

def loadData(src, cimg):
    print ('Downloading ' + src)
    gzfname, h = urllib.urlretrieve(src, './delete.me')  //从URL下载该数据文件

均修改为:

def loadLabels(gzfname, cimg):
    #print ('Downloading ' + src)
    #gzfname, h = urlretrieve(src, './delete.me')

然后将install_mnist.py文件修改为:

from __future__ import print_function
import mnist_utils as ut

if __name__ == “__main__”:
    train = ut.load('./train-images-idx3-ubyte.gz',
        './train-labels-idx1-ubyte.gz', 60000)
    print ('Writing train text file…')
    ut.savetxt(r'./Train-28x28_cntk_text.txt', train)
    print ('Done.')
    test = ut.load('./t10k-images-idx3-ubyte.gz',
        './t10k-labels-idx1-ubyte.gz', 10000)
    print ('Writing test text file…')
    ut.savetxt(r'./Test-28x28_cntk_text.txt', test)
    print ('Done.')

找个方法把MNIST的数据下载放在该目录下,这样就避免下了载速度上的问题。

 

《CNTK研究(一):MNIST的文件转换》有4个想法

        1. 洛杉矶机房能有如此快的速度?是仅仅host1plus的有这个速度还是?在下现在用的linode的东京机房,如果不错的话,考虑迁到美国机房去。

发表评论

电子邮件地址不会被公开。 必填项已用*标注