离线
|
PACS数据集
Paper:Self-supervised Domain Adaptation for Computer Vision Tasks
GitHub:https://github.com/robertofranceschi/Domain-adaptation-on-PACS-dataset
数据集下载:https://github.com/MachineLearning2020/Homework3-PACS/tree/master/PACS
- PACS数据集总共9991张图片,每张图片3x227x227
- 7 classes:Dog, Elephant, Giraffe, Guitar, Horse, House, Person
- 4 domains: Art painting, Cartoon, Photo and Sketch.
- Photo (1,670 images), Art Painting (2,048 images), Cartoon (2,344 images) and Sketch (3,929 images)
用Pytorch加载PACS数据集
PACS原始数据集目录结果:
- from torch.utils.data import Dataset, DataLoader
- import matplotlib.pyplot as plt
- import os
- import random
- from PIL import Image
- import torch
- import numpy as np
- from torchvision.transforms import transforms
- from sklearn.model_selection import train_test_split
- import os
- import torchvision.transforms as transforms
- from PIL import Image
- from torch.utils.data import Dataset, DataLoader
- class PACS(Dataset):
- def __init__(self, root_path, domain, train=True, transform=None, target_transform=None):
- self.root = f"{root_path}/{domain}"
- self.train = train
- self.transform = transform
- self.target_transform = target_transform
- label_name_list = os.listdir(self.root)
- self.label = []
- self.data = []
- if not os.path.exists(f"{root_path}/precessed"):
- os.makedirs(f"{root_path}/precessed")
- if os.path.exists(f"{root_path}/precessed/{domain}_data.pt") and os.path.exists(
- f"{root_path}/precessed/{domain}_label.pt"):
- print(f"Load {domain} data and label from cache.")
- self.data = torch.load(f"{root_path}/precessed/{domain}_data.pt")
- self.label = torch.load(f"{root_path}/precessed/{domain}_label.pt")
- else:
- print(f"Getting {domain} datasets")
- for index, label_name in enumerate(label_name_list):
- label_name_2_index = {
- 'dog': 0,
- 'elephant': 1,
- 'giraffe': 2,
- 'guitar': 3,
- 'horse': 4,
- 'house': 5,
- 'person': 6,
- }
- images_list = os.listdir(f"{self.root}/{label_name}")
- for img_name in images_list:
- img = Image.open(f"{self.root}/{label_name}/{img_name}").convert('RGB')
- img = np.array(img)
- self.label.append(label_name_2_index[label_name])
- if self.transform is not None:
- img = self.transform(img)
- self.data.append(img)
- self.data = torch.stack(self.data)
- self.label = torch.tensor(self.label, dtype=torch.long)
- torch.save(self.data, f"{root_path}/precessed/{domain}_data.pt")
- torch.save(self.label, f"{root_path}/precessed/{domain}_label.pt")
- def __getitem__(self, index):
- img, target = self.data[index], self.label[index]
- if self.transform is not None:
- img = self.transform(img)
- if self.target_transform is not None:
- target = self.target_transform(target)
- return img, target
- def __len__(self):
- return len(self.data)
- def get_pacs_domain(root_path=f"{DATA_PATH}/PACS", domain='art_painting', verbose=False):
- transform = transforms.Compose([
- transforms.ToPILImage(),
- transforms.ToTensor(),
- # transforms.Resize((224, 224)),
- transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
- ])
- all_data = PACS(root_path, domain, transform=transform)
- # train:test=8:2
- x_train, x_test, y_train, y_test = train_test_split(all_data.data.numpy(), all_data.label.numpy(),
- test_size=0.20, random_state=42)
- return x_train, y_train, x_test, y_test
复制代码 其他写法:https://github.com/ValerioDiEugenio/DomainAdaptation-PACSDataset/blob/main/DomainAdaptation.ipynb
最后以dog类别为例,用Python代码展示四种不同风格的图片
Python可视化图片数据集的代码
修改dir_path为对应的文件夹,dir_path = f"{DATA_PATH}/PACS/art_painting/dog"
- import os
- import matplotlib.pyplot as plt
- import random
- from PIL import Image
- def plotPics(data, h=3, w=3, filename="out.jpg"):
- fig, ax_array = plt.subplots(h, w, figsize=(15, 15))
- axes = ax_array.flatten()
- for i, ax in enumerate(axes):
- ri = random.randint(0, len(data) - 1)
- ax.imshow(data[ri], cmap=plt.cm.gray)
- plt.setp(axes, xticks=[], yticks=[], frame_on=False)
- fig.tight_layout()
- fig.savefig(filename)
- plt.show()
- DATA_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/raw_data"))
- dir_path = f"{DATA_PATH}/PACS/art_painting/dog"
- data = []
- for pic in os.listdir(dir_path):
- data.append(Image.open(f"{dir_path}/{pic}"))
- plotPics(data, h=5, w=5)
复制代码 art_painting风格
sketch风格
cartoon风格
photo风格
来源:https://blog.csdn.net/qq_43827595/article/details/121345640
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |
|