• 三维医学图像深度学习,数据增强方法(monai):RandHistogramShiftD, Flipd, Rotate90d


    #coding:utf-8
    import torch
    from monai.transforms import Compose, RandHistogramShiftD, Flipd, Rotate90d
    import matplotlib.pyplot as plt
    import SimpleITK as sitk
    # start a chain of transforms
    KEYS = ("image", "label")
    class aug():
        def __init__(self):
            self.random_rotated = Compose([
                Rotate90d(KEYS, k=1, spatial_axes=(2,3),allow_missing_keys=True),
                Flipd(KEYS, spatial_axis=(1,2,3),allow_missing_keys=True),
                RandHistogramShiftD(KEYS,  prob=1, num_control_points=30, allow_missing_keys=True),
                # ToTensorD(KEYS),
            ])
        def forward(self,x):
            x = self.random_rotated(x)
            return x
    
    # start a dataset
    def save(before_x, after_x, new_path,new_name=""):
        after_x = after_x[0, 0,...]
        if new_name=="image":
            ct = sitk.ReadImage(before_x, sitk.sitkInt16)
        else:
            ct = sitk.ReadImage(before_x, sitk.sitkUInt8)
        predict_seg = sitk.GetImageFromArray(after_x)
        predict_seg.SetDirection(ct.GetDirection())
        predict_seg.SetOrigin(ct.GetOrigin())
        predict_seg.SetSpacing(ct.GetSpacing())
    
        sitk.WriteImage(predict_seg,new_path)
    
    
    if __name__ == "__main__":
        image = r"D:\MyData\3Dircadb1_fusion_date\image_2.nii"   # 原图
        label = r"D:\MyData\3Dircadb1_fusion_date\liver_2.nii"   #标签
        new_path = r"D:\MyData\3Dircadb1_fusion_date\image_0.nii"  #增强后的原图
        new_path1 = r"D:\MyData\3Dircadb1_fusion_date\liver_1.nii"  #增强后的标签
    
        ct = sitk.ReadImage(image)
        ct1 = sitk.GetArrayFromImage(ct)
        seg = sitk.ReadImage(label)
        seg1 = sitk.GetArrayFromImage(seg)
    
        ct = ct1[None, None,...]
        seg = seg1[None, None,...]
    
        ct = torch.from_numpy(ct)
        seg = torch.from_numpy(seg)
        m = {"image": ct,
             "label":seg}
        augs = aug()
        print(m["image"].shape)
        data_dict= augs.forward(m)
    
        save(image, data_dict["image"], new_path, "image")
        save(label, data_dict["label"], new_path1, "label")
    
    
        print(data_dict["image"].shape)
        plt.subplots(1, 3)
        plt.subplot(1, 3, 1);
        plt.imshow(ct1[66,...])
        plt.subplot(1, 3, 2);
        plt.imshow(data_dict["image"][0,0, 66,...])
        plt.subplot(1, 3, 3);
        plt.imshow(data_dict["label"][0,0, 66,...])
        plt.show()
  • 相关阅读:
    一步步学习SPD2010--第十一章节--处理母版页(10)--重置母版页到网站定义
    pandas转numpy并打平实例
    list和numpy互相转换
    pandas转numpy
    pandas库疑难问题---2、pandas切片操作
    pandas切片操作
    pandas中的iloc和loc用法的区别
    NumPy疑难问题---1、NumPy切片操作
    numpy切片操作
    python疑难问题---13、Python切片操作
  • 原文地址:https://www.cnblogs.com/peixu/p/16005792.html
Copyright © 2020-2023  润新知