guodong's blog

PhD@zhejiang university
   

tensorflow弃坑转pytorch小记

学习tensorflow1.0和2.0好久了,感觉还是pytorch香。加上好多公开代码都是基于pytorch的,遂决定弃坑tensorflow,转向pytorch。

本贴主要是记录学习时遇到的pytorch’有关函数,新手向。

torchvision

关于数据处理,首先是torchvision,官方的话:The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision.

torchvision.datasets

All datasets are subclasses of torch.utils.data.Dataset i.e, they have __getitem__ and __len__ methods implemented. Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples parallelly using torch.multiprocessing workers.

没什么好说的

torchvision.io

io操作,主要是针对video,略

torchvision.models

包含了常见的模型,比如分类模型检测模型等,eg

import torchvision.models as models
resnet18 = models.resnet18()
resnet18 = models.resnet18(pretrained=True) # 如果使用预训练的权重

其中实现基于下面的函数实现:

torch.utils.model_zoo.load_url(url, model_dir=None, map_location=None, progress=True, check_hash=False)

其中权重的储存位置:If the object is already present in model_dir, it’s deserialized and returned. The default value of model_dir is $TORCH_HOME/checkpoints where environment variable $TORCH_HOME defaults to $XDG_CACHE_HOME/torch$XDG_CACHE_HOME follows the X Design Group specification of the Linux filesytem layout, with a default value ~/.cache if not set.

一般都要用到

torchvision.ops

主要是对计算机视觉处理

比如nms,roi_align,roi_pool

torchvision.transforms

Transforms are common image transformations. They can be chained together using ComposeAdditionally, there is the torchvision.transforms.functional module.

首先是compose,整合在一起:

对于PIL图像,对于torch *Tensor都有一些函数,同时两者还可以互相转化,torchvision.transforms.ToTensor, torchvision.transforms.ToPILImage(mode=None)

还有一些通用转换,torchvision.transforms.Lambda(lambd函数式转换.

torch.utils.data

核心函数dataloader,

这里一一介绍预备知识:

Dataset Types

DataLoader 里最重要的参数就是dataset了。有两种格式,一个是map-style dataset,另一个是iterable-style datasets。

其中map-style格式的就是实现the __getitem__() and __len__() protocols,表示一个映射,即可以通过dataset[idx],来从硬盘里读第idx个图片和对应的标签。

Iterable-style 格式就是迭代器对象,这种类型的数据集特别适用于以下情况:随机读取消耗巨大甚至不大可能,并且批处理大小取决于所获取的数据。例如,可以返回从数据库,远程服务器甚至实时生成的日志中读取的数据流。

Sampler

只针对map-style的数据类型

Loading Batched and Non-Batched Data

支持自动生成batch数据。当从iterable-style数据集中,使用multi-processing时,drop_last参数丢弃最后的不足一批batch的数据。

collate_fn用来自定义校对规则,比如padding。还有一些高级用法,暂不看

Single- and Multi-process Data Loading

默认使用单核处理。为了避免在加载数据时阻塞计算代码,PyTorch提供了一个简单的开关,只需将参数num_workers设置为正整数即可执行多进程数据加载。

Single-process data loading 时,data loading may block computing,此外,单进程加载通常显示更多可读的错误跟踪,因此对于调试很有用。

Multi-process data loading 时,当调用enumerate(dataloader)时,多进程核也被创立。此时, dataset, collate_fn, 和worker_init_fn都被传送到每个worker,这意味着数据集访问及其内部IO转换(包括collat​​e_fn)在工作进程中运行。其中可以用torch.utils.data.get_worker_info()来返回工作核信息。

通常不建议在多进程加载时返回CUDA张量,因为在多进程中使用CUDA和共享CUDA张量有许多微妙之处(请参阅多进程中的CUDA)。相反,我们建议使用自动内存固定(即,设置pin_memory=True),这样可以快速地将数据传输到启用CUDA的gpu

另外不同平台也有不一样,linux很方便,windows上要注意。首先尽量用if __name__ == ‘__main__’: 来确保不会再被执行。可以把dataset和dataloader实例创建逻辑放在那里,因为不需要再执行。确保在__main__检查之外将任何自定义collat​​e_fn,worker_init_fn或数据集代码声明为顶级定义。 这样可以确保它们在工作进程中可用。 (这是必需的,因为将函数仅作为引用而不是字节码进行腌制。)

推荐这样的格式

另外关于随机性,如果是对于numpy随机,则需要torch.utils.data.get_worker_info().seed or torch.initial_seed(), and use it to seed other libraries before data loading.

Memory Pinning

对于数据加载,将pin_memory = True传递给DataLoader将自动将获取的数据张量放入固定内存中,从而能够更快地将数据传输到支持CUDA的GPU。

 

torch.nn

nn.Sequential()另一种用法

 




上一篇:
下一篇:

头像

guodong

没有评论


你先离开吧:)



发表评论

电子邮件地址不会被公开。 必填项已用*标注