Merge branch 'KaiyangZhou:master' into master

pull/511/head
ZXYFrank 2022-07-10 17:19:53 +08:00 committed by GitHub
commit be05a098a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 13 deletions

View File

@ -97,14 +97,14 @@ Get started: 30 seconds to Torchreid
.. code-block:: python
datamanager = torchreid.data.ImageDataManager(
root='reid-data',
sources='market1501',
targets='market1501',
root="reid-data",
sources="market1501",
targets="market1501",
height=256,
width=128,
batch_size_train=32,
batch_size_test=100,
transforms=['random_flip', 'random_crop']
transforms=["random_flip", "random_crop"]
)
3 Build model, optimizer and lr_scheduler
@ -112,9 +112,9 @@ Get started: 30 seconds to Torchreid
.. code-block:: python
model = torchreid.models.build_model(
name='resnet50',
name="resnet50",
num_classes=datamanager.num_train_pids,
loss='softmax',
loss="softmax",
pretrained=True
)
@ -122,13 +122,13 @@ Get started: 30 seconds to Torchreid
optimizer = torchreid.optim.build_optimizer(
model,
optim='adam',
optim="adam",
lr=0.0003
)
scheduler = torchreid.optim.build_lr_scheduler(
optimizer,
lr_scheduler='single_step',
lr_scheduler="single_step",
stepsize=20
)
@ -149,7 +149,7 @@ Get started: 30 seconds to Torchreid
.. code-block:: python
engine.run(
save_dir='log/resnet50',
save_dir="log/resnet50",
max_epoch=60,
eval_freq=10,
print_freq=10,

View File

@ -367,10 +367,10 @@ class Engine(object):
end = time.time()
features = self.extract_features(imgs)
batch_time.update(time.time() - end)
features = features.cpu().clone()
features = features.cpu()
f_.append(features)
pids_.extend(pids)
camids_.extend(camids)
pids_.extend(pids.tolist())
camids_.extend(camids.tolist())
f_ = torch.cat(f_, 0)
pids_ = np.asarray(pids_)
camids_ = np.asarray(camids_)

View File

@ -71,7 +71,7 @@ class FeatureExtractor(object):
model = build_model(
model_name,
num_classes=1,
pretrained=True,
pretrained=not (model_path and check_isfile(model_path)),
use_gpu=device.startswith('cuda')
)
model.eval()