Merge pull request #1 from Clarifai/796

use mmengine config
pull/1952/head
Alan Yu 2024-09-25 09:52:40 -04:00 committed by GitHub
commit 1a984c05f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 20 additions and 19 deletions

View File

@ -7,7 +7,7 @@ from typing import OrderedDict
import numpy as np
import torch
from mmcv import Config
from mmengine.config import Config
from mmcv.parallel import collate, scatter
from modelindex.load_model_index import load
from rich.console import Console

View File

@ -1113,7 +1113,7 @@
},
"source": [
"# Load the base config file\n",
"from mmcv import Config\n",
"from mmengine.config import Config\n",
"cfg = Config.fromfile('configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py')\n",
"\n",
"# Modify the number of classes in the head.\n",

View File

@ -1115,7 +1115,7 @@
},
"source": [
"# 载入已经存在的配置文件\n",
"from mmcv import Config\n",
"from mmengine.config import Config\n",
"cfg = Config.fromfile('configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py')\n",
"\n",
"# 修改模型分类头中的类别数目\n",

View File

@ -52,7 +52,7 @@ class ImageClassifier(BaseClassifier):
1. Backbone output
>>> import torch
>>> from mmcv import Config
>>> from mmengine.config import Config
>>> from mmcls.models import build_classifier
>>>
>>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model
@ -69,7 +69,7 @@ class ImageClassifier(BaseClassifier):
2. Neck output
>>> import torch
>>> from mmcv import Config
>>> from mmengine.config import Config
>>> from mmcls.models import build_classifier
>>>
>>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model
@ -87,7 +87,7 @@ class ImageClassifier(BaseClassifier):
3. Pre-logits output (without the final linear classifier head)
>>> import torch
>>> from mmcv import Config
>>> from mmengine.config import Config
>>> from mmcls.models import build_classifier
>>>
>>> cfg = Config.fromfile('configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py').model

View File

@ -1 +1,2 @@
mmcv-full>=1.4.2,<1.6.0
mmcv==2.2.0
mmengine==0.10.5

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv import Config
from mmengine.config import Config
from mmdet.apis import inference_detector
from mmdet.models import build_detector

View File

@ -5,7 +5,7 @@ from copy import deepcopy
import numpy as np
import torch
from mmcv import ConfigDict
from mmengine.config import ConfigDict
from mmcls.models import CLASSIFIERS
from mmcls.models.classifiers import ImageClassifier

View File

@ -4,7 +4,7 @@ import os
import platform
import cv2
from mmcv import Config
from mmengine.config import Config
from mmcls.utils import setup_multi_processes

View File

@ -2,7 +2,7 @@
import argparse
import mmcv
from mmcv import Config, DictAction
from mmengine.config import Config, DictAction
from mmcls.datasets import build_dataset

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from mmcv import Config
from mmengine.config import Config
from mmcv.cnn.utils import get_model_complexity_info
from mmcls.models import build_classifier

View File

@ -10,7 +10,7 @@ from pathlib import Path
import mmcv
import torch
from mmcv import Config, DictAction
from mmengine.config import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcls import __version__

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from mmcv import Config, DictAction
from mmengine.config import Config, DictAction
def parse_args():

View File

@ -4,7 +4,7 @@ import fcntl
import os
from pathlib import Path
from mmcv import Config, DictAction, track_parallel_progress, track_progress
from mmengine.config import Config, DictAction, track_parallel_progress, track_progress
from mmcls.datasets import PIPELINES, build_dataset

View File

@ -9,7 +9,7 @@ import warnings
import mmcv
import torch
import torch.distributed as dist
from mmcv import Config, DictAction
from mmengine.config import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcls import __version__

View File

@ -8,7 +8,7 @@ from pathlib import Path
import mmcv
import numpy as np
from mmcv import Config, DictAction
from mmengine.config import Config, DictAction
from mmcv.utils import to_2tuple
from torch.nn import BatchNorm1d, BatchNorm2d, GroupNorm, LayerNorm

View File

@ -9,7 +9,7 @@ from pprint import pformat
import matplotlib.pyplot as plt
import mmcv
import torch.nn as nn
from mmcv import Config, DictAction, ProgressBar
from mmengine.config import Config, DictAction, ProgressBar
from mmcv.runner import (EpochBasedRunner, IterBasedRunner, IterLoader,
build_optimizer)
from torch.utils.data import DataLoader

View File

@ -12,7 +12,7 @@ from typing import List
import cv2
import mmcv
import numpy as np
from mmcv import Config, DictAction, ProgressBar
from mmengine.config import Config, DictAction, ProgressBar
from mmcls.core import visualization as vis
from mmcls.datasets.builder import PIPELINES, build_dataset, build_from_cfg