Deep Learning

Yolo V5 Gridmask Augmentation

해시브라우니 2024. 2. 7. 06:14

Cutout augmentation에 이어서 Gridmask를 적용하는 것도 진행을 하였습니다.

Cutout은 코드가 적용되어 있는데에 반해 Gridmask는 따로 없어서 Github에서 긁어서 진행하였구요.

https://github.com/dvlab-research/GridMask

 

GitHub - dvlab-research/GridMask

Contribute to dvlab-research/GridMask development by creating an account on GitHub.

github.com

 

gridmask를 구현하는 python파일은 gridmask 파일 중 data/transform/grid.py에 구현되어 있습니다.

기존에 yolo에서 진행하는 Augmentation에서 dataloader.py에서 cutout을 적용할때 수정했던것을  바탕으로 어떤 방식으로 구현할까 고민을 했는데 ,, 1) Class형식으로 구현되어 있는 Gridmask을 사용하자 2) 함수를 만들어서 사진마다 사용해보자 였는데 우선 1번으로 진행을 하였습니다.

Class형식으로 구현되어 있는 Gridmask를 어떻게 적용했는지는 Albumetation Class를 참고하였습니다.

 

이게 Albumetation class인데 def __call__ 쪽과 Dataloader에서 어떻게 사용됐는지 확인해서 Gridmask에도 마찬가지로 적용을 하였습니다.

####################################
#Own customization GridMask
class Grid:
    def __init__(self, use_h, use_w, rotate = 1, offset=False, ratio = 0.5, mode=0, prob = 1.):
        self.use_h = use_h
        self.use_w = use_w
        self.rotate = rotate
        self.offset = offset
        self.ratio = ratio
        self.mode=mode
        self.st_prob = prob
        self.prob = prob

    def set_prob(self, epoch, max_epoch):
        self.prob = self.st_prob * epoch / max_epoch

    def __call__(self, img, label):
        if np.random.rand() > self.prob:
            return img, label
        h ,w = img.shape[0], img.shape[1]
        self.d1 = 2
        self.d2 = min(h, w)
        hh = int(1.5*h)
        ww = int(1.5*w)
        d = np.random.randint(self.d1, self.d2)
        #d = self.d
#        self.l = int(d*self.ratio+0.5)
        if self.ratio == 1:
            self.l = np.random.randint(1, d)
        else:
            self.l = min(max(int(d*self.ratio+0.5),1),d-1)
        mask = np.ones((hh, ww, 3), np.float32)
        st_h = np.random.randint(d)
        st_w = np.random.randint(d)
        if self.use_h:
            for i in range(hh//d):
                s = d*i + st_h
                t = min(s+self.l, hh)
                mask[s:t,:,:] *= 0
        if self.use_w:
            for i in range(ww//d):
                s = d*i + st_w
                t = min(s+self.l, ww)
                mask[:,s:t,:] *= 0
       
        r = np.random.randint(self.rotate)
        mask = Image.fromarray(np.uint8(mask))
        mask = mask.rotate(r)
        mask = np.asarray(mask)
#        mask = 1*(np.random.randint(0,3,[hh,ww])>0)
        mask = mask[(hh-h)//2:(hh-h)//2+h, (ww-w)//2:(ww-w)//2+w]

        # mask = torch.from_numpy(mask).float() ##에러 해결?
        if self.mode == 1:
            mask = 1-mask
        
        # mask = mask.expand_as(img)
        if self.offset:
            offset = torch.from_numpy(2 * (np.random.rand(h,w) - 0.5)).float()
            offset = (1 - mask) * offset
            img = img * mask + offset
        else:
            img = img * mask 

        return img, label

 

원본의 Gridmask class에서 긁어와서 일부분 수정하였습니다. 수정한 부분은 

mask = np.ones((hh, ww, 3), np.float32)
 
mask[s:t,:,:] *= 0
mask[:,s:t,:] *= 0
 
# mask = torch.from_numpy(mask).float() ##에러 해결?
# mask = mask.expand_as(img)
 

 

위 코드 입니다. 그냥 Gridmask 원본은 예상한 결과로는 처음부터 Image를 Tensor로 받아서 Size 함수, expand_as함수를 사용하는 것 같더라구요. 하지만 Yolo에서는 image를 numpy로 받기 때문에 size함수, expand_as함수 모두 호환이나 에러가 발생했습니다. 무턱대고 그런 생각을 안하고 구현해서 해당 부분에서 시간이 좀 걸렸네요 

발생한 에러들

 

그래서 그 부분을 해결해주고자 Numpy로 받는 차원 고려해줘서 슬라이싱이라든지 행렬을 만든다든지 진행하였고, tensor관련 코드들은 주석처리 해버렸습니다.

 

이제 Dataloader를 참고해서 적용을 해봅시다. 

Dataloader에 Albumentation과 유사하게 augmentation에 추가를 해줍시다.

 

그럼 이제 생성자를 만들어 줘야겠습니다. Albumentation도 self.albumentation으로 초기 값을 설정하였기 때문에, 원본 Gridmask의 default값이나 실제 config에서 사용되었던 값들로 맞추어줬구요. 

Grid가 어떻게 사용되는지? config파일은 오른쪽

 

맨 아랫줄이 config를 참고해서 적용한 코드입니다! 저것만 적어주고, __get item__ 부분만 수정하면 해결이 됩니다. 

 

실제 Albumentation을 그대로 따라한 것을 볼 수 있습니다. 

모든 준비가 다되면 COCO128개로 잘되는지 확인해봅니다.

 

 

Gridmask가 잘 적용되었습니다! 이번엔 Cutout가 별개로 Mosaic된 사진 전체에 Gridmask가 적용된 것을 확인할 수 있네요. Cutout은 빼먹는걸 까먹어서 같이 구현된것도있구요... 아마 앞으로는 구현한 Augmentation중 하나를 선택해서 실제 Training에 어떤 결과를 나타내는지 확인해 볼 생각입니다. Training하는데 걸리는 시간이 좀 빡세지만요.... 

Gridmask가 코드가 오래되어서 그런지 본 코드를 직접 돌리기엔 여러 conflict가 있는것 같던데 (Cuda 버전 굳이 내리기 싫기도하고..) 이렇게 yolo에 넣어서 확인해보는것도 하나의 방법이네요. 

 

혹시 새로이 Augmentation을 적용하실 분에게 조그만한 도움이라도 됐으면..