# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import json
import os
import tempfile
import time
import unittest
from mock import MagicMock
import torch
from torch import nn

from detectron2.engine import SimpleTrainer, hooks
from detectron2.utils.events import CommonMetricPrinter, JSONWriter


class SimpleModel(nn.Module):
    def __init__(self, sleep_sec=0):
        super().__init__()
        self.mod = nn.Linear(10, 20)
        self.sleep_sec = sleep_sec

    def forward(self, x):
        if self.sleep_sec > 0:
            time.sleep(self.sleep_sec)
        return {"loss": x.sum() + sum([x.mean() for x in self.parameters()])}


class TestTrainer(unittest.TestCase):
    def _data_loader(self, device):
        device = torch.device(device)
        while True:
            yield torch.rand(3, 3).to(device)

    def test_simple_trainer(self, device="cpu"):
        model = SimpleModel().to(device=device)
        trainer = SimpleTrainer(
            model, self._data_loader(device), torch.optim.SGD(model.parameters(), 0.1)
        )
        trainer.train(0, 10)

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
    def test_simple_trainer_cuda(self):
        self.test_simple_trainer(device="cuda")

    def test_writer_hooks(self):
        model = SimpleModel(sleep_sec=0.1)
        trainer = SimpleTrainer(
            model, self._data_loader("cpu"), torch.optim.SGD(model.parameters(), 0.1)
        )

        max_iter = 50

        with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
            json_file = os.path.join(d, "metrics.json")
            writers = [CommonMetricPrinter(max_iter), JSONWriter(json_file)]
            logger_info = writers[0].logger.info = MagicMock()

            trainer.register_hooks(
                [hooks.EvalHook(0, lambda: {"metric": 100}), hooks.PeriodicWriter(writers)]
            )
            trainer.train(0, max_iter)

            with open(json_file, "r") as f:
                data = [json.loads(line.strip()) for line in f]
                self.assertEqual([x["iteration"] for x in data], [19, 39, 49, 50])
                # the eval metric is in the last line with iter 50
                self.assertIn("metric", data[-1], "Eval metric must be in last line of JSON!")

            # test logged messages from CommonMetricPrinter
            all_logs = [str(x) for x in logger_info.call_args_list]
            self.assertEqual(len(all_logs), 3)
            for log, iter in zip(all_logs, [19, 39, 49]):
                self.assertIn(f"iter: {iter}", log)

            self.assertIn("eta: 0:00:00", all_logs[-1], "Last ETA must be 0!")