Skip to content

File: Broken/Core/BrokenTorch.py

Broken.Core.BrokenTorch

TORCH_INDEX_URL_NIGHTLY

TORCH_INDEX_URL_NIGHTLY: str = (
    "https://download.pytorch.org/whl/nightly/"
)

TORCH_INDEX_URL_STABLE

TORCH_INDEX_URL_STABLE: str = (
    "https://download.pytorch.org/whl/"
)

TORCH_VERSION

TORCH_VERSION: str = '2.6.0'

TorchRelease

Bases: str, BrokenEnum

Source code in Broken/Core/BrokenTorch.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class TorchRelease(str, BrokenEnum):
    TORCH_251_MACOS    = "2.5.1"
    TORCH_251_CPU      = "2.5.1+cpu"
    TORCH_251_CUDA_118 = "2.5.1+cu118"
    TORCH_251_CUDA_121 = "2.5.1+cu121"
    TORCH_251_CUDA_124 = "2.5.1+cu124"
    TORCH_251_ROCM_610 = "2.5.1+rocm6.1"
    TORCH_251_ROCM_620 = "2.5.1+rocm6.2"
    TORCH_260_MACOS    = "2.6.0"
    TORCH_260_CPU      = "2.6.0+cpu"
    TORCH_260_CUDA_118 = "2.6.0+cu118"
    TORCH_260_CUDA_124 = "2.6.0+cu124"
    TORCH_260_CUDA_126 = "2.6.0+cu126"
    TORCH_260_ROCM_610 = "2.6.0+rocm6.1"
    TORCH_260_ROCM_624 = "2.6.0+rocm6.2.4"
    TORCH_260_XPU      = "2.6.0+xpu"

    # Installation

    @property
    def index(self) -> Optional[str]:
        if (not self.is_plain):
            return TORCH_INDEX_URL_STABLE + (self.flavor or '')

    @property
    def packages(self) -> tuple[str]:
        return (f"torch=={self.value}", "torchvision")

    def install(self) -> None:
        log.special(f"Installing PyTorch version ({self.value})")
        shell(Tools.pip, "install", self.packages, every("--index-url", self.index))

    def uninstall(self) -> None:
        shell(Tools.pip, "uninstall", "--quiet", self.packages)

    # Differentiators

    @property
    def number(self) -> str:
        return self.value.split("+")[0]

    @property
    def flavor(self) -> Optional[str]:
        if len(parts := self.value.split("+")) > 1:
            return parts[1]

    # Util properties

    @property
    def is_plain(self) -> bool:
        return ("+" not in self.value)

    @property
    def is_cuda(self) -> bool:
        return ("+cu" in self.value)

    @property
    def is_rocm(self) -> bool:
        return ("+rocm" in self.value)

    @property
    def is_cpu(self) -> bool:
        return ("+cpu" in self.value)

    @property
    def is_xpu(self) -> bool:
        return ("+xpu" in self.value)

TORCH_251_MACOS

TORCH_251_MACOS = '2.5.1'

TORCH_251_CPU

TORCH_251_CPU = '2.5.1+cpu'

TORCH_251_CUDA_118

TORCH_251_CUDA_118 = '2.5.1+cu118'

TORCH_251_CUDA_121

TORCH_251_CUDA_121 = '2.5.1+cu121'

TORCH_251_CUDA_124

TORCH_251_CUDA_124 = '2.5.1+cu124'

TORCH_251_ROCM_610

TORCH_251_ROCM_610 = '2.5.1+rocm6.1'

TORCH_251_ROCM_620

TORCH_251_ROCM_620 = '2.5.1+rocm6.2'

TORCH_260_MACOS

TORCH_260_MACOS = '2.6.0'

TORCH_260_CPU

TORCH_260_CPU = '2.6.0+cpu'

TORCH_260_CUDA_118

TORCH_260_CUDA_118 = '2.6.0+cu118'

TORCH_260_CUDA_124

TORCH_260_CUDA_124 = '2.6.0+cu124'

TORCH_260_CUDA_126

TORCH_260_CUDA_126 = '2.6.0+cu126'

TORCH_260_ROCM_610

TORCH_260_ROCM_610 = '2.6.0+rocm6.1'

TORCH_260_ROCM_624

TORCH_260_ROCM_624 = '2.6.0+rocm6.2.4'

TORCH_260_XPU

TORCH_260_XPU = '2.6.0+xpu'

index

index: Optional[str]

packages

packages: tuple[str]

install

install() -> None
Source code in Broken/Core/BrokenTorch.py
48
49
50
def install(self) -> None:
    log.special(f"Installing PyTorch version ({self.value})")
    shell(Tools.pip, "install", self.packages, every("--index-url", self.index))

uninstall

uninstall() -> None
Source code in Broken/Core/BrokenTorch.py
52
53
def uninstall(self) -> None:
    shell(Tools.pip, "uninstall", "--quiet", self.packages)

number

number: str

flavor

flavor: Optional[str]

is_plain

is_plain: bool

is_cuda

is_cuda: bool

is_rocm

is_rocm: bool

is_cpu

is_cpu: bool

is_xpu

is_xpu: bool

SimpleTorch

Bases: BrokenEnum

Global torch versions target and suggestions

Source code in Broken/Core/BrokenTorch.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class SimpleTorch(BrokenEnum):
    """Global torch versions target and suggestions"""
    CPU   = TorchRelease.TORCH_260_CPU
    MACOS = TorchRelease.TORCH_260_MACOS
    CUDA  = TorchRelease.TORCH_260_CUDA_124
    ROCM  = TorchRelease.TORCH_260_ROCM_624
    XPU   = TorchRelease.TORCH_260_XPU

    @classmethod
    def prompt_choices(cls) -> Iterable[str]:
        for option in cls:
            if (option is cls.MACOS):
                continue
            yield option.name.lower()

prompt_choices

prompt_choices() -> Iterable[str]
Source code in Broken/Core/BrokenTorch.py
 98
 99
100
101
102
103
@classmethod
def prompt_choices(cls) -> Iterable[str]:
    for option in cls:
        if (option is cls.MACOS):
            continue
        yield option.name.lower()

BrokenTorch

Source code in Broken/Core/BrokenTorch.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
class BrokenTorch:

    @staticmethod
    def docker() -> Iterable[TorchRelease]:
        """List of versions for docker images builds"""
        yield SimpleTorch.CUDA.value
        yield SimpleTorch.CPU.value

    @staticmethod
    def version() -> Optional[Union[TorchRelease, str]]:
        """Current torch version if any, may return a string if not part of known enum"""

        # Note: Reversed as Windows lists system first, and we might have multiple on Docker
        for site_packages in map(Path, reversed(site.getsitepackages())):
            if (script := (site_packages/"torch"/"version.py")).exists():
                exec(script.read_text("utf-8"), namespace := {})
                version = namespace.get("__version__")
                return TorchRelease.get(version) or version

    @BrokenWorker.easy_lock
    @staticmethod
    def install(
        version: Annotated[TorchRelease,
            Option("--version", "-v",
            help="Torch version and flavor to install"
        )]=None,

        exists_ok: Annotated[bool, BrokenTyper.exclude()]=False
    ) -> None:
        """📦 Install or modify PyTorch versions"""

        # Global opt-out of torch management
        if not Environment.flag("BROKEN_TORCH", True):
            return None

        installed = BrokenTorch.version()

        # Only skip if installed and exists_ok, but not 'torch' in sys.argv
        if (exists_ok and (installed or "torch" in sys.argv)):
            return None

        log.special(f"Currently installed PyTorch version: {denum(installed)}")

        # Ask interactively if no flavor was provided
        if not (version := TorchRelease.get(version)):

            # Assume it's a Linux server on NVIDIA
            if (not Runtime.Interactive):
                version = SimpleTorch.CUDA

            # Fixed single version for macOS
            if BrokenPlatform.OnMacOS:
                version = SimpleTorch.MACOS

            else:
                version = BrokenTorch.prompt_flavor()

        if (installed == version):
            log.special("• Requested torch version matches current one!")
            return

        version.install()

    @staticmethod
    def prompt_flavor() -> TorchRelease:
        from rich.prompt import Prompt

        log.special("""
            Generally speaking, you should chose for:
            • [royal_blue1](Windows or Linux)[/] NVIDIA GPU: 'cuda'
            • [royal_blue1](Windows or Linux)[/] Intel ARC: 'xpu'
            • [royal_blue1](Linux)[/] AMD GPU (>= RX 5000): 'rocm'
            • [royal_blue1](Other)[/] Others or CPU: 'cpu'

            [dim]Tip: Set 'HSA_OVERRIDE_GFX_VERSION=10.3.0' for RX 5000 Series[/]
        """, dedent=True)

        try:
            choice = SimpleTorch.get(Prompt.ask(
                prompt="\n:: What PyTorch version do you want to install?\n\n",
                choices=list(SimpleTorch.prompt_choices()),
                default="cuda"
            ).upper())
            print()
        except KeyboardInterrupt:
            exit(0)

        return choice.value

docker

docker() -> Iterable[TorchRelease]

List of versions for docker images builds

Source code in Broken/Core/BrokenTorch.py
109
110
111
112
113
@staticmethod
def docker() -> Iterable[TorchRelease]:
    """List of versions for docker images builds"""
    yield SimpleTorch.CUDA.value
    yield SimpleTorch.CPU.value

version

version() -> Optional[Union[TorchRelease, str]]

Current torch version if any, may return a string if not part of known enum

Source code in Broken/Core/BrokenTorch.py
115
116
117
118
119
120
121
122
123
124
@staticmethod
def version() -> Optional[Union[TorchRelease, str]]:
    """Current torch version if any, may return a string if not part of known enum"""

    # Note: Reversed as Windows lists system first, and we might have multiple on Docker
    for site_packages in map(Path, reversed(site.getsitepackages())):
        if (script := (site_packages/"torch"/"version.py")).exists():
            exec(script.read_text("utf-8"), namespace := {})
            version = namespace.get("__version__")
            return TorchRelease.get(version) or version

install

install(
    version: Annotated[
        TorchRelease,
        Option(
            --version,
            -v,
            help="Torch version and flavor to install",
        ),
    ] = None,
    exists_ok: Annotated[
        bool, BrokenTyper.exclude()
    ] = False,
) -> None

📦 Install or modify PyTorch versions

Source code in Broken/Core/BrokenTorch.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@BrokenWorker.easy_lock
@staticmethod
def install(
    version: Annotated[TorchRelease,
        Option("--version", "-v",
        help="Torch version and flavor to install"
    )]=None,

    exists_ok: Annotated[bool, BrokenTyper.exclude()]=False
) -> None:
    """📦 Install or modify PyTorch versions"""

    # Global opt-out of torch management
    if not Environment.flag("BROKEN_TORCH", True):
        return None

    installed = BrokenTorch.version()

    # Only skip if installed and exists_ok, but not 'torch' in sys.argv
    if (exists_ok and (installed or "torch" in sys.argv)):
        return None

    log.special(f"Currently installed PyTorch version: {denum(installed)}")

    # Ask interactively if no flavor was provided
    if not (version := TorchRelease.get(version)):

        # Assume it's a Linux server on NVIDIA
        if (not Runtime.Interactive):
            version = SimpleTorch.CUDA

        # Fixed single version for macOS
        if BrokenPlatform.OnMacOS:
            version = SimpleTorch.MACOS

        else:
            version = BrokenTorch.prompt_flavor()

    if (installed == version):
        log.special("• Requested torch version matches current one!")
        return

    version.install()

prompt_flavor

prompt_flavor() -> TorchRelease
Source code in Broken/Core/BrokenTorch.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
@staticmethod
def prompt_flavor() -> TorchRelease:
    from rich.prompt import Prompt

    log.special("""
        Generally speaking, you should chose for:
        • [royal_blue1](Windows or Linux)[/] NVIDIA GPU: 'cuda'
        • [royal_blue1](Windows or Linux)[/] Intel ARC: 'xpu'
        • [royal_blue1](Linux)[/] AMD GPU (>= RX 5000): 'rocm'
        • [royal_blue1](Other)[/] Others or CPU: 'cpu'

        [dim]Tip: Set 'HSA_OVERRIDE_GFX_VERSION=10.3.0' for RX 5000 Series[/]
    """, dedent=True)

    try:
        choice = SimpleTorch.get(Prompt.ask(
            prompt="\n:: What PyTorch version do you want to install?\n\n",
            choices=list(SimpleTorch.prompt_choices()),
            default="cuda"
        ).upper())
        print()
    except KeyboardInterrupt:
        exit(0)

    return choice.value