【PyTorch】PyTorch使用LMDB数据库加速文件读取

合集下载
  1. 1、下载文档前请自行甄别文档内容的完整性,平台不提供额外的编辑、内容补充、找答案等附加服务。
  2. 2、"仅部分预览"的文档,不可在线预览部分如存在完整性等问题,可反馈申请退款(可完整预览的文档不适用该条件!)。
  3. 3、如文档侵犯您的权益,请联系客服反馈,我们会尽快为您处理(人工客服工作时间:9:00-18:30)。

【PyTorch 】PyTorch 使⽤LMDB 数据库加速⽂件读取
PyTorch 使⽤LMDB
数据库加速⽂件读取
对于数据库的了解较少,⽂章中⼤部分的介绍主要来⾃于各种博客和LMDB 的⽂档,但是⽂档中的介绍,默认是已经了解了数据库的许多知识,这导致⽬前只能囫囵吞枣,待之后仔细了解后再重新补充内容。

Caffe 使⽤LMDB 来存放训练/测试⽤的数据集,以及使⽤⽹络提取出的feature (为了⽅便,以下还是统称数据集)。

数据集的结构很简单,就是⼤量的矩阵/向量数据平铺开来。

数据之间没有什么关联,数据内没有复杂的对象结构,就是向量和矩阵。

既然数据并不复杂,Caffe 就选择了LMDB 这个简单的数据库来存放数据。

LMDB 的全称是Lightning Memory-Mapped Database ,闪电般的内存映射数据库。

它⽂件结构简单,⼀个⽂件夹,⾥⾯⼀个数据⽂件,⼀个锁⽂件。

数据随意复制,随意传输。

它的访问简单,不需要运⾏单独的数据库管理进程,只要在访问数据的代码⾥引⽤LMDB 库,访问时给⽂件路径即可。

图像数据集归根究底从图像⽂件⽽来。

引⼊数据库存放数据集,是为了减少IO 开销。

读取⼤量⼩⽂件的开销是⾮常⼤的,尤其是在机械硬盘上。

LMDB 的整个数据库放在⼀个⽂件⾥,避免了⽂件系统寻址的开销。

LMDB 使⽤内存映射的⽅式访问⽂件,使得⽂件内寻址的开销⾮常⼩,使⽤指针运算就能实现。

数据库单⽂件还能减少数据集复制/传输过程的开销。

⼀个⼏万,⼏⼗万⽂件的数据集,不管是直接复制,还是打包再解包,过程都⽆⽐漫长⽽痛苦。

LMDB 数据库只有⼀个⽂件,你的介质有多块,就能复制多快,不会因为⽂件多⽽慢如蜗⽜。

为什么要把图像数据转换成⼤的⼆进制⽂件?
简单来说,是因为读写⼩⽂件的速度太慢。

那么,不禁要问,图像数据也是⼆进制⽂件,单个⼤的⼆进制⽂件例如 LMDB ⽂件也是⼆进制⽂件,为什么单个图像读写速度就慢了呢?这⾥分两种情况解释。

1. 机械硬盘的情况:机械硬盘的每次读写启动时间⽐较长,例如磁头的寻道时间占⽐很⾼,因此,如果单个⼩⽂件读写,尤其是随机读写单个⼩⽂件的时候,这个寻道时间占⽐就会很⾼,最后导致⼤量读写⼩⽂件的时候时间会很浪费;
2. NFS 的情况:在 NFS 的场景下,系统的⼀次读写⾸先要进⾏上百次的⽹络通讯,并且这个通讯次数和⽂件的⼤⼩⽆关。

因此,如果是读写⼩⽂件,这个⽹络通讯时间占据了整个读写时间的⼤部分。

固态硬盘的情况下应该也会有⼀些类似的开销,⽬前没有研究过。

总⽽⾔之,使⽤LMDB 可以为我们的数据读取进⾏加速。

LMDB
主要类
这是数据库环境的结构。

⼀个环境可能包含多个数据库,所有数据库都驻留在同⼀共享内存映射和基础磁盘⽂件中。

要写⼊环境,必须创建事务(Transaction )。

允许同时进⾏⼀次写⼊事务,但是即使存在写⼊事务,读取事务的数量也没有限制。

⼏个重要的实例⽅法:
begin(db=None, parent=None, write=False, buffers=False): 可以调⽤事务类 lmdb.Transaction
open_db(key=None, txn=None, reverse_key=False, dupsort=False, create=True, integerkey=False, integerdup=False,dupfixed=False): 打开⼀个数据库,返回⼀个不透明的句柄。

重复Environment.open_db() 调⽤相同的名称将返回相同的句柄。

作为⼀个特殊情况,主数据库总是开放的。

命名数据库是通过在主数据库中存储⼀个特殊的描述符来实现的。

环境中的所有数据库共享相同的⽂件。

因为描述符存在于主数据库中,所以如果已经存在与数据库名称匹配的 key ,创建命名数据库的尝试将失败。

此外,查找和枚举可以看到key 。

如果主数据库keyspace
与命名数据库使⽤的名称冲突,则将主数据库的内容移动到另⼀个命名数据库。

这和事务对象有关。

pip install imdb
>>> env = lmdb.open('/tmp/test', max_dbs=2)
>>> with env.begin(write=True) as txn
... txn.put('somename', 'somedata')
>>> # Error: database cannot share name of existing key!
>>> subdb = env.open_db('somename')
class lmdb.Transaction(env, db=None, parent=None, write=False, buffers=False)。

所有操作都需要事务句柄,事务可以是只读或读写的。

写事务可能不会跨越线程。

事务对象实现了上下⽂管理器协议,因此即使⾯对未处理的异常,也可以可靠地释放事务:
# Transaction aborts correctly:
with env.begin(write=True) as txn:
crash()
# Transaction commits automatically:
with env.begin(write=True) as txn:
txn.put('a', 'b')
这个类的实例包含着很多有⽤的操作⽅法。

abort(): 中⽌挂起的事务。

重复调⽤abort()在之前成功的commit()或abort()后或者在相关环境关闭后是没有效果的。

commit(): 提交挂起的事务。

cursor(db=None): Shortcut for lmdb.Cursor(db, self)
delete(key, value='', db=None): Delete a key from the database.
key: The key to delete.
value:如果数据库是以dupsort = True打开的,并且value不是空的bytestring,则删除仅与此(key, value)对匹配的元素,否则该key 的所有值都将被删除。

Returns True if at least one key was deleted.
drop(db, delete=True): 删除命名数据库中的所有键,并可选地删除命名数据库本⾝。

删除命名数据库会导致其不可⽤,并使现有cursors⽆效。

get(key, default=None, db=None): 获取匹配键的第⼀个值,如果键不存在,则返回默认值。

cursor必须⽤于获取dupsort = True数据库中的key的所有值。

id(): 返回事务的ID。

这将返回与此事务相关联的标识符。

对于只读事务,这对应于正在读取的快照; 并发读取器通常具有相同的事务ID。

pop(key, db=None): 使⽤临时cursor调⽤Cursor.pop()。

db: 要操作的命名数据库。

如果未指定,默认为事务构造函数被给定的数据库。

put(key, value, dupdata=True, overwrite=True, append=False, db=None): 存储⼀条记录(record),如果记录被写⼊,则返回True,否则返回False,以指⽰key已经存在并且overwrite = False。

成功后,cursor位于新记录上。

key: Bytestring key to store.
value: Bytestring value to store.
dupdata: 如果True,并且数据库是⽤dupsort = True打开的,如果给定key已经存在,则添加键值对作为副本。

否则覆盖任何现有匹配的key。

overwrite: If False , do not overwrite any existing matching key.
append: 如果为True,则将对附加到数据库末尾,⽽不⾸先⽐较其顺序。

附加不⼤于现有最⾼key的key将导致损坏。

db: 要操作的命名数据库。

如果未指定,默认为事务构造函数被给定的数据库。

replace(key, value, db=None): 使⽤临时cursor调⽤Cursor.replace() .
db: Named database to operate on. If unspecified, defaults to the database given to the Transaction constructor.
stat(db): Return statistics like Environment.stat() , except for a single DBI. db must be a database handle returned by open_db() .
class lmdb.Cursor(db, txn)是⽤于在数据库中导航(navigate)的结构。

db: Database to navigate.
txn: Transaction to navigate.
As a convenience, Transaction.cursor() can be used to quickly return a cursor:
>>> env = lmdb.open('/tmp/foo')
>>> child_db = env.open_db('child_db')
>>> with env.begin() as txn:
... cursor = txn.cursor() # Cursor on main database.
... cursor2 = txn.cursor(child_db) # Cursor on child database.
游标以未定位的状态开始。

如果在这种状态下使⽤iternext()或iterprev(),那么迭代将分别从开始处和结束处开始。

迭代器直接使⽤游标定位,这意味着在同⼀游标上存在多个迭代器时会产⽣奇怪的⾏为。

从Python绑定的⾓度来看,⼀旦任何扫描或查找⽅法(例如next()、prev_nodup()、set_range() )返回False或引发异常,游标将返回未定位状态。

这主要是为了确保在⾯对任何错误条件时语义的安全性和⼀致性。

当游标返回到未定位的状态时,它的key()和value()返回空字符串,表⽰没有活动的位置,尽管在内部,LMDB游标可能仍然有⼀个有效的位置。

这可能会导致在迭代dupsort=True数据库的key时出现⼀些令⼈吃惊的⾏为,因为iternext_dup()等⽅法将导致游标显⽰为未定位,尽管它返回False只是为了表明当前键没有更多的值。

在这种情况下,简单地调⽤next()将导致在下⼀个可⽤键处继续迭代。

This behaviour may change in future.
Iterator methods such as iternext() and iterprev() accept keys and values arguments. If both are True , then the value of item() is yielded on each iteration. If only keys is True , key() is yielded, otherwise only value() is yielded.
在迭代之前,游标可能定位在数据库中的任何位置
不需要迭代来导航,有时会导致丑陋或低效的代码。

在迭代顺序不明显的情况下,或者与正在读取的数据相关的情况下,使⽤ set_key() 、set_range() 、 key() 、 value() 和 item() 可能是更好的选择。

⼏个实例⽅法:
set_key(key): Seek exactly to key, returning True on success or False if the exact key was not found. 对于 set_key() ,空字节串是错误的。

对于使⽤ dupsort=True 打开的数据库,移动到键的第⼀个值(复制)。

set_range(key): Seek to the first key greater than or equal to key , returning True on success, or False to indicate key was past end of database. Behaves like first() if key is the empty bytestring. 对于使⽤ dupsort=True 打开的数据库,移动到键的第⼀个值(复制)。

get(key, default=None): Equivalent to set_key() , except value() is returned when key is found, otherwise default.
item(): Return the current (key, value) pair.key(): Return the current key.
value(): Return the current value.
操作流程
概况地讲,操作LMDB 的流程是:通过 env = lmdb.open() 打开环境通过 txn = env.begin() 建⽴事务通过 txn.put(key, value) 进⾏插⼊和修改通过 txn.delete(key) 进⾏删除通过 txn.get(key) 进⾏查询通过 txn.cursor() 进⾏遍历
通过 mit() 提交更改
这⾥要注意:
1. put 和 delete 后⼀定注意要 commit ,不然根本没有存进去
2. 每⼀次 commit 后,需要再定义⼀次 txn=env.begin(write=True)
>>> with env.begin() as txn:
... cursor = txn.cursor()
... if not cursor.set_range('5'): # Position at first key >= '5'.
... print('Not found!')
... else:
... for key, value in cursor: # Iterate from first key >= '5'.
... print((key, value))
>>> # Record the path from a child to the root of a tree.
>>> path = ['child14123']
>>> while path[-1] != 'root':
... assert cursor.set_key(path[-1]), \
... 'Tree is broken! Path: %s' % (path,)
... path.append(cursor.value())
#!/usr/bin/env python
import lmdb
import os, sys
def initialize():
env = lmdb.open("students");
return env;
def insert(env, sid, name):
txn = env.begin(write = True);
txn.put(str(sid), name);
mit();
def delete(env, sid):
txn = env.begin(write = True);
txn.delete(str(sid));
mit();
def update(env, sid, name):
txn = env.begin(write = True);
txn.put(str(sid), name);
mit();
def search(env, sid):
txn = env.begin();
name = txn.get(str(sid));
return name;
创建图像数据集
改写为:
def display(env):
txn = env.begin();
cur = txn.cursor();
for key, value in cur:
print (key, value);
env = initialize();
print "Insert 3 records."
insert(env, 1, "Alice");
insert(env, 2, "Bob");
insert(env, 3, "Peter");
display(env);
print "Delete the record where sid = 1."
delete(env, 1);
display(env);
print "Update the record where sid = 3."
update(env, 3, "Mark");
display(env);
print "Get the name of student whose sid = 3."
name = search(env, 3);
print name;
env.close();
os.system("rm -r students");
import glob
import os
import pickle
import sys
import cv2
import lmdb
import numpy as np
from tqdm import tqdm
def main(mode):
proj_root = '/home/lart/coding/TIFNet'
datasets_root = '/home/lart/Datasets/'
lmdb_path = os.path.join(proj_root, 'datasets/ECSSD.lmdb')
data_path = os.path.join(datasets_root, 'RGBSaliency', 'ECSSD/Image')
if mode == 'creating':
opt = {
'name': 'TrainSet',
'img_folder': data_path,
'lmdb_save_path': lmdb_path,
'commit_interval': 100, # After commit_interval images, lmdb commits
'num_workers': 8,
}
general_image_folder(opt)
elif mode == 'testing':
test_lmdb(lmdb_path, index=1)
def general_image_folder(opt):
"""
Create lmdb for general image folders
If all the images have the same resolution, it will only store one copy of resolution info.
Otherwise, it will store every resolution info.
"""
img_folder = opt['img_folder']
lmdb_save_path = opt['lmdb_save_path']
meta_info = {'name': opt['name']}
if not lmdb_save_path.endswith('.lmdb'):
raise ValueError("lmdb_save_path must end with 'lmdb'.")
if os.path.exists(lmdb_save_path):
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
sys.exit(1)
# read all the image paths to a list
print('Reading image path list ...')
all_img_list = sorted(glob.glob(os.path.join(img_folder, '*')))
# cache the filename, 这⾥的⽂件名必须是ascii字符
keys = []
for img_path in all_img_list:
keys.append(os.path.basename(img_path))
# create lmdb environment
# 估算⼤概的映射空间⼤⼩
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
print('data size per image is: ', data_size_per_img)
data_size = data_size_per_img * len(all_img_list)
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
# map_size:
# Maximum size database may grow to; used to size the memory mapping. If database grows larger # than map_size, an exception will be raised and the user must close and reopen Environment.
# write data to lmdb
txn = env.begin(write=True)
resolutions = []
tqdm_iter = tqdm(enumerate(zip(all_img_list, keys)), total=len(all_img_list), leave=False)
for idx, (path, key) in tqdm_iter:
tqdm_iter.set_description('Write {}'.format(key))
key_byte = key.encode('ascii')
data = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if data.ndim == 2:
H, W = data.shape
C = 1
else:
H, W, C = data.shape
resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W))
txn.put(key_byte, data)
if (idx + 1) % opt['commit_interval'] == 0:
mit()
# commit 之后需要再次 begin
txn = env.begin(write=True)
mit()
env.close()
print('Finish writing lmdb.')
# create meta information
# check whether all the images are the same size
assert len(keys) == len(resolutions)
if len(set(resolutions)) <= 1:
meta_info['resolution'] = [resolutions[0]]
meta_info['keys'] = keys
print('All images have the same resolution. Simplify the meta info.')
else:
meta_info['resolution'] = resolutions
meta_info['keys'] = keys
print('Not all images have the same resolution. Save meta info for each image.')
pickle.dump(meta_info, open(os.path.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
print('Finish creating lmdb meta info.')
def test_lmdb(dataroot, index=1):
env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False)
meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), "rb"))
print('Name: ', meta_info['name'])
print('Resolution: ', meta_info['resolution'])
print('# keys: ', len(meta_info['keys']))
# read one image
key = meta_info['keys'][index]
print('Reading {} for test.'.format(key))
with env.begin(write=False) as txn:
buf = txn.get(key.encode('ascii'))
img_flat = np.frombuffer(buf, dtype=np.uint8)
C, H, W = [int(s) for s in meta_info['resolution'][index].split('_')]
img = img_flat.reshape(H, W, C)
dWindow('Test')
cv2.imshow('Test', img)
cv2.waitKeyEx()
配合
这⾥仅对训练集进⾏LMDB 处理,测试机依旧使⽤的原始的读取图⽚的⽅式。

if __name__ == "__main__":
# mode = creating or testing
main(mode='creating')
import os
import pickle
import lmdb
import numpy as np
from PIL import Image
from prefetch_generator import BackgroundGenerator
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from utils import joint_transforms
def _get_paths_from_lmdb(dataroot):
"""get image path list from lmdb meta info"""
meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'),
'rb'))
paths = meta_info['keys']
sizes = meta_info['resolution']
if len(sizes) == 1:
sizes = sizes * len(paths)
return paths, sizes
def _read_img_lmdb(env, key, size):
"""read image from lmdb with key (w/ and w/o fixed size)
size: (C, H, W) tuple"""
with env.begin(write=False) as txn:
buf = txn.get(key.encode('ascii'))
img_flat = np.frombuffer(buf, dtype=np.uint8)
C, H, W = size
img = img_flat.reshape(H, W, C)
return img
def _make_dataset(root, prefix=('.jpg', '.png')):
img_path = os.path.join(root, 'Image')
gt_path = os.path.join(root, 'Mask')
img_list = [
os.path.splitext(f)[0] for f in os.listdir(gt_path)
if f.endswith(prefix[1])
]
return [(os.path.join(img_path, img_name + prefix[0]),
os.path.join(gt_path, img_name + prefix[1]))
for img_name in img_list]
class TestImageFolder(Dataset):
def __init__(self, root, in_size, prefix):
self.imgs = _make_dataset(root, prefix=prefix)
self.test_img_trainsform = pose([
# 输⼊的如果是⼀个tuple ,则按照数据缩放,但是如果是⼀个数字,则按⽐例缩放到短边等于该值
transforms.Resize((in_size, in_size)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def __getitem__(self, index):
img_path, gt_path = self.imgs[index]
img = Image.open(img_path).convert('RGB')
img_name = (img_path.split(os.sep)[-1]).split('.')[0]
img = self.test_img_trainsform(img)
return img, img_name
def __len__(self):
return len(self.imgs)
class TrainImageFolder(Dataset):
def __init__(self, root, in_size, scale=1.5, use_bigt=False):
e_bigt = use_bigt
self.in_size = in_size
self.root = root
self.train_joint_transform = joint_pose([
joint_transforms.JointResize(in_size),
joint_transforms.RandomHorizontallyFlip(),
joint_transforms.RandomRotate(10)
])
self.train_img_transform = pose([
transforms.ColorJitter(0.1, 0.1, 0.1),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]) # 处理的是Tensor
])
# ToTensor 操作会将 PIL.Image 或形状为 H×W×D,数值范围为 [0, 255] 的 np.ndarray 转换为形状为 D×H×W, # 数值范围为 [0.0, 1.0] 的 torch.Tensor。

self.train_target_transform = transforms.ToTensor()
self.gt_root = '/home/lart/coding/TIFNet/datasets/DUTSTR/DUTSTR_GT.lmdb'
self.img_root = '/home/lart/coding/TIFNet/datasets/DUTSTR/DUTSTR_IMG.lmdb'
self.paths_gt, self.sizes_gt = _get_paths_from_lmdb(self.gt_root)
self.paths_img, self.sizes_img = _get_paths_from_lmdb(self.img_root)
self.gt_env = lmdb.open(self.gt_root, readonly=True, lock=False, readahead=False,
meminit=False)
self.img_env = lmdb.open(self.img_root, readonly=True, lock=False, readahead=False,
meminit=False)
def __getitem__(self, index):
gt_path = self.paths_gt[index]
img_path = self.paths_img[index]
gt_resolution = [int(s) for s in self.sizes_gt[index].split('_')]
img_resolution = [int(s) for s in self.sizes_img[index].split('_')]
img_gt = _read_img_lmdb(self.gt_env, gt_path, gt_resolution)
img_img = _read_img_lmdb(self.img_env, img_path, img_resolution)
if img_img.shape[-1] != 3:
img_img = np.repeat(img_img, repeats=3, axis=-1)
img_img = img_img[:, :, [2, 1, 0]] # bgr => rgb
img_gt = np.squeeze(img_gt, axis=2)
gt = Image.fromarray(img_gt, mode='L')
img = Image.fromarray(img_img, mode='RGB')
img, gt = self.train_joint_transform(img, gt)
gt = self.train_target_transform(gt)
img = self.train_img_transform(img)
if e_bigt:
gt = gt.ge(0.5).float() # ⼆值化
img_name = self.paths_img[index]
return img, gt, img_name
def __len__(self):
return len(self.paths_img)
class DataLoaderX(DataLoader):
def __iter__(self):
return BackgroundGenerator(super(DataLoaderX, self).__iter__())
⽂档:
关于LMDB的介绍:
代码⽰例:。

相关文档
最新文档