Spaces:
Running
Running
| """ | |
| Medical Data Preprocessing for AI training | |
| Optimized for medical images and text with memory constraints | |
| """ | |
| import logging | |
| import numpy as np | |
| from typing import Dict, Any, List, Optional, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image, ImageEnhance, ImageFilter | |
| import cv2 | |
| import re | |
| logger = logging.getLogger(__name__) | |
| class MedicalPreprocessor: | |
| """ | |
| Medical data preprocessor with memory optimization | |
| """ | |
| def __init__(self, target_size: Tuple[int, int] = (512, 512), | |
| normalize_images: bool = True): | |
| """ | |
| Initialize medical preprocessor | |
| Args: | |
| target_size: Target size for image resizing | |
| normalize_images: Whether to normalize images | |
| """ | |
| self.target_size = target_size | |
| self.normalize_images = normalize_images | |
| # Medical text preprocessing patterns | |
| self.medical_patterns = { | |
| 'measurements': r'\d+\.?\d*\s*(mm|cm|m|ml|l|kg|g|mg)', | |
| 'dates': r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', | |
| 'times': r'\d{1,2}:\d{2}(?::\d{2})?', | |
| 'medical_codes': r'[A-Z]\d{2}\.?\d*', | |
| 'dosages': r'\d+\.?\d*\s*(mg|g|ml|units?)', | |
| } | |
| # Common medical abbreviations | |
| self.medical_abbreviations = { | |
| 'pt': 'patient', | |
| 'pts': 'patients', | |
| 'dx': 'diagnosis', | |
| 'tx': 'treatment', | |
| 'hx': 'history', | |
| 'sx': 'symptoms', | |
| 'rx': 'prescription', | |
| 'w/': 'with', | |
| 'w/o': 'without', | |
| 'c/o': 'complains of', | |
| 'r/o': 'rule out', | |
| 's/p': 'status post', | |
| 'nkda': 'no known drug allergies', | |
| 'sob': 'shortness of breath', | |
| 'cp': 'chest pain', | |
| 'abd': 'abdomen', | |
| 'ext': 'extremities' | |
| } | |
| logger.info(f"Medical Preprocessor initialized with target size {target_size}") | |
| def preprocess_medical_image(self, image: torch.Tensor, | |
| modality: str = 'unknown', | |
| enhance_contrast: bool = True) -> torch.Tensor: | |
| """ | |
| Preprocess medical image with modality-specific optimizations | |
| Args: | |
| image: Input image tensor | |
| modality: Medical imaging modality (CT, MRI, X-ray, etc.) | |
| enhance_contrast: Whether to enhance contrast | |
| Returns: | |
| Preprocessed image tensor | |
| """ | |
| try: | |
| # Ensure image is float tensor | |
| if image.dtype != torch.float32: | |
| image = image.float() | |
| # Handle different input shapes | |
| if len(image.shape) == 2: | |
| image = image.unsqueeze(0) # Add channel dimension | |
| elif len(image.shape) == 4: | |
| image = image.squeeze(0) # Remove batch dimension if present | |
| # Resize to target size | |
| if image.shape[-2:] != self.target_size: | |
| image = F.interpolate( | |
| image.unsqueeze(0), | |
| size=self.target_size, | |
| mode='bilinear', | |
| align_corners=False | |
| ).squeeze(0) | |
| # Apply modality-specific preprocessing | |
| image = self._apply_modality_specific_processing(image, modality) | |
| # Enhance contrast if requested | |
| if enhance_contrast: | |
| image = self._enhance_medical_image_contrast(image) | |
| # Normalize if requested | |
| if self.normalize_images: | |
| image = self._normalize_medical_image(image) | |
| # Ensure proper range [0, 1] | |
| image = torch.clamp(image, 0.0, 1.0) | |
| return image | |
| except Exception as e: | |
| logger.error(f"Error preprocessing medical image: {e}") | |
| # Return dummy image on error | |
| return torch.zeros(1, *self.target_size) | |
| def _apply_modality_specific_processing(self, image: torch.Tensor, | |
| modality: str) -> torch.Tensor: | |
| """Apply modality-specific image processing""" | |
| modality_lower = modality.lower() | |
| try: | |
| if 'ct' in modality_lower: | |
| # CT scan specific processing | |
| image = self._process_ct_image(image) | |
| elif 'mri' in modality_lower: | |
| # MRI specific processing | |
| image = self._process_mri_image(image) | |
| elif 'xray' in modality_lower or 'x-ray' in modality_lower: | |
| # X-ray specific processing | |
| image = self._process_xray_image(image) | |
| elif 'ultrasound' in modality_lower: | |
| # Ultrasound specific processing | |
| image = self._process_ultrasound_image(image) | |
| return image | |
| except Exception as e: | |
| logger.warning(f"Error in modality-specific processing for {modality}: {e}") | |
| return image | |
| def _process_ct_image(self, image: torch.Tensor) -> torch.Tensor: | |
| """Process CT scan images""" | |
| # CT images often need windowing adjustments | |
| # Apply soft tissue window as default | |
| image = torch.clamp(image, 0.0, 1.0) | |
| # Enhance contrast for better tissue differentiation | |
| image = self._apply_gamma_correction(image, gamma=0.8) | |
| return image | |
| def _process_mri_image(self, image: torch.Tensor) -> torch.Tensor: | |
| """Process MRI images""" | |
| # MRI images often have good contrast already | |
| # Apply mild enhancement | |
| image = self._apply_gamma_correction(image, gamma=0.9) | |
| return image | |
| def _process_xray_image(self, image: torch.Tensor) -> torch.Tensor: | |
| """Process X-ray images""" | |
| # X-rays often need contrast enhancement | |
| image = self._enhance_medical_image_contrast(image, factor=1.2) | |
| # Apply histogram equalization equivalent | |
| image = self._apply_histogram_equalization(image) | |
| return image | |
| def _process_ultrasound_image(self, image: torch.Tensor) -> torch.Tensor: | |
| """Process ultrasound images""" | |
| # Ultrasound images often need noise reduction | |
| image = self._apply_noise_reduction(image) | |
| return image | |
| def _enhance_medical_image_contrast(self, image: torch.Tensor, | |
| factor: float = 1.1) -> torch.Tensor: | |
| """Enhance contrast of medical images""" | |
| try: | |
| # Apply contrast enhancement | |
| mean_val = torch.mean(image) | |
| enhanced = (image - mean_val) * factor + mean_val | |
| return torch.clamp(enhanced, 0.0, 1.0) | |
| except Exception as e: | |
| logger.warning(f"Error enhancing contrast: {e}") | |
| return image | |
| def _apply_gamma_correction(self, image: torch.Tensor, | |
| gamma: float = 1.0) -> torch.Tensor: | |
| """Apply gamma correction to image""" | |
| try: | |
| return torch.pow(image, gamma) | |
| except Exception as e: | |
| logger.warning(f"Error applying gamma correction: {e}") | |
| return image | |
| def _apply_histogram_equalization(self, image: torch.Tensor) -> torch.Tensor: | |
| """Apply histogram equalization equivalent""" | |
| try: | |
| # Convert to numpy for processing | |
| image_np = image.squeeze().numpy() | |
| # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| # Convert to uint8 for CLAHE | |
| image_uint8 = (image_np * 255).astype(np.uint8) | |
| equalized = clahe.apply(image_uint8) | |
| # Convert back to tensor | |
| result = torch.from_numpy(equalized.astype(np.float32) / 255.0) | |
| # Restore original shape | |
| if len(image.shape) == 3: | |
| result = result.unsqueeze(0) | |
| return result | |
| except Exception as e: | |
| logger.warning(f"Error applying histogram equalization: {e}") | |
| return image | |
| def _apply_noise_reduction(self, image: torch.Tensor) -> torch.Tensor: | |
| """Apply noise reduction to image""" | |
| try: | |
| # Simple Gaussian blur for noise reduction | |
| kernel_size = 3 | |
| sigma = 0.5 | |
| # Create Gaussian kernel | |
| kernel = self._create_gaussian_kernel(kernel_size, sigma) | |
| kernel = kernel.unsqueeze(0).unsqueeze(0) # Add batch and channel dims | |
| # Apply convolution | |
| if len(image.shape) == 3: | |
| image_input = image.unsqueeze(0) # Add batch dimension | |
| else: | |
| image_input = image | |
| filtered = F.conv2d(image_input, kernel, padding=kernel_size//2) | |
| # Remove batch dimension if added | |
| if len(image.shape) == 3: | |
| filtered = filtered.squeeze(0) | |
| return filtered | |
| except Exception as e: | |
| logger.warning(f"Error applying noise reduction: {e}") | |
| return image | |
| def _create_gaussian_kernel(self, kernel_size: int, sigma: float) -> torch.Tensor: | |
| """Create Gaussian kernel for filtering""" | |
| coords = torch.arange(kernel_size, dtype=torch.float32) | |
| coords -= kernel_size // 2 | |
| g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) | |
| g /= g.sum() | |
| # Create 2D kernel | |
| kernel = g[:, None] * g[None, :] | |
| return kernel | |
| def _normalize_medical_image(self, image: torch.Tensor) -> torch.Tensor: | |
| """Normalize medical image""" | |
| try: | |
| # Z-score normalization per image | |
| mean_val = torch.mean(image) | |
| std_val = torch.std(image) | |
| if std_val > 0: | |
| normalized = (image - mean_val) / std_val | |
| # Scale to [0, 1] range | |
| normalized = (normalized - normalized.min()) / (normalized.max() - normalized.min()) | |
| else: | |
| normalized = image | |
| return normalized | |
| except Exception as e: | |
| logger.warning(f"Error normalizing image: {e}") | |
| return image | |
| def preprocess_medical_text(self, text: str, | |
| expand_abbreviations: bool = True, | |
| remove_phi: bool = True) -> str: | |
| """ | |
| Preprocess medical text | |
| Args: | |
| text: Input medical text | |
| expand_abbreviations: Whether to expand medical abbreviations | |
| remove_phi: Whether to remove potential PHI (Protected Health Information) | |
| Returns: | |
| Preprocessed text | |
| """ | |
| try: | |
| if not isinstance(text, str): | |
| text = str(text) | |
| # Convert to lowercase for processing | |
| processed_text = text.lower() | |
| # Remove potential PHI if requested | |
| if remove_phi: | |
| processed_text = self._remove_phi(processed_text) | |
| # Expand medical abbreviations | |
| if expand_abbreviations: | |
| processed_text = self._expand_medical_abbreviations(processed_text) | |
| # Clean up text | |
| processed_text = self._clean_medical_text(processed_text) | |
| # Limit length to prevent memory issues | |
| max_length = 2048 | |
| if len(processed_text) > max_length: | |
| processed_text = processed_text[:max_length] + "..." | |
| return processed_text | |
| except Exception as e: | |
| logger.error(f"Error preprocessing medical text: {e}") | |
| return text # Return original text on error | |
| def _remove_phi(self, text: str) -> str: | |
| """Remove potential Protected Health Information""" | |
| # Remove dates | |
| text = re.sub(self.medical_patterns['dates'], '[DATE]', text) | |
| # Remove times | |
| text = re.sub(self.medical_patterns['times'], '[TIME]', text) | |
| # Remove phone numbers | |
| text = re.sub(r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', '[PHONE]', text) | |
| # Remove email addresses | |
| text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]', text) | |
| # Remove potential names (very basic - would need more sophisticated NER in practice) | |
| text = re.sub(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]', text) | |
| return text | |
| def _expand_medical_abbreviations(self, text: str) -> str: | |
| """Expand common medical abbreviations""" | |
| for abbrev, expansion in self.medical_abbreviations.items(): | |
| # Use word boundaries to avoid partial matches | |
| pattern = r'\b' + re.escape(abbrev) + r'\b' | |
| text = re.sub(pattern, expansion, text, flags=re.IGNORECASE) | |
| return text | |
| def _clean_medical_text(self, text: str) -> str: | |
| """Clean and normalize medical text""" | |
| # Remove extra whitespace | |
| text = re.sub(r'\s+', ' ', text) | |
| # Remove special characters but keep medical-relevant ones | |
| text = re.sub(r'[^\w\s\-\.\,\:\;\(\)\/\%]', '', text) | |
| # Strip leading/trailing whitespace | |
| text = text.strip() | |
| return text | |
| def batch_preprocess_medical_data(self, batch: Dict[str, Any]) -> Dict[str, Any]: | |
| """Preprocess a batch of medical data""" | |
| processed_batch = {} | |
| try: | |
| # Process images if present | |
| if 'images' in batch and batch['images'] is not None: | |
| images = batch['images'] | |
| processed_images = [] | |
| for i, image in enumerate(images): | |
| # Get modality if available | |
| modality = 'unknown' | |
| if 'modalities' in batch and i < len(batch['modalities']): | |
| modality = batch['modalities'][i] | |
| processed_image = self.preprocess_medical_image(image, modality) | |
| processed_images.append(processed_image) | |
| processed_batch['images'] = torch.stack(processed_images) | |
| # Process texts if present | |
| if 'texts' in batch: | |
| texts = batch['texts'] | |
| processed_texts = [] | |
| for text in texts: | |
| processed_text = self.preprocess_medical_text(text) | |
| processed_texts.append(processed_text) | |
| processed_batch['texts'] = processed_texts | |
| # Copy other fields | |
| for key, value in batch.items(): | |
| if key not in ['images', 'texts']: | |
| processed_batch[key] = value | |
| return processed_batch | |
| except Exception as e: | |
| logger.error(f"Error in batch preprocessing: {e}") | |
| return batch # Return original batch on error | |