Spaces:
Running
Running
make checkpointing optional
Browse files
dalle_mini/modeling_bart_flax.py
CHANGED
|
@@ -252,8 +252,7 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
| 252 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 253 |
)
|
| 254 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
| 255 |
-
|
| 256 |
-
@nn.remat
|
| 257 |
def __call__(
|
| 258 |
self,
|
| 259 |
hidden_states: jnp.ndarray,
|
|
@@ -283,8 +282,9 @@ class FlaxBartEncoderLayerCollection(nn.Module):
|
|
| 283 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 284 |
|
| 285 |
def setup(self):
|
|
|
|
| 286 |
self.layers = [
|
| 287 |
-
|
| 288 |
]
|
| 289 |
|
| 290 |
def __call__(
|
|
@@ -344,8 +344,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 344 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 345 |
)
|
| 346 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
| 347 |
-
|
| 348 |
-
@nn.remat
|
| 349 |
def __call__(
|
| 350 |
self,
|
| 351 |
hidden_states: jnp.ndarray,
|
|
@@ -394,8 +393,9 @@ class FlaxBartDecoderLayerCollection(nn.Module):
|
|
| 394 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 395 |
|
| 396 |
def setup(self):
|
|
|
|
| 397 |
self.layers = [
|
| 398 |
-
|
| 399 |
]
|
| 400 |
|
| 401 |
def __call__(
|
|
|
|
| 252 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 253 |
)
|
| 254 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
| 255 |
+
|
|
|
|
| 256 |
def __call__(
|
| 257 |
self,
|
| 258 |
hidden_states: jnp.ndarray,
|
|
|
|
| 282 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 283 |
|
| 284 |
def setup(self):
|
| 285 |
+
layer_module = nn.remat(FlaxBartEncoderLayer) if self.config.gradient_checkpointing else FlaxBartEncoderLayer
|
| 286 |
self.layers = [
|
| 287 |
+
layer_module(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers)
|
| 288 |
]
|
| 289 |
|
| 290 |
def __call__(
|
|
|
|
| 344 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 345 |
)
|
| 346 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
| 347 |
+
|
|
|
|
| 348 |
def __call__(
|
| 349 |
self,
|
| 350 |
hidden_states: jnp.ndarray,
|
|
|
|
| 393 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 394 |
|
| 395 |
def setup(self):
|
| 396 |
+
layer_module = nn.remat(FlaxBartDecoderLayer) if self.config.gradient_checkpointing else FlaxBartDecoderLayer
|
| 397 |
self.layers = [
|
| 398 |
+
layer_module(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers)
|
| 399 |
]
|
| 400 |
|
| 401 |
def __call__(
|