forked from intel/intel-extension-for-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
patch
379 lines (357 loc) · 15.8 KB
/
patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
diff --git a/keras_cv/models/stable_diffusion/__internal__/layers/attention_block.py b/keras_cv/models/stable_diffusion/__internal__/layers/attention_block.py
index e7e1896..1624932 100644
--- a/keras_cv/models/stable_diffusion/__internal__/layers/attention_block.py
+++ b/keras_cv/models/stable_diffusion/__internal__/layers/attention_block.py
@@ -14,6 +14,12 @@
import tensorflow as tf
from tensorflow import keras
+try:
+ import intel_extension_for_tensorflow as itex
+ keras.layers.GroupNormalization = itex.ops.GroupNormalization
+except ImportError:
+ pass
+
from keras_cv.models.stable_diffusion.__internal__.layers.padded_conv2d import (
PaddedConv2D,
diff --git a/keras_cv/models/stable_diffusion/__internal__/layers/resnet_block.py b/keras_cv/models/stable_diffusion/__internal__/layers/resnet_block.py
index 29aeaaa..f7d19df 100644
--- a/keras_cv/models/stable_diffusion/__internal__/layers/resnet_block.py
+++ b/keras_cv/models/stable_diffusion/__internal__/layers/resnet_block.py
@@ -13,6 +13,13 @@
# limitations under the License.
from tensorflow import keras
+try:
+ import intel_extension_for_tensorflow as itex
+ keras.layers.GroupNormalization = itex.ops.GroupNormalization
+except ImportError:
+ pass
+
+
from keras_cv.models.stable_diffusion.__internal__.layers.padded_conv2d import (
PaddedConv2D,
diff --git a/keras_cv/models/stable_diffusion/decoder.py b/keras_cv/models/stable_diffusion/decoder.py
index fe619d3..7eb0cf4 100644
--- a/keras_cv/models/stable_diffusion/decoder.py
+++ b/keras_cv/models/stable_diffusion/decoder.py
@@ -13,6 +13,12 @@
# limitations under the License.
from tensorflow import keras
+try:
+ import intel_extension_for_tensorflow as itex
+ keras.layers.GroupNormalization = itex.ops.GroupNormalization
+except ImportError:
+ pass
+
from keras_cv.models.stable_diffusion.__internal__.layers.attention_block import ( # noqa: E501
AttentionBlock,
diff --git a/keras_cv/models/stable_diffusion/diffusion_model.py b/keras_cv/models/stable_diffusion/diffusion_model.py
index 25b5241..dbc5e7f 100644
--- a/keras_cv/models/stable_diffusion/diffusion_model.py
+++ b/keras_cv/models/stable_diffusion/diffusion_model.py
@@ -14,7 +14,14 @@
import tensorflow as tf
from tensorflow import keras
-
+try:
+ import intel_extension_for_tensorflow as itex
+ keras.layers.GroupNormalization = itex.ops.GroupNormalization
+ keras.layers.LayerNormalization = itex.ops.LayerNormalization
+except ImportError:
+ pass
+
+
from keras_cv.models.stable_diffusion.__internal__.layers.padded_conv2d import (
PaddedConv2D,
)
@@ -302,6 +309,26 @@ class CrossAttention(keras.layers.Layer):
self.num_heads = num_heads
self.head_size = head_size
self.out_proj = keras.layers.Dense(num_heads * head_size)
+
+ def naive_scaled_dot_product_attention(self, query, key, value):
+ i_dtype = query.dtype
+ atten_scores = tf.matmul(query, key, transpose_b=True)
+ atten_scores = tf.multiply(atten_scores, tf.cast(self.scale, i_dtype))
+ atten_probs = tf.nn.softmax(atten_scores)
+ # `atten_output` = [B, N, F, H]
+ atten_output = tf.matmul(atten_probs, value)
+ # `atten_output` = [B, F, N, H]
+ atten_output = tf.transpose(a=atten_output, perm=[0, 2, 1, 3])
+ return atten_output
+
+
+ def sdp(self, q, k, v):
+ try:
+ from intel_extension_for_tensorflow.python.ops.multi_head_attention import scaled_dot_product_attention
+ output, _ = scaled_dot_product_attention(q, k, v, use_fast_attention=False)
+ except ImportError:
+ output = self.naive_scaled_dot_product_attention(q, k, v)
+ return output
def call(self, inputs):
inputs, context = inputs
@@ -316,17 +343,10 @@ class CrossAttention(keras.layers.Layer):
)
q = tf.transpose(q, (0, 2, 1, 3)) # (bs, num_heads, time, head_size)
- k = tf.transpose(k, (0, 2, 3, 1)) # (bs, num_heads, head_size, time)
+ k = tf.transpose(k, (0, 2, 1, 3)) # (bs, num_heads, head_size, time)
v = tf.transpose(v, (0, 2, 1, 3)) # (bs, num_heads, time, head_size)
- score = td_dot(q, k) * self.scale
- weights = keras.activations.softmax(
- score
- ) # (bs, num_heads, time, time)
- attn = td_dot(weights, v)
- attn = tf.transpose(
- attn, (0, 2, 1, 3)
- ) # (bs, time, num_heads, head_size)
+ attn = self.sdp(q, k, v)
out = tf.reshape(
attn, (-1, inputs.shape[1], self.num_heads * self.head_size)
)
@@ -352,10 +372,11 @@ class GEGLU(keras.layers.Layer):
def call(self, inputs):
x = self.dense(inputs)
x, gate = x[..., : self.output_dim], x[..., self.output_dim :]
- tanh_res = keras.activations.tanh(
- gate * 0.7978845608 * (1 + 0.044715 * (gate**2))
- )
- return x * 0.5 * gate * (1 + tanh_res)
+ # tanh_res = keras.activations.tanh(
+ # gate * 0.7978845608 * (1 + 0.044715 * (gate**2))
+ # )
+ # return x * 0.5 * gate * (1 + tanh_res)
+ return x * tf.keras.activations.gelu(gate, approximate=True)
def td_dot(a, b):
diff --git a/keras_cv/models/stable_diffusion/image_encoder.py b/keras_cv/models/stable_diffusion/image_encoder.py
index 614b11d..a7aef60 100644
--- a/keras_cv/models/stable_diffusion/image_encoder.py
+++ b/keras_cv/models/stable_diffusion/image_encoder.py
@@ -13,6 +13,12 @@
# limitations under the License.
from tensorflow import keras
+try:
+ import intel_extension_for_tensorflow as itex
+ keras.layers.GroupNormalization = itex.ops.GroupNormalization
+except ImportError:
+ pass
+
from keras_cv.models.stable_diffusion.__internal__.layers.attention_block import ( # noqa: E501
AttentionBlock,
diff --git a/keras_cv/models/stable_diffusion/stable_diffusion.py b/keras_cv/models/stable_diffusion/stable_diffusion.py
index 31752e8..f990a92 100644
--- a/keras_cv/models/stable_diffusion/stable_diffusion.py
+++ b/keras_cv/models/stable_diffusion/stable_diffusion.py
@@ -29,7 +29,8 @@ import math
import numpy as np
import tensorflow as tf
from tensorflow import keras
-
+from keras import backend as K
+import os
from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer
from keras_cv.models.stable_diffusion.constants import _ALPHAS_CUMPROD
from keras_cv.models.stable_diffusion.constants import _UNCONDITIONAL_TOKENS
@@ -51,6 +52,7 @@ class StableDiffusionBase:
img_height=512,
img_width=512,
jit_compile=False,
+ precision="fp32",
):
# UNet requires multiples of 2**7 = 128
img_height = round(img_height / 128) * 128
@@ -66,6 +68,7 @@ class StableDiffusionBase:
self._tokenizer = None
self.jit_compile = jit_compile
+ self.to_fp16 = (precision == "fp16")
def text_to_image(
self,
@@ -207,18 +210,27 @@ class StableDiffusionBase:
# Iterative reverse diffusion stage
timesteps = tf.range(1, 1000, 1000 // num_steps)
+ t_embs_lst = self._get_timesteps_embedding(timesteps, batch_size)
+ contexts = tf.concat((unconditional_context, context), 0)
+
alphas, alphas_prev = self._get_initial_alphas(timesteps)
progbar = keras.utils.Progbar(len(timesteps))
iteration = 0
+ # options = tf.profiler.experimental.ProfilerOptions(host_tracer_level = 3, python_tracer_level = 1, device_tracer_level = 1)
for index, timestep in list(enumerate(timesteps))[::-1]:
+ # if index == 0:
+ # profile_path = '/home/gta_local/zhuwei/workspace/tmp/stable_diffusion/profile_data/' + str(index)
+ # if not os.path.isdir(profile_path):
+ # os.mkdir(profile_path)
+ # tf.profiler.experimental.start(profile_path, options = options)
latent_prev = latent # Set aside the previous latent vector
- t_emb = self._get_timestep_embedding(timestep, batch_size)
- unconditional_latent = self.diffusion_model.predict_on_batch(
- [latent, t_emb, unconditional_context]
- )
- latent = self.diffusion_model.predict_on_batch(
- [latent, t_emb, context]
- )
+ latents = tf.concat((latent, latent), 0)
+ t_embs = t_embs_lst[index]
+
+ pred_latent = self.diffusion_model.predict_on_batch(
+ [latents, t_embs, contexts])
+ unconditional_latent, latent = tf.split(pred_latent, 2)
+
latent = unconditional_latent + unconditional_guidance_scale * (
latent - unconditional_latent
)
@@ -231,6 +243,7 @@ class StableDiffusionBase:
)
iteration += 1
progbar.update(iteration)
+ # tf.profiler.experimental.stop()
# Decoding stage
decoded = self.decoder.predict_on_batch(latent)
@@ -304,6 +317,23 @@ class StableDiffusionBase:
self._tokenizer = SimpleTokenizer()
return self._tokenizer
+ def _get_timesteps_embedding(
+ self, timesteps, batch_size, dim=320, max_period=10000
+ ):
+ half = dim // 2
+ freqs = tf.math.exp(
+ -math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
+ )
+ # timesteps shape: [num_steps]
+ args = tf.cast(tf.reshape(timesteps, [-1, 1]), dtype=tf.float32) * freqs
+ # embeddings shape:(steps, half)
+ embeddings = tf.concat([tf.math.cos(args), tf.math.sin(args)], axis=1)
+ embeddings = tf.expand_dims(embeddings, axis=1)
+ if self.to_fp16:
+ embeddings = tf.cast(embeddings, tf.float16)
+ # 2 is to concatenate the embedding of two forward pass
+ return tf.repeat(embeddings, batch_size * 2, axis=1)
+
def _get_timestep_embedding(
self, timestep, batch_size, dim=320, max_period=10000
):
@@ -314,6 +344,8 @@ class StableDiffusionBase:
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
embedding = tf.reshape(embedding, [1, -1])
+ if self.to_fp16:
+ embedding = tf.cast(embedding, tf.float16)
return tf.repeat(embedding, batch_size, axis=0)
def _get_initial_alphas(self, timesteps):
@@ -324,14 +356,25 @@ class StableDiffusionBase:
def _get_initial_diffusion_noise(self, batch_size, seed):
if seed is not None:
- return tf.random.stateless_normal(
- (batch_size, self.img_height // 8, self.img_width // 8, 4),
- seed=[seed, seed],
- )
+ if self.to_fp16:
+ return tf.random.stateless_normal(
+ (batch_size, self.img_height // 8, self.img_width // 8, 4),
+ seed=[seed, seed], dtype=tf.float16
+ )
+ else:
+ return tf.random.stateless_normal(
+ (batch_size, self.img_height // 8, self.img_width // 8, 4),
+ seed=[seed, seed],
+ )
else:
- return tf.random.normal(
- (batch_size, self.img_height // 8, self.img_width // 8, 4)
- )
+ if self.to_fp16:
+ return tf.random.normal(
+ (batch_size, self.img_height // 8, self.img_width // 8, 4), dtype=tf.float16
+ )
+ else:
+ return tf.random.normal(
+ (batch_size, self.img_height // 8, self.img_width // 8, 4)
+ )
@staticmethod
def _get_pos_ids():
@@ -390,8 +433,9 @@ class StableDiffusion(StableDiffusionBase):
img_height=512,
img_width=512,
jit_compile=False,
+ precision="fp32",
):
- super().__init__(img_height, img_width, jit_compile)
+ super().__init__(img_height, img_width, jit_compile, precision)
print(
"By using this model checkpoint, you acknowledge that its usage is "
"subject to the terms of the CreativeML Open RAIL-M license at "
@@ -405,7 +449,20 @@ class StableDiffusion(StableDiffusionBase):
needs to be modified.
"""
if self._text_encoder is None:
- self._text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
+ if self.to_fp16:
+ self._text_encoder_fp32 = TextEncoder(MAX_PROMPT_LENGTH)
+ weights_fp32 = self._text_encoder_fp32.get_weights()
+ print("text_encoder before: ", np.unique(
+ [w.dtype for w in weights_fp32]), flush=True)
+ K.set_floatx('float16')
+ weights_fp16 = [w.astype(K.floatx()) for w in weights_fp32]
+ self._text_encoder = TextEncoder(
+ MAX_PROMPT_LENGTH, download_weights=False)
+ self._text_encoder.set_weights(weights_fp16)
+ print("text_encoder after: ", np.unique(
+ [w.dtype for w in self._text_encoder.get_weights()]), flush=True)
+ else:
+ self._text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
if self.jit_compile:
self._text_encoder.compile(jit_compile=True)
return self._text_encoder
@@ -420,6 +477,10 @@ class StableDiffusion(StableDiffusionBase):
self._diffusion_model = DiffusionModel(
self.img_height, self.img_width, MAX_PROMPT_LENGTH
)
+ wb = self._diffusion_model.get_weights()
+ print("text_encoder before: ", np.unique(
+ [w.dtype for w in wb]), flush=True)
+
if self.jit_compile:
self._diffusion_model.compile(jit_compile=True)
return self._diffusion_model
@@ -475,8 +536,9 @@ class StableDiffusionV2(StableDiffusionBase):
img_height=512,
img_width=512,
jit_compile=False,
+ precision="fp32",
):
- super().__init__(img_height, img_width, jit_compile)
+ super().__init__(img_height, img_width, jit_compile, precision)
print(
"By using this model checkpoint, you acknowledge that its usage is "
"subject to the terms of the CreativeML Open RAIL++-M license at "
@@ -490,7 +552,20 @@ class StableDiffusionV2(StableDiffusionBase):
needs to be modified.
"""
if self._text_encoder is None:
- self._text_encoder = TextEncoderV2(MAX_PROMPT_LENGTH)
+ if self.to_fp16:
+ self._text_encoder_fp32 = TextEncoderV2(MAX_PROMPT_LENGTH)
+ weights_fp32 = self._text_encoder_fp32.get_weights()
+ print("text_encoder before: ", np.unique(
+ [w.dtype for w in weights_fp32]), flush=True)
+ K.set_floatx('float16')
+ weights_fp16 = [w.astype(K.floatx()) for w in weights_fp32]
+ self._text_encoder = TextEncoderV2(
+ MAX_PROMPT_LENGTH, download_weights=False)
+ self._text_encoder.set_weights(weights_fp16)
+ print("text_encoder after: ", np.unique(
+ [w.dtype for w in self._text_encoder.get_weights()]), flush=True)
+ else:
+ self._text_encoder = TextEncoderV2(MAX_PROMPT_LENGTH)
if self.jit_compile:
self._text_encoder.compile(jit_compile=True)
return self._text_encoder
@@ -505,6 +580,10 @@ class StableDiffusionV2(StableDiffusionBase):
self._diffusion_model = DiffusionModelV2(
self.img_height, self.img_width, MAX_PROMPT_LENGTH
)
+ wb = self._diffusion_model.get_weights()
+ print("diffusion_model before: ", np.unique(
+ [w.dtype for w in wb]), flush=True)
+
if self.jit_compile:
self._diffusion_model.compile(jit_compile=True)
return self._diffusion_model