Skip to content

File: Broken/Externals/__init__.py

Broken.Externals

ExternalTorchBase

Bases: BrokenModel

Source code in Broken/Externals/__init__.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class ExternalTorchBase(BrokenModel):

    @property
    def device(self) -> str:
        self.load_torch()
        if (device := Environment.get("TORCH_DEVICE")):
            return device
        if torch.cuda.is_available():
            return "cuda"
        if torch.backends.mps.is_available():
            return "mps"
        return "cpu"

    def load_torch(self) -> None:
        """Install and inject torch in the caller's globals"""
        BrokenTorch.install(exists_ok=True)
        inspect.currentframe().f_back.f_globals["torch"] = __import__("torch")

device

device: str

load_torch

load_torch() -> None

Install and inject torch in the caller's globals

Source code in Broken/Externals/__init__.py
36
37
38
39
def load_torch(self) -> None:
    """Install and inject torch in the caller's globals"""
    BrokenTorch.install(exists_ok=True)
    inspect.currentframe().f_back.f_globals["torch"] = __import__("torch")

ExternalModelsBase

Bases: BrokenModel, ABC

Source code in Broken/Externals/__init__.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class ExternalModelsBase(BrokenModel, ABC):
    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        validate_assignment=True
    )

    model: str = Field("any")

    _model: Any = PrivateAttr(None)
    """The true loaded model object"""

    _loaded: SameTracker = PrivateAttr(default_factory=SameTracker)
    """Keeps track of the current loaded model name, to avoid reloading"""

    @BrokenWorker.easy_lock
    def load_model(self) -> Self:
        if self._loaded(self.model):
            return
        if self._model:
            del self._model
        self._load_model()
        return self

    @abstractmethod
    def _load_model(self) -> None:
        ...

model_config

model_config = ConfigDict(
    arbitrary_types_allowed=True, validate_assignment=True
)

model

model: str = Field('any')

load_model

load_model() -> Self
Source code in Broken/Externals/__init__.py
57
58
59
60
61
62
63
64
@BrokenWorker.easy_lock
def load_model(self) -> Self:
    if self._loaded(self.model):
        return
    if self._model:
        del self._model
    self._load_model()
    return self