2
2
3
3
import torch
4
4
from torch import Tensor
5
- from torchao .utils import TorchAOBaseTensor
5
+ from torch .utils ._python_dispatch import return_and_correct_aliasing
6
+ from torchao .utils import TorchAOBaseTensor , TORCH_VERSION_AT_LEAST_2_4
6
7
7
8
from .quant_utils import create_dynamic_map , scale_tensor , quantize_4bit_with_qmap , dequant_with_qmap
8
9
@@ -60,8 +61,9 @@ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=No
60
61
def dequantize (self , output_dtype = None ):
61
62
codes = torch .stack ([self .codes >> 4 , self .codes & 0b1111 ], dim = - 1 ) # unpack
62
63
float_data = dequant_with_qmap (codes , self .qmap , self .scale )
63
- dtype = output_dtype or torch .get_default_dtype ()
64
- return float_data .view (self ._shape ).to (dtype )
64
+ if output_dtype is not None :
65
+ float_data = float_data .to (output_dtype )
66
+ return float_data .view (self ._shape )
65
67
66
68
@classmethod
67
69
def zeros (cls , shape , signed : bool = True , block_size : int = 128 , device = None ):
@@ -80,6 +82,24 @@ def __repr__(self):
80
82
)
81
83
82
84
85
+ # in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when
86
+ # dtype is the same but device is different. thus, we must override .to() method instead.
87
+ if not TORCH_VERSION_AT_LEAST_2_4 :
88
+ def _to (self , * args , ** kwargs ):
89
+ # ignore other args/kwargs
90
+ device = kwargs .pop ("device" , None )
91
+ return OptimState4bit (
92
+ self .codes .to (device ),
93
+ self .scale .to (device ),
94
+ self .qmap .to (device ),
95
+ self .signed ,
96
+ self .shape ,
97
+ )
98
+
99
+ OptimState4bit .to = _to
100
+ del _to # make sure to not re-use
101
+
102
+
83
103
@OptimState4bit .implements (aten .copy_ .default )
84
104
def _ (func , types , args , kwargs ):
85
105
dst = args [0 ]
@@ -107,6 +127,20 @@ def _(func, types, args, kwargs):
107
127
return dst
108
128
109
129
130
+ @OptimState4bit .implements (aten ._to_copy .default )
131
+ def _ (func , types , args , kwargs ):
132
+ # ignore dtype
133
+ device = kwargs .get ("device" , None )
134
+ out = OptimState4bit (
135
+ args [0 ].codes .to (device = device ),
136
+ args [0 ].scale .to (device = device ),
137
+ args [0 ].qmap .to (device = device ),
138
+ args [0 ].signed ,
139
+ args [0 ].shape ,
140
+ )
141
+ return return_and_correct_aliasing (func , args , kwargs , out )
142
+
143
+
110
144
@OptimState4bit .implements (aten .lerp .Scalar )
111
145
def _ (func , types , args , kwargs ):
112
146
args = [x .dequantize () if isinstance (x , OptimState4bit ) else x for x in args ]
0 commit comments