Move opt_einsum import back out of class __init__

kron_optimizer
Ross Wightman 2025-01-27 14:03:25 -08:00
parent 71d174180f
commit 80a0205725
1 changed files with 16 additions and 14 deletions

View File

@ -14,7 +14,15 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union
import numpy as np
import torch
try:
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
import opt_einsum
import torch.backends.opt_einsum
torch.backends.opt_einsum.enabled = True
torch.backends.opt_einsum.strategy = "auto-hq"
has_opt_einsum = True
except ImportError:
has_opt_einsum = False
try:
torch._dynamo.config.cache_size_limit = 1_000_000
@ -26,11 +34,11 @@ _logger = logging.getLogger(__name__)
def precond_update_prob_schedule(
n: float,
max_prob: float = 1.0,
min_prob: float = 0.03,
decay: float = 0.001,
flat_start: float = 500,
n: float,
max_prob: float = 1.0,
min_prob: float = 0.03,
decay: float = 0.001,
flat_start: float = 500,
) -> torch.Tensor:
"""Anneal preconditioner update probability during beginning of training.
@ -92,14 +100,8 @@ class Kron(torch.optim.Optimizer):
flatten_dim: bool = False,
deterministic: bool = False,
):
try:
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
import opt_einsum
opt_einsum.enabled = True
opt_einsum.strategy = "auto-hq"
import torch.backends.opt_einsum
except ImportError:
warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer." )
if not has_opt_einsum:
warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer.")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")