refine
parent
f994c6d09e
commit
6c1a108e64
|
@ -24,7 +24,7 @@ def setup(rank, world_size):
|
|||
|
||||
dist.init_process_group('nccl', rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
print_log(f'init {rank}/{world_size}')
|
||||
print_log(f'init {rank}/{world_size}', only_rank0=False)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -93,7 +93,6 @@ def _materialize_meta_module(module: nn.Module, ):
|
|||
|
||||
|
||||
def main(rank, world_size=8, args=None):
|
||||
print(f'init {rank}/{world_size}')
|
||||
setup(rank, world_size)
|
||||
|
||||
model_name = args.model
|
||||
|
@ -148,9 +147,13 @@ def main(rank, world_size=8, args=None):
|
|||
try:
|
||||
op.prune(0.5, prunen=2, prunem=4)
|
||||
torch.cuda.empty_cache()
|
||||
print_log(f'prune {name}')
|
||||
print_log(
|
||||
f'prune {name} on rank:{rank} successfully.', # noqa
|
||||
only_rank0=False)
|
||||
except Exception as e:
|
||||
print_log(f'{e}')
|
||||
print_log(
|
||||
f'prune {name} on rank:{rank} failed, as {e}', # noqa
|
||||
only_rank0=False)
|
||||
num_op += 1
|
||||
num_op = 0
|
||||
for name, op in fsdp.named_modules():
|
||||
|
|
Loading…
Reference in New Issue