医工互联

 找回密码
 注册[Register]

手机动态码快速登录

手机号快速登录

微信登录

微信扫一扫,快速登录

QQ登录

只需一步,快速开始

查看: 956|回复: 0
收起左侧

迁移学习之Multi-Domain Adaptation多领域自适应常用数据集PACS介绍

[复制链接]

  离线 

发表于 2022-10-10 16:01:22 | 显示全部楼层 |阅读模式 <
PACS数据集

Paper:Self-supervised Domain Adaptation for Computer Vision Tasks
GitHub:https://github.com/robertofranceschi/Domain-adaptation-on-PACS-dataset
数据集下载:https://github.com/MachineLearning2020/Homework3-PACS/tree/master/PACS


  • PACS数据集总共9991张图片,每张图片3x227x227
  • 7 classes:Dog, Elephant, Giraffe, Guitar, Horse, House, Person
  • 4 domains: Art painting, Cartoon, Photo and Sketch.
  • Photo (1,670 images), Art Painting (2,048 images), Cartoon (2,344 images) and Sketch (3,929 images)
170502y8hmv5p4mvvqmbod.png

用Pytorch加载PACS数据集

PACS原始数据集目录结果:
170503fkfh1zbuj8u6vvik.png

  1. from torch.utils.data import Dataset, DataLoader
  2. import matplotlib.pyplot as plt
  3. import os
  4. import random
  5. from PIL import Image
  6. import torch
  7. import numpy as np
  8. from torchvision.transforms import transforms
  9. from sklearn.model_selection import train_test_split
  10. import os
  11. import torchvision.transforms as transforms
  12. from PIL import Image
  13. from torch.utils.data import Dataset, DataLoader
  14. class PACS(Dataset):
  15.     def __init__(self, root_path, domain, train=True, transform=None, target_transform=None):
  16.         self.root = f"{root_path}/{domain}"
  17.         self.train = train
  18.         self.transform = transform
  19.         self.target_transform = target_transform
  20.         label_name_list = os.listdir(self.root)
  21.         self.label = []
  22.         self.data = []
  23.         if not os.path.exists(f"{root_path}/precessed"):
  24.             os.makedirs(f"{root_path}/precessed")
  25.         if os.path.exists(f"{root_path}/precessed/{domain}_data.pt") and os.path.exists(
  26.                 f"{root_path}/precessed/{domain}_label.pt"):
  27.             print(f"Load {domain} data and label from cache.")
  28.             self.data = torch.load(f"{root_path}/precessed/{domain}_data.pt")
  29.             self.label = torch.load(f"{root_path}/precessed/{domain}_label.pt")
  30.         else:
  31.             print(f"Getting {domain} datasets")
  32.             for index, label_name in enumerate(label_name_list):
  33.                 label_name_2_index = {
  34.                     'dog': 0,
  35.                     'elephant': 1,
  36.                     'giraffe': 2,
  37.                     'guitar': 3,
  38.                     'horse': 4,
  39.                     'house': 5,
  40.                     'person': 6,
  41.                 }
  42.                 images_list = os.listdir(f"{self.root}/{label_name}")
  43.                 for img_name in images_list:
  44.                     img = Image.open(f"{self.root}/{label_name}/{img_name}").convert('RGB')
  45.                     img = np.array(img)
  46.                     self.label.append(label_name_2_index[label_name])
  47.                     if self.transform is not None:
  48.                         img = self.transform(img)
  49.                     self.data.append(img)
  50.             self.data = torch.stack(self.data)
  51.             self.label = torch.tensor(self.label, dtype=torch.long)
  52.             torch.save(self.data, f"{root_path}/precessed/{domain}_data.pt")
  53.             torch.save(self.label, f"{root_path}/precessed/{domain}_label.pt")
  54.     def __getitem__(self, index):
  55.         img, target = self.data[index], self.label[index]
  56.         if self.transform is not None:
  57.             img = self.transform(img)
  58.         if self.target_transform is not None:
  59.             target = self.target_transform(target)
  60.         return img, target
  61.     def __len__(self):
  62.         return len(self.data)
  63. def get_pacs_domain(root_path=f"{DATA_PATH}/PACS", domain='art_painting', verbose=False):
  64.     transform = transforms.Compose([
  65.         transforms.ToPILImage(),
  66.         transforms.ToTensor(),
  67.         # transforms.Resize((224, 224)),
  68.         transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
  69.     ])
  70.     all_data = PACS(root_path, domain, transform=transform)
  71.     # train:test=8:2
  72.     x_train, x_test, y_train, y_test = train_test_split(all_data.data.numpy(), all_data.label.numpy(),
  73.                                                         test_size=0.20, random_state=42)
  74.     return x_train, y_train, x_test, y_test
复制代码
其他写法:https://github.com/ValerioDiEugenio/DomainAdaptation-PACSDataset/blob/main/DomainAdaptation.ipynb

最后以dog类别为例,用Python代码展示四种不同风格的图片
Python可视化图片数据集的代码

修改dir_path为对应的文件夹,dir_path = f"{DATA_PATH}/PACS/art_painting/dog"
  1. import os
  2. import matplotlib.pyplot as plt
  3. import random
  4. from PIL import Image
  5. def plotPics(data, h=3, w=3, filename="out.jpg"):
  6.     fig, ax_array = plt.subplots(h, w, figsize=(15, 15))
  7.     axes = ax_array.flatten()
  8.     for i, ax in enumerate(axes):
  9.         ri = random.randint(0, len(data) - 1)
  10.         ax.imshow(data[ri], cmap=plt.cm.gray)
  11.     plt.setp(axes, xticks=[], yticks=[], frame_on=False)
  12.     fig.tight_layout()
  13.     fig.savefig(filename)
  14.     plt.show()
  15. DATA_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/raw_data"))
  16. dir_path = f"{DATA_PATH}/PACS/art_painting/dog"
  17. data = []
  18. for pic in os.listdir(dir_path):
  19.     data.append(Image.open(f"{dir_path}/{pic}"))
  20. plotPics(data, h=5, w=5)
复制代码
art_painting风格

170503imynf2x1mhdl2udk.png

sketch风格

170504mgrgrjkjqvirik4h.png

cartoon风格

170505sx61wl5wl7zxm1hk.png

photo风格

170506dssk811skxsj8sk4.png


来源:https://blog.csdn.net/qq_43827595/article/details/121345640
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
回复

使用道具 举报

提醒:禁止复制他人回复等『恶意灌水』行为,违者重罚!
您需要登录后才可以回帖 登录 | 注册[Register] 手机动态码快速登录 微信登录

本版积分规则

发布主题 快速回复 收藏帖子 返回列表 客服中心 搜索
简体中文 繁體中文 English 한국 사람 日本語 Deutsch русский بالعربية TÜRKÇE português คนไทย french

QQ|RSS订阅|小黑屋|处罚记录|手机版|联系我们|Archiver|医工互联 |粤ICP备2021178090号 |网站地图

GMT+8, 2024-11-24 17:39 , Processed in 0.255684 second(s), 62 queries .

Powered by Discuz!

Copyright © 2001-2023, Discuz! Team.

快速回复 返回顶部 返回列表