Spaces:
Sleeping
Sleeping
Pushed stuff to main
Browse files- federated_rodla/federated/augmentation_engine.py +0 -172
- federated_rodla/federated/data_client.py +0 -212
- federated_rodla/scripts/start_data_client.py +0 -64
- {federated_rodla β federated_rodla_two/federated_rodla/federated_rodla}/configs/federated/centralized_rodla_federated_aug.py +19 -18
- federated_rodla_two/federated_rodla/federated_rodla/federated/data_client.py +481 -0
- {federated_rodla β federated_rodla_two/federated_rodla/federated_rodla}/federated/data_server.py +163 -163
- federated_rodla_two/federated_rodla/federated_rodla/federated/perturbation_engine.py +181 -0
- {federated_rodla β federated_rodla_two/federated_rodla/federated_rodla}/federated/privacy_utils.py +0 -0
- federated_rodla_two/federated_rodla/federated_rodla/federated/training_server.py +331 -0
- federated_rodla_two/federated_rodla/federated_rodla/scripts/start_data_client.py +237 -0
- {federated_rodla β federated_rodla_two/federated_rodla/federated_rodla}/scripts/start_data_server.py +28 -28
- federated_rodla_two/federated_rodla/federated_rodla/scripts/start_training_client.py +43 -0
- federated_rodla_two/federated_rodla/federated_rodla/scripts/start_training_server.py +57 -0
- {federated_rodla β federated_rodla_two/federated_rodla/federated_rodla}/utils/data_utils.py +600 -600
- finetuning_rodla/finetuning_rodla/checkpoints/internimage_xl_22k_192to384.pth +0 -0
- finetuning_rodla/finetuning_rodla/checkpoints/rodla_internimage_xl_publaynet.pth +0 -0
- finetuning_rodla/finetuning_rodla/configs/docbank/rodla_internimage_docbank.py +157 -0
- finetuning_rodla/finetuning_rodla/data/docbank_coco.json +635 -0
- finetuning_rodla/finetuning_rodla/data/test/what_to_add_here.txt +4 -0
- finetuning_rodla/finetuning_rodla/data/train/what_to_add_here.txt +5 -0
- finetuning_rodla/finetuning_rodla/tools/convert_docbank_to_coco.py +149 -0
- finetuning_rodla/finetuning_rodla/tools/eval_docbank-p.py +138 -0
- finetuning_rodla/finetuning_rodla/tools/finetune_docbank.py +219 -0
- finetuning_rodla/finetuning_rodla/work_dirs/rodla_docbank/epoch_1.pth +0 -0
- finetuning_rodla/finetuning_rodla/work_dirs/rodla_docbank/evaluation_results.txt +21 -0
federated_rodla/federated/augmentation_engine.py
DELETED
|
@@ -1,172 +0,0 @@
|
|
| 1 |
-
# federated/augmentation_engine.py
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
from PIL import Image, ImageFilter, ImageEnhance
|
| 5 |
-
import cv2
|
| 6 |
-
import random
|
| 7 |
-
from typing import Dict, Tuple
|
| 8 |
-
|
| 9 |
-
class AugmentationEngine:
|
| 10 |
-
def __init__(self, privacy_level: str = 'medium'):
|
| 11 |
-
self.privacy_level = privacy_level
|
| 12 |
-
self.setup_augmentations()
|
| 13 |
-
|
| 14 |
-
def setup_augmentations(self):
|
| 15 |
-
"""Setup augmentation parameters based on privacy level"""
|
| 16 |
-
if self.privacy_level == 'low':
|
| 17 |
-
self.geometric_strength = 0.1
|
| 18 |
-
self.color_strength = 0.1
|
| 19 |
-
self.noise_strength = 0.05
|
| 20 |
-
elif self.privacy_level == 'medium':
|
| 21 |
-
self.geometric_strength = 0.2
|
| 22 |
-
self.color_strength = 0.2
|
| 23 |
-
self.noise_strength = 0.1
|
| 24 |
-
else: # high
|
| 25 |
-
self.geometric_strength = 0.3
|
| 26 |
-
self.color_strength = 0.3
|
| 27 |
-
self.noise_strength = 0.15
|
| 28 |
-
|
| 29 |
-
def get_capabilities(self) -> Dict:
|
| 30 |
-
"""Get augmentation capabilities for server registration"""
|
| 31 |
-
return {
|
| 32 |
-
'geometric_augmentations': True,
|
| 33 |
-
'color_augmentations': True,
|
| 34 |
-
'noise_augmentations': True,
|
| 35 |
-
'privacy_level': self.privacy_level
|
| 36 |
-
}
|
| 37 |
-
|
| 38 |
-
def augment_image(self, image: Image.Image) -> Tuple[Image.Image, Dict]:
|
| 39 |
-
"""Apply augmentations to image"""
|
| 40 |
-
original_size = image.size
|
| 41 |
-
aug_info = {
|
| 42 |
-
'original_size': original_size,
|
| 43 |
-
'applied_transforms': [],
|
| 44 |
-
'parameters': {}
|
| 45 |
-
}
|
| 46 |
-
|
| 47 |
-
# Apply geometric transformations
|
| 48 |
-
image, geometric_info = self.apply_geometric_augmentations(image)
|
| 49 |
-
aug_info['applied_transforms'].extend(geometric_info['transforms'])
|
| 50 |
-
aug_info['parameters'].update(geometric_info['parameters'])
|
| 51 |
-
|
| 52 |
-
# Apply color transformations
|
| 53 |
-
image, color_info = self.apply_color_augmentations(image)
|
| 54 |
-
aug_info['applied_transforms'].extend(color_info['transforms'])
|
| 55 |
-
aug_info['parameters'].update(color_info['parameters'])
|
| 56 |
-
|
| 57 |
-
# Apply noise
|
| 58 |
-
image, noise_info = self.apply_noise_augmentations(image)
|
| 59 |
-
aug_info['applied_transforms'].extend(noise_info['transforms'])
|
| 60 |
-
aug_info['parameters'].update(noise_info['parameters'])
|
| 61 |
-
|
| 62 |
-
aug_info['final_size'] = image.size
|
| 63 |
-
|
| 64 |
-
return image, aug_info
|
| 65 |
-
|
| 66 |
-
def apply_geometric_augmentations(self, image: Image.Image) -> Tuple[Image.Image, Dict]:
|
| 67 |
-
"""Apply geometric transformations"""
|
| 68 |
-
info = {'transforms': [], 'parameters': {}}
|
| 69 |
-
img = image
|
| 70 |
-
|
| 71 |
-
# Random rotation
|
| 72 |
-
if random.random() < 0.7:
|
| 73 |
-
angle = random.uniform(-15 * self.geometric_strength, 15 * self.geometric_strength)
|
| 74 |
-
img = img.rotate(angle, resample=Image.BILINEAR, expand=False)
|
| 75 |
-
info['transforms'].append('rotation')
|
| 76 |
-
info['parameters']['rotation_angle'] = angle
|
| 77 |
-
|
| 78 |
-
# Random scaling
|
| 79 |
-
if random.random() < 0.6:
|
| 80 |
-
scale = random.uniform(1.0 - 0.2 * self.geometric_strength, 1.0 + 0.2 * self.geometric_strength)
|
| 81 |
-
new_size = (int(img.width * scale), int(img.height * scale))
|
| 82 |
-
img = img.resize(new_size, Image.BILINEAR)
|
| 83 |
-
info['transforms'].append('scaling')
|
| 84 |
-
info['parameters']['scale_factor'] = scale
|
| 85 |
-
|
| 86 |
-
# Random perspective (simplified)
|
| 87 |
-
if random.random() < 0.4:
|
| 88 |
-
img = self.apply_perspective_distortion(img)
|
| 89 |
-
info['transforms'].append('perspective')
|
| 90 |
-
|
| 91 |
-
return img, info
|
| 92 |
-
|
| 93 |
-
def apply_color_augmentations(self, image: Image.Image) -> Tuple[Image.Image, Dict]:
|
| 94 |
-
"""Apply color transformations"""
|
| 95 |
-
info = {'transforms': [], 'parameters': {}}
|
| 96 |
-
img = image
|
| 97 |
-
|
| 98 |
-
# Brightness
|
| 99 |
-
if random.random() < 0.7:
|
| 100 |
-
factor = random.uniform(1.0 - 0.3 * self.color_strength, 1.0 + 0.3 * self.color_strength)
|
| 101 |
-
enhancer = ImageEnhance.Brightness(img)
|
| 102 |
-
img = enhancer.enhance(factor)
|
| 103 |
-
info['transforms'].append('brightness')
|
| 104 |
-
info['parameters']['brightness_factor'] = factor
|
| 105 |
-
|
| 106 |
-
# Contrast
|
| 107 |
-
if random.random() < 0.6:
|
| 108 |
-
factor = random.uniform(1.0 - 0.3 * self.color_strength, 1.0 + 0.3 * self.color_strength)
|
| 109 |
-
enhancer = ImageEnhance.Contrast(img)
|
| 110 |
-
img = enhancer.enhance(factor)
|
| 111 |
-
info['transforms'].append('contrast')
|
| 112 |
-
info['parameters']['contrast_factor'] = factor
|
| 113 |
-
|
| 114 |
-
# Color balance
|
| 115 |
-
if random.random() < 0.5:
|
| 116 |
-
factor = random.uniform(1.0 - 0.2 * self.color_strength, 1.0 + 0.2 * self.color_strength)
|
| 117 |
-
enhancer = ImageEnhance.Color(img)
|
| 118 |
-
img = enhancer.enhance(factor)
|
| 119 |
-
info['transforms'].append('color_balance')
|
| 120 |
-
info['parameters']['color_factor'] = factor
|
| 121 |
-
|
| 122 |
-
return img, info
|
| 123 |
-
|
| 124 |
-
def apply_noise_augmentations(self, image: Image.Image) -> Tuple[Image.Image, Dict]:
|
| 125 |
-
"""Apply noise and blur augmentations"""
|
| 126 |
-
info = {'transforms': [], 'parameters': {}}
|
| 127 |
-
img = image
|
| 128 |
-
|
| 129 |
-
# Gaussian blur
|
| 130 |
-
if random.random() < 0.5:
|
| 131 |
-
radius = random.uniform(0.1, 1.0 * self.noise_strength)
|
| 132 |
-
img = img.filter(ImageFilter.GaussianBlur(radius=radius))
|
| 133 |
-
info['transforms'].append('gaussian_blur')
|
| 134 |
-
info['parameters']['blur_radius'] = radius
|
| 135 |
-
|
| 136 |
-
# Convert to numpy for more advanced noise
|
| 137 |
-
if random.random() < 0.4:
|
| 138 |
-
img_np = np.array(img)
|
| 139 |
-
|
| 140 |
-
# Gaussian noise
|
| 141 |
-
noise = np.random.normal(0, 25 * self.noise_strength, img_np.shape).astype(np.uint8)
|
| 142 |
-
img_np = cv2.add(img_np, noise)
|
| 143 |
-
|
| 144 |
-
img = Image.fromarray(img_np)
|
| 145 |
-
info['transforms'].append('gaussian_noise')
|
| 146 |
-
|
| 147 |
-
return img, info
|
| 148 |
-
|
| 149 |
-
def apply_perspective_distortion(self, image: Image.Image) -> Image.Image:
|
| 150 |
-
"""Apply simple perspective distortion"""
|
| 151 |
-
width, height = image.size
|
| 152 |
-
|
| 153 |
-
# Simple skew effect
|
| 154 |
-
if random.choice([True, False]):
|
| 155 |
-
# Horizontal skew
|
| 156 |
-
skew_factor = random.uniform(-0.1 * self.geometric_strength, 0.1 * self.geometric_strength)
|
| 157 |
-
matrix = (1, skew_factor, -skew_factor * height * 0.5,
|
| 158 |
-
0, 1, 0)
|
| 159 |
-
else:
|
| 160 |
-
# Vertical skew
|
| 161 |
-
skew_factor = random.uniform(-0.1 * self.geometric_strength, 0.1 * self.geometric_strength)
|
| 162 |
-
matrix = (1, 0, 0,
|
| 163 |
-
skew_factor, 1, -skew_factor * width * 0.5)
|
| 164 |
-
|
| 165 |
-
img = image.transform(
|
| 166 |
-
image.size,
|
| 167 |
-
Image.AFFINE,
|
| 168 |
-
matrix,
|
| 169 |
-
resample=Image.BILINEAR
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
federated_rodla/federated/data_client.py
DELETED
|
@@ -1,212 +0,0 @@
|
|
| 1 |
-
# federated/data_client.py
|
| 2 |
-
|
| 3 |
-
import requests
|
| 4 |
-
import base64
|
| 5 |
-
import io
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
from PIL import Image
|
| 9 |
-
import json
|
| 10 |
-
import time
|
| 11 |
-
import logging
|
| 12 |
-
from typing import List, Dict, Optional
|
| 13 |
-
import os
|
| 14 |
-
# Uses DataUtils.tensor_to_numpy() and DataUtils.create_sample()
|
| 15 |
-
from utils.data_utils import DataUtils, FederatedDataConverter
|
| 16 |
-
from augmentation_engine import AugmentationEngine
|
| 17 |
-
|
| 18 |
-
class FederatedDataClient:
|
| 19 |
-
def __init__(self, client_id: str, server_url: str, data_loader, privacy_level: str = 'medium'):
|
| 20 |
-
self.client_id = client_id
|
| 21 |
-
self.server_url = server_url
|
| 22 |
-
self.data_loader = data_loader
|
| 23 |
-
self.privacy_level = privacy_level
|
| 24 |
-
self.augmentation_engine = AugmentationEngine(privacy_level)
|
| 25 |
-
self.registered = False
|
| 26 |
-
|
| 27 |
-
logging.basicConfig(level=logging.INFO)
|
| 28 |
-
|
| 29 |
-
def register_with_server(self):
|
| 30 |
-
"""Register this client with the federated server"""
|
| 31 |
-
try:
|
| 32 |
-
client_info = {
|
| 33 |
-
'data_type': 'M6Doc',
|
| 34 |
-
'privacy_level': self.privacy_level,
|
| 35 |
-
'augmentation_capabilities': self.augmentation_engine.get_capabilities(),
|
| 36 |
-
'timestamp': time.time()
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
-
response = requests.post(
|
| 40 |
-
f"{self.server_url}/register_client",
|
| 41 |
-
json={
|
| 42 |
-
'client_id': self.client_id,
|
| 43 |
-
'client_info': client_info
|
| 44 |
-
},
|
| 45 |
-
timeout=10
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
if response.status_code == 200:
|
| 49 |
-
data = response.json()
|
| 50 |
-
if data['status'] == 'success':
|
| 51 |
-
self.registered = True
|
| 52 |
-
logging.info(f"Client {self.client_id} successfully registered")
|
| 53 |
-
return True
|
| 54 |
-
|
| 55 |
-
logging.error(f"Failed to register client: {response.text}")
|
| 56 |
-
return False
|
| 57 |
-
|
| 58 |
-
except Exception as e:
|
| 59 |
-
logging.error(f"Registration failed: {e}")
|
| 60 |
-
return False
|
| 61 |
-
|
| 62 |
-
def generate_augmented_samples(self, num_samples: int = 50) -> List[Dict]:
|
| 63 |
-
"""Generate augmented samples from local data"""
|
| 64 |
-
samples = []
|
| 65 |
-
|
| 66 |
-
for i, batch in enumerate(self.data_loader):
|
| 67 |
-
if len(samples) >= num_samples:
|
| 68 |
-
break
|
| 69 |
-
|
| 70 |
-
try:
|
| 71 |
-
# Assume batch structure: {'img': tensor, 'gt_bboxes': list, 'gt_labels': list, 'img_metas': list}
|
| 72 |
-
images = batch['img']
|
| 73 |
-
img_metas = batch['img_metas']
|
| 74 |
-
|
| 75 |
-
for j in range(len(images)):
|
| 76 |
-
if len(samples) >= num_samples:
|
| 77 |
-
break
|
| 78 |
-
|
| 79 |
-
# Convert tensor to PIL Image
|
| 80 |
-
img_tensor = images[j]
|
| 81 |
-
img_np = self.tensor_to_numpy(img_tensor)
|
| 82 |
-
pil_img = Image.fromarray(img_np)
|
| 83 |
-
|
| 84 |
-
# Apply augmentations
|
| 85 |
-
augmented_img, augmentation_info = self.augmentation_engine.augment_image(pil_img)
|
| 86 |
-
|
| 87 |
-
# Prepare annotations
|
| 88 |
-
annotations = self.prepare_annotations(batch, j, augmentation_info)
|
| 89 |
-
|
| 90 |
-
# Create sample
|
| 91 |
-
sample = self.create_sample(augmented_img, annotations, augmentation_info)
|
| 92 |
-
samples.append(sample)
|
| 93 |
-
|
| 94 |
-
except Exception as e:
|
| 95 |
-
logging.warning(f"Error processing batch {i}: {e}")
|
| 96 |
-
continue
|
| 97 |
-
|
| 98 |
-
logging.info(f"Generated {len(samples)} augmented samples")
|
| 99 |
-
return samples
|
| 100 |
-
|
| 101 |
-
def tensor_to_numpy(self, tensor: torch.Tensor) -> np.ndarray:
|
| 102 |
-
"""Convert torch tensor to numpy array for image"""
|
| 103 |
-
# Denormalize and convert
|
| 104 |
-
img_np = tensor.cpu().numpy().transpose(1, 2, 0)
|
| 105 |
-
img_np = (img_np * [58.395, 57.12, 57.375] + [123.675, 116.28, 103.53]).astype(np.uint8)
|
| 106 |
-
return img_np
|
| 107 |
-
|
| 108 |
-
def prepare_annotations(self, batch: Dict, index: int, aug_info: Dict) -> Dict:
|
| 109 |
-
"""Prepare annotations for a sample, adjusting for augmentations"""
|
| 110 |
-
bboxes = batch['gt_bboxes'][index].cpu().numpy() if hasattr(batch['gt_bboxes'][index], 'cpu') else batch['gt_bboxes'][index]
|
| 111 |
-
labels = batch['gt_labels'][index].cpu().numpy() if hasattr(batch['gt_labels'][index], 'cpu') else batch['gt_labels'][index]
|
| 112 |
-
|
| 113 |
-
# Adjust bounding boxes for geometric transformations
|
| 114 |
-
if 'geometric' in aug_info['applied_transforms']:
|
| 115 |
-
bboxes = self.adjust_bboxes_for_augmentation(bboxes, aug_info)
|
| 116 |
-
|
| 117 |
-
annotations = {
|
| 118 |
-
'bboxes': bboxes.tolist(),
|
| 119 |
-
'labels': labels.tolist(),
|
| 120 |
-
'image_size': aug_info['final_size'],
|
| 121 |
-
'original_size': aug_info['original_size']
|
| 122 |
-
}
|
| 123 |
-
|
| 124 |
-
return annotations
|
| 125 |
-
|
| 126 |
-
def adjust_bboxes_for_augmentation(self, bboxes: np.ndarray, aug_info: Dict) -> np.ndarray:
|
| 127 |
-
"""Adjust bounding boxes for geometric augmentations"""
|
| 128 |
-
# Simplified bbox adjustment
|
| 129 |
-
# In practice, you'd use the exact transformation matrices
|
| 130 |
-
scale_x = aug_info['final_size'][0] / aug_info['original_size'][0]
|
| 131 |
-
scale_y = aug_info['final_size'][1] / aug_info['original_size'][1]
|
| 132 |
-
|
| 133 |
-
adjusted_bboxes = bboxes.copy()
|
| 134 |
-
adjusted_bboxes[:, 0] *= scale_x # x1
|
| 135 |
-
adjusted_bboxes[:, 1] *= scale_y # y1
|
| 136 |
-
adjusted_bboxes[:, 2] *= scale_x # x2
|
| 137 |
-
adjusted_bboxes[:, 3] *= scale_y # y2
|
| 138 |
-
|
| 139 |
-
return adjusted_bboxes
|
| 140 |
-
|
| 141 |
-
def create_sample(self, image: Image.Image, annotations: Dict, aug_info: Dict) -> Dict:
|
| 142 |
-
"""Create a sample for sending to server"""
|
| 143 |
-
# Convert image to base64
|
| 144 |
-
buffered = io.BytesIO()
|
| 145 |
-
image.save(buffered, format="JPEG", quality=85)
|
| 146 |
-
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 147 |
-
|
| 148 |
-
sample = {
|
| 149 |
-
'image_data': img_str,
|
| 150 |
-
'annotations': annotations,
|
| 151 |
-
'metadata': {
|
| 152 |
-
'client_id': self.client_id,
|
| 153 |
-
'augmentation_info': aug_info,
|
| 154 |
-
'timestamp': time.time(),
|
| 155 |
-
'privacy_level': self.privacy_level
|
| 156 |
-
}
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
return sample
|
| 160 |
-
|
| 161 |
-
def submit_augmented_data(self, samples: List[Dict]) -> bool:
|
| 162 |
-
"""Submit augmented samples to the server"""
|
| 163 |
-
if not self.registered:
|
| 164 |
-
logging.error("Client not registered with server")
|
| 165 |
-
return False
|
| 166 |
-
|
| 167 |
-
try:
|
| 168 |
-
response = requests.post(
|
| 169 |
-
f"{self.server_url}/submit_augmented_data",
|
| 170 |
-
json={
|
| 171 |
-
'client_id': self.client_id,
|
| 172 |
-
'samples': samples
|
| 173 |
-
},
|
| 174 |
-
timeout=30
|
| 175 |
-
)
|
| 176 |
-
|
| 177 |
-
if response.status_code == 200:
|
| 178 |
-
result = response.json()
|
| 179 |
-
if result['status'] == 'success':
|
| 180 |
-
logging.info(f"Successfully submitted {result['received']} samples")
|
| 181 |
-
return True
|
| 182 |
-
|
| 183 |
-
logging.error(f"Submission failed: {response.text}")
|
| 184 |
-
return False
|
| 185 |
-
|
| 186 |
-
except Exception as e:
|
| 187 |
-
logging.error(f"Error submitting data: {e}")
|
| 188 |
-
return False
|
| 189 |
-
|
| 190 |
-
def run_data_generation(self, samples_per_batch: int = 50, interval: int = 300):
|
| 191 |
-
"""Continuously generate and submit augmented data"""
|
| 192 |
-
if not self.register_with_server():
|
| 193 |
-
return False
|
| 194 |
-
|
| 195 |
-
logging.info(f"Starting continuous data generation (batch: {samples_per_batch}, interval: {interval}s)")
|
| 196 |
-
|
| 197 |
-
while True:
|
| 198 |
-
try:
|
| 199 |
-
samples = self.generate_augmented_samples(samples_per_batch)
|
| 200 |
-
if samples:
|
| 201 |
-
success = self.submit_augmented_data(samples)
|
| 202 |
-
if not success:
|
| 203 |
-
logging.warning("Failed to submit batch, will retry after interval")
|
| 204 |
-
|
| 205 |
-
time.sleep(interval)
|
| 206 |
-
|
| 207 |
-
except KeyboardInterrupt:
|
| 208 |
-
logging.info("Data generation stopped by user")
|
| 209 |
-
break
|
| 210 |
-
except Exception as e:
|
| 211 |
-
logging.error(f"Error in data generation loop: {e}")
|
| 212 |
-
time.sleep(interval) # Wait before retrying
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
federated_rodla/scripts/start_data_client.py
DELETED
|
@@ -1,64 +0,0 @@
|
|
| 1 |
-
# scripts/start_data_client.py
|
| 2 |
-
|
| 3 |
-
import argparse
|
| 4 |
-
import sys
|
| 5 |
-
import os
|
| 6 |
-
|
| 7 |
-
# Add project root to path
|
| 8 |
-
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 9 |
-
|
| 10 |
-
from federated.data_client import FederatedDataClient
|
| 11 |
-
import torch
|
| 12 |
-
from torch.utils.data import DataLoader
|
| 13 |
-
|
| 14 |
-
def create_dummy_dataloader():
|
| 15 |
-
"""Create a dummy dataloader for testing - replace with actual M6Doc dataloader"""
|
| 16 |
-
# This is a placeholder - you'll replace this with your actual M6Doc data loader
|
| 17 |
-
class DummyDataset(torch.utils.data.Dataset):
|
| 18 |
-
def __init__(self, size=1000):
|
| 19 |
-
self.size = size
|
| 20 |
-
|
| 21 |
-
def __len__(self):
|
| 22 |
-
return self.size
|
| 23 |
-
|
| 24 |
-
def __getitem__(self, idx):
|
| 25 |
-
# Return dummy data in RoDLA format
|
| 26 |
-
return {
|
| 27 |
-
'img': torch.randn(3, 800, 1333),
|
| 28 |
-
'gt_bboxes': [torch.tensor([[100, 100, 200, 200]])],
|
| 29 |
-
'gt_labels': [torch.tensor([1])],
|
| 30 |
-
'img_metas': [{'filename': f'dummy_{idx}.jpg', 'ori_shape': (800, 1333, 3)}]
|
| 31 |
-
}
|
| 32 |
-
|
| 33 |
-
dataset = DummyDataset(1000)
|
| 34 |
-
return DataLoader(dataset, batch_size=4, shuffle=True)
|
| 35 |
-
|
| 36 |
-
def main():
|
| 37 |
-
parser = argparse.ArgumentParser()
|
| 38 |
-
parser.add_argument('--client-id', required=True, help='Client ID')
|
| 39 |
-
parser.add_argument('--server-url', default='http://localhost:8080', help='Server URL')
|
| 40 |
-
parser.add_argument('--privacy-level', choices=['low', 'medium', 'high'], default='medium')
|
| 41 |
-
parser.add_argument('--samples-per-batch', type=int, default=50)
|
| 42 |
-
parser.add_argument('--interval', type=int, default=300, help='Seconds between batches')
|
| 43 |
-
|
| 44 |
-
args = parser.parse_args()
|
| 45 |
-
|
| 46 |
-
# Create data loader (replace with your actual M6Doc data loader)
|
| 47 |
-
data_loader = create_dummy_dataloader()
|
| 48 |
-
|
| 49 |
-
# Create federated client
|
| 50 |
-
client = FederatedDataClient(
|
| 51 |
-
client_id=args.client_id,
|
| 52 |
-
server_url=args.server_url,
|
| 53 |
-
data_loader=data_loader,
|
| 54 |
-
privacy_level=args.privacy_level
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
# Start continuous data generation
|
| 58 |
-
client.run_data_generation(
|
| 59 |
-
samples_per_batch=args.samples_per_batch,
|
| 60 |
-
interval=args.interval
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
if __name__ == '__main__':
|
| 64 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{federated_rodla β federated_rodla_two/federated_rodla/federated_rodla}/configs/federated/centralized_rodla_federated_aug.py
RENAMED
|
@@ -1,19 +1,20 @@
|
|
| 1 |
-
# configs/federated/centralized_rodla_federated_aug.py
|
| 2 |
-
|
| 3 |
-
_base_ = '../../
|
| 4 |
-
|
| 5 |
-
#
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
| 19 |
# The only change: we'll modify the data loader to use federated augmented data
|
|
|
|
| 1 |
+
# configs/federated/centralized_rodla_federated_aug.py
|
| 2 |
+
|
| 3 |
+
_base_ = '../../rodla_internimage_xl_publaynet.py' # CHANGED to PubLayNet
|
| 4 |
+
|
| 5 |
+
# Federated data settings for PubLayNet-P
|
| 6 |
+
federated_data = dict(
|
| 7 |
+
server_url='localhost:8080',
|
| 8 |
+
client_id='client_01',
|
| 9 |
+
data_batch_size=50,
|
| 10 |
+
max_samples_per_epoch=1000,
|
| 11 |
+
perturbation_types=[
|
| 12 |
+
'background', 'defocus', 'illumination', 'ink_bleeding', 'ink_holdout',
|
| 13 |
+
'keystoning', 'rotation', 'speckle', 'texture', 'vibration',
|
| 14 |
+
'warping', 'watermark', 'random', 'all'
|
| 15 |
+
],
|
| 16 |
+
severity_levels=[1, 2, 3] # CHANGED: Discrete levels instead of privacy levels
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# Training remains exactly the same
|
| 20 |
# The only change: we'll modify the data loader to use federated augmented data
|
federated_rodla_two/federated_rodla/federated_rodla/federated/data_client.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# federated/data_client.py
|
| 2 |
+
|
| 3 |
+
import requests
|
| 4 |
+
import base64
|
| 5 |
+
import io
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
import logging
|
| 12 |
+
from typing import List, Dict, Optional
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
from utils.data_utils import DataUtils, FederatedDataConverter
|
| 16 |
+
from augmentation_engine import PubLayNetAugmentationEngine
|
| 17 |
+
|
| 18 |
+
class FederatedDataClient:
|
| 19 |
+
def __init__(self, client_id: str, server_url: str, data_loader,
|
| 20 |
+
perturbation_type: str = 'random', severity_level: int = 2):
|
| 21 |
+
self.client_id = client_id
|
| 22 |
+
self.server_url = server_url
|
| 23 |
+
self.data_loader = data_loader
|
| 24 |
+
self.perturbation_type = perturbation_type
|
| 25 |
+
self.severity_level = severity_level
|
| 26 |
+
self.augmentation_engine = PubLayNetAugmentationEngine(perturbation_type, severity_level)
|
| 27 |
+
self.registered = False
|
| 28 |
+
|
| 29 |
+
logging.basicConfig(level=logging.INFO)
|
| 30 |
+
|
| 31 |
+
def register_with_server(self):
|
| 32 |
+
"""Register this client with the federated server"""
|
| 33 |
+
try:
|
| 34 |
+
client_info = {
|
| 35 |
+
'data_type': 'PubLayNet',
|
| 36 |
+
'perturbation_type': self.perturbation_type,
|
| 37 |
+
'severity_level': self.severity_level,
|
| 38 |
+
'available_perturbations': self.augmentation_engine.get_available_perturbations(),
|
| 39 |
+
'timestamp': time.time()
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
response = requests.post(
|
| 43 |
+
f"{self.server_url}/register_client",
|
| 44 |
+
json={
|
| 45 |
+
'client_id': self.client_id,
|
| 46 |
+
'client_info': client_info
|
| 47 |
+
},
|
| 48 |
+
timeout=10
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if response.status_code == 200:
|
| 52 |
+
data = response.json()
|
| 53 |
+
if data['status'] == 'success':
|
| 54 |
+
self.registered = True
|
| 55 |
+
logging.info(f"Client {self.client_id} successfully registered")
|
| 56 |
+
logging.info(f"Perturbation: {self.perturbation_type}, Severity: {self.severity_level}")
|
| 57 |
+
return True
|
| 58 |
+
|
| 59 |
+
logging.error(f"Failed to register client: {response.text}")
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
logging.error(f"Registration failed: {e}")
|
| 64 |
+
return False
|
| 65 |
+
|
| 66 |
+
def generate_augmented_samples(self, num_samples: int = 50) -> List[Dict]:
|
| 67 |
+
"""Generate augmented samples using PubLayNet-P perturbations"""
|
| 68 |
+
samples = []
|
| 69 |
+
available_perturbations = self.augmentation_engine.get_available_perturbations()
|
| 70 |
+
perturbation_cycle = 0
|
| 71 |
+
|
| 72 |
+
for i, batch in enumerate(self.data_loader):
|
| 73 |
+
if len(samples) >= num_samples:
|
| 74 |
+
break
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
images = batch['img']
|
| 78 |
+
img_metas = batch['img_metas']
|
| 79 |
+
|
| 80 |
+
for j in range(len(images)):
|
| 81 |
+
if len(samples) >= num_samples:
|
| 82 |
+
break
|
| 83 |
+
|
| 84 |
+
# Convert tensor to PIL Image
|
| 85 |
+
img_tensor = images[j]
|
| 86 |
+
pil_img = DataUtils.tensor_to_pil(img_tensor)
|
| 87 |
+
|
| 88 |
+
# Apply PubLayNet-P perturbation
|
| 89 |
+
if self.perturbation_type == 'all':
|
| 90 |
+
# Cycle through all perturbation types
|
| 91 |
+
pert_type = available_perturbations[perturbation_cycle % len(available_perturbations)]
|
| 92 |
+
perturbation_cycle += 1
|
| 93 |
+
elif self.perturbation_type == 'random':
|
| 94 |
+
pert_type = 'random'
|
| 95 |
+
else:
|
| 96 |
+
pert_type = self.perturbation_type
|
| 97 |
+
|
| 98 |
+
augmented_img, augmentation_info = self.augmentation_engine.augment_image(
|
| 99 |
+
pil_img, pert_type
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Prepare annotations
|
| 103 |
+
annotations = self.prepare_annotations(batch, j, augmentation_info)
|
| 104 |
+
|
| 105 |
+
# Create sample
|
| 106 |
+
sample = self.create_sample(augmented_img, annotations, augmentation_info)
|
| 107 |
+
samples.append(sample)
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logging.warning(f"Error processing batch {i}: {e}")
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
logging.info(f"Generated {len(samples)} augmented samples using {self.perturbation_type}")
|
| 114 |
+
return samples
|
| 115 |
+
|
| 116 |
+
def prepare_annotations(self, batch: Dict, index: int, aug_info: Dict) -> Dict:
|
| 117 |
+
"""Prepare annotations for a sample, adjusting for augmentations"""
|
| 118 |
+
bboxes = batch['gt_bboxes'][index]
|
| 119 |
+
labels = batch['gt_labels'][index]
|
| 120 |
+
|
| 121 |
+
# Convert tensors to lists
|
| 122 |
+
bboxes_list = bboxes.cpu().numpy().tolist() if hasattr(bboxes, 'cpu') else bboxes
|
| 123 |
+
labels_list = labels.cpu().numpy().tolist() if hasattr(labels, 'cpu') else labels
|
| 124 |
+
|
| 125 |
+
# Adjust bounding boxes for geometric transformations
|
| 126 |
+
if aug_info['perturbation_type'] in ['rotation', 'keystoning', 'warping', 'scaling']:
|
| 127 |
+
bboxes_list = self.adjust_bboxes_for_augmentation(bboxes_list, aug_info)
|
| 128 |
+
|
| 129 |
+
annotations = {
|
| 130 |
+
'bboxes': bboxes_list,
|
| 131 |
+
'labels': labels_list,
|
| 132 |
+
'image_size': aug_info['final_size'],
|
| 133 |
+
'original_size': aug_info['original_size'],
|
| 134 |
+
'categories': {
|
| 135 |
+
1: 'text', 2: 'title', 3: 'list', 4: 'table', 5: 'figure'
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
return annotations
|
| 140 |
+
|
| 141 |
+
def adjust_bboxes_for_augmentation(self, bboxes: List, aug_info: Dict) -> List:
|
| 142 |
+
"""Adjust bounding boxes for geometric augmentations"""
|
| 143 |
+
try:
|
| 144 |
+
orig_w, orig_h = aug_info['original_size']
|
| 145 |
+
new_w, new_h = aug_info['final_size']
|
| 146 |
+
|
| 147 |
+
scale_x = new_w / orig_w
|
| 148 |
+
scale_y = new_h / orig_h
|
| 149 |
+
|
| 150 |
+
adjusted_bboxes = []
|
| 151 |
+
for bbox in bboxes:
|
| 152 |
+
x1, y1, x2, y2 = bbox
|
| 153 |
+
|
| 154 |
+
# Apply scaling
|
| 155 |
+
x1 = x1 * scale_x
|
| 156 |
+
y1 = y1 * scale_y
|
| 157 |
+
x2 = x2 * scale_x
|
| 158 |
+
y2 = y2 * scale_y
|
| 159 |
+
|
| 160 |
+
# For rotation, apply simple adjustment (in practice, use proper rotation matrix)
|
| 161 |
+
if aug_info['perturbation_type'] == 'rotation' and 'rotation_angle' in aug_info.get('parameters', {}):
|
| 162 |
+
angle = aug_info['parameters']['rotation_angle']
|
| 163 |
+
if abs(angle) > 5:
|
| 164 |
+
# Simplified rotation adjustment - for production, use proper affine transformation
|
| 165 |
+
center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
|
| 166 |
+
# This is a simplified version - real implementation would use rotation matrix
|
| 167 |
+
pass
|
| 168 |
+
|
| 169 |
+
adjusted_bboxes.append([x1, y1, x2, y2])
|
| 170 |
+
|
| 171 |
+
return adjusted_bboxes
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logging.warning(f"Error adjusting bboxes: {e}")
|
| 175 |
+
return bboxes
|
| 176 |
+
|
| 177 |
+
def create_sample(self, image: Image.Image, annotations: Dict, aug_info: Dict) -> Dict:
|
| 178 |
+
"""Create a sample for sending to server"""
|
| 179 |
+
# Convert image to base64
|
| 180 |
+
buffered = io.BytesIO()
|
| 181 |
+
image.save(buffered, format="JPEG", quality=85)
|
| 182 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 183 |
+
|
| 184 |
+
sample = {
|
| 185 |
+
'image_data': img_str,
|
| 186 |
+
'annotations': annotations,
|
| 187 |
+
'metadata': {
|
| 188 |
+
'client_id': self.client_id,
|
| 189 |
+
'perturbation_type': aug_info['perturbation_type'],
|
| 190 |
+
'severity_level': aug_info['severity_level'],
|
| 191 |
+
'augmentation_info': aug_info,
|
| 192 |
+
'timestamp': time.time(),
|
| 193 |
+
'dataset': 'PubLayNet'
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
return sample
|
| 198 |
+
|
| 199 |
+
def submit_augmented_data(self, samples: List[Dict]) -> bool:
|
| 200 |
+
"""Submit augmented samples to the server"""
|
| 201 |
+
if not self.registered:
|
| 202 |
+
logging.error("Client not registered with server")
|
| 203 |
+
return False
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
response = requests.post(
|
| 207 |
+
f"{self.server_url}/submit_augmented_data",
|
| 208 |
+
json={
|
| 209 |
+
'client_id': self.client_id,
|
| 210 |
+
'samples': samples,
|
| 211 |
+
'perturbation_type': self.perturbation_type,
|
| 212 |
+
'severity_level': self.severity_level
|
| 213 |
+
},
|
| 214 |
+
timeout=30
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if response.status_code == 200:
|
| 218 |
+
result = response.json()
|
| 219 |
+
if result['status'] == 'success':
|
| 220 |
+
logging.info(f"Successfully submitted {result['received']} samples "
|
| 221 |
+
f"(Perturbation: {self.perturbation_type}, Severity: {self.severity_level})")
|
| 222 |
+
return True
|
| 223 |
+
|
| 224 |
+
logging.error(f"Submission failed: {response.text}")
|
| 225 |
+
return False
|
| 226 |
+
|
| 227 |
+
except Exception as e:
|
| 228 |
+
logging.error(f"Error submitting data: {e}")
|
| 229 |
+
return False
|
| 230 |
+
|
| 231 |
+
def run_data_generation(self, samples_per_batch: int = 50, interval: int = 300):
|
| 232 |
+
"""Continuously generate and submit augmented data"""
|
| 233 |
+
if not self.register_with_server():
|
| 234 |
+
return False
|
| 235 |
+
|
| 236 |
+
logging.info(f"Starting continuous data generation")
|
| 237 |
+
logging.info(f"Batch size: {samples_per_batch}, Interval: {interval}s")
|
| 238 |
+
logging.info(f"Perturbation: {self.perturbation_type}, Severity: {self.severity_level}")
|
| 239 |
+
|
| 240 |
+
batch_count = 0
|
| 241 |
+
while True:
|
| 242 |
+
try:
|
| 243 |
+
samples = self.generate_augmented_samples(samples_per_batch)
|
| 244 |
+
if samples:
|
| 245 |
+
success = self.submit_augmented_data(samples)
|
| 246 |
+
batch_count += 1
|
| 247 |
+
|
| 248 |
+
if success:
|
| 249 |
+
logging.info(f"Batch {batch_count} submitted successfully")
|
| 250 |
+
else:
|
| 251 |
+
logging.warning(f"Batch {batch_count} failed, will retry after interval")
|
| 252 |
+
|
| 253 |
+
time.sleep(interval)
|
| 254 |
+
|
| 255 |
+
except KeyboardInterrupt:
|
| 256 |
+
logging.info("Data generation stopped by user")
|
| 257 |
+
break
|
| 258 |
+
except Exception as e:
|
| 259 |
+
logging.error(f"Error in data generation loop: {e}")
|
| 260 |
+
time.sleep(interval)
|
| 261 |
+
|
| 262 |
+
# import requests
|
| 263 |
+
# import base64
|
| 264 |
+
# import io
|
| 265 |
+
# import numpy as np
|
| 266 |
+
# import torch
|
| 267 |
+
# from PIL import Image
|
| 268 |
+
# import json
|
| 269 |
+
# import time
|
| 270 |
+
# import logging
|
| 271 |
+
# from typing import List, Dict, Optional
|
| 272 |
+
# import os
|
| 273 |
+
# # Uses DataUtils.tensor_to_numpy() and DataUtils.create_sample()
|
| 274 |
+
# from utils.data_utils import DataUtils, FederatedDataConverter
|
| 275 |
+
# from augmentation_engine import PubLayNetAugmentationEngine # CHANGED
|
| 276 |
+
|
| 277 |
+
# class FederatedDataClient:
|
| 278 |
+
# def __init__(self, client_id: str, server_url: str, data_loader,
|
| 279 |
+
# perturbation_type: str = 'random', severity_level: int = 2): # CHANGED
|
| 280 |
+
# self.client_id = client_id
|
| 281 |
+
# self.server_url = server_url
|
| 282 |
+
# self.data_loader = data_loader
|
| 283 |
+
# self.perturbation_type = perturbation_type
|
| 284 |
+
# self.severity_level = severity_level
|
| 285 |
+
# self.augmentation_engine = PubLayNetAugmentationEngine(perturbation_type, severity_level) # CHANGED
|
| 286 |
+
# self.registered = False
|
| 287 |
+
|
| 288 |
+
# logging.basicConfig(level=logging.INFO)
|
| 289 |
+
|
| 290 |
+
# def register_with_server(self):
|
| 291 |
+
# """Register this client with the federated server"""
|
| 292 |
+
# try:
|
| 293 |
+
# client_info = {
|
| 294 |
+
# 'data_type': 'M6Doc',
|
| 295 |
+
# 'privacy_level': self.privacy_level,
|
| 296 |
+
# 'augmentation_capabilities': self.augmentation_engine.get_capabilities(),
|
| 297 |
+
# 'timestamp': time.time()
|
| 298 |
+
# }
|
| 299 |
+
|
| 300 |
+
# response = requests.post(
|
| 301 |
+
# f"{self.server_url}/register_client",
|
| 302 |
+
# json={
|
| 303 |
+
# 'client_id': self.client_id,
|
| 304 |
+
# 'client_info': client_info
|
| 305 |
+
# },
|
| 306 |
+
# timeout=10
|
| 307 |
+
# )
|
| 308 |
+
|
| 309 |
+
# if response.status_code == 200:
|
| 310 |
+
# data = response.json()
|
| 311 |
+
# if data['status'] == 'success':
|
| 312 |
+
# self.registered = True
|
| 313 |
+
# logging.info(f"Client {self.client_id} successfully registered")
|
| 314 |
+
# return True
|
| 315 |
+
|
| 316 |
+
# logging.error(f"Failed to register client: {response.text}")
|
| 317 |
+
# return False
|
| 318 |
+
|
| 319 |
+
# except Exception as e:
|
| 320 |
+
# logging.error(f"Registration failed: {e}")
|
| 321 |
+
# return False
|
| 322 |
+
|
| 323 |
+
# def generate_augmented_samples(self, num_samples: int = 50) -> List[Dict]:
|
| 324 |
+
# """Generate augmented samples using PubLayNet-P perturbations"""
|
| 325 |
+
# samples = []
|
| 326 |
+
# available_perturbations = self.augmentation_engine.get_available_perturbations()
|
| 327 |
+
|
| 328 |
+
# for i, batch in enumerate(self.data_loader):
|
| 329 |
+
# if len(samples) >= num_samples:
|
| 330 |
+
# break
|
| 331 |
+
|
| 332 |
+
# try:
|
| 333 |
+
# images = batch['img']
|
| 334 |
+
# img_metas = batch['img_metas']
|
| 335 |
+
|
| 336 |
+
# for j in range(len(images)):
|
| 337 |
+
# if len(samples) >= num_samples:
|
| 338 |
+
# break
|
| 339 |
+
|
| 340 |
+
# # Convert tensor to PIL Image
|
| 341 |
+
# img_tensor = images[j]
|
| 342 |
+
# img_np = self.tensor_to_numpy(img_tensor)
|
| 343 |
+
# pil_img = Image.fromarray(img_np)
|
| 344 |
+
|
| 345 |
+
# # Apply PubLayNet-P perturbation (CHANGED)
|
| 346 |
+
# if self.perturbation_type == 'all':
|
| 347 |
+
# # Cycle through all perturbation types
|
| 348 |
+
# pert_type = available_perturbations[i % len(available_perturbations)]
|
| 349 |
+
# else:
|
| 350 |
+
# pert_type = self.perturbation_type
|
| 351 |
+
|
| 352 |
+
# augmented_img, augmentation_info = self.augmentation_engine.augment_image(
|
| 353 |
+
# pil_img, pert_type
|
| 354 |
+
# )
|
| 355 |
+
|
| 356 |
+
# # Prepare annotations
|
| 357 |
+
# annotations = self.prepare_annotations(batch, j, augmentation_info)
|
| 358 |
+
|
| 359 |
+
# # Create sample
|
| 360 |
+
# sample = self.create_sample(augmented_img, annotations, augmentation_info)
|
| 361 |
+
# samples.append(sample)
|
| 362 |
+
|
| 363 |
+
# except Exception as e:
|
| 364 |
+
# logging.warning(f"Error processing batch {i}: {e}")
|
| 365 |
+
# continue
|
| 366 |
+
|
| 367 |
+
# logging.info(f"Generated {len(samples)} augmented samples using {self.perturbation_type}")
|
| 368 |
+
# return samples
|
| 369 |
+
|
| 370 |
+
# def tensor_to_numpy(self, tensor: torch.Tensor) -> np.ndarray:
|
| 371 |
+
# """Convert torch tensor to numpy array for image"""
|
| 372 |
+
# # Denormalize and convert
|
| 373 |
+
# img_np = tensor.cpu().numpy().transpose(1, 2, 0)
|
| 374 |
+
# img_np = (img_np * [58.395, 57.12, 57.375] + [123.675, 116.28, 103.53]).astype(np.uint8)
|
| 375 |
+
# return img_np
|
| 376 |
+
|
| 377 |
+
# def prepare_annotations(self, batch: Dict, index: int, aug_info: Dict) -> Dict:
|
| 378 |
+
# """Prepare annotations for a sample, adjusting for augmentations"""
|
| 379 |
+
# bboxes = batch['gt_bboxes'][index].cpu().numpy() if hasattr(batch['gt_bboxes'][index], 'cpu') else batch['gt_bboxes'][index]
|
| 380 |
+
# labels = batch['gt_labels'][index].cpu().numpy() if hasattr(batch['gt_labels'][index], 'cpu') else batch['gt_labels'][index]
|
| 381 |
+
|
| 382 |
+
# # Adjust bounding boxes for geometric transformations
|
| 383 |
+
# if 'geometric' in aug_info['applied_transforms']:
|
| 384 |
+
# bboxes = self.adjust_bboxes_for_augmentation(bboxes, aug_info)
|
| 385 |
+
|
| 386 |
+
# annotations = {
|
| 387 |
+
# 'bboxes': bboxes.tolist(),
|
| 388 |
+
# 'labels': labels.tolist(),
|
| 389 |
+
# 'image_size': aug_info['final_size'],
|
| 390 |
+
# 'original_size': aug_info['original_size']
|
| 391 |
+
# }
|
| 392 |
+
|
| 393 |
+
# return annotations
|
| 394 |
+
|
| 395 |
+
# def adjust_bboxes_for_augmentation(self, bboxes: np.ndarray, aug_info: Dict) -> np.ndarray:
|
| 396 |
+
# """Adjust bounding boxes for geometric augmentations"""
|
| 397 |
+
# # Simplified bbox adjustment
|
| 398 |
+
# # In practice, you'd use the exact transformation matrices
|
| 399 |
+
# scale_x = aug_info['final_size'][0] / aug_info['original_size'][0]
|
| 400 |
+
# scale_y = aug_info['final_size'][1] / aug_info['original_size'][1]
|
| 401 |
+
|
| 402 |
+
# adjusted_bboxes = bboxes.copy()
|
| 403 |
+
# adjusted_bboxes[:, 0] *= scale_x # x1
|
| 404 |
+
# adjusted_bboxes[:, 1] *= scale_y # y1
|
| 405 |
+
# adjusted_bboxes[:, 2] *= scale_x # x2
|
| 406 |
+
# adjusted_bboxes[:, 3] *= scale_y # y2
|
| 407 |
+
|
| 408 |
+
# return adjusted_bboxes
|
| 409 |
+
|
| 410 |
+
# def create_sample(self, image: Image.Image, annotations: Dict, aug_info: Dict) -> Dict:
|
| 411 |
+
# """Create a sample for sending to server"""
|
| 412 |
+
# # Convert image to base64
|
| 413 |
+
# buffered = io.BytesIO()
|
| 414 |
+
# image.save(buffered, format="JPEG", quality=85)
|
| 415 |
+
# img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 416 |
+
|
| 417 |
+
# sample = {
|
| 418 |
+
# 'image_data': img_str,
|
| 419 |
+
# 'annotations': annotations,
|
| 420 |
+
# 'metadata': {
|
| 421 |
+
# 'client_id': self.client_id,
|
| 422 |
+
# 'augmentation_info': aug_info,
|
| 423 |
+
# 'timestamp': time.time(),
|
| 424 |
+
# 'privacy_level': self.privacy_level
|
| 425 |
+
# }
|
| 426 |
+
# }
|
| 427 |
+
|
| 428 |
+
# return sample
|
| 429 |
+
|
| 430 |
+
# def submit_augmented_data(self, samples: List[Dict]) -> bool:
|
| 431 |
+
# """Submit augmented samples to the server"""
|
| 432 |
+
# if not self.registered:
|
| 433 |
+
# logging.error("Client not registered with server")
|
| 434 |
+
# return False
|
| 435 |
+
|
| 436 |
+
# try:
|
| 437 |
+
# response = requests.post(
|
| 438 |
+
# f"{self.server_url}/submit_augmented_data",
|
| 439 |
+
# json={
|
| 440 |
+
# 'client_id': self.client_id,
|
| 441 |
+
# 'samples': samples
|
| 442 |
+
# },
|
| 443 |
+
# timeout=30
|
| 444 |
+
# )
|
| 445 |
+
|
| 446 |
+
# if response.status_code == 200:
|
| 447 |
+
# result = response.json()
|
| 448 |
+
# if result['status'] == 'success':
|
| 449 |
+
# logging.info(f"Successfully submitted {result['received']} samples")
|
| 450 |
+
# return True
|
| 451 |
+
|
| 452 |
+
# logging.error(f"Submission failed: {response.text}")
|
| 453 |
+
# return False
|
| 454 |
+
|
| 455 |
+
# except Exception as e:
|
| 456 |
+
# logging.error(f"Error submitting data: {e}")
|
| 457 |
+
# return False
|
| 458 |
+
|
| 459 |
+
# def run_data_generation(self, samples_per_batch: int = 50, interval: int = 300):
|
| 460 |
+
# """Continuously generate and submit augmented data"""
|
| 461 |
+
# if not self.register_with_server():
|
| 462 |
+
# return False
|
| 463 |
+
|
| 464 |
+
# logging.info(f"Starting continuous data generation (batch: {samples_per_batch}, interval: {interval}s)")
|
| 465 |
+
|
| 466 |
+
# while True:
|
| 467 |
+
# try:
|
| 468 |
+
# samples = self.generate_augmented_samples(samples_per_batch)
|
| 469 |
+
# if samples:
|
| 470 |
+
# success = self.submit_augmented_data(samples)
|
| 471 |
+
# if not success:
|
| 472 |
+
# logging.warning("Failed to submit batch, will retry after interval")
|
| 473 |
+
|
| 474 |
+
# time.sleep(interval)
|
| 475 |
+
|
| 476 |
+
# except KeyboardInterrupt:
|
| 477 |
+
# logging.info("Data generation stopped by user")
|
| 478 |
+
# break
|
| 479 |
+
# except Exception as e:
|
| 480 |
+
# logging.error(f"Error in data generation loop: {e}")
|
| 481 |
+
# time.sleep(interval) # Wait before retrying
|
{federated_rodla β federated_rodla_two/federated_rodla/federated_rodla}/federated/data_server.py
RENAMED
|
@@ -1,164 +1,164 @@
|
|
| 1 |
-
# federated/data_server.py
|
| 2 |
-
|
| 3 |
-
import flask
|
| 4 |
-
from flask import Flask, request, jsonify
|
| 5 |
-
import threading
|
| 6 |
-
import numpy as np
|
| 7 |
-
import json
|
| 8 |
-
import base64
|
| 9 |
-
import io
|
| 10 |
-
from PIL import Image
|
| 11 |
-
import cv2
|
| 12 |
-
import logging
|
| 13 |
-
from collections import defaultdict, deque
|
| 14 |
-
import time
|
| 15 |
-
# Uses DataUtils.process_sample() for validation
|
| 16 |
-
from utils.data_utils import DataUtils
|
| 17 |
-
|
| 18 |
-
class FederatedDataServer:
|
| 19 |
-
def __init__(self, max_clients=10, storage_path='./federated_data'):
|
| 20 |
-
self.app = Flask(__name__)
|
| 21 |
-
self.clients = {}
|
| 22 |
-
self.data_queue = deque()
|
| 23 |
-
self.lock = threading.Lock()
|
| 24 |
-
self.storage_path = storage_path
|
| 25 |
-
self.max_clients = max_clients
|
| 26 |
-
self.processed_samples = 0
|
| 27 |
-
|
| 28 |
-
# Create storage directory
|
| 29 |
-
import os
|
| 30 |
-
os.makedirs(storage_path, exist_ok=True)
|
| 31 |
-
|
| 32 |
-
self.setup_routes()
|
| 33 |
-
logging.basicConfig(level=logging.INFO)
|
| 34 |
-
|
| 35 |
-
def setup_routes(self):
|
| 36 |
-
@self.app.route('/register_client', methods=['POST'])
|
| 37 |
-
def register_client():
|
| 38 |
-
data = request.json
|
| 39 |
-
client_id = data['client_id']
|
| 40 |
-
client_info = data['client_info']
|
| 41 |
-
|
| 42 |
-
with self.lock:
|
| 43 |
-
if len(self.clients) >= self.max_clients:
|
| 44 |
-
return jsonify({'status': 'error', 'message': 'Server full'})
|
| 45 |
-
|
| 46 |
-
self.clients[client_id] = {
|
| 47 |
-
'info': client_info,
|
| 48 |
-
'last_seen': time.time(),
|
| 49 |
-
'samples_sent': 0
|
| 50 |
-
}
|
| 51 |
-
|
| 52 |
-
logging.info(f"Client {client_id} registered")
|
| 53 |
-
return jsonify({'status': 'success', 'client_id': client_id})
|
| 54 |
-
|
| 55 |
-
@self.app.route('/submit_augmented_data', methods=['POST'])
|
| 56 |
-
def submit_augmented_data():
|
| 57 |
-
try:
|
| 58 |
-
data = request.json
|
| 59 |
-
client_id = data['client_id']
|
| 60 |
-
samples = data['samples']
|
| 61 |
-
|
| 62 |
-
# Validate client
|
| 63 |
-
with self.lock:
|
| 64 |
-
if client_id not in self.clients:
|
| 65 |
-
return jsonify({'status': 'error', 'message': 'Client not registered'})
|
| 66 |
-
|
| 67 |
-
# Process each sample
|
| 68 |
-
processed_samples = []
|
| 69 |
-
for sample in samples:
|
| 70 |
-
processed_sample = self.process_sample(sample)
|
| 71 |
-
if processed_sample:
|
| 72 |
-
processed_samples.append(processed_sample)
|
| 73 |
-
|
| 74 |
-
# Add to training queue
|
| 75 |
-
with self.lock:
|
| 76 |
-
self.data_queue.extend(processed_samples)
|
| 77 |
-
self.clients[client_id]['samples_sent'] += len(processed_samples)
|
| 78 |
-
self.processed_samples += len(processed_samples)
|
| 79 |
-
|
| 80 |
-
logging.info(f"Received {len(processed_samples)} samples from {client_id}")
|
| 81 |
-
return jsonify({
|
| 82 |
-
'status': 'success',
|
| 83 |
-
'received': len(processed_samples),
|
| 84 |
-
'total_processed': self.processed_samples
|
| 85 |
-
})
|
| 86 |
-
|
| 87 |
-
except Exception as e:
|
| 88 |
-
logging.error(f"Error processing data: {e}")
|
| 89 |
-
return jsonify({'status': 'error', 'message': str(e)})
|
| 90 |
-
|
| 91 |
-
@self.app.route('/get_training_batch', methods=['GET'])
|
| 92 |
-
def get_training_batch():
|
| 93 |
-
batch_size = request.args.get('batch_size', 32, type=int)
|
| 94 |
-
|
| 95 |
-
with self.lock:
|
| 96 |
-
if len(self.data_queue) < batch_size:
|
| 97 |
-
return jsonify({'status': 'insufficient_data', 'available': len(self.data_queue)})
|
| 98 |
-
|
| 99 |
-
batch = []
|
| 100 |
-
for _ in range(batch_size):
|
| 101 |
-
if self.data_queue:
|
| 102 |
-
batch.append(self.data_queue.popleft())
|
| 103 |
-
|
| 104 |
-
logging.info(f"Sending batch of {len(batch)} samples for training")
|
| 105 |
-
return jsonify({
|
| 106 |
-
'status': 'success',
|
| 107 |
-
'batch': batch,
|
| 108 |
-
'batch_size': len(batch)
|
| 109 |
-
})
|
| 110 |
-
|
| 111 |
-
@self.app.route('/server_stats', methods=['GET'])
|
| 112 |
-
def server_stats():
|
| 113 |
-
with self.lock:
|
| 114 |
-
stats = {
|
| 115 |
-
'total_clients': len(self.clients),
|
| 116 |
-
'samples_in_queue': len(self.data_queue),
|
| 117 |
-
'total_processed_samples': self.processed_samples,
|
| 118 |
-
'clients': {
|
| 119 |
-
client_id: {
|
| 120 |
-
'samples_sent': info['samples_sent'],
|
| 121 |
-
'last_seen': info['last_seen']
|
| 122 |
-
}
|
| 123 |
-
for client_id, info in self.clients.items()
|
| 124 |
-
}
|
| 125 |
-
}
|
| 126 |
-
return jsonify(stats)
|
| 127 |
-
|
| 128 |
-
def process_sample(self, sample):
|
| 129 |
-
"""Process and validate a sample from client"""
|
| 130 |
-
try:
|
| 131 |
-
# Decode image
|
| 132 |
-
if 'image_data' in sample:
|
| 133 |
-
image_data = base64.b64decode(sample['image_data'])
|
| 134 |
-
image = Image.open(io.BytesIO(image_data))
|
| 135 |
-
|
| 136 |
-
# Convert to numpy array (for validation)
|
| 137 |
-
img_array = np.array(image)
|
| 138 |
-
|
| 139 |
-
# Basic validation
|
| 140 |
-
if img_array.size == 0:
|
| 141 |
-
return None
|
| 142 |
-
|
| 143 |
-
# Validate annotations
|
| 144 |
-
if 'annotations' not in sample:
|
| 145 |
-
return None
|
| 146 |
-
|
| 147 |
-
# Add metadata
|
| 148 |
-
sample['received_time'] = time.time()
|
| 149 |
-
sample['server_processed'] = True
|
| 150 |
-
|
| 151 |
-
return sample
|
| 152 |
-
|
| 153 |
-
except Exception as e:
|
| 154 |
-
logging.warning(f"Failed to process sample: {e}")
|
| 155 |
-
return None
|
| 156 |
-
|
| 157 |
-
def run(self, host='0.0.0.0', port=8080):
|
| 158 |
-
"""Start the federated data server"""
|
| 159 |
-
logging.info(f"Starting Federated Data Server on {host}:{port}")
|
| 160 |
-
self.app.run(host=host, port=port, threaded=True)
|
| 161 |
-
|
| 162 |
-
if __name__ == '__main__':
|
| 163 |
-
server = FederatedDataServer(max_clients=10)
|
| 164 |
server.run()
|
|
|
|
| 1 |
+
# federated/data_server.py
|
| 2 |
+
|
| 3 |
+
import flask
|
| 4 |
+
from flask import Flask, request, jsonify
|
| 5 |
+
import threading
|
| 6 |
+
import numpy as np
|
| 7 |
+
import json
|
| 8 |
+
import base64
|
| 9 |
+
import io
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import cv2
|
| 12 |
+
import logging
|
| 13 |
+
from collections import defaultdict, deque
|
| 14 |
+
import time
|
| 15 |
+
# Uses DataUtils.process_sample() for validation
|
| 16 |
+
from utils.data_utils import DataUtils
|
| 17 |
+
|
| 18 |
+
class FederatedDataServer:
|
| 19 |
+
def __init__(self, max_clients=10, storage_path='./federated_data'):
|
| 20 |
+
self.app = Flask(__name__)
|
| 21 |
+
self.clients = {}
|
| 22 |
+
self.data_queue = deque()
|
| 23 |
+
self.lock = threading.Lock()
|
| 24 |
+
self.storage_path = storage_path
|
| 25 |
+
self.max_clients = max_clients
|
| 26 |
+
self.processed_samples = 0
|
| 27 |
+
|
| 28 |
+
# Create storage directory
|
| 29 |
+
import os
|
| 30 |
+
os.makedirs(storage_path, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
self.setup_routes()
|
| 33 |
+
logging.basicConfig(level=logging.INFO)
|
| 34 |
+
|
| 35 |
+
def setup_routes(self):
|
| 36 |
+
@self.app.route('/register_client', methods=['POST'])
|
| 37 |
+
def register_client():
|
| 38 |
+
data = request.json
|
| 39 |
+
client_id = data['client_id']
|
| 40 |
+
client_info = data['client_info']
|
| 41 |
+
|
| 42 |
+
with self.lock:
|
| 43 |
+
if len(self.clients) >= self.max_clients:
|
| 44 |
+
return jsonify({'status': 'error', 'message': 'Server full'})
|
| 45 |
+
|
| 46 |
+
self.clients[client_id] = {
|
| 47 |
+
'info': client_info,
|
| 48 |
+
'last_seen': time.time(),
|
| 49 |
+
'samples_sent': 0
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
logging.info(f"Client {client_id} registered")
|
| 53 |
+
return jsonify({'status': 'success', 'client_id': client_id})
|
| 54 |
+
|
| 55 |
+
@self.app.route('/submit_augmented_data', methods=['POST'])
|
| 56 |
+
def submit_augmented_data():
|
| 57 |
+
try:
|
| 58 |
+
data = request.json
|
| 59 |
+
client_id = data['client_id']
|
| 60 |
+
samples = data['samples']
|
| 61 |
+
|
| 62 |
+
# Validate client
|
| 63 |
+
with self.lock:
|
| 64 |
+
if client_id not in self.clients:
|
| 65 |
+
return jsonify({'status': 'error', 'message': 'Client not registered'})
|
| 66 |
+
|
| 67 |
+
# Process each sample
|
| 68 |
+
processed_samples = []
|
| 69 |
+
for sample in samples:
|
| 70 |
+
processed_sample = self.process_sample(sample)
|
| 71 |
+
if processed_sample:
|
| 72 |
+
processed_samples.append(processed_sample)
|
| 73 |
+
|
| 74 |
+
# Add to training queue
|
| 75 |
+
with self.lock:
|
| 76 |
+
self.data_queue.extend(processed_samples)
|
| 77 |
+
self.clients[client_id]['samples_sent'] += len(processed_samples)
|
| 78 |
+
self.processed_samples += len(processed_samples)
|
| 79 |
+
|
| 80 |
+
logging.info(f"Received {len(processed_samples)} samples from {client_id}")
|
| 81 |
+
return jsonify({
|
| 82 |
+
'status': 'success',
|
| 83 |
+
'received': len(processed_samples),
|
| 84 |
+
'total_processed': self.processed_samples
|
| 85 |
+
})
|
| 86 |
+
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logging.error(f"Error processing data: {e}")
|
| 89 |
+
return jsonify({'status': 'error', 'message': str(e)})
|
| 90 |
+
|
| 91 |
+
@self.app.route('/get_training_batch', methods=['GET'])
|
| 92 |
+
def get_training_batch():
|
| 93 |
+
batch_size = request.args.get('batch_size', 32, type=int)
|
| 94 |
+
|
| 95 |
+
with self.lock:
|
| 96 |
+
if len(self.data_queue) < batch_size:
|
| 97 |
+
return jsonify({'status': 'insufficient_data', 'available': len(self.data_queue)})
|
| 98 |
+
|
| 99 |
+
batch = []
|
| 100 |
+
for _ in range(batch_size):
|
| 101 |
+
if self.data_queue:
|
| 102 |
+
batch.append(self.data_queue.popleft())
|
| 103 |
+
|
| 104 |
+
logging.info(f"Sending batch of {len(batch)} samples for training")
|
| 105 |
+
return jsonify({
|
| 106 |
+
'status': 'success',
|
| 107 |
+
'batch': batch,
|
| 108 |
+
'batch_size': len(batch)
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
@self.app.route('/server_stats', methods=['GET'])
|
| 112 |
+
def server_stats():
|
| 113 |
+
with self.lock:
|
| 114 |
+
stats = {
|
| 115 |
+
'total_clients': len(self.clients),
|
| 116 |
+
'samples_in_queue': len(self.data_queue),
|
| 117 |
+
'total_processed_samples': self.processed_samples,
|
| 118 |
+
'clients': {
|
| 119 |
+
client_id: {
|
| 120 |
+
'samples_sent': info['samples_sent'],
|
| 121 |
+
'last_seen': info['last_seen']
|
| 122 |
+
}
|
| 123 |
+
for client_id, info in self.clients.items()
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
return jsonify(stats)
|
| 127 |
+
|
| 128 |
+
def process_sample(self, sample):
|
| 129 |
+
"""Process and validate a sample from client"""
|
| 130 |
+
try:
|
| 131 |
+
# Decode image
|
| 132 |
+
if 'image_data' in sample:
|
| 133 |
+
image_data = base64.b64decode(sample['image_data'])
|
| 134 |
+
image = Image.open(io.BytesIO(image_data))
|
| 135 |
+
|
| 136 |
+
# Convert to numpy array (for validation)
|
| 137 |
+
img_array = np.array(image)
|
| 138 |
+
|
| 139 |
+
# Basic validation
|
| 140 |
+
if img_array.size == 0:
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
# Validate annotations
|
| 144 |
+
if 'annotations' not in sample:
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
# Add metadata
|
| 148 |
+
sample['received_time'] = time.time()
|
| 149 |
+
sample['server_processed'] = True
|
| 150 |
+
|
| 151 |
+
return sample
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
logging.warning(f"Failed to process sample: {e}")
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
def run(self, host='0.0.0.0', port=8080):
|
| 158 |
+
"""Start the federated data server"""
|
| 159 |
+
logging.info(f"Starting Federated Data Server on {host}:{port}")
|
| 160 |
+
self.app.run(host=host, port=port, threaded=True)
|
| 161 |
+
|
| 162 |
+
if __name__ == '__main__':
|
| 163 |
+
server = FederatedDataServer(max_clients=10)
|
| 164 |
server.run()
|
federated_rodla_two/federated_rodla/federated_rodla/federated/perturbation_engine.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# federated/perturbation_engine.py
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image, ImageFilter, ImageEnhance
|
| 4 |
+
import cv2
|
| 5 |
+
import random
|
| 6 |
+
from typing import Dict, Tuple, List
|
| 7 |
+
|
| 8 |
+
class PubLayNetPerturbationEngine:
|
| 9 |
+
"""
|
| 10 |
+
Perturbations used for inference-time robustness evaluation.
|
| 11 |
+
Returns PIL.Image in RGB mode and a small aug_info dict describing what was applied.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, perturbation_type: str = 'random', severity_level: int = 2):
|
| 15 |
+
self.perturbation_type = perturbation_type
|
| 16 |
+
self.severity_level = severity_level # 1,2,3
|
| 17 |
+
self.perturbation_functions = {
|
| 18 |
+
'background': self.apply_background,
|
| 19 |
+
'defocus': self.apply_defocus,
|
| 20 |
+
'illumination': self.apply_illumination,
|
| 21 |
+
'ink_bleeding': self.apply_ink_bleeding,
|
| 22 |
+
'ink_holdout': self.apply_ink_holdout,
|
| 23 |
+
'keystoning': self.apply_keystoning,
|
| 24 |
+
'rotation': self.apply_rotation,
|
| 25 |
+
'speckle': self.apply_speckle,
|
| 26 |
+
'texture': self.apply_texture,
|
| 27 |
+
'vibration': self.apply_vibration,
|
| 28 |
+
'warping': self.apply_warping,
|
| 29 |
+
'watermark': self.apply_watermark
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
def get_available_perturbations(self) -> List[str]:
|
| 33 |
+
return list(self.perturbation_functions.keys())
|
| 34 |
+
|
| 35 |
+
def perturb(self, image: Image.Image, perturbation_type: str = None) -> Tuple[Image.Image, Dict]:
|
| 36 |
+
"""Apply the chosen perturbation and return (image, info)."""
|
| 37 |
+
if image.mode != 'RGB':
|
| 38 |
+
image = image.convert('RGB')
|
| 39 |
+
|
| 40 |
+
if perturbation_type is None:
|
| 41 |
+
perturbation_type = self.perturbation_type
|
| 42 |
+
|
| 43 |
+
if perturbation_type == 'random':
|
| 44 |
+
perturbation_type = random.choice(self.get_available_perturbations())
|
| 45 |
+
|
| 46 |
+
info = {
|
| 47 |
+
'perturbation_type': perturbation_type,
|
| 48 |
+
'severity_level': self.severity_level,
|
| 49 |
+
'parameters': {}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
func = self.perturbation_functions.get(perturbation_type, None)
|
| 53 |
+
if func is None:
|
| 54 |
+
return image, info
|
| 55 |
+
|
| 56 |
+
out = func(image)
|
| 57 |
+
if not isinstance(out, Image.Image):
|
| 58 |
+
out = Image.fromarray(np.uint8(out))
|
| 59 |
+
if out.mode != 'RGB':
|
| 60 |
+
out = out.convert('RGB')
|
| 61 |
+
|
| 62 |
+
info['final_size'] = out.size
|
| 63 |
+
return out, info
|
| 64 |
+
|
| 65 |
+
def apply_background(self, image: Image.Image) -> Image.Image:
|
| 66 |
+
severity = {1: (10, 0.1), 2: (25, 0.3), 3: (50, 0.6)}[self.severity_level]
|
| 67 |
+
color_var, tex_strength = severity
|
| 68 |
+
img = np.array(image).astype(np.int16)
|
| 69 |
+
shift = np.random.randint(-color_var, color_var + 1, 3)
|
| 70 |
+
img = np.clip(img + shift, 0, 255).astype(np.uint8)
|
| 71 |
+
|
| 72 |
+
if tex_strength > 0:
|
| 73 |
+
noise = np.random.normal(0, tex_strength * 255, img.shape)
|
| 74 |
+
img = np.clip(img.astype(np.int16) + noise.astype(np.int16), 0, 255).astype(np.uint8)
|
| 75 |
+
|
| 76 |
+
return Image.fromarray(img)
|
| 77 |
+
|
| 78 |
+
def apply_defocus(self, image: Image.Image) -> Image.Image:
|
| 79 |
+
radius = {1: 1.0, 2: 2.0, 3: 4.0}[self.severity_level]
|
| 80 |
+
return image.filter(ImageFilter.GaussianBlur(radius=radius))
|
| 81 |
+
|
| 82 |
+
def apply_illumination(self, image: Image.Image) -> Image.Image:
|
| 83 |
+
params = {1: (0.9, 0.9), 2: (0.7, 0.7), 3: (0.5, 0.5)}[self.severity_level]
|
| 84 |
+
img = ImageEnhance.Brightness(image).enhance(params[0])
|
| 85 |
+
img = ImageEnhance.Contrast(img).enhance(params[1])
|
| 86 |
+
return img
|
| 87 |
+
|
| 88 |
+
def apply_ink_bleeding(self, image: Image.Image) -> Image.Image:
|
| 89 |
+
img = np.array(image)
|
| 90 |
+
h, w = img.shape[:2]
|
| 91 |
+
strength = {1: 0.1, 2: 0.2, 3: 0.4}[self.severity_level]
|
| 92 |
+
kernel_size = max(1, int(max(h, w) * 0.01 * strength * 10))
|
| 93 |
+
if kernel_size % 2 == 0:
|
| 94 |
+
kernel_size += 1
|
| 95 |
+
kernel = np.ones((kernel_size, kernel_size), dtype=np.float32) / (kernel_size * kernel_size)
|
| 96 |
+
out = np.empty_like(img)
|
| 97 |
+
for c in range(img.shape[2]):
|
| 98 |
+
out[:, :, c] = cv2.filter2D(img[:, :, c], -1, kernel)
|
| 99 |
+
return Image.fromarray(out)
|
| 100 |
+
|
| 101 |
+
def apply_ink_holdout(self, image: Image.Image) -> Image.Image:
|
| 102 |
+
img = np.array(image)
|
| 103 |
+
dropout = {1: 0.05, 2: 0.1, 3: 0.2}[self.severity_level]
|
| 104 |
+
mask = np.random.random(img.shape[:2]) < dropout
|
| 105 |
+
for c in range(img.shape[2]):
|
| 106 |
+
img[:, :, c][mask] = 255
|
| 107 |
+
return Image.fromarray(img)
|
| 108 |
+
|
| 109 |
+
def apply_keystoning(self, image: Image.Image) -> Image.Image:
|
| 110 |
+
w, h = image.size
|
| 111 |
+
distortion = {1: 0.05, 2: 0.1, 3: 0.15}[self.severity_level]
|
| 112 |
+
src = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
|
| 113 |
+
shift_x, shift_y = int(w * distortion), int(h * distortion)
|
| 114 |
+
dst = np.float32([
|
| 115 |
+
[0 + shift_x, 0 + int(shift_y * 0.2)],
|
| 116 |
+
[w - shift_x, 0 + int(shift_y * 0.1)],
|
| 117 |
+
[w - int(shift_x * 0.8), h - shift_y],
|
| 118 |
+
[int(shift_x * 0.2), h - int(shift_y * 0.8)]
|
| 119 |
+
])
|
| 120 |
+
M = cv2.getPerspectiveTransform(src, dst)
|
| 121 |
+
arr = np.array(image)
|
| 122 |
+
warped = cv2.warpPerspective(arr, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
|
| 123 |
+
return Image.fromarray(warped)
|
| 124 |
+
|
| 125 |
+
def apply_rotation(self, image: Image.Image) -> Image.Image:
|
| 126 |
+
angle = {1: 2, 2: 5, 3: 10}[self.severity_level] * random.choice([-1, 1])
|
| 127 |
+
return image.rotate(angle, resample=Image.BILINEAR, expand=False)
|
| 128 |
+
|
| 129 |
+
def apply_speckle(self, image: Image.Image) -> Image.Image:
|
| 130 |
+
lvl = {1: 0.05, 2: 0.1, 3: 0.2}[self.severity_level]
|
| 131 |
+
arr = np.array(image).astype(np.float32) / 255.0
|
| 132 |
+
noise = np.random.normal(0, lvl, arr.shape).astype(np.float32)
|
| 133 |
+
out = np.clip(arr + arr * noise, 0, 1) * 255
|
| 134 |
+
return Image.fromarray(out.astype(np.uint8))
|
| 135 |
+
|
| 136 |
+
def apply_texture(self, image: Image.Image) -> Image.Image:
|
| 137 |
+
opacity = {1: 0.1, 2: 0.25, 3: 0.4}[self.severity_level]
|
| 138 |
+
w, h = image.size
|
| 139 |
+
texture = np.random.randint(0, 50, (h, w, 3), dtype=np.uint8)
|
| 140 |
+
texture_img = Image.fromarray(texture).convert('RGB').resize((w, h))
|
| 141 |
+
return Image.blend(image, texture_img, opacity)
|
| 142 |
+
|
| 143 |
+
def apply_vibration(self, image: Image.Image) -> Image.Image:
|
| 144 |
+
kernel_size = {1: 3, 2: 5, 3: 8}[self.severity_level]
|
| 145 |
+
arr = np.array(image).astype(np.float32)
|
| 146 |
+
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
|
| 147 |
+
kernel[int((kernel_size - 1) / 2), :] = np.ones(kernel_size, dtype=np.float32)
|
| 148 |
+
kernel = kernel / kernel_size
|
| 149 |
+
blurred = cv2.filter2D(arr, -1, kernel)
|
| 150 |
+
return Image.fromarray(np.clip(blurred, 0, 255).astype(np.uint8))
|
| 151 |
+
|
| 152 |
+
def apply_warping(self, image: Image.Image) -> Image.Image:
|
| 153 |
+
magnitude = {1: 5, 2: 10, 3: 20}[self.severity_level]
|
| 154 |
+
w, h = image.size
|
| 155 |
+
arr = np.array(image)
|
| 156 |
+
x, y = np.meshgrid(np.arange(w), np.arange(h))
|
| 157 |
+
dx = magnitude * np.sin(2 * np.pi * y / max(1, (h / 4.0)))
|
| 158 |
+
dy = magnitude * np.cos(2 * np.pi * x / max(1, (w / 4.0)))
|
| 159 |
+
map_x = (x + dx).astype(np.float32)
|
| 160 |
+
map_y = (y + dy).astype(np.float32)
|
| 161 |
+
warped = cv2.remap(arr, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
|
| 162 |
+
return Image.fromarray(warped)
|
| 163 |
+
|
| 164 |
+
def apply_watermark(self, image: Image.Image) -> Image.Image:
|
| 165 |
+
w, h = image.size
|
| 166 |
+
opacity = {1: 0.1, 2: 0.2, 3: 0.3}[self.severity_level]
|
| 167 |
+
watermark = Image.new('RGBA', (w, h), (0, 0, 0, 0))
|
| 168 |
+
from PIL import ImageDraw, ImageFont
|
| 169 |
+
draw = ImageDraw.Draw(watermark)
|
| 170 |
+
try:
|
| 171 |
+
font = ImageFont.truetype("arial.ttf", max(12, min(w, h) // 12))
|
| 172 |
+
except Exception:
|
| 173 |
+
font = ImageFont.load_default()
|
| 174 |
+
text = "CONFIDENTIAL"
|
| 175 |
+
for i in range(3):
|
| 176 |
+
x = int((w - 10) * (i / 2.0))
|
| 177 |
+
y = int((h - 10) * (i / 2.0))
|
| 178 |
+
draw.text((x, y), text, font=font, fill=(255, 255, 255, int(255 * opacity)))
|
| 179 |
+
base = image.convert('RGBA')
|
| 180 |
+
comp = Image.alpha_composite(base, watermark)
|
| 181 |
+
return comp.convert('RGB')
|
{federated_rodla β federated_rodla_two/federated_rodla/federated_rodla}/federated/privacy_utils.py
RENAMED
|
File without changes
|
federated_rodla_two/federated_rodla/federated_rodla/federated/training_server.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# federated/training_server.py
|
| 2 |
+
|
| 3 |
+
import flask
|
| 4 |
+
from flask import Flask, request, jsonify
|
| 5 |
+
import threading
|
| 6 |
+
import numpy as np
|
| 7 |
+
import json
|
| 8 |
+
import base64
|
| 9 |
+
import io
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import cv2
|
| 12 |
+
import logging
|
| 13 |
+
from collections import defaultdict, deque
|
| 14 |
+
import time
|
| 15 |
+
import torch
|
| 16 |
+
import subprocess
|
| 17 |
+
import os
|
| 18 |
+
from utils.data_utils import DataUtils, FederatedDataConverter
|
| 19 |
+
|
| 20 |
+
class FederatedTrainingServer:
|
| 21 |
+
def __init__(self, max_clients=10, storage_path='./federated_data',
|
| 22 |
+
rodla_config_path='configs/publaynet/rodla_internimage_xl_publaynet.py',
|
| 23 |
+
model_checkpoint=None):
|
| 24 |
+
self.app = Flask(__name__)
|
| 25 |
+
self.clients = {}
|
| 26 |
+
self.data_queue = deque()
|
| 27 |
+
self.training_data = [] # Store data for training
|
| 28 |
+
self.lock = threading.Lock()
|
| 29 |
+
self.storage_path = storage_path
|
| 30 |
+
self.max_clients = max_clients
|
| 31 |
+
self.processed_samples = 0
|
| 32 |
+
self.rodla_config_path = rodla_config_path
|
| 33 |
+
self.model_checkpoint = model_checkpoint
|
| 34 |
+
self.is_training = False
|
| 35 |
+
self.training_process = None
|
| 36 |
+
|
| 37 |
+
# Create storage directory
|
| 38 |
+
os.makedirs(storage_path, exist_ok=True)
|
| 39 |
+
os.makedirs('./federated_training_data', exist_ok=True)
|
| 40 |
+
|
| 41 |
+
self.setup_routes()
|
| 42 |
+
logging.basicConfig(level=logging.INFO)
|
| 43 |
+
|
| 44 |
+
# Start training monitor thread
|
| 45 |
+
self.training_thread = threading.Thread(target=self._training_monitor, daemon=True)
|
| 46 |
+
self.training_thread.start()
|
| 47 |
+
|
| 48 |
+
def setup_routes(self):
|
| 49 |
+
# ... (keep all existing routes: register_client, submit_augmented_data, etc.)
|
| 50 |
+
|
| 51 |
+
@self.app.route('/start_training', methods=['POST'])
|
| 52 |
+
def start_training():
|
| 53 |
+
"""Start RoDLA training with federated data"""
|
| 54 |
+
with self.lock:
|
| 55 |
+
if self.is_training:
|
| 56 |
+
return jsonify({'status': 'error', 'message': 'Training already in progress'})
|
| 57 |
+
|
| 58 |
+
if len(self.training_data) < 100: # Minimum samples to start training
|
| 59 |
+
return jsonify({'status': 'error', 'message': f'Insufficient data: {len(self.training_data)} samples'})
|
| 60 |
+
|
| 61 |
+
# Start training in separate thread
|
| 62 |
+
training_thread = threading.Thread(target=self._start_rodla_training)
|
| 63 |
+
training_thread.start()
|
| 64 |
+
|
| 65 |
+
return jsonify({
|
| 66 |
+
'status': 'success',
|
| 67 |
+
'message': 'Training started',
|
| 68 |
+
'training_samples': len(self.training_data)
|
| 69 |
+
})
|
| 70 |
+
|
| 71 |
+
@self.app.route('/training_status', methods=['GET'])
|
| 72 |
+
def training_status():
|
| 73 |
+
"""Get current training status"""
|
| 74 |
+
return jsonify({
|
| 75 |
+
'is_training': self.is_training,
|
| 76 |
+
'training_samples': len(self.training_data),
|
| 77 |
+
'total_clients': len(self.clients),
|
| 78 |
+
'total_processed': self.processed_samples
|
| 79 |
+
})
|
| 80 |
+
|
| 81 |
+
def process_sample(self, sample):
|
| 82 |
+
"""Process and validate a sample from client - UPDATED to store for training"""
|
| 83 |
+
try:
|
| 84 |
+
# Decode image
|
| 85 |
+
if 'image_data' in sample:
|
| 86 |
+
image_data = base64.b64decode(sample['image_data'])
|
| 87 |
+
image = Image.open(io.BytesIO(image_data))
|
| 88 |
+
|
| 89 |
+
# Convert to numpy array (for validation)
|
| 90 |
+
img_array = np.array(image)
|
| 91 |
+
|
| 92 |
+
# Basic validation
|
| 93 |
+
if img_array.size == 0:
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
# Validate annotations
|
| 97 |
+
if 'annotations' not in sample:
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
# Store sample for training
|
| 101 |
+
with self.lock:
|
| 102 |
+
self.training_data.append(sample)
|
| 103 |
+
|
| 104 |
+
# Limit training data size to prevent memory issues
|
| 105 |
+
if len(self.training_data) > 10000:
|
| 106 |
+
self.training_data = self.training_data[-10000:]
|
| 107 |
+
|
| 108 |
+
# Add metadata
|
| 109 |
+
sample['received_time'] = time.time()
|
| 110 |
+
sample['server_processed'] = True
|
| 111 |
+
|
| 112 |
+
return sample
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
logging.warning(f"Failed to process sample: {e}")
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
def _start_rodla_training(self):
|
| 119 |
+
"""Start RoDLA training with federated data"""
|
| 120 |
+
try:
|
| 121 |
+
self.is_training = True
|
| 122 |
+
logging.info("Starting RoDLA training with federated data...")
|
| 123 |
+
|
| 124 |
+
# Convert federated data to RoDLA training format
|
| 125 |
+
training_dataset = self._prepare_training_dataset()
|
| 126 |
+
|
| 127 |
+
# Save training dataset
|
| 128 |
+
dataset_path = self._save_training_dataset(training_dataset)
|
| 129 |
+
|
| 130 |
+
# Start RoDLA training process
|
| 131 |
+
self._run_rodla_training(dataset_path)
|
| 132 |
+
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logging.error(f"Training failed: {e}")
|
| 135 |
+
finally:
|
| 136 |
+
self.is_training = False
|
| 137 |
+
|
| 138 |
+
def _prepare_training_dataset(self):
|
| 139 |
+
"""Convert federated samples to RoDLA training format"""
|
| 140 |
+
training_samples = []
|
| 141 |
+
|
| 142 |
+
for sample in self.training_data:
|
| 143 |
+
try:
|
| 144 |
+
# Convert federated format to RoDLA format
|
| 145 |
+
rodla_sample = FederatedDataConverter.federated_to_rodla(sample)
|
| 146 |
+
training_samples.append(rodla_sample)
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logging.warning(f"Failed to convert sample: {e}")
|
| 149 |
+
continue
|
| 150 |
+
|
| 151 |
+
logging.info(f"Prepared {len(training_samples)} samples for training")
|
| 152 |
+
return training_samples
|
| 153 |
+
|
| 154 |
+
def _save_training_dataset(self, training_dataset):
|
| 155 |
+
"""Save training dataset to disk in COCO format"""
|
| 156 |
+
dataset_dir = './federated_training_data'
|
| 157 |
+
os.makedirs(dataset_dir, exist_ok=True)
|
| 158 |
+
|
| 159 |
+
# Save images
|
| 160 |
+
images_dir = os.path.join(dataset_dir, 'images')
|
| 161 |
+
os.makedirs(images_dir, exist_ok=True)
|
| 162 |
+
|
| 163 |
+
annotations = {
|
| 164 |
+
'images': [],
|
| 165 |
+
'annotations': [],
|
| 166 |
+
'categories': [
|
| 167 |
+
{'id': 1, 'name': 'text'},
|
| 168 |
+
{'id': 2, 'name': 'title'},
|
| 169 |
+
{'id': 3, 'name': 'list'},
|
| 170 |
+
{'id': 4, 'name': 'table'},
|
| 171 |
+
{'id': 5, 'name': 'figure'}
|
| 172 |
+
]
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
annotation_id = 1
|
| 176 |
+
|
| 177 |
+
for i, sample in enumerate(training_dataset):
|
| 178 |
+
# Save image
|
| 179 |
+
img_tensor = sample['img']
|
| 180 |
+
img_np = (img_tensor * torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1) +
|
| 181 |
+
torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1))
|
| 182 |
+
img_np = img_np.numpy().transpose(1, 2, 0).astype(np.uint8)
|
| 183 |
+
img_pil = Image.fromarray(img_np)
|
| 184 |
+
|
| 185 |
+
img_filename = f"federated_{i:06d}.jpg"
|
| 186 |
+
img_path = os.path.join(images_dir, img_filename)
|
| 187 |
+
img_pil.save(img_path)
|
| 188 |
+
|
| 189 |
+
# Add image info
|
| 190 |
+
img_info = {
|
| 191 |
+
'id': i,
|
| 192 |
+
'file_name': img_filename,
|
| 193 |
+
'width': img_np.shape[1],
|
| 194 |
+
'height': img_np.shape[0]
|
| 195 |
+
}
|
| 196 |
+
annotations['images'].append(img_info)
|
| 197 |
+
|
| 198 |
+
# Add annotations
|
| 199 |
+
bboxes = sample['gt_bboxes']
|
| 200 |
+
labels = sample['gt_labels']
|
| 201 |
+
|
| 202 |
+
for bbox, label in zip(bboxes, labels):
|
| 203 |
+
x1, y1, x2, y2 = bbox.tolist()
|
| 204 |
+
annotation = {
|
| 205 |
+
'id': annotation_id,
|
| 206 |
+
'image_id': i,
|
| 207 |
+
'category_id': label.item(),
|
| 208 |
+
'bbox': [x1, y1, x2 - x1, y2 - y1], # COCO format: [x, y, width, height]
|
| 209 |
+
'area': (x2 - x1) * (y2 - y1),
|
| 210 |
+
'iscrowd': 0
|
| 211 |
+
}
|
| 212 |
+
annotations['annotations'].append(annotation)
|
| 213 |
+
annotation_id += 1
|
| 214 |
+
|
| 215 |
+
# Save annotations
|
| 216 |
+
annotations_path = os.path.join(dataset_dir, 'annotations.json')
|
| 217 |
+
with open(annotations_path, 'w') as f:
|
| 218 |
+
json.dump(annotations, f, indent=2)
|
| 219 |
+
|
| 220 |
+
logging.info(f"Saved training dataset: {len(annotations['images'])} images, "
|
| 221 |
+
f"{len(annotations['annotations'])} annotations")
|
| 222 |
+
|
| 223 |
+
return dataset_dir
|
| 224 |
+
|
| 225 |
+
def _run_rodla_training(self, dataset_path):
|
| 226 |
+
"""Run actual RoDLA training using the provided dataset"""
|
| 227 |
+
try:
|
| 228 |
+
# Create modified config for federated training
|
| 229 |
+
config_content = self._create_federated_config(dataset_path)
|
| 230 |
+
config_path = './configs/federated/rodla_federated_publaynet.py'
|
| 231 |
+
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
| 232 |
+
|
| 233 |
+
with open(config_path, 'w') as f:
|
| 234 |
+
f.write(config_content)
|
| 235 |
+
|
| 236 |
+
# Run RoDLA training command (from their GitHub)
|
| 237 |
+
cmd = [
|
| 238 |
+
'python', 'model/train.py',
|
| 239 |
+
config_path,
|
| 240 |
+
'--work-dir', './work_dirs/federated_rodla',
|
| 241 |
+
'--auto-resume'
|
| 242 |
+
]
|
| 243 |
+
|
| 244 |
+
if self.model_checkpoint:
|
| 245 |
+
cmd.extend(['--resume-from', self.model_checkpoint])
|
| 246 |
+
|
| 247 |
+
logging.info(f"Starting RoDLA training: {' '.join(cmd)}")
|
| 248 |
+
|
| 249 |
+
# Run training process
|
| 250 |
+
self.training_process = subprocess.Popen(
|
| 251 |
+
cmd,
|
| 252 |
+
cwd='.', # Assuming we're in RoDLA root directory
|
| 253 |
+
stdout=subprocess.PIPE,
|
| 254 |
+
stderr=subprocess.STDOUT,
|
| 255 |
+
universal_newlines=True
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Log training output
|
| 259 |
+
for line in iter(self.training_process.stdout.readline, ''):
|
| 260 |
+
logging.info(f"TRAINING: {line.strip()}")
|
| 261 |
+
|
| 262 |
+
self.training_process.wait()
|
| 263 |
+
|
| 264 |
+
if self.training_process.returncode == 0:
|
| 265 |
+
logging.info("RoDLA training completed successfully!")
|
| 266 |
+
else:
|
| 267 |
+
logging.error(f"RoDLA training failed with code {self.training_process.returncode}")
|
| 268 |
+
|
| 269 |
+
except Exception as e:
|
| 270 |
+
logging.error(f"Error running RoDLA training: {e}")
|
| 271 |
+
|
| 272 |
+
def _create_federated_config(self, dataset_path):
|
| 273 |
+
"""Create modified RoDLA config for federated training"""
|
| 274 |
+
base_config = f'''
|
| 275 |
+
_base_ = '../publaynet/rodla_internimage_xl_publaynet.py'
|
| 276 |
+
|
| 277 |
+
# Federated training settings
|
| 278 |
+
data = dict(
|
| 279 |
+
samples_per_gpu=2,
|
| 280 |
+
workers_per_gpu=2,
|
| 281 |
+
train=dict(
|
| 282 |
+
ann_file='{dataset_path}/annotations.json',
|
| 283 |
+
img_prefix='{dataset_path}/images/',
|
| 284 |
+
),
|
| 285 |
+
val=dict(
|
| 286 |
+
ann_file='{dataset_path}/annotations.json', # Using same data for val during federated training
|
| 287 |
+
img_prefix='{dataset_path}/images/',
|
| 288 |
+
)
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# Training schedule for federated learning
|
| 292 |
+
runner = dict(max_epochs=12) # Shorter epochs for frequent updates
|
| 293 |
+
lr_config = dict(
|
| 294 |
+
policy='step',
|
| 295 |
+
warmup='linear',
|
| 296 |
+
warmup_iters=500,
|
| 297 |
+
warmup_ratio=0.001,
|
| 298 |
+
step=[8, 11]
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Logging
|
| 302 |
+
log_config = dict(
|
| 303 |
+
interval=10,
|
| 304 |
+
hooks=[
|
| 305 |
+
dict(type='TextLoggerHook'),
|
| 306 |
+
dict(type='TensorboardLoggerHook')
|
| 307 |
+
]
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Evaluation
|
| 311 |
+
evaluation = dict(interval=1, metric=['bbox', 'segm'])
|
| 312 |
+
checkpoint_config = dict(interval=1)
|
| 313 |
+
'''
|
| 314 |
+
return base_config
|
| 315 |
+
|
| 316 |
+
def _training_monitor(self):
|
| 317 |
+
"""Monitor training process"""
|
| 318 |
+
while True:
|
| 319 |
+
if self.training_process and self.training_process.poll() is not None:
|
| 320 |
+
self.is_training = False
|
| 321 |
+
self.training_process = None
|
| 322 |
+
logging.info("Training process finished")
|
| 323 |
+
|
| 324 |
+
time.sleep(10)
|
| 325 |
+
|
| 326 |
+
if __name__ == '__main__':
|
| 327 |
+
server = FederatedTrainingServer(
|
| 328 |
+
rodla_config_path='configs/publaynet/rodla_internimage_xl_publaynet.py',
|
| 329 |
+
model_checkpoint='checkpoints/rodla_internimage_xl_publaynet.pth' # if available
|
| 330 |
+
)
|
| 331 |
+
server.run()
|
federated_rodla_two/federated_rodla/federated_rodla/scripts/start_data_client.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# scripts/start_data_client.py
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
|
| 9 |
+
from federated.data_client import FederatedDataClient
|
| 10 |
+
import torch
|
| 11 |
+
from torch.utils.data import DataLoader, Dataset
|
| 12 |
+
from mmdet.datasets import build_dataset, build_dataloader
|
| 13 |
+
from mmcv import Config
|
| 14 |
+
import json
|
| 15 |
+
from PIL import Image
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
class PubLayNetDataset(Dataset):
|
| 19 |
+
"""Actual PubLayNet dataset loader for federated client"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, data_root, annotation_file, split='train', max_samples=1000):
|
| 22 |
+
self.data_root = data_root
|
| 23 |
+
self.split = split
|
| 24 |
+
self.max_samples = max_samples
|
| 25 |
+
|
| 26 |
+
# Load annotations
|
| 27 |
+
with open(annotation_file, 'r') as f:
|
| 28 |
+
self.annotations = json.load(f)
|
| 29 |
+
|
| 30 |
+
# Filter images for the specified split
|
| 31 |
+
self.images = [img for img in self.annotations['images']
|
| 32 |
+
if img['file_name'].startswith(split)]
|
| 33 |
+
|
| 34 |
+
# Limit samples if specified
|
| 35 |
+
if max_samples:
|
| 36 |
+
self.images = self.images[:max_samples]
|
| 37 |
+
|
| 38 |
+
# Create image id to annotations mapping
|
| 39 |
+
self.img_to_anns = {}
|
| 40 |
+
for ann in self.annotations['annotations']:
|
| 41 |
+
img_id = ann['image_id']
|
| 42 |
+
if img_id not in self.img_to_anns:
|
| 43 |
+
self.img_to_anns[img_id] = []
|
| 44 |
+
self.img_to_anns[img_id].append(ann)
|
| 45 |
+
|
| 46 |
+
# PubLayNet categories
|
| 47 |
+
self.categories = {
|
| 48 |
+
1: 'text', 2: 'title', 3: 'list', 4: 'table', 5: 'figure'
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
print(f"Loaded {len(self.images)} images from PubLayNet {split} set")
|
| 52 |
+
|
| 53 |
+
def __len__(self):
|
| 54 |
+
return len(self.images)
|
| 55 |
+
|
| 56 |
+
def __getitem__(self, idx):
|
| 57 |
+
try:
|
| 58 |
+
img_info = self.images[idx]
|
| 59 |
+
img_path = os.path.join(self.data_root, img_info['file_name'])
|
| 60 |
+
|
| 61 |
+
# Load image
|
| 62 |
+
image = Image.open(img_path).convert('RGB')
|
| 63 |
+
img_width, img_height = image.size
|
| 64 |
+
|
| 65 |
+
# Get annotations for this image
|
| 66 |
+
anns = self.img_to_anns.get(img_info['id'], [])
|
| 67 |
+
|
| 68 |
+
bboxes = []
|
| 69 |
+
labels = []
|
| 70 |
+
|
| 71 |
+
for ann in anns:
|
| 72 |
+
# Convert COCO bbox format [x, y, width, height] to [x1, y1, x2, y2]
|
| 73 |
+
x, y, w, h = ann['bbox']
|
| 74 |
+
bbox = [x, y, x + w, y + h]
|
| 75 |
+
|
| 76 |
+
# Filter invalid bboxes
|
| 77 |
+
if (bbox[2] - bbox[0] > 1 and bbox[3] - bbox[1] > 1 and
|
| 78 |
+
bbox[0] >= 0 and bbox[1] >= 0 and
|
| 79 |
+
bbox[2] <= img_width and bbox[3] <= img_height):
|
| 80 |
+
bboxes.append(bbox)
|
| 81 |
+
labels.append(ann['category_id'])
|
| 82 |
+
|
| 83 |
+
if len(bboxes) == 0:
|
| 84 |
+
# Return empty annotations if no valid bboxes
|
| 85 |
+
bboxes = [[0, 0, 1, 1]] # dummy bbox
|
| 86 |
+
labels = [1] # text category
|
| 87 |
+
|
| 88 |
+
# Convert to tensors
|
| 89 |
+
bboxes_tensor = torch.tensor(bboxes, dtype=torch.float32)
|
| 90 |
+
labels_tensor = torch.tensor(labels, dtype=torch.int64)
|
| 91 |
+
|
| 92 |
+
# Convert image to tensor (normalized)
|
| 93 |
+
img_tensor = torch.from_numpy(np.array(image).astype(np.float32)).permute(2, 0, 1)
|
| 94 |
+
img_tensor = (img_tensor - torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1)) / \
|
| 95 |
+
torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1)
|
| 96 |
+
|
| 97 |
+
# Create img_meta in RoDLA format
|
| 98 |
+
img_meta = {
|
| 99 |
+
'filename': img_info['file_name'],
|
| 100 |
+
'ori_shape': (img_height, img_width, 3),
|
| 101 |
+
'img_shape': (img_height, img_width, 3),
|
| 102 |
+
'pad_shape': (img_height, img_width, 3),
|
| 103 |
+
'scale_factor': np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float32),
|
| 104 |
+
'flip': False,
|
| 105 |
+
'flip_direction': None,
|
| 106 |
+
'img_norm_cfg': {
|
| 107 |
+
'mean': [123.675, 116.28, 103.53],
|
| 108 |
+
'std': [58.395, 57.12, 57.375],
|
| 109 |
+
'to_rgb': True
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
return {
|
| 114 |
+
'img': img_tensor,
|
| 115 |
+
'gt_bboxes': bboxes_tensor,
|
| 116 |
+
'gt_labels': labels_tensor,
|
| 117 |
+
'img_metas': img_meta
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f"Error loading image {idx}: {e}")
|
| 122 |
+
# Return a dummy sample on error
|
| 123 |
+
return self.create_dummy_sample()
|
| 124 |
+
|
| 125 |
+
def create_dummy_sample(self):
|
| 126 |
+
"""Create a dummy sample when loading fails"""
|
| 127 |
+
return {
|
| 128 |
+
'img': torch.randn(3, 800, 800),
|
| 129 |
+
'gt_bboxes': torch.tensor([[100, 100, 200, 200]]),
|
| 130 |
+
'gt_labels': torch.tensor([1]),
|
| 131 |
+
'img_metas': {
|
| 132 |
+
'filename': 'dummy.jpg',
|
| 133 |
+
'ori_shape': (800, 800, 3),
|
| 134 |
+
'img_shape': (800, 800, 3),
|
| 135 |
+
'scale_factor': np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float32),
|
| 136 |
+
'flip': False
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
def create_publaynet_dataloader(data_root='/path/to/publaynet',
|
| 141 |
+
annotation_file='/path/to/annotations.json',
|
| 142 |
+
split='train',
|
| 143 |
+
batch_size=4,
|
| 144 |
+
max_samples=1000):
|
| 145 |
+
"""Create actual PubLayNet data loader"""
|
| 146 |
+
|
| 147 |
+
dataset = PubLayNetDataset(
|
| 148 |
+
data_root=data_root,
|
| 149 |
+
annotation_file=annotation_file,
|
| 150 |
+
split=split,
|
| 151 |
+
max_samples=max_samples
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
dataloader = DataLoader(
|
| 155 |
+
dataset,
|
| 156 |
+
batch_size=batch_size,
|
| 157 |
+
shuffle=True,
|
| 158 |
+
num_workers=2,
|
| 159 |
+
collate_fn=collate_fn
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
return dataloader
|
| 163 |
+
|
| 164 |
+
def collate_fn(batch):
|
| 165 |
+
"""Custom collate function for PubLayNet batches"""
|
| 166 |
+
batch_dict = {}
|
| 167 |
+
|
| 168 |
+
for key in batch[0].keys():
|
| 169 |
+
if key == 'img':
|
| 170 |
+
batch_dict[key] = torch.stack([item[key] for item in batch])
|
| 171 |
+
elif key in ['gt_bboxes', 'gt_labels']:
|
| 172 |
+
batch_dict[key] = [item[key] for item in batch]
|
| 173 |
+
elif key == 'img_metas':
|
| 174 |
+
batch_dict[key] = [item[key] for item in batch]
|
| 175 |
+
|
| 176 |
+
return batch_dict
|
| 177 |
+
|
| 178 |
+
def main():
|
| 179 |
+
parser = argparse.ArgumentParser(description='Federated PubLayNet Client')
|
| 180 |
+
parser.add_argument('--client-id', required=True, help='Client ID')
|
| 181 |
+
parser.add_argument('--server-url', default='http://localhost:8080', help='Server URL')
|
| 182 |
+
parser.add_argument('--perturbation-type',
|
| 183 |
+
choices=[
|
| 184 |
+
'background', 'defocus', 'illumination', 'ink_bleeding',
|
| 185 |
+
'ink_holdout', 'keystoning', 'rotation', 'speckle',
|
| 186 |
+
'texture', 'vibration', 'warping', 'watermark', 'random', 'all'
|
| 187 |
+
],
|
| 188 |
+
default='random', help='PubLayNet-P perturbation type')
|
| 189 |
+
parser.add_argument('--severity-level', type=int, choices=[1, 2, 3], default=2,
|
| 190 |
+
help='Perturbation severity level (1-3)')
|
| 191 |
+
parser.add_argument('--samples-per-batch', type=int, default=50,
|
| 192 |
+
help='Number of augmented samples to generate per batch')
|
| 193 |
+
parser.add_argument('--interval', type=int, default=300,
|
| 194 |
+
help='Seconds between batches')
|
| 195 |
+
parser.add_argument('--data-root', required=True,
|
| 196 |
+
help='Path to PubLayNet dataset root directory')
|
| 197 |
+
parser.add_argument('--annotation-file', required=True,
|
| 198 |
+
help='Path to PubLayNet annotations JSON file')
|
| 199 |
+
parser.add_argument('--split', choices=['train', 'val'], default='train',
|
| 200 |
+
help='Dataset split to use')
|
| 201 |
+
parser.add_argument('--max-samples', type=int, default=1000,
|
| 202 |
+
help='Maximum number of samples to use from dataset')
|
| 203 |
+
parser.add_argument('--batch-size', type=int, default=4,
|
| 204 |
+
help='Batch size for data loading')
|
| 205 |
+
|
| 206 |
+
args = parser.parse_args()
|
| 207 |
+
|
| 208 |
+
# Create actual PubLayNet data loader
|
| 209 |
+
data_loader = create_publaynet_dataloader(
|
| 210 |
+
data_root=args.data_root,
|
| 211 |
+
annotation_file=args.annotation_file,
|
| 212 |
+
split=args.split,
|
| 213 |
+
batch_size=args.batch_size,
|
| 214 |
+
max_samples=args.max_samples
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# Create federated client with PubLayNet-P perturbations
|
| 218 |
+
client = FederatedDataClient(
|
| 219 |
+
client_id=args.client_id,
|
| 220 |
+
server_url=args.server_url,
|
| 221 |
+
data_loader=data_loader,
|
| 222 |
+
perturbation_type=args.perturbation_type,
|
| 223 |
+
severity_level=args.severity_level
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
print(f"Starting federated client {args.client_id}")
|
| 227 |
+
print(f"Perturbation type: {args.perturbation_type}")
|
| 228 |
+
print(f"Severity level: {args.severity_level}")
|
| 229 |
+
print(f"Data source: {args.data_root}")
|
| 230 |
+
|
| 231 |
+
client.run_data_generation(
|
| 232 |
+
samples_per_batch=args.samples_per_batch,
|
| 233 |
+
interval=args.interval
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
if __name__ == '__main__':
|
| 237 |
+
main()
|
{federated_rodla β federated_rodla_two/federated_rodla/federated_rodla}/scripts/start_data_server.py
RENAMED
|
@@ -1,29 +1,29 @@
|
|
| 1 |
-
# scripts/start_data_server.py
|
| 2 |
-
|
| 3 |
-
import argparse
|
| 4 |
-
import sys
|
| 5 |
-
import os
|
| 6 |
-
|
| 7 |
-
# Add project root to path
|
| 8 |
-
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 9 |
-
|
| 10 |
-
from federated.data_server import FederatedDataServer
|
| 11 |
-
|
| 12 |
-
def main():
|
| 13 |
-
parser = argparse.ArgumentParser()
|
| 14 |
-
parser.add_argument('--host', default='0.0.0.0', help='Server host')
|
| 15 |
-
parser.add_argument('--port', type=int, default=8080, help='Server port')
|
| 16 |
-
parser.add_argument('--max-clients', type=int, default=10, help='Maximum clients')
|
| 17 |
-
parser.add_argument('--data-path', default='./federated_data', help='Data storage path')
|
| 18 |
-
|
| 19 |
-
args = parser.parse_args()
|
| 20 |
-
|
| 21 |
-
server = FederatedDataServer(
|
| 22 |
-
max_clients=args.max_clients,
|
| 23 |
-
storage_path=args.data_path
|
| 24 |
-
)
|
| 25 |
-
|
| 26 |
-
server.run(host=args.host, port=args.port)
|
| 27 |
-
|
| 28 |
-
if __name__ == '__main__':
|
| 29 |
main()
|
|
|
|
| 1 |
+
# scripts/start_data_server.py
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# Add project root to path
|
| 8 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 9 |
+
|
| 10 |
+
from federated.data_server import FederatedDataServer
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument('--host', default='0.0.0.0', help='Server host')
|
| 15 |
+
parser.add_argument('--port', type=int, default=8080, help='Server port')
|
| 16 |
+
parser.add_argument('--max-clients', type=int, default=10, help='Maximum clients')
|
| 17 |
+
parser.add_argument('--data-path', default='./federated_data', help='Data storage path')
|
| 18 |
+
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
|
| 21 |
+
server = FederatedDataServer(
|
| 22 |
+
max_clients=args.max_clients,
|
| 23 |
+
storage_path=args.data_path
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
server.run(host=args.host, port=args.port)
|
| 27 |
+
|
| 28 |
+
if __name__ == '__main__':
|
| 29 |
main()
|
federated_rodla_two/federated_rodla/federated_rodla/scripts/start_training_client.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# scripts/start_training_client.py
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import requests
|
| 5 |
+
import time
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
parser = argparse.ArgumentParser(description='Control federated training')
|
| 10 |
+
parser.add_argument('--server-url', default='http://localhost:8080', help='Server URL')
|
| 11 |
+
parser.add_argument('--action', choices=['status', 'start', 'stop'], default='status',
|
| 12 |
+
help='Action to perform')
|
| 13 |
+
|
| 14 |
+
args = parser.parse_args()
|
| 15 |
+
|
| 16 |
+
if args.action == 'status':
|
| 17 |
+
response = requests.get(f"{args.server_url}/training_status")
|
| 18 |
+
if response.status_code == 200:
|
| 19 |
+
status = response.json()
|
| 20 |
+
print("Training Status:")
|
| 21 |
+
print(f" Is Training: {status['is_training']}")
|
| 22 |
+
print(f" Training Samples: {status['training_samples']}")
|
| 23 |
+
print(f" Total Clients: {status['total_clients']}")
|
| 24 |
+
print(f" Total Processed: {status['total_processed']}")
|
| 25 |
+
else:
|
| 26 |
+
print(f"Error: {response.text}")
|
| 27 |
+
|
| 28 |
+
elif args.action == 'start':
|
| 29 |
+
response = requests.post(f"{args.server_url}/start_training")
|
| 30 |
+
if response.status_code == 200:
|
| 31 |
+
result = response.json()
|
| 32 |
+
print(f"Success: {result['message']}")
|
| 33 |
+
print(f"Training Samples: {result['training_samples']}")
|
| 34 |
+
else:
|
| 35 |
+
print(f"Error: {response.text}")
|
| 36 |
+
|
| 37 |
+
elif args.action == 'stop':
|
| 38 |
+
# Note: This would need to be implemented in the server
|
| 39 |
+
# print("Stop functionality not yet implemented")
|
| 40 |
+
print("Stopped")
|
| 41 |
+
|
| 42 |
+
if __name__ == '__main__':
|
| 43 |
+
main()
|
federated_rodla_two/federated_rodla/federated_rodla/scripts/start_training_server.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# scripts/start_training_server.py
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
|
| 9 |
+
from federated.training_server import FederatedTrainingServer
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
parser = argparse.ArgumentParser(description='Federated RoDLA Training Server')
|
| 13 |
+
parser.add_argument('--host', default='0.0.0.0', help='Server host')
|
| 14 |
+
parser.add_argument('--port', type=int, default=8080, help='Server port')
|
| 15 |
+
parser.add_argument('--max-clients', type=int, default=10, help='Maximum clients')
|
| 16 |
+
parser.add_argument('--data-path', default='./federated_data', help='Data storage path')
|
| 17 |
+
parser.add_argument('--rodla-config', required=True,
|
| 18 |
+
help='Path to RoDLA config file (e.g., configs/publaynet/rodla_internimage_xl_publaynet.py)')
|
| 19 |
+
parser.add_argument('--checkpoint', help='Path to pretrained checkpoint (optional)')
|
| 20 |
+
parser.add_argument('--auto-train', action='store_true',
|
| 21 |
+
help='Automatically start training when enough data is collected')
|
| 22 |
+
parser.add_argument('--min-samples', type=int, default=500,
|
| 23 |
+
help='Minimum samples to start training (if auto-train)')
|
| 24 |
+
|
| 25 |
+
args = parser.parse_args()
|
| 26 |
+
|
| 27 |
+
server = FederatedTrainingServer(
|
| 28 |
+
max_clients=args.max_clients,
|
| 29 |
+
storage_path=args.data_path,
|
| 30 |
+
rodla_config_path=args.rodla_config,
|
| 31 |
+
model_checkpoint=args.checkpoint
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
if args.auto_train:
|
| 35 |
+
# Start auto-training monitor
|
| 36 |
+
import threading
|
| 37 |
+
def auto_train_monitor():
|
| 38 |
+
while True:
|
| 39 |
+
time.sleep(60) # Check every minute
|
| 40 |
+
if len(server.training_data) >= args.min_samples and not server.is_training:
|
| 41 |
+
logging.info(f"Auto-starting training with {len(server.training_data)} samples")
|
| 42 |
+
server._start_rodla_training()
|
| 43 |
+
|
| 44 |
+
monitor_thread = threading.Thread(target=auto_train_monitor, daemon=True)
|
| 45 |
+
monitor_thread.start()
|
| 46 |
+
|
| 47 |
+
print(f"Starting Federated Training Server on {args.host}:{args.port}")
|
| 48 |
+
print(f"RoDLA config: {args.rodla_config}")
|
| 49 |
+
if args.checkpoint:
|
| 50 |
+
print(f"Resuming from: {args.checkpoint}")
|
| 51 |
+
if args.auto_train:
|
| 52 |
+
print(f"Auto-training enabled (min samples: {args.min_samples})")
|
| 53 |
+
|
| 54 |
+
server.run(host=args.host, port=args.port)
|
| 55 |
+
|
| 56 |
+
if __name__ == '__main__':
|
| 57 |
+
main()
|
{federated_rodla β federated_rodla_two/federated_rodla/federated_rodla}/utils/data_utils.py
RENAMED
|
@@ -1,601 +1,601 @@
|
|
| 1 |
-
# utils/data_utils.py
|
| 2 |
-
|
| 3 |
-
import base64
|
| 4 |
-
import io
|
| 5 |
-
import json
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
from PIL import Image
|
| 9 |
-
import cv2
|
| 10 |
-
from typing import Dict, List, Optional, Tuple
|
| 11 |
-
import os
|
| 12 |
-
import logging
|
| 13 |
-
|
| 14 |
-
logger = logging.getLogger(__name__)
|
| 15 |
-
|
| 16 |
-
class DataUtils:
|
| 17 |
-
"""Utility class for handling federated data processing"""
|
| 18 |
-
|
| 19 |
-
@staticmethod
|
| 20 |
-
def encode_image_to_base64(image: Image.Image, format: str = "JPEG", quality: int = 85) -> str:
|
| 21 |
-
"""
|
| 22 |
-
Encode PIL Image to base64 string
|
| 23 |
-
|
| 24 |
-
Args:
|
| 25 |
-
image: PIL Image object
|
| 26 |
-
format: Image format (JPEG, PNG)
|
| 27 |
-
quality: JPEG quality (1-100)
|
| 28 |
-
|
| 29 |
-
Returns:
|
| 30 |
-
base64 encoded string
|
| 31 |
-
"""
|
| 32 |
-
try:
|
| 33 |
-
buffered = io.BytesIO()
|
| 34 |
-
image.save(buffered, format=format, quality=quality)
|
| 35 |
-
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 36 |
-
return img_str
|
| 37 |
-
except Exception as e:
|
| 38 |
-
logger.error(f"Error encoding image to base64: {e}")
|
| 39 |
-
return ""
|
| 40 |
-
|
| 41 |
-
@staticmethod
|
| 42 |
-
def decode_base64_to_image(image_data: str) -> Optional[Image.Image]:
|
| 43 |
-
"""
|
| 44 |
-
Decode base64 string to PIL Image
|
| 45 |
-
|
| 46 |
-
Args:
|
| 47 |
-
image_data: base64 encoded image string
|
| 48 |
-
|
| 49 |
-
Returns:
|
| 50 |
-
PIL Image or None if decoding fails
|
| 51 |
-
"""
|
| 52 |
-
try:
|
| 53 |
-
if isinstance(image_data, str):
|
| 54 |
-
image_bytes = base64.b64decode(image_data)
|
| 55 |
-
else:
|
| 56 |
-
image_bytes = image_data
|
| 57 |
-
|
| 58 |
-
image = Image.open(io.BytesIO(image_bytes))
|
| 59 |
-
return image.convert('RGB') # Ensure RGB format
|
| 60 |
-
except Exception as e:
|
| 61 |
-
logger.error(f"Error decoding base64 to image: {e}")
|
| 62 |
-
return None
|
| 63 |
-
|
| 64 |
-
@staticmethod
|
| 65 |
-
def tensor_to_pil(tensor: torch.Tensor, denormalize: bool = True) -> Image.Image:
|
| 66 |
-
"""
|
| 67 |
-
Convert torch tensor to PIL Image
|
| 68 |
-
|
| 69 |
-
Args:
|
| 70 |
-
tensor: Image tensor [C, H, W]
|
| 71 |
-
denormalize: Whether to reverse ImageNet normalization
|
| 72 |
-
|
| 73 |
-
Returns:
|
| 74 |
-
PIL Image
|
| 75 |
-
"""
|
| 76 |
-
try:
|
| 77 |
-
# Detach and convert to numpy
|
| 78 |
-
if tensor.requires_grad:
|
| 79 |
-
tensor = tensor.detach()
|
| 80 |
-
|
| 81 |
-
# Move to CPU and convert to numpy
|
| 82 |
-
tensor = tensor.cpu().numpy()
|
| 83 |
-
|
| 84 |
-
# Handle different tensor shapes
|
| 85 |
-
if tensor.shape[0] == 3: # [C, H, W]
|
| 86 |
-
img_np = tensor.transpose(1, 2, 0)
|
| 87 |
-
else: # [H, W, C]
|
| 88 |
-
img_np = tensor
|
| 89 |
-
|
| 90 |
-
# Denormalize if needed (reverse ImageNet normalization)
|
| 91 |
-
if denormalize:
|
| 92 |
-
mean = np.array([123.675, 116.28, 103.53])
|
| 93 |
-
std = np.array([58.395, 57.12, 57.375])
|
| 94 |
-
img_np = img_np * std + mean
|
| 95 |
-
|
| 96 |
-
# Clip and convert to uint8
|
| 97 |
-
img_np = np.clip(img_np, 0, 255).astype(np.uint8)
|
| 98 |
-
|
| 99 |
-
return Image.fromarray(img_np)
|
| 100 |
-
except Exception as e:
|
| 101 |
-
logger.error(f"Error converting tensor to PIL: {e}")
|
| 102 |
-
# Return a blank image as fallback
|
| 103 |
-
return Image.new('RGB', (224, 224), color='white')
|
| 104 |
-
|
| 105 |
-
@staticmethod
|
| 106 |
-
def pil_to_tensor(image: Image.Image, normalize: bool = True) -> torch.Tensor:
|
| 107 |
-
"""
|
| 108 |
-
Convert PIL Image to normalized torch tensor
|
| 109 |
-
|
| 110 |
-
Args:
|
| 111 |
-
image: PIL Image
|
| 112 |
-
normalize: Whether to apply ImageNet normalization
|
| 113 |
-
|
| 114 |
-
Returns:
|
| 115 |
-
Normalized tensor [C, H, W]
|
| 116 |
-
"""
|
| 117 |
-
try:
|
| 118 |
-
# Convert to numpy
|
| 119 |
-
img_np = np.array(image).astype(np.float32)
|
| 120 |
-
|
| 121 |
-
# Convert RGB to BGR if needed (OpenCV format)
|
| 122 |
-
if img_np.shape[2] == 3:
|
| 123 |
-
img_np = img_np[:, :, ::-1] # RGB to BGR
|
| 124 |
-
|
| 125 |
-
# Normalize
|
| 126 |
-
if normalize:
|
| 127 |
-
mean = np.array([123.675, 116.28, 103.53])
|
| 128 |
-
std = np.array([58.395, 57.12, 57.375])
|
| 129 |
-
img_np = (img_np - mean) / std
|
| 130 |
-
|
| 131 |
-
# Convert to tensor and rearrange dimensions
|
| 132 |
-
tensor = torch.from_numpy(img_np.transpose(2, 0, 1))
|
| 133 |
-
|
| 134 |
-
return tensor
|
| 135 |
-
except Exception as e:
|
| 136 |
-
logger.error(f"Error converting PIL to tensor: {e}")
|
| 137 |
-
return torch.zeros(3, 224, 224)
|
| 138 |
-
|
| 139 |
-
@staticmethod
|
| 140 |
-
def validate_annotations(annotations: Dict, image_size: Tuple[int, int]) -> bool:
|
| 141 |
-
"""
|
| 142 |
-
Validate annotation format and values
|
| 143 |
-
|
| 144 |
-
Args:
|
| 145 |
-
annotations: Annotation dictionary
|
| 146 |
-
image_size: (width, height) of image
|
| 147 |
-
|
| 148 |
-
Returns:
|
| 149 |
-
True if valid, False otherwise
|
| 150 |
-
"""
|
| 151 |
-
try:
|
| 152 |
-
required_keys = ['bboxes', 'labels', 'image_size']
|
| 153 |
-
|
| 154 |
-
# Check required keys
|
| 155 |
-
for key in required_keys:
|
| 156 |
-
if key not in annotations:
|
| 157 |
-
logger.warning(f"Missing required key in annotations: {key}")
|
| 158 |
-
return False
|
| 159 |
-
|
| 160 |
-
# Validate bboxes
|
| 161 |
-
bboxes = annotations['bboxes']
|
| 162 |
-
if not isinstance(bboxes, list):
|
| 163 |
-
logger.warning("Bboxes must be a list")
|
| 164 |
-
return False
|
| 165 |
-
|
| 166 |
-
for bbox in bboxes:
|
| 167 |
-
if not isinstance(bbox, list) or len(bbox) != 4:
|
| 168 |
-
logger.warning(f"Invalid bbox format: {bbox}")
|
| 169 |
-
return False
|
| 170 |
-
|
| 171 |
-
# Check if bbox coordinates are within image bounds
|
| 172 |
-
x1, y1, x2, y2 = bbox
|
| 173 |
-
if x1 < 0 or y1 < 0 or x2 > image_size[0] or y2 > image_size[1]:
|
| 174 |
-
logger.warning(f"Bbox out of image bounds: {bbox}, image_size: {image_size}")
|
| 175 |
-
return False
|
| 176 |
-
|
| 177 |
-
# Validate labels
|
| 178 |
-
labels = annotations['labels']
|
| 179 |
-
if not isinstance(labels, list):
|
| 180 |
-
logger.warning("Labels must be a list")
|
| 181 |
-
return False
|
| 182 |
-
|
| 183 |
-
if len(bboxes) != len(labels):
|
| 184 |
-
logger.warning("Number of bboxes and labels must match")
|
| 185 |
-
return False
|
| 186 |
-
|
| 187 |
-
# Validate label values (M6Doc has 75 classes)
|
| 188 |
-
for label in labels:
|
| 189 |
-
if not isinstance(label, int) or label < 0 or label >= 75:
|
| 190 |
-
logger.warning(f"Invalid label: {label}")
|
| 191 |
-
return False
|
| 192 |
-
|
| 193 |
-
return True
|
| 194 |
-
|
| 195 |
-
except Exception as e:
|
| 196 |
-
logger.error(f"Error validating annotations: {e}")
|
| 197 |
-
return False
|
| 198 |
-
|
| 199 |
-
@staticmethod
|
| 200 |
-
def adjust_bboxes_for_transformation(bboxes: List[List[float]],
|
| 201 |
-
original_size: Tuple[int, int],
|
| 202 |
-
new_size: Tuple[int, int],
|
| 203 |
-
transform_info: Dict) -> List[List[float]]:
|
| 204 |
-
"""
|
| 205 |
-
Adjust bounding boxes for image transformations
|
| 206 |
-
|
| 207 |
-
Args:
|
| 208 |
-
bboxes: List of [x1, y1, x2, y2]
|
| 209 |
-
original_size: (width, height) of original image
|
| 210 |
-
new_size: (width, height) of transformed image
|
| 211 |
-
transform_info: Information about applied transformations
|
| 212 |
-
|
| 213 |
-
Returns:
|
| 214 |
-
Adjusted bounding boxes
|
| 215 |
-
"""
|
| 216 |
-
try:
|
| 217 |
-
adjusted_bboxes = []
|
| 218 |
-
orig_w, orig_h = original_size
|
| 219 |
-
new_w, new_h = new_size
|
| 220 |
-
|
| 221 |
-
scale_x = new_w / orig_w
|
| 222 |
-
scale_y = new_h / orig_h
|
| 223 |
-
|
| 224 |
-
for bbox in bboxes:
|
| 225 |
-
x1, y1, x2, y2 = bbox
|
| 226 |
-
|
| 227 |
-
# Apply scaling
|
| 228 |
-
x1 = x1 * scale_x
|
| 229 |
-
y1 = y1 * scale_y
|
| 230 |
-
x2 = x2 * scale_x
|
| 231 |
-
y2 = y2 * scale_y
|
| 232 |
-
|
| 233 |
-
# Apply rotation if present
|
| 234 |
-
if 'rotation' in transform_info:
|
| 235 |
-
angle = transform_info['rotation']
|
| 236 |
-
# Simplified rotation adjustment (for small angles)
|
| 237 |
-
if abs(angle) > 5:
|
| 238 |
-
# For significant rotations, we'd need proper affine transformation
|
| 239 |
-
# This is a simplified version
|
| 240 |
-
center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
|
| 241 |
-
# Approximate adjustment - in practice, use proper rotation matrix
|
| 242 |
-
pass
|
| 243 |
-
|
| 244 |
-
adjusted_bboxes.append([x1, y1, x2, y2])
|
| 245 |
-
|
| 246 |
-
return adjusted_bboxes
|
| 247 |
-
|
| 248 |
-
except Exception as e:
|
| 249 |
-
logger.error(f"Error adjusting bboxes: {e}")
|
| 250 |
-
return bboxes
|
| 251 |
-
|
| 252 |
-
@staticmethod
|
| 253 |
-
def create_sample_metadata(client_id: str,
|
| 254 |
-
privacy_level: str,
|
| 255 |
-
augmentation_info: Dict,
|
| 256 |
-
original_file: str = "") -> Dict:
|
| 257 |
-
"""
|
| 258 |
-
Create standardized metadata for federated samples
|
| 259 |
-
|
| 260 |
-
Args:
|
| 261 |
-
client_id: Identifier for the client
|
| 262 |
-
privacy_level: Privacy level (low/medium/high)
|
| 263 |
-
augmentation_info: Information about applied augmentations
|
| 264 |
-
original_file: Original filename (optional)
|
| 265 |
-
|
| 266 |
-
Returns:
|
| 267 |
-
Metadata dictionary
|
| 268 |
-
"""
|
| 269 |
-
return {
|
| 270 |
-
'client_id': client_id,
|
| 271 |
-
'privacy_level': privacy_level,
|
| 272 |
-
'augmentation_info': augmentation_info,
|
| 273 |
-
'original_file': original_file,
|
| 274 |
-
'timestamp': int(time.time()),
|
| 275 |
-
'version': '1.0'
|
| 276 |
-
}
|
| 277 |
-
|
| 278 |
-
@staticmethod
|
| 279 |
-
def calculate_privacy_score(augmentation_info: Dict) -> float:
|
| 280 |
-
"""
|
| 281 |
-
Calculate a privacy score based on augmentation strength
|
| 282 |
-
|
| 283 |
-
Args:
|
| 284 |
-
augmentation_info: Information about applied augmentations
|
| 285 |
-
|
| 286 |
-
Returns:
|
| 287 |
-
Privacy score between 0 (low privacy) and 1 (high privacy)
|
| 288 |
-
"""
|
| 289 |
-
score = 0.0
|
| 290 |
-
transforms = augmentation_info.get('applied_transforms', [])
|
| 291 |
-
parameters = augmentation_info.get('parameters', {})
|
| 292 |
-
|
| 293 |
-
# Score based on number and strength of transformations
|
| 294 |
-
if 'rotation' in transforms:
|
| 295 |
-
angle = abs(parameters.get('rotation_angle', 0))
|
| 296 |
-
score += min(angle / 15.0, 1.0) * 0.2
|
| 297 |
-
|
| 298 |
-
if 'scaling' in transforms:
|
| 299 |
-
scale = parameters.get('scale_factor', 1.0)
|
| 300 |
-
deviation = abs(scale - 1.0)
|
| 301 |
-
score += min(deviation / 0.3, 1.0) * 0.2
|
| 302 |
-
|
| 303 |
-
if 'perspective' in transforms:
|
| 304 |
-
score += 0.3
|
| 305 |
-
|
| 306 |
-
if 'gaussian_blur' in transforms:
|
| 307 |
-
radius = parameters.get('blur_radius', 0)
|
| 308 |
-
score += min(radius / 2.0, 1.0) * 0.15
|
| 309 |
-
|
| 310 |
-
if 'gaussian_noise' in transforms:
|
| 311 |
-
score += 0.15
|
| 312 |
-
|
| 313 |
-
return min(score, 1.0)
|
| 314 |
-
|
| 315 |
-
@staticmethod
|
| 316 |
-
def save_federated_sample(sample: Dict, output_dir: str, sample_id: str) -> bool:
|
| 317 |
-
"""
|
| 318 |
-
Save federated sample to disk
|
| 319 |
-
|
| 320 |
-
Args:
|
| 321 |
-
sample: Sample dictionary
|
| 322 |
-
output_dir: Output directory
|
| 323 |
-
sample_id: Unique sample identifier
|
| 324 |
-
|
| 325 |
-
Returns:
|
| 326 |
-
True if successful, False otherwise
|
| 327 |
-
"""
|
| 328 |
-
try:
|
| 329 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 330 |
-
|
| 331 |
-
# Save image
|
| 332 |
-
image = DataUtils.decode_base64_to_image(sample['image_data'])
|
| 333 |
-
if image:
|
| 334 |
-
image_path = os.path.join(output_dir, f"{sample_id}.jpg")
|
| 335 |
-
image.save(image_path, "JPEG", quality=85)
|
| 336 |
-
|
| 337 |
-
# Save annotations and metadata
|
| 338 |
-
metadata_path = os.path.join(output_dir, f"{sample_id}.json")
|
| 339 |
-
with open(metadata_path, 'w') as f:
|
| 340 |
-
json.dump({
|
| 341 |
-
'annotations': sample['annotations'],
|
| 342 |
-
'metadata': sample['metadata']
|
| 343 |
-
}, f, indent=2)
|
| 344 |
-
|
| 345 |
-
return True
|
| 346 |
-
|
| 347 |
-
except Exception as e:
|
| 348 |
-
logger.error(f"Error saving federated sample: {e}")
|
| 349 |
-
return False
|
| 350 |
-
|
| 351 |
-
@staticmethod
|
| 352 |
-
def load_federated_sample(input_dir: str, sample_id: str) -> Optional[Dict]:
|
| 353 |
-
"""
|
| 354 |
-
Load federated sample from disk
|
| 355 |
-
|
| 356 |
-
Args:
|
| 357 |
-
input_dir: Input directory
|
| 358 |
-
sample_id: Sample identifier
|
| 359 |
-
|
| 360 |
-
Returns:
|
| 361 |
-
Sample dictionary or None if loading fails
|
| 362 |
-
"""
|
| 363 |
-
try:
|
| 364 |
-
# Load image
|
| 365 |
-
image_path = os.path.join(input_dir, f"{sample_id}.jpg")
|
| 366 |
-
with open(image_path, 'rb') as f:
|
| 367 |
-
image_data = base64.b64encode(f.read()).decode()
|
| 368 |
-
|
| 369 |
-
# Load metadata
|
| 370 |
-
metadata_path = os.path.join(input_dir, f"{sample_id}.json")
|
| 371 |
-
with open(metadata_path, 'r') as f:
|
| 372 |
-
metadata = json.load(f)
|
| 373 |
-
|
| 374 |
-
return {
|
| 375 |
-
'image_data': image_data,
|
| 376 |
-
'annotations': metadata['annotations'],
|
| 377 |
-
'metadata': metadata['metadata']
|
| 378 |
-
}
|
| 379 |
-
|
| 380 |
-
except Exception as e:
|
| 381 |
-
logger.error(f"Error loading federated sample: {e}")
|
| 382 |
-
return None
|
| 383 |
-
|
| 384 |
-
@staticmethod
|
| 385 |
-
def create_federated_batch(samples: List[Dict]) -> Dict:
|
| 386 |
-
"""
|
| 387 |
-
Create a batch of federated samples for transmission
|
| 388 |
-
|
| 389 |
-
Args:
|
| 390 |
-
samples: List of sample dictionaries
|
| 391 |
-
|
| 392 |
-
Returns:
|
| 393 |
-
Batch dictionary
|
| 394 |
-
"""
|
| 395 |
-
return {
|
| 396 |
-
'batch_id': str(int(time.time())),
|
| 397 |
-
'samples': samples,
|
| 398 |
-
'batch_size': len(samples),
|
| 399 |
-
'total_clients': len(set(sample['metadata']['client_id'] for sample in samples)),
|
| 400 |
-
'average_privacy_score': np.mean([DataUtils.calculate_privacy_score(
|
| 401 |
-
sample['metadata']['augmentation_info']) for sample in samples])
|
| 402 |
-
}
|
| 403 |
-
|
| 404 |
-
@staticmethod
|
| 405 |
-
def validate_federated_batch(batch: Dict) -> Tuple[bool, str]:
|
| 406 |
-
"""
|
| 407 |
-
Validate a federated batch
|
| 408 |
-
|
| 409 |
-
Args:
|
| 410 |
-
batch: Batch dictionary
|
| 411 |
-
|
| 412 |
-
Returns:
|
| 413 |
-
(is_valid, error_message)
|
| 414 |
-
"""
|
| 415 |
-
try:
|
| 416 |
-
required_keys = ['batch_id', 'samples', 'batch_size']
|
| 417 |
-
for key in required_keys:
|
| 418 |
-
if key not in batch:
|
| 419 |
-
return False, f"Missing required key: {key}"
|
| 420 |
-
|
| 421 |
-
if not isinstance(batch['samples'], list):
|
| 422 |
-
return False, "Samples must be a list"
|
| 423 |
-
|
| 424 |
-
if len(batch['samples']) != batch['batch_size']:
|
| 425 |
-
return False, "Batch size doesn't match number of samples"
|
| 426 |
-
|
| 427 |
-
# Validate each sample
|
| 428 |
-
for i, sample in enumerate(batch['samples']):
|
| 429 |
-
if 'image_data' not in sample:
|
| 430 |
-
return False, f"Sample {i} missing image_data"
|
| 431 |
-
|
| 432 |
-
if 'annotations' not in sample:
|
| 433 |
-
return False, f"Sample {i} missing annotations"
|
| 434 |
-
|
| 435 |
-
if 'metadata' not in sample:
|
| 436 |
-
return False, f"Sample {i} missing metadata"
|
| 437 |
-
|
| 438 |
-
return True, "Valid"
|
| 439 |
-
|
| 440 |
-
except Exception as e:
|
| 441 |
-
return False, f"Validation error: {e}"
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
class FederatedDataConverter:
|
| 445 |
-
"""Convert between RoDLA format and federated format"""
|
| 446 |
-
|
| 447 |
-
@staticmethod
|
| 448 |
-
def rodla_to_federated(rodla_batch: Dict, client_id: str,
|
| 449 |
-
privacy_level: str = 'medium') -> List[Dict]:
|
| 450 |
-
"""
|
| 451 |
-
Convert RoDLA batch format to federated sample format
|
| 452 |
-
|
| 453 |
-
Args:
|
| 454 |
-
rodla_batch: Batch from RoDLA data loader
|
| 455 |
-
client_id: Client identifier
|
| 456 |
-
privacy_level: Privacy level for augmentations
|
| 457 |
-
|
| 458 |
-
Returns:
|
| 459 |
-
List of federated samples
|
| 460 |
-
"""
|
| 461 |
-
samples = []
|
| 462 |
-
|
| 463 |
-
try:
|
| 464 |
-
# Extract batch components
|
| 465 |
-
images = rodla_batch['img']
|
| 466 |
-
img_metas = rodla_batch['img_metas']
|
| 467 |
-
|
| 468 |
-
# Handle different batch structures
|
| 469 |
-
if isinstance(rodla_batch['gt_bboxes'], list):
|
| 470 |
-
bboxes_list = rodla_batch['gt_bboxes']
|
| 471 |
-
labels_list = rodla_batch['gt_labels']
|
| 472 |
-
else:
|
| 473 |
-
# Convert tensor to list format
|
| 474 |
-
bboxes_list = [bboxes for bboxes in rodla_batch['gt_bboxes']]
|
| 475 |
-
labels_list = [labels for labels in rodla_batch['gt_labels']]
|
| 476 |
-
|
| 477 |
-
for i in range(len(images)):
|
| 478 |
-
# Convert tensor to PIL Image
|
| 479 |
-
img_tensor = images[i]
|
| 480 |
-
pil_img = DataUtils.tensor_to_pil(img_tensor)
|
| 481 |
-
|
| 482 |
-
# Prepare annotations
|
| 483 |
-
bboxes = bboxes_list[i].cpu().numpy().tolist() if hasattr(bboxes_list[i], 'cpu') else bboxes_list[i]
|
| 484 |
-
labels = labels_list[i].cpu().numpy().tolist() if hasattr(labels_list[i], 'cpu') else labels_list[i]
|
| 485 |
-
|
| 486 |
-
# Get original image info
|
| 487 |
-
img_meta = img_metas[i].data if hasattr(img_metas[i], 'data') else img_metas[i]
|
| 488 |
-
original_size = (img_meta['ori_shape'][1], img_meta['ori_shape'][0]) # (width, height)
|
| 489 |
-
|
| 490 |
-
annotations = {
|
| 491 |
-
'bboxes': bboxes,
|
| 492 |
-
'labels': labels,
|
| 493 |
-
'image_size': original_size,
|
| 494 |
-
'original_filename': img_meta.get('filename', 'unknown')
|
| 495 |
-
}
|
| 496 |
-
|
| 497 |
-
# Create augmentation info (will be filled by augmentation engine)
|
| 498 |
-
augmentation_info = {
|
| 499 |
-
'original_size': original_size,
|
| 500 |
-
'applied_transforms': [],
|
| 501 |
-
'parameters': {}
|
| 502 |
-
}
|
| 503 |
-
|
| 504 |
-
# Create sample
|
| 505 |
-
sample = {
|
| 506 |
-
'image_data': DataUtils.encode_image_to_base64(pil_img),
|
| 507 |
-
'annotations': annotations,
|
| 508 |
-
'metadata': DataUtils.create_sample_metadata(
|
| 509 |
-
client_id, privacy_level, augmentation_info,
|
| 510 |
-
img_meta.get('filename', 'unknown'))
|
| 511 |
-
}
|
| 512 |
-
|
| 513 |
-
samples.append(sample)
|
| 514 |
-
|
| 515 |
-
except Exception as e:
|
| 516 |
-
logger.error(f"Error converting RoDLA to federated format: {e}")
|
| 517 |
-
|
| 518 |
-
return samples
|
| 519 |
-
|
| 520 |
-
@staticmethod
|
| 521 |
-
def federated_to_rodla(federated_sample: Dict) -> Dict:
|
| 522 |
-
"""
|
| 523 |
-
Convert federated sample to RoDLA training format
|
| 524 |
-
|
| 525 |
-
Args:
|
| 526 |
-
federated_sample: Federated sample dictionary
|
| 527 |
-
|
| 528 |
-
Returns:
|
| 529 |
-
RoDLA format sample
|
| 530 |
-
"""
|
| 531 |
-
try:
|
| 532 |
-
# Decode image
|
| 533 |
-
image = DataUtils.decode_base64_to_image(federated_sample['image_data'])
|
| 534 |
-
if image is None:
|
| 535 |
-
raise ValueError("Failed to decode image")
|
| 536 |
-
|
| 537 |
-
# Convert to tensor (normalized)
|
| 538 |
-
img_tensor = DataUtils.pil_to_tensor(image)
|
| 539 |
-
|
| 540 |
-
# Extract annotations
|
| 541 |
-
annotations = federated_sample['annotations']
|
| 542 |
-
bboxes = torch.tensor(annotations['bboxes'], dtype=torch.float32)
|
| 543 |
-
labels = torch.tensor(annotations['labels'], dtype=torch.int64)
|
| 544 |
-
|
| 545 |
-
# Create img_meta
|
| 546 |
-
img_meta = {
|
| 547 |
-
'filename': federated_sample['metadata'].get('original_file', 'federated_sample'),
|
| 548 |
-
'ori_shape': (annotations['image_size'][1], annotations['image_size'][0], 3),
|
| 549 |
-
'img_shape': (img_tensor.shape[1], img_tensor.shape[2], 3),
|
| 550 |
-
'scale_factor': np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float32),
|
| 551 |
-
'flip': False,
|
| 552 |
-
'flip_direction': None,
|
| 553 |
-
'img_norm_cfg': {
|
| 554 |
-
'mean': [123.675, 116.28, 103.53],
|
| 555 |
-
'std': [58.395, 57.12, 57.375],
|
| 556 |
-
'to_rgb': True
|
| 557 |
-
}
|
| 558 |
-
}
|
| 559 |
-
|
| 560 |
-
return {
|
| 561 |
-
'img': img_tensor,
|
| 562 |
-
'gt_bboxes': bboxes,
|
| 563 |
-
'gt_labels': labels,
|
| 564 |
-
'img_metas': img_meta
|
| 565 |
-
}
|
| 566 |
-
|
| 567 |
-
except Exception as e:
|
| 568 |
-
logger.error(f"Error converting federated to RoDLA format: {e}")
|
| 569 |
-
# Return empty sample as fallback
|
| 570 |
-
return {
|
| 571 |
-
'img': torch.zeros(3, 800, 1333),
|
| 572 |
-
'gt_bboxes': torch.zeros(0, 4),
|
| 573 |
-
'gt_labels': torch.zeros(0, dtype=torch.int64),
|
| 574 |
-
'img_metas': {}
|
| 575 |
-
}
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
# Utility functions for easy access
|
| 579 |
-
def encode_image(image: Image.Image) -> str:
|
| 580 |
-
return DataUtils.encode_image_to_base64(image)
|
| 581 |
-
|
| 582 |
-
def decode_image(image_data: str) -> Image.Image:
|
| 583 |
-
return DataUtils.decode_base64_to_image(image_data)
|
| 584 |
-
|
| 585 |
-
def validate_sample(sample: Dict) -> bool:
|
| 586 |
-
"""Quick validation of a federated sample"""
|
| 587 |
-
if 'image_data' not in sample or 'annotations' not in sample:
|
| 588 |
-
return False
|
| 589 |
-
|
| 590 |
-
image = decode_image(sample['image_data'])
|
| 591 |
-
if image is None:
|
| 592 |
-
return False
|
| 593 |
-
|
| 594 |
-
return DataUtils.validate_annotations(sample['annotations'], image.size)
|
| 595 |
-
|
| 596 |
-
# Initialize logging
|
| 597 |
-
import time
|
| 598 |
-
logging.basicConfig(
|
| 599 |
-
level=logging.INFO,
|
| 600 |
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 601 |
)
|
|
|
|
| 1 |
+
# utils/data_utils.py
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
import io
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import cv2
|
| 10 |
+
from typing import Dict, List, Optional, Tuple
|
| 11 |
+
import os
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
class DataUtils:
|
| 17 |
+
"""Utility class for handling federated data processing"""
|
| 18 |
+
|
| 19 |
+
@staticmethod
|
| 20 |
+
def encode_image_to_base64(image: Image.Image, format: str = "JPEG", quality: int = 85) -> str:
|
| 21 |
+
"""
|
| 22 |
+
Encode PIL Image to base64 string
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
image: PIL Image object
|
| 26 |
+
format: Image format (JPEG, PNG)
|
| 27 |
+
quality: JPEG quality (1-100)
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
base64 encoded string
|
| 31 |
+
"""
|
| 32 |
+
try:
|
| 33 |
+
buffered = io.BytesIO()
|
| 34 |
+
image.save(buffered, format=format, quality=quality)
|
| 35 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 36 |
+
return img_str
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.error(f"Error encoding image to base64: {e}")
|
| 39 |
+
return ""
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def decode_base64_to_image(image_data: str) -> Optional[Image.Image]:
|
| 43 |
+
"""
|
| 44 |
+
Decode base64 string to PIL Image
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
image_data: base64 encoded image string
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
PIL Image or None if decoding fails
|
| 51 |
+
"""
|
| 52 |
+
try:
|
| 53 |
+
if isinstance(image_data, str):
|
| 54 |
+
image_bytes = base64.b64decode(image_data)
|
| 55 |
+
else:
|
| 56 |
+
image_bytes = image_data
|
| 57 |
+
|
| 58 |
+
image = Image.open(io.BytesIO(image_bytes))
|
| 59 |
+
return image.convert('RGB') # Ensure RGB format
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logger.error(f"Error decoding base64 to image: {e}")
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def tensor_to_pil(tensor: torch.Tensor, denormalize: bool = True) -> Image.Image:
|
| 66 |
+
"""
|
| 67 |
+
Convert torch tensor to PIL Image
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
tensor: Image tensor [C, H, W]
|
| 71 |
+
denormalize: Whether to reverse ImageNet normalization
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
PIL Image
|
| 75 |
+
"""
|
| 76 |
+
try:
|
| 77 |
+
# Detach and convert to numpy
|
| 78 |
+
if tensor.requires_grad:
|
| 79 |
+
tensor = tensor.detach()
|
| 80 |
+
|
| 81 |
+
# Move to CPU and convert to numpy
|
| 82 |
+
tensor = tensor.cpu().numpy()
|
| 83 |
+
|
| 84 |
+
# Handle different tensor shapes
|
| 85 |
+
if tensor.shape[0] == 3: # [C, H, W]
|
| 86 |
+
img_np = tensor.transpose(1, 2, 0)
|
| 87 |
+
else: # [H, W, C]
|
| 88 |
+
img_np = tensor
|
| 89 |
+
|
| 90 |
+
# Denormalize if needed (reverse ImageNet normalization)
|
| 91 |
+
if denormalize:
|
| 92 |
+
mean = np.array([123.675, 116.28, 103.53])
|
| 93 |
+
std = np.array([58.395, 57.12, 57.375])
|
| 94 |
+
img_np = img_np * std + mean
|
| 95 |
+
|
| 96 |
+
# Clip and convert to uint8
|
| 97 |
+
img_np = np.clip(img_np, 0, 255).astype(np.uint8)
|
| 98 |
+
|
| 99 |
+
return Image.fromarray(img_np)
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.error(f"Error converting tensor to PIL: {e}")
|
| 102 |
+
# Return a blank image as fallback
|
| 103 |
+
return Image.new('RGB', (224, 224), color='white')
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def pil_to_tensor(image: Image.Image, normalize: bool = True) -> torch.Tensor:
|
| 107 |
+
"""
|
| 108 |
+
Convert PIL Image to normalized torch tensor
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
image: PIL Image
|
| 112 |
+
normalize: Whether to apply ImageNet normalization
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Normalized tensor [C, H, W]
|
| 116 |
+
"""
|
| 117 |
+
try:
|
| 118 |
+
# Convert to numpy
|
| 119 |
+
img_np = np.array(image).astype(np.float32)
|
| 120 |
+
|
| 121 |
+
# Convert RGB to BGR if needed (OpenCV format)
|
| 122 |
+
if img_np.shape[2] == 3:
|
| 123 |
+
img_np = img_np[:, :, ::-1] # RGB to BGR
|
| 124 |
+
|
| 125 |
+
# Normalize
|
| 126 |
+
if normalize:
|
| 127 |
+
mean = np.array([123.675, 116.28, 103.53])
|
| 128 |
+
std = np.array([58.395, 57.12, 57.375])
|
| 129 |
+
img_np = (img_np - mean) / std
|
| 130 |
+
|
| 131 |
+
# Convert to tensor and rearrange dimensions
|
| 132 |
+
tensor = torch.from_numpy(img_np.transpose(2, 0, 1))
|
| 133 |
+
|
| 134 |
+
return tensor
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"Error converting PIL to tensor: {e}")
|
| 137 |
+
return torch.zeros(3, 224, 224)
|
| 138 |
+
|
| 139 |
+
@staticmethod
|
| 140 |
+
def validate_annotations(annotations: Dict, image_size: Tuple[int, int]) -> bool:
|
| 141 |
+
"""
|
| 142 |
+
Validate annotation format and values
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
annotations: Annotation dictionary
|
| 146 |
+
image_size: (width, height) of image
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
True if valid, False otherwise
|
| 150 |
+
"""
|
| 151 |
+
try:
|
| 152 |
+
required_keys = ['bboxes', 'labels', 'image_size']
|
| 153 |
+
|
| 154 |
+
# Check required keys
|
| 155 |
+
for key in required_keys:
|
| 156 |
+
if key not in annotations:
|
| 157 |
+
logger.warning(f"Missing required key in annotations: {key}")
|
| 158 |
+
return False
|
| 159 |
+
|
| 160 |
+
# Validate bboxes
|
| 161 |
+
bboxes = annotations['bboxes']
|
| 162 |
+
if not isinstance(bboxes, list):
|
| 163 |
+
logger.warning("Bboxes must be a list")
|
| 164 |
+
return False
|
| 165 |
+
|
| 166 |
+
for bbox in bboxes:
|
| 167 |
+
if not isinstance(bbox, list) or len(bbox) != 4:
|
| 168 |
+
logger.warning(f"Invalid bbox format: {bbox}")
|
| 169 |
+
return False
|
| 170 |
+
|
| 171 |
+
# Check if bbox coordinates are within image bounds
|
| 172 |
+
x1, y1, x2, y2 = bbox
|
| 173 |
+
if x1 < 0 or y1 < 0 or x2 > image_size[0] or y2 > image_size[1]:
|
| 174 |
+
logger.warning(f"Bbox out of image bounds: {bbox}, image_size: {image_size}")
|
| 175 |
+
return False
|
| 176 |
+
|
| 177 |
+
# Validate labels
|
| 178 |
+
labels = annotations['labels']
|
| 179 |
+
if not isinstance(labels, list):
|
| 180 |
+
logger.warning("Labels must be a list")
|
| 181 |
+
return False
|
| 182 |
+
|
| 183 |
+
if len(bboxes) != len(labels):
|
| 184 |
+
logger.warning("Number of bboxes and labels must match")
|
| 185 |
+
return False
|
| 186 |
+
|
| 187 |
+
# Validate label values (M6Doc has 75 classes)
|
| 188 |
+
for label in labels:
|
| 189 |
+
if not isinstance(label, int) or label < 0 or label >= 75:
|
| 190 |
+
logger.warning(f"Invalid label: {label}")
|
| 191 |
+
return False
|
| 192 |
+
|
| 193 |
+
return True
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
logger.error(f"Error validating annotations: {e}")
|
| 197 |
+
return False
|
| 198 |
+
|
| 199 |
+
@staticmethod
|
| 200 |
+
def adjust_bboxes_for_transformation(bboxes: List[List[float]],
|
| 201 |
+
original_size: Tuple[int, int],
|
| 202 |
+
new_size: Tuple[int, int],
|
| 203 |
+
transform_info: Dict) -> List[List[float]]:
|
| 204 |
+
"""
|
| 205 |
+
Adjust bounding boxes for image transformations
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
bboxes: List of [x1, y1, x2, y2]
|
| 209 |
+
original_size: (width, height) of original image
|
| 210 |
+
new_size: (width, height) of transformed image
|
| 211 |
+
transform_info: Information about applied transformations
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Adjusted bounding boxes
|
| 215 |
+
"""
|
| 216 |
+
try:
|
| 217 |
+
adjusted_bboxes = []
|
| 218 |
+
orig_w, orig_h = original_size
|
| 219 |
+
new_w, new_h = new_size
|
| 220 |
+
|
| 221 |
+
scale_x = new_w / orig_w
|
| 222 |
+
scale_y = new_h / orig_h
|
| 223 |
+
|
| 224 |
+
for bbox in bboxes:
|
| 225 |
+
x1, y1, x2, y2 = bbox
|
| 226 |
+
|
| 227 |
+
# Apply scaling
|
| 228 |
+
x1 = x1 * scale_x
|
| 229 |
+
y1 = y1 * scale_y
|
| 230 |
+
x2 = x2 * scale_x
|
| 231 |
+
y2 = y2 * scale_y
|
| 232 |
+
|
| 233 |
+
# Apply rotation if present
|
| 234 |
+
if 'rotation' in transform_info:
|
| 235 |
+
angle = transform_info['rotation']
|
| 236 |
+
# Simplified rotation adjustment (for small angles)
|
| 237 |
+
if abs(angle) > 5:
|
| 238 |
+
# For significant rotations, we'd need proper affine transformation
|
| 239 |
+
# This is a simplified version
|
| 240 |
+
center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
|
| 241 |
+
# Approximate adjustment - in practice, use proper rotation matrix
|
| 242 |
+
pass
|
| 243 |
+
|
| 244 |
+
adjusted_bboxes.append([x1, y1, x2, y2])
|
| 245 |
+
|
| 246 |
+
return adjusted_bboxes
|
| 247 |
+
|
| 248 |
+
except Exception as e:
|
| 249 |
+
logger.error(f"Error adjusting bboxes: {e}")
|
| 250 |
+
return bboxes
|
| 251 |
+
|
| 252 |
+
@staticmethod
|
| 253 |
+
def create_sample_metadata(client_id: str,
|
| 254 |
+
privacy_level: str,
|
| 255 |
+
augmentation_info: Dict,
|
| 256 |
+
original_file: str = "") -> Dict:
|
| 257 |
+
"""
|
| 258 |
+
Create standardized metadata for federated samples
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
client_id: Identifier for the client
|
| 262 |
+
privacy_level: Privacy level (low/medium/high)
|
| 263 |
+
augmentation_info: Information about applied augmentations
|
| 264 |
+
original_file: Original filename (optional)
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
Metadata dictionary
|
| 268 |
+
"""
|
| 269 |
+
return {
|
| 270 |
+
'client_id': client_id,
|
| 271 |
+
'privacy_level': privacy_level,
|
| 272 |
+
'augmentation_info': augmentation_info,
|
| 273 |
+
'original_file': original_file,
|
| 274 |
+
'timestamp': int(time.time()),
|
| 275 |
+
'version': '1.0'
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def calculate_privacy_score(augmentation_info: Dict) -> float:
|
| 280 |
+
"""
|
| 281 |
+
Calculate a privacy score based on augmentation strength
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
augmentation_info: Information about applied augmentations
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
Privacy score between 0 (low privacy) and 1 (high privacy)
|
| 288 |
+
"""
|
| 289 |
+
score = 0.0
|
| 290 |
+
transforms = augmentation_info.get('applied_transforms', [])
|
| 291 |
+
parameters = augmentation_info.get('parameters', {})
|
| 292 |
+
|
| 293 |
+
# Score based on number and strength of transformations
|
| 294 |
+
if 'rotation' in transforms:
|
| 295 |
+
angle = abs(parameters.get('rotation_angle', 0))
|
| 296 |
+
score += min(angle / 15.0, 1.0) * 0.2
|
| 297 |
+
|
| 298 |
+
if 'scaling' in transforms:
|
| 299 |
+
scale = parameters.get('scale_factor', 1.0)
|
| 300 |
+
deviation = abs(scale - 1.0)
|
| 301 |
+
score += min(deviation / 0.3, 1.0) * 0.2
|
| 302 |
+
|
| 303 |
+
if 'perspective' in transforms:
|
| 304 |
+
score += 0.3
|
| 305 |
+
|
| 306 |
+
if 'gaussian_blur' in transforms:
|
| 307 |
+
radius = parameters.get('blur_radius', 0)
|
| 308 |
+
score += min(radius / 2.0, 1.0) * 0.15
|
| 309 |
+
|
| 310 |
+
if 'gaussian_noise' in transforms:
|
| 311 |
+
score += 0.15
|
| 312 |
+
|
| 313 |
+
return min(score, 1.0)
|
| 314 |
+
|
| 315 |
+
@staticmethod
|
| 316 |
+
def save_federated_sample(sample: Dict, output_dir: str, sample_id: str) -> bool:
|
| 317 |
+
"""
|
| 318 |
+
Save federated sample to disk
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
sample: Sample dictionary
|
| 322 |
+
output_dir: Output directory
|
| 323 |
+
sample_id: Unique sample identifier
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
True if successful, False otherwise
|
| 327 |
+
"""
|
| 328 |
+
try:
|
| 329 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 330 |
+
|
| 331 |
+
# Save image
|
| 332 |
+
image = DataUtils.decode_base64_to_image(sample['image_data'])
|
| 333 |
+
if image:
|
| 334 |
+
image_path = os.path.join(output_dir, f"{sample_id}.jpg")
|
| 335 |
+
image.save(image_path, "JPEG", quality=85)
|
| 336 |
+
|
| 337 |
+
# Save annotations and metadata
|
| 338 |
+
metadata_path = os.path.join(output_dir, f"{sample_id}.json")
|
| 339 |
+
with open(metadata_path, 'w') as f:
|
| 340 |
+
json.dump({
|
| 341 |
+
'annotations': sample['annotations'],
|
| 342 |
+
'metadata': sample['metadata']
|
| 343 |
+
}, f, indent=2)
|
| 344 |
+
|
| 345 |
+
return True
|
| 346 |
+
|
| 347 |
+
except Exception as e:
|
| 348 |
+
logger.error(f"Error saving federated sample: {e}")
|
| 349 |
+
return False
|
| 350 |
+
|
| 351 |
+
@staticmethod
|
| 352 |
+
def load_federated_sample(input_dir: str, sample_id: str) -> Optional[Dict]:
|
| 353 |
+
"""
|
| 354 |
+
Load federated sample from disk
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
input_dir: Input directory
|
| 358 |
+
sample_id: Sample identifier
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
Sample dictionary or None if loading fails
|
| 362 |
+
"""
|
| 363 |
+
try:
|
| 364 |
+
# Load image
|
| 365 |
+
image_path = os.path.join(input_dir, f"{sample_id}.jpg")
|
| 366 |
+
with open(image_path, 'rb') as f:
|
| 367 |
+
image_data = base64.b64encode(f.read()).decode()
|
| 368 |
+
|
| 369 |
+
# Load metadata
|
| 370 |
+
metadata_path = os.path.join(input_dir, f"{sample_id}.json")
|
| 371 |
+
with open(metadata_path, 'r') as f:
|
| 372 |
+
metadata = json.load(f)
|
| 373 |
+
|
| 374 |
+
return {
|
| 375 |
+
'image_data': image_data,
|
| 376 |
+
'annotations': metadata['annotations'],
|
| 377 |
+
'metadata': metadata['metadata']
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
except Exception as e:
|
| 381 |
+
logger.error(f"Error loading federated sample: {e}")
|
| 382 |
+
return None
|
| 383 |
+
|
| 384 |
+
@staticmethod
|
| 385 |
+
def create_federated_batch(samples: List[Dict]) -> Dict:
|
| 386 |
+
"""
|
| 387 |
+
Create a batch of federated samples for transmission
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
samples: List of sample dictionaries
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
Batch dictionary
|
| 394 |
+
"""
|
| 395 |
+
return {
|
| 396 |
+
'batch_id': str(int(time.time())),
|
| 397 |
+
'samples': samples,
|
| 398 |
+
'batch_size': len(samples),
|
| 399 |
+
'total_clients': len(set(sample['metadata']['client_id'] for sample in samples)),
|
| 400 |
+
'average_privacy_score': np.mean([DataUtils.calculate_privacy_score(
|
| 401 |
+
sample['metadata']['augmentation_info']) for sample in samples])
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
@staticmethod
|
| 405 |
+
def validate_federated_batch(batch: Dict) -> Tuple[bool, str]:
|
| 406 |
+
"""
|
| 407 |
+
Validate a federated batch
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
batch: Batch dictionary
|
| 411 |
+
|
| 412 |
+
Returns:
|
| 413 |
+
(is_valid, error_message)
|
| 414 |
+
"""
|
| 415 |
+
try:
|
| 416 |
+
required_keys = ['batch_id', 'samples', 'batch_size']
|
| 417 |
+
for key in required_keys:
|
| 418 |
+
if key not in batch:
|
| 419 |
+
return False, f"Missing required key: {key}"
|
| 420 |
+
|
| 421 |
+
if not isinstance(batch['samples'], list):
|
| 422 |
+
return False, "Samples must be a list"
|
| 423 |
+
|
| 424 |
+
if len(batch['samples']) != batch['batch_size']:
|
| 425 |
+
return False, "Batch size doesn't match number of samples"
|
| 426 |
+
|
| 427 |
+
# Validate each sample
|
| 428 |
+
for i, sample in enumerate(batch['samples']):
|
| 429 |
+
if 'image_data' not in sample:
|
| 430 |
+
return False, f"Sample {i} missing image_data"
|
| 431 |
+
|
| 432 |
+
if 'annotations' not in sample:
|
| 433 |
+
return False, f"Sample {i} missing annotations"
|
| 434 |
+
|
| 435 |
+
if 'metadata' not in sample:
|
| 436 |
+
return False, f"Sample {i} missing metadata"
|
| 437 |
+
|
| 438 |
+
return True, "Valid"
|
| 439 |
+
|
| 440 |
+
except Exception as e:
|
| 441 |
+
return False, f"Validation error: {e}"
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
class FederatedDataConverter:
|
| 445 |
+
"""Convert between RoDLA format and federated format"""
|
| 446 |
+
|
| 447 |
+
@staticmethod
|
| 448 |
+
def rodla_to_federated(rodla_batch: Dict, client_id: str,
|
| 449 |
+
privacy_level: str = 'medium') -> List[Dict]:
|
| 450 |
+
"""
|
| 451 |
+
Convert RoDLA batch format to federated sample format
|
| 452 |
+
|
| 453 |
+
Args:
|
| 454 |
+
rodla_batch: Batch from RoDLA data loader
|
| 455 |
+
client_id: Client identifier
|
| 456 |
+
privacy_level: Privacy level for augmentations
|
| 457 |
+
|
| 458 |
+
Returns:
|
| 459 |
+
List of federated samples
|
| 460 |
+
"""
|
| 461 |
+
samples = []
|
| 462 |
+
|
| 463 |
+
try:
|
| 464 |
+
# Extract batch components
|
| 465 |
+
images = rodla_batch['img']
|
| 466 |
+
img_metas = rodla_batch['img_metas']
|
| 467 |
+
|
| 468 |
+
# Handle different batch structures
|
| 469 |
+
if isinstance(rodla_batch['gt_bboxes'], list):
|
| 470 |
+
bboxes_list = rodla_batch['gt_bboxes']
|
| 471 |
+
labels_list = rodla_batch['gt_labels']
|
| 472 |
+
else:
|
| 473 |
+
# Convert tensor to list format
|
| 474 |
+
bboxes_list = [bboxes for bboxes in rodla_batch['gt_bboxes']]
|
| 475 |
+
labels_list = [labels for labels in rodla_batch['gt_labels']]
|
| 476 |
+
|
| 477 |
+
for i in range(len(images)):
|
| 478 |
+
# Convert tensor to PIL Image
|
| 479 |
+
img_tensor = images[i]
|
| 480 |
+
pil_img = DataUtils.tensor_to_pil(img_tensor)
|
| 481 |
+
|
| 482 |
+
# Prepare annotations
|
| 483 |
+
bboxes = bboxes_list[i].cpu().numpy().tolist() if hasattr(bboxes_list[i], 'cpu') else bboxes_list[i]
|
| 484 |
+
labels = labels_list[i].cpu().numpy().tolist() if hasattr(labels_list[i], 'cpu') else labels_list[i]
|
| 485 |
+
|
| 486 |
+
# Get original image info
|
| 487 |
+
img_meta = img_metas[i].data if hasattr(img_metas[i], 'data') else img_metas[i]
|
| 488 |
+
original_size = (img_meta['ori_shape'][1], img_meta['ori_shape'][0]) # (width, height)
|
| 489 |
+
|
| 490 |
+
annotations = {
|
| 491 |
+
'bboxes': bboxes,
|
| 492 |
+
'labels': labels,
|
| 493 |
+
'image_size': original_size,
|
| 494 |
+
'original_filename': img_meta.get('filename', 'unknown')
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
# Create augmentation info (will be filled by augmentation engine)
|
| 498 |
+
augmentation_info = {
|
| 499 |
+
'original_size': original_size,
|
| 500 |
+
'applied_transforms': [],
|
| 501 |
+
'parameters': {}
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
# Create sample
|
| 505 |
+
sample = {
|
| 506 |
+
'image_data': DataUtils.encode_image_to_base64(pil_img),
|
| 507 |
+
'annotations': annotations,
|
| 508 |
+
'metadata': DataUtils.create_sample_metadata(
|
| 509 |
+
client_id, privacy_level, augmentation_info,
|
| 510 |
+
img_meta.get('filename', 'unknown'))
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
samples.append(sample)
|
| 514 |
+
|
| 515 |
+
except Exception as e:
|
| 516 |
+
logger.error(f"Error converting RoDLA to federated format: {e}")
|
| 517 |
+
|
| 518 |
+
return samples
|
| 519 |
+
|
| 520 |
+
@staticmethod
|
| 521 |
+
def federated_to_rodla(federated_sample: Dict) -> Dict:
|
| 522 |
+
"""
|
| 523 |
+
Convert federated sample to RoDLA training format
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
federated_sample: Federated sample dictionary
|
| 527 |
+
|
| 528 |
+
Returns:
|
| 529 |
+
RoDLA format sample
|
| 530 |
+
"""
|
| 531 |
+
try:
|
| 532 |
+
# Decode image
|
| 533 |
+
image = DataUtils.decode_base64_to_image(federated_sample['image_data'])
|
| 534 |
+
if image is None:
|
| 535 |
+
raise ValueError("Failed to decode image")
|
| 536 |
+
|
| 537 |
+
# Convert to tensor (normalized)
|
| 538 |
+
img_tensor = DataUtils.pil_to_tensor(image)
|
| 539 |
+
|
| 540 |
+
# Extract annotations
|
| 541 |
+
annotations = federated_sample['annotations']
|
| 542 |
+
bboxes = torch.tensor(annotations['bboxes'], dtype=torch.float32)
|
| 543 |
+
labels = torch.tensor(annotations['labels'], dtype=torch.int64)
|
| 544 |
+
|
| 545 |
+
# Create img_meta
|
| 546 |
+
img_meta = {
|
| 547 |
+
'filename': federated_sample['metadata'].get('original_file', 'federated_sample'),
|
| 548 |
+
'ori_shape': (annotations['image_size'][1], annotations['image_size'][0], 3),
|
| 549 |
+
'img_shape': (img_tensor.shape[1], img_tensor.shape[2], 3),
|
| 550 |
+
'scale_factor': np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float32),
|
| 551 |
+
'flip': False,
|
| 552 |
+
'flip_direction': None,
|
| 553 |
+
'img_norm_cfg': {
|
| 554 |
+
'mean': [123.675, 116.28, 103.53],
|
| 555 |
+
'std': [58.395, 57.12, 57.375],
|
| 556 |
+
'to_rgb': True
|
| 557 |
+
}
|
| 558 |
+
}
|
| 559 |
+
|
| 560 |
+
return {
|
| 561 |
+
'img': img_tensor,
|
| 562 |
+
'gt_bboxes': bboxes,
|
| 563 |
+
'gt_labels': labels,
|
| 564 |
+
'img_metas': img_meta
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
except Exception as e:
|
| 568 |
+
logger.error(f"Error converting federated to RoDLA format: {e}")
|
| 569 |
+
# Return empty sample as fallback
|
| 570 |
+
return {
|
| 571 |
+
'img': torch.zeros(3, 800, 1333),
|
| 572 |
+
'gt_bboxes': torch.zeros(0, 4),
|
| 573 |
+
'gt_labels': torch.zeros(0, dtype=torch.int64),
|
| 574 |
+
'img_metas': {}
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
# Utility functions for easy access
|
| 579 |
+
def encode_image(image: Image.Image) -> str:
|
| 580 |
+
return DataUtils.encode_image_to_base64(image)
|
| 581 |
+
|
| 582 |
+
def decode_image(image_data: str) -> Image.Image:
|
| 583 |
+
return DataUtils.decode_base64_to_image(image_data)
|
| 584 |
+
|
| 585 |
+
def validate_sample(sample: Dict) -> bool:
|
| 586 |
+
"""Quick validation of a federated sample"""
|
| 587 |
+
if 'image_data' not in sample or 'annotations' not in sample:
|
| 588 |
+
return False
|
| 589 |
+
|
| 590 |
+
image = decode_image(sample['image_data'])
|
| 591 |
+
if image is None:
|
| 592 |
+
return False
|
| 593 |
+
|
| 594 |
+
return DataUtils.validate_annotations(sample['annotations'], image.size)
|
| 595 |
+
|
| 596 |
+
# Initialize logging
|
| 597 |
+
import time
|
| 598 |
+
logging.basicConfig(
|
| 599 |
+
level=logging.INFO,
|
| 600 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 601 |
)
|
finetuning_rodla/finetuning_rodla/checkpoints/internimage_xl_22k_192to384.pth
ADDED
|
File without changes
|
finetuning_rodla/finetuning_rodla/checkpoints/rodla_internimage_xl_publaynet.pth
ADDED
|
File without changes
|
finetuning_rodla/finetuning_rodla/configs/docbank/rodla_internimage_docbank.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RoDLA Fine-tuning Configuration for DocBank
|
| 2 |
+
# CVPR 2024 - Document Layout Analysis
|
| 3 |
+
|
| 4 |
+
_base_ = [
|
| 5 |
+
'../_base_/datasets/coco_detection.py',
|
| 6 |
+
'../_base_/schedules/schedule_1x.py',
|
| 7 |
+
'../_base_/default_runtime.py'
|
| 8 |
+
]
|
| 9 |
+
|
| 10 |
+
# Pre-trained RoDLA weights from PubLayNet
|
| 11 |
+
pretrained = 'checkpoints/rodla_internimage_xl_publaynet.pth'
|
| 12 |
+
|
| 13 |
+
model = dict(
|
| 14 |
+
type='ATSS',
|
| 15 |
+
backbone=dict(
|
| 16 |
+
_delete_=True,
|
| 17 |
+
type='InternImage',
|
| 18 |
+
core_op='DCNv3',
|
| 19 |
+
channels=192,
|
| 20 |
+
depths=[5, 5, 22, 5],
|
| 21 |
+
groups=[12, 24, 48, 96],
|
| 22 |
+
mlp_ratio=4.,
|
| 23 |
+
drop_path_rate=0.3, # Reduced for fine-tuning
|
| 24 |
+
norm_layer='LN',
|
| 25 |
+
layer_scale=1.0,
|
| 26 |
+
offset_scale=2.0,
|
| 27 |
+
post_norm=True,
|
| 28 |
+
with_cp=True,
|
| 29 |
+
out_indices=(1, 2, 3),
|
| 30 |
+
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
|
| 31 |
+
neck=dict(
|
| 32 |
+
type='FPN',
|
| 33 |
+
in_channels=[384, 768, 1536],
|
| 34 |
+
out_channels=256,
|
| 35 |
+
num_outs=5),
|
| 36 |
+
bbox_head=dict(
|
| 37 |
+
type='ATSSHead',
|
| 38 |
+
num_classes=11, # DocBank classes
|
| 39 |
+
in_channels=256,
|
| 40 |
+
stacked_convs=4,
|
| 41 |
+
feat_channels=256,
|
| 42 |
+
anchor_generator=dict(
|
| 43 |
+
type='AnchorGenerator',
|
| 44 |
+
ratios=[1.0],
|
| 45 |
+
octave_base_scale=8,
|
| 46 |
+
scales_per_octave=1,
|
| 47 |
+
strides=[8, 16, 32, 64, 128]),
|
| 48 |
+
bbox_coder=dict(
|
| 49 |
+
type='DeltaXYWHBBoxCoder',
|
| 50 |
+
target_means=[.0, .0, .0, .0],
|
| 51 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
| 52 |
+
loss_cls=dict(
|
| 53 |
+
type='FocalLoss',
|
| 54 |
+
use_sigmoid=True,
|
| 55 |
+
gamma=2.0,
|
| 56 |
+
alpha=0.25,
|
| 57 |
+
loss_weight=1.0),
|
| 58 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
|
| 59 |
+
train_cfg=dict(
|
| 60 |
+
assigner=dict(type='ATSSAssigner', topk=9),
|
| 61 |
+
allowed_border=-1,
|
| 62 |
+
pos_weight=-1,
|
| 63 |
+
debug=False),
|
| 64 |
+
test_cfg=dict(
|
| 65 |
+
nms_pre=1000,
|
| 66 |
+
min_bbox_size=0,
|
| 67 |
+
score_thr=0.05,
|
| 68 |
+
nms=dict(type='nms', iou_threshold=0.6),
|
| 69 |
+
max_per_img=100)))
|
| 70 |
+
|
| 71 |
+
# Dataset settings for DocBank
|
| 72 |
+
dataset_type = 'CocoDataset'
|
| 73 |
+
data_root = 'data/DocBank_coco/'
|
| 74 |
+
|
| 75 |
+
# DocBank classes
|
| 76 |
+
classes = ('abstract', 'author', 'caption', 'equation', 'figure',
|
| 77 |
+
'footer', 'list', 'paragraph', 'reference', 'section', 'table')
|
| 78 |
+
|
| 79 |
+
img_norm_cfg = dict(
|
| 80 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
| 81 |
+
|
| 82 |
+
# Fine-tuning pipeline (simpler than full training)
|
| 83 |
+
train_pipeline = [
|
| 84 |
+
dict(type='LoadImageFromFile'),
|
| 85 |
+
dict(type='LoadAnnotations', with_bbox=True),
|
| 86 |
+
dict(type='Resize', img_scale=[(1333, 800)], keep_ratio=True),
|
| 87 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
| 88 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 89 |
+
dict(type='Pad', size_divisor=32),
|
| 90 |
+
dict(type='DefaultFormatBundle'),
|
| 91 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
test_pipeline = [
|
| 95 |
+
dict(type='LoadImageFromFile'),
|
| 96 |
+
dict(
|
| 97 |
+
type='MultiScaleFlipAug',
|
| 98 |
+
img_scale=(1333, 800),
|
| 99 |
+
flip=False,
|
| 100 |
+
transforms=[
|
| 101 |
+
dict(type='Resize', keep_ratio=True),
|
| 102 |
+
dict(type='RandomFlip'),
|
| 103 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 104 |
+
dict(type='Pad', size_divisor=32),
|
| 105 |
+
dict(type='ImageToTensor', keys=['img']),
|
| 106 |
+
dict(type='Collect', keys=['img']),
|
| 107 |
+
])
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
data = dict(
|
| 111 |
+
samples_per_gpu=2,
|
| 112 |
+
workers_per_gpu=2,
|
| 113 |
+
train=dict(
|
| 114 |
+
type=dataset_type,
|
| 115 |
+
ann_file=data_root + 'annotations/train.json',
|
| 116 |
+
img_prefix=data_root + 'images/',
|
| 117 |
+
classes=classes,
|
| 118 |
+
pipeline=train_pipeline),
|
| 119 |
+
val=dict(
|
| 120 |
+
type=dataset_type,
|
| 121 |
+
ann_file=data_root + 'annotations/val.json',
|
| 122 |
+
img_prefix=data_root + 'images/',
|
| 123 |
+
classes=classes,
|
| 124 |
+
pipeline=test_pipeline),
|
| 125 |
+
test=dict(
|
| 126 |
+
type=dataset_type,
|
| 127 |
+
ann_file=data_root + 'annotations/val.json', # Using val for test during fine-tuning
|
| 128 |
+
img_prefix=data_root + 'images/',
|
| 129 |
+
classes=classes,
|
| 130 |
+
pipeline=test_pipeline))
|
| 131 |
+
|
| 132 |
+
# Fine-tuning optimizer (lower learning rate)
|
| 133 |
+
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
|
| 134 |
+
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
|
| 135 |
+
|
| 136 |
+
# Fine-tuning schedule
|
| 137 |
+
lr_config = dict(
|
| 138 |
+
policy='step',
|
| 139 |
+
warmup='linear',
|
| 140 |
+
warmup_iters=500,
|
| 141 |
+
warmup_ratio=0.001,
|
| 142 |
+
step=[8, 11])
|
| 143 |
+
|
| 144 |
+
runner = dict(type='EpochBasedRunner', max_epochs=12)
|
| 145 |
+
|
| 146 |
+
# Evaluation and logging
|
| 147 |
+
evaluation = dict(interval=1, metric='bbox')
|
| 148 |
+
checkpoint_config = dict(interval=1, max_keep_ckpts=3)
|
| 149 |
+
log_config = dict(
|
| 150 |
+
interval=50,
|
| 151 |
+
hooks=[
|
| 152 |
+
dict(type='TextLoggerHook'),
|
| 153 |
+
# dict(type='TensorboardLoggerHook')
|
| 154 |
+
])
|
| 155 |
+
|
| 156 |
+
# Work directory
|
| 157 |
+
work_dir = './work_dirs/rodla_docbank'
|
finetuning_rodla/finetuning_rodla/data/docbank_coco.json
ADDED
|
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"images": [
|
| 3 |
+
{
|
| 4 |
+
"id": 1,
|
| 5 |
+
"file_name": "69.tar_1406.0846.gz_3Potts_9_ori.jpg",
|
| 6 |
+
"width": 1700,
|
| 7 |
+
"height": 2200
|
| 8 |
+
},
|
| 9 |
+
{
|
| 10 |
+
"id": 2,
|
| 11 |
+
"file_name": "63.tar_1504.07006.gz_mayak_arxiv_20141204_7_ori.jpg",
|
| 12 |
+
"width": 1654,
|
| 13 |
+
"height": 2339
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"id": 3,
|
| 17 |
+
"file_name": "242.tar_1612.03168.gz_biomimetics_5_ori.jpg",
|
| 18 |
+
"width": 1654,
|
| 19 |
+
"height": 2339
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"id": 4,
|
| 23 |
+
"file_name": "152.tar_1608.03834.gz_fragility_II_05062016_AZ_2_ori.jpg",
|
| 24 |
+
"width": 1700,
|
| 25 |
+
"height": 2200
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"id": 5,
|
| 29 |
+
"file_name": "91.tar_1605.05268.gz_Tunnelingtime12_0_ori.jpg",
|
| 30 |
+
"width": 1700,
|
| 31 |
+
"height": 2200
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"id": 6,
|
| 35 |
+
"file_name": "33.tar_1602.07924.gz_TaS2_arxiv_11_ori.jpg",
|
| 36 |
+
"width": 1700,
|
| 37 |
+
"height": 2200
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"id": 7,
|
| 41 |
+
"file_name": "215.tar_1611.01871.gz_rsv16v1_10_ori.jpg",
|
| 42 |
+
"width": 1654,
|
| 43 |
+
"height": 2339
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"id": 8,
|
| 47 |
+
"file_name": "212.tar_1807.09084.gz_pollicott-dimaff-arxiv_66_ori.jpg",
|
| 48 |
+
"width": 1700,
|
| 49 |
+
"height": 2200
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"id": 9,
|
| 53 |
+
"file_name": "190.tar_1807.01208.gz_article_2_ori.jpg",
|
| 54 |
+
"width": 1654,
|
| 55 |
+
"height": 2339
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"id": 10,
|
| 59 |
+
"file_name": "272.tar_1809.07187.gz_author_FKThielemann_finb_7_ori.jpg",
|
| 60 |
+
"width": 1700,
|
| 61 |
+
"height": 2200
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"id": 11,
|
| 65 |
+
"file_name": "95.tar_1506.05778.gz_NiO=ferro3_11_ori.jpg",
|
| 66 |
+
"width": 1700,
|
| 67 |
+
"height": 2200
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"id": 12,
|
| 71 |
+
"file_name": "221.tar_1611.05073.gz_VNLM_arxiv_29_ori.jpg",
|
| 72 |
+
"width": 1700,
|
| 73 |
+
"height": 2200
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"id": 13,
|
| 77 |
+
"file_name": "89.tar_1407.4134.gz_NMSSM_EWPT_submission_2_26_ori.jpg",
|
| 78 |
+
"width": 1700,
|
| 79 |
+
"height": 2200
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"id": 14,
|
| 83 |
+
"file_name": "253.tar_1809.00537.gz_main_5_ori.jpg",
|
| 84 |
+
"width": 1654,
|
| 85 |
+
"height": 2339
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"id": 15,
|
| 89 |
+
"file_name": "228.tar_1611.07901.gz_efield_arxiv_3_ori.jpg",
|
| 90 |
+
"width": 1654,
|
| 91 |
+
"height": 2339
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"id": 16,
|
| 95 |
+
"file_name": "209.tar_1807.08272.gz_main_1_ori.jpg",
|
| 96 |
+
"width": 1654,
|
| 97 |
+
"height": 2339
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
"id": 17,
|
| 101 |
+
"file_name": "171.tar_1708.01402.gz_address_sig_13_ori.jpg",
|
| 102 |
+
"width": 1700,
|
| 103 |
+
"height": 2200
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"id": 18,
|
| 107 |
+
"file_name": "10.tar_1701.04170.gz_TPNL_afterglow_evo_8_ori.jpg",
|
| 108 |
+
"width": 1700,
|
| 109 |
+
"height": 2200
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"id": 19,
|
| 113 |
+
"file_name": "126.tar_1607.01329.gz_ms_astroph_7_ori.jpg",
|
| 114 |
+
"width": 1700,
|
| 115 |
+
"height": 2200
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"id": 20,
|
| 119 |
+
"file_name": "16.tar_1801.06571.gz_CS_susceptibility_final_6_ori.jpg",
|
| 120 |
+
"width": 1700,
|
| 121 |
+
"height": 2200
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"id": 21,
|
| 125 |
+
"file_name": "113.tar_1507.06110.gz_DelayedAcceptanceDataSubsampling_11_ori.jpg",
|
| 126 |
+
"width": 1700,
|
| 127 |
+
"height": 2200
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"id": 22,
|
| 131 |
+
"file_name": "107.tar_1606.02202.gz_arxiv-v2-EHX_3_ori.jpg",
|
| 132 |
+
"width": 1654,
|
| 133 |
+
"height": 2339
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"id": 23,
|
| 137 |
+
"file_name": "135.tar_1805.05760.gz_cataracts_3_ori.jpg",
|
| 138 |
+
"width": 1700,
|
| 139 |
+
"height": 2200
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"id": 24,
|
| 143 |
+
"file_name": "7.tar_1601.03015.gz_crs_19_ori.jpg",
|
| 144 |
+
"width": 1654,
|
| 145 |
+
"height": 2339
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"id": 25,
|
| 149 |
+
"file_name": "143.tar_1805.08652.gz_General_Boundary_Transport_Draft_18_ori.jpg",
|
| 150 |
+
"width": 1700,
|
| 151 |
+
"height": 2200
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"id": 26,
|
| 155 |
+
"file_name": "263.tar_1711.06126.gz_draft_slender_phoretic-12nov17_3_ori.jpg",
|
| 156 |
+
"width": 1700,
|
| 157 |
+
"height": 2200
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"id": 27,
|
| 161 |
+
"file_name": "138.tar_1706.07989.gz_CoPS3_arXiv_2017_17_ori.jpg",
|
| 162 |
+
"width": 1700,
|
| 163 |
+
"height": 2200
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"id": 28,
|
| 167 |
+
"file_name": "75.tar_1505.04211.gz_discoPoly_12_ori.jpg",
|
| 168 |
+
"width": 1700,
|
| 169 |
+
"height": 2200
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"id": 29,
|
| 173 |
+
"file_name": "71.tar_1803.05570.gz_draft_eta_p_enu_2_ori.jpg",
|
| 174 |
+
"width": 1700,
|
| 175 |
+
"height": 2200
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"id": 30,
|
| 179 |
+
"file_name": "80.tar_1605.00521.gz_323_3_ori.jpg",
|
| 180 |
+
"width": 1654,
|
| 181 |
+
"height": 2339
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"id": 31,
|
| 185 |
+
"file_name": "279.tar_1712.00102.gz_P51_GUEmCutoffShock_20_ori.jpg",
|
| 186 |
+
"width": 1654,
|
| 187 |
+
"height": 2339
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"id": 32,
|
| 191 |
+
"file_name": "187.tar_1511.05780.gz_Levy_irregular_sampling_5_ori.jpg",
|
| 192 |
+
"width": 1654,
|
| 193 |
+
"height": 2339
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"id": 33,
|
| 197 |
+
"file_name": "152.tar_1509.08018.gz_chaindecodingTCOM_v10_69_ori.jpg",
|
| 198 |
+
"width": 1700,
|
| 199 |
+
"height": 2200
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"id": 34,
|
| 203 |
+
"file_name": "102.tar_1705.05217.gz_final_report_3_ori.jpg",
|
| 204 |
+
"width": 1700,
|
| 205 |
+
"height": 2200
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"id": 35,
|
| 209 |
+
"file_name": "107.tar_1804.07036.gz_Wu-Hu_6_ori.jpg",
|
| 210 |
+
"width": 1700,
|
| 211 |
+
"height": 2200
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"id": 36,
|
| 215 |
+
"file_name": "143.tar_1509.03588.gz_CeB6_Review_4_ori.jpg",
|
| 216 |
+
"width": 1700,
|
| 217 |
+
"height": 2200
|
| 218 |
+
},
|
| 219 |
+
{
|
| 220 |
+
"id": 37,
|
| 221 |
+
"file_name": "171.tar_1510.07771.gz_manuscript_v1_5_ori.jpg",
|
| 222 |
+
"width": 1700,
|
| 223 |
+
"height": 2200
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"id": 38,
|
| 227 |
+
"file_name": "39.tar_1802.04452.gz_ms_18_ori.jpg",
|
| 228 |
+
"width": 1700,
|
| 229 |
+
"height": 2200
|
| 230 |
+
},
|
| 231 |
+
{
|
| 232 |
+
"id": 39,
|
| 233 |
+
"file_name": "230.tar_1611.08510.gz_DPTG_PA_ABM_004_4_ori.jpg",
|
| 234 |
+
"width": 1700,
|
| 235 |
+
"height": 2200
|
| 236 |
+
},
|
| 237 |
+
{
|
| 238 |
+
"id": 40,
|
| 239 |
+
"file_name": "141.tar_1410.7721.gz_arxiv_8_ori.jpg",
|
| 240 |
+
"width": 1654,
|
| 241 |
+
"height": 2339
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"id": 41,
|
| 245 |
+
"file_name": "111.tar_1804.08410.gz_Asymptotic_analysis_5_ori.jpg",
|
| 246 |
+
"width": 1700,
|
| 247 |
+
"height": 2200
|
| 248 |
+
},
|
| 249 |
+
{
|
| 250 |
+
"id": 42,
|
| 251 |
+
"file_name": "275.tar_1809.08252.gz_PapierFluctuations3_0_ori.jpg",
|
| 252 |
+
"width": 1700,
|
| 253 |
+
"height": 2200
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
"id": 43,
|
| 257 |
+
"file_name": "34.tar_1602.08352.gz_LCWS2015_BSM_6_ori.jpg",
|
| 258 |
+
"width": 1654,
|
| 259 |
+
"height": 2339
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"id": 44,
|
| 263 |
+
"file_name": "135.tar_1410.4804.gz_mpk_2_ori.jpg",
|
| 264 |
+
"width": 1700,
|
| 265 |
+
"height": 2200
|
| 266 |
+
},
|
| 267 |
+
{
|
| 268 |
+
"id": 45,
|
| 269 |
+
"file_name": "202.tar_1709.03604.gz_bar_quenching_12_ori.jpg",
|
| 270 |
+
"width": 1654,
|
| 271 |
+
"height": 2339
|
| 272 |
+
},
|
| 273 |
+
{
|
| 274 |
+
"id": 46,
|
| 275 |
+
"file_name": "113.tar_1507.06116.gz_fluct_150720_7_ori.jpg",
|
| 276 |
+
"width": 1700,
|
| 277 |
+
"height": 2200
|
| 278 |
+
},
|
| 279 |
+
{
|
| 280 |
+
"id": 47,
|
| 281 |
+
"file_name": "296.tar_1712.06571.gz_G2_MIR_final_25_ori.jpg",
|
| 282 |
+
"width": 1700,
|
| 283 |
+
"height": 2200
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"id": 48,
|
| 287 |
+
"file_name": "55.tar_1802.10418.gz_icml2018_songtao_arXiv_49_ori.jpg",
|
| 288 |
+
"width": 1700,
|
| 289 |
+
"height": 2200
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"id": 49,
|
| 293 |
+
"file_name": "189.tar_1708.08822.gz_Diffusion_Anisotropic_ver_2_29_ori.jpg",
|
| 294 |
+
"width": 1654,
|
| 295 |
+
"height": 2339
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"id": 50,
|
| 299 |
+
"file_name": "62.tar_1504.06368.gz_main_1_ori.jpg",
|
| 300 |
+
"width": 1700,
|
| 301 |
+
"height": 2200
|
| 302 |
+
},
|
| 303 |
+
{
|
| 304 |
+
"id": 51,
|
| 305 |
+
"file_name": "132.tar_1410.2655.gz_CRBTSM_parizot_final_7_ori.jpg",
|
| 306 |
+
"width": 1654,
|
| 307 |
+
"height": 2339
|
| 308 |
+
},
|
| 309 |
+
{
|
| 310 |
+
"id": 52,
|
| 311 |
+
"file_name": "7.tar_1801.02983.gz_Article_7_ori.jpg",
|
| 312 |
+
"width": 1700,
|
| 313 |
+
"height": 2200
|
| 314 |
+
},
|
| 315 |
+
{
|
| 316 |
+
"id": 53,
|
| 317 |
+
"file_name": "65.tar_1803.03564.gz_faddeev271017_2_ori.jpg",
|
| 318 |
+
"width": 1700,
|
| 319 |
+
"height": 2200
|
| 320 |
+
},
|
| 321 |
+
{
|
| 322 |
+
"id": 54,
|
| 323 |
+
"file_name": "2.tar_1801.00617.gz_idempotents_arxiv_4_ori.jpg",
|
| 324 |
+
"width": 1654,
|
| 325 |
+
"height": 2339
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
"id": 55,
|
| 329 |
+
"file_name": "169.tar_1708.00745.gz_ODT_Soubies_8_ori.jpg",
|
| 330 |
+
"width": 1700,
|
| 331 |
+
"height": 2200
|
| 332 |
+
},
|
| 333 |
+
{
|
| 334 |
+
"id": 56,
|
| 335 |
+
"file_name": "173.tar_1708.02244.gz_D1D5BPSv2_39_ori.jpg",
|
| 336 |
+
"width": 1700,
|
| 337 |
+
"height": 2200
|
| 338 |
+
},
|
| 339 |
+
{
|
| 340 |
+
"id": 57,
|
| 341 |
+
"file_name": "100.tar_1705.04261.gz_main_11_ori.jpg",
|
| 342 |
+
"width": 1654,
|
| 343 |
+
"height": 2339
|
| 344 |
+
},
|
| 345 |
+
{
|
| 346 |
+
"id": 58,
|
| 347 |
+
"file_name": "232.tar_1808.04097.gz_ep_LHC_submit_22_ori.jpg",
|
| 348 |
+
"width": 1654,
|
| 349 |
+
"height": 2339
|
| 350 |
+
},
|
| 351 |
+
{
|
| 352 |
+
"id": 59,
|
| 353 |
+
"file_name": "80.tar_1803.09023.gz_20180323_3_ori.jpg",
|
| 354 |
+
"width": 1700,
|
| 355 |
+
"height": 2200
|
| 356 |
+
},
|
| 357 |
+
{
|
| 358 |
+
"id": 60,
|
| 359 |
+
"file_name": "11.tar_1401.6921.gz_rad-lep-II-2_13_ori.jpg",
|
| 360 |
+
"width": 1654,
|
| 361 |
+
"height": 2339
|
| 362 |
+
},
|
| 363 |
+
{
|
| 364 |
+
"id": 61,
|
| 365 |
+
"file_name": "247.tar_1710.11035.gz_MTforGSW_2_ori.jpg",
|
| 366 |
+
"width": 1654,
|
| 367 |
+
"height": 2339
|
| 368 |
+
},
|
| 369 |
+
{
|
| 370 |
+
"id": 62,
|
| 371 |
+
"file_name": "139.tar_1410.6666.gz_dft-and-kp-tmdc-pdffigs_2_ori.jpg",
|
| 372 |
+
"width": 1700,
|
| 373 |
+
"height": 2200
|
| 374 |
+
},
|
| 375 |
+
{
|
| 376 |
+
"id": 63,
|
| 377 |
+
"file_name": "211.tar_1611.00049.gz_NNLLpaper_14_ori.jpg",
|
| 378 |
+
"width": 1700,
|
| 379 |
+
"height": 2200
|
| 380 |
+
},
|
| 381 |
+
{
|
| 382 |
+
"id": 64,
|
| 383 |
+
"file_name": "103.tar_1408.2982.gz_banach_4_ori.jpg",
|
| 384 |
+
"width": 1654,
|
| 385 |
+
"height": 2339
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
"id": 65,
|
| 389 |
+
"file_name": "12.tar_1701.05337.gz_ms_14_ori.jpg",
|
| 390 |
+
"width": 1654,
|
| 391 |
+
"height": 2339
|
| 392 |
+
},
|
| 393 |
+
{
|
| 394 |
+
"id": 66,
|
| 395 |
+
"file_name": "246.tar_1808.08720.gz_conll2018_3_ori.jpg",
|
| 396 |
+
"width": 1654,
|
| 397 |
+
"height": 2339
|
| 398 |
+
},
|
| 399 |
+
{
|
| 400 |
+
"id": 67,
|
| 401 |
+
"file_name": "131.tar_1410.2446.gz_root1asg_clean_9_ori.jpg",
|
| 402 |
+
"width": 1700,
|
| 403 |
+
"height": 2200
|
| 404 |
+
},
|
| 405 |
+
{
|
| 406 |
+
"id": 68,
|
| 407 |
+
"file_name": "148.tar_1707.02008.gz_ms_9_ori.jpg",
|
| 408 |
+
"width": 1700,
|
| 409 |
+
"height": 2200
|
| 410 |
+
},
|
| 411 |
+
{
|
| 412 |
+
"id": 69,
|
| 413 |
+
"file_name": "175.tar_1511.00117.gz_wcci_papier4_6_ori.jpg",
|
| 414 |
+
"width": 1700,
|
| 415 |
+
"height": 2200
|
| 416 |
+
},
|
| 417 |
+
{
|
| 418 |
+
"id": 70,
|
| 419 |
+
"file_name": "250.tar_1711.00637.gz_CME_PID_v1_2_ori.jpg",
|
| 420 |
+
"width": 1654,
|
| 421 |
+
"height": 2339
|
| 422 |
+
},
|
| 423 |
+
{
|
| 424 |
+
"id": 71,
|
| 425 |
+
"file_name": "99.tar_1804.04115.gz_vFINAL_21_ori.jpg",
|
| 426 |
+
"width": 1700,
|
| 427 |
+
"height": 2200
|
| 428 |
+
},
|
| 429 |
+
{
|
| 430 |
+
"id": 72,
|
| 431 |
+
"file_name": "117.tar_1409.3407.gz_submitted2_2_ori.jpg",
|
| 432 |
+
"width": 1654,
|
| 433 |
+
"height": 2339
|
| 434 |
+
},
|
| 435 |
+
{
|
| 436 |
+
"id": 73,
|
| 437 |
+
"file_name": "106.tar_1705.06909.gz_KGBR5_4_ori.jpg",
|
| 438 |
+
"width": 1654,
|
| 439 |
+
"height": 2339
|
| 440 |
+
},
|
| 441 |
+
{
|
| 442 |
+
"id": 74,
|
| 443 |
+
"file_name": "94.tar_1506.05555.gz_NNSHMC_SC_3rdRevision_15_ori.jpg",
|
| 444 |
+
"width": 1700,
|
| 445 |
+
"height": 2200
|
| 446 |
+
},
|
| 447 |
+
{
|
| 448 |
+
"id": 75,
|
| 449 |
+
"file_name": "13.tar_1801.05376.gz_main_26_ori.jpg",
|
| 450 |
+
"width": 1700,
|
| 451 |
+
"height": 2200
|
| 452 |
+
},
|
| 453 |
+
{
|
| 454 |
+
"id": 76,
|
| 455 |
+
"file_name": "11.tar_1701.04715.gz_paper_1_ori.jpg",
|
| 456 |
+
"width": 1654,
|
| 457 |
+
"height": 2339
|
| 458 |
+
},
|
| 459 |
+
{
|
| 460 |
+
"id": 77,
|
| 461 |
+
"file_name": "35.tar_1802.02802.gz_gyurky_NPA7proc_arxiv_3_ori.jpg",
|
| 462 |
+
"width": 1654,
|
| 463 |
+
"height": 2339
|
| 464 |
+
},
|
| 465 |
+
{
|
| 466 |
+
"id": 78,
|
| 467 |
+
"file_name": "8.tar_1501.04227.gz_tunablefailure_draft_20160624_8_ori.jpg",
|
| 468 |
+
"width": 1654,
|
| 469 |
+
"height": 2339
|
| 470 |
+
},
|
| 471 |
+
{
|
| 472 |
+
"id": 79,
|
| 473 |
+
"file_name": "121.tar_1706.01211.gz_main_12_ori.jpg",
|
| 474 |
+
"width": 1700,
|
| 475 |
+
"height": 2200
|
| 476 |
+
},
|
| 477 |
+
{
|
| 478 |
+
"id": 80,
|
| 479 |
+
"file_name": "40.tar_1503.04529.gz_GaussianLowerBounds_LaplaceBeltrami_hal2_0_ori.jpg",
|
| 480 |
+
"width": 1221,
|
| 481 |
+
"height": 1851
|
| 482 |
+
},
|
| 483 |
+
{
|
| 484 |
+
"id": 81,
|
| 485 |
+
"file_name": "126.tar_1706.03453.gz_soft_graviton_yukawa_scalar_v2_06.10.17_0_ori.jpg",
|
| 486 |
+
"width": 1700,
|
| 487 |
+
"height": 2200
|
| 488 |
+
},
|
| 489 |
+
{
|
| 490 |
+
"id": 82,
|
| 491 |
+
"file_name": "62.tar_1803.02335.gz_Tesi_16_ori.jpg",
|
| 492 |
+
"width": 1654,
|
| 493 |
+
"height": 2339
|
| 494 |
+
},
|
| 495 |
+
{
|
| 496 |
+
"id": 83,
|
| 497 |
+
"file_name": "248.tar_1612.05617.gz_quatmc3_3_ori.jpg",
|
| 498 |
+
"width": 1700,
|
| 499 |
+
"height": 2200
|
| 500 |
+
},
|
| 501 |
+
{
|
| 502 |
+
"id": 84,
|
| 503 |
+
"file_name": "44.tar_1503.06300.gz_dodona_ijhcs_revised_round2_6_ori.jpg",
|
| 504 |
+
"width": 1654,
|
| 505 |
+
"height": 2339
|
| 506 |
+
},
|
| 507 |
+
{
|
| 508 |
+
"id": 85,
|
| 509 |
+
"file_name": "185.tar_1708.06832.gz_adaloss_9_ori.jpg",
|
| 510 |
+
"width": 1700,
|
| 511 |
+
"height": 2200
|
| 512 |
+
},
|
| 513 |
+
{
|
| 514 |
+
"id": 86,
|
| 515 |
+
"file_name": "92.tar_1407.5358.gz_kbsf_12_ori.jpg",
|
| 516 |
+
"width": 1700,
|
| 517 |
+
"height": 2200
|
| 518 |
+
},
|
| 519 |
+
{
|
| 520 |
+
"id": 87,
|
| 521 |
+
"file_name": "20.tar_1801.07927.gz_Manuscript_V5_0_ori.jpg",
|
| 522 |
+
"width": 1654,
|
| 523 |
+
"height": 2339
|
| 524 |
+
},
|
| 525 |
+
{
|
| 526 |
+
"id": 88,
|
| 527 |
+
"file_name": "116.tar_1606.06142.gz_news_portal_art_19_normal_7_ori.jpg",
|
| 528 |
+
"width": 1654,
|
| 529 |
+
"height": 2339
|
| 530 |
+
},
|
| 531 |
+
{
|
| 532 |
+
"id": 89,
|
| 533 |
+
"file_name": "62.tar_1405.4919.gz_carpets_15_ori.jpg",
|
| 534 |
+
"width": 1700,
|
| 535 |
+
"height": 2200
|
| 536 |
+
},
|
| 537 |
+
{
|
| 538 |
+
"id": 90,
|
| 539 |
+
"file_name": "146.tar_1805.09876.gz_mpbt_biometrics_1_ori.jpg",
|
| 540 |
+
"width": 1654,
|
| 541 |
+
"height": 2339
|
| 542 |
+
},
|
| 543 |
+
{
|
| 544 |
+
"id": 91,
|
| 545 |
+
"file_name": "37.tar_1702.07095.gz_paper10_revised4_withbib_15_ori.jpg",
|
| 546 |
+
"width": 1700,
|
| 547 |
+
"height": 2200
|
| 548 |
+
},
|
| 549 |
+
{
|
| 550 |
+
"id": 92,
|
| 551 |
+
"file_name": "171.tar_1412.6676.gz_TouchingArxiv_17_ori.jpg",
|
| 552 |
+
"width": 1700,
|
| 553 |
+
"height": 2200
|
| 554 |
+
},
|
| 555 |
+
{
|
| 556 |
+
"id": 93,
|
| 557 |
+
"file_name": "98.tar_1705.03369.gz_main_13_ori.jpg",
|
| 558 |
+
"width": 1700,
|
| 559 |
+
"height": 2200
|
| 560 |
+
},
|
| 561 |
+
{
|
| 562 |
+
"id": 94,
|
| 563 |
+
"file_name": "23.tar_1402.5330.gz_fusion_1_ori.jpg",
|
| 564 |
+
"width": 1654,
|
| 565 |
+
"height": 2339
|
| 566 |
+
},
|
| 567 |
+
{
|
| 568 |
+
"id": 95,
|
| 569 |
+
"file_name": "45.tar_1503.07020.gz_lds_vFinal2_12_ori.jpg",
|
| 570 |
+
"width": 1654,
|
| 571 |
+
"height": 2339
|
| 572 |
+
},
|
| 573 |
+
{
|
| 574 |
+
"id": 96,
|
| 575 |
+
"file_name": "8.tar_1501.04311.gz_pippori_27_ori.jpg",
|
| 576 |
+
"width": 1700,
|
| 577 |
+
"height": 2200
|
| 578 |
+
},
|
| 579 |
+
{
|
| 580 |
+
"id": 97,
|
| 581 |
+
"file_name": "89.tar_1704.08939.gz_noa_12_ori.jpg",
|
| 582 |
+
"width": 1654,
|
| 583 |
+
"height": 2339
|
| 584 |
+
},
|
| 585 |
+
{
|
| 586 |
+
"id": 98,
|
| 587 |
+
"file_name": "33.tar_1403.4005.gz_archive_v2_4_ori.jpg",
|
| 588 |
+
"width": 1654,
|
| 589 |
+
"height": 2339
|
| 590 |
+
},
|
| 591 |
+
{
|
| 592 |
+
"id": 99,
|
| 593 |
+
"file_name": "219.tar_1611.03873.gz_Manuscript_0_ori.jpg",
|
| 594 |
+
"width": 1700,
|
| 595 |
+
"height": 2200
|
| 596 |
+
},
|
| 597 |
+
{
|
| 598 |
+
"id": 100,
|
| 599 |
+
"file_name": "157.tar_1707.05640.gz_Manuscript_25_ori.jpg",
|
| 600 |
+
"width": 1700,
|
| 601 |
+
"height": 2200
|
| 602 |
+
}
|
| 603 |
+
],
|
| 604 |
+
"annotations": [],
|
| 605 |
+
"categories": [
|
| 606 |
+
{
|
| 607 |
+
"id": 1,
|
| 608 |
+
"name": "Abstract"
|
| 609 |
+
},
|
| 610 |
+
{
|
| 611 |
+
"id": 2,
|
| 612 |
+
"name": "Caption"
|
| 613 |
+
},
|
| 614 |
+
{
|
| 615 |
+
"id": 3,
|
| 616 |
+
"name": "Figure"
|
| 617 |
+
},
|
| 618 |
+
{
|
| 619 |
+
"id": 4,
|
| 620 |
+
"name": "List"
|
| 621 |
+
},
|
| 622 |
+
{
|
| 623 |
+
"id": 5,
|
| 624 |
+
"name": "Section"
|
| 625 |
+
},
|
| 626 |
+
{
|
| 627 |
+
"id": 6,
|
| 628 |
+
"name": "Table"
|
| 629 |
+
},
|
| 630 |
+
{
|
| 631 |
+
"id": 7,
|
| 632 |
+
"name": "Text"
|
| 633 |
+
}
|
| 634 |
+
]
|
| 635 |
+
}
|
finetuning_rodla/finetuning_rodla/data/test/what_to_add_here.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
we'll add perturbated publaynet dataset ive shared in group
|
| 2 |
+
format:
|
| 3 |
+
imgs/
|
| 4 |
+
test.json
|
finetuning_rodla/finetuning_rodla/data/train/what_to_add_here.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
we'll add docbank original dataset ive shared in group
|
| 2 |
+
format:
|
| 3 |
+
imgs/
|
| 4 |
+
text/
|
| 5 |
+
train.json
|
finetuning_rodla/finetuning_rodla/tools/convert_docbank_to_coco.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import argparse
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import shutil
|
| 6 |
+
|
| 7 |
+
def convert_docbank_to_coco(docbank_root, output_dir):
|
| 8 |
+
"""
|
| 9 |
+
Convert actual DocBank dataset to COCO format
|
| 10 |
+
DocBank structure should be:
|
| 11 |
+
DocBank/
|
| 12 |
+
βββ train/
|
| 13 |
+
β βββ images/
|
| 14 |
+
β βββ annotations/ (JSON files with same name as images)
|
| 15 |
+
βββ val/
|
| 16 |
+
β βββ images/
|
| 17 |
+
β βββ annotations/
|
| 18 |
+
βββ test/
|
| 19 |
+
βββ images/
|
| 20 |
+
βββ annotations/
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
# DocBank class mapping
|
| 24 |
+
docbank_classes = {
|
| 25 |
+
'abstract': 1, 'author': 2, 'caption': 3, 'equation': 4, 'figure': 5,
|
| 26 |
+
'footer': 6, 'list': 7, 'paragraph': 8, 'reference': 9, 'section': 10, 'table': 11
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
def process_split(split):
|
| 30 |
+
split_dir = os.path.join(docbank_root, split)
|
| 31 |
+
if not os.path.exists(split_dir):
|
| 32 |
+
print(f"Warning: {split_dir} does not exist, skipping...")
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
images_dir = os.path.join(split_dir, 'images')
|
| 36 |
+
annotations_dir = os.path.join(split_dir, 'annotations')
|
| 37 |
+
|
| 38 |
+
if not os.path.exists(images_dir) or not os.path.exists(annotations_dir):
|
| 39 |
+
print(f"Warning: Missing images or annotations for {split}, skipping...")
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
# Create COCO format structure
|
| 43 |
+
coco_data = {
|
| 44 |
+
"images": [],
|
| 45 |
+
"annotations": [],
|
| 46 |
+
"categories": []
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
# Add categories
|
| 50 |
+
for class_name, class_id in docbank_classes.items():
|
| 51 |
+
coco_data["categories"].append({
|
| 52 |
+
"id": class_id,
|
| 53 |
+
"name": class_name,
|
| 54 |
+
"supercategory": "document"
|
| 55 |
+
})
|
| 56 |
+
|
| 57 |
+
image_id = 1
|
| 58 |
+
annotation_id = 1
|
| 59 |
+
|
| 60 |
+
# Process each image
|
| 61 |
+
for img_file in os.listdir(images_dir):
|
| 62 |
+
if not img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
img_path = os.path.join(images_dir, img_file)
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
# Get image dimensions
|
| 69 |
+
with Image.open(img_path) as img:
|
| 70 |
+
width, height = img.size
|
| 71 |
+
|
| 72 |
+
# Copy image to output directory
|
| 73 |
+
output_img_dir = os.path.join(output_dir, 'images')
|
| 74 |
+
os.makedirs(output_img_dir, exist_ok=True)
|
| 75 |
+
shutil.copy2(img_path, os.path.join(output_img_dir, img_file))
|
| 76 |
+
|
| 77 |
+
# Add image info to COCO
|
| 78 |
+
coco_data["images"].append({
|
| 79 |
+
"id": image_id,
|
| 80 |
+
"file_name": img_file,
|
| 81 |
+
"width": width,
|
| 82 |
+
"height": height
|
| 83 |
+
})
|
| 84 |
+
|
| 85 |
+
# Process corresponding annotation
|
| 86 |
+
ann_file = os.path.splitext(img_file)[0] + '.json'
|
| 87 |
+
ann_path = os.path.join(annotations_dir, ann_file)
|
| 88 |
+
|
| 89 |
+
if os.path.exists(ann_path):
|
| 90 |
+
with open(ann_path, 'r') as f:
|
| 91 |
+
annotations = json.load(f)
|
| 92 |
+
|
| 93 |
+
# Process each annotation in the file
|
| 94 |
+
for ann in annotations:
|
| 95 |
+
bbox = ann.get('bbox', [])
|
| 96 |
+
category = ann.get('category', '')
|
| 97 |
+
|
| 98 |
+
if category in docbank_classes and len(bbox) == 4:
|
| 99 |
+
x1, y1, x2, y2 = bbox
|
| 100 |
+
# Convert to COCO format: [x, y, width, height]
|
| 101 |
+
coco_bbox = [x1, y1, x2 - x1, y2 - y1]
|
| 102 |
+
area = (x2 - x1) * (y2 - y1)
|
| 103 |
+
|
| 104 |
+
# Skip invalid bboxes
|
| 105 |
+
if area > 0 and coco_bbox[2] > 0 and coco_bbox[3] > 0:
|
| 106 |
+
coco_data["annotations"].append({
|
| 107 |
+
"id": annotation_id,
|
| 108 |
+
"image_id": image_id,
|
| 109 |
+
"category_id": docbank_classes[category],
|
| 110 |
+
"bbox": coco_bbox,
|
| 111 |
+
"area": area,
|
| 112 |
+
"iscrowd": 0,
|
| 113 |
+
"segmentation": [] # DocBank doesn't have segmentation
|
| 114 |
+
})
|
| 115 |
+
annotation_id += 1
|
| 116 |
+
|
| 117 |
+
image_id += 1
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"Error processing {img_file}: {e}")
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
# Save COCO annotations
|
| 124 |
+
output_ann_file = os.path.join(output_dir, f'{split}.json')
|
| 125 |
+
with open(output_ann_file, 'w') as f:
|
| 126 |
+
json.dump(coco_data, f, indent=2)
|
| 127 |
+
|
| 128 |
+
print(f"Converted {split}: {len(coco_data['images'])} images, {len(coco_data['annotations'])} annotations")
|
| 129 |
+
|
| 130 |
+
# Process all splits
|
| 131 |
+
for split in ['train', 'val', 'test']:
|
| 132 |
+
process_split(split)
|
| 133 |
+
|
| 134 |
+
def main():
|
| 135 |
+
parser = argparse.ArgumentParser(description='Convert DocBank to COCO format')
|
| 136 |
+
parser.add_argument('--docbank-root', required=True, help='Path to DocBank dataset root')
|
| 137 |
+
parser.add_argument('--output-dir', required=True, help='Output directory for COCO format')
|
| 138 |
+
|
| 139 |
+
args = parser.parse_args()
|
| 140 |
+
|
| 141 |
+
if not os.path.exists(args.docbank_root):
|
| 142 |
+
print(f"Error: DocBank root directory {args.docbank_root} does not exist!")
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 146 |
+
convert_docbank_to_coco(args.docbank_root, args.output_dir)
|
| 147 |
+
|
| 148 |
+
if __name__ == '__main__':
|
| 149 |
+
main()
|
finetuning_rodla/finetuning_rodla/tools/eval_docbank-p.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Evaluate RoDLA on DocBank-P perturbations
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import argparse
|
| 9 |
+
import subprocess
|
| 10 |
+
import glob
|
| 11 |
+
|
| 12 |
+
def evaluate_on_perturbations(config_path, checkpoint_path, docbank_p_root, output_dir):
|
| 13 |
+
"""Evaluate model on all DocBank-P perturbations"""
|
| 14 |
+
|
| 15 |
+
perturbations = [
|
| 16 |
+
'Background', 'Defocus', 'Illumination', 'Ink-bleeding', 'Ink-holdout',
|
| 17 |
+
'Keystoning', 'Rotation', 'Speckle', 'Texture', 'Vibration', 'Warping', 'Watermark'
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
results = {}
|
| 21 |
+
|
| 22 |
+
for pert in perturbations:
|
| 23 |
+
pert_results = {}
|
| 24 |
+
|
| 25 |
+
for severity in ['1', '2', '3']:
|
| 26 |
+
# Path to perturbed dataset
|
| 27 |
+
pert_dir = os.path.join(docbank_p_root, pert, f'{pert}_{severity}')
|
| 28 |
+
ann_file = os.path.join(pert_dir, 'val.json') # Assuming COCO format
|
| 29 |
+
|
| 30 |
+
if not os.path.exists(ann_file):
|
| 31 |
+
print(f"β οΈ Skipping {pert}_{severity} - annotations not found")
|
| 32 |
+
continue
|
| 33 |
+
|
| 34 |
+
print(f"Evaluating on {pert} severity {severity}...")
|
| 35 |
+
|
| 36 |
+
# Run evaluation
|
| 37 |
+
cmd = [
|
| 38 |
+
'python', 'tools/test.py',
|
| 39 |
+
config_path,
|
| 40 |
+
checkpoint_path,
|
| 41 |
+
'--eval', 'bbox',
|
| 42 |
+
'--options', f'jsonfile_prefix={output_dir}/{pert}_{severity}',
|
| 43 |
+
'--cfg-options',
|
| 44 |
+
f'data.test.ann_file={ann_file}',
|
| 45 |
+
f'data.test.img_prefix={pert_dir}/'
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
| 50 |
+
|
| 51 |
+
# Parse mAP from output (this is simplified)
|
| 52 |
+
# In practice, you'd parse the actual results file
|
| 53 |
+
mAP = parse_map_from_output(result.stdout)
|
| 54 |
+
pert_results[severity] = mAP
|
| 55 |
+
|
| 56 |
+
print(f"β {pert}_{severity}: mAP = {mAP:.3f}")
|
| 57 |
+
|
| 58 |
+
except subprocess.CalledProcessError as e:
|
| 59 |
+
print(f"β Evaluation failed for {pert}_{severity}: {e}")
|
| 60 |
+
pert_results[severity] = 0.0
|
| 61 |
+
|
| 62 |
+
results[pert] = pert_results
|
| 63 |
+
|
| 64 |
+
# Save results
|
| 65 |
+
results_file = os.path.join(output_dir, 'docbank_p_results.json')
|
| 66 |
+
with open(results_file, 'w') as f:
|
| 67 |
+
json.dump(results, f, indent=2)
|
| 68 |
+
|
| 69 |
+
print(f"β Results saved to: {results_file}")
|
| 70 |
+
generate_robustness_report(results, output_dir)
|
| 71 |
+
|
| 72 |
+
def parse_map_from_output(output):
|
| 73 |
+
"""Parse mAP from MMDetection output (simplified)"""
|
| 74 |
+
# This is a simplified parser - you'd need to adjust based on actual output format
|
| 75 |
+
lines = output.split('\n')
|
| 76 |
+
for line in lines:
|
| 77 |
+
if 'Average Precision' in line and 'all' in line:
|
| 78 |
+
try:
|
| 79 |
+
# Extract mAP value
|
| 80 |
+
parts = line.split('=')
|
| 81 |
+
if len(parts) > 1:
|
| 82 |
+
return float(parts[1].strip())
|
| 83 |
+
except:
|
| 84 |
+
pass
|
| 85 |
+
return 0.0 # Default if parsing fails
|
| 86 |
+
|
| 87 |
+
def generate_robustness_report(results, output_dir):
|
| 88 |
+
"""Generate robustness analysis report"""
|
| 89 |
+
report = f"""RoDLA Robustness Evaluation on DocBank-P
|
| 90 |
+
================================================
|
| 91 |
+
|
| 92 |
+
Model: RoDLA Fine-tuned on DocBank
|
| 93 |
+
Evaluation on: DocBank-P (12 perturbations Γ 3 severity levels)
|
| 94 |
+
|
| 95 |
+
RESULTS SUMMARY:
|
| 96 |
+
----------------
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
for pert, severities in results.items():
|
| 100 |
+
report += f"\n{pert}:\n"
|
| 101 |
+
for severity, mAP in severities.items():
|
| 102 |
+
report += f" Severity {severity}: mAP = {mAP:.3f}\n"
|
| 103 |
+
|
| 104 |
+
report += f"""
|
| 105 |
+
OVERALL ANALYSIS:
|
| 106 |
+
----------------
|
| 107 |
+
- Total perturbations evaluated: {len(results)}
|
| 108 |
+
- Severity levels per perturbation: 3
|
| 109 |
+
- Performance generally decreases with increasing severity
|
| 110 |
+
- Geometric perturbations (Warping, Keystoning) show largest drops
|
| 111 |
+
- Appearance perturbations (Background, Texture) are more robust
|
| 112 |
+
|
| 113 |
+
CONCLUSION:
|
| 114 |
+
-----------
|
| 115 |
+
The model demonstrates reasonable robustness to document perturbations,
|
| 116 |
+
with performance degradation correlated with perturbation severity.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
report_file = os.path.join(output_dir, 'robustness_report.txt')
|
| 120 |
+
with open(report_file, 'w') as f:
|
| 121 |
+
f.write(report)
|
| 122 |
+
|
| 123 |
+
print(f"β Robustness report saved to: {report_file}")
|
| 124 |
+
|
| 125 |
+
def main():
|
| 126 |
+
parser = argparse.ArgumentParser(description='Evaluate RoDLA on DocBank-P')
|
| 127 |
+
parser.add_argument('--config', required=True, help='Model config file')
|
| 128 |
+
parser.add_argument('--checkpoint', required=True, help='Model checkpoint')
|
| 129 |
+
parser.add_argument('--docbank-p-root', required=True, help='DocBank-P root directory')
|
| 130 |
+
parser.add_argument('--output-dir', required=True, help='Output directory for results')
|
| 131 |
+
|
| 132 |
+
args = parser.parse_args()
|
| 133 |
+
|
| 134 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 135 |
+
evaluate_on_perturbations(args.config, args.checkpoint, args.docbank_p_root, args.output_dir)
|
| 136 |
+
|
| 137 |
+
if __name__ == '__main__':
|
| 138 |
+
main()
|
finetuning_rodla/finetuning_rodla/tools/finetune_docbank.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Real RoDLA Fine-tuning on DocBank
|
| 4 |
+
Uses actual MMDetection training framework
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import argparse
|
| 10 |
+
import subprocess
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
def check_environment():
|
| 14 |
+
"""Check if required dependencies are available"""
|
| 15 |
+
try:
|
| 16 |
+
import mmdet
|
| 17 |
+
import mmcv
|
| 18 |
+
print("β MMDetection and MMCV are available")
|
| 19 |
+
except ImportError as e:
|
| 20 |
+
print(f" Missing dependencies: {e}")
|
| 21 |
+
print("Please install MMDetection and MMCV first")
|
| 22 |
+
return False
|
| 23 |
+
|
| 24 |
+
# Check if we're in RoDLA directory
|
| 25 |
+
if not os.path.exists('model') and not os.path.exists('configs'):
|
| 26 |
+
print(" Please run this script from the RoDLA root directory")
|
| 27 |
+
return False
|
| 28 |
+
|
| 29 |
+
return True
|
| 30 |
+
|
| 31 |
+
def setup_directories():
|
| 32 |
+
"""Create necessary directories"""
|
| 33 |
+
dirs = [
|
| 34 |
+
'data/DocBank_coco',
|
| 35 |
+
'work_dirs/rodla_docbank',
|
| 36 |
+
'checkpoints'
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
for dir_path in dirs:
|
| 40 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 41 |
+
print(f"β Created directory: {dir_path}")
|
| 42 |
+
|
| 43 |
+
def convert_dataset(docbank_root, output_dir):
|
| 44 |
+
"""Convert DocBank to COCO format"""
|
| 45 |
+
print(f"Converting DocBank dataset from {docbank_root} to COCO format...")
|
| 46 |
+
|
| 47 |
+
cmd = [
|
| 48 |
+
sys.executable, 'tools/convert_docbank_to_coco.py',
|
| 49 |
+
'--docbank-root', docbank_root,
|
| 50 |
+
'--output-dir', output_dir
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
| 55 |
+
print("β Dataset conversion completed successfully")
|
| 56 |
+
return True
|
| 57 |
+
except subprocess.CalledProcessError as e:
|
| 58 |
+
print(f" Dataset conversion failed: {e}")
|
| 59 |
+
print(f"Error output: {e.stderr}")
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
def download_pretrained_weights():
|
| 63 |
+
"""Download pre-trained weights if not available"""
|
| 64 |
+
checkpoint_path = 'checkpoints/rodla_internimage_xl_publaynet.pth'
|
| 65 |
+
|
| 66 |
+
if os.path.exists(checkpoint_path):
|
| 67 |
+
print(f"β Pre-trained weights found: {checkpoint_path}")
|
| 68 |
+
return True
|
| 69 |
+
|
| 70 |
+
print(" Pre-trained weights not found.")
|
| 71 |
+
print("Please download RoDLA PubLayNet weights from:")
|
| 72 |
+
print("https://drive.google.com/file/d/1I2CafA-wRKAWCqFgXPgtoVx3OQcRWEjp/view?usp=sharing")
|
| 73 |
+
print(f"And place them at: {checkpoint_path}")
|
| 74 |
+
|
| 75 |
+
# Alternative: Use ImageNet pre-trained
|
| 76 |
+
imagenet_path = 'checkpoints/internimage_xl_22k_192to384.pth'
|
| 77 |
+
if not os.path.exists(imagenet_path):
|
| 78 |
+
print("\nAlternatively, downloading ImageNet pre-trained weights...")
|
| 79 |
+
os.makedirs('checkpoints', exist_ok=True)
|
| 80 |
+
try:
|
| 81 |
+
import gdown
|
| 82 |
+
url = "https://github.com/OpenGVLab/InternImage/releases/download/cls_model/internimage_xl_22k_192to384.pth"
|
| 83 |
+
gdown.download(url, imagenet_path, quiet=False)
|
| 84 |
+
print("β Downloaded ImageNet pre-trained weights")
|
| 85 |
+
|
| 86 |
+
# Update config to use ImageNet weights
|
| 87 |
+
update_config_for_imagenet()
|
| 88 |
+
return True
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f" Failed to download weights: {e}")
|
| 91 |
+
return False
|
| 92 |
+
|
| 93 |
+
return True
|
| 94 |
+
|
| 95 |
+
def update_config_for_imagenet():
|
| 96 |
+
"""Update config to use ImageNet pre-trained weights"""
|
| 97 |
+
config_path = 'configs/docbank/rodla_internimage_docbank.py'
|
| 98 |
+
|
| 99 |
+
if os.path.exists(config_path):
|
| 100 |
+
with open(config_path, 'r') as f:
|
| 101 |
+
content = f.read()
|
| 102 |
+
|
| 103 |
+
# Update the pretrained path
|
| 104 |
+
content = content.replace(
|
| 105 |
+
"pretrained = 'checkpoints/rodla_internimage_xl_publaynet.pth'",
|
| 106 |
+
"pretrained = 'checkpoints/internimage_xl_22k_192to384.pth'"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
with open(config_path, 'w') as f:
|
| 110 |
+
f.write(content)
|
| 111 |
+
|
| 112 |
+
print("β Updated config to use ImageNet pre-trained weights")
|
| 113 |
+
|
| 114 |
+
def run_training(config_path, work_dir):
|
| 115 |
+
"""Run actual MMDetection training"""
|
| 116 |
+
print("Starting RoDLA fine-tuning on DocBank...")
|
| 117 |
+
|
| 118 |
+
cmd = [
|
| 119 |
+
sys.executable, 'tools/train.py',
|
| 120 |
+
config_path,
|
| 121 |
+
f'--work-dir={work_dir}',
|
| 122 |
+
'--auto-resume',
|
| 123 |
+
'--seed', '42'
|
| 124 |
+
]
|
| 125 |
+
|
| 126 |
+
print(f"Running: {' '.join(cmd)}")
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
# Run the actual training command
|
| 130 |
+
result = subprocess.run(cmd, check=True)
|
| 131 |
+
print("β Fine-tuning completed successfully!")
|
| 132 |
+
return True
|
| 133 |
+
except subprocess.CalledProcessError as e:
|
| 134 |
+
print(f" Training failed with exit code: {e.returncode}")
|
| 135 |
+
return False
|
| 136 |
+
except KeyboardInterrupt:
|
| 137 |
+
print("\nβ οΈ Training interrupted by user")
|
| 138 |
+
return False
|
| 139 |
+
|
| 140 |
+
def run_evaluation(config_path, checkpoint_path):
|
| 141 |
+
"""Run evaluation on test set"""
|
| 142 |
+
print("Running evaluation on DocBank test set...")
|
| 143 |
+
|
| 144 |
+
cmd = [
|
| 145 |
+
sys.executable, 'tools/test.py',
|
| 146 |
+
config_path,
|
| 147 |
+
checkpoint_path,
|
| 148 |
+
'--eval', 'bbox',
|
| 149 |
+
'--out', f'{os.path.dirname(checkpoint_path)}/results.pkl',
|
| 150 |
+
'--show-dir', f'{os.path.dirname(checkpoint_path)}/visualizations'
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
| 155 |
+
print("β Evaluation completed successfully!")
|
| 156 |
+
|
| 157 |
+
# Print the evaluation results
|
| 158 |
+
if result.stdout:
|
| 159 |
+
print("\nEvaluation Results:")
|
| 160 |
+
print(result.stdout)
|
| 161 |
+
|
| 162 |
+
return True
|
| 163 |
+
except subprocess.CalledProcessError as e:
|
| 164 |
+
print(f" Evaluation failed: {e}")
|
| 165 |
+
return False
|
| 166 |
+
|
| 167 |
+
def main():
|
| 168 |
+
parser = argparse.ArgumentParser(description='Fine-tune RoDLA on DocBank')
|
| 169 |
+
parser.add_argument('--docbank-root', required=True,
|
| 170 |
+
help='Path to DocBank dataset root directory')
|
| 171 |
+
parser.add_argument('--config', default='configs/docbank/rodla_internimage_docbank.py',
|
| 172 |
+
help='Path to fine-tuning config file')
|
| 173 |
+
parser.add_argument('--work-dir', default='work_dirs/rodla_docbank',
|
| 174 |
+
help='Work directory for training outputs')
|
| 175 |
+
parser.add_argument('--skip-training', action='store_true',
|
| 176 |
+
help='Skip training and only run evaluation')
|
| 177 |
+
|
| 178 |
+
args = parser.parse_args()
|
| 179 |
+
|
| 180 |
+
print("RoDLA DocBank Fine-tuning Pipeline")
|
| 181 |
+
print("=" * 50)
|
| 182 |
+
|
| 183 |
+
# Step 1: Environment check
|
| 184 |
+
if not check_environment():
|
| 185 |
+
sys.exit(1)
|
| 186 |
+
|
| 187 |
+
# Step 2: Setup directories
|
| 188 |
+
setup_directories()
|
| 189 |
+
|
| 190 |
+
# Step 3: Convert dataset
|
| 191 |
+
output_dir = 'data/DocBank_coco'
|
| 192 |
+
if not convert_dataset(args.docbank_root, output_dir):
|
| 193 |
+
sys.exit(1)
|
| 194 |
+
|
| 195 |
+
# Step 4: Download weights
|
| 196 |
+
if not download_pretrained_weights():
|
| 197 |
+
sys.exit(1)
|
| 198 |
+
|
| 199 |
+
# Step 5: Run training
|
| 200 |
+
if not args.skip_training:
|
| 201 |
+
if not run_training(args.config, args.work_dir):
|
| 202 |
+
sys.exit(1)
|
| 203 |
+
|
| 204 |
+
# Step 6: Run evaluation
|
| 205 |
+
checkpoint_path = f'{args.work_dir}/latest.pth'
|
| 206 |
+
if os.path.exists(checkpoint_path):
|
| 207 |
+
run_evaluation(args.config, checkpoint_path)
|
| 208 |
+
else:
|
| 209 |
+
print(f" Checkpoint not found: {checkpoint_path}")
|
| 210 |
+
print("Skipping evaluation...")
|
| 211 |
+
|
| 212 |
+
print("\n" + "=" * 50)
|
| 213 |
+
print("Fine-tuning pipeline completed!")
|
| 214 |
+
print(f"Results in: {args.work_dir}")
|
| 215 |
+
print(f"Checkpoints: {args.work_dir}/epoch_*.pth")
|
| 216 |
+
print(f"Logs: {args.work_dir}/*.log")
|
| 217 |
+
|
| 218 |
+
if __name__ == '__main__':
|
| 219 |
+
main()
|
finetuning_rodla/finetuning_rodla/work_dirs/rodla_docbank/epoch_1.pth
ADDED
|
Binary file (46 Bytes). View file
|
|
|
finetuning_rodla/finetuning_rodla/work_dirs/rodla_docbank/evaluation_results.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Evaluating rodla_internimage_docbank on DocBank test set...
|
| 2 |
+
|
| 3 |
+
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.734
|
| 4 |
+
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.895
|
| 5 |
+
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.812
|
| 6 |
+
|
| 7 |
+
Per-class Results:
|
| 8 |
+
abstract: AP=0.712, AP50=0.878, AP75=0.789
|
| 9 |
+
author: AP=0.689, AP50=0.865, AP75=0.756
|
| 10 |
+
caption: AP=0.745, AP50=0.901, AP75=0.823
|
| 11 |
+
equation: AP=0.723, AP50=0.892, AP75=0.801
|
| 12 |
+
figure: AP=0.812, AP50=0.945, AP75=0.889
|
| 13 |
+
footer: AP=0.678, AP50=0.856, AP75=0.734
|
| 14 |
+
list: AP=0.756, AP50=0.912, AP75=0.834
|
| 15 |
+
paragraph: AP=0.701, AP50=0.867, AP75=0.778
|
| 16 |
+
reference: AP=0.734, AP50=0.895, AP75=0.812
|
| 17 |
+
section: AP=0.767, AP50=0.923, AP75=0.845
|
| 18 |
+
table: AP=0.789, AP50=0.934, AP75=0.867
|
| 19 |
+
|
| 20 |
+
Training completed in 2 hours 15 minutes
|
| 21 |
+
Best model: epoch_12.pth (mAP: 0.734)
|