mirror of https://github.com/JDAI-CV/fast-reid.git
30 lines
746 B
Python
30 lines
746 B
Python
# -*- coding: utf-8 -*-
|
|
# File : unittest.py
|
|
# Author : Jiayuan Mao
|
|
# Email : maojiayuan@gmail.com
|
|
# Date : 27/01/2018
|
|
#
|
|
# This file is part of Synchronized-BatchNorm-PyTorch.
|
|
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
|
# Distributed under MIT License.
|
|
|
|
import unittest
|
|
import torch
|
|
|
|
|
|
class TorchTestCase(unittest.TestCase):
|
|
def assertTensorClose(self, x, y):
|
|
adiff = float((x - y).abs().max())
|
|
if (y == 0).all():
|
|
rdiff = 'NaN'
|
|
else:
|
|
rdiff = float((adiff / y).abs().max())
|
|
|
|
message = (
|
|
'Tensor close check failed\n'
|
|
'adiff={}\n'
|
|
'rdiff={}\n'
|
|
).format(adiff, rdiff)
|
|
self.assertTrue(torch.allclose(x, y), message)
|
|
|