[Refactor] refactor msic files
parent
766fa72533
commit
dfaa8215ae
|
@ -2,6 +2,7 @@
|
|||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
**/*.pyc
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
@ -103,22 +104,16 @@ venv.bak/
|
|||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
openselfsup/version.py
|
||||
version.py
|
||||
data
|
||||
# custom
|
||||
/data
|
||||
.vscode
|
||||
.idea
|
||||
|
||||
# custom
|
||||
*.pkl
|
||||
*.pkl.json
|
||||
*.log.json
|
||||
work_dirs/
|
||||
/mmselfsup/.mim
|
||||
pretrains
|
||||
|
||||
# Pytorch
|
||||
*.pth
|
||||
|
||||
*.swp
|
||||
source.sh
|
||||
tensorboard.sh
|
||||
|
@ -126,3 +121,6 @@ tensorboard.sh
|
|||
replace.sh
|
||||
benchmarks/detection/datasets
|
||||
benchmarks/detection/output
|
||||
|
||||
# Pytorch
|
||||
*.pth
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
exclude: ^tests/data/
|
||||
repos:
|
||||
- repo: https://gitlab.com/pycqa/flake8.git
|
||||
rev: 3.8.3
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/asottile/seed-isort-config
|
||||
rev: v2.2.0
|
||||
hooks:
|
||||
- id: seed-isort-config
|
||||
- repo: https://github.com/timothycrosley/isort
|
||||
rev: 4.3.21
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
rev: v0.30.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: check-yaml
|
||||
exclude: configs/benchmarks/detectron2/Base-RetinaNet.yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: requirements-txt-fixer
|
||||
- id: double-quote-string-fixer
|
||||
- id: check-merge-conflict
|
||||
- id: fix-encoding-pragma
|
||||
args: ["--remove"]
|
||||
- id: mixed-line-ending
|
||||
args: ["--fix=lf"]
|
||||
- repo: https://github.com/markdownlint/markdownlint
|
||||
rev: v0.11.0
|
||||
hooks:
|
||||
- id: markdownlint
|
||||
args: ["-r", "~MD002,~MD013,~MD024,~MD029,~MD033,~MD034,~MD036", "-t", "allow_different_nesting"]
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.1.0
|
||||
hooks:
|
||||
- id: codespell
|
||||
- repo: https://github.com/myint/docformatter
|
||||
rev: v1.3.1
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
|
@ -0,0 +1,9 @@
|
|||
version: 2
|
||||
|
||||
formats: all
|
||||
|
||||
python:
|
||||
version: 3.7
|
||||
install:
|
||||
- requirements: requirements/docs.txt
|
||||
- requirements: requirements/readthedocs.txt
|
|
@ -1,59 +0,0 @@
|
|||
_base_ = '../../base.py'
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
out_indices=[4], # 4: stage-4
|
||||
norm_cfg=dict(type='BN')),
|
||||
head=dict(
|
||||
type='ClsHead', with_avg_pool=True, in_channels=2048, num_classes=10))
|
||||
# dataset settings
|
||||
data_source_cfg = dict(type='Cifar10', root='data/cifar/')
|
||||
dataset_type = 'ClassificationDataset'
|
||||
img_norm_cfg = dict(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.201])
|
||||
train_pipeline = [
|
||||
dict(type='RandomCrop', size=32, padding=4),
|
||||
dict(type='RandomHorizontalFlip'),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
]
|
||||
data = dict(
|
||||
imgs_per_gpu=128,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(split='train', **data_source_cfg),
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(split='test', **data_source_cfg),
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(split='test', **data_source_cfg),
|
||||
pipeline=test_pipeline))
|
||||
# additional hooks
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='ValidateHook',
|
||||
dataset=data['val'],
|
||||
initial=True,
|
||||
interval=10,
|
||||
imgs_per_gpu=128,
|
||||
workers_per_gpu=8,
|
||||
eval_param=dict(topk=(1, 5)))
|
||||
]
|
||||
# optimizer
|
||||
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0005)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[150, 250])
|
||||
checkpoint_config = dict(interval=50)
|
||||
# runtime settings
|
||||
total_epochs = 350
|
|
@ -1,68 +0,0 @@
|
|||
_base_ = '../../base.py'
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='SyncBN')),
|
||||
head=dict(
|
||||
type='ClsHead', with_avg_pool=True, in_channels=2048,
|
||||
num_classes=1000))
|
||||
# dataset settings
|
||||
data_source_cfg = dict(
|
||||
type='ImageNet',
|
||||
memcached=True,
|
||||
mclient_path='/mnt/lustre/share/memcached_client')
|
||||
data_train_list = 'data/imagenet/meta/train_labeled.txt'
|
||||
data_train_root = 'data/imagenet/train'
|
||||
data_test_list = 'data/imagenet/meta/val_labeled.txt'
|
||||
data_test_root = 'data/imagenet/val'
|
||||
dataset_type = 'ClassificationDataset'
|
||||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_pipeline = [
|
||||
dict(type='RandomResizedCrop', size=224),
|
||||
dict(type='RandomHorizontalFlip'),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='Resize', size=256),
|
||||
dict(type='CenterCrop', size=224),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
]
|
||||
data = dict(
|
||||
imgs_per_gpu=32, # total 256
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
list_file=data_train_list, root=data_train_root,
|
||||
**data_source_cfg),
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
list_file=data_test_list, root=data_test_root, **data_source_cfg),
|
||||
pipeline=test_pipeline))
|
||||
# additional hooks
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='ValidateHook',
|
||||
dataset=data['val'],
|
||||
initial=True,
|
||||
interval=10,
|
||||
imgs_per_gpu=32,
|
||||
workers_per_gpu=2,
|
||||
eval_param=dict(topk=(1, 5)))
|
||||
]
|
||||
# optimizer
|
||||
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[30, 60, 90])
|
||||
checkpoint_config = dict(interval=10)
|
||||
# runtime settings
|
||||
total_epochs = 90
|
|
@ -1,3 +1,60 @@
|
|||
from .version import __version__, short_version
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
__all__ = ['__version__', 'short_version']
|
||||
import mmcv
|
||||
from packaging.version import parse
|
||||
|
||||
from .version import __version__
|
||||
|
||||
|
||||
def digit_version(version_str: str, length: int = 4):
|
||||
"""Convert a version string into a tuple of integers.
|
||||
|
||||
This method is usually used for comparing two versions. For pre-release
|
||||
versions: alpha < beta < rc.
|
||||
|
||||
Args:
|
||||
version_str (str): The version string.
|
||||
length (int): The maximum number of version levels. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
tuple[int]: The version info in digits (integers).
|
||||
"""
|
||||
version = parse(version_str)
|
||||
assert version.release, f'failed to parse version {version_str}'
|
||||
release = list(version.release)
|
||||
release = release[:length]
|
||||
if len(release) < length:
|
||||
release = release + [0] * (length - len(release))
|
||||
if version.is_prerelease:
|
||||
mapping = {'a': -3, 'b': -2, 'rc': -1}
|
||||
val = -4
|
||||
# version.pre can be None
|
||||
if version.pre:
|
||||
if version.pre[0] not in mapping:
|
||||
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
|
||||
'version checking may go wrong')
|
||||
else:
|
||||
val = mapping[version.pre[0]]
|
||||
release.extend([val, version.pre[-1]])
|
||||
else:
|
||||
release.extend([val, 0])
|
||||
|
||||
elif version.is_postrelease:
|
||||
release.extend([1, version.post])
|
||||
else:
|
||||
release.extend([0, 0])
|
||||
return tuple(release)
|
||||
|
||||
|
||||
mmcv_minimum_version = '1.3.16'
|
||||
mmcv_maximum_version = '1.5.0'
|
||||
mmcv_version = digit_version(mmcv.__version__)
|
||||
|
||||
|
||||
assert (mmcv_version >= digit_version(mmcv_minimum_version)
|
||||
and mmcv_version <= digit_version(mmcv_maximum_version)), \
|
||||
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
||||
f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
|
||||
|
||||
__all__ = ['__version__', 'digit_version']
|
||||
|
|
|
@ -1,309 +0,0 @@
|
|||
# This file is modified from
|
||||
# https://github.com/facebookresearch/deepcluster/blob/master/clustering.py
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import faiss
|
||||
import torch
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
__all__ = ['Kmeans', 'PIC']
|
||||
|
||||
|
||||
def preprocess_features(npdata, pca):
|
||||
"""Preprocess an array of features.
|
||||
Args:
|
||||
npdata (np.array N * ndim): features to preprocess
|
||||
pca (int): dim of output
|
||||
Returns:
|
||||
np.array of dim N * pca: data PCA-reduced, whitened and L2-normalized
|
||||
"""
|
||||
_, ndim = npdata.shape
|
||||
#npdata = npdata.astype('float32')
|
||||
assert npdata.dtype == np.float32
|
||||
|
||||
if np.any(np.isnan(npdata)):
|
||||
raise Exception("nan occurs")
|
||||
if pca != -1:
|
||||
print("\nPCA from dim {} to dim {}".format(ndim, pca))
|
||||
mat = faiss.PCAMatrix(ndim, pca, eigen_power=-0.5)
|
||||
mat.train(npdata)
|
||||
assert mat.is_trained
|
||||
npdata = mat.apply_py(npdata)
|
||||
if np.any(np.isnan(npdata)):
|
||||
percent = np.isnan(npdata).sum().item() / float(np.size(npdata)) * 100
|
||||
if percent > 0.1:
|
||||
raise Exception(
|
||||
"More than 0.1% nan occurs after pca, percent: {}%".format(
|
||||
percent))
|
||||
else:
|
||||
npdata[np.isnan(npdata)] = 0.
|
||||
# L2 normalization
|
||||
row_sums = np.linalg.norm(npdata, axis=1)
|
||||
|
||||
npdata = npdata / (row_sums[:, np.newaxis] + 1e-10)
|
||||
|
||||
return npdata
|
||||
|
||||
|
||||
def make_graph(xb, nnn):
|
||||
"""Builds a graph of nearest neighbors.
|
||||
Args:
|
||||
xb (np.array): data
|
||||
nnn (int): number of nearest neighbors
|
||||
Returns:
|
||||
list: for each data the list of ids to its nnn nearest neighbors
|
||||
list: for each data the list of distances to its nnn NN
|
||||
"""
|
||||
N, dim = xb.shape
|
||||
|
||||
# we need only a StandardGpuResources per GPU
|
||||
res = faiss.StandardGpuResources()
|
||||
|
||||
# L2
|
||||
flat_config = faiss.GpuIndexFlatConfig()
|
||||
flat_config.device = int(torch.cuda.device_count()) - 1
|
||||
index = faiss.GpuIndexFlatL2(res, dim, flat_config)
|
||||
index.add(xb)
|
||||
D, I = index.search(xb, nnn + 1)
|
||||
return I, D
|
||||
|
||||
|
||||
def run_kmeans(x, nmb_clusters, verbose=False):
|
||||
"""Runs kmeans on 1 GPU.
|
||||
Args:
|
||||
x: data
|
||||
nmb_clusters (int): number of clusters
|
||||
Returns:
|
||||
list: ids of data in each cluster
|
||||
"""
|
||||
n_data, d = x.shape
|
||||
|
||||
# faiss implementation of k-means
|
||||
clus = faiss.Clustering(d, nmb_clusters)
|
||||
|
||||
# Change faiss seed at each k-means so that the randomly picked
|
||||
# initialization centroids do not correspond to the same feature ids
|
||||
# from an epoch to another.
|
||||
clus.seed = np.random.randint(1234)
|
||||
|
||||
clus.niter = 20
|
||||
clus.max_points_per_centroid = 10000000
|
||||
res = faiss.StandardGpuResources()
|
||||
flat_config = faiss.GpuIndexFlatConfig()
|
||||
flat_config.useFloat16 = False
|
||||
flat_config.device = 0
|
||||
index = faiss.GpuIndexFlatL2(res, d, flat_config)
|
||||
|
||||
# perform the training
|
||||
clus.train(x, index)
|
||||
_, I = index.search(x, 1)
|
||||
losses = faiss.vector_to_array(clus.obj)
|
||||
if verbose:
|
||||
print('k-means loss evolution: {0}'.format(losses))
|
||||
|
||||
return [int(n[0]) for n in I], losses[-1]
|
||||
|
||||
|
||||
def arrange_clustering(images_lists):
|
||||
pseudolabels = []
|
||||
image_indexes = []
|
||||
for cluster, images in enumerate(images_lists):
|
||||
image_indexes.extend(images)
|
||||
pseudolabels.extend([cluster] * len(images))
|
||||
indexes = np.argsort(image_indexes)
|
||||
return np.asarray(pseudolabels)[indexes]
|
||||
|
||||
|
||||
class Kmeans:
|
||||
|
||||
def __init__(self, k, pca_dim=256):
|
||||
self.k = k
|
||||
self.pca_dim = pca_dim
|
||||
|
||||
def cluster(self, feat, verbose=False):
|
||||
"""Performs k-means clustering.
|
||||
Args:
|
||||
x_data (np.array N * dim): data to cluster
|
||||
"""
|
||||
end = time.time()
|
||||
|
||||
# PCA-reducing, whitening and L2-normalization
|
||||
xb = preprocess_features(feat, self.pca_dim)
|
||||
|
||||
# cluster the data
|
||||
I, loss = run_kmeans(xb, self.k, verbose)
|
||||
self.labels = np.array(I)
|
||||
if verbose:
|
||||
print('k-means time: {0:.0f} s'.format(time.time() - end))
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def make_adjacencyW(I, D, sigma):
|
||||
"""Create adjacency matrix with a Gaussian kernel.
|
||||
Args:
|
||||
I (numpy array): for each vertex the ids to its nnn linked vertices
|
||||
+ first column of identity.
|
||||
D (numpy array): for each data the l2 distances to its nnn linked vertices
|
||||
+ first column of zeros.
|
||||
sigma (float): Bandwith of the Gaussian kernel.
|
||||
|
||||
Returns:
|
||||
csr_matrix: affinity matrix of the graph.
|
||||
"""
|
||||
V, k = I.shape
|
||||
k = k - 1
|
||||
indices = np.reshape(np.delete(I, 0, 1), (1, -1))
|
||||
indptr = np.multiply(k, np.arange(V + 1))
|
||||
|
||||
def exp_ker(d):
|
||||
return np.exp(-d / sigma**2)
|
||||
|
||||
exp_ker = np.vectorize(exp_ker)
|
||||
res_D = exp_ker(D)
|
||||
data = np.reshape(np.delete(res_D, 0, 1), (1, -1))
|
||||
adj_matrix = csr_matrix((data[0], indices[0], indptr), shape=(V, V))
|
||||
return adj_matrix
|
||||
|
||||
|
||||
def run_pic(I, D, sigma, alpha):
|
||||
"""Run PIC algorithm"""
|
||||
a = make_adjacencyW(I, D, sigma)
|
||||
graph = a + a.transpose()
|
||||
cgraph = graph
|
||||
nim = graph.shape[0]
|
||||
|
||||
W = graph
|
||||
t0 = time.time()
|
||||
|
||||
v0 = np.ones(nim) / nim
|
||||
|
||||
# power iterations
|
||||
v = v0.astype('float32')
|
||||
|
||||
t0 = time.time()
|
||||
dt = 0
|
||||
for i in range(200):
|
||||
vnext = np.zeros(nim, dtype='float32')
|
||||
|
||||
vnext = vnext + W.transpose().dot(v)
|
||||
|
||||
vnext = alpha * vnext + (1 - alpha) / nim
|
||||
# L1 normalize
|
||||
vnext /= vnext.sum()
|
||||
v = vnext
|
||||
|
||||
if (i == 200 - 1):
|
||||
clust = find_maxima_cluster(W, v)
|
||||
|
||||
return [int(i) for i in clust]
|
||||
|
||||
|
||||
def find_maxima_cluster(W, v):
|
||||
n, m = W.shape
|
||||
assert (n == m)
|
||||
assign = np.zeros(n)
|
||||
# for each node
|
||||
pointers = list(range(n))
|
||||
for i in range(n):
|
||||
best_vi = 0
|
||||
l0 = W.indptr[i]
|
||||
l1 = W.indptr[i + 1]
|
||||
for l in range(l0, l1):
|
||||
j = W.indices[l]
|
||||
vi = W.data[l] * (v[j] - v[i])
|
||||
if vi > best_vi:
|
||||
best_vi = vi
|
||||
pointers[i] = j
|
||||
n_clus = 0
|
||||
cluster_ids = -1 * np.ones(n)
|
||||
for i in range(n):
|
||||
if pointers[i] == i:
|
||||
cluster_ids[i] = n_clus
|
||||
n_clus = n_clus + 1
|
||||
for i in range(n):
|
||||
# go from pointers to pointers starting from i until reached a local optim
|
||||
current_node = i
|
||||
while pointers[current_node] != current_node:
|
||||
current_node = pointers[current_node]
|
||||
|
||||
assign[i] = cluster_ids[current_node]
|
||||
assert (assign[i] >= 0)
|
||||
return assign
|
||||
|
||||
|
||||
class PIC():
|
||||
"""Class to perform Power Iteration Clustering on a graph of nearest neighbors.
|
||||
Args:
|
||||
args: for consistency with k-means init
|
||||
sigma (float): bandwith of the Gaussian kernel (default 0.2)
|
||||
nnn (int): number of nearest neighbors (default 5)
|
||||
alpha (float): parameter in PIC (default 0.001)
|
||||
distribute_singletons (bool): If True, reassign each singleton to
|
||||
the cluster of its closest non
|
||||
singleton nearest neighbors (up to nnn
|
||||
nearest neighbors).
|
||||
Attributes:
|
||||
images_lists (list of list): for each cluster, the list of image indexes
|
||||
belonging to this cluster
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
args=None,
|
||||
sigma=0.2,
|
||||
nnn=5,
|
||||
alpha=0.001,
|
||||
distribute_singletons=True,
|
||||
pca_dim=256):
|
||||
self.sigma = sigma
|
||||
self.alpha = alpha
|
||||
self.nnn = nnn
|
||||
self.distribute_singletons = distribute_singletons
|
||||
self.pca_dim = pca_dim
|
||||
|
||||
def cluster(self, data, verbose=False):
|
||||
end = time.time()
|
||||
|
||||
# preprocess the data
|
||||
xb = preprocess_features(data, self.pca_dim)
|
||||
|
||||
# construct nnn graph
|
||||
I, D = make_graph(xb, self.nnn)
|
||||
|
||||
# run PIC
|
||||
clust = run_pic(I, D, self.sigma, self.alpha)
|
||||
images_lists = {}
|
||||
for h in set(clust):
|
||||
images_lists[h] = []
|
||||
for data, c in enumerate(clust):
|
||||
images_lists[c].append(data)
|
||||
|
||||
# allocate singletons to clusters of their closest NN not singleton
|
||||
if self.distribute_singletons:
|
||||
clust_NN = {}
|
||||
for i in images_lists:
|
||||
# if singleton
|
||||
if len(images_lists[i]) == 1:
|
||||
s = images_lists[i][0]
|
||||
# for NN
|
||||
for n in I[s, 1:]:
|
||||
# if NN is not a singleton
|
||||
if not len(images_lists[clust[n]]) == 1:
|
||||
clust_NN[s] = n
|
||||
break
|
||||
for s in clust_NN:
|
||||
del images_lists[clust[s]]
|
||||
clust[s] = clust[clust_NN[s]]
|
||||
images_lists[clust[s]].append(s)
|
||||
|
||||
self.images_lists = []
|
||||
self.labels = -1 * np.ones((data.shape[0], ), dtype=np.int)
|
||||
for i, c in enumerate(images_lists):
|
||||
self.images_lists.append(images_lists[c])
|
||||
self.labels[images_lists[c]] = i
|
||||
assert np.all(self.labels != -1)
|
||||
|
||||
if verbose:
|
||||
print('pic time: {0:.0f} s'.format(time.time() - end))
|
||||
return 0
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
|
||||
__version__ = '0.4.0'
|
||||
|
||||
|
||||
def parse_version_info(version_str):
|
||||
"""Parse a version string into a tuple.
|
||||
|
||||
Args:
|
||||
version_str (str): The version string.
|
||||
|
||||
Returns:
|
||||
tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
|
||||
(1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1').
|
||||
"""
|
||||
version_info = []
|
||||
for x in version_str.split('.'):
|
||||
if x.isdigit():
|
||||
version_info.append(int(x))
|
||||
elif x.find('rc') != -1:
|
||||
patch_version = x.split('rc')
|
||||
version_info.append(int(patch_version[0]))
|
||||
version_info.append(f'rc{patch_version[1]}')
|
||||
return tuple(version_info)
|
||||
|
||||
|
||||
version_info = parse_version_info(__version__)
|
||||
|
||||
__all__ = ['__version__', 'version_info', 'parse_version_info']
|
|
@ -0,0 +1,18 @@
|
|||
[yapf]
|
||||
based_on_style = pep8
|
||||
blank_line_before_nested_class_or_def = true
|
||||
split_before_expression_after_opening_paren = true
|
||||
|
||||
[isort]
|
||||
line_length = 79
|
||||
multi_line_output = 0
|
||||
known_standard_library = setuptools
|
||||
known_first_party = mmselfsup
|
||||
known_third_party = PIL,detectron2,faiss,matplotlib,mmcv,mmdet,numpy,packaging,pytest,scipy,seaborn,six,sklearn,svm_helper,torch,torchvision,tqdm
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
||||
[codespell]
|
||||
skip = *.ipynb
|
||||
quiet-level = 3
|
||||
ignore-words-list = patten,confectionary,nd,ty,formating
|
151
setup.py
151
setup.py
|
@ -1,7 +1,10 @@
|
|||
#!/usr/bin/env python
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import sys
|
||||
import warnings
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
|
@ -11,82 +14,15 @@ def readme():
|
|||
return content
|
||||
|
||||
|
||||
MAJOR = 0
|
||||
MINOR = 3
|
||||
PATCH = 0
|
||||
SUFFIX = ''
|
||||
if PATCH != '':
|
||||
SHORT_VERSION = '{}.{}.{}{}'.format(MAJOR, MINOR, PATCH, SUFFIX)
|
||||
else:
|
||||
SHORT_VERSION = '{}.{}{}'.format(MAJOR, MINOR, SUFFIX)
|
||||
|
||||
version_file = 'openselfsup/version.py'
|
||||
|
||||
|
||||
def get_git_hash():
|
||||
|
||||
def _minimal_ext_cmd(cmd):
|
||||
# construct minimal environment
|
||||
env = {}
|
||||
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
|
||||
v = os.environ.get(k)
|
||||
if v is not None:
|
||||
env[k] = v
|
||||
# LANGUAGE is used on win32
|
||||
env['LANGUAGE'] = 'C'
|
||||
env['LANG'] = 'C'
|
||||
env['LC_ALL'] = 'C'
|
||||
out = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
|
||||
return out
|
||||
|
||||
try:
|
||||
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
|
||||
sha = out.strip().decode('ascii')
|
||||
except OSError:
|
||||
sha = 'unknown'
|
||||
|
||||
return sha
|
||||
|
||||
|
||||
def get_hash():
|
||||
if os.path.exists('.git'):
|
||||
sha = get_git_hash()[:7]
|
||||
elif os.path.exists(version_file):
|
||||
try:
|
||||
from openselfsup.version import __version__
|
||||
sha = __version__.split('+')[-1]
|
||||
except ImportError:
|
||||
raise ImportError('Unable to get git version')
|
||||
else:
|
||||
sha = 'unknown'
|
||||
|
||||
return sha
|
||||
|
||||
|
||||
def write_version_py():
|
||||
content = """# GENERATED VERSION FILE
|
||||
# TIME: {}
|
||||
|
||||
__version__ = '{}'
|
||||
short_version = '{}'
|
||||
"""
|
||||
sha = get_hash()
|
||||
VERSION = SHORT_VERSION + '+' + sha
|
||||
|
||||
with open(version_file, 'w') as f:
|
||||
f.write(content.format(time.asctime(), VERSION, SHORT_VERSION))
|
||||
|
||||
|
||||
def get_version():
|
||||
with open(version_file, 'r') as f:
|
||||
version_file = 'mmselfsup/version.py'
|
||||
with open(version_file, 'r', encoding='utf-8') as f:
|
||||
exec(compile(f.read(), version_file, 'exec'))
|
||||
return locals()['__version__']
|
||||
|
||||
|
||||
def parse_requirements(fname='requirements.txt', with_version=True):
|
||||
"""
|
||||
Parse the package dependencies listed in a requirements file but strips
|
||||
"""Parse the package dependencies listed in a requirements file but strips
|
||||
specific versioning information.
|
||||
|
||||
Args:
|
||||
|
@ -99,15 +35,13 @@ def parse_requirements(fname='requirements.txt', with_version=True):
|
|||
CommandLine:
|
||||
python -c "import setup; print(setup.parse_requirements())"
|
||||
"""
|
||||
import re
|
||||
import sys
|
||||
from os.path import exists
|
||||
import re
|
||||
require_fpath = fname
|
||||
|
||||
def parse_line(line):
|
||||
"""
|
||||
Parse information from a line in a requirements text file
|
||||
"""
|
||||
"""Parse information from a line in a requirements text file."""
|
||||
if line.startswith('-r '):
|
||||
# Allow specifying requirements in other files
|
||||
target = line.split(' ')[1]
|
||||
|
@ -163,18 +97,68 @@ def parse_requirements(fname='requirements.txt', with_version=True):
|
|||
return packages
|
||||
|
||||
|
||||
def add_mim_extension():
|
||||
"""Add extra files that are required to support MIM into the package.
|
||||
|
||||
These files will be added by creating a symlink to the originals if the
|
||||
package is installed in `editable` mode (e.g. pip install -e .), or by
|
||||
copying from the originals otherwise.
|
||||
"""
|
||||
|
||||
# parse installment mode
|
||||
if 'develop' in sys.argv:
|
||||
# installed by `pip install -e .`
|
||||
mode = 'symlink'
|
||||
elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv:
|
||||
# installed by `pip install .`
|
||||
# or create source distribution by `python setup.py sdist`
|
||||
mode = 'copy'
|
||||
else:
|
||||
return
|
||||
|
||||
filenames = ['tools', 'configs', 'model-index.yml']
|
||||
repo_path = osp.dirname(__file__)
|
||||
mim_path = osp.join(repo_path, 'mmselfsup', '.mim')
|
||||
os.makedirs(mim_path, exist_ok=True)
|
||||
|
||||
for filename in filenames:
|
||||
if osp.exists(filename):
|
||||
src_path = osp.join(repo_path, filename)
|
||||
tar_path = osp.join(mim_path, filename)
|
||||
|
||||
if osp.isfile(tar_path) or osp.islink(tar_path):
|
||||
os.remove(tar_path)
|
||||
elif osp.isdir(tar_path):
|
||||
shutil.rmtree(tar_path)
|
||||
|
||||
if mode == 'symlink':
|
||||
src_relpath = osp.relpath(src_path, osp.dirname(tar_path))
|
||||
os.symlink(src_relpath, tar_path)
|
||||
elif mode == 'copy':
|
||||
if osp.isfile(src_path):
|
||||
shutil.copyfile(src_path, tar_path)
|
||||
elif osp.isdir(src_path):
|
||||
shutil.copytree(src_path, tar_path)
|
||||
else:
|
||||
warnings.warn(f'Cannot copy file {src_path}.')
|
||||
else:
|
||||
raise ValueError(f'Invalid mode {mode}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
write_version_py()
|
||||
add_mim_extension()
|
||||
setup(
|
||||
name='openselfsup',
|
||||
name='mmselfsup',
|
||||
version=get_version(),
|
||||
description='Self-Supervision Toolbox and Benchmark',
|
||||
description='OpenMMLab Self-Supervision Toolbox and Benchmark',
|
||||
long_description=readme(),
|
||||
author='Xiaohang Zhan',
|
||||
author_email='xiaohangzhan@outlook.com',
|
||||
long_description_content_type='text/markdown',
|
||||
author='MMSelfSup Contributors',
|
||||
author_email='openmmlab@gmail.com',
|
||||
keywords='unsupervised learning, self-supervised learning',
|
||||
url='https://github.com/open-mmlab/openselfsup',
|
||||
url='https://github.com/open-mmlab/mmselfsup',
|
||||
packages=find_packages(exclude=('configs', 'tools', 'demo')),
|
||||
include_package_data=True,
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
|
@ -183,9 +167,10 @@ if __name__ == '__main__':
|
|||
'Programming Language :: Python :: 3.5',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Programming Language :: Python :: 3.9',
|
||||
],
|
||||
license='Apache License 2.0',
|
||||
setup_requires=parse_requirements('requirements/build.txt'),
|
||||
tests_require=parse_requirements('requirements/tests.txt'),
|
||||
install_requires=parse_requirements('requirements/runtime.txt'),
|
||||
zip_safe=False)
|
||||
|
|
Loading…
Reference in New Issue