Spaces:
Runtime error
Runtime error
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +332 -0
- MANIFEST.in +1 -0
- README.md +1 -13
- SETUP.cfg +8 -0
- app.py +245 -0
- dockerfiles/Dockerfile.cpu +17 -0
- dockerfiles/Dockerfile.cuda +38 -0
- examples/train_retriever.py +45 -0
- pyproject.toml +15 -0
- relik/__init__.py +1 -0
- relik/common/__init__.py +0 -0
- relik/common/log.py +97 -0
- relik/common/upload.py +128 -0
- relik/common/utils.py +609 -0
- relik/inference/__init__.py +0 -0
- relik/inference/annotator.py +422 -0
- relik/inference/data/__init__.py +0 -0
- relik/inference/data/objects.py +64 -0
- relik/inference/data/tokenizers/__init__.py +89 -0
- relik/inference/data/tokenizers/base_tokenizer.py +84 -0
- relik/inference/data/tokenizers/regex_tokenizer.py +73 -0
- relik/inference/data/tokenizers/spacy_tokenizer.py +228 -0
- relik/inference/data/tokenizers/whitespace_tokenizer.py +70 -0
- relik/inference/data/window/__init__.py +0 -0
- relik/inference/data/window/manager.py +262 -0
- relik/inference/gerbil.py +254 -0
- relik/inference/preprocessing.py +4 -0
- relik/inference/serve/__init__.py +0 -0
- relik/inference/serve/backend/__init__.py +0 -0
- relik/inference/serve/backend/relik.py +210 -0
- relik/inference/serve/backend/retriever.py +206 -0
- relik/inference/serve/backend/utils.py +29 -0
- relik/inference/serve/frontend/__init__.py +0 -0
- relik/inference/serve/frontend/relik.py +231 -0
- relik/inference/serve/frontend/style.css +33 -0
- relik/reader/__init__.py +0 -0
- relik/reader/conf/config.yaml +14 -0
- relik/reader/conf/data/base.yaml +21 -0
- relik/reader/conf/data/re.yaml +54 -0
- relik/reader/conf/training/base.yaml +12 -0
- relik/reader/conf/training/re.yaml +12 -0
- relik/reader/data/__init__.py +0 -0
- relik/reader/data/patches.py +51 -0
- relik/reader/data/relik_reader_data.py +965 -0
- relik/reader/data/relik_reader_data_utils.py +51 -0
- relik/reader/data/relik_reader_sample.py +49 -0
- relik/reader/lightning_modules/__init__.py +0 -0
- relik/reader/lightning_modules/relik_reader_pl_module.py +50 -0
- relik/reader/lightning_modules/relik_reader_re_pl_module.py +54 -0
- relik/reader/pytorch_modules/__init__.py +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# custom
|
| 2 |
+
|
| 3 |
+
data/*
|
| 4 |
+
experiments/*
|
| 5 |
+
retrievers
|
| 6 |
+
outputs
|
| 7 |
+
model
|
| 8 |
+
wandb
|
| 9 |
+
|
| 10 |
+
# Created by https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos
|
| 11 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos
|
| 12 |
+
|
| 13 |
+
### JetBrains+all ###
|
| 14 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
| 15 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
| 16 |
+
|
| 17 |
+
# User-specific stuff
|
| 18 |
+
.idea/**/workspace.xml
|
| 19 |
+
.idea/**/tasks.xml
|
| 20 |
+
.idea/**/usage.statistics.xml
|
| 21 |
+
.idea/**/dictionaries
|
| 22 |
+
.idea/**/shelf
|
| 23 |
+
|
| 24 |
+
# Generated files
|
| 25 |
+
.idea/**/contentModel.xml
|
| 26 |
+
|
| 27 |
+
# Sensitive or high-churn files
|
| 28 |
+
.idea/**/dataSources/
|
| 29 |
+
.idea/**/dataSources.ids
|
| 30 |
+
.idea/**/dataSources.local.xml
|
| 31 |
+
.idea/**/sqlDataSources.xml
|
| 32 |
+
.idea/**/dynamic.xml
|
| 33 |
+
.idea/**/uiDesigner.xml
|
| 34 |
+
.idea/**/dbnavigator.xml
|
| 35 |
+
|
| 36 |
+
# Gradle
|
| 37 |
+
.idea/**/gradle.xml
|
| 38 |
+
.idea/**/libraries
|
| 39 |
+
|
| 40 |
+
# Gradle and Maven with auto-import
|
| 41 |
+
# When using Gradle or Maven with auto-import, you should exclude module files,
|
| 42 |
+
# since they will be recreated, and may cause churn. Uncomment if using
|
| 43 |
+
# auto-import.
|
| 44 |
+
# .idea/artifacts
|
| 45 |
+
# .idea/compiler.xml
|
| 46 |
+
# .idea/jarRepositories.xml
|
| 47 |
+
# .idea/modules.xml
|
| 48 |
+
# .idea/*.iml
|
| 49 |
+
# .idea/modules
|
| 50 |
+
# *.iml
|
| 51 |
+
# *.ipr
|
| 52 |
+
|
| 53 |
+
# CMake
|
| 54 |
+
cmake-build-*/
|
| 55 |
+
|
| 56 |
+
# Mongo Explorer plugin
|
| 57 |
+
.idea/**/mongoSettings.xml
|
| 58 |
+
|
| 59 |
+
# File-based project format
|
| 60 |
+
*.iws
|
| 61 |
+
|
| 62 |
+
# IntelliJ
|
| 63 |
+
out/
|
| 64 |
+
|
| 65 |
+
# mpeltonen/sbt-idea plugin
|
| 66 |
+
.idea_modules/
|
| 67 |
+
|
| 68 |
+
# JIRA plugin
|
| 69 |
+
atlassian-ide-plugin.xml
|
| 70 |
+
|
| 71 |
+
# Cursive Clojure plugin
|
| 72 |
+
.idea/replstate.xml
|
| 73 |
+
|
| 74 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
| 75 |
+
com_crashlytics_export_strings.xml
|
| 76 |
+
crashlytics.properties
|
| 77 |
+
crashlytics-build.properties
|
| 78 |
+
fabric.properties
|
| 79 |
+
|
| 80 |
+
# Editor-based Rest Client
|
| 81 |
+
.idea/httpRequests
|
| 82 |
+
|
| 83 |
+
# Android studio 3.1+ serialized cache file
|
| 84 |
+
.idea/caches/build_file_checksums.ser
|
| 85 |
+
|
| 86 |
+
### JetBrains+all Patch ###
|
| 87 |
+
# Ignores the whole .idea folder and all .iml files
|
| 88 |
+
# See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360
|
| 89 |
+
|
| 90 |
+
.idea/
|
| 91 |
+
|
| 92 |
+
# Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023
|
| 93 |
+
|
| 94 |
+
*.iml
|
| 95 |
+
modules.xml
|
| 96 |
+
.idea/misc.xml
|
| 97 |
+
*.ipr
|
| 98 |
+
|
| 99 |
+
# Sonarlint plugin
|
| 100 |
+
.idea/sonarlint
|
| 101 |
+
|
| 102 |
+
### JupyterNotebooks ###
|
| 103 |
+
# gitignore template for Jupyter Notebooks
|
| 104 |
+
# website: http://jupyter.org/
|
| 105 |
+
|
| 106 |
+
.ipynb_checkpoints
|
| 107 |
+
*/.ipynb_checkpoints/*
|
| 108 |
+
|
| 109 |
+
# IPython
|
| 110 |
+
profile_default/
|
| 111 |
+
ipython_config.py
|
| 112 |
+
|
| 113 |
+
# Remove previous ipynb_checkpoints
|
| 114 |
+
# git rm -r .ipynb_checkpoints/
|
| 115 |
+
|
| 116 |
+
### Linux ###
|
| 117 |
+
*~
|
| 118 |
+
|
| 119 |
+
# temporary files which can be created if a process still has a handle open of a deleted file
|
| 120 |
+
.fuse_hidden*
|
| 121 |
+
|
| 122 |
+
# KDE directory preferences
|
| 123 |
+
.directory
|
| 124 |
+
|
| 125 |
+
# Linux trash folder which might appear on any partition or disk
|
| 126 |
+
.Trash-*
|
| 127 |
+
|
| 128 |
+
# .nfs files are created when an open file is removed but is still being accessed
|
| 129 |
+
.nfs*
|
| 130 |
+
|
| 131 |
+
### macOS ###
|
| 132 |
+
# General
|
| 133 |
+
.DS_Store
|
| 134 |
+
.AppleDouble
|
| 135 |
+
.LSOverride
|
| 136 |
+
|
| 137 |
+
# Icon must end with two \r
|
| 138 |
+
Icon
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# Thumbnails
|
| 142 |
+
._*
|
| 143 |
+
|
| 144 |
+
# Files that might appear in the root of a volume
|
| 145 |
+
.DocumentRevisions-V100
|
| 146 |
+
.fseventsd
|
| 147 |
+
.Spotlight-V100
|
| 148 |
+
.TemporaryItems
|
| 149 |
+
.Trashes
|
| 150 |
+
.VolumeIcon.icns
|
| 151 |
+
.com.apple.timemachine.donotpresent
|
| 152 |
+
|
| 153 |
+
# Directories potentially created on remote AFP share
|
| 154 |
+
.AppleDB
|
| 155 |
+
.AppleDesktop
|
| 156 |
+
Network Trash Folder
|
| 157 |
+
Temporary Items
|
| 158 |
+
.apdisk
|
| 159 |
+
|
| 160 |
+
### Python ###
|
| 161 |
+
# Byte-compiled / optimized / DLL files
|
| 162 |
+
__pycache__/
|
| 163 |
+
*.py[cod]
|
| 164 |
+
*$py.class
|
| 165 |
+
|
| 166 |
+
# C extensions
|
| 167 |
+
*.so
|
| 168 |
+
|
| 169 |
+
# Distribution / packaging
|
| 170 |
+
.Python
|
| 171 |
+
build/
|
| 172 |
+
develop-eggs/
|
| 173 |
+
dist/
|
| 174 |
+
downloads/
|
| 175 |
+
eggs/
|
| 176 |
+
.eggs/
|
| 177 |
+
lib/
|
| 178 |
+
lib64/
|
| 179 |
+
parts/
|
| 180 |
+
sdist/
|
| 181 |
+
var/
|
| 182 |
+
wheels/
|
| 183 |
+
pip-wheel-metadata/
|
| 184 |
+
share/python-wheels/
|
| 185 |
+
*.egg-info/
|
| 186 |
+
.installed.cfg
|
| 187 |
+
*.egg
|
| 188 |
+
MANIFEST
|
| 189 |
+
|
| 190 |
+
# PyInstaller
|
| 191 |
+
# Usually these files are written by a python script from a template
|
| 192 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 193 |
+
*.manifest
|
| 194 |
+
*.spec
|
| 195 |
+
|
| 196 |
+
# Installer logs
|
| 197 |
+
pip-log.txt
|
| 198 |
+
pip-delete-this-directory.txt
|
| 199 |
+
|
| 200 |
+
# Unit test / coverage reports
|
| 201 |
+
htmlcov/
|
| 202 |
+
.tox/
|
| 203 |
+
.nox/
|
| 204 |
+
.coverage
|
| 205 |
+
.coverage.*
|
| 206 |
+
.cache
|
| 207 |
+
nosetests.xml
|
| 208 |
+
coverage.xml
|
| 209 |
+
*.cover
|
| 210 |
+
*.py,cover
|
| 211 |
+
.hypothesis/
|
| 212 |
+
.pytest_cache/
|
| 213 |
+
pytestdebug.log
|
| 214 |
+
|
| 215 |
+
# Translations
|
| 216 |
+
*.mo
|
| 217 |
+
*.pot
|
| 218 |
+
|
| 219 |
+
# Django stuff:
|
| 220 |
+
*.log
|
| 221 |
+
local_settings.py
|
| 222 |
+
db.sqlite3
|
| 223 |
+
db.sqlite3-journal
|
| 224 |
+
|
| 225 |
+
# Flask stuff:
|
| 226 |
+
instance/
|
| 227 |
+
.webassets-cache
|
| 228 |
+
|
| 229 |
+
# Scrapy stuff:
|
| 230 |
+
.scrapy
|
| 231 |
+
|
| 232 |
+
# Sphinx documentation
|
| 233 |
+
docs/_build/
|
| 234 |
+
doc/_build/
|
| 235 |
+
|
| 236 |
+
# PyBuilder
|
| 237 |
+
target/
|
| 238 |
+
|
| 239 |
+
# Jupyter Notebook
|
| 240 |
+
|
| 241 |
+
# IPython
|
| 242 |
+
|
| 243 |
+
# pyenv
|
| 244 |
+
.python-version
|
| 245 |
+
|
| 246 |
+
# pipenv
|
| 247 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 248 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 249 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 250 |
+
# install all needed dependencies.
|
| 251 |
+
#Pipfile.lock
|
| 252 |
+
|
| 253 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 254 |
+
__pypackages__/
|
| 255 |
+
|
| 256 |
+
# Celery stuff
|
| 257 |
+
celerybeat-schedule
|
| 258 |
+
celerybeat.pid
|
| 259 |
+
|
| 260 |
+
# SageMath parsed files
|
| 261 |
+
*.sage.py
|
| 262 |
+
|
| 263 |
+
# Environments
|
| 264 |
+
.env
|
| 265 |
+
.venv
|
| 266 |
+
env/
|
| 267 |
+
venv/
|
| 268 |
+
ENV/
|
| 269 |
+
env.bak/
|
| 270 |
+
venv.bak/
|
| 271 |
+
pythonenv*
|
| 272 |
+
|
| 273 |
+
# Spyder project settings
|
| 274 |
+
.spyderproject
|
| 275 |
+
.spyproject
|
| 276 |
+
|
| 277 |
+
# Rope project settings
|
| 278 |
+
.ropeproject
|
| 279 |
+
|
| 280 |
+
# mkdocs documentation
|
| 281 |
+
/site
|
| 282 |
+
|
| 283 |
+
# mypy
|
| 284 |
+
.mypy_cache/
|
| 285 |
+
.dmypy.json
|
| 286 |
+
dmypy.json
|
| 287 |
+
|
| 288 |
+
# Pyre type checker
|
| 289 |
+
.pyre/
|
| 290 |
+
|
| 291 |
+
# pytype static type analyzer
|
| 292 |
+
.pytype/
|
| 293 |
+
|
| 294 |
+
# profiling data
|
| 295 |
+
.prof
|
| 296 |
+
|
| 297 |
+
### vscode ###
|
| 298 |
+
.vscode
|
| 299 |
+
.vscode/*
|
| 300 |
+
!.vscode/settings.json
|
| 301 |
+
!.vscode/tasks.json
|
| 302 |
+
!.vscode/launch.json
|
| 303 |
+
!.vscode/extensions.json
|
| 304 |
+
*.code-workspace
|
| 305 |
+
|
| 306 |
+
### Windows ###
|
| 307 |
+
# Windows thumbnail cache files
|
| 308 |
+
Thumbs.db
|
| 309 |
+
Thumbs.db:encryptable
|
| 310 |
+
ehthumbs.db
|
| 311 |
+
ehthumbs_vista.db
|
| 312 |
+
|
| 313 |
+
# Dump file
|
| 314 |
+
*.stackdump
|
| 315 |
+
|
| 316 |
+
# Folder config file
|
| 317 |
+
[Dd]esktop.ini
|
| 318 |
+
|
| 319 |
+
# Recycle Bin used on file shares
|
| 320 |
+
$RECYCLE.BIN/
|
| 321 |
+
|
| 322 |
+
# Windows Installer files
|
| 323 |
+
*.cab
|
| 324 |
+
*.msi
|
| 325 |
+
*.msix
|
| 326 |
+
*.msm
|
| 327 |
+
*.msp
|
| 328 |
+
|
| 329 |
+
# Windows shortcuts
|
| 330 |
+
*.lnk
|
| 331 |
+
|
| 332 |
+
# End of https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos
|
MANIFEST.in
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
include requirements.txt
|
README.md
CHANGED
|
@@ -1,13 +1 @@
|
|
| 1 |
-
|
| 2 |
-
title: Relik
|
| 3 |
-
emoji: 🐨
|
| 4 |
-
colorFrom: gray
|
| 5 |
-
colorTo: pink
|
| 6 |
-
sdk: streamlit
|
| 7 |
-
sdk_version: 1.27.2
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: mit
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
# relik
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SETUP.cfg
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[metadata]
|
| 2 |
+
description-file = README.md
|
| 3 |
+
|
| 4 |
+
[build]
|
| 5 |
+
build-base = /tmp/build
|
| 6 |
+
|
| 7 |
+
[egg_info]
|
| 8 |
+
egg-base = /tmp
|
app.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import time
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import requests
|
| 7 |
+
import streamlit as st
|
| 8 |
+
from spacy import displacy
|
| 9 |
+
from streamlit_extras.badges import badge
|
| 10 |
+
from streamlit_extras.stylable_container import stylable_container
|
| 11 |
+
|
| 12 |
+
# RELIK = os.getenv("RELIK", "localhost:8000/api/entities")
|
| 13 |
+
|
| 14 |
+
import random
|
| 15 |
+
|
| 16 |
+
from relik.inference.annotator import Relik
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_random_color(ents):
|
| 20 |
+
colors = {}
|
| 21 |
+
random_colors = generate_pastel_colors(len(ents))
|
| 22 |
+
for ent in ents:
|
| 23 |
+
colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
|
| 24 |
+
return colors
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def floatrange(start, stop, steps):
|
| 28 |
+
if int(steps) == 1:
|
| 29 |
+
return [stop]
|
| 30 |
+
return [
|
| 31 |
+
start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def hsl_to_rgb(h, s, l):
|
| 36 |
+
def hue_2_rgb(v1, v2, v_h):
|
| 37 |
+
while v_h < 0.0:
|
| 38 |
+
v_h += 1.0
|
| 39 |
+
while v_h > 1.0:
|
| 40 |
+
v_h -= 1.0
|
| 41 |
+
if 6 * v_h < 1.0:
|
| 42 |
+
return v1 + (v2 - v1) * 6.0 * v_h
|
| 43 |
+
if 2 * v_h < 1.0:
|
| 44 |
+
return v2
|
| 45 |
+
if 3 * v_h < 2.0:
|
| 46 |
+
return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
|
| 47 |
+
return v1
|
| 48 |
+
|
| 49 |
+
# if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
|
| 50 |
+
# if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
|
| 51 |
+
|
| 52 |
+
r, b, g = (l * 255,) * 3
|
| 53 |
+
if s != 0.0:
|
| 54 |
+
if l < 0.5:
|
| 55 |
+
var_2 = l * (1.0 + s)
|
| 56 |
+
else:
|
| 57 |
+
var_2 = (l + s) - (s * l)
|
| 58 |
+
var_1 = 2.0 * l - var_2
|
| 59 |
+
r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
|
| 60 |
+
g = 255 * hue_2_rgb(var_1, var_2, h)
|
| 61 |
+
b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
|
| 62 |
+
|
| 63 |
+
return int(round(r)), int(round(g)), int(round(b))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def generate_pastel_colors(n):
|
| 67 |
+
"""Return different pastel colours.
|
| 68 |
+
|
| 69 |
+
Input:
|
| 70 |
+
n (integer) : The number of colors to return
|
| 71 |
+
|
| 72 |
+
Output:
|
| 73 |
+
A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
|
| 74 |
+
|
| 75 |
+
Example:
|
| 76 |
+
>>> print generate_pastel_colors(5)
|
| 77 |
+
['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
|
| 78 |
+
"""
|
| 79 |
+
if n == 0:
|
| 80 |
+
return []
|
| 81 |
+
|
| 82 |
+
# To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
|
| 83 |
+
start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue
|
| 84 |
+
saturation = 1.0
|
| 85 |
+
lightness = 0.8
|
| 86 |
+
# We take points around the chromatic circle (hue):
|
| 87 |
+
# (Note: we generate n+1 colors, then drop the last one ([:-1]) because
|
| 88 |
+
# it equals the first one (hue 0 = hue 1))
|
| 89 |
+
return [
|
| 90 |
+
"#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
|
| 91 |
+
for hue in floatrange(start_hue, start_hue + 1, n + 1)
|
| 92 |
+
][:-1]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def set_sidebar(css):
|
| 96 |
+
white_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>"
|
| 97 |
+
with st.sidebar:
|
| 98 |
+
st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
|
| 99 |
+
st.image(
|
| 100 |
+
"http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
|
| 101 |
+
use_column_width=True,
|
| 102 |
+
)
|
| 103 |
+
st.markdown("## ReLiK")
|
| 104 |
+
st.write(
|
| 105 |
+
f"""
|
| 106 |
+
- {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i> Paper")}
|
| 107 |
+
- {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i> GitHub")}
|
| 108 |
+
- {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i> Docker Hub")}
|
| 109 |
+
""",
|
| 110 |
+
unsafe_allow_html=True,
|
| 111 |
+
)
|
| 112 |
+
st.markdown("## Sapienza NLP")
|
| 113 |
+
st.write(
|
| 114 |
+
f"""
|
| 115 |
+
- {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i> Webpage")}
|
| 116 |
+
- {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i> GitHub")}
|
| 117 |
+
- {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i> Twitter")}
|
| 118 |
+
- {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i> LinkedIn")}
|
| 119 |
+
""",
|
| 120 |
+
unsafe_allow_html=True,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_el_annotations(response):
|
| 125 |
+
# swap labels key with ents
|
| 126 |
+
dict_of_ents = {"text": response.text, "ents": []}
|
| 127 |
+
dict_of_ents["ents"] = response.labels
|
| 128 |
+
label_in_text = set(l["label"] for l in dict_of_ents["ents"])
|
| 129 |
+
options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
|
| 130 |
+
return dict_of_ents, options
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def set_intro(css):
|
| 134 |
+
# intro
|
| 135 |
+
st.markdown("# ReLik")
|
| 136 |
+
st.markdown(
|
| 137 |
+
"### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget"
|
| 138 |
+
)
|
| 139 |
+
# st.markdown(
|
| 140 |
+
# "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
|
| 141 |
+
# "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by "
|
| 142 |
+
# "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
|
| 143 |
+
# "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
|
| 144 |
+
# )
|
| 145 |
+
badge(type="github", name="sapienzanlp/relik")
|
| 146 |
+
badge(type="pypi", name="relik")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def run_client():
|
| 150 |
+
with open(Path(__file__).parent / "style.css") as f:
|
| 151 |
+
css = f.read()
|
| 152 |
+
|
| 153 |
+
st.set_page_config(
|
| 154 |
+
page_title="ReLik",
|
| 155 |
+
page_icon="🦮",
|
| 156 |
+
layout="wide",
|
| 157 |
+
)
|
| 158 |
+
set_sidebar(css)
|
| 159 |
+
set_intro(css)
|
| 160 |
+
|
| 161 |
+
# text input
|
| 162 |
+
text = st.text_area(
|
| 163 |
+
"Enter Text Below:",
|
| 164 |
+
value="Obama went to Rome for a quick vacation.",
|
| 165 |
+
height=200,
|
| 166 |
+
max_chars=500,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
with stylable_container(
|
| 170 |
+
key="annotate_button",
|
| 171 |
+
css_styles="""
|
| 172 |
+
button {
|
| 173 |
+
background-color: #802433;
|
| 174 |
+
color: white;
|
| 175 |
+
border-radius: 25px;
|
| 176 |
+
}
|
| 177 |
+
""",
|
| 178 |
+
):
|
| 179 |
+
submit = st.button("Annotate")
|
| 180 |
+
# submit = st.button("Run")
|
| 181 |
+
|
| 182 |
+
relik = Relik(
|
| 183 |
+
question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
|
| 184 |
+
document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
|
| 185 |
+
reader="riccorl/relik-reader-aida-deberta-small",
|
| 186 |
+
top_k=100,
|
| 187 |
+
window_size=32,
|
| 188 |
+
window_stride=16,
|
| 189 |
+
candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing",
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# ReLik API call
|
| 193 |
+
if submit:
|
| 194 |
+
text = text.strip()
|
| 195 |
+
if text:
|
| 196 |
+
st.markdown("####")
|
| 197 |
+
st.markdown("#### Entity Linking")
|
| 198 |
+
with st.spinner(text="In progress"):
|
| 199 |
+
response = relik(text)
|
| 200 |
+
# response = requests.post(RELIK, json=text)
|
| 201 |
+
# if response.status_code != 200:
|
| 202 |
+
# st.error("Error: {}".format(response.status_code))
|
| 203 |
+
# else:
|
| 204 |
+
# response = response.json()
|
| 205 |
+
|
| 206 |
+
# Entity Linking
|
| 207 |
+
# with stylable_container(
|
| 208 |
+
# key="container_with_border",
|
| 209 |
+
# css_styles="""
|
| 210 |
+
# {
|
| 211 |
+
# border: 1px solid rgba(49, 51, 63, 0.2);
|
| 212 |
+
# border-radius: 0.5rem;
|
| 213 |
+
# padding: 0.5rem;
|
| 214 |
+
# padding-bottom: 2rem;
|
| 215 |
+
# }
|
| 216 |
+
# """,
|
| 217 |
+
# ):
|
| 218 |
+
# st.markdown("##")
|
| 219 |
+
dict_of_ents, options = get_el_annotations(response=response)
|
| 220 |
+
display = displacy.render(
|
| 221 |
+
dict_of_ents, manual=True, style="ent", options=options
|
| 222 |
+
)
|
| 223 |
+
display = display.replace("\n", " ")
|
| 224 |
+
# wsd_display = re.sub(
|
| 225 |
+
# r"(wiki::\d+\w)",
|
| 226 |
+
# r"<a href='https://babelnet.org/synset?id=\g<1>&orig=\g<1>&lang={}'>\g<1></a>".format(
|
| 227 |
+
# language.upper()
|
| 228 |
+
# ),
|
| 229 |
+
# wsd_display,
|
| 230 |
+
# )
|
| 231 |
+
with st.container():
|
| 232 |
+
st.write(display, unsafe_allow_html=True)
|
| 233 |
+
|
| 234 |
+
st.markdown("####")
|
| 235 |
+
st.markdown("#### Relation Extraction")
|
| 236 |
+
|
| 237 |
+
with st.container():
|
| 238 |
+
st.write("Coming :)", unsafe_allow_html=True)
|
| 239 |
+
|
| 240 |
+
else:
|
| 241 |
+
st.error("Please enter some text.")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == "__main__":
|
| 245 |
+
run_client()
|
dockerfiles/Dockerfile.cpu
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM tiangolo/uvicorn-gunicorn:python3.10-slim
|
| 2 |
+
|
| 3 |
+
# Copy and install requirements.txt
|
| 4 |
+
COPY ./requirements.txt ./requirements.txt
|
| 5 |
+
COPY ./src /app
|
| 6 |
+
COPY ./scripts/start.sh /start.sh
|
| 7 |
+
COPY ./scripts/prestart.sh /app
|
| 8 |
+
COPY ./scripts/gunicorn_conf.py /gunicorn_conf.py
|
| 9 |
+
COPY ./scripts/start-reload.sh /start-reload.sh
|
| 10 |
+
COPY ./VERSION /
|
| 11 |
+
RUN mkdir -p /app/resources/model \
|
| 12 |
+
&& pip install --no-cache-dir -r requirements.txt \
|
| 13 |
+
&& chmod +x /start.sh && chmod +x /start-reload.sh
|
| 14 |
+
ARG MODEL_PATH
|
| 15 |
+
COPY ${MODEL_PATH}/* /app/resources/model/
|
| 16 |
+
|
| 17 |
+
ENV APP_MODULE=main:app
|
dockerfiles/Dockerfile.cuda
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.2.0-base-ubuntu20.04
|
| 2 |
+
|
| 3 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
| 4 |
+
|
| 5 |
+
RUN apt-get update \
|
| 6 |
+
&& apt-get install \
|
| 7 |
+
curl wget python3.10 \
|
| 8 |
+
python3.10-distutils \
|
| 9 |
+
python3-pip \
|
| 10 |
+
curl wget -y \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
|
| 13 |
+
# FastAPI section
|
| 14 |
+
# device env
|
| 15 |
+
ENV DEVICE="cuda"
|
| 16 |
+
# Copy and install requirements.txt
|
| 17 |
+
COPY ./gpu-requirements.txt ./requirements.txt
|
| 18 |
+
COPY ./src /app
|
| 19 |
+
COPY ./scripts/start.sh /start.sh
|
| 20 |
+
COPY ./scripts/gunicorn_conf.py /gunicorn_conf.py
|
| 21 |
+
COPY ./scripts/start-reload.sh /start-reload.sh
|
| 22 |
+
COPY ./scripts/prestart.sh /app
|
| 23 |
+
COPY ./VERSION /
|
| 24 |
+
RUN mkdir -p /app/resources/model \
|
| 25 |
+
&& pip install --upgrade --no-cache-dir -r requirements.txt \
|
| 26 |
+
&& chmod +x /start.sh \
|
| 27 |
+
&& chmod +x /start-reload.sh
|
| 28 |
+
ARG MODEL_NAME_OR_PATH
|
| 29 |
+
|
| 30 |
+
WORKDIR /app
|
| 31 |
+
|
| 32 |
+
ENV PYTHONPATH=/app
|
| 33 |
+
|
| 34 |
+
EXPOSE 80
|
| 35 |
+
|
| 36 |
+
# Run the start script, it will check for an /app/prestart.sh script (e.g. for migrations)
|
| 37 |
+
# And then will start Gunicorn with Uvicorn
|
| 38 |
+
CMD ["/start.sh"]
|
examples/train_retriever.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from relik.retriever.trainer import RetrieverTrainer
|
| 2 |
+
from relik import GoldenRetriever
|
| 3 |
+
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
|
| 4 |
+
from relik.retriever.data.datasets import AidaInBatchNegativesDataset
|
| 5 |
+
|
| 6 |
+
if __name__ == "__main__":
|
| 7 |
+
# instantiate retriever
|
| 8 |
+
document_index = InMemoryDocumentIndex(
|
| 9 |
+
documents="/root/golden-retriever-v2/data/dpr-like/el/definitions.txt",
|
| 10 |
+
device="cuda",
|
| 11 |
+
precision="16",
|
| 12 |
+
)
|
| 13 |
+
retriever = GoldenRetriever(
|
| 14 |
+
question_encoder="intfloat/e5-small-v2", document_index=document_index
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
train_dataset = AidaInBatchNegativesDataset(
|
| 18 |
+
name="aida_train",
|
| 19 |
+
path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/train.jsonl",
|
| 20 |
+
tokenizer=retriever.question_tokenizer,
|
| 21 |
+
question_batch_size=64,
|
| 22 |
+
passage_batch_size=400,
|
| 23 |
+
max_passage_length=64,
|
| 24 |
+
use_topics=True,
|
| 25 |
+
shuffle=True,
|
| 26 |
+
)
|
| 27 |
+
val_dataset = AidaInBatchNegativesDataset(
|
| 28 |
+
name="aida_val",
|
| 29 |
+
path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/val.jsonl",
|
| 30 |
+
tokenizer=retriever.question_tokenizer,
|
| 31 |
+
question_batch_size=64,
|
| 32 |
+
passage_batch_size=400,
|
| 33 |
+
max_passage_length=64,
|
| 34 |
+
use_topics=True,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
trainer = RetrieverTrainer(
|
| 38 |
+
retriever=retriever,
|
| 39 |
+
train_dataset=train_dataset,
|
| 40 |
+
val_dataset=val_dataset,
|
| 41 |
+
max_steps=25_000,
|
| 42 |
+
wandb_offline_mode=True,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
trainer.train()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.black]
|
| 2 |
+
include = '\.pyi?$'
|
| 3 |
+
exclude = '''
|
| 4 |
+
/(
|
| 5 |
+
\.git
|
| 6 |
+
| \.hg
|
| 7 |
+
| \.mypy_cache
|
| 8 |
+
| \.tox
|
| 9 |
+
| \.venv
|
| 10 |
+
| _build
|
| 11 |
+
| buck-out
|
| 12 |
+
| build
|
| 13 |
+
| dist
|
| 14 |
+
)/
|
| 15 |
+
'''
|
relik/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from relik.retriever.pytorch_modules.model import GoldenRetriever
|
relik/common/__init__.py
ADDED
|
File without changes
|
relik/common/log.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import sys
|
| 3 |
+
import threading
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
from rich import get_console
|
| 7 |
+
|
| 8 |
+
_lock = threading.Lock()
|
| 9 |
+
_default_handler: Optional[logging.Handler] = None
|
| 10 |
+
|
| 11 |
+
_default_log_level = logging.WARNING
|
| 12 |
+
|
| 13 |
+
# fancy logger
|
| 14 |
+
_console = get_console()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _get_library_name() -> str:
|
| 18 |
+
return __name__.split(".")[0]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _get_library_root_logger() -> logging.Logger:
|
| 22 |
+
return logging.getLogger(_get_library_name())
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _configure_library_root_logger() -> None:
|
| 26 |
+
global _default_handler
|
| 27 |
+
|
| 28 |
+
with _lock:
|
| 29 |
+
if _default_handler:
|
| 30 |
+
# This library has already configured the library root logger.
|
| 31 |
+
return
|
| 32 |
+
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
|
| 33 |
+
_default_handler.flush = sys.stderr.flush
|
| 34 |
+
|
| 35 |
+
# Apply our default configuration to the library root logger.
|
| 36 |
+
library_root_logger = _get_library_root_logger()
|
| 37 |
+
library_root_logger.addHandler(_default_handler)
|
| 38 |
+
library_root_logger.setLevel(_default_log_level)
|
| 39 |
+
library_root_logger.propagate = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _reset_library_root_logger() -> None:
|
| 43 |
+
global _default_handler
|
| 44 |
+
|
| 45 |
+
with _lock:
|
| 46 |
+
if not _default_handler:
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
library_root_logger = _get_library_root_logger()
|
| 50 |
+
library_root_logger.removeHandler(_default_handler)
|
| 51 |
+
library_root_logger.setLevel(logging.NOTSET)
|
| 52 |
+
_default_handler = None
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def set_log_level(level: int, logger: logging.Logger = None) -> None:
|
| 56 |
+
"""
|
| 57 |
+
Set the log level.
|
| 58 |
+
Args:
|
| 59 |
+
level (:obj:`int`):
|
| 60 |
+
Logging level.
|
| 61 |
+
logger (:obj:`logging.Logger`):
|
| 62 |
+
Logger to set the log level.
|
| 63 |
+
"""
|
| 64 |
+
if not logger:
|
| 65 |
+
_configure_library_root_logger()
|
| 66 |
+
logger = _get_library_root_logger()
|
| 67 |
+
logger.setLevel(level)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_logger(
|
| 71 |
+
name: Optional[str] = None,
|
| 72 |
+
level: Optional[int] = None,
|
| 73 |
+
formatter: Optional[str] = None,
|
| 74 |
+
) -> logging.Logger:
|
| 75 |
+
"""
|
| 76 |
+
Return a logger with the specified name.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
if name is None:
|
| 80 |
+
name = _get_library_name()
|
| 81 |
+
|
| 82 |
+
_configure_library_root_logger()
|
| 83 |
+
|
| 84 |
+
if level is not None:
|
| 85 |
+
set_log_level(level)
|
| 86 |
+
|
| 87 |
+
if formatter is None:
|
| 88 |
+
formatter = logging.Formatter(
|
| 89 |
+
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
|
| 90 |
+
)
|
| 91 |
+
_default_handler.setFormatter(formatter)
|
| 92 |
+
|
| 93 |
+
return logging.getLogger(name)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_console_logger():
|
| 97 |
+
return _console
|
relik/common/upload.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import tempfile
|
| 6 |
+
import zipfile
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Optional, Union
|
| 10 |
+
|
| 11 |
+
import huggingface_hub
|
| 12 |
+
|
| 13 |
+
from relik.common.log import get_logger
|
| 14 |
+
from relik.common.utils import SAPIENZANLP_DATE_FORMAT, get_md5
|
| 15 |
+
|
| 16 |
+
logger = get_logger(level=logging.DEBUG)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def create_info_file(tmpdir: Path):
|
| 20 |
+
logger.debug("Computing md5 of model.zip")
|
| 21 |
+
md5 = get_md5(tmpdir / "model.zip")
|
| 22 |
+
date = datetime.now().strftime(SAPIENZANLP_DATE_FORMAT)
|
| 23 |
+
|
| 24 |
+
logger.debug("Dumping info.json file")
|
| 25 |
+
with (tmpdir / "info.json").open("w") as f:
|
| 26 |
+
json.dump(dict(md5=md5, upload_date=date), f, indent=2)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def zip_run(
|
| 30 |
+
dir_path: Union[str, os.PathLike],
|
| 31 |
+
tmpdir: Union[str, os.PathLike],
|
| 32 |
+
zip_name: str = "model.zip",
|
| 33 |
+
) -> Path:
|
| 34 |
+
logger.debug(f"zipping {dir_path} to {tmpdir}")
|
| 35 |
+
# creates a zip version of the provided dir_path
|
| 36 |
+
run_dir = Path(dir_path)
|
| 37 |
+
zip_path = tmpdir / zip_name
|
| 38 |
+
|
| 39 |
+
with zipfile.ZipFile(zip_path, "w") as zip_file:
|
| 40 |
+
# fully zip the run directory maintaining its structure
|
| 41 |
+
for file in run_dir.rglob("*.*"):
|
| 42 |
+
if file.is_dir():
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
zip_file.write(file, arcname=file.relative_to(run_dir))
|
| 46 |
+
|
| 47 |
+
return zip_path
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def upload(
|
| 51 |
+
model_dir: Union[str, os.PathLike],
|
| 52 |
+
model_name: str,
|
| 53 |
+
organization: Optional[str] = None,
|
| 54 |
+
repo_name: Optional[str] = None,
|
| 55 |
+
commit: Optional[str] = None,
|
| 56 |
+
archive: bool = False,
|
| 57 |
+
):
|
| 58 |
+
token = huggingface_hub.HfFolder.get_token()
|
| 59 |
+
if token is None:
|
| 60 |
+
print(
|
| 61 |
+
"No HuggingFace token found. You need to execute `huggingface-cli login` first!"
|
| 62 |
+
)
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
repo_id = repo_name or model_name
|
| 66 |
+
if organization is not None:
|
| 67 |
+
repo_id = f"{organization}/{repo_id}"
|
| 68 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 69 |
+
api = huggingface_hub.HfApi()
|
| 70 |
+
repo_url = api.create_repo(
|
| 71 |
+
token=token,
|
| 72 |
+
repo_id=repo_id,
|
| 73 |
+
exist_ok=True,
|
| 74 |
+
)
|
| 75 |
+
repo = huggingface_hub.Repository(
|
| 76 |
+
str(tmpdir), clone_from=repo_url, use_auth_token=token
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
tmp_path = Path(tmpdir)
|
| 80 |
+
if archive:
|
| 81 |
+
# otherwise we zip the model_dir
|
| 82 |
+
logger.debug(f"Zipping {model_dir} to {tmp_path}")
|
| 83 |
+
zip_run(model_dir, tmp_path)
|
| 84 |
+
create_info_file(tmp_path)
|
| 85 |
+
else:
|
| 86 |
+
# if the user wants to upload a transformers model, we don't need to zip it
|
| 87 |
+
# we just need to copy the files to the tmpdir
|
| 88 |
+
logger.debug(f"Copying {model_dir} to {tmpdir}")
|
| 89 |
+
os.system(f"cp -r {model_dir}/* {tmpdir}")
|
| 90 |
+
|
| 91 |
+
# this method automatically puts large files (>10MB) into git lfs
|
| 92 |
+
repo.push_to_hub(commit_message=commit or "Automatic push from sapienzanlp")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def parse_args() -> argparse.Namespace:
|
| 96 |
+
parser = argparse.ArgumentParser()
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"model_dir", help="The directory of the model you want to upload"
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument("model_name", help="The model you want to upload")
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--organization",
|
| 103 |
+
help="the name of the organization where you want to upload the model",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--repo_name",
|
| 107 |
+
help="Optional name to use when uploading to the HuggingFace repository",
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--commit", help="Commit message to use when pushing to the HuggingFace Hub"
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--archive",
|
| 114 |
+
action="store_true",
|
| 115 |
+
help="""
|
| 116 |
+
Whether to compress the model directory before uploading it.
|
| 117 |
+
If True, the model directory will be zipped and the zip file will be uploaded.
|
| 118 |
+
If False, the model directory will be uploaded as is.""",
|
| 119 |
+
)
|
| 120 |
+
return parser.parse_args()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def main():
|
| 124 |
+
upload(**vars(parse_args()))
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
main()
|
relik/common/utils.py
ADDED
|
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.util
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
import tarfile
|
| 7 |
+
import tempfile
|
| 8 |
+
from functools import partial
|
| 9 |
+
from hashlib import sha256
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, BinaryIO, Dict, List, Optional, Union
|
| 12 |
+
from urllib.parse import urlparse
|
| 13 |
+
from zipfile import ZipFile, is_zipfile
|
| 14 |
+
|
| 15 |
+
import huggingface_hub
|
| 16 |
+
import requests
|
| 17 |
+
import tqdm
|
| 18 |
+
from filelock import FileLock
|
| 19 |
+
from transformers.utils.hub import cached_file as hf_cached_file
|
| 20 |
+
|
| 21 |
+
from relik.common.log import get_logger
|
| 22 |
+
|
| 23 |
+
# name constants
|
| 24 |
+
WEIGHTS_NAME = "weights.pt"
|
| 25 |
+
ONNX_WEIGHTS_NAME = "weights.onnx"
|
| 26 |
+
CONFIG_NAME = "config.yaml"
|
| 27 |
+
LABELS_NAME = "labels.json"
|
| 28 |
+
|
| 29 |
+
# SAPIENZANLP_USER_NAME = "sapienzanlp"
|
| 30 |
+
SAPIENZANLP_USER_NAME = "riccorl"
|
| 31 |
+
SAPIENZANLP_HF_MODEL_REPO_URL = "riccorl/{model_id}"
|
| 32 |
+
SAPIENZANLP_HF_MODEL_REPO_ARCHIVE_URL = (
|
| 33 |
+
f"{SAPIENZANLP_HF_MODEL_REPO_URL}/resolve/main/model.zip"
|
| 34 |
+
)
|
| 35 |
+
# path constants
|
| 36 |
+
SAPIENZANLP_CACHE_DIR = os.getenv("SAPIENZANLP_CACHE_DIR", Path.home() / ".sapienzanlp")
|
| 37 |
+
SAPIENZANLP_DATE_FORMAT = "%Y-%m-%d %H-%M-%S"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
logger = get_logger(__name__)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def sapienzanlp_model_urls(model_id: str) -> str:
|
| 44 |
+
"""
|
| 45 |
+
Returns the URL for a possible SapienzaNLP valid model.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
model_id (:obj:`str`):
|
| 49 |
+
A SapienzaNLP model id.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
:obj:`str`: The url for the model id.
|
| 53 |
+
"""
|
| 54 |
+
# check if there is already the namespace of the user
|
| 55 |
+
if "/" in model_id:
|
| 56 |
+
return model_id
|
| 57 |
+
return SAPIENZANLP_HF_MODEL_REPO_URL.format(model_id=model_id)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def is_package_available(package_name: str) -> bool:
|
| 61 |
+
"""
|
| 62 |
+
Check if a package is available.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
package_name (`str`): The name of the package to check.
|
| 66 |
+
"""
|
| 67 |
+
return importlib.util.find_spec(package_name) is not None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def load_json(path: Union[str, Path]) -> Any:
|
| 71 |
+
"""
|
| 72 |
+
Load a json file provided in input.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
path (`Union[str, Path]`): The path to the json file to load.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
`Any`: The loaded json file.
|
| 79 |
+
"""
|
| 80 |
+
with open(path, encoding="utf8") as f:
|
| 81 |
+
return json.load(f)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def dump_json(document: Any, path: Union[str, Path], indent: Optional[int] = None):
|
| 85 |
+
"""
|
| 86 |
+
Dump input to json file.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
document (`Any`): The document to dump.
|
| 90 |
+
path (`Union[str, Path]`): The path to dump the document to.
|
| 91 |
+
indent (`Optional[int]`): The indent to use for the json file.
|
| 92 |
+
|
| 93 |
+
"""
|
| 94 |
+
with open(path, "w", encoding="utf8") as outfile:
|
| 95 |
+
json.dump(document, outfile, indent=indent)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_md5(path: Path):
|
| 99 |
+
"""
|
| 100 |
+
Get the MD5 value of a path.
|
| 101 |
+
"""
|
| 102 |
+
import hashlib
|
| 103 |
+
|
| 104 |
+
with path.open("rb") as fin:
|
| 105 |
+
data = fin.read()
|
| 106 |
+
return hashlib.md5(data).hexdigest()
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def file_exists(path: Union[str, os.PathLike]) -> bool:
|
| 110 |
+
"""
|
| 111 |
+
Check if the file at :obj:`path` exists.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
path (:obj:`str`, :obj:`os.PathLike`):
|
| 115 |
+
Path to check.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
:obj:`bool`: :obj:`True` if the file exists.
|
| 119 |
+
"""
|
| 120 |
+
return Path(path).exists()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def dir_exists(path: Union[str, os.PathLike]) -> bool:
|
| 124 |
+
"""
|
| 125 |
+
Check if the directory at :obj:`path` exists.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
path (:obj:`str`, :obj:`os.PathLike`):
|
| 129 |
+
Path to check.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
:obj:`bool`: :obj:`True` if the directory exists.
|
| 133 |
+
"""
|
| 134 |
+
return Path(path).is_dir()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def is_remote_url(url_or_filename: Union[str, Path]):
|
| 138 |
+
"""
|
| 139 |
+
Returns :obj:`True` if the input path is an url.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
url_or_filename (:obj:`str`, :obj:`Path`):
|
| 143 |
+
path to check.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
:obj:`bool`: :obj:`True` if the input path is an url, :obj:`False` otherwise.
|
| 147 |
+
|
| 148 |
+
"""
|
| 149 |
+
if isinstance(url_or_filename, Path):
|
| 150 |
+
url_or_filename = str(url_or_filename)
|
| 151 |
+
parsed = urlparse(url_or_filename)
|
| 152 |
+
return parsed.scheme in ("http", "https")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def url_to_filename(resource: str, etag: str = None) -> str:
|
| 156 |
+
"""
|
| 157 |
+
Convert a `resource` into a hashed filename in a repeatable way.
|
| 158 |
+
If `etag` is specified, append its hash to the resources's, delimited
|
| 159 |
+
by a period.
|
| 160 |
+
"""
|
| 161 |
+
resource_bytes = resource.encode("utf-8")
|
| 162 |
+
resource_hash = sha256(resource_bytes)
|
| 163 |
+
filename = resource_hash.hexdigest()
|
| 164 |
+
|
| 165 |
+
if etag:
|
| 166 |
+
etag_bytes = etag.encode("utf-8")
|
| 167 |
+
etag_hash = sha256(etag_bytes)
|
| 168 |
+
filename += "." + etag_hash.hexdigest()
|
| 169 |
+
|
| 170 |
+
return filename
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def download_resource(
|
| 174 |
+
url: str,
|
| 175 |
+
temp_file: BinaryIO,
|
| 176 |
+
headers=None,
|
| 177 |
+
):
|
| 178 |
+
"""
|
| 179 |
+
Download remote file.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
if headers is None:
|
| 183 |
+
headers = {}
|
| 184 |
+
|
| 185 |
+
r = requests.get(url, stream=True, headers=headers)
|
| 186 |
+
r.raise_for_status()
|
| 187 |
+
content_length = r.headers.get("Content-Length")
|
| 188 |
+
total = int(content_length) if content_length is not None else None
|
| 189 |
+
progress = tqdm(
|
| 190 |
+
unit="B",
|
| 191 |
+
unit_scale=True,
|
| 192 |
+
total=total,
|
| 193 |
+
desc="Downloading",
|
| 194 |
+
disable=logger.level in [logging.NOTSET],
|
| 195 |
+
)
|
| 196 |
+
for chunk in r.iter_content(chunk_size=1024):
|
| 197 |
+
if chunk: # filter out keep-alive new chunks
|
| 198 |
+
progress.update(len(chunk))
|
| 199 |
+
temp_file.write(chunk)
|
| 200 |
+
progress.close()
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def download_and_cache(
|
| 204 |
+
url: Union[str, Path],
|
| 205 |
+
cache_dir: Union[str, Path] = None,
|
| 206 |
+
force_download: bool = False,
|
| 207 |
+
):
|
| 208 |
+
if cache_dir is None:
|
| 209 |
+
cache_dir = SAPIENZANLP_CACHE_DIR
|
| 210 |
+
if isinstance(url, Path):
|
| 211 |
+
url = str(url)
|
| 212 |
+
|
| 213 |
+
# check if cache dir exists
|
| 214 |
+
Path(cache_dir).mkdir(parents=True, exist_ok=True)
|
| 215 |
+
|
| 216 |
+
# check if file is private
|
| 217 |
+
headers = {}
|
| 218 |
+
try:
|
| 219 |
+
r = requests.head(url, allow_redirects=False, timeout=10)
|
| 220 |
+
r.raise_for_status()
|
| 221 |
+
except requests.exceptions.HTTPError:
|
| 222 |
+
if r.status_code == 401:
|
| 223 |
+
hf_token = huggingface_hub.HfFolder.get_token()
|
| 224 |
+
if hf_token is None:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
"You need to login to HuggingFace to download this model "
|
| 227 |
+
"(use the `huggingface-cli login` command)"
|
| 228 |
+
)
|
| 229 |
+
headers["Authorization"] = f"Bearer {hf_token}"
|
| 230 |
+
|
| 231 |
+
etag = None
|
| 232 |
+
try:
|
| 233 |
+
r = requests.head(url, allow_redirects=True, timeout=10, headers=headers)
|
| 234 |
+
r.raise_for_status()
|
| 235 |
+
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
| 236 |
+
# We favor a custom header indicating the etag of the linked resource, and
|
| 237 |
+
# we fallback to the regular etag header.
|
| 238 |
+
# If we don't have any of those, raise an error.
|
| 239 |
+
if etag is None:
|
| 240 |
+
raise OSError(
|
| 241 |
+
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
| 242 |
+
)
|
| 243 |
+
# In case of a redirect,
|
| 244 |
+
# save an extra redirect on the request.get call,
|
| 245 |
+
# and ensure we download the exact atomic version even if it changed
|
| 246 |
+
# between the HEAD and the GET (unlikely, but hey).
|
| 247 |
+
if 300 <= r.status_code <= 399:
|
| 248 |
+
url = r.headers["Location"]
|
| 249 |
+
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
|
| 250 |
+
# Actually raise for those subclasses of ConnectionError
|
| 251 |
+
raise
|
| 252 |
+
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
| 253 |
+
# Otherwise, our Internet connection is down.
|
| 254 |
+
# etag is None
|
| 255 |
+
pass
|
| 256 |
+
|
| 257 |
+
# get filename from the url
|
| 258 |
+
filename = url_to_filename(url, etag)
|
| 259 |
+
# get cache path to put the file
|
| 260 |
+
cache_path = cache_dir / filename
|
| 261 |
+
|
| 262 |
+
# the file is already here, return it
|
| 263 |
+
if file_exists(cache_path) and not force_download:
|
| 264 |
+
logger.info(
|
| 265 |
+
f"{url} found in cache, set `force_download=True` to force the download"
|
| 266 |
+
)
|
| 267 |
+
return cache_path
|
| 268 |
+
|
| 269 |
+
cache_path = str(cache_path)
|
| 270 |
+
# Prevent parallel downloads of the same file with a lock.
|
| 271 |
+
lock_path = cache_path + ".lock"
|
| 272 |
+
with FileLock(lock_path):
|
| 273 |
+
# If the download just completed while the lock was activated.
|
| 274 |
+
if file_exists(cache_path) and not force_download:
|
| 275 |
+
# Even if returning early like here, the lock will be released.
|
| 276 |
+
return cache_path
|
| 277 |
+
|
| 278 |
+
temp_file_manager = partial(
|
| 279 |
+
tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Download to temporary file, then copy to cache dir once finished.
|
| 283 |
+
# Otherwise, you get corrupt cache entries if the download gets interrupted.
|
| 284 |
+
with temp_file_manager() as temp_file:
|
| 285 |
+
logger.info(
|
| 286 |
+
f"{url} not found in cache or `force_download` set to `True`, downloading to {temp_file.name}"
|
| 287 |
+
)
|
| 288 |
+
download_resource(url, temp_file, headers)
|
| 289 |
+
|
| 290 |
+
logger.info(f"storing {url} in cache at {cache_path}")
|
| 291 |
+
os.replace(temp_file.name, cache_path)
|
| 292 |
+
|
| 293 |
+
# NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it.
|
| 294 |
+
umask = os.umask(0o666)
|
| 295 |
+
os.umask(umask)
|
| 296 |
+
os.chmod(cache_path, 0o666 & ~umask)
|
| 297 |
+
|
| 298 |
+
logger.info(f"creating metadata file for {cache_path}")
|
| 299 |
+
meta = {"url": url} # , "etag": etag}
|
| 300 |
+
meta_path = cache_path + ".json"
|
| 301 |
+
with open(meta_path, "w") as meta_file:
|
| 302 |
+
json.dump(meta, meta_file)
|
| 303 |
+
|
| 304 |
+
return cache_path
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def download_from_hf(
|
| 308 |
+
path_or_repo_id: Union[str, Path],
|
| 309 |
+
filenames: Optional[List[str]],
|
| 310 |
+
cache_dir: Union[str, Path] = None,
|
| 311 |
+
force_download: bool = False,
|
| 312 |
+
resume_download: bool = False,
|
| 313 |
+
proxies: Optional[Dict[str, str]] = None,
|
| 314 |
+
use_auth_token: Optional[Union[bool, str]] = None,
|
| 315 |
+
revision: Optional[str] = None,
|
| 316 |
+
local_files_only: bool = False,
|
| 317 |
+
subfolder: str = "",
|
| 318 |
+
):
|
| 319 |
+
if isinstance(path_or_repo_id, Path):
|
| 320 |
+
path_or_repo_id = str(path_or_repo_id)
|
| 321 |
+
|
| 322 |
+
downloaded_paths = []
|
| 323 |
+
for filename in filenames:
|
| 324 |
+
downloaded_path = hf_cached_file(
|
| 325 |
+
path_or_repo_id,
|
| 326 |
+
filename,
|
| 327 |
+
cache_dir=cache_dir,
|
| 328 |
+
force_download=force_download,
|
| 329 |
+
proxies=proxies,
|
| 330 |
+
resume_download=resume_download,
|
| 331 |
+
use_auth_token=use_auth_token,
|
| 332 |
+
revision=revision,
|
| 333 |
+
local_files_only=local_files_only,
|
| 334 |
+
subfolder=subfolder,
|
| 335 |
+
)
|
| 336 |
+
downloaded_paths.append(downloaded_path)
|
| 337 |
+
|
| 338 |
+
# we want the folder where the files are downloaded
|
| 339 |
+
# the best guess is the parent folder of the first file
|
| 340 |
+
probably_the_folder = Path(downloaded_paths[0]).parent
|
| 341 |
+
return probably_the_folder
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def model_name_or_path_resolver(model_name_or_dir: Union[str, os.PathLike]) -> str:
|
| 345 |
+
"""
|
| 346 |
+
Resolve a model name or directory to a model archive name or directory.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
model_name_or_dir (:obj:`str` or :obj:`os.PathLike`):
|
| 350 |
+
A model name or directory.
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
:obj:`str`: The model archive name or directory.
|
| 354 |
+
"""
|
| 355 |
+
if is_remote_url(model_name_or_dir):
|
| 356 |
+
# if model_name_or_dir is a URL
|
| 357 |
+
# download it and try to load
|
| 358 |
+
model_archive = model_name_or_dir
|
| 359 |
+
elif Path(model_name_or_dir).is_dir() or Path(model_name_or_dir).is_file():
|
| 360 |
+
# if model_name_or_dir is a local directory or
|
| 361 |
+
# an archive file try to load it
|
| 362 |
+
model_archive = model_name_or_dir
|
| 363 |
+
else:
|
| 364 |
+
# probably model_name_or_dir is a sapienzanlp model id
|
| 365 |
+
# guess the url and try to download
|
| 366 |
+
model_name_or_dir_ = model_name_or_dir
|
| 367 |
+
# raise ValueError(f"Providing a model id is not supported yet.")
|
| 368 |
+
model_archive = sapienzanlp_model_urls(model_name_or_dir_)
|
| 369 |
+
|
| 370 |
+
return model_archive
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def from_cache(
|
| 374 |
+
url_or_filename: Union[str, Path],
|
| 375 |
+
cache_dir: Union[str, Path] = None,
|
| 376 |
+
force_download: bool = False,
|
| 377 |
+
resume_download: bool = False,
|
| 378 |
+
proxies: Optional[Dict[str, str]] = None,
|
| 379 |
+
use_auth_token: Optional[Union[bool, str]] = None,
|
| 380 |
+
revision: Optional[str] = None,
|
| 381 |
+
local_files_only: bool = False,
|
| 382 |
+
subfolder: str = "",
|
| 383 |
+
filenames: Optional[List[str]] = None,
|
| 384 |
+
) -> Path:
|
| 385 |
+
"""
|
| 386 |
+
Given something that could be either a local path or a URL (or a SapienzaNLP model id),
|
| 387 |
+
determine which one and return a path to the corresponding file.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
url_or_filename (:obj:`str` or :obj:`Path`):
|
| 391 |
+
A path to a local file or a URL (or a SapienzaNLP model id).
|
| 392 |
+
cache_dir (:obj:`str` or :obj:`Path`, `optional`):
|
| 393 |
+
Path to a directory in which a downloaded file will be cached.
|
| 394 |
+
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 395 |
+
Whether or not to re-download the file even if it already exists.
|
| 396 |
+
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 397 |
+
Whether or not to delete incompletely received files. Attempts to resume the download if such a file
|
| 398 |
+
exists.
|
| 399 |
+
proxies (:obj:`Dict[str, str]`, `optional`):
|
| 400 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
| 401 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 402 |
+
use_auth_token (:obj:`Union[bool, str]`, `optional`):
|
| 403 |
+
Optional string or boolean to use as Bearer token for remote files. If :obj:`True`, will get token from
|
| 404 |
+
:obj:`~transformers.hf_api.HfApi`. If :obj:`str`, will use that string as token.
|
| 405 |
+
revision (:obj:`str`, `optional`):
|
| 406 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 407 |
+
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
| 408 |
+
identifier allowed by git.
|
| 409 |
+
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 410 |
+
Whether or not to raise an error if the file to be downloaded is local.
|
| 411 |
+
subfolder (:obj:`str`, `optional`):
|
| 412 |
+
In case the relevant file is in a subfolder of the URL, specify it here.
|
| 413 |
+
filenames (:obj:`List[str]`, `optional`):
|
| 414 |
+
List of filenames to look for in the directory structure.
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
:obj:`Path`: Path to the cached file.
|
| 418 |
+
"""
|
| 419 |
+
|
| 420 |
+
url_or_filename = model_name_or_path_resolver(url_or_filename)
|
| 421 |
+
|
| 422 |
+
if cache_dir is None:
|
| 423 |
+
cache_dir = SAPIENZANLP_CACHE_DIR
|
| 424 |
+
|
| 425 |
+
if file_exists(url_or_filename):
|
| 426 |
+
logger.info(f"{url_or_filename} is a local path or file")
|
| 427 |
+
output_path = url_or_filename
|
| 428 |
+
elif is_remote_url(url_or_filename):
|
| 429 |
+
# URL, so get it from the cache (downloading if necessary)
|
| 430 |
+
output_path = download_and_cache(
|
| 431 |
+
url_or_filename,
|
| 432 |
+
cache_dir=cache_dir,
|
| 433 |
+
force_download=force_download,
|
| 434 |
+
)
|
| 435 |
+
else:
|
| 436 |
+
if filenames is None:
|
| 437 |
+
filenames = [WEIGHTS_NAME, CONFIG_NAME, LABELS_NAME]
|
| 438 |
+
output_path = download_from_hf(
|
| 439 |
+
url_or_filename,
|
| 440 |
+
filenames,
|
| 441 |
+
cache_dir,
|
| 442 |
+
force_download,
|
| 443 |
+
resume_download,
|
| 444 |
+
proxies,
|
| 445 |
+
use_auth_token,
|
| 446 |
+
revision,
|
| 447 |
+
local_files_only,
|
| 448 |
+
subfolder,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# if is_hf_hub_url(url_or_filename):
|
| 452 |
+
# HuggingFace Hub
|
| 453 |
+
# output_path = hf_hub_download_url(url_or_filename)
|
| 454 |
+
# elif is_remote_url(url_or_filename):
|
| 455 |
+
# # URL, so get it from the cache (downloading if necessary)
|
| 456 |
+
# output_path = download_and_cache(
|
| 457 |
+
# url_or_filename,
|
| 458 |
+
# cache_dir=cache_dir,
|
| 459 |
+
# force_download=force_download,
|
| 460 |
+
# )
|
| 461 |
+
# elif file_exists(url_or_filename):
|
| 462 |
+
# logger.info(f"{url_or_filename} is a local path or file")
|
| 463 |
+
# # File, and it exists.
|
| 464 |
+
# output_path = url_or_filename
|
| 465 |
+
# elif urlparse(url_or_filename).scheme == "":
|
| 466 |
+
# # File, but it doesn't exist.
|
| 467 |
+
# raise EnvironmentError(f"file {url_or_filename} not found")
|
| 468 |
+
# else:
|
| 469 |
+
# # Something unknown
|
| 470 |
+
# raise ValueError(
|
| 471 |
+
# f"unable to parse {url_or_filename} as a URL or as a local path"
|
| 472 |
+
# )
|
| 473 |
+
|
| 474 |
+
if dir_exists(output_path) or (
|
| 475 |
+
not is_zipfile(output_path) and not tarfile.is_tarfile(output_path)
|
| 476 |
+
):
|
| 477 |
+
return Path(output_path)
|
| 478 |
+
|
| 479 |
+
# Path where we extract compressed archives
|
| 480 |
+
# for now it will extract it in the same folder
|
| 481 |
+
# maybe implement extraction in the sapienzanlp folder
|
| 482 |
+
# when using local archive path?
|
| 483 |
+
logger.info("Extracting compressed archive")
|
| 484 |
+
output_dir, output_file = os.path.split(output_path)
|
| 485 |
+
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
|
| 486 |
+
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
|
| 487 |
+
|
| 488 |
+
# already extracted, do not extract
|
| 489 |
+
if (
|
| 490 |
+
os.path.isdir(output_path_extracted)
|
| 491 |
+
and os.listdir(output_path_extracted)
|
| 492 |
+
and not force_download
|
| 493 |
+
):
|
| 494 |
+
return Path(output_path_extracted)
|
| 495 |
+
|
| 496 |
+
# Prevent parallel extractions
|
| 497 |
+
lock_path = output_path + ".lock"
|
| 498 |
+
with FileLock(lock_path):
|
| 499 |
+
shutil.rmtree(output_path_extracted, ignore_errors=True)
|
| 500 |
+
os.makedirs(output_path_extracted)
|
| 501 |
+
if is_zipfile(output_path):
|
| 502 |
+
with ZipFile(output_path, "r") as zip_file:
|
| 503 |
+
zip_file.extractall(output_path_extracted)
|
| 504 |
+
zip_file.close()
|
| 505 |
+
elif tarfile.is_tarfile(output_path):
|
| 506 |
+
tar_file = tarfile.open(output_path)
|
| 507 |
+
tar_file.extractall(output_path_extracted)
|
| 508 |
+
tar_file.close()
|
| 509 |
+
else:
|
| 510 |
+
raise EnvironmentError(
|
| 511 |
+
f"Archive format of {output_path} could not be identified"
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# remove lock file, is it safe?
|
| 515 |
+
os.remove(lock_path)
|
| 516 |
+
|
| 517 |
+
return Path(output_path_extracted)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def is_str_a_path(maybe_path: str) -> bool:
|
| 521 |
+
"""
|
| 522 |
+
Check if a string is a path.
|
| 523 |
+
|
| 524 |
+
Args:
|
| 525 |
+
maybe_path (`str`): The string to check.
|
| 526 |
+
|
| 527 |
+
Returns:
|
| 528 |
+
`bool`: `True` if the string is a path, `False` otherwise.
|
| 529 |
+
"""
|
| 530 |
+
# first check if it is a path
|
| 531 |
+
if Path(maybe_path).exists():
|
| 532 |
+
return True
|
| 533 |
+
# check if it is a relative path
|
| 534 |
+
if Path(os.path.join(os.getcwd(), maybe_path)).exists():
|
| 535 |
+
return True
|
| 536 |
+
# otherwise it is not a path
|
| 537 |
+
return False
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def relative_to_absolute_path(path: str) -> os.PathLike:
|
| 541 |
+
"""
|
| 542 |
+
Convert a relative path to an absolute path.
|
| 543 |
+
|
| 544 |
+
Args:
|
| 545 |
+
path (`str`): The relative path to convert.
|
| 546 |
+
|
| 547 |
+
Returns:
|
| 548 |
+
`os.PathLike`: The absolute path.
|
| 549 |
+
"""
|
| 550 |
+
if not is_str_a_path(path):
|
| 551 |
+
raise ValueError(f"{path} is not a path")
|
| 552 |
+
if Path(path).exists():
|
| 553 |
+
return Path(path).absolute()
|
| 554 |
+
if Path(os.path.join(os.getcwd(), path)).exists():
|
| 555 |
+
return Path(os.path.join(os.getcwd(), path)).absolute()
|
| 556 |
+
raise ValueError(f"{path} is not a path")
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def to_config(object_to_save: Any) -> Dict[str, Any]:
|
| 560 |
+
"""
|
| 561 |
+
Convert an object to a dictionary.
|
| 562 |
+
|
| 563 |
+
Returns:
|
| 564 |
+
`Dict[str, Any]`: The dictionary representation of the object.
|
| 565 |
+
"""
|
| 566 |
+
|
| 567 |
+
def obj_to_dict(obj):
|
| 568 |
+
match obj:
|
| 569 |
+
case dict():
|
| 570 |
+
data = {}
|
| 571 |
+
for k, v in obj.items():
|
| 572 |
+
data[k] = obj_to_dict(v)
|
| 573 |
+
return data
|
| 574 |
+
|
| 575 |
+
case list() | tuple():
|
| 576 |
+
return [obj_to_dict(x) for x in obj]
|
| 577 |
+
|
| 578 |
+
case object(__dict__=_):
|
| 579 |
+
data = {
|
| 580 |
+
"_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}",
|
| 581 |
+
}
|
| 582 |
+
for k, v in obj.__dict__.items():
|
| 583 |
+
if not k.startswith("_"):
|
| 584 |
+
data[k] = obj_to_dict(v)
|
| 585 |
+
return data
|
| 586 |
+
|
| 587 |
+
case _:
|
| 588 |
+
return obj
|
| 589 |
+
|
| 590 |
+
return obj_to_dict(object_to_save)
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def get_callable_from_string(callable_fn: str) -> Any:
|
| 594 |
+
"""
|
| 595 |
+
Get a callable from a string.
|
| 596 |
+
|
| 597 |
+
Args:
|
| 598 |
+
callable_fn (`str`):
|
| 599 |
+
The string representation of the callable.
|
| 600 |
+
|
| 601 |
+
Returns:
|
| 602 |
+
`Any`: The callable.
|
| 603 |
+
"""
|
| 604 |
+
# separate the function name from the module name
|
| 605 |
+
module_name, function_name = callable_fn.rsplit(".", 1)
|
| 606 |
+
# import the module
|
| 607 |
+
module = importlib.import_module(module_name)
|
| 608 |
+
# get the function
|
| 609 |
+
return getattr(module, function_name)
|
relik/inference/__init__.py
ADDED
|
File without changes
|
relik/inference/annotator.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any, Callable, Dict, Optional, Union
|
| 4 |
+
|
| 5 |
+
import hydra
|
| 6 |
+
from omegaconf import OmegaConf
|
| 7 |
+
from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel
|
| 8 |
+
from rich.pretty import pprint
|
| 9 |
+
|
| 10 |
+
from relik.common.log import get_console_logger, get_logger
|
| 11 |
+
from relik.common.upload import upload
|
| 12 |
+
from relik.common.utils import CONFIG_NAME, from_cache, get_callable_from_string
|
| 13 |
+
from relik.inference.data.objects import EntitySpan, RelikOutput
|
| 14 |
+
from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
|
| 15 |
+
from relik.inference.data.window.manager import WindowManager
|
| 16 |
+
from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
|
| 17 |
+
from relik.reader.relik_reader import RelikReader
|
| 18 |
+
from relik.retriever.data.utils import batch_generator
|
| 19 |
+
from relik.retriever.indexers.base import BaseDocumentIndex
|
| 20 |
+
from relik.retriever.pytorch_modules.model import GoldenRetriever
|
| 21 |
+
|
| 22 |
+
logger = get_logger(__name__)
|
| 23 |
+
console_logger = get_console_logger()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Relik:
|
| 27 |
+
"""
|
| 28 |
+
Relik main class. It is a wrapper around a retriever and a reader.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
retriever (`Optional[GoldenRetriever]`, `optional`):
|
| 32 |
+
The retriever to use. If `None`, a retriever will be instantiated from the
|
| 33 |
+
provided `question_encoder`, `passage_encoder` and `document_index`.
|
| 34 |
+
Defaults to `None`.
|
| 35 |
+
question_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`):
|
| 36 |
+
The question encoder to use. If `retriever` is `None`, a retriever will be
|
| 37 |
+
instantiated from this parameter. Defaults to `None`.
|
| 38 |
+
passage_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`):
|
| 39 |
+
The passage encoder to use. If `retriever` is `None`, a retriever will be
|
| 40 |
+
instantiated from this parameter. Defaults to `None`.
|
| 41 |
+
document_index (`Optional[Union[str, BaseDocumentIndex]]`, `optional`):
|
| 42 |
+
The document index to use. If `retriever` is `None`, a retriever will be
|
| 43 |
+
instantiated from this parameter. Defaults to `None`.
|
| 44 |
+
reader (`Optional[Union[str, RelikReader]]`, `optional`):
|
| 45 |
+
The reader to use. If `None`, a reader will be instantiated from the
|
| 46 |
+
provided `reader`. Defaults to `None`.
|
| 47 |
+
retriever_device (`str`, `optional`, defaults to `cpu`):
|
| 48 |
+
The device to use for the retriever.
|
| 49 |
+
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
retriever: GoldenRetriever | None = None,
|
| 55 |
+
question_encoder: str | GoldenRetrieverModel | None = None,
|
| 56 |
+
passage_encoder: str | GoldenRetrieverModel | None = None,
|
| 57 |
+
document_index: str | BaseDocumentIndex | None = None,
|
| 58 |
+
reader: str | RelikReader | None = None,
|
| 59 |
+
device: str = "cpu",
|
| 60 |
+
retriever_device: str | None = None,
|
| 61 |
+
document_index_device: str | None = None,
|
| 62 |
+
reader_device: str | None = None,
|
| 63 |
+
precision: int = 32,
|
| 64 |
+
retriever_precision: int | None = None,
|
| 65 |
+
document_index_precision: int | None = None,
|
| 66 |
+
reader_precision: int | None = None,
|
| 67 |
+
reader_kwargs: dict | None = None,
|
| 68 |
+
retriever_kwargs: dict | None = None,
|
| 69 |
+
candidates_preprocessing_fn: str | Callable | None = None,
|
| 70 |
+
top_k: int | None = None,
|
| 71 |
+
window_size: int | None = None,
|
| 72 |
+
window_stride: int | None = None,
|
| 73 |
+
**kwargs,
|
| 74 |
+
) -> None:
|
| 75 |
+
# retriever
|
| 76 |
+
retriever_device = retriever_device or device
|
| 77 |
+
document_index_device = document_index_device or device
|
| 78 |
+
retriever_precision = retriever_precision or precision
|
| 79 |
+
document_index_precision = document_index_precision or precision
|
| 80 |
+
if retriever is None and question_encoder is None:
|
| 81 |
+
raise ValueError(
|
| 82 |
+
"Either `retriever` or `question_encoder` must be provided"
|
| 83 |
+
)
|
| 84 |
+
if retriever is None:
|
| 85 |
+
self.retriever_kwargs = dict(
|
| 86 |
+
question_encoder=question_encoder,
|
| 87 |
+
passage_encoder=passage_encoder,
|
| 88 |
+
document_index=document_index,
|
| 89 |
+
device=retriever_device,
|
| 90 |
+
precision=retriever_precision,
|
| 91 |
+
index_device=document_index_device,
|
| 92 |
+
index_precision=document_index_precision,
|
| 93 |
+
)
|
| 94 |
+
# overwrite default_retriever_kwargs with retriever_kwargs
|
| 95 |
+
self.retriever_kwargs.update(retriever_kwargs or {})
|
| 96 |
+
retriever = GoldenRetriever(**self.retriever_kwargs)
|
| 97 |
+
retriever.training = False
|
| 98 |
+
retriever.eval()
|
| 99 |
+
self.retriever = retriever
|
| 100 |
+
|
| 101 |
+
# reader
|
| 102 |
+
self.reader_device = reader_device or device
|
| 103 |
+
self.reader_precision = reader_precision or precision
|
| 104 |
+
self.reader_kwargs = reader_kwargs
|
| 105 |
+
if isinstance(reader, str):
|
| 106 |
+
reader_kwargs = reader_kwargs or {}
|
| 107 |
+
reader = RelikReaderForSpanExtraction(reader, **reader_kwargs)
|
| 108 |
+
self.reader = reader
|
| 109 |
+
|
| 110 |
+
# windowization stuff
|
| 111 |
+
self.tokenizer = SpacyTokenizer(language="en")
|
| 112 |
+
self.window_manager: WindowManager | None = None
|
| 113 |
+
|
| 114 |
+
# candidates preprocessing
|
| 115 |
+
# TODO: maybe move this logic somewhere else
|
| 116 |
+
candidates_preprocessing_fn = candidates_preprocessing_fn or (lambda x: x)
|
| 117 |
+
if isinstance(candidates_preprocessing_fn, str):
|
| 118 |
+
candidates_preprocessing_fn = get_callable_from_string(
|
| 119 |
+
candidates_preprocessing_fn
|
| 120 |
+
)
|
| 121 |
+
self.candidates_preprocessing_fn = candidates_preprocessing_fn
|
| 122 |
+
|
| 123 |
+
# inference params
|
| 124 |
+
self.top_k = top_k
|
| 125 |
+
self.window_size = window_size
|
| 126 |
+
self.window_stride = window_stride
|
| 127 |
+
|
| 128 |
+
def __call__(
|
| 129 |
+
self,
|
| 130 |
+
text: Union[str, list],
|
| 131 |
+
top_k: Optional[int] = None,
|
| 132 |
+
window_size: Optional[int] = None,
|
| 133 |
+
window_stride: Optional[int] = None,
|
| 134 |
+
retriever_batch_size: Optional[int] = 32,
|
| 135 |
+
reader_batch_size: Optional[int] = 32,
|
| 136 |
+
return_also_windows: bool = False,
|
| 137 |
+
**kwargs,
|
| 138 |
+
) -> Union[RelikOutput, list[RelikOutput]]:
|
| 139 |
+
"""
|
| 140 |
+
Annotate a text with entities.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
text (`str` or `list`):
|
| 144 |
+
The text to annotate. If a list is provided, each element of the list
|
| 145 |
+
will be annotated separately.
|
| 146 |
+
top_k (`int`, `optional`, defaults to `None`):
|
| 147 |
+
The number of candidates to retrieve for each window.
|
| 148 |
+
window_size (`int`, `optional`, defaults to `None`):
|
| 149 |
+
The size of the window. If `None`, the whole text will be annotated.
|
| 150 |
+
window_stride (`int`, `optional`, defaults to `None`):
|
| 151 |
+
The stride of the window. If `None`, there will be no overlap between windows.
|
| 152 |
+
retriever_batch_size (`int`, `optional`, defaults to `None`):
|
| 153 |
+
The batch size to use for the retriever. The whole input is the batch for the retriever.
|
| 154 |
+
reader_batch_size (`int`, `optional`, defaults to `None`):
|
| 155 |
+
The batch size to use for the reader. The whole input is the batch for the reader.
|
| 156 |
+
return_also_windows (`bool`, `optional`, defaults to `False`):
|
| 157 |
+
Whether to return the windows in the output.
|
| 158 |
+
**kwargs:
|
| 159 |
+
Additional keyword arguments to pass to the retriever and the reader.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
`RelikOutput` or `list[RelikOutput]`:
|
| 163 |
+
The annotated text. If a list was provided as input, a list of
|
| 164 |
+
`RelikOutput` objects will be returned.
|
| 165 |
+
"""
|
| 166 |
+
if top_k is None:
|
| 167 |
+
top_k = self.top_k or 100
|
| 168 |
+
if window_size is None:
|
| 169 |
+
window_size = self.window_size
|
| 170 |
+
if window_stride is None:
|
| 171 |
+
window_stride = self.window_stride
|
| 172 |
+
|
| 173 |
+
if isinstance(text, str):
|
| 174 |
+
text = [text]
|
| 175 |
+
|
| 176 |
+
if window_size is not None:
|
| 177 |
+
if self.window_manager is None:
|
| 178 |
+
self.window_manager = WindowManager(self.tokenizer)
|
| 179 |
+
|
| 180 |
+
if window_size == "sentence":
|
| 181 |
+
# todo: implement sentence windowizer
|
| 182 |
+
raise NotImplementedError("Sentence windowizer not implemented yet")
|
| 183 |
+
|
| 184 |
+
# if window_size < window_stride:
|
| 185 |
+
# raise ValueError(
|
| 186 |
+
# f"Window size ({window_size}) must be greater than window stride ({window_stride})"
|
| 187 |
+
# )
|
| 188 |
+
|
| 189 |
+
# window generator
|
| 190 |
+
windows = [
|
| 191 |
+
window
|
| 192 |
+
for doc_id, t in enumerate(text)
|
| 193 |
+
for window in self.window_manager.create_windows(
|
| 194 |
+
t,
|
| 195 |
+
window_size=window_size,
|
| 196 |
+
stride=window_stride,
|
| 197 |
+
doc_id=doc_id,
|
| 198 |
+
)
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
# retrieve candidates first
|
| 202 |
+
windows_candidates = []
|
| 203 |
+
# TODO: Move batching inside retriever
|
| 204 |
+
for batch in batch_generator(windows, batch_size=retriever_batch_size):
|
| 205 |
+
retriever_out = self.retriever.retrieve([b.text for b in batch], k=top_k)
|
| 206 |
+
windows_candidates.extend(
|
| 207 |
+
[[p.label for p in predictions] for predictions in retriever_out]
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# add passage to the windows
|
| 211 |
+
for window, candidates in zip(windows, windows_candidates):
|
| 212 |
+
window.window_candidates = [
|
| 213 |
+
self.candidates_preprocessing_fn(c) for c in candidates
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
windows = self.reader.read(samples=windows, max_batch_size=reader_batch_size)
|
| 217 |
+
windows = self.window_manager.merge_windows(windows)
|
| 218 |
+
|
| 219 |
+
# transform predictions into RelikOutput objects
|
| 220 |
+
output = []
|
| 221 |
+
for w in windows:
|
| 222 |
+
sample_output = RelikOutput(
|
| 223 |
+
text=text[w.doc_id],
|
| 224 |
+
labels=sorted(
|
| 225 |
+
[
|
| 226 |
+
EntitySpan(
|
| 227 |
+
start=ss, end=se, label=sl, text=text[w.doc_id][ss:se]
|
| 228 |
+
)
|
| 229 |
+
for ss, se, sl in w.predicted_window_labels_chars
|
| 230 |
+
],
|
| 231 |
+
key=lambda x: x.start,
|
| 232 |
+
),
|
| 233 |
+
)
|
| 234 |
+
output.append(sample_output)
|
| 235 |
+
|
| 236 |
+
if return_also_windows:
|
| 237 |
+
for i, sample_output in enumerate(output):
|
| 238 |
+
sample_output.windows = [w for w in windows if w.doc_id == i]
|
| 239 |
+
|
| 240 |
+
# if only one text was provided, return a single RelikOutput object
|
| 241 |
+
if len(output) == 1:
|
| 242 |
+
return output[0]
|
| 243 |
+
|
| 244 |
+
return output
|
| 245 |
+
|
| 246 |
+
@classmethod
|
| 247 |
+
def from_pretrained(
|
| 248 |
+
cls,
|
| 249 |
+
model_name_or_dir: Union[str, os.PathLike],
|
| 250 |
+
config_kwargs: Optional[Dict] = None,
|
| 251 |
+
config_file_name: str = CONFIG_NAME,
|
| 252 |
+
*args,
|
| 253 |
+
**kwargs,
|
| 254 |
+
) -> "Relik":
|
| 255 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 256 |
+
force_download = kwargs.pop("force_download", False)
|
| 257 |
+
|
| 258 |
+
model_dir = from_cache(
|
| 259 |
+
model_name_or_dir,
|
| 260 |
+
filenames=[config_file_name],
|
| 261 |
+
cache_dir=cache_dir,
|
| 262 |
+
force_download=force_download,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
config_path = model_dir / config_file_name
|
| 266 |
+
if not config_path.exists():
|
| 267 |
+
raise FileNotFoundError(
|
| 268 |
+
f"Model configuration file not found at {config_path}."
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# overwrite config with config_kwargs
|
| 272 |
+
config = OmegaConf.load(config_path)
|
| 273 |
+
if config_kwargs is not None:
|
| 274 |
+
# TODO: check merging behavior
|
| 275 |
+
config = OmegaConf.merge(config, OmegaConf.create(config_kwargs))
|
| 276 |
+
# do we want to print the config? I like it
|
| 277 |
+
pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True)
|
| 278 |
+
|
| 279 |
+
# load relik from config
|
| 280 |
+
relik = hydra.utils.instantiate(config, *args, **kwargs)
|
| 281 |
+
|
| 282 |
+
return relik
|
| 283 |
+
|
| 284 |
+
def save_pretrained(
|
| 285 |
+
self,
|
| 286 |
+
output_dir: Union[str, os.PathLike],
|
| 287 |
+
config: Optional[Dict[str, Any]] = None,
|
| 288 |
+
config_file_name: Optional[str] = None,
|
| 289 |
+
save_weights: bool = False,
|
| 290 |
+
push_to_hub: bool = False,
|
| 291 |
+
model_id: Optional[str] = None,
|
| 292 |
+
organization: Optional[str] = None,
|
| 293 |
+
repo_name: Optional[str] = None,
|
| 294 |
+
**kwargs,
|
| 295 |
+
):
|
| 296 |
+
"""
|
| 297 |
+
Save the configuration of Relik to the specified directory as a YAML file.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
output_dir (`str`):
|
| 301 |
+
The directory to save the configuration file to.
|
| 302 |
+
config (`Optional[Dict[str, Any]]`, `optional`):
|
| 303 |
+
The configuration to save. If `None`, the current configuration will be
|
| 304 |
+
saved. Defaults to `None`.
|
| 305 |
+
config_file_name (`Optional[str]`, `optional`):
|
| 306 |
+
The name of the configuration file. Defaults to `config.yaml`.
|
| 307 |
+
save_weights (`bool`, `optional`):
|
| 308 |
+
Whether to save the weights of the model. Defaults to `False`.
|
| 309 |
+
push_to_hub (`bool`, `optional`):
|
| 310 |
+
Whether to push the saved model to the hub. Defaults to `False`.
|
| 311 |
+
model_id (`Optional[str]`, `optional`):
|
| 312 |
+
The id of the model to push to the hub. If `None`, the name of the
|
| 313 |
+
directory will be used. Defaults to `None`.
|
| 314 |
+
organization (`Optional[str]`, `optional`):
|
| 315 |
+
The organization to push the model to. Defaults to `None`.
|
| 316 |
+
repo_name (`Optional[str]`, `optional`):
|
| 317 |
+
The name of the repository to push the model to. Defaults to `None`.
|
| 318 |
+
**kwargs:
|
| 319 |
+
Additional keyword arguments to pass to `OmegaConf.save`.
|
| 320 |
+
"""
|
| 321 |
+
if config is None:
|
| 322 |
+
# create a default config
|
| 323 |
+
config = {
|
| 324 |
+
"_target_": f"{self.__class__.__module__}.{self.__class__.__name__}"
|
| 325 |
+
}
|
| 326 |
+
if self.retriever is not None:
|
| 327 |
+
if self.retriever.question_encoder is not None:
|
| 328 |
+
config[
|
| 329 |
+
"question_encoder"
|
| 330 |
+
] = self.retriever.question_encoder.name_or_path
|
| 331 |
+
if self.retriever.passage_encoder is not None:
|
| 332 |
+
config[
|
| 333 |
+
"passage_encoder"
|
| 334 |
+
] = self.retriever.passage_encoder.name_or_path
|
| 335 |
+
if self.retriever.document_index is not None:
|
| 336 |
+
config["document_index"] = self.retriever.document_index.name_or_dir
|
| 337 |
+
if self.reader is not None:
|
| 338 |
+
config["reader"] = self.reader.model_path
|
| 339 |
+
|
| 340 |
+
config["retriever_kwargs"] = self.retriever_kwargs
|
| 341 |
+
config["reader_kwargs"] = self.reader_kwargs
|
| 342 |
+
# expand the fn as to be able to save it and load it later
|
| 343 |
+
config[
|
| 344 |
+
"candidates_preprocessing_fn"
|
| 345 |
+
] = f"{self.candidates_preprocessing_fn.__module__}.{self.candidates_preprocessing_fn.__name__}"
|
| 346 |
+
|
| 347 |
+
# these are model-specific and should be saved
|
| 348 |
+
config["top_k"] = self.top_k
|
| 349 |
+
config["window_size"] = self.window_size
|
| 350 |
+
config["window_stride"] = self.window_stride
|
| 351 |
+
|
| 352 |
+
config_file_name = config_file_name or CONFIG_NAME
|
| 353 |
+
|
| 354 |
+
# create the output directory
|
| 355 |
+
output_dir = Path(output_dir)
|
| 356 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 357 |
+
|
| 358 |
+
logger.info(f"Saving relik config to {output_dir / config_file_name}")
|
| 359 |
+
# pretty print the config
|
| 360 |
+
pprint(config, console=console_logger, expand_all=True)
|
| 361 |
+
OmegaConf.save(config, output_dir / config_file_name)
|
| 362 |
+
|
| 363 |
+
if save_weights:
|
| 364 |
+
model_id = model_id or output_dir.name
|
| 365 |
+
retriever_model_id = model_id + "-retriever"
|
| 366 |
+
# save weights
|
| 367 |
+
logger.info(f"Saving retriever to {output_dir / retriever_model_id}")
|
| 368 |
+
self.retriever.save_pretrained(
|
| 369 |
+
output_dir / retriever_model_id,
|
| 370 |
+
question_encoder_name=retriever_model_id + "-question-encoder",
|
| 371 |
+
passage_encoder_name=retriever_model_id + "-passage-encoder",
|
| 372 |
+
document_index_name=retriever_model_id + "-index",
|
| 373 |
+
push_to_hub=push_to_hub,
|
| 374 |
+
organization=organization,
|
| 375 |
+
repo_name=repo_name,
|
| 376 |
+
**kwargs,
|
| 377 |
+
)
|
| 378 |
+
reader_model_id = model_id + "-reader"
|
| 379 |
+
logger.info(f"Saving reader to {output_dir / reader_model_id}")
|
| 380 |
+
self.reader.save_pretrained(
|
| 381 |
+
output_dir / reader_model_id,
|
| 382 |
+
push_to_hub=push_to_hub,
|
| 383 |
+
organization=organization,
|
| 384 |
+
repo_name=repo_name,
|
| 385 |
+
**kwargs,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
if push_to_hub:
|
| 389 |
+
# push to hub
|
| 390 |
+
logger.info(f"Pushing to hub")
|
| 391 |
+
model_id = model_id or output_dir.name
|
| 392 |
+
upload(output_dir, model_id, organization=organization, repo_name=repo_name)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def main():
|
| 396 |
+
from pprint import pprint
|
| 397 |
+
|
| 398 |
+
relik = Relik(
|
| 399 |
+
question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
|
| 400 |
+
document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
|
| 401 |
+
reader="riccorl/relik-reader-aida-deberta-small",
|
| 402 |
+
device="cuda",
|
| 403 |
+
precision=16,
|
| 404 |
+
top_k=100,
|
| 405 |
+
window_size=32,
|
| 406 |
+
window_stride=16,
|
| 407 |
+
candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing",
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
input_text = """
|
| 411 |
+
Bernie Ecclestone, the former boss of Formula One, has admitted fraud after failing to declare more than £400m held in a trust in Singapore.
|
| 412 |
+
The 92-year-old billionaire did not disclose the trust to the government in July 2015.
|
| 413 |
+
Appearing at Southwark Crown Court on Thursday, he told the judge "I plead guilty" after having previously pleaded not guilty.
|
| 414 |
+
Ecclestone had been due to go on trial next month.
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
preds = relik(input_text)
|
| 418 |
+
pprint(preds)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
if __name__ == "__main__":
|
| 422 |
+
main()
|
relik/inference/data/__init__.py
ADDED
|
File without changes
|
relik/inference/data/objects.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, NamedTuple, Optional
|
| 5 |
+
|
| 6 |
+
from relik.reader.pytorch_modules.hf.modeling_relik import RelikReaderSample
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class Word:
|
| 11 |
+
"""
|
| 12 |
+
A word representation that includes text, index in the sentence, POS tag, lemma,
|
| 13 |
+
dependency relation, and similar information.
|
| 14 |
+
|
| 15 |
+
# Parameters
|
| 16 |
+
text : `str`, optional
|
| 17 |
+
The text representation.
|
| 18 |
+
index : `int`, optional
|
| 19 |
+
The word offset in the sentence.
|
| 20 |
+
lemma : `str`, optional
|
| 21 |
+
The lemma of this word.
|
| 22 |
+
pos : `str`, optional
|
| 23 |
+
The coarse-grained part of speech of this word.
|
| 24 |
+
dep : `str`, optional
|
| 25 |
+
The dependency relation for this word.
|
| 26 |
+
|
| 27 |
+
input_id : `int`, optional
|
| 28 |
+
Integer representation of the word, used to pass it to a model.
|
| 29 |
+
token_type_id : `int`, optional
|
| 30 |
+
Token type id used by some transformers.
|
| 31 |
+
attention_mask: `int`, optional
|
| 32 |
+
Attention mask used by transformers, indicates to the model which tokens should
|
| 33 |
+
be attended to, and which should not.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
text: str
|
| 37 |
+
index: int
|
| 38 |
+
start_char: Optional[int] = None
|
| 39 |
+
end_char: Optional[int] = None
|
| 40 |
+
# preprocessing fields
|
| 41 |
+
lemma: Optional[str] = None
|
| 42 |
+
pos: Optional[str] = None
|
| 43 |
+
dep: Optional[str] = None
|
| 44 |
+
head: Optional[int] = None
|
| 45 |
+
|
| 46 |
+
def __str__(self):
|
| 47 |
+
return self.text
|
| 48 |
+
|
| 49 |
+
def __repr__(self):
|
| 50 |
+
return self.__str__()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class EntitySpan(NamedTuple):
|
| 54 |
+
start: int
|
| 55 |
+
end: int
|
| 56 |
+
label: str
|
| 57 |
+
text: str
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class RelikOutput:
|
| 62 |
+
text: str
|
| 63 |
+
labels: List[EntitySpan]
|
| 64 |
+
windows: Optional[List[RelikReaderSample]] = None
|
relik/inference/data/tokenizers/__init__.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SPACY_LANGUAGE_MAPPER = {
|
| 2 |
+
"ca": "ca_core_news_sm",
|
| 3 |
+
"da": "da_core_news_sm",
|
| 4 |
+
"de": "de_core_news_sm",
|
| 5 |
+
"el": "el_core_news_sm",
|
| 6 |
+
"en": "en_core_web_sm",
|
| 7 |
+
"es": "es_core_news_sm",
|
| 8 |
+
"fr": "fr_core_news_sm",
|
| 9 |
+
"it": "it_core_news_sm",
|
| 10 |
+
"ja": "ja_core_news_sm",
|
| 11 |
+
"lt": "lt_core_news_sm",
|
| 12 |
+
"mk": "mk_core_news_sm",
|
| 13 |
+
"nb": "nb_core_news_sm",
|
| 14 |
+
"nl": "nl_core_news_sm",
|
| 15 |
+
"pl": "pl_core_news_sm",
|
| 16 |
+
"pt": "pt_core_news_sm",
|
| 17 |
+
"ro": "ro_core_news_sm",
|
| 18 |
+
"ru": "ru_core_news_sm",
|
| 19 |
+
"xx": "xx_sent_ud_sm",
|
| 20 |
+
"zh": "zh_core_web_sm",
|
| 21 |
+
"ca_core_news_sm": "ca_core_news_sm",
|
| 22 |
+
"ca_core_news_md": "ca_core_news_md",
|
| 23 |
+
"ca_core_news_lg": "ca_core_news_lg",
|
| 24 |
+
"ca_core_news_trf": "ca_core_news_trf",
|
| 25 |
+
"da_core_news_sm": "da_core_news_sm",
|
| 26 |
+
"da_core_news_md": "da_core_news_md",
|
| 27 |
+
"da_core_news_lg": "da_core_news_lg",
|
| 28 |
+
"da_core_news_trf": "da_core_news_trf",
|
| 29 |
+
"de_core_news_sm": "de_core_news_sm",
|
| 30 |
+
"de_core_news_md": "de_core_news_md",
|
| 31 |
+
"de_core_news_lg": "de_core_news_lg",
|
| 32 |
+
"de_dep_news_trf": "de_dep_news_trf",
|
| 33 |
+
"el_core_news_sm": "el_core_news_sm",
|
| 34 |
+
"el_core_news_md": "el_core_news_md",
|
| 35 |
+
"el_core_news_lg": "el_core_news_lg",
|
| 36 |
+
"en_core_web_sm": "en_core_web_sm",
|
| 37 |
+
"en_core_web_md": "en_core_web_md",
|
| 38 |
+
"en_core_web_lg": "en_core_web_lg",
|
| 39 |
+
"en_core_web_trf": "en_core_web_trf",
|
| 40 |
+
"es_core_news_sm": "es_core_news_sm",
|
| 41 |
+
"es_core_news_md": "es_core_news_md",
|
| 42 |
+
"es_core_news_lg": "es_core_news_lg",
|
| 43 |
+
"es_dep_news_trf": "es_dep_news_trf",
|
| 44 |
+
"fr_core_news_sm": "fr_core_news_sm",
|
| 45 |
+
"fr_core_news_md": "fr_core_news_md",
|
| 46 |
+
"fr_core_news_lg": "fr_core_news_lg",
|
| 47 |
+
"fr_dep_news_trf": "fr_dep_news_trf",
|
| 48 |
+
"it_core_news_sm": "it_core_news_sm",
|
| 49 |
+
"it_core_news_md": "it_core_news_md",
|
| 50 |
+
"it_core_news_lg": "it_core_news_lg",
|
| 51 |
+
"ja_core_news_sm": "ja_core_news_sm",
|
| 52 |
+
"ja_core_news_md": "ja_core_news_md",
|
| 53 |
+
"ja_core_news_lg": "ja_core_news_lg",
|
| 54 |
+
"ja_dep_news_trf": "ja_dep_news_trf",
|
| 55 |
+
"lt_core_news_sm": "lt_core_news_sm",
|
| 56 |
+
"lt_core_news_md": "lt_core_news_md",
|
| 57 |
+
"lt_core_news_lg": "lt_core_news_lg",
|
| 58 |
+
"mk_core_news_sm": "mk_core_news_sm",
|
| 59 |
+
"mk_core_news_md": "mk_core_news_md",
|
| 60 |
+
"mk_core_news_lg": "mk_core_news_lg",
|
| 61 |
+
"nb_core_news_sm": "nb_core_news_sm",
|
| 62 |
+
"nb_core_news_md": "nb_core_news_md",
|
| 63 |
+
"nb_core_news_lg": "nb_core_news_lg",
|
| 64 |
+
"nl_core_news_sm": "nl_core_news_sm",
|
| 65 |
+
"nl_core_news_md": "nl_core_news_md",
|
| 66 |
+
"nl_core_news_lg": "nl_core_news_lg",
|
| 67 |
+
"pl_core_news_sm": "pl_core_news_sm",
|
| 68 |
+
"pl_core_news_md": "pl_core_news_md",
|
| 69 |
+
"pl_core_news_lg": "pl_core_news_lg",
|
| 70 |
+
"pt_core_news_sm": "pt_core_news_sm",
|
| 71 |
+
"pt_core_news_md": "pt_core_news_md",
|
| 72 |
+
"pt_core_news_lg": "pt_core_news_lg",
|
| 73 |
+
"ro_core_news_sm": "ro_core_news_sm",
|
| 74 |
+
"ro_core_news_md": "ro_core_news_md",
|
| 75 |
+
"ro_core_news_lg": "ro_core_news_lg",
|
| 76 |
+
"ru_core_news_sm": "ru_core_news_sm",
|
| 77 |
+
"ru_core_news_md": "ru_core_news_md",
|
| 78 |
+
"ru_core_news_lg": "ru_core_news_lg",
|
| 79 |
+
"xx_ent_wiki_sm": "xx_ent_wiki_sm",
|
| 80 |
+
"xx_sent_ud_sm": "xx_sent_ud_sm",
|
| 81 |
+
"zh_core_web_sm": "zh_core_web_sm",
|
| 82 |
+
"zh_core_web_md": "zh_core_web_md",
|
| 83 |
+
"zh_core_web_lg": "zh_core_web_lg",
|
| 84 |
+
"zh_core_web_trf": "zh_core_web_trf",
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
from relik.inference.data.tokenizers.regex_tokenizer import RegexTokenizer
|
| 88 |
+
from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
|
| 89 |
+
from relik.inference.data.tokenizers.whitespace_tokenizer import WhitespaceTokenizer
|
relik/inference/data/tokenizers/base_tokenizer.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
+
|
| 3 |
+
from relik.inference.data.objects import Word
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BaseTokenizer:
|
| 7 |
+
"""
|
| 8 |
+
A :obj:`Tokenizer` splits strings of text into single words, optionally adds
|
| 9 |
+
pos tags and perform lemmatization.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __call__(
|
| 13 |
+
self,
|
| 14 |
+
texts: Union[str, List[str], List[List[str]]],
|
| 15 |
+
is_split_into_words: bool = False,
|
| 16 |
+
**kwargs
|
| 17 |
+
) -> List[List[Word]]:
|
| 18 |
+
"""
|
| 19 |
+
Tokenize the input into single words.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
| 23 |
+
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
| 24 |
+
is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
|
| 25 |
+
If :obj:`True` and the input is a string, the input is split on spaces.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
:obj:`List[List[Word]]`: The input text tokenized in single words.
|
| 29 |
+
"""
|
| 30 |
+
raise NotImplementedError
|
| 31 |
+
|
| 32 |
+
def tokenize(self, text: str) -> List[Word]:
|
| 33 |
+
"""
|
| 34 |
+
Implements splitting words into tokens.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
text (:obj:`str`):
|
| 38 |
+
Text to tokenize.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
:obj:`List[Word]`: The input text tokenized in single words.
|
| 42 |
+
|
| 43 |
+
"""
|
| 44 |
+
raise NotImplementedError
|
| 45 |
+
|
| 46 |
+
def tokenize_batch(self, texts: List[str]) -> List[List[Word]]:
|
| 47 |
+
"""
|
| 48 |
+
Implements batch splitting words into tokens.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
texts (:obj:`List[str]`):
|
| 52 |
+
Batch of text to tokenize.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
:obj:`List[List[Word]]`: The input batch tokenized in single words.
|
| 56 |
+
|
| 57 |
+
"""
|
| 58 |
+
return [self.tokenize(text) for text in texts]
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def check_is_batched(
|
| 62 |
+
texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool
|
| 63 |
+
):
|
| 64 |
+
"""
|
| 65 |
+
Check if input is batched or a single sample.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
| 69 |
+
Text to check.
|
| 70 |
+
is_split_into_words (:obj:`bool`):
|
| 71 |
+
If :obj:`True` and the input is a string, the input is split on spaces.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
:obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise.
|
| 75 |
+
"""
|
| 76 |
+
return bool(
|
| 77 |
+
(not is_split_into_words and isinstance(texts, (list, tuple)))
|
| 78 |
+
or (
|
| 79 |
+
is_split_into_words
|
| 80 |
+
and isinstance(texts, (list, tuple))
|
| 81 |
+
and texts
|
| 82 |
+
and isinstance(texts[0], (list, tuple))
|
| 83 |
+
)
|
| 84 |
+
)
|
relik/inference/data/tokenizers/regex_tokenizer.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
from overrides import overrides
|
| 5 |
+
|
| 6 |
+
from relik.inference.data.objects import Word
|
| 7 |
+
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class RegexTokenizer(BaseTokenizer):
|
| 11 |
+
"""
|
| 12 |
+
A :obj:`Tokenizer` that splits the text based on a simple regex.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
super(RegexTokenizer, self).__init__()
|
| 17 |
+
# regex for splitting on spaces and punctuation and new lines
|
| 18 |
+
# self._regex = re.compile(r"\S+|[\[\](),.!?;:\"]|\\n")
|
| 19 |
+
self._regex = re.compile(
|
| 20 |
+
r"\w+|\$[\d\.]+|\S+", re.UNICODE | re.MULTILINE | re.DOTALL
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def __call__(
|
| 24 |
+
self,
|
| 25 |
+
texts: Union[str, List[str], List[List[str]]],
|
| 26 |
+
is_split_into_words: bool = False,
|
| 27 |
+
**kwargs,
|
| 28 |
+
) -> List[List[Word]]:
|
| 29 |
+
"""
|
| 30 |
+
Tokenize the input into single words by splitting using a simple regex.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
| 34 |
+
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
| 35 |
+
is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
|
| 36 |
+
If :obj:`True` and the input is a string, the input is split on spaces.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
:obj:`List[List[Word]]`: The input text tokenized in single words.
|
| 40 |
+
|
| 41 |
+
Example::
|
| 42 |
+
|
| 43 |
+
>>> from relik.retriever.serve.tokenizers.regex_tokenizer import RegexTokenizer
|
| 44 |
+
|
| 45 |
+
>>> regex_tokenizer = RegexTokenizer()
|
| 46 |
+
>>> regex_tokenizer("Mary sold the car to John.")
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
# check if input is batched or a single sample
|
| 50 |
+
is_batched = self.check_is_batched(texts, is_split_into_words)
|
| 51 |
+
|
| 52 |
+
if is_batched:
|
| 53 |
+
tokenized = self.tokenize_batch(texts)
|
| 54 |
+
else:
|
| 55 |
+
tokenized = self.tokenize(texts)
|
| 56 |
+
|
| 57 |
+
return tokenized
|
| 58 |
+
|
| 59 |
+
@overrides
|
| 60 |
+
def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
|
| 61 |
+
if not isinstance(text, (str, list)):
|
| 62 |
+
raise ValueError(
|
| 63 |
+
f"text must be either `str` or `list`, found: `{type(text)}`"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
if isinstance(text, list):
|
| 67 |
+
text = " ".join(text)
|
| 68 |
+
return [
|
| 69 |
+
Word(t[0], i, start_char=t[1], end_char=t[2])
|
| 70 |
+
for i, t in enumerate(
|
| 71 |
+
(m.group(0), m.start(), m.end()) for m in self._regex.finditer(text)
|
| 72 |
+
)
|
| 73 |
+
]
|
relik/inference/data/tokenizers/spacy_tokenizer.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Dict, List, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import spacy
|
| 5 |
+
|
| 6 |
+
# from ipa.common.utils import load_spacy
|
| 7 |
+
from overrides import overrides
|
| 8 |
+
from spacy.cli.download import download as spacy_download
|
| 9 |
+
from spacy.tokens import Doc
|
| 10 |
+
|
| 11 |
+
from relik.common.log import get_logger
|
| 12 |
+
from relik.inference.data.objects import Word
|
| 13 |
+
from relik.inference.data.tokenizers import SPACY_LANGUAGE_MAPPER
|
| 14 |
+
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
|
| 15 |
+
|
| 16 |
+
logger = get_logger(level=logging.DEBUG)
|
| 17 |
+
|
| 18 |
+
# Spacy and Stanza stuff
|
| 19 |
+
|
| 20 |
+
LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool, bool], spacy.Language] = {}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_spacy(
|
| 24 |
+
language: str,
|
| 25 |
+
pos_tags: bool = False,
|
| 26 |
+
lemma: bool = False,
|
| 27 |
+
parse: bool = False,
|
| 28 |
+
split_on_spaces: bool = False,
|
| 29 |
+
) -> spacy.Language:
|
| 30 |
+
"""
|
| 31 |
+
Download and load spacy model.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
language (:obj:`str`, defaults to :obj:`en`):
|
| 35 |
+
Language of the text to tokenize.
|
| 36 |
+
pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
|
| 37 |
+
If :obj:`True`, performs POS tagging with spacy model.
|
| 38 |
+
lemma (:obj:`bool`, optional, defaults to :obj:`False`):
|
| 39 |
+
If :obj:`True`, performs lemmatization with spacy model.
|
| 40 |
+
parse (:obj:`bool`, optional, defaults to :obj:`False`):
|
| 41 |
+
If :obj:`True`, performs dependency parsing with spacy model.
|
| 42 |
+
split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`):
|
| 43 |
+
If :obj:`True`, will split by spaces without performing tokenization.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
:obj:`spacy.Language`: The spacy model loaded.
|
| 47 |
+
"""
|
| 48 |
+
exclude = ["vectors", "textcat", "ner"]
|
| 49 |
+
if not pos_tags:
|
| 50 |
+
exclude.append("tagger")
|
| 51 |
+
if not lemma:
|
| 52 |
+
exclude.append("lemmatizer")
|
| 53 |
+
if not parse:
|
| 54 |
+
exclude.append("parser")
|
| 55 |
+
|
| 56 |
+
# check if the model is already loaded
|
| 57 |
+
# if so, there is no need to reload it
|
| 58 |
+
spacy_params = (language, pos_tags, lemma, parse, split_on_spaces)
|
| 59 |
+
if spacy_params not in LOADED_SPACY_MODELS:
|
| 60 |
+
try:
|
| 61 |
+
spacy_tagger = spacy.load(language, exclude=exclude)
|
| 62 |
+
except OSError:
|
| 63 |
+
logger.warning(
|
| 64 |
+
"Spacy model '%s' not found. Downloading and installing.", language
|
| 65 |
+
)
|
| 66 |
+
spacy_download(language)
|
| 67 |
+
spacy_tagger = spacy.load(language, exclude=exclude)
|
| 68 |
+
|
| 69 |
+
# if everything is disabled, return only the tokenizer
|
| 70 |
+
# for faster tokenization
|
| 71 |
+
# TODO: is it really faster?
|
| 72 |
+
# if len(exclude) >= 6:
|
| 73 |
+
# spacy_tagger = spacy_tagger.tokenizer
|
| 74 |
+
LOADED_SPACY_MODELS[spacy_params] = spacy_tagger
|
| 75 |
+
|
| 76 |
+
return LOADED_SPACY_MODELS[spacy_params]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class SpacyTokenizer(BaseTokenizer):
|
| 80 |
+
"""
|
| 81 |
+
A :obj:`Tokenizer` that uses SpaCy to tokenizer and preprocess the text. It returns :obj:`Word` objects.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
language (:obj:`str`, optional, defaults to :obj:`en`):
|
| 85 |
+
Language of the text to tokenize.
|
| 86 |
+
return_pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
|
| 87 |
+
If :obj:`True`, performs POS tagging with spacy model.
|
| 88 |
+
return_lemmas (:obj:`bool`, optional, defaults to :obj:`False`):
|
| 89 |
+
If :obj:`True`, performs lemmatization with spacy model.
|
| 90 |
+
return_deps (:obj:`bool`, optional, defaults to :obj:`False`):
|
| 91 |
+
If :obj:`True`, performs dependency parsing with spacy model.
|
| 92 |
+
split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`):
|
| 93 |
+
If :obj:`True`, will split by spaces without performing tokenization.
|
| 94 |
+
use_gpu (:obj:`bool`, optional, defaults to :obj:`False`):
|
| 95 |
+
If :obj:`True`, will load the Stanza model on GPU.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
language: str = "en",
|
| 101 |
+
return_pos_tags: bool = False,
|
| 102 |
+
return_lemmas: bool = False,
|
| 103 |
+
return_deps: bool = False,
|
| 104 |
+
split_on_spaces: bool = False,
|
| 105 |
+
use_gpu: bool = False,
|
| 106 |
+
):
|
| 107 |
+
super(SpacyTokenizer, self).__init__()
|
| 108 |
+
if language not in SPACY_LANGUAGE_MAPPER:
|
| 109 |
+
raise ValueError(
|
| 110 |
+
f"`{language}` language not supported. The supported "
|
| 111 |
+
f"languages are: {list(SPACY_LANGUAGE_MAPPER.keys())}."
|
| 112 |
+
)
|
| 113 |
+
if use_gpu:
|
| 114 |
+
# load the model on GPU
|
| 115 |
+
# if the GPU is not available or not correctly configured,
|
| 116 |
+
# it will rise an error
|
| 117 |
+
spacy.require_gpu()
|
| 118 |
+
self.spacy = load_spacy(
|
| 119 |
+
SPACY_LANGUAGE_MAPPER[language],
|
| 120 |
+
return_pos_tags,
|
| 121 |
+
return_lemmas,
|
| 122 |
+
return_deps,
|
| 123 |
+
split_on_spaces,
|
| 124 |
+
)
|
| 125 |
+
self.split_on_spaces = split_on_spaces
|
| 126 |
+
|
| 127 |
+
def __call__(
|
| 128 |
+
self,
|
| 129 |
+
texts: Union[str, List[str], List[List[str]]],
|
| 130 |
+
is_split_into_words: bool = False,
|
| 131 |
+
**kwargs,
|
| 132 |
+
) -> Union[List[Word], List[List[Word]]]:
|
| 133 |
+
"""
|
| 134 |
+
Tokenize the input into single words using SpaCy models.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
| 138 |
+
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
| 139 |
+
is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
|
| 140 |
+
If :obj:`True` and the input is a string, the input is split on spaces.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
:obj:`List[List[Word]]`: The input text tokenized in single words.
|
| 144 |
+
|
| 145 |
+
Example::
|
| 146 |
+
|
| 147 |
+
>>> from ipa import SpacyTokenizer
|
| 148 |
+
|
| 149 |
+
>>> spacy_tokenizer = SpacyTokenizer(language="en", pos_tags=True, lemma=True)
|
| 150 |
+
>>> spacy_tokenizer("Mary sold the car to John.")
|
| 151 |
+
|
| 152 |
+
"""
|
| 153 |
+
# check if input is batched or a single sample
|
| 154 |
+
is_batched = self.check_is_batched(texts, is_split_into_words)
|
| 155 |
+
if is_batched:
|
| 156 |
+
tokenized = self.tokenize_batch(texts)
|
| 157 |
+
else:
|
| 158 |
+
tokenized = self.tokenize(texts)
|
| 159 |
+
return tokenized
|
| 160 |
+
|
| 161 |
+
@overrides
|
| 162 |
+
def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
|
| 163 |
+
if self.split_on_spaces:
|
| 164 |
+
if isinstance(text, str):
|
| 165 |
+
text = text.split(" ")
|
| 166 |
+
spaces = [True] * len(text)
|
| 167 |
+
text = Doc(self.spacy.vocab, words=text, spaces=spaces)
|
| 168 |
+
return self._clean_tokens(self.spacy(text))
|
| 169 |
+
|
| 170 |
+
@overrides
|
| 171 |
+
def tokenize_batch(
|
| 172 |
+
self, texts: Union[List[str], List[List[str]]]
|
| 173 |
+
) -> List[List[Word]]:
|
| 174 |
+
if self.split_on_spaces:
|
| 175 |
+
if isinstance(texts[0], str):
|
| 176 |
+
texts = [text.split(" ") for text in texts]
|
| 177 |
+
spaces = [[True] * len(text) for text in texts]
|
| 178 |
+
texts = [
|
| 179 |
+
Doc(self.spacy.vocab, words=text, spaces=space)
|
| 180 |
+
for text, space in zip(texts, spaces)
|
| 181 |
+
]
|
| 182 |
+
return [self._clean_tokens(tokens) for tokens in self.spacy.pipe(texts)]
|
| 183 |
+
|
| 184 |
+
@staticmethod
|
| 185 |
+
def _clean_tokens(tokens: Doc) -> List[Word]:
|
| 186 |
+
"""
|
| 187 |
+
Converts spaCy tokens to :obj:`Word`.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
tokens (:obj:`spacy.tokens.Doc`):
|
| 191 |
+
Tokens from SpaCy model.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
:obj:`List[Word]`: The SpaCy model output converted into :obj:`Word` objects.
|
| 195 |
+
"""
|
| 196 |
+
words = [
|
| 197 |
+
Word(
|
| 198 |
+
token.text,
|
| 199 |
+
token.i,
|
| 200 |
+
token.idx,
|
| 201 |
+
token.idx + len(token),
|
| 202 |
+
token.lemma_,
|
| 203 |
+
token.pos_,
|
| 204 |
+
token.dep_,
|
| 205 |
+
token.head.i,
|
| 206 |
+
)
|
| 207 |
+
for token in tokens
|
| 208 |
+
]
|
| 209 |
+
return words
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class WhitespaceSpacyTokenizer:
|
| 213 |
+
"""Simple white space tokenizer for SpaCy."""
|
| 214 |
+
|
| 215 |
+
def __init__(self, vocab):
|
| 216 |
+
self.vocab = vocab
|
| 217 |
+
|
| 218 |
+
def __call__(self, text):
|
| 219 |
+
if isinstance(text, str):
|
| 220 |
+
words = text.split(" ")
|
| 221 |
+
elif isinstance(text, list):
|
| 222 |
+
words = text
|
| 223 |
+
else:
|
| 224 |
+
raise ValueError(
|
| 225 |
+
f"text must be either `str` or `list`, found: `{type(text)}`"
|
| 226 |
+
)
|
| 227 |
+
spaces = [True] * len(words)
|
| 228 |
+
return Doc(self.vocab, words=words, spaces=spaces)
|
relik/inference/data/tokenizers/whitespace_tokenizer.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
from overrides import overrides
|
| 5 |
+
|
| 6 |
+
from relik.inference.data.objects import Word
|
| 7 |
+
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class WhitespaceTokenizer(BaseTokenizer):
|
| 11 |
+
"""
|
| 12 |
+
A :obj:`Tokenizer` that splits the text on spaces.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
super(WhitespaceTokenizer, self).__init__()
|
| 17 |
+
self.whitespace_regex = re.compile(r"\S+")
|
| 18 |
+
|
| 19 |
+
def __call__(
|
| 20 |
+
self,
|
| 21 |
+
texts: Union[str, List[str], List[List[str]]],
|
| 22 |
+
is_split_into_words: bool = False,
|
| 23 |
+
**kwargs,
|
| 24 |
+
) -> List[List[Word]]:
|
| 25 |
+
"""
|
| 26 |
+
Tokenize the input into single words by splitting on spaces.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
| 30 |
+
Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
|
| 31 |
+
is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
|
| 32 |
+
If :obj:`True` and the input is a string, the input is split on spaces.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
:obj:`List[List[Word]]`: The input text tokenized in single words.
|
| 36 |
+
|
| 37 |
+
Example::
|
| 38 |
+
|
| 39 |
+
>>> from nlp_preprocessing_wrappers import WhitespaceTokenizer
|
| 40 |
+
|
| 41 |
+
>>> whitespace_tokenizer = WhitespaceTokenizer()
|
| 42 |
+
>>> whitespace_tokenizer("Mary sold the car to John .")
|
| 43 |
+
|
| 44 |
+
"""
|
| 45 |
+
# check if input is batched or a single sample
|
| 46 |
+
is_batched = self.check_is_batched(texts, is_split_into_words)
|
| 47 |
+
|
| 48 |
+
if is_batched:
|
| 49 |
+
tokenized = self.tokenize_batch(texts)
|
| 50 |
+
else:
|
| 51 |
+
tokenized = self.tokenize(texts)
|
| 52 |
+
|
| 53 |
+
return tokenized
|
| 54 |
+
|
| 55 |
+
@overrides
|
| 56 |
+
def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
|
| 57 |
+
if not isinstance(text, (str, list)):
|
| 58 |
+
raise ValueError(
|
| 59 |
+
f"text must be either `str` or `list`, found: `{type(text)}`"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
if isinstance(text, list):
|
| 63 |
+
text = " ".join(text)
|
| 64 |
+
return [
|
| 65 |
+
Word(t[0], i, start_char=t[1], end_char=t[2])
|
| 66 |
+
for i, t in enumerate(
|
| 67 |
+
(m.group(0), m.start(), m.end())
|
| 68 |
+
for m in self.whitespace_regex.finditer(text)
|
| 69 |
+
)
|
| 70 |
+
]
|
relik/inference/data/window/__init__.py
ADDED
|
File without changes
|
relik/inference/data/window/manager.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import itertools
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, Optional, Set, Tuple
|
| 5 |
+
|
| 6 |
+
from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
|
| 7 |
+
from relik.reader.data.relik_reader_sample import RelikReaderSample
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class Window:
|
| 12 |
+
doc_id: int
|
| 13 |
+
window_id: int
|
| 14 |
+
text: str
|
| 15 |
+
tokens: List[str]
|
| 16 |
+
doc_topic: Optional[str]
|
| 17 |
+
offset: int
|
| 18 |
+
token2char_start: dict
|
| 19 |
+
token2char_end: dict
|
| 20 |
+
window_candidates: Optional[List[str]] = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class WindowManager:
|
| 24 |
+
def __init__(self, tokenizer: BaseTokenizer) -> None:
|
| 25 |
+
self.tokenizer = tokenizer
|
| 26 |
+
|
| 27 |
+
def tokenize(self, document: str) -> Tuple[List[str], List[Tuple[int, int]]]:
|
| 28 |
+
tokenized_document = self.tokenizer(document)
|
| 29 |
+
tokens = []
|
| 30 |
+
tokens_char_mapping = []
|
| 31 |
+
for token in tokenized_document:
|
| 32 |
+
tokens.append(token.text)
|
| 33 |
+
tokens_char_mapping.append((token.start_char, token.end_char))
|
| 34 |
+
return tokens, tokens_char_mapping
|
| 35 |
+
|
| 36 |
+
def create_windows(
|
| 37 |
+
self,
|
| 38 |
+
document: str,
|
| 39 |
+
window_size: int,
|
| 40 |
+
stride: int,
|
| 41 |
+
doc_id: int = 0,
|
| 42 |
+
doc_topic: str = None,
|
| 43 |
+
) -> List[RelikReaderSample]:
|
| 44 |
+
document_tokens, tokens_char_mapping = self.tokenize(document)
|
| 45 |
+
if doc_topic is None:
|
| 46 |
+
doc_topic = document_tokens[0] if len(document_tokens) > 0 else ""
|
| 47 |
+
document_windows = []
|
| 48 |
+
if len(document_tokens) <= window_size:
|
| 49 |
+
text = document
|
| 50 |
+
# relik_reader_sample = RelikReaderSample()
|
| 51 |
+
document_windows.append(
|
| 52 |
+
# Window(
|
| 53 |
+
RelikReaderSample(
|
| 54 |
+
doc_id=doc_id,
|
| 55 |
+
window_id=0,
|
| 56 |
+
text=text,
|
| 57 |
+
tokens=document_tokens,
|
| 58 |
+
doc_topic=doc_topic,
|
| 59 |
+
offset=0,
|
| 60 |
+
token2char_start={
|
| 61 |
+
str(i): tokens_char_mapping[i][0]
|
| 62 |
+
for i in range(len(document_tokens))
|
| 63 |
+
},
|
| 64 |
+
token2char_end={
|
| 65 |
+
str(i): tokens_char_mapping[i][1]
|
| 66 |
+
for i in range(len(document_tokens))
|
| 67 |
+
},
|
| 68 |
+
)
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
for window_id, i in enumerate(range(0, len(document_tokens), stride)):
|
| 72 |
+
# if the last stride is smaller than the window size, then we can
|
| 73 |
+
# include more tokens form the previous window.
|
| 74 |
+
if i != 0 and i + window_size > len(document_tokens):
|
| 75 |
+
overflowing_tokens = i + window_size - len(document_tokens)
|
| 76 |
+
if overflowing_tokens >= stride:
|
| 77 |
+
break
|
| 78 |
+
i -= overflowing_tokens
|
| 79 |
+
|
| 80 |
+
involved_token_indices = list(
|
| 81 |
+
range(i, min(i + window_size, len(document_tokens) - 1))
|
| 82 |
+
)
|
| 83 |
+
window_tokens = [document_tokens[j] for j in involved_token_indices]
|
| 84 |
+
window_text_start = tokens_char_mapping[involved_token_indices[0]][0]
|
| 85 |
+
window_text_end = tokens_char_mapping[involved_token_indices[-1]][1]
|
| 86 |
+
text = document[window_text_start:window_text_end]
|
| 87 |
+
document_windows.append(
|
| 88 |
+
# Window(
|
| 89 |
+
RelikReaderSample(
|
| 90 |
+
# dict(
|
| 91 |
+
doc_id=doc_id,
|
| 92 |
+
window_id=window_id,
|
| 93 |
+
text=text,
|
| 94 |
+
tokens=window_tokens,
|
| 95 |
+
doc_topic=doc_topic,
|
| 96 |
+
offset=window_text_start,
|
| 97 |
+
token2char_start={
|
| 98 |
+
str(i): tokens_char_mapping[ti][0]
|
| 99 |
+
for i, ti in enumerate(involved_token_indices)
|
| 100 |
+
},
|
| 101 |
+
token2char_end={
|
| 102 |
+
str(i): tokens_char_mapping[ti][1]
|
| 103 |
+
for i, ti in enumerate(involved_token_indices)
|
| 104 |
+
},
|
| 105 |
+
# )
|
| 106 |
+
)
|
| 107 |
+
)
|
| 108 |
+
return document_windows
|
| 109 |
+
|
| 110 |
+
def merge_windows(
|
| 111 |
+
self, windows: List[RelikReaderSample]
|
| 112 |
+
) -> List[RelikReaderSample]:
|
| 113 |
+
windows_by_doc_id = collections.defaultdict(list)
|
| 114 |
+
for window in windows:
|
| 115 |
+
windows_by_doc_id[window.doc_id].append(window)
|
| 116 |
+
|
| 117 |
+
merged_window_by_doc = {
|
| 118 |
+
doc_id: self.merge_doc_windows(doc_windows)
|
| 119 |
+
for doc_id, doc_windows in windows_by_doc_id.items()
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
return list(merged_window_by_doc.values())
|
| 123 |
+
|
| 124 |
+
def merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample:
|
| 125 |
+
if len(windows) == 1:
|
| 126 |
+
return windows[0]
|
| 127 |
+
|
| 128 |
+
if len(windows) > 0 and getattr(windows[0], "offset", None) is not None:
|
| 129 |
+
windows = sorted(windows, key=(lambda x: x.offset))
|
| 130 |
+
|
| 131 |
+
window_accumulator = windows[0]
|
| 132 |
+
|
| 133 |
+
for next_window in windows[1:]:
|
| 134 |
+
window_accumulator = self._merge_window_pair(
|
| 135 |
+
window_accumulator, next_window
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
return window_accumulator
|
| 139 |
+
|
| 140 |
+
def _merge_tokens(
|
| 141 |
+
self, window1: RelikReaderSample, window2: RelikReaderSample
|
| 142 |
+
) -> Tuple[list, dict, dict]:
|
| 143 |
+
w1_tokens = window1.tokens[1:-1]
|
| 144 |
+
w2_tokens = window2.tokens[1:-1]
|
| 145 |
+
|
| 146 |
+
# find intersection
|
| 147 |
+
tokens_intersection = None
|
| 148 |
+
for k in reversed(range(1, len(w1_tokens))):
|
| 149 |
+
if w1_tokens[-k:] == w2_tokens[:k]:
|
| 150 |
+
tokens_intersection = k
|
| 151 |
+
break
|
| 152 |
+
assert tokens_intersection is not None, (
|
| 153 |
+
f"{window1.doc_id} - {window1.sent_id} - {window1.offset}"
|
| 154 |
+
+ f" {window2.doc_id} - {window2.sent_id} - {window2.offset}\n"
|
| 155 |
+
+ f"w1 tokens: {w1_tokens}\n"
|
| 156 |
+
+ f"w2 tokens: {w2_tokens}\n"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
final_tokens = (
|
| 160 |
+
[window1.tokens[0]] # CLS
|
| 161 |
+
+ w1_tokens
|
| 162 |
+
+ w2_tokens[tokens_intersection:]
|
| 163 |
+
+ [window1.tokens[-1]] # SEP
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
w2_starting_offset = len(w1_tokens) - tokens_intersection
|
| 167 |
+
|
| 168 |
+
def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict:
|
| 169 |
+
final_t2c = dict()
|
| 170 |
+
final_t2c.update(t2c1)
|
| 171 |
+
for t, c in t2c2.items():
|
| 172 |
+
t = int(t)
|
| 173 |
+
if t < tokens_intersection:
|
| 174 |
+
continue
|
| 175 |
+
final_t2c[str(t + w2_starting_offset)] = c
|
| 176 |
+
return final_t2c
|
| 177 |
+
|
| 178 |
+
return (
|
| 179 |
+
final_tokens,
|
| 180 |
+
merge_char_mapping(window1.token2char_start, window2.token2char_start),
|
| 181 |
+
merge_char_mapping(window1.token2char_end, window2.token2char_end),
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def _merge_span_annotation(
|
| 185 |
+
self, span_annotation1: List[list], span_annotation2: List[list]
|
| 186 |
+
) -> List[list]:
|
| 187 |
+
uniq_store = set()
|
| 188 |
+
final_span_annotation_store = []
|
| 189 |
+
for span_annotation in itertools.chain(span_annotation1, span_annotation2):
|
| 190 |
+
span_annotation_id = tuple(span_annotation)
|
| 191 |
+
if span_annotation_id not in uniq_store:
|
| 192 |
+
uniq_store.add(span_annotation_id)
|
| 193 |
+
final_span_annotation_store.append(span_annotation)
|
| 194 |
+
return sorted(final_span_annotation_store, key=lambda x: x[0])
|
| 195 |
+
|
| 196 |
+
def _merge_predictions(
|
| 197 |
+
self,
|
| 198 |
+
window1: RelikReaderSample,
|
| 199 |
+
window2: RelikReaderSample,
|
| 200 |
+
) -> Tuple[Set[Tuple[int, int, str]], dict]:
|
| 201 |
+
merged_predictions = window1.predicted_window_labels_chars.union(
|
| 202 |
+
window2.predicted_window_labels_chars
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
span_title_probabilities = dict()
|
| 206 |
+
# probabilities
|
| 207 |
+
for span_prediction, predicted_probs in itertools.chain(
|
| 208 |
+
window1.probs_window_labels_chars.items(),
|
| 209 |
+
window2.probs_window_labels_chars.items(),
|
| 210 |
+
):
|
| 211 |
+
if span_prediction not in span_title_probabilities:
|
| 212 |
+
span_title_probabilities[span_prediction] = predicted_probs
|
| 213 |
+
|
| 214 |
+
return merged_predictions, span_title_probabilities
|
| 215 |
+
|
| 216 |
+
def _merge_window_pair(
|
| 217 |
+
self,
|
| 218 |
+
window1: RelikReaderSample,
|
| 219 |
+
window2: RelikReaderSample,
|
| 220 |
+
) -> RelikReaderSample:
|
| 221 |
+
merging_output = dict()
|
| 222 |
+
|
| 223 |
+
if getattr(window1, "doc_id", None) is not None:
|
| 224 |
+
assert window1.doc_id == window2.doc_id
|
| 225 |
+
|
| 226 |
+
if getattr(window1, "offset", None) is not None:
|
| 227 |
+
assert (
|
| 228 |
+
window1.offset < window2.offset
|
| 229 |
+
), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})"
|
| 230 |
+
|
| 231 |
+
merging_output["doc_id"] = window1.doc_id
|
| 232 |
+
merging_output["offset"] = window2.offset
|
| 233 |
+
|
| 234 |
+
m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens(
|
| 235 |
+
window1, window2
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
window_labels = None
|
| 239 |
+
if getattr(window1, "window_labels", None) is not None:
|
| 240 |
+
window_labels = self._merge_span_annotation(
|
| 241 |
+
window1.window_labels, window2.window_labels
|
| 242 |
+
)
|
| 243 |
+
(
|
| 244 |
+
predicted_window_labels_chars,
|
| 245 |
+
probs_window_labels_chars,
|
| 246 |
+
) = self._merge_predictions(
|
| 247 |
+
window1,
|
| 248 |
+
window2,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
merging_output.update(
|
| 252 |
+
dict(
|
| 253 |
+
tokens=m_tokens,
|
| 254 |
+
token2char_start=m_token2char_start,
|
| 255 |
+
token2char_end=m_token2char_end,
|
| 256 |
+
window_labels=window_labels,
|
| 257 |
+
predicted_window_labels_chars=predicted_window_labels_chars,
|
| 258 |
+
probs_window_labels_chars=probs_window_labels_chars,
|
| 259 |
+
)
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return RelikReaderSample(**merging_output)
|
relik/inference/gerbil.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import sys
|
| 6 |
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
| 7 |
+
from typing import Iterator, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
from relik.inference.annotator import Relik
|
| 10 |
+
from relik.inference.data.objects import RelikOutput
|
| 11 |
+
|
| 12 |
+
# sys.path += ['../']
|
| 13 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class GerbilAlbyManager:
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
annotator: Optional[Relik] = None,
|
| 25 |
+
response_logger_dir: Optional[str] = None,
|
| 26 |
+
) -> None:
|
| 27 |
+
self.annotator = annotator
|
| 28 |
+
self.response_logger_dir = response_logger_dir
|
| 29 |
+
self.predictions_counter = 0
|
| 30 |
+
self.labels_mapping = None
|
| 31 |
+
|
| 32 |
+
def annotate(self, document: str):
|
| 33 |
+
relik_output: RelikOutput = self.annotator(document)
|
| 34 |
+
annotations = [(ss, se, l) for ss, se, l, _ in relik_output.labels]
|
| 35 |
+
if self.labels_mapping is not None:
|
| 36 |
+
return [
|
| 37 |
+
(ss, se, self.labels_mapping.get(l, l)) for ss, se, l in annotations
|
| 38 |
+
]
|
| 39 |
+
return annotations
|
| 40 |
+
|
| 41 |
+
def set_mapping_file(self, mapping_file_path: str):
|
| 42 |
+
with open(mapping_file_path) as f:
|
| 43 |
+
labels_mapping = json.load(f)
|
| 44 |
+
self.labels_mapping = {v: k for k, v in labels_mapping.items()}
|
| 45 |
+
|
| 46 |
+
def write_response_bundle(
|
| 47 |
+
self,
|
| 48 |
+
document: str,
|
| 49 |
+
new_document: str,
|
| 50 |
+
annotations: list,
|
| 51 |
+
mapped_annotations: list,
|
| 52 |
+
) -> None:
|
| 53 |
+
if self.response_logger_dir is None:
|
| 54 |
+
return
|
| 55 |
+
|
| 56 |
+
if not os.path.isdir(self.response_logger_dir):
|
| 57 |
+
os.mkdir(self.response_logger_dir)
|
| 58 |
+
|
| 59 |
+
with open(
|
| 60 |
+
f"{self.response_logger_dir}/{self.predictions_counter}.json", "w"
|
| 61 |
+
) as f:
|
| 62 |
+
out_json_obj = dict(
|
| 63 |
+
document=document,
|
| 64 |
+
new_document=new_document,
|
| 65 |
+
annotations=annotations,
|
| 66 |
+
mapped_annotations=mapped_annotations,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
out_json_obj["span_annotations"] = [
|
| 70 |
+
(ss, se, document[ss:se], label) for (ss, se, label) in annotations
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
out_json_obj["span_mapped_annotations"] = [
|
| 74 |
+
(ss, se, new_document[ss:se], label)
|
| 75 |
+
for (ss, se, label) in mapped_annotations
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
json.dump(out_json_obj, f, indent=2)
|
| 79 |
+
|
| 80 |
+
self.predictions_counter += 1
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
manager = GerbilAlbyManager()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def preprocess_document(document: str) -> Tuple[str, List[Tuple[int, int]]]:
|
| 87 |
+
pattern_subs = {
|
| 88 |
+
"-LPR- ": " (",
|
| 89 |
+
"-RPR-": ")",
|
| 90 |
+
"\n\n": "\n",
|
| 91 |
+
"-LRB-": "(",
|
| 92 |
+
"-RRB-": ")",
|
| 93 |
+
'","': ",",
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
document_acc = document
|
| 97 |
+
curr_offset = 0
|
| 98 |
+
char2offset = []
|
| 99 |
+
|
| 100 |
+
matchings = re.finditer("({})".format("|".join(pattern_subs)), document)
|
| 101 |
+
for span_matching in sorted(matchings, key=lambda x: x.span()[0]):
|
| 102 |
+
span_start, span_end = span_matching.span()
|
| 103 |
+
span_start -= curr_offset
|
| 104 |
+
span_end -= curr_offset
|
| 105 |
+
|
| 106 |
+
span_text = document_acc[span_start:span_end]
|
| 107 |
+
span_sub = pattern_subs[span_text]
|
| 108 |
+
document_acc = document_acc[:span_start] + span_sub + document_acc[span_end:]
|
| 109 |
+
|
| 110 |
+
offset = len(span_text) - len(span_sub)
|
| 111 |
+
curr_offset += offset
|
| 112 |
+
|
| 113 |
+
char2offset.append((span_start + len(span_sub), curr_offset))
|
| 114 |
+
|
| 115 |
+
return document_acc, char2offset
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def map_back_annotations(
|
| 119 |
+
annotations: List[Tuple[int, int, str]], char_mapping: List[Tuple[int, int]]
|
| 120 |
+
) -> Iterator[Tuple[int, int, str]]:
|
| 121 |
+
def map_char(char_idx: int) -> int:
|
| 122 |
+
current_offset = 0
|
| 123 |
+
for offset_idx, offset_value in char_mapping:
|
| 124 |
+
if char_idx >= offset_idx:
|
| 125 |
+
current_offset = offset_value
|
| 126 |
+
else:
|
| 127 |
+
break
|
| 128 |
+
return char_idx + current_offset
|
| 129 |
+
|
| 130 |
+
for ss, se, label in annotations:
|
| 131 |
+
yield map_char(ss), map_char(se), label
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def annotate(document: str) -> List[Tuple[int, int, str]]:
|
| 135 |
+
new_document, mapping = preprocess_document(document)
|
| 136 |
+
logger.info("Mapping: " + str(mapping))
|
| 137 |
+
logger.info("Document: " + str(document))
|
| 138 |
+
annotations = [
|
| 139 |
+
(cs, ce, label.replace(" ", "_"))
|
| 140 |
+
for cs, ce, label in manager.annotate(new_document)
|
| 141 |
+
]
|
| 142 |
+
logger.info("New document: " + str(new_document))
|
| 143 |
+
mapped_annotations = (
|
| 144 |
+
list(map_back_annotations(annotations, mapping))
|
| 145 |
+
if len(mapping) > 0
|
| 146 |
+
else annotations
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
logger.info(
|
| 150 |
+
"Annotations: "
|
| 151 |
+
+ str([(ss, se, document[ss:se], ann) for ss, se, ann in mapped_annotations])
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
manager.write_response_bundle(
|
| 155 |
+
document, new_document, mapped_annotations, annotations
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if not all(
|
| 159 |
+
[
|
| 160 |
+
new_document[ss:se] == document[mss:mse]
|
| 161 |
+
for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
|
| 162 |
+
]
|
| 163 |
+
):
|
| 164 |
+
diff_mappings = [
|
| 165 |
+
(new_document[ss:se], document[mss:mse])
|
| 166 |
+
for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
|
| 167 |
+
]
|
| 168 |
+
return None
|
| 169 |
+
assert all(
|
| 170 |
+
[
|
| 171 |
+
document[mss:mse] == new_document[ss:se]
|
| 172 |
+
for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
|
| 173 |
+
]
|
| 174 |
+
), (mapped_annotations, annotations)
|
| 175 |
+
|
| 176 |
+
return [(cs, ce - cs, label) for cs, ce, label in mapped_annotations]
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class GetHandler(BaseHTTPRequestHandler):
|
| 180 |
+
def do_POST(self):
|
| 181 |
+
content_length = int(self.headers["Content-Length"])
|
| 182 |
+
post_data = self.rfile.read(content_length)
|
| 183 |
+
self.send_response(200)
|
| 184 |
+
self.end_headers()
|
| 185 |
+
doc_text = read_json(post_data)
|
| 186 |
+
# try:
|
| 187 |
+
response = annotate(doc_text)
|
| 188 |
+
|
| 189 |
+
self.wfile.write(bytes(json.dumps(response), "utf-8"))
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def read_json(post_data):
|
| 194 |
+
data = json.loads(post_data.decode("utf-8"))
|
| 195 |
+
# logger.info("received data:", data)
|
| 196 |
+
text = data["text"]
|
| 197 |
+
# spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]]
|
| 198 |
+
return text
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def parse_args() -> argparse.Namespace:
|
| 202 |
+
parser = argparse.ArgumentParser()
|
| 203 |
+
parser.add_argument("--relik-model-name", required=True)
|
| 204 |
+
parser.add_argument("--responses-log-dir")
|
| 205 |
+
parser.add_argument("--log-file", default="logs/logging.txt")
|
| 206 |
+
parser.add_argument("--mapping-file")
|
| 207 |
+
return parser.parse_args()
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def main():
|
| 211 |
+
args = parse_args()
|
| 212 |
+
|
| 213 |
+
# init manager
|
| 214 |
+
manager.response_logger_dir = args.responses_log_dir
|
| 215 |
+
# manager.annotator = Relik.from_pretrained(args.relik_model_name)
|
| 216 |
+
|
| 217 |
+
print("Debugging, not using you relik model but an hardcoded one.")
|
| 218 |
+
manager.annotator = Relik(
|
| 219 |
+
question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
|
| 220 |
+
document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
|
| 221 |
+
reader="relik/reader/models/relik-reader-deberta-base-new-data",
|
| 222 |
+
window_size=32,
|
| 223 |
+
window_stride=16,
|
| 224 |
+
candidates_preprocessing_fn=(lambda x: x.split("<def>")[0].strip()),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if args.mapping_file is not None:
|
| 228 |
+
manager.set_mapping_file(args.mapping_file)
|
| 229 |
+
|
| 230 |
+
port = 6654
|
| 231 |
+
server = HTTPServer(("localhost", port), GetHandler)
|
| 232 |
+
logger.info(f"Starting server at http://localhost:{port}")
|
| 233 |
+
|
| 234 |
+
# Create a file handler and set its level
|
| 235 |
+
file_handler = logging.FileHandler(args.log_file)
|
| 236 |
+
file_handler.setLevel(logging.DEBUG)
|
| 237 |
+
|
| 238 |
+
# Create a log formatter and set it on the handler
|
| 239 |
+
formatter = logging.Formatter(
|
| 240 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 241 |
+
)
|
| 242 |
+
file_handler.setFormatter(formatter)
|
| 243 |
+
|
| 244 |
+
# Add the file handler to the logger
|
| 245 |
+
logger.addHandler(file_handler)
|
| 246 |
+
|
| 247 |
+
try:
|
| 248 |
+
server.serve_forever()
|
| 249 |
+
except KeyboardInterrupt:
|
| 250 |
+
exit(0)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
main()
|
relik/inference/preprocessing.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def wikipedia_title_and_openings_preprocessing(
|
| 2 |
+
wikipedia_title_and_openings: str, sepator: str = " <def>"
|
| 3 |
+
):
|
| 4 |
+
return wikipedia_title_and_openings.split(sepator, 1)[0]
|
relik/inference/serve/__init__.py
ADDED
|
File without changes
|
relik/inference/serve/backend/__init__.py
ADDED
|
File without changes
|
relik/inference/serve/backend/relik.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List, Optional, Union
|
| 4 |
+
|
| 5 |
+
from relik.common.utils import is_package_available
|
| 6 |
+
from relik.inference.annotator import Relik
|
| 7 |
+
|
| 8 |
+
if not is_package_available("fastapi"):
|
| 9 |
+
raise ImportError(
|
| 10 |
+
"FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
|
| 11 |
+
)
|
| 12 |
+
from fastapi import FastAPI, HTTPException
|
| 13 |
+
|
| 14 |
+
if not is_package_available("ray"):
|
| 15 |
+
raise ImportError(
|
| 16 |
+
"Ray is not installed. Please install Ray with `pip install relik[serve]`."
|
| 17 |
+
)
|
| 18 |
+
from ray import serve
|
| 19 |
+
|
| 20 |
+
from relik.common.log import get_logger
|
| 21 |
+
from relik.inference.serve.backend.utils import (
|
| 22 |
+
RayParameterManager,
|
| 23 |
+
ServerParameterManager,
|
| 24 |
+
)
|
| 25 |
+
from relik.retriever.data.utils import batch_generator
|
| 26 |
+
|
| 27 |
+
logger = get_logger(__name__, level=logging.INFO)
|
| 28 |
+
|
| 29 |
+
VERSION = {} # type: ignore
|
| 30 |
+
with open(
|
| 31 |
+
Path(__file__).parent.parent.parent.parent / "version.py", "r"
|
| 32 |
+
) as version_file:
|
| 33 |
+
exec(version_file.read(), VERSION)
|
| 34 |
+
|
| 35 |
+
# Env variables for server
|
| 36 |
+
SERVER_MANAGER = ServerParameterManager()
|
| 37 |
+
RAY_MANAGER = RayParameterManager()
|
| 38 |
+
|
| 39 |
+
app = FastAPI(
|
| 40 |
+
title="ReLiK",
|
| 41 |
+
version=VERSION["VERSION"],
|
| 42 |
+
description="ReLiK REST API",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@serve.deployment(
|
| 47 |
+
ray_actor_options={
|
| 48 |
+
"num_gpus": RAY_MANAGER.num_gpus
|
| 49 |
+
if (
|
| 50 |
+
SERVER_MANAGER.retriver_device == "cuda"
|
| 51 |
+
or SERVER_MANAGER.reader_device == "cuda"
|
| 52 |
+
)
|
| 53 |
+
else 0
|
| 54 |
+
},
|
| 55 |
+
autoscaling_config={
|
| 56 |
+
"min_replicas": RAY_MANAGER.min_replicas,
|
| 57 |
+
"max_replicas": RAY_MANAGER.max_replicas,
|
| 58 |
+
},
|
| 59 |
+
)
|
| 60 |
+
@serve.ingress(app)
|
| 61 |
+
class RelikServer:
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
question_encoder: str,
|
| 65 |
+
document_index: str,
|
| 66 |
+
passage_encoder: Optional[str] = None,
|
| 67 |
+
reader_encoder: Optional[str] = None,
|
| 68 |
+
top_k: int = 100,
|
| 69 |
+
retriver_device: str = "cpu",
|
| 70 |
+
reader_device: str = "cpu",
|
| 71 |
+
index_device: Optional[str] = None,
|
| 72 |
+
precision: int = 32,
|
| 73 |
+
index_precision: Optional[int] = None,
|
| 74 |
+
use_faiss: bool = False,
|
| 75 |
+
window_batch_size: int = 32,
|
| 76 |
+
window_size: int = 32,
|
| 77 |
+
window_stride: int = 16,
|
| 78 |
+
split_on_spaces: bool = False,
|
| 79 |
+
):
|
| 80 |
+
# parameters
|
| 81 |
+
self.question_encoder = question_encoder
|
| 82 |
+
self.passage_encoder = passage_encoder
|
| 83 |
+
self.reader_encoder = reader_encoder
|
| 84 |
+
self.document_index = document_index
|
| 85 |
+
self.top_k = top_k
|
| 86 |
+
self.retriver_device = retriver_device
|
| 87 |
+
self.index_device = index_device or retriver_device
|
| 88 |
+
self.reader_device = reader_device
|
| 89 |
+
self.precision = precision
|
| 90 |
+
self.index_precision = index_precision or precision
|
| 91 |
+
self.use_faiss = use_faiss
|
| 92 |
+
self.window_batch_size = window_batch_size
|
| 93 |
+
self.window_size = window_size
|
| 94 |
+
self.window_stride = window_stride
|
| 95 |
+
self.split_on_spaces = split_on_spaces
|
| 96 |
+
|
| 97 |
+
# log stuff for debugging
|
| 98 |
+
logger.info("Initializing RelikServer with parameters:")
|
| 99 |
+
logger.info(f"QUESTION_ENCODER: {self.question_encoder}")
|
| 100 |
+
logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}")
|
| 101 |
+
logger.info(f"READER_ENCODER: {self.reader_encoder}")
|
| 102 |
+
logger.info(f"DOCUMENT_INDEX: {self.document_index}")
|
| 103 |
+
logger.info(f"TOP_K: {self.top_k}")
|
| 104 |
+
logger.info(f"RETRIEVER_DEVICE: {self.retriver_device}")
|
| 105 |
+
logger.info(f"READER_DEVICE: {self.reader_device}")
|
| 106 |
+
logger.info(f"INDEX_DEVICE: {self.index_device}")
|
| 107 |
+
logger.info(f"PRECISION: {self.precision}")
|
| 108 |
+
logger.info(f"INDEX_PRECISION: {self.index_precision}")
|
| 109 |
+
logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}")
|
| 110 |
+
logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}")
|
| 111 |
+
|
| 112 |
+
self.relik = Relik(
|
| 113 |
+
question_encoder=self.question_encoder,
|
| 114 |
+
passage_encoder=self.passage_encoder,
|
| 115 |
+
document_index=self.document_index,
|
| 116 |
+
reader=self.reader_encoder,
|
| 117 |
+
retriever_device=self.retriver_device,
|
| 118 |
+
document_index_device=self.index_device,
|
| 119 |
+
reader_device=self.reader_device,
|
| 120 |
+
retriever_precision=self.precision,
|
| 121 |
+
document_index_precision=self.index_precision,
|
| 122 |
+
reader_precision=self.precision,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# @serve.batch()
|
| 126 |
+
async def handle_batch(self, documents: List[str]) -> List:
|
| 127 |
+
return self.relik(
|
| 128 |
+
documents,
|
| 129 |
+
top_k=self.top_k,
|
| 130 |
+
window_size=self.window_size,
|
| 131 |
+
window_stride=self.window_stride,
|
| 132 |
+
batch_size=self.window_batch_size,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
@app.post("/api/entities")
|
| 136 |
+
async def entities_endpoint(
|
| 137 |
+
self,
|
| 138 |
+
documents: Union[str, List[str]],
|
| 139 |
+
):
|
| 140 |
+
try:
|
| 141 |
+
# normalize input
|
| 142 |
+
if isinstance(documents, str):
|
| 143 |
+
documents = [documents]
|
| 144 |
+
if document_topics is not None:
|
| 145 |
+
if isinstance(document_topics, str):
|
| 146 |
+
document_topics = [document_topics]
|
| 147 |
+
assert len(documents) == len(document_topics)
|
| 148 |
+
# get predictions for the retriever
|
| 149 |
+
return await self.handle_batch(documents, document_topics)
|
| 150 |
+
except Exception as e:
|
| 151 |
+
# log the entire stack trace
|
| 152 |
+
logger.exception(e)
|
| 153 |
+
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
|
| 154 |
+
|
| 155 |
+
@app.post("/api/gerbil")
|
| 156 |
+
async def gerbil_endpoint(self, documents: Union[str, List[str]]):
|
| 157 |
+
try:
|
| 158 |
+
# normalize input
|
| 159 |
+
if isinstance(documents, str):
|
| 160 |
+
documents = [documents]
|
| 161 |
+
|
| 162 |
+
# output list
|
| 163 |
+
windows_passages = []
|
| 164 |
+
# split documents into windows
|
| 165 |
+
document_windows = [
|
| 166 |
+
window
|
| 167 |
+
for doc_id, document in enumerate(documents)
|
| 168 |
+
for window in self.window_manager(
|
| 169 |
+
self.tokenizer,
|
| 170 |
+
document,
|
| 171 |
+
window_size=self.window_size,
|
| 172 |
+
stride=self.window_stride,
|
| 173 |
+
doc_id=doc_id,
|
| 174 |
+
)
|
| 175 |
+
]
|
| 176 |
+
|
| 177 |
+
# get text and topic from document windows and create new list
|
| 178 |
+
model_inputs = [
|
| 179 |
+
(window.text, window.doc_topic) for window in document_windows
|
| 180 |
+
]
|
| 181 |
+
|
| 182 |
+
# batch generator
|
| 183 |
+
for batch in batch_generator(
|
| 184 |
+
model_inputs, batch_size=self.window_batch_size
|
| 185 |
+
):
|
| 186 |
+
text, text_pair = zip(*batch)
|
| 187 |
+
batch_predictions = await self.handle_batch_retriever(text, text_pair)
|
| 188 |
+
windows_passages.extend(
|
| 189 |
+
[
|
| 190 |
+
[p.label for p in predictions]
|
| 191 |
+
for predictions in batch_predictions
|
| 192 |
+
]
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# add passage to document windows
|
| 196 |
+
for window, passages in zip(document_windows, windows_passages):
|
| 197 |
+
# clean up passages (remove everything after first <def> tag if present)
|
| 198 |
+
passages = [c.split(" <def>", 1)[0] for c in passages]
|
| 199 |
+
window.window_candidates = passages
|
| 200 |
+
|
| 201 |
+
# return document windows
|
| 202 |
+
return document_windows
|
| 203 |
+
|
| 204 |
+
except Exception as e:
|
| 205 |
+
# log the entire stack trace
|
| 206 |
+
logger.exception(e)
|
| 207 |
+
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
server = RelikServer.bind(**vars(SERVER_MANAGER))
|
relik/inference/serve/backend/retriever.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List, Optional, Union
|
| 4 |
+
|
| 5 |
+
from relik.common.utils import is_package_available
|
| 6 |
+
|
| 7 |
+
if not is_package_available("fastapi"):
|
| 8 |
+
raise ImportError(
|
| 9 |
+
"FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
|
| 10 |
+
)
|
| 11 |
+
from fastapi import FastAPI, HTTPException
|
| 12 |
+
|
| 13 |
+
if not is_package_available("ray"):
|
| 14 |
+
raise ImportError(
|
| 15 |
+
"Ray is not installed. Please install Ray with `pip install relik[serve]`."
|
| 16 |
+
)
|
| 17 |
+
from ray import serve
|
| 18 |
+
|
| 19 |
+
from relik.common.log import get_logger
|
| 20 |
+
from relik.inference.data.tokenizers import SpacyTokenizer, WhitespaceTokenizer
|
| 21 |
+
from relik.inference.data.window.manager import WindowManager
|
| 22 |
+
from relik.inference.serve.backend.utils import (
|
| 23 |
+
RayParameterManager,
|
| 24 |
+
ServerParameterManager,
|
| 25 |
+
)
|
| 26 |
+
from relik.retriever.data.utils import batch_generator
|
| 27 |
+
from relik.retriever.pytorch_modules import GoldenRetriever
|
| 28 |
+
|
| 29 |
+
logger = get_logger(__name__, level=logging.INFO)
|
| 30 |
+
|
| 31 |
+
VERSION = {} # type: ignore
|
| 32 |
+
with open(Path(__file__).parent.parent.parent / "version.py", "r") as version_file:
|
| 33 |
+
exec(version_file.read(), VERSION)
|
| 34 |
+
|
| 35 |
+
# Env variables for server
|
| 36 |
+
SERVER_MANAGER = ServerParameterManager()
|
| 37 |
+
RAY_MANAGER = RayParameterManager()
|
| 38 |
+
|
| 39 |
+
app = FastAPI(
|
| 40 |
+
title="Golden Retriever",
|
| 41 |
+
version=VERSION["VERSION"],
|
| 42 |
+
description="Golden Retriever REST API",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@serve.deployment(
|
| 47 |
+
ray_actor_options={
|
| 48 |
+
"num_gpus": RAY_MANAGER.num_gpus if SERVER_MANAGER.device == "cuda" else 0
|
| 49 |
+
},
|
| 50 |
+
autoscaling_config={
|
| 51 |
+
"min_replicas": RAY_MANAGER.min_replicas,
|
| 52 |
+
"max_replicas": RAY_MANAGER.max_replicas,
|
| 53 |
+
},
|
| 54 |
+
)
|
| 55 |
+
@serve.ingress(app)
|
| 56 |
+
class GoldenRetrieverServer:
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
question_encoder: str,
|
| 60 |
+
document_index: str,
|
| 61 |
+
passage_encoder: Optional[str] = None,
|
| 62 |
+
top_k: int = 100,
|
| 63 |
+
device: str = "cpu",
|
| 64 |
+
index_device: Optional[str] = None,
|
| 65 |
+
precision: int = 32,
|
| 66 |
+
index_precision: Optional[int] = None,
|
| 67 |
+
use_faiss: bool = False,
|
| 68 |
+
window_batch_size: int = 32,
|
| 69 |
+
window_size: int = 32,
|
| 70 |
+
window_stride: int = 16,
|
| 71 |
+
split_on_spaces: bool = False,
|
| 72 |
+
):
|
| 73 |
+
# parameters
|
| 74 |
+
self.question_encoder = question_encoder
|
| 75 |
+
self.passage_encoder = passage_encoder
|
| 76 |
+
self.document_index = document_index
|
| 77 |
+
self.top_k = top_k
|
| 78 |
+
self.device = device
|
| 79 |
+
self.index_device = index_device or device
|
| 80 |
+
self.precision = precision
|
| 81 |
+
self.index_precision = index_precision or precision
|
| 82 |
+
self.use_faiss = use_faiss
|
| 83 |
+
self.window_batch_size = window_batch_size
|
| 84 |
+
self.window_size = window_size
|
| 85 |
+
self.window_stride = window_stride
|
| 86 |
+
self.split_on_spaces = split_on_spaces
|
| 87 |
+
|
| 88 |
+
# log stuff for debugging
|
| 89 |
+
logger.info("Initializing GoldenRetrieverServer with parameters:")
|
| 90 |
+
logger.info(f"QUESTION_ENCODER: {self.question_encoder}")
|
| 91 |
+
logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}")
|
| 92 |
+
logger.info(f"DOCUMENT_INDEX: {self.document_index}")
|
| 93 |
+
logger.info(f"TOP_K: {self.top_k}")
|
| 94 |
+
logger.info(f"DEVICE: {self.device}")
|
| 95 |
+
logger.info(f"INDEX_DEVICE: {self.index_device}")
|
| 96 |
+
logger.info(f"PRECISION: {self.precision}")
|
| 97 |
+
logger.info(f"INDEX_PRECISION: {self.index_precision}")
|
| 98 |
+
logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}")
|
| 99 |
+
logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}")
|
| 100 |
+
|
| 101 |
+
self.retriever = GoldenRetriever(
|
| 102 |
+
question_encoder=self.question_encoder,
|
| 103 |
+
passage_encoder=self.passage_encoder,
|
| 104 |
+
document_index=self.document_index,
|
| 105 |
+
device=self.device,
|
| 106 |
+
index_device=self.index_device,
|
| 107 |
+
index_precision=self.index_precision,
|
| 108 |
+
)
|
| 109 |
+
self.retriever.eval()
|
| 110 |
+
|
| 111 |
+
if self.split_on_spaces:
|
| 112 |
+
logger.info("Using WhitespaceTokenizer")
|
| 113 |
+
self.tokenizer = WhitespaceTokenizer()
|
| 114 |
+
# logger.info("Using RegexTokenizer")
|
| 115 |
+
# self.tokenizer = RegexTokenizer()
|
| 116 |
+
else:
|
| 117 |
+
logger.info("Using SpacyTokenizer")
|
| 118 |
+
self.tokenizer = SpacyTokenizer(language="en")
|
| 119 |
+
|
| 120 |
+
self.window_manager = WindowManager(tokenizer=self.tokenizer)
|
| 121 |
+
|
| 122 |
+
# @serve.batch()
|
| 123 |
+
async def handle_batch(
|
| 124 |
+
self, documents: List[str], document_topics: List[str]
|
| 125 |
+
) -> List:
|
| 126 |
+
return self.retriever.retrieve(
|
| 127 |
+
documents, text_pair=document_topics, k=self.top_k, precision=self.precision
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
@app.post("/api/retrieve")
|
| 131 |
+
async def retrieve_endpoint(
|
| 132 |
+
self,
|
| 133 |
+
documents: Union[str, List[str]],
|
| 134 |
+
document_topics: Optional[Union[str, List[str]]] = None,
|
| 135 |
+
):
|
| 136 |
+
try:
|
| 137 |
+
# normalize input
|
| 138 |
+
if isinstance(documents, str):
|
| 139 |
+
documents = [documents]
|
| 140 |
+
if document_topics is not None:
|
| 141 |
+
if isinstance(document_topics, str):
|
| 142 |
+
document_topics = [document_topics]
|
| 143 |
+
assert len(documents) == len(document_topics)
|
| 144 |
+
# get predictions
|
| 145 |
+
return await self.handle_batch(documents, document_topics)
|
| 146 |
+
except Exception as e:
|
| 147 |
+
# log the entire stack trace
|
| 148 |
+
logger.exception(e)
|
| 149 |
+
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
|
| 150 |
+
|
| 151 |
+
@app.post("/api/gerbil")
|
| 152 |
+
async def gerbil_endpoint(self, documents: Union[str, List[str]]):
|
| 153 |
+
try:
|
| 154 |
+
# normalize input
|
| 155 |
+
if isinstance(documents, str):
|
| 156 |
+
documents = [documents]
|
| 157 |
+
|
| 158 |
+
# output list
|
| 159 |
+
windows_passages = []
|
| 160 |
+
# split documents into windows
|
| 161 |
+
document_windows = [
|
| 162 |
+
window
|
| 163 |
+
for doc_id, document in enumerate(documents)
|
| 164 |
+
for window in self.window_manager(
|
| 165 |
+
self.tokenizer,
|
| 166 |
+
document,
|
| 167 |
+
window_size=self.window_size,
|
| 168 |
+
stride=self.window_stride,
|
| 169 |
+
doc_id=doc_id,
|
| 170 |
+
)
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
# get text and topic from document windows and create new list
|
| 174 |
+
model_inputs = [
|
| 175 |
+
(window.text, window.doc_topic) for window in document_windows
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
# batch generator
|
| 179 |
+
for batch in batch_generator(
|
| 180 |
+
model_inputs, batch_size=self.window_batch_size
|
| 181 |
+
):
|
| 182 |
+
text, text_pair = zip(*batch)
|
| 183 |
+
batch_predictions = await self.handle_batch(text, text_pair)
|
| 184 |
+
windows_passages.extend(
|
| 185 |
+
[
|
| 186 |
+
[p.label for p in predictions]
|
| 187 |
+
for predictions in batch_predictions
|
| 188 |
+
]
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# add passage to document windows
|
| 192 |
+
for window, passages in zip(document_windows, windows_passages):
|
| 193 |
+
# clean up passages (remove everything after first <def> tag if present)
|
| 194 |
+
passages = [c.split(" <def>", 1)[0] for c in passages]
|
| 195 |
+
window.window_candidates = passages
|
| 196 |
+
|
| 197 |
+
# return document windows
|
| 198 |
+
return document_windows
|
| 199 |
+
|
| 200 |
+
except Exception as e:
|
| 201 |
+
# log the entire stack trace
|
| 202 |
+
logger.exception(e)
|
| 203 |
+
raise HTTPException(status_code=500, detail=f"Server Error: {e}")
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
server = GoldenRetrieverServer.bind(**vars(SERVER_MANAGER))
|
relik/inference/serve/backend/utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class ServerParameterManager:
|
| 8 |
+
retriver_device: str = os.environ.get("RETRIEVER_DEVICE", "cpu")
|
| 9 |
+
reader_device: str = os.environ.get("READER_DEVICE", "cpu")
|
| 10 |
+
index_device: str = os.environ.get("INDEX_DEVICE", retriver_device)
|
| 11 |
+
precision: Union[str, int] = os.environ.get("PRECISION", "fp32")
|
| 12 |
+
index_precision: Union[str, int] = os.environ.get("INDEX_PRECISION", precision)
|
| 13 |
+
question_encoder: str = os.environ.get("QUESTION_ENCODER", None)
|
| 14 |
+
passage_encoder: str = os.environ.get("PASSAGE_ENCODER", None)
|
| 15 |
+
document_index: str = os.environ.get("DOCUMENT_INDEX", None)
|
| 16 |
+
reader_encoder: str = os.environ.get("READER_ENCODER", None)
|
| 17 |
+
top_k: int = int(os.environ.get("TOP_K", 100))
|
| 18 |
+
use_faiss: bool = os.environ.get("USE_FAISS", False)
|
| 19 |
+
window_batch_size: int = int(os.environ.get("WINDOW_BATCH_SIZE", 32))
|
| 20 |
+
window_size: int = int(os.environ.get("WINDOW_SIZE", 32))
|
| 21 |
+
window_stride: int = int(os.environ.get("WINDOW_SIZE", 16))
|
| 22 |
+
split_on_spaces: bool = os.environ.get("SPLIT_ON_SPACES", False)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class RayParameterManager:
|
| 26 |
+
def __init__(self) -> None:
|
| 27 |
+
self.num_gpus = int(os.environ.get("NUM_GPUS", 1))
|
| 28 |
+
self.min_replicas = int(os.environ.get("MIN_REPLICAS", 1))
|
| 29 |
+
self.max_replicas = int(os.environ.get("MAX_REPLICAS", 1))
|
relik/inference/serve/frontend/__init__.py
ADDED
|
File without changes
|
relik/inference/serve/frontend/relik.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import time
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import requests
|
| 7 |
+
import streamlit as st
|
| 8 |
+
from spacy import displacy
|
| 9 |
+
from streamlit_extras.badges import badge
|
| 10 |
+
from streamlit_extras.stylable_container import stylable_container
|
| 11 |
+
|
| 12 |
+
RELIK = os.getenv("RELIK", "localhost:8000/api/entities")
|
| 13 |
+
|
| 14 |
+
import random
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_random_color(ents):
|
| 18 |
+
colors = {}
|
| 19 |
+
random_colors = generate_pastel_colors(len(ents))
|
| 20 |
+
for ent in ents:
|
| 21 |
+
colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
|
| 22 |
+
return colors
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def floatrange(start, stop, steps):
|
| 26 |
+
if int(steps) == 1:
|
| 27 |
+
return [stop]
|
| 28 |
+
return [
|
| 29 |
+
start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def hsl_to_rgb(h, s, l):
|
| 34 |
+
def hue_2_rgb(v1, v2, v_h):
|
| 35 |
+
while v_h < 0.0:
|
| 36 |
+
v_h += 1.0
|
| 37 |
+
while v_h > 1.0:
|
| 38 |
+
v_h -= 1.0
|
| 39 |
+
if 6 * v_h < 1.0:
|
| 40 |
+
return v1 + (v2 - v1) * 6.0 * v_h
|
| 41 |
+
if 2 * v_h < 1.0:
|
| 42 |
+
return v2
|
| 43 |
+
if 3 * v_h < 2.0:
|
| 44 |
+
return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
|
| 45 |
+
return v1
|
| 46 |
+
|
| 47 |
+
# if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
|
| 48 |
+
# if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
|
| 49 |
+
|
| 50 |
+
r, b, g = (l * 255,) * 3
|
| 51 |
+
if s != 0.0:
|
| 52 |
+
if l < 0.5:
|
| 53 |
+
var_2 = l * (1.0 + s)
|
| 54 |
+
else:
|
| 55 |
+
var_2 = (l + s) - (s * l)
|
| 56 |
+
var_1 = 2.0 * l - var_2
|
| 57 |
+
r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
|
| 58 |
+
g = 255 * hue_2_rgb(var_1, var_2, h)
|
| 59 |
+
b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
|
| 60 |
+
|
| 61 |
+
return int(round(r)), int(round(g)), int(round(b))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def generate_pastel_colors(n):
|
| 65 |
+
"""Return different pastel colours.
|
| 66 |
+
|
| 67 |
+
Input:
|
| 68 |
+
n (integer) : The number of colors to return
|
| 69 |
+
|
| 70 |
+
Output:
|
| 71 |
+
A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
|
| 72 |
+
|
| 73 |
+
Example:
|
| 74 |
+
>>> print generate_pastel_colors(5)
|
| 75 |
+
['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
|
| 76 |
+
"""
|
| 77 |
+
if n == 0:
|
| 78 |
+
return []
|
| 79 |
+
|
| 80 |
+
# To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
|
| 81 |
+
start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue
|
| 82 |
+
saturation = 1.0
|
| 83 |
+
lightness = 0.8
|
| 84 |
+
# We take points around the chromatic circle (hue):
|
| 85 |
+
# (Note: we generate n+1 colors, then drop the last one ([:-1]) because
|
| 86 |
+
# it equals the first one (hue 0 = hue 1))
|
| 87 |
+
return [
|
| 88 |
+
"#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
|
| 89 |
+
for hue in floatrange(start_hue, start_hue + 1, n + 1)
|
| 90 |
+
][:-1]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def set_sidebar(css):
|
| 94 |
+
white_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>"
|
| 95 |
+
with st.sidebar:
|
| 96 |
+
st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
|
| 97 |
+
st.image(
|
| 98 |
+
"http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
|
| 99 |
+
use_column_width=True,
|
| 100 |
+
)
|
| 101 |
+
st.markdown("## ReLiK")
|
| 102 |
+
st.write(
|
| 103 |
+
f"""
|
| 104 |
+
- {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i> Paper")}
|
| 105 |
+
- {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i> GitHub")}
|
| 106 |
+
- {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i> Docker Hub")}
|
| 107 |
+
""",
|
| 108 |
+
unsafe_allow_html=True,
|
| 109 |
+
)
|
| 110 |
+
st.markdown("## Sapienza NLP")
|
| 111 |
+
st.write(
|
| 112 |
+
f"""
|
| 113 |
+
- {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i> Webpage")}
|
| 114 |
+
- {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i> GitHub")}
|
| 115 |
+
- {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i> Twitter")}
|
| 116 |
+
- {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i> LinkedIn")}
|
| 117 |
+
""",
|
| 118 |
+
unsafe_allow_html=True,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_el_annotations(response):
|
| 123 |
+
# swap labels key with ents
|
| 124 |
+
response["ents"] = response.pop("labels")
|
| 125 |
+
label_in_text = set(l["label"] for l in response["ents"])
|
| 126 |
+
options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
|
| 127 |
+
return response, options
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def set_intro(css):
|
| 131 |
+
# intro
|
| 132 |
+
st.markdown("# ReLik")
|
| 133 |
+
st.markdown(
|
| 134 |
+
"### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget"
|
| 135 |
+
)
|
| 136 |
+
# st.markdown(
|
| 137 |
+
# "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
|
| 138 |
+
# "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by "
|
| 139 |
+
# "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
|
| 140 |
+
# "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
|
| 141 |
+
# )
|
| 142 |
+
badge(type="github", name="sapienzanlp/relik")
|
| 143 |
+
badge(type="pypi", name="relik")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def run_client():
|
| 147 |
+
with open(Path(__file__).parent / "style.css") as f:
|
| 148 |
+
css = f.read()
|
| 149 |
+
|
| 150 |
+
st.set_page_config(
|
| 151 |
+
page_title="ReLik",
|
| 152 |
+
page_icon="🦮",
|
| 153 |
+
layout="wide",
|
| 154 |
+
)
|
| 155 |
+
set_sidebar(css)
|
| 156 |
+
set_intro(css)
|
| 157 |
+
|
| 158 |
+
# text input
|
| 159 |
+
text = st.text_area(
|
| 160 |
+
"Enter Text Below:",
|
| 161 |
+
value="Obama went to Rome for a quick vacation.",
|
| 162 |
+
height=200,
|
| 163 |
+
max_chars=500,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
with stylable_container(
|
| 167 |
+
key="annotate_button",
|
| 168 |
+
css_styles="""
|
| 169 |
+
button {
|
| 170 |
+
background-color: #802433;
|
| 171 |
+
color: white;
|
| 172 |
+
border-radius: 25px;
|
| 173 |
+
}
|
| 174 |
+
""",
|
| 175 |
+
):
|
| 176 |
+
submit = st.button("Annotate")
|
| 177 |
+
# submit = st.button("Run")
|
| 178 |
+
|
| 179 |
+
# ReLik API call
|
| 180 |
+
if submit:
|
| 181 |
+
text = text.strip()
|
| 182 |
+
if text:
|
| 183 |
+
st.markdown("####")
|
| 184 |
+
st.markdown("#### Entity Linking")
|
| 185 |
+
with st.spinner(text="In progress"):
|
| 186 |
+
response = requests.post(RELIK, json=text)
|
| 187 |
+
if response.status_code != 200:
|
| 188 |
+
st.error("Error: {}".format(response.status_code))
|
| 189 |
+
else:
|
| 190 |
+
response = response.json()
|
| 191 |
+
|
| 192 |
+
# Entity Linking
|
| 193 |
+
# with stylable_container(
|
| 194 |
+
# key="container_with_border",
|
| 195 |
+
# css_styles="""
|
| 196 |
+
# {
|
| 197 |
+
# border: 1px solid rgba(49, 51, 63, 0.2);
|
| 198 |
+
# border-radius: 0.5rem;
|
| 199 |
+
# padding: 0.5rem;
|
| 200 |
+
# padding-bottom: 2rem;
|
| 201 |
+
# }
|
| 202 |
+
# """,
|
| 203 |
+
# ):
|
| 204 |
+
# st.markdown("##")
|
| 205 |
+
dict_of_ents, options = get_el_annotations(response=response)
|
| 206 |
+
display = displacy.render(
|
| 207 |
+
dict_of_ents, manual=True, style="ent", options=options
|
| 208 |
+
)
|
| 209 |
+
display = display.replace("\n", " ")
|
| 210 |
+
# wsd_display = re.sub(
|
| 211 |
+
# r"(wiki::\d+\w)",
|
| 212 |
+
# r"<a href='https://babelnet.org/synset?id=\g<1>&orig=\g<1>&lang={}'>\g<1></a>".format(
|
| 213 |
+
# language.upper()
|
| 214 |
+
# ),
|
| 215 |
+
# wsd_display,
|
| 216 |
+
# )
|
| 217 |
+
with st.container():
|
| 218 |
+
st.write(display, unsafe_allow_html=True)
|
| 219 |
+
|
| 220 |
+
st.markdown("####")
|
| 221 |
+
st.markdown("#### Relation Extraction")
|
| 222 |
+
|
| 223 |
+
with st.container():
|
| 224 |
+
st.write("Coming :)", unsafe_allow_html=True)
|
| 225 |
+
|
| 226 |
+
else:
|
| 227 |
+
st.error("Please enter some text.")
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
if __name__ == "__main__":
|
| 231 |
+
run_client()
|
relik/inference/serve/frontend/style.css
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Sidebar */
|
| 2 |
+
.eczjsme11 {
|
| 3 |
+
background-color: #802433;
|
| 4 |
+
}
|
| 5 |
+
|
| 6 |
+
.st-emotion-cache-10oheav h2 {
|
| 7 |
+
color: white;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
.st-emotion-cache-10oheav li {
|
| 11 |
+
color: white;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
/* Main */
|
| 15 |
+
a:link {
|
| 16 |
+
text-decoration: none;
|
| 17 |
+
color: white;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
a:visited {
|
| 21 |
+
text-decoration: none;
|
| 22 |
+
color: white;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
a:hover {
|
| 26 |
+
text-decoration: none;
|
| 27 |
+
color: rgba(255, 255, 255, 0.871);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
a:active {
|
| 31 |
+
text-decoration: none;
|
| 32 |
+
color: white;
|
| 33 |
+
}
|
relik/reader/__init__.py
ADDED
|
File without changes
|
relik/reader/conf/config.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Required to make the "experiments" dir the default one for the output of the models
|
| 2 |
+
hydra:
|
| 3 |
+
run:
|
| 4 |
+
dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 5 |
+
|
| 6 |
+
model_name: relik-reader-deberta-base # used to name the model in wandb and output dir
|
| 7 |
+
project_name: relik-reader # used to name the project in wandb
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
defaults:
|
| 11 |
+
- _self_
|
| 12 |
+
- training: base
|
| 13 |
+
- model: base
|
| 14 |
+
- data: base
|
relik/reader/conf/data/base.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train_dataset_path: "relik/reader/data/train.jsonl"
|
| 2 |
+
val_dataset_path: "relik/reader/data/testa.jsonl"
|
| 3 |
+
|
| 4 |
+
train_dataset:
|
| 5 |
+
_target_: "relik.reader.relik_reader_data.RelikDataset"
|
| 6 |
+
transformer_model: "${model.model.transformer_model}"
|
| 7 |
+
materialize_samples: False
|
| 8 |
+
shuffle_candidates: 0.5
|
| 9 |
+
random_drop_gold_candidates: 0.05
|
| 10 |
+
noise_param: 0.0
|
| 11 |
+
for_inference: False
|
| 12 |
+
tokens_per_batch: 4096
|
| 13 |
+
special_symbols: null
|
| 14 |
+
|
| 15 |
+
val_dataset:
|
| 16 |
+
_target_: "relik.reader.relik_reader_data.RelikDataset"
|
| 17 |
+
transformer_model: "${model.model.transformer_model}"
|
| 18 |
+
materialize_samples: False
|
| 19 |
+
shuffle_candidates: False
|
| 20 |
+
for_inference: True
|
| 21 |
+
special_symbols: null
|
relik/reader/conf/data/re.yaml
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train_dataset_path: "relik/reader/data/nyt-alby+/train.jsonl"
|
| 2 |
+
val_dataset_path: "relik/reader/data/nyt-alby+/valid.jsonl"
|
| 3 |
+
test_dataset_path: "relik/reader/data/nyt-alby+/test.jsonl"
|
| 4 |
+
|
| 5 |
+
relations_definitions:
|
| 6 |
+
/people/person/nationality: "nationality"
|
| 7 |
+
/sports/sports_team/location: "sports team location"
|
| 8 |
+
/location/country/administrative_divisions: "administrative divisions"
|
| 9 |
+
/business/company/major_shareholders: "shareholders"
|
| 10 |
+
/people/ethnicity/people: "ethnicity"
|
| 11 |
+
/people/ethnicity/geographic_distribution: "geographic distributi6on"
|
| 12 |
+
/business/company_shareholder/major_shareholder_of: "major shareholder"
|
| 13 |
+
/location/location/contains: "location"
|
| 14 |
+
/business/company/founders: "founders"
|
| 15 |
+
/business/person/company: "company"
|
| 16 |
+
/business/company/advisors: "advisor"
|
| 17 |
+
/people/deceased_person/place_of_death: "place of death"
|
| 18 |
+
/business/company/industry: "industry"
|
| 19 |
+
/people/person/ethnicity: "ethnic background"
|
| 20 |
+
/people/person/place_of_birth: "place of birth"
|
| 21 |
+
/location/administrative_division/country: "country of an administration division"
|
| 22 |
+
/people/person/place_lived: "place lived"
|
| 23 |
+
/sports/sports_team_location/teams: "sports team"
|
| 24 |
+
/people/person/children: "child"
|
| 25 |
+
/people/person/religion: "religion"
|
| 26 |
+
/location/neighborhood/neighborhood_of: "neighborhood"
|
| 27 |
+
/location/country/capital: "capital"
|
| 28 |
+
/business/company/place_founded: "company founded location"
|
| 29 |
+
/people/person/profession: "occupation"
|
| 30 |
+
|
| 31 |
+
train_dataset:
|
| 32 |
+
_target_: "relik.reader.relik_reader_re_data.RelikREDataset"
|
| 33 |
+
transformer_model: "${model.model.transformer_model}"
|
| 34 |
+
materialize_samples: False
|
| 35 |
+
shuffle_candidates: False
|
| 36 |
+
flip_candidates: 1.0
|
| 37 |
+
noise_param: 0.0
|
| 38 |
+
for_inference: False
|
| 39 |
+
tokens_per_batch: 4096
|
| 40 |
+
min_length: -1
|
| 41 |
+
special_symbols: null
|
| 42 |
+
relations_definitions: ${data.relations_definitions}
|
| 43 |
+
sorting_fields:
|
| 44 |
+
- "predictable_candidates"
|
| 45 |
+
val_dataset:
|
| 46 |
+
_target_: "relik.reader.relik_reader_re_data.RelikREDataset"
|
| 47 |
+
transformer_model: "${model.model.transformer_model}"
|
| 48 |
+
materialize_samples: False
|
| 49 |
+
shuffle_candidates: False
|
| 50 |
+
flip_candidates: False
|
| 51 |
+
for_inference: True
|
| 52 |
+
min_length: -1
|
| 53 |
+
special_symbols: null
|
| 54 |
+
relations_definitions: ${data.relations_definitions}
|
relik/reader/conf/training/base.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed: 94
|
| 2 |
+
|
| 3 |
+
trainer:
|
| 4 |
+
_target_: lightning.Trainer
|
| 5 |
+
devices:
|
| 6 |
+
- 0
|
| 7 |
+
precision: "16-mixed"
|
| 8 |
+
max_steps: 50000
|
| 9 |
+
val_check_interval: 1.0
|
| 10 |
+
num_sanity_val_steps: 0
|
| 11 |
+
limit_val_batches: 1
|
| 12 |
+
gradient_clip_val: 1.0
|
relik/reader/conf/training/re.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed: 15
|
| 2 |
+
|
| 3 |
+
trainer:
|
| 4 |
+
_target_: lightning.Trainer
|
| 5 |
+
devices:
|
| 6 |
+
- 0
|
| 7 |
+
precision: "16-mixed"
|
| 8 |
+
max_steps: 100000
|
| 9 |
+
val_check_interval: 1.0
|
| 10 |
+
num_sanity_val_steps: 0
|
| 11 |
+
limit_val_batches: 1
|
| 12 |
+
gradient_clip_val: 1.0
|
relik/reader/data/__init__.py
ADDED
|
File without changes
|
relik/reader/data/patches.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from relik.reader.data.relik_reader_sample import RelikReaderSample
|
| 4 |
+
from relik.reader.utils.special_symbols import NME_SYMBOL
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def merge_patches_predictions(sample) -> None:
|
| 8 |
+
sample._d["predicted_window_labels"] = dict()
|
| 9 |
+
predicted_window_labels = sample._d["predicted_window_labels"]
|
| 10 |
+
|
| 11 |
+
sample._d["span_title_probabilities"] = dict()
|
| 12 |
+
span_title_probabilities = sample._d["span_title_probabilities"]
|
| 13 |
+
|
| 14 |
+
span2title = dict()
|
| 15 |
+
for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
|
| 16 |
+
# selecting span predictions
|
| 17 |
+
for predicted_title, predicted_spans in patch_info[
|
| 18 |
+
"predicted_window_labels"
|
| 19 |
+
].items():
|
| 20 |
+
for pred_span in predicted_spans:
|
| 21 |
+
pred_span = tuple(pred_span)
|
| 22 |
+
curr_title = span2title.get(pred_span)
|
| 23 |
+
if curr_title is None or curr_title == NME_SYMBOL:
|
| 24 |
+
span2title[pred_span] = predicted_title
|
| 25 |
+
# else:
|
| 26 |
+
# print("Merging at patch level")
|
| 27 |
+
|
| 28 |
+
# selecting span predictions probability
|
| 29 |
+
for predicted_span, titles_probabilities in patch_info[
|
| 30 |
+
"span_title_probabilities"
|
| 31 |
+
].items():
|
| 32 |
+
if predicted_span not in span_title_probabilities:
|
| 33 |
+
span_title_probabilities[predicted_span] = titles_probabilities
|
| 34 |
+
|
| 35 |
+
for span, title in span2title.items():
|
| 36 |
+
if title not in predicted_window_labels:
|
| 37 |
+
predicted_window_labels[title] = list()
|
| 38 |
+
predicted_window_labels[title].append(span)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def remove_duplicate_samples(
|
| 42 |
+
samples: List[RelikReaderSample],
|
| 43 |
+
) -> List[RelikReaderSample]:
|
| 44 |
+
seen_sample = set()
|
| 45 |
+
samples_store = []
|
| 46 |
+
for sample in samples:
|
| 47 |
+
sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}"
|
| 48 |
+
if sample_id not in seen_sample:
|
| 49 |
+
seen_sample.add(sample_id)
|
| 50 |
+
samples_store.append(sample)
|
| 51 |
+
return samples_store
|
relik/reader/data/relik_reader_data.py
ADDED
|
@@ -0,0 +1,965 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import (
|
| 3 |
+
Any,
|
| 4 |
+
Callable,
|
| 5 |
+
Dict,
|
| 6 |
+
Generator,
|
| 7 |
+
Iterable,
|
| 8 |
+
Iterator,
|
| 9 |
+
List,
|
| 10 |
+
NamedTuple,
|
| 11 |
+
Optional,
|
| 12 |
+
Tuple,
|
| 13 |
+
Union,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
from torch.utils.data import IterableDataset
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from transformers import AutoTokenizer, PreTrainedTokenizer
|
| 21 |
+
|
| 22 |
+
from relik.reader.data.relik_reader_data_utils import (
|
| 23 |
+
add_noise_to_value,
|
| 24 |
+
batchify,
|
| 25 |
+
chunks,
|
| 26 |
+
flatten,
|
| 27 |
+
)
|
| 28 |
+
from relik.reader.data.relik_reader_sample import (
|
| 29 |
+
RelikReaderSample,
|
| 30 |
+
load_relik_reader_samples,
|
| 31 |
+
)
|
| 32 |
+
from relik.reader.utils.special_symbols import NME_SYMBOL
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def preprocess_dataset(
|
| 38 |
+
input_dataset: Iterable[dict],
|
| 39 |
+
transformer_model: str,
|
| 40 |
+
add_topic: bool,
|
| 41 |
+
) -> Iterable[dict]:
|
| 42 |
+
tokenizer = AutoTokenizer.from_pretrained(transformer_model)
|
| 43 |
+
for dataset_elem in tqdm(input_dataset, desc="Preprocessing input dataset"):
|
| 44 |
+
if len(dataset_elem["tokens"]) == 0:
|
| 45 |
+
print(
|
| 46 |
+
f"Dataset element with doc id: {dataset_elem['doc_id']}",
|
| 47 |
+
f"and offset {dataset_elem['offset']} does not contain any token",
|
| 48 |
+
"Skipping it",
|
| 49 |
+
)
|
| 50 |
+
continue
|
| 51 |
+
|
| 52 |
+
new_dataset_elem = dict(
|
| 53 |
+
doc_id=dataset_elem["doc_id"],
|
| 54 |
+
offset=dataset_elem["offset"],
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
tokenization_out = tokenizer(
|
| 58 |
+
dataset_elem["tokens"],
|
| 59 |
+
return_offsets_mapping=True,
|
| 60 |
+
add_special_tokens=False,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
window_tokens = tokenization_out.input_ids
|
| 64 |
+
window_tokens = flatten(window_tokens)
|
| 65 |
+
|
| 66 |
+
offsets_mapping = [
|
| 67 |
+
[
|
| 68 |
+
(
|
| 69 |
+
ss + dataset_elem["token2char_start"][str(i)],
|
| 70 |
+
se + dataset_elem["token2char_start"][str(i)],
|
| 71 |
+
)
|
| 72 |
+
for ss, se in tokenization_out.offset_mapping[i]
|
| 73 |
+
]
|
| 74 |
+
for i in range(len(dataset_elem["tokens"]))
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
offsets_mapping = flatten(offsets_mapping)
|
| 78 |
+
|
| 79 |
+
assert len(offsets_mapping) == len(window_tokens)
|
| 80 |
+
|
| 81 |
+
window_tokens = (
|
| 82 |
+
[tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id]
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
topic_offset = 0
|
| 86 |
+
if add_topic:
|
| 87 |
+
topic_tokens = tokenizer(
|
| 88 |
+
dataset_elem["doc_topic"], add_special_tokens=False
|
| 89 |
+
).input_ids
|
| 90 |
+
topic_offset = len(topic_tokens)
|
| 91 |
+
new_dataset_elem["topic_tokens"] = topic_offset
|
| 92 |
+
window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:]
|
| 93 |
+
|
| 94 |
+
new_dataset_elem.update(
|
| 95 |
+
dict(
|
| 96 |
+
tokens=window_tokens,
|
| 97 |
+
token2char_start={
|
| 98 |
+
str(i): s
|
| 99 |
+
for i, (s, _) in enumerate(offsets_mapping, start=topic_offset)
|
| 100 |
+
},
|
| 101 |
+
token2char_end={
|
| 102 |
+
str(i): e
|
| 103 |
+
for i, (_, e) in enumerate(offsets_mapping, start=topic_offset)
|
| 104 |
+
},
|
| 105 |
+
window_candidates=dataset_elem["window_candidates"],
|
| 106 |
+
window_candidates_scores=dataset_elem.get("window_candidates_scores"),
|
| 107 |
+
)
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if "window_labels" in dataset_elem:
|
| 111 |
+
window_labels = [
|
| 112 |
+
(s, e, l.replace("_", " ")) for s, e, l in dataset_elem["window_labels"]
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
new_dataset_elem["window_labels"] = window_labels
|
| 116 |
+
|
| 117 |
+
if not all(
|
| 118 |
+
[
|
| 119 |
+
s in new_dataset_elem["token2char_start"].values()
|
| 120 |
+
for s, _, _ in new_dataset_elem["window_labels"]
|
| 121 |
+
]
|
| 122 |
+
):
|
| 123 |
+
print(
|
| 124 |
+
"Mismatching token start char mapping with labels",
|
| 125 |
+
new_dataset_elem["token2char_start"],
|
| 126 |
+
new_dataset_elem["window_labels"],
|
| 127 |
+
dataset_elem["tokens"],
|
| 128 |
+
)
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
if not all(
|
| 132 |
+
[
|
| 133 |
+
e in new_dataset_elem["token2char_end"].values()
|
| 134 |
+
for _, e, _ in new_dataset_elem["window_labels"]
|
| 135 |
+
]
|
| 136 |
+
):
|
| 137 |
+
print(
|
| 138 |
+
"Mismatching token end char mapping with labels",
|
| 139 |
+
new_dataset_elem["token2char_end"],
|
| 140 |
+
new_dataset_elem["window_labels"],
|
| 141 |
+
dataset_elem["tokens"],
|
| 142 |
+
)
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
yield new_dataset_elem
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def preprocess_sample(
|
| 149 |
+
relik_sample: RelikReaderSample,
|
| 150 |
+
tokenizer,
|
| 151 |
+
lowercase_policy: float,
|
| 152 |
+
add_topic: bool = False,
|
| 153 |
+
) -> None:
|
| 154 |
+
if len(relik_sample.tokens) == 0:
|
| 155 |
+
return
|
| 156 |
+
|
| 157 |
+
if lowercase_policy > 0:
|
| 158 |
+
lc_tokens = np.random.uniform(0, 1, len(relik_sample.tokens)) < lowercase_policy
|
| 159 |
+
relik_sample.tokens = [
|
| 160 |
+
t.lower() if lc else t for t, lc in zip(relik_sample.tokens, lc_tokens)
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
tokenization_out = tokenizer(
|
| 164 |
+
relik_sample.tokens,
|
| 165 |
+
return_offsets_mapping=True,
|
| 166 |
+
add_special_tokens=False,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
window_tokens = tokenization_out.input_ids
|
| 170 |
+
window_tokens = flatten(window_tokens)
|
| 171 |
+
|
| 172 |
+
offsets_mapping = [
|
| 173 |
+
[
|
| 174 |
+
(
|
| 175 |
+
ss + relik_sample.token2char_start[str(i)],
|
| 176 |
+
se + relik_sample.token2char_start[str(i)],
|
| 177 |
+
)
|
| 178 |
+
for ss, se in tokenization_out.offset_mapping[i]
|
| 179 |
+
]
|
| 180 |
+
for i in range(len(relik_sample.tokens))
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
offsets_mapping = flatten(offsets_mapping)
|
| 184 |
+
|
| 185 |
+
assert len(offsets_mapping) == len(window_tokens)
|
| 186 |
+
|
| 187 |
+
window_tokens = [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id]
|
| 188 |
+
|
| 189 |
+
topic_offset = 0
|
| 190 |
+
if add_topic:
|
| 191 |
+
topic_tokens = tokenizer(
|
| 192 |
+
relik_sample.doc_topic, add_special_tokens=False
|
| 193 |
+
).input_ids
|
| 194 |
+
topic_offset = len(topic_tokens)
|
| 195 |
+
relik_sample.topic_tokens = topic_offset
|
| 196 |
+
window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:]
|
| 197 |
+
|
| 198 |
+
relik_sample._d.update(
|
| 199 |
+
dict(
|
| 200 |
+
tokens=window_tokens,
|
| 201 |
+
token2char_start={
|
| 202 |
+
str(i): s
|
| 203 |
+
for i, (s, _) in enumerate(offsets_mapping, start=topic_offset)
|
| 204 |
+
},
|
| 205 |
+
token2char_end={
|
| 206 |
+
str(i): e
|
| 207 |
+
for i, (_, e) in enumerate(offsets_mapping, start=topic_offset)
|
| 208 |
+
},
|
| 209 |
+
)
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
if "window_labels" in relik_sample._d:
|
| 213 |
+
relik_sample.window_labels = [
|
| 214 |
+
(s, e, l.replace("_", " ")) for s, e, l in relik_sample.window_labels
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class TokenizationOutput(NamedTuple):
|
| 219 |
+
input_ids: torch.Tensor
|
| 220 |
+
attention_mask: torch.Tensor
|
| 221 |
+
token_type_ids: torch.Tensor
|
| 222 |
+
prediction_mask: torch.Tensor
|
| 223 |
+
special_symbols_mask: torch.Tensor
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class RelikDataset(IterableDataset):
|
| 227 |
+
def __init__(
|
| 228 |
+
self,
|
| 229 |
+
dataset_path: Optional[str],
|
| 230 |
+
materialize_samples: bool,
|
| 231 |
+
transformer_model: Union[str, PreTrainedTokenizer],
|
| 232 |
+
special_symbols: List[str],
|
| 233 |
+
shuffle_candidates: Optional[Union[bool, float]] = False,
|
| 234 |
+
for_inference: bool = False,
|
| 235 |
+
noise_param: float = 0.1,
|
| 236 |
+
sorting_fields: Optional[str] = None,
|
| 237 |
+
tokens_per_batch: int = 2048,
|
| 238 |
+
batch_size: int = None,
|
| 239 |
+
max_batch_size: int = 128,
|
| 240 |
+
section_size: int = 50_000,
|
| 241 |
+
prebatch: bool = True,
|
| 242 |
+
random_drop_gold_candidates: float = 0.0,
|
| 243 |
+
use_nme: bool = True,
|
| 244 |
+
max_subwords_per_candidate: bool = 22,
|
| 245 |
+
mask_by_instances: bool = False,
|
| 246 |
+
min_length: int = 5,
|
| 247 |
+
max_length: int = 2048,
|
| 248 |
+
model_max_length: int = 1000,
|
| 249 |
+
split_on_cand_overload: bool = True,
|
| 250 |
+
skip_empty_training_samples: bool = False,
|
| 251 |
+
drop_last: bool = False,
|
| 252 |
+
samples: Optional[Iterator[RelikReaderSample]] = None,
|
| 253 |
+
lowercase_policy: float = 0.0,
|
| 254 |
+
**kwargs,
|
| 255 |
+
):
|
| 256 |
+
super().__init__(**kwargs)
|
| 257 |
+
self.dataset_path = dataset_path
|
| 258 |
+
self.materialize_samples = materialize_samples
|
| 259 |
+
self.samples: Optional[List[RelikReaderSample]] = None
|
| 260 |
+
if self.materialize_samples:
|
| 261 |
+
self.samples = list()
|
| 262 |
+
|
| 263 |
+
if isinstance(transformer_model, str):
|
| 264 |
+
self.tokenizer = self._build_tokenizer(transformer_model, special_symbols)
|
| 265 |
+
else:
|
| 266 |
+
self.tokenizer = transformer_model
|
| 267 |
+
self.special_symbols = special_symbols
|
| 268 |
+
self.shuffle_candidates = shuffle_candidates
|
| 269 |
+
self.for_inference = for_inference
|
| 270 |
+
self.noise_param = noise_param
|
| 271 |
+
self.batching_fields = ["input_ids"]
|
| 272 |
+
self.sorting_fields = (
|
| 273 |
+
sorting_fields if sorting_fields is not None else self.batching_fields
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
self.tokens_per_batch = tokens_per_batch
|
| 277 |
+
self.batch_size = batch_size
|
| 278 |
+
self.max_batch_size = max_batch_size
|
| 279 |
+
self.section_size = section_size
|
| 280 |
+
self.prebatch = prebatch
|
| 281 |
+
|
| 282 |
+
self.random_drop_gold_candidates = random_drop_gold_candidates
|
| 283 |
+
self.use_nme = use_nme
|
| 284 |
+
self.max_subwords_per_candidate = max_subwords_per_candidate
|
| 285 |
+
self.mask_by_instances = mask_by_instances
|
| 286 |
+
self.min_length = min_length
|
| 287 |
+
self.max_length = max_length
|
| 288 |
+
self.model_max_length = (
|
| 289 |
+
model_max_length
|
| 290 |
+
if model_max_length < self.tokenizer.model_max_length
|
| 291 |
+
else self.tokenizer.model_max_length
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# retrocompatibility workaround
|
| 295 |
+
self.transformer_model = (
|
| 296 |
+
transformer_model
|
| 297 |
+
if isinstance(transformer_model, str)
|
| 298 |
+
else transformer_model.name_or_path
|
| 299 |
+
)
|
| 300 |
+
self.split_on_cand_overload = split_on_cand_overload
|
| 301 |
+
self.skip_empty_training_samples = skip_empty_training_samples
|
| 302 |
+
self.drop_last = drop_last
|
| 303 |
+
self.lowercase_policy = lowercase_policy
|
| 304 |
+
self.samples = samples
|
| 305 |
+
|
| 306 |
+
def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]):
|
| 307 |
+
return AutoTokenizer.from_pretrained(
|
| 308 |
+
transformer_model,
|
| 309 |
+
additional_special_tokens=[ss for ss in special_symbols],
|
| 310 |
+
add_prefix_space=True,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
@property
|
| 314 |
+
def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]:
|
| 315 |
+
fields_batchers = {
|
| 316 |
+
"input_ids": lambda x: batchify(
|
| 317 |
+
x, padding_value=self.tokenizer.pad_token_id
|
| 318 |
+
),
|
| 319 |
+
"attention_mask": lambda x: batchify(x, padding_value=0),
|
| 320 |
+
"token_type_ids": lambda x: batchify(x, padding_value=0),
|
| 321 |
+
"prediction_mask": lambda x: batchify(x, padding_value=1),
|
| 322 |
+
"global_attention": lambda x: batchify(x, padding_value=0),
|
| 323 |
+
"token2word": None,
|
| 324 |
+
"sample": None,
|
| 325 |
+
"special_symbols_mask": lambda x: batchify(x, padding_value=False),
|
| 326 |
+
"start_labels": lambda x: batchify(x, padding_value=-100),
|
| 327 |
+
"end_labels": lambda x: batchify(x, padding_value=-100),
|
| 328 |
+
"predictable_candidates_symbols": None,
|
| 329 |
+
"predictable_candidates": None,
|
| 330 |
+
"patch_offset": None,
|
| 331 |
+
"optimus_labels": None,
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
if "roberta" in self.transformer_model:
|
| 335 |
+
del fields_batchers["token_type_ids"]
|
| 336 |
+
|
| 337 |
+
return fields_batchers
|
| 338 |
+
|
| 339 |
+
def _build_input_ids(
|
| 340 |
+
self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]]
|
| 341 |
+
) -> List[int]:
|
| 342 |
+
return (
|
| 343 |
+
[self.tokenizer.cls_token_id]
|
| 344 |
+
+ sentence_input_ids
|
| 345 |
+
+ [self.tokenizer.sep_token_id]
|
| 346 |
+
+ flatten(candidates_input_ids)
|
| 347 |
+
+ [self.tokenizer.sep_token_id]
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
def _get_special_symbols_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 351 |
+
special_symbols_mask = input_ids >= (
|
| 352 |
+
len(self.tokenizer) - len(self.special_symbols)
|
| 353 |
+
)
|
| 354 |
+
special_symbols_mask[0] = True
|
| 355 |
+
return special_symbols_mask
|
| 356 |
+
|
| 357 |
+
def _build_tokenizer_essentials(
|
| 358 |
+
self, input_ids, original_sequence, sample
|
| 359 |
+
) -> TokenizationOutput:
|
| 360 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
| 361 |
+
attention_mask = torch.ones_like(input_ids)
|
| 362 |
+
|
| 363 |
+
total_sequence_len = len(input_ids)
|
| 364 |
+
predictable_sentence_len = len(original_sequence)
|
| 365 |
+
|
| 366 |
+
# token type ids
|
| 367 |
+
token_type_ids = torch.cat(
|
| 368 |
+
[
|
| 369 |
+
input_ids.new_zeros(
|
| 370 |
+
predictable_sentence_len + 2
|
| 371 |
+
), # original sentence bpes + CLS and SEP
|
| 372 |
+
input_ids.new_ones(total_sequence_len - predictable_sentence_len - 2),
|
| 373 |
+
]
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# prediction mask -> boolean on tokens that are predictable
|
| 377 |
+
|
| 378 |
+
prediction_mask = torch.tensor(
|
| 379 |
+
[1]
|
| 380 |
+
+ ([0] * predictable_sentence_len)
|
| 381 |
+
+ ([1] * (total_sequence_len - predictable_sentence_len - 1))
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# add topic tokens to the prediction mask so that they cannot be predicted
|
| 385 |
+
# or optimized during training
|
| 386 |
+
topic_tokens = getattr(sample, "topic_tokens", None)
|
| 387 |
+
if topic_tokens is not None:
|
| 388 |
+
prediction_mask[1 : 1 + topic_tokens] = 1
|
| 389 |
+
|
| 390 |
+
# If mask by instances is active the prediction mask is applied to everything
|
| 391 |
+
# that is not indicated as an instance in the training set.
|
| 392 |
+
if self.mask_by_instances:
|
| 393 |
+
char_start2token = {
|
| 394 |
+
cs: int(tok) for tok, cs in sample.token2char_start.items()
|
| 395 |
+
}
|
| 396 |
+
char_end2token = {ce: int(tok) for tok, ce in sample.token2char_end.items()}
|
| 397 |
+
instances_mask = torch.ones_like(prediction_mask)
|
| 398 |
+
for _, span_info in sample.instance_id2span_data.items():
|
| 399 |
+
span_info = span_info[0]
|
| 400 |
+
token_start = char_start2token[span_info[0]] + 1 # +1 for the CLS
|
| 401 |
+
token_end = char_end2token[span_info[1]] + 1 # +1 for the CLS
|
| 402 |
+
instances_mask[token_start : token_end + 1] = 0
|
| 403 |
+
|
| 404 |
+
prediction_mask += instances_mask
|
| 405 |
+
prediction_mask[prediction_mask > 1] = 1
|
| 406 |
+
|
| 407 |
+
assert len(prediction_mask) == len(input_ids)
|
| 408 |
+
|
| 409 |
+
# special symbols mask
|
| 410 |
+
special_symbols_mask = self._get_special_symbols_mask(input_ids)
|
| 411 |
+
|
| 412 |
+
return TokenizationOutput(
|
| 413 |
+
input_ids,
|
| 414 |
+
attention_mask,
|
| 415 |
+
token_type_ids,
|
| 416 |
+
prediction_mask,
|
| 417 |
+
special_symbols_mask,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
def _build_labels(
|
| 421 |
+
self,
|
| 422 |
+
sample,
|
| 423 |
+
tokenization_output: TokenizationOutput,
|
| 424 |
+
predictable_candidates: List[str],
|
| 425 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 426 |
+
start_labels = [0] * len(tokenization_output.input_ids)
|
| 427 |
+
end_labels = [0] * len(tokenization_output.input_ids)
|
| 428 |
+
|
| 429 |
+
char_start2token = {v: int(k) for k, v in sample.token2char_start.items()}
|
| 430 |
+
char_end2token = {v: int(k) for k, v in sample.token2char_end.items()}
|
| 431 |
+
for cs, ce, gold_candidate_title in sample.window_labels:
|
| 432 |
+
if gold_candidate_title not in predictable_candidates:
|
| 433 |
+
if self.use_nme:
|
| 434 |
+
gold_candidate_title = NME_SYMBOL
|
| 435 |
+
else:
|
| 436 |
+
continue
|
| 437 |
+
# +1 is to account for the CLS token
|
| 438 |
+
start_bpe = char_start2token[cs] + 1
|
| 439 |
+
end_bpe = char_end2token[ce] + 1
|
| 440 |
+
class_index = predictable_candidates.index(gold_candidate_title)
|
| 441 |
+
if (
|
| 442 |
+
start_labels[start_bpe] == 0 and end_labels[end_bpe] == 0
|
| 443 |
+
): # prevent from having entities that ends with the same label
|
| 444 |
+
start_labels[start_bpe] = class_index + 1 # +1 for the NONE class
|
| 445 |
+
end_labels[end_bpe] = class_index + 1 # +1 for the NONE class
|
| 446 |
+
else:
|
| 447 |
+
print(
|
| 448 |
+
"Found entity with the same last subword, it will not be included."
|
| 449 |
+
)
|
| 450 |
+
print(
|
| 451 |
+
cs,
|
| 452 |
+
ce,
|
| 453 |
+
gold_candidate_title,
|
| 454 |
+
start_labels,
|
| 455 |
+
end_labels,
|
| 456 |
+
sample.doc_id,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
ignored_labels_indices = tokenization_output.prediction_mask == 1
|
| 460 |
+
|
| 461 |
+
start_labels = torch.tensor(start_labels, dtype=torch.long)
|
| 462 |
+
start_labels[ignored_labels_indices] = -100
|
| 463 |
+
|
| 464 |
+
end_labels = torch.tensor(end_labels, dtype=torch.long)
|
| 465 |
+
end_labels[ignored_labels_indices] = -100
|
| 466 |
+
|
| 467 |
+
return start_labels, end_labels
|
| 468 |
+
|
| 469 |
+
def produce_sample_bag(
|
| 470 |
+
self, sample, predictable_candidates: List[str], candidates_starting_offset: int
|
| 471 |
+
) -> Optional[Tuple[dict, list, int]]:
|
| 472 |
+
# input sentence tokenization
|
| 473 |
+
input_subwords = sample.tokens[1:-1] # removing special tokens
|
| 474 |
+
candidates_symbols = self.special_symbols[candidates_starting_offset:]
|
| 475 |
+
|
| 476 |
+
predictable_candidates = list(predictable_candidates)
|
| 477 |
+
original_predictable_candidates = list(predictable_candidates)
|
| 478 |
+
|
| 479 |
+
# add NME as a possible candidate
|
| 480 |
+
if self.use_nme:
|
| 481 |
+
predictable_candidates.insert(0, NME_SYMBOL)
|
| 482 |
+
|
| 483 |
+
# candidates encoding
|
| 484 |
+
candidates_symbols = candidates_symbols[: len(predictable_candidates)]
|
| 485 |
+
candidates_encoding_result = self.tokenizer.batch_encode_plus(
|
| 486 |
+
[
|
| 487 |
+
"{} {}".format(cs, ct) if ct != NME_SYMBOL else NME_SYMBOL
|
| 488 |
+
for cs, ct in zip(candidates_symbols, predictable_candidates)
|
| 489 |
+
],
|
| 490 |
+
add_special_tokens=False,
|
| 491 |
+
).input_ids
|
| 492 |
+
|
| 493 |
+
if (
|
| 494 |
+
self.max_subwords_per_candidate is not None
|
| 495 |
+
and self.max_subwords_per_candidate > 0
|
| 496 |
+
):
|
| 497 |
+
candidates_encoding_result = [
|
| 498 |
+
cer[: self.max_subwords_per_candidate]
|
| 499 |
+
for cer in candidates_encoding_result
|
| 500 |
+
]
|
| 501 |
+
|
| 502 |
+
# drop candidates if the number of input tokens is too long for the model
|
| 503 |
+
if (
|
| 504 |
+
sum(map(len, candidates_encoding_result))
|
| 505 |
+
+ len(input_subwords)
|
| 506 |
+
+ 20 # + 20 special tokens
|
| 507 |
+
> self.model_max_length
|
| 508 |
+
):
|
| 509 |
+
acceptable_tokens_from_candidates = (
|
| 510 |
+
self.model_max_length - 20 - len(input_subwords)
|
| 511 |
+
)
|
| 512 |
+
i = 0
|
| 513 |
+
cum_len = 0
|
| 514 |
+
while (
|
| 515 |
+
cum_len + len(candidates_encoding_result[i])
|
| 516 |
+
< acceptable_tokens_from_candidates
|
| 517 |
+
):
|
| 518 |
+
cum_len += len(candidates_encoding_result[i])
|
| 519 |
+
i += 1
|
| 520 |
+
|
| 521 |
+
candidates_encoding_result = candidates_encoding_result[:i]
|
| 522 |
+
candidates_symbols = candidates_symbols[:i]
|
| 523 |
+
predictable_candidates = predictable_candidates[:i]
|
| 524 |
+
|
| 525 |
+
# final input_ids build
|
| 526 |
+
input_ids = self._build_input_ids(
|
| 527 |
+
sentence_input_ids=input_subwords,
|
| 528 |
+
candidates_input_ids=candidates_encoding_result,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
# complete input building (e.g. attention / prediction mask)
|
| 532 |
+
tokenization_output = self._build_tokenizer_essentials(
|
| 533 |
+
input_ids, input_subwords, sample
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
output_dict = {
|
| 537 |
+
"input_ids": tokenization_output.input_ids,
|
| 538 |
+
"attention_mask": tokenization_output.attention_mask,
|
| 539 |
+
"token_type_ids": tokenization_output.token_type_ids,
|
| 540 |
+
"prediction_mask": tokenization_output.prediction_mask,
|
| 541 |
+
"special_symbols_mask": tokenization_output.special_symbols_mask,
|
| 542 |
+
"sample": sample,
|
| 543 |
+
"predictable_candidates_symbols": candidates_symbols,
|
| 544 |
+
"predictable_candidates": predictable_candidates,
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
# labels creation
|
| 548 |
+
if sample.window_labels is not None:
|
| 549 |
+
start_labels, end_labels = self._build_labels(
|
| 550 |
+
sample,
|
| 551 |
+
tokenization_output,
|
| 552 |
+
predictable_candidates,
|
| 553 |
+
)
|
| 554 |
+
output_dict.update(start_labels=start_labels, end_labels=end_labels)
|
| 555 |
+
|
| 556 |
+
if (
|
| 557 |
+
"roberta" in self.transformer_model
|
| 558 |
+
or "longformer" in self.transformer_model
|
| 559 |
+
):
|
| 560 |
+
del output_dict["token_type_ids"]
|
| 561 |
+
|
| 562 |
+
predictable_candidates_set = set(predictable_candidates)
|
| 563 |
+
remaining_candidates = [
|
| 564 |
+
candidate
|
| 565 |
+
for candidate in original_predictable_candidates
|
| 566 |
+
if candidate not in predictable_candidates_set
|
| 567 |
+
]
|
| 568 |
+
total_used_candidates = (
|
| 569 |
+
candidates_starting_offset
|
| 570 |
+
+ len(predictable_candidates)
|
| 571 |
+
- (1 if self.use_nme else 0)
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
if self.use_nme:
|
| 575 |
+
assert predictable_candidates[0] == NME_SYMBOL
|
| 576 |
+
|
| 577 |
+
return output_dict, remaining_candidates, total_used_candidates
|
| 578 |
+
|
| 579 |
+
def __iter__(self):
|
| 580 |
+
dataset_iterator = self.dataset_iterator_func()
|
| 581 |
+
|
| 582 |
+
current_dataset_elements = []
|
| 583 |
+
|
| 584 |
+
i = None
|
| 585 |
+
for i, dataset_elem in enumerate(dataset_iterator, start=1):
|
| 586 |
+
if (
|
| 587 |
+
self.section_size is not None
|
| 588 |
+
and len(current_dataset_elements) == self.section_size
|
| 589 |
+
):
|
| 590 |
+
for batch in self.materialize_batches(current_dataset_elements):
|
| 591 |
+
yield batch
|
| 592 |
+
current_dataset_elements = []
|
| 593 |
+
|
| 594 |
+
current_dataset_elements.append(dataset_elem)
|
| 595 |
+
|
| 596 |
+
if i % 50_000 == 0:
|
| 597 |
+
logger.info(f"Processed: {i} number of elements")
|
| 598 |
+
|
| 599 |
+
if len(current_dataset_elements) != 0:
|
| 600 |
+
for batch in self.materialize_batches(current_dataset_elements):
|
| 601 |
+
yield batch
|
| 602 |
+
|
| 603 |
+
if i is not None:
|
| 604 |
+
logger.info(f"Dataset finished: {i} number of elements processed")
|
| 605 |
+
else:
|
| 606 |
+
logger.warning("Dataset empty")
|
| 607 |
+
|
| 608 |
+
def dataset_iterator_func(self):
|
| 609 |
+
skipped_instances = 0
|
| 610 |
+
data_samples = (
|
| 611 |
+
load_relik_reader_samples(self.dataset_path)
|
| 612 |
+
if self.samples is None
|
| 613 |
+
else self.samples
|
| 614 |
+
)
|
| 615 |
+
for sample in data_samples:
|
| 616 |
+
preprocess_sample(
|
| 617 |
+
sample, self.tokenizer, lowercase_policy=self.lowercase_policy
|
| 618 |
+
)
|
| 619 |
+
current_patch = 0
|
| 620 |
+
sample_bag, used_candidates = None, None
|
| 621 |
+
remaining_candidates = list(sample.window_candidates)
|
| 622 |
+
|
| 623 |
+
if not self.for_inference:
|
| 624 |
+
# randomly drop gold candidates at training time
|
| 625 |
+
if (
|
| 626 |
+
self.random_drop_gold_candidates > 0.0
|
| 627 |
+
and np.random.uniform() < self.random_drop_gold_candidates
|
| 628 |
+
and len(set(ct for _, _, ct in sample.window_labels)) > 1
|
| 629 |
+
):
|
| 630 |
+
# selecting candidates to drop
|
| 631 |
+
np.random.shuffle(sample.window_labels)
|
| 632 |
+
n_dropped_candidates = np.random.randint(
|
| 633 |
+
0, len(sample.window_labels) - 1
|
| 634 |
+
)
|
| 635 |
+
dropped_candidates = [
|
| 636 |
+
label_elem[-1]
|
| 637 |
+
for label_elem in sample.window_labels[:n_dropped_candidates]
|
| 638 |
+
]
|
| 639 |
+
dropped_candidates = set(dropped_candidates)
|
| 640 |
+
|
| 641 |
+
# saving NMEs because they should not be dropped
|
| 642 |
+
if NME_SYMBOL in dropped_candidates:
|
| 643 |
+
dropped_candidates.remove(NME_SYMBOL)
|
| 644 |
+
|
| 645 |
+
# sample update
|
| 646 |
+
sample.window_labels = [
|
| 647 |
+
(s, e, _l)
|
| 648 |
+
if _l not in dropped_candidates
|
| 649 |
+
else (s, e, NME_SYMBOL)
|
| 650 |
+
for s, e, _l in sample.window_labels
|
| 651 |
+
]
|
| 652 |
+
remaining_candidates = [
|
| 653 |
+
wc
|
| 654 |
+
for wc in remaining_candidates
|
| 655 |
+
if wc not in dropped_candidates
|
| 656 |
+
]
|
| 657 |
+
|
| 658 |
+
# shuffle candidates
|
| 659 |
+
if (
|
| 660 |
+
isinstance(self.shuffle_candidates, bool)
|
| 661 |
+
and self.shuffle_candidates
|
| 662 |
+
) or (
|
| 663 |
+
isinstance(self.shuffle_candidates, float)
|
| 664 |
+
and np.random.uniform() < self.shuffle_candidates
|
| 665 |
+
):
|
| 666 |
+
np.random.shuffle(remaining_candidates)
|
| 667 |
+
|
| 668 |
+
while len(remaining_candidates) != 0:
|
| 669 |
+
sample_bag = self.produce_sample_bag(
|
| 670 |
+
sample,
|
| 671 |
+
predictable_candidates=remaining_candidates,
|
| 672 |
+
candidates_starting_offset=used_candidates
|
| 673 |
+
if used_candidates is not None
|
| 674 |
+
else 0,
|
| 675 |
+
)
|
| 676 |
+
if sample_bag is not None:
|
| 677 |
+
sample_bag, remaining_candidates, used_candidates = sample_bag
|
| 678 |
+
if (
|
| 679 |
+
self.for_inference
|
| 680 |
+
or not self.skip_empty_training_samples
|
| 681 |
+
or (
|
| 682 |
+
(
|
| 683 |
+
sample_bag.get("start_labels") is not None
|
| 684 |
+
and torch.any(sample_bag["start_labels"] > 1).item()
|
| 685 |
+
)
|
| 686 |
+
or (
|
| 687 |
+
sample_bag.get("optimus_labels") is not None
|
| 688 |
+
and len(sample_bag["optimus_labels"]) > 0
|
| 689 |
+
)
|
| 690 |
+
)
|
| 691 |
+
):
|
| 692 |
+
sample_bag["patch_offset"] = current_patch
|
| 693 |
+
current_patch += 1
|
| 694 |
+
yield sample_bag
|
| 695 |
+
else:
|
| 696 |
+
skipped_instances += 1
|
| 697 |
+
if skipped_instances % 1000 == 0 and skipped_instances != 0:
|
| 698 |
+
logger.info(
|
| 699 |
+
f"Skipped {skipped_instances} instances since they did not have any gold labels..."
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
# Just use the first fitting candidates if split on
|
| 703 |
+
# cand is not True
|
| 704 |
+
if not self.split_on_cand_overload:
|
| 705 |
+
break
|
| 706 |
+
|
| 707 |
+
def preshuffle_elements(self, dataset_elements: List):
|
| 708 |
+
# This shuffling is done so that when using the sorting function,
|
| 709 |
+
# if it is deterministic given a collection and its order, we will
|
| 710 |
+
# make the whole operation not deterministic anymore.
|
| 711 |
+
# Basically, the aim is not to build every time the same batches.
|
| 712 |
+
if not self.for_inference:
|
| 713 |
+
dataset_elements = np.random.permutation(dataset_elements)
|
| 714 |
+
|
| 715 |
+
sorting_fn = (
|
| 716 |
+
lambda elem: add_noise_to_value(
|
| 717 |
+
sum(len(elem[k]) for k in self.sorting_fields),
|
| 718 |
+
noise_param=self.noise_param,
|
| 719 |
+
)
|
| 720 |
+
if not self.for_inference
|
| 721 |
+
else sum(len(elem[k]) for k in self.sorting_fields)
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
dataset_elements = sorted(dataset_elements, key=sorting_fn)
|
| 725 |
+
|
| 726 |
+
if self.for_inference:
|
| 727 |
+
return dataset_elements
|
| 728 |
+
|
| 729 |
+
ds = list(chunks(dataset_elements, 64))
|
| 730 |
+
np.random.shuffle(ds)
|
| 731 |
+
return flatten(ds)
|
| 732 |
+
|
| 733 |
+
def materialize_batches(
|
| 734 |
+
self, dataset_elements: List[Dict[str, Any]]
|
| 735 |
+
) -> Generator[Dict[str, Any], None, None]:
|
| 736 |
+
if self.prebatch:
|
| 737 |
+
dataset_elements = self.preshuffle_elements(dataset_elements)
|
| 738 |
+
|
| 739 |
+
current_batch = []
|
| 740 |
+
|
| 741 |
+
# function that creates a batch from the 'current_batch' list
|
| 742 |
+
def output_batch() -> Dict[str, Any]:
|
| 743 |
+
assert (
|
| 744 |
+
len(
|
| 745 |
+
set([len(elem["predictable_candidates"]) for elem in current_batch])
|
| 746 |
+
)
|
| 747 |
+
== 1
|
| 748 |
+
), " ".join(
|
| 749 |
+
map(
|
| 750 |
+
str, [len(elem["predictable_candidates"]) for elem in current_batch]
|
| 751 |
+
)
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
batch_dict = dict()
|
| 755 |
+
|
| 756 |
+
de_values_by_field = {
|
| 757 |
+
fn: [de[fn] for de in current_batch if fn in de]
|
| 758 |
+
for fn in self.fields_batcher
|
| 759 |
+
}
|
| 760 |
+
|
| 761 |
+
# in case you provide fields batchers but in the batch
|
| 762 |
+
# there are no elements for that field
|
| 763 |
+
de_values_by_field = {
|
| 764 |
+
fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0
|
| 765 |
+
}
|
| 766 |
+
|
| 767 |
+
assert len(set([len(v) for v in de_values_by_field.values()]))
|
| 768 |
+
|
| 769 |
+
# todo: maybe we should report the user about possible
|
| 770 |
+
# fields filtering due to "None" instances
|
| 771 |
+
de_values_by_field = {
|
| 772 |
+
fn: fvs
|
| 773 |
+
for fn, fvs in de_values_by_field.items()
|
| 774 |
+
if all([fv is not None for fv in fvs])
|
| 775 |
+
}
|
| 776 |
+
|
| 777 |
+
for field_name, field_values in de_values_by_field.items():
|
| 778 |
+
field_batch = (
|
| 779 |
+
self.fields_batcher[field_name](field_values)
|
| 780 |
+
if self.fields_batcher[field_name] is not None
|
| 781 |
+
else field_values
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
batch_dict[field_name] = field_batch
|
| 785 |
+
|
| 786 |
+
return batch_dict
|
| 787 |
+
|
| 788 |
+
max_len_discards, min_len_discards = 0, 0
|
| 789 |
+
|
| 790 |
+
should_token_batch = self.batch_size is None
|
| 791 |
+
|
| 792 |
+
curr_pred_elements = -1
|
| 793 |
+
for de in dataset_elements:
|
| 794 |
+
if (
|
| 795 |
+
should_token_batch
|
| 796 |
+
and self.max_batch_size != -1
|
| 797 |
+
and len(current_batch) == self.max_batch_size
|
| 798 |
+
) or (not should_token_batch and len(current_batch) == self.batch_size):
|
| 799 |
+
yield output_batch()
|
| 800 |
+
current_batch = []
|
| 801 |
+
curr_pred_elements = -1
|
| 802 |
+
|
| 803 |
+
too_long_fields = [
|
| 804 |
+
k
|
| 805 |
+
for k in de
|
| 806 |
+
if self.max_length != -1
|
| 807 |
+
and torch.is_tensor(de[k])
|
| 808 |
+
and len(de[k]) > self.max_length
|
| 809 |
+
]
|
| 810 |
+
if len(too_long_fields) > 0:
|
| 811 |
+
max_len_discards += 1
|
| 812 |
+
continue
|
| 813 |
+
|
| 814 |
+
too_short_fields = [
|
| 815 |
+
k
|
| 816 |
+
for k in de
|
| 817 |
+
if self.min_length != -1
|
| 818 |
+
and torch.is_tensor(de[k])
|
| 819 |
+
and len(de[k]) < self.min_length
|
| 820 |
+
]
|
| 821 |
+
if len(too_short_fields) > 0:
|
| 822 |
+
min_len_discards += 1
|
| 823 |
+
continue
|
| 824 |
+
|
| 825 |
+
if should_token_batch:
|
| 826 |
+
de_len = sum(len(de[k]) for k in self.batching_fields)
|
| 827 |
+
|
| 828 |
+
future_max_len = max(
|
| 829 |
+
de_len,
|
| 830 |
+
max(
|
| 831 |
+
[
|
| 832 |
+
sum(len(bde[k]) for k in self.batching_fields)
|
| 833 |
+
for bde in current_batch
|
| 834 |
+
],
|
| 835 |
+
default=0,
|
| 836 |
+
),
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
future_tokens_per_batch = future_max_len * (len(current_batch) + 1)
|
| 840 |
+
|
| 841 |
+
num_predictable_candidates = len(de["predictable_candidates"])
|
| 842 |
+
|
| 843 |
+
if len(current_batch) > 0 and (
|
| 844 |
+
future_tokens_per_batch >= self.tokens_per_batch
|
| 845 |
+
or (
|
| 846 |
+
num_predictable_candidates != curr_pred_elements
|
| 847 |
+
and curr_pred_elements != -1
|
| 848 |
+
)
|
| 849 |
+
):
|
| 850 |
+
yield output_batch()
|
| 851 |
+
current_batch = []
|
| 852 |
+
|
| 853 |
+
current_batch.append(de)
|
| 854 |
+
curr_pred_elements = len(de["predictable_candidates"])
|
| 855 |
+
|
| 856 |
+
if len(current_batch) != 0 and not self.drop_last:
|
| 857 |
+
yield output_batch()
|
| 858 |
+
|
| 859 |
+
if max_len_discards > 0:
|
| 860 |
+
if self.for_inference:
|
| 861 |
+
logger.warning(
|
| 862 |
+
f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were "
|
| 863 |
+
f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
|
| 864 |
+
f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the "
|
| 865 |
+
f"sample length exceeds the maximum length supported by the current model."
|
| 866 |
+
)
|
| 867 |
+
else:
|
| 868 |
+
logger.warning(
|
| 869 |
+
f"During iteration, {max_len_discards} elements were "
|
| 870 |
+
f"discarded since longer than max length {self.max_length}"
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
if min_len_discards > 0:
|
| 874 |
+
if self.for_inference:
|
| 875 |
+
logger.warning(
|
| 876 |
+
f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were "
|
| 877 |
+
f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
|
| 878 |
+
f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the "
|
| 879 |
+
f"sample length is shorter than the minimum length supported by the current model."
|
| 880 |
+
)
|
| 881 |
+
else:
|
| 882 |
+
logger.warning(
|
| 883 |
+
f"During iteration, {min_len_discards} elements were "
|
| 884 |
+
f"discarded since shorter than min length {self.min_length}"
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
@staticmethod
|
| 888 |
+
def convert_tokens_to_char_annotations(
|
| 889 |
+
sample: RelikReaderSample,
|
| 890 |
+
remove_nmes: bool = True,
|
| 891 |
+
) -> RelikReaderSample:
|
| 892 |
+
"""
|
| 893 |
+
Converts the token annotations to char annotations.
|
| 894 |
+
|
| 895 |
+
Args:
|
| 896 |
+
sample (:obj:`RelikReaderSample`):
|
| 897 |
+
The sample to convert.
|
| 898 |
+
remove_nmes (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
| 899 |
+
Whether to remove the NMEs from the annotations.
|
| 900 |
+
Returns:
|
| 901 |
+
:obj:`RelikReaderSample`: The converted sample.
|
| 902 |
+
"""
|
| 903 |
+
char_annotations = set()
|
| 904 |
+
for (
|
| 905 |
+
predicted_entity,
|
| 906 |
+
predicted_spans,
|
| 907 |
+
) in sample.predicted_window_labels.items():
|
| 908 |
+
if predicted_entity == NME_SYMBOL and remove_nmes:
|
| 909 |
+
continue
|
| 910 |
+
|
| 911 |
+
for span_start, span_end in predicted_spans:
|
| 912 |
+
span_start = sample.token2char_start[str(span_start)]
|
| 913 |
+
span_end = sample.token2char_end[str(span_end)]
|
| 914 |
+
|
| 915 |
+
char_annotations.add((span_start, span_end, predicted_entity))
|
| 916 |
+
|
| 917 |
+
char_probs_annotations = dict()
|
| 918 |
+
for (
|
| 919 |
+
span_start,
|
| 920 |
+
span_end,
|
| 921 |
+
), candidates_probs in sample.span_title_probabilities.items():
|
| 922 |
+
span_start = sample.token2char_start[str(span_start)]
|
| 923 |
+
span_end = sample.token2char_end[str(span_end)]
|
| 924 |
+
char_probs_annotations[(span_start, span_end)] = {
|
| 925 |
+
title for title, _ in candidates_probs
|
| 926 |
+
}
|
| 927 |
+
|
| 928 |
+
sample.predicted_window_labels_chars = char_annotations
|
| 929 |
+
sample.probs_window_labels_chars = char_probs_annotations
|
| 930 |
+
|
| 931 |
+
return sample
|
| 932 |
+
|
| 933 |
+
@staticmethod
|
| 934 |
+
def merge_patches_predictions(sample) -> None:
|
| 935 |
+
sample._d["predicted_window_labels"] = dict()
|
| 936 |
+
predicted_window_labels = sample._d["predicted_window_labels"]
|
| 937 |
+
|
| 938 |
+
sample._d["span_title_probabilities"] = dict()
|
| 939 |
+
span_title_probabilities = sample._d["span_title_probabilities"]
|
| 940 |
+
|
| 941 |
+
span2title = dict()
|
| 942 |
+
for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
|
| 943 |
+
# selecting span predictions
|
| 944 |
+
for predicted_title, predicted_spans in patch_info[
|
| 945 |
+
"predicted_window_labels"
|
| 946 |
+
].items():
|
| 947 |
+
for pred_span in predicted_spans:
|
| 948 |
+
pred_span = tuple(pred_span)
|
| 949 |
+
curr_title = span2title.get(pred_span)
|
| 950 |
+
if curr_title is None or curr_title == NME_SYMBOL:
|
| 951 |
+
span2title[pred_span] = predicted_title
|
| 952 |
+
# else:
|
| 953 |
+
# print("Merging at patch level")
|
| 954 |
+
|
| 955 |
+
# selecting span predictions probability
|
| 956 |
+
for predicted_span, titles_probabilities in patch_info[
|
| 957 |
+
"span_title_probabilities"
|
| 958 |
+
].items():
|
| 959 |
+
if predicted_span not in span_title_probabilities:
|
| 960 |
+
span_title_probabilities[predicted_span] = titles_probabilities
|
| 961 |
+
|
| 962 |
+
for span, title in span2title.items():
|
| 963 |
+
if title not in predicted_window_labels:
|
| 964 |
+
predicted_window_labels[title] = list()
|
| 965 |
+
predicted_window_labels[title].append(span)
|
relik/reader/data/relik_reader_data_utils.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def flatten(lsts: List[list]) -> list:
|
| 8 |
+
acc_lst = list()
|
| 9 |
+
for lst in lsts:
|
| 10 |
+
acc_lst.extend(lst)
|
| 11 |
+
return acc_lst
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def batchify(tensors: List[torch.Tensor], padding_value: int = 0) -> torch.Tensor:
|
| 15 |
+
return torch.nn.utils.rnn.pad_sequence(
|
| 16 |
+
tensors, batch_first=True, padding_value=padding_value
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def batchify_matrices(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
|
| 21 |
+
x = max([t.shape[0] for t in tensors])
|
| 22 |
+
y = max([t.shape[1] for t in tensors])
|
| 23 |
+
out_matrix = torch.zeros((len(tensors), x, y))
|
| 24 |
+
out_matrix += padding_value
|
| 25 |
+
for i, tensor in enumerate(tensors):
|
| 26 |
+
out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1]] = tensor
|
| 27 |
+
return out_matrix
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def batchify_tensor(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
|
| 31 |
+
x = max([t.shape[0] for t in tensors])
|
| 32 |
+
y = max([t.shape[1] for t in tensors])
|
| 33 |
+
rest = tensors[0].shape[2]
|
| 34 |
+
out_matrix = torch.zeros((len(tensors), x, y, rest))
|
| 35 |
+
out_matrix += padding_value
|
| 36 |
+
for i, tensor in enumerate(tensors):
|
| 37 |
+
out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1], :] = tensor
|
| 38 |
+
return out_matrix
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def chunks(lst: list, chunk_size: int) -> List[list]:
|
| 42 |
+
chunks_acc = list()
|
| 43 |
+
for i in range(0, len(lst), chunk_size):
|
| 44 |
+
chunks_acc.append(lst[i : i + chunk_size])
|
| 45 |
+
return chunks_acc
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def add_noise_to_value(value: int, noise_param: float):
|
| 49 |
+
noise_value = value * noise_param
|
| 50 |
+
noise = np.random.uniform(-noise_value, noise_value)
|
| 51 |
+
return max(1, value + noise)
|
relik/reader/data/relik_reader_sample.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import Iterable
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class RelikReaderSample:
|
| 6 |
+
def __init__(self, **kwargs):
|
| 7 |
+
super().__setattr__("_d", {})
|
| 8 |
+
self._d = kwargs
|
| 9 |
+
|
| 10 |
+
def __getattribute__(self, item):
|
| 11 |
+
return super(RelikReaderSample, self).__getattribute__(item)
|
| 12 |
+
|
| 13 |
+
def __getattr__(self, item):
|
| 14 |
+
if item.startswith("__") and item.endswith("__"):
|
| 15 |
+
# this is likely some python library-specific variable (such as __deepcopy__ for copy)
|
| 16 |
+
# better follow standard behavior here
|
| 17 |
+
raise AttributeError(item)
|
| 18 |
+
elif item in self._d:
|
| 19 |
+
return self._d[item]
|
| 20 |
+
else:
|
| 21 |
+
return None
|
| 22 |
+
|
| 23 |
+
def __setattr__(self, key, value):
|
| 24 |
+
if key in self._d:
|
| 25 |
+
self._d[key] = value
|
| 26 |
+
else:
|
| 27 |
+
super().__setattr__(key, value)
|
| 28 |
+
|
| 29 |
+
def to_jsons(self) -> str:
|
| 30 |
+
if "predicted_window_labels" in self._d:
|
| 31 |
+
new_obj = {
|
| 32 |
+
k: v
|
| 33 |
+
for k, v in self._d.items()
|
| 34 |
+
if k != "predicted_window_labels" and k != "span_title_probabilities"
|
| 35 |
+
}
|
| 36 |
+
new_obj["predicted_window_labels"] = [
|
| 37 |
+
[ss, se, pred_title]
|
| 38 |
+
for (ss, se), pred_title in self.predicted_window_labels_chars
|
| 39 |
+
]
|
| 40 |
+
else:
|
| 41 |
+
return json.dumps(self._d)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_relik_reader_samples(path: str) -> Iterable[RelikReaderSample]:
|
| 45 |
+
with open(path) as f:
|
| 46 |
+
for line in f:
|
| 47 |
+
jsonl_line = json.loads(line.strip())
|
| 48 |
+
relik_reader_sample = RelikReaderSample(**jsonl_line)
|
| 49 |
+
yield relik_reader_sample
|
relik/reader/lightning_modules/__init__.py
ADDED
|
File without changes
|
relik/reader/lightning_modules/relik_reader_pl_module.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Optional
|
| 2 |
+
|
| 3 |
+
import lightning
|
| 4 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
|
| 5 |
+
|
| 6 |
+
from relik.reader.relik_reader_core import RelikReaderCoreModel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RelikReaderPLModule(lightning.LightningModule):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
cfg: dict,
|
| 13 |
+
transformer_model: str,
|
| 14 |
+
additional_special_symbols: int,
|
| 15 |
+
num_layers: Optional[int] = None,
|
| 16 |
+
activation: str = "gelu",
|
| 17 |
+
linears_hidden_size: Optional[int] = 512,
|
| 18 |
+
use_last_k_layers: int = 1,
|
| 19 |
+
training: bool = False,
|
| 20 |
+
*args: Any,
|
| 21 |
+
**kwargs: Any
|
| 22 |
+
):
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
self.save_hyperparameters()
|
| 25 |
+
self.relik_reader_core_model = RelikReaderCoreModel(
|
| 26 |
+
transformer_model,
|
| 27 |
+
additional_special_symbols,
|
| 28 |
+
num_layers,
|
| 29 |
+
activation,
|
| 30 |
+
linears_hidden_size,
|
| 31 |
+
use_last_k_layers,
|
| 32 |
+
training=training,
|
| 33 |
+
)
|
| 34 |
+
self.optimizer_factory = None
|
| 35 |
+
|
| 36 |
+
def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
| 37 |
+
relik_output = self.relik_reader_core_model(**batch)
|
| 38 |
+
self.log("train-loss", relik_output["loss"])
|
| 39 |
+
return relik_output["loss"]
|
| 40 |
+
|
| 41 |
+
def validation_step(
|
| 42 |
+
self, batch: dict, *args: Any, **kwargs: Any
|
| 43 |
+
) -> Optional[STEP_OUTPUT]:
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
def set_optimizer_factory(self, optimizer_factory) -> None:
|
| 47 |
+
self.optimizer_factory = optimizer_factory
|
| 48 |
+
|
| 49 |
+
def configure_optimizers(self) -> OptimizerLRScheduler:
|
| 50 |
+
return self.optimizer_factory(self.relik_reader_core_model)
|
relik/reader/lightning_modules/relik_reader_re_pl_module.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Optional
|
| 2 |
+
|
| 3 |
+
import lightning
|
| 4 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
|
| 5 |
+
|
| 6 |
+
from relik.reader.relik_reader_re import RelikReaderForTripletExtraction
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RelikReaderREPLModule(lightning.LightningModule):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
cfg: dict,
|
| 13 |
+
transformer_model: str,
|
| 14 |
+
additional_special_symbols: int,
|
| 15 |
+
num_layers: Optional[int] = None,
|
| 16 |
+
activation: str = "gelu",
|
| 17 |
+
linears_hidden_size: Optional[int] = 512,
|
| 18 |
+
use_last_k_layers: int = 1,
|
| 19 |
+
training: bool = False,
|
| 20 |
+
*args: Any,
|
| 21 |
+
**kwargs: Any
|
| 22 |
+
):
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
self.save_hyperparameters()
|
| 25 |
+
|
| 26 |
+
self.relik_reader_re_model = RelikReaderForTripletExtraction(
|
| 27 |
+
transformer_model,
|
| 28 |
+
additional_special_symbols,
|
| 29 |
+
num_layers,
|
| 30 |
+
activation,
|
| 31 |
+
linears_hidden_size,
|
| 32 |
+
use_last_k_layers,
|
| 33 |
+
training=training,
|
| 34 |
+
)
|
| 35 |
+
self.optimizer_factory = None
|
| 36 |
+
|
| 37 |
+
def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
| 38 |
+
relik_output = self.relik_reader_re_model(**batch)
|
| 39 |
+
self.log("train-loss", relik_output["loss"])
|
| 40 |
+
self.log("train-start_loss", relik_output["ned_start_loss"])
|
| 41 |
+
self.log("train-end_loss", relik_output["ned_end_loss"])
|
| 42 |
+
self.log("train-relation_loss", relik_output["re_loss"])
|
| 43 |
+
return relik_output["loss"]
|
| 44 |
+
|
| 45 |
+
def validation_step(
|
| 46 |
+
self, batch: dict, *args: Any, **kwargs: Any
|
| 47 |
+
) -> Optional[STEP_OUTPUT]:
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
+
def set_optimizer_factory(self, optimizer_factory) -> None:
|
| 51 |
+
self.optimizer_factory = optimizer_factory
|
| 52 |
+
|
| 53 |
+
def configure_optimizers(self) -> OptimizerLRScheduler:
|
| 54 |
+
return self.optimizer_factory(self.relik_reader_re_model)
|
relik/reader/pytorch_modules/__init__.py
ADDED
|
File without changes
|