Pytorch中使用ImageFolder读取数据集时忽略特定文件

合集下载
  1. 1、下载文档前请自行甄别文档内容的完整性,平台不提供额外的编辑、内容补充、找答案等附加服务。
  2. 2、"仅部分预览"的文档,不可在线预览部分如存在完整性等问题,可反馈申请退款(可完整预览的文档不适用该条件!)。
  3. 3、如文档侵犯您的权益,请联系客服反馈,我们会尽快为您处理(人工客服工作时间:9:00-18:30)。

Pytorch中使⽤ImageFolder读取数据集时忽略特定⽂件⽬录
⼀、使⽤ImageFolder读取数据集时忽略特定⽂件
⼆、ImageFolder只读取部分类别⽂件夹
⼀、使⽤ImageFolder读取数据集时忽略特定⽂件
如果事先知道需要忽略哪些⽂件,当然直接从数据集⾥删除就⾏了。

但如果需要在程序运⾏时动态确认,或者筛选规则⽐较复杂,⼈⼯不好做,就需要让ImageFolder在读取时使⽤⾃定义的筛选规则。

ImageFolder有⼀个可选参数为is_valid_file,参数类型为可调⽤的函数,该函数传⼊⼀个str参数,返回⼀个bool值。

当返回值为True时保留该⽂件,否则忽略。

例如,读取时想要忽略所有⽂件名带‘invalid’的⽂件,
代码如下:
import platform
from torchvision.datasets import ImageFolder
class Check(object):
def __init__(self,
key_word: str):
self.key_word = key_word
self.separator = '\\' if platform.system() == 'Windows' else '/'
def __call__(self,
file_name: str) -> bool:
folders = file_name.split(self.separator)
return folders[-1].find(self.key_word) < 0
dataset = ImageFolder('./data', is_valid_file=Check('invalid'))
这⾥定义了⼀个实现了__call__⽅法的Check类,相⽐于直接定义函数的好处在于可以在构造函数⾥指定想要忽略的字符,并且能够根据操作系统的不同把⽂件⽬录分隔符给确定了。

更加复杂的功能可以⾃⾏修改代码逻辑实现,但是要注意如果某个类别的所有⽂件都被筛选掉了,ImageFolder会
报FileNotFoundError错误。

如果想要忽略整个类别可以使⽤下⾯⽅法
⼆、ImageFolder只读取部分类别⽂件夹
直接继承并且重写ImageFolder类的find_classes⽅法即可
from torchvision.datasets.folder import *
from typing import *
class FilterableImageFolder(ImageFolder):
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
valid_classes: List = None
):
self.valid_classes = valid_classes
super(FilterableImageFolder, self).__init__(root, transform, target_transform, loader, is_valid_file)
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
classes = sorted( for entry in os.scandir(directory) if entry.is_dir())
#增加了这下⾯这句
classes = [valid_class for valid_class in classes if valid_class in self.valid_classes]
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
使⽤时,例如有mouse、cat、dog三个类别的数据集⽂件夹,只想读取cat和dog,
代码如下:
dataset = FilterableImageFolder('./data', valid_classes=['cat', 'dog'])
到此这篇关于Pytorch中使⽤ImageFolder读取数据集时忽略特定⽂件的⽂章就介绍到这了,更多相关ImageFolder读取数据集内容请搜索以前的⽂章或继续浏览下⾯的相关⽂章希望⼤家以后多多⽀持!。

相关文档
最新文档