Move opt_einsum import back out of class __init__
parent
71d174180f
commit
80a0205725
|
@ -14,7 +14,15 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
|
||||
import opt_einsum
|
||||
import torch.backends.opt_einsum
|
||||
torch.backends.opt_einsum.enabled = True
|
||||
torch.backends.opt_einsum.strategy = "auto-hq"
|
||||
has_opt_einsum = True
|
||||
except ImportError:
|
||||
has_opt_einsum = False
|
||||
|
||||
try:
|
||||
torch._dynamo.config.cache_size_limit = 1_000_000
|
||||
|
@ -26,11 +34,11 @@ _logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def precond_update_prob_schedule(
|
||||
n: float,
|
||||
max_prob: float = 1.0,
|
||||
min_prob: float = 0.03,
|
||||
decay: float = 0.001,
|
||||
flat_start: float = 500,
|
||||
n: float,
|
||||
max_prob: float = 1.0,
|
||||
min_prob: float = 0.03,
|
||||
decay: float = 0.001,
|
||||
flat_start: float = 500,
|
||||
) -> torch.Tensor:
|
||||
"""Anneal preconditioner update probability during beginning of training.
|
||||
|
||||
|
@ -92,14 +100,8 @@ class Kron(torch.optim.Optimizer):
|
|||
flatten_dim: bool = False,
|
||||
deterministic: bool = False,
|
||||
):
|
||||
try:
|
||||
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
|
||||
import opt_einsum
|
||||
opt_einsum.enabled = True
|
||||
opt_einsum.strategy = "auto-hq"
|
||||
import torch.backends.opt_einsum
|
||||
except ImportError:
|
||||
warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer." )
|
||||
if not has_opt_einsum:
|
||||
warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer.")
|
||||
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
|
|
Loading…
Reference in New Issue