import os
import torch
import numpy as np
from PIL import Image
from pycocotools.coco import COCO
from torch.utils.data import Dataset
from pycocotools import mask as maskUtils

class COCODataset(Dataset):
    def __init__(self, root, train, transform=None):
        super().__init__()
        directory = "train" if train else "val"
        annotations = os.path.join(root, "annotations", f"{directory}_annotations.json")
        
        self.coco = COCO(annotations)
        self.iamge_path = os.path.join(root, directory)
        self.transform = transform

        self.categories = self._get_categories()
        self.data = self._load_data()

    def _get_categories(self):
        categories = {0: "background"}
        for category in self.coco.cats.values():
            categories[category["id"]] = category["name"]
        return categories
    
    def _load_data(self):
        data = []
        for _id in self.coco.imgs:
            file_name = self.coco.loadImgs(_id)[0]["file_name"]
            image_path = os.path.join(self.iamge_path, file_name)
            image = Image.open(image_path).convert("RGB")
            width, height = image.size
            
            boxes = []
            labels = []
            masks = []
            anns = self.coco.loadAnns(self.coco.getAnnIds(_id))
            for ann in anns:
                x, y, w, h = ann["bbox"]
                segmentations = ann["segmentation"]
                try:
                    mask = self._polygon_to_mask(segmentations, width, height)
                except Exception as e:
                    pass

                boxes.append([x, y, x + w, y + h])
                labels.append(ann["category_id"])
                masks.append(mask)
                
            target = {
            "image_id": torch.LongTensor([_id]),
                "boxes": torch.FloatTensor(boxes),
                "labels": torch.LongTensor(labels),
                "masks": torch.FloatTensor(masks)
            }
            data.append([image, target])
        return data

    def _polygon_to_mask(self, segmentations, width, height):
        binary_mask = []
        for seg in segmentations:
            rles = maskUtils.frPyObjects([seg], height, width)
            binary_mask.append(maskUtils.decode(rles))

        combined_mask = np.sum(binary_mask, axis=0).squeeze()
        return combined_mask

    def __getitem__(self, index):
        image, target = self.data[index]
        if self.transform:
            image = self.transform(image)
        return image, target

    def __len__(self):
        return len(self.data)
from torchvision import transforms
from torch.utils.data import DataLoader

def collator(batch):
    return tuple(zip(*batch))

transform = transforms.Compose(
    [
        transforms.PILToTensor(),
        transforms.ConvertImageDtype(dtype=torch.float)
    ]
)

train_dataset = COCODataset("../datasets/coco", train=True, transform=transform)
test_dataset = COCODataset("../datasets/coco", train=False, transform=transform)

train_dataloader = DataLoader(
    train_dataset, batch_size=4, shuffle=True, drop_last=True, collate_fn=collator
)
test_dataloader = DataLoader(
    test_dataset, batch_size=1, shuffle=True, drop_last=True, collate_fn=collator
)
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

num_classes = 3
hidden_layer = 256
device = "cuda" if torch.cuda.is_available() else "cpu"
model = maskrcnn_resnet50_fpn(weights="DEFAULT")

model.roi_heads.box_predictor = FastRCNNPredictor(
    in_channels=model.roi_heads.box_predictor.cls_score.in_features,
    num_classes=num_classes
)
model.roi_heads.mask_predictor = MaskRCNNPredictor(
    in_channels=model.roi_heads.mask_predictor.conv5_mask.in_channels,
    dim_reduced=hidden_layer,
    num_classes=num_classes
)
model.to(device)
from torch import optim

params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
for epoch in range(5):
    cost = 0.0
    for idx, (images, targets) in enumerate(train_dataloader):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        cost += losses

    lr_scheduler.step()
    cost = cost / len(train_dataloader)
    print(f"Epoch : {epoch+1:4d}, Cost : {cost:.3f}")
import numpy as np
from matplotlib import pyplot as plt
from torchvision.transforms.functional import to_pil_image

def draw_bbox(ax, box, text, color, mask):
    ax.add_patch(
        plt.Rectangle(
            xy=(box[0], box[1]),
            width=box[2] - box[0],
            height=box[3] - box[1],
            fill=False,
            edgecolor=color,
            linewidth=2,
        )
    )
    ax.annotate(
        text=text,
        xy=(box[0] - 5, box[1] - 5),
        color=color,
        weight="bold",
        fontsize=13,
    )

    mask = np.ma.masked_where(mask == 0, mask)
    mask_color = {"blue": "Blues", "red" : "Reds"}

    cmap = plt.cm.get_cmap(mask_color.get(color, "Greens"))
    norm = plt.Normalize(vmin=0, vmax=1)
    rgba = cmap(norm(mask))
    ax.imshow(rgba, interpolation="nearest", alpha=0.3)

threshold = 0.5
categories = test_dataset.categories

with torch.no_grad():
    model.eval()
    for images, targets in test_dataloader:
        images = [image.to(device) for image in images]
        outputs = model(images)
        
        boxes = outputs[0]["boxes"].to("cpu").numpy()
        labels = outputs[0]["labels"].to("cpu").numpy()
        scores = outputs[0]["scores"].to("cpu").numpy()
        
        boxes = boxes[scores >= threshold].astype(np.int32)
        labels = labels[scores >= threshold]
        scores = scores[scores >= threshold]

        fig = plt.figure(figsize=(8, 8))
        ax = fig.add_subplot(1, 1, 1)
        plt.imshow(to_pil_image(images[0]))

        masks = outputs[0]["masks"].squeeze(1).to("cpu").numpy()
        masks[masks >= threshold] = 1.0
        masks[masks < threshold] = 0.0

        for box, mask, label, score in zip(boxes, masks, labels, scores):
            draw_bbox(ax, box, f"{categories[label]} - {score:.4f}", "red", mask)

        tboxes = targets[0]["boxes"].numpy()
        tmask = targets[0]["masks"].numpy()
        tlabels = targets[0]["labels"].numpy()

        for box, mask, label in zip(tboxes, tmask, tlabels):
            draw_bbox(ax, box, f"{categories[label]}", "blue", mask)
            
        plt.show()
import numpy as np
from pycocotools.cocoeval import COCOeval

with torch.no_grad():
    model.eval()
    coco_detections = []
    for images, targets in test_dataloader:
        images = [img.to(device) for img in images]
        outputs = model(images)
        
        for i in range(len(targets)):
            image_id = targets[i]["image_id"].data.cpu().numpy().tolist()[0]
            boxes = outputs[i]["boxes"].data.cpu().numpy()
            boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
            boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
            scores = outputs[i]["scores"].data.cpu().numpy()
            labels = outputs[i]["labels"].data.cpu().numpy()
            masks = outputs[i]["masks"].squeeze(1).data.cpu().numpy()

            for instance_id in range(len(boxes)):
                segmentation_mask = masks[instance_id]
                binary_mask = segmentation_mask > 0.5
                binary_mask = binary_mask.astype(np.uint8)
                binary_mask_encoded = maskUtils.encode(
                    np.asfortranarray(binary_mask)
                )

                prediction = {
                    "image_id": int(image_id),
                    "category_id": int(labels[instance_id]),
                    "bbox": [round(coord, 2) for coord in boxes[instance_id]],
                    "score": float(scores[instance_id]),
                    "segmentation": binary_mask_encoded
                }
                coco_detections.append(prediction)

    coco_gt = test_dataloader.dataset.coco
    coco_dt = coco_gt.loadRes(coco_detections)
    coco_evaluator = COCOeval(coco_gt, coco_dt, iouType="segm")
    coco_evaluator.evaluate()
    coco_evaluator.accumulate()
    coco_evaluator.summarize()