Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
·
7e48337
1
Parent(s):
2b2be9b
Tokenizer, config, model can be loaded from wandb.
Browse files
src/dalle_mini/model/__init__.py
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
from .configuration import DalleBartConfig
|
| 2 |
from .modeling import DalleBart
|
|
|
|
|
|
| 1 |
from .configuration import DalleBartConfig
|
| 2 |
from .modeling import DalleBart
|
| 3 |
+
from .tokenizer import DalleBartTokenizer
|
src/dalle_mini/model/configuration.py
CHANGED
|
@@ -18,10 +18,12 @@ import warnings
|
|
| 18 |
from transformers.configuration_utils import PretrainedConfig
|
| 19 |
from transformers.utils import logging
|
| 20 |
|
|
|
|
|
|
|
| 21 |
logger = logging.get_logger(__name__)
|
| 22 |
|
| 23 |
|
| 24 |
-
class DalleBartConfig(PretrainedConfig):
|
| 25 |
model_type = "dallebart"
|
| 26 |
keys_to_ignore_at_inference = ["past_key_values"]
|
| 27 |
attribute_map = {
|
|
|
|
| 18 |
from transformers.configuration_utils import PretrainedConfig
|
| 19 |
from transformers.utils import logging
|
| 20 |
|
| 21 |
+
from .wandb_pretrained import PretrainedFromWandbMixin
|
| 22 |
+
|
| 23 |
logger = logging.get_logger(__name__)
|
| 24 |
|
| 25 |
|
| 26 |
+
class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
| 27 |
model_type = "dallebart"
|
| 28 |
keys_to_ignore_at_inference = ["past_key_values"]
|
| 29 |
attribute_map = {
|
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -15,14 +15,12 @@
|
|
| 15 |
""" DalleBart model. """
|
| 16 |
|
| 17 |
import math
|
| 18 |
-
import os
|
| 19 |
from functools import partial
|
| 20 |
from typing import Optional, Tuple
|
| 21 |
|
| 22 |
import flax.linen as nn
|
| 23 |
import jax
|
| 24 |
import jax.numpy as jnp
|
| 25 |
-
import wandb
|
| 26 |
from flax.core.frozen_dict import unfreeze
|
| 27 |
from flax.linen import make_causal_mask
|
| 28 |
from flax.traverse_util import flatten_dict
|
|
@@ -48,6 +46,7 @@ from transformers.models.bart.modeling_flax_bart import (
|
|
| 48 |
from transformers.utils import logging
|
| 49 |
|
| 50 |
from .configuration import DalleBartConfig
|
|
|
|
| 51 |
|
| 52 |
logger = logging.get_logger(__name__)
|
| 53 |
|
|
@@ -421,7 +420,9 @@ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationMod
|
|
| 421 |
)
|
| 422 |
|
| 423 |
|
| 424 |
-
class DalleBart(
|
|
|
|
|
|
|
| 425 |
"""
|
| 426 |
Edits:
|
| 427 |
- renamed from FlaxBartForConditionalGeneration
|
|
@@ -563,24 +564,3 @@ class DalleBart(FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration):
|
|
| 563 |
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
| 564 |
|
| 565 |
return outputs
|
| 566 |
-
|
| 567 |
-
@classmethod
|
| 568 |
-
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 569 |
-
"""
|
| 570 |
-
Initializes from a wandb artifact, or delegates loading to the superclass.
|
| 571 |
-
"""
|
| 572 |
-
if ":" in pretrained_model_name_or_path and not os.path.isdir(
|
| 573 |
-
pretrained_model_name_or_path
|
| 574 |
-
):
|
| 575 |
-
# wandb artifact
|
| 576 |
-
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
| 577 |
-
|
| 578 |
-
# we download everything, including opt_state, so we can resume training if needed
|
| 579 |
-
# see also: #120
|
| 580 |
-
pretrained_model_name_or_path = artifact.download()
|
| 581 |
-
|
| 582 |
-
model = super(DalleBart, cls).from_pretrained(
|
| 583 |
-
pretrained_model_name_or_path, *model_args, **kwargs
|
| 584 |
-
)
|
| 585 |
-
model.config.resolved_name_or_path = pretrained_model_name_or_path
|
| 586 |
-
return model
|
|
|
|
| 15 |
""" DalleBart model. """
|
| 16 |
|
| 17 |
import math
|
|
|
|
| 18 |
from functools import partial
|
| 19 |
from typing import Optional, Tuple
|
| 20 |
|
| 21 |
import flax.linen as nn
|
| 22 |
import jax
|
| 23 |
import jax.numpy as jnp
|
|
|
|
| 24 |
from flax.core.frozen_dict import unfreeze
|
| 25 |
from flax.linen import make_causal_mask
|
| 26 |
from flax.traverse_util import flatten_dict
|
|
|
|
| 46 |
from transformers.utils import logging
|
| 47 |
|
| 48 |
from .configuration import DalleBartConfig
|
| 49 |
+
from .wandb_pretrained import PretrainedFromWandbMixin
|
| 50 |
|
| 51 |
logger = logging.get_logger(__name__)
|
| 52 |
|
|
|
|
| 420 |
)
|
| 421 |
|
| 422 |
|
| 423 |
+
class DalleBart(
|
| 424 |
+
PretrainedFromWandbMixin, FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration
|
| 425 |
+
):
|
| 426 |
"""
|
| 427 |
Edits:
|
| 428 |
- renamed from FlaxBartForConditionalGeneration
|
|
|
|
| 564 |
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
| 565 |
|
| 566 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/dalle_mini/model/tokenizer.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" DalleBart tokenizer """
|
| 2 |
+
from transformers import BartTokenizer
|
| 3 |
+
from transformers.utils import logging
|
| 4 |
+
|
| 5 |
+
from .wandb_pretrained import PretrainedFromWandbMixin
|
| 6 |
+
|
| 7 |
+
logger = logging.get_logger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DalleBartTokenizer(PretrainedFromWandbMixin, BartTokenizer):
|
| 11 |
+
pass
|
src/dalle_mini/model/wandb_pretrained.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import wandb
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PretrainedFromWandbMixin:
|
| 6 |
+
@classmethod
|
| 7 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 8 |
+
"""
|
| 9 |
+
Initializes from a wandb artifact, or delegates loading to the superclass.
|
| 10 |
+
"""
|
| 11 |
+
if ":" in pretrained_model_name_or_path and not os.path.isdir(
|
| 12 |
+
pretrained_model_name_or_path
|
| 13 |
+
):
|
| 14 |
+
# wandb artifact
|
| 15 |
+
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
| 16 |
+
pretrained_model_name_or_path = artifact.download()
|
| 17 |
+
|
| 18 |
+
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|
| 19 |
+
pretrained_model_name_or_path, *model_args, **kwargs
|
| 20 |
+
)
|