From 67ef6f0a92abf2631874a5729a97c7ac13b0d0a2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 27 Jan 2025 14:03:25 -0800 Subject: [PATCH] Move opt_einsum import back out of class __init__ --- timm/optim/kron.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/timm/optim/kron.py b/timm/optim/kron.py index e4198ca4..7f1fcd47 100644 --- a/timm/optim/kron.py +++ b/timm/optim/kron.py @@ -14,7 +14,15 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union import numpy as np import torch - +try: + # NOTE opt_einsum needed to avoid blowing up memory with einsum ops + import opt_einsum + import torch.backends.opt_einsum + torch.backends.opt_einsum.enabled = True + torch.backends.opt_einsum.strategy = "auto-hq" + has_opt_einsum = True +except ImportError: + has_opt_einsum = False try: torch._dynamo.config.cache_size_limit = 1_000_000 @@ -26,11 +34,11 @@ _logger = logging.getLogger(__name__) def precond_update_prob_schedule( - n: float, - max_prob: float = 1.0, - min_prob: float = 0.03, - decay: float = 0.001, - flat_start: float = 500, + n: float, + max_prob: float = 1.0, + min_prob: float = 0.03, + decay: float = 0.001, + flat_start: float = 500, ) -> torch.Tensor: """Anneal preconditioner update probability during beginning of training. @@ -92,14 +100,8 @@ class Kron(torch.optim.Optimizer): flatten_dim: bool = False, deterministic: bool = False, ): - try: - # NOTE opt_einsum needed to avoid blowing up memory with einsum ops - import opt_einsum - opt_einsum.enabled = True - opt_einsum.strategy = "auto-hq" - import torch.backends.opt_einsum - except ImportError: - warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer." ) + if not has_opt_einsum: + warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer.") if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}")