解決PyTorch DataLoader中PIL TypeError: Cannot handle this data type問題

最近嘗試將白平衡技術作為一種Data Augmentation方法時,我先自定義了一個class,並將其傳入transforms.Compose來做資料擴增,卻遇到PIL TypeError的錯誤,之前在寫的時候並沒有發生這樣的問題。檢查了資料型態後發現,原來PIL圖片需要是uint8格式,Image.fromarray才能順利運作。

class WhiteBalance(object):
def __call__(self, img):
wbModel = wb_srgb.WBsRGB(gamut_mapping=2, upgraded=0)
img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
img = wbModel.correctImage(img)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = (img * 255).round().astype(np.uint8)
return Image.fromarray(img)
def __repr__(self):
return self.__class__.__name__+'()'

之前沒注意到這件事,頂多在 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)結束後,就直接 return Image.fromarray(img),也沒遇到特別的錯誤。

但是本次採用的方法wbModel.correctImage(img),回傳的是float32格式,而不是uint8格式,才會產生PIL TypeError: Cannot handle this data type。

因此,透過img = (img * 255).round().astype(np.uint8)的方式,先將被轉成0~1之間的圖片乘上255後,得到完整的顏色分布,再透過astype(np.uint8)將其轉換成uint8格式,PyTorch的transforms.Compose就能將此class定義的內容,作為Data Augmentation。

transform_test = transforms.Compose([WhiteBalance(),
transforms.Resize([224, 224]),
transforms.ToTensor(),
normalize
])

參考資料

python — PIL TypeError: Cannot handle this data type — Stack Overflow

python make RGB image from 3 float32 numpy arrays — Stack Overflow

相關文章

如何在PyTorch的transforms.Compose中使用自定義的Data Augmentation?

Machine Learning | Deep Learning | https://linktr.ee/yanwei

Love podcasts or audiobooks? Learn on the go with our new app.