解決PyTorch DataLoader中PIL TypeError: Cannot handle this data type問題
最近嘗試將白平衡技術作為一種Data Augmentation方法時,我先自定義了一個class,並將其傳入transforms.Compose來做資料擴增,卻遇到PIL TypeError的錯誤,之前在寫的時候並沒有發生這樣的問題。檢查了資料型態後發現,原來PIL圖片需要是uint8格式,Image.fromarray才能順利運作。
class WhiteBalance(object):
def __call__(self, img):
wbModel = wb_srgb.WBsRGB(gamut_mapping=2, upgraded=0)
img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
img = wbModel.correctImage(img)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = (img * 255).round().astype(np.uint8)
return Image.fromarray(img)
def __repr__(self):
return self.__class__.__name__+'()'
之前沒注意到這件事,頂多在 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)結束後,就直接 return Image.fromarray(img),也沒遇到特別的錯誤。
但是本次採用的方法wbModel.correctImage(img),回傳的是float32格式,而不是uint8格式,才會產生PIL TypeError: Cannot handle this data type。
因此,透過img = (img * 255).round().astype(np.uint8)的方式,先將被轉成0~1之間的圖片乘上255後,得到完整的顏色分布,再透過astype(np.uint8)將其轉換成uint8格式,PyTorch的transforms.Compose就能將此class定義的內容,作為Data Augmentation。
transform_test = transforms.Compose([WhiteBalance(),
transforms.Resize([224, 224]),
transforms.ToTensor(),
normalize
])
參考資料
python — PIL TypeError: Cannot handle this data type — Stack Overflow
python make RGB image from 3 float32 numpy arrays — Stack Overflow
相關文章