|
|
import pdb |
|
|
import scipy |
|
|
import numpy as np |
|
|
|
|
|
scipy.inf = np.inf |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
from dataset.custom_types import MsaInfo |
|
|
from msaf.eval import compute_results |
|
|
from postprocessing.functional import postprocess_functional_structure |
|
|
from x_transformers import Encoder |
|
|
import bisect |
|
|
|
|
|
|
|
|
class Head(nn.Module): |
|
|
def __init__(self, input_dim, output_dim, hidden_dims=None, activation="silu"): |
|
|
super().__init__() |
|
|
hidden_dims = hidden_dims or [] |
|
|
act_layers = {"relu": nn.ReLU, "silu": nn.SiLU, "gelu": nn.GELU} |
|
|
act_layer = act_layers.get(activation.lower()) |
|
|
if not act_layer: |
|
|
raise ValueError(f"Unsupported activation: {activation}") |
|
|
|
|
|
dims = [input_dim] + hidden_dims + [output_dim] |
|
|
layers = [] |
|
|
for i in range(len(dims) - 1): |
|
|
layers.append(nn.Linear(dims[i], dims[i + 1])) |
|
|
if i < len(dims) - 2: |
|
|
layers.append(act_layer()) |
|
|
self.net = nn.Sequential(*layers) |
|
|
|
|
|
def reset_parameters(self, confidence): |
|
|
bias_value = -torch.log(torch.tensor((1 - confidence) / confidence)) |
|
|
self.net[-1].bias.data.fill_(bias_value.item()) |
|
|
|
|
|
def forward(self, x): |
|
|
batch, T, C = x.shape |
|
|
x = x.reshape(-1, C) |
|
|
x = self.net(x) |
|
|
return x.reshape(batch, T, -1) |
|
|
|
|
|
|
|
|
class WrapedTransformerEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, input_dim, transformer_input_dim, num_layers=1, nhead=8, dropout=0.1 |
|
|
): |
|
|
super().__init__() |
|
|
self.input_dim = input_dim |
|
|
self.transformer_input_dim = transformer_input_dim |
|
|
|
|
|
if input_dim != transformer_input_dim: |
|
|
self.input_proj = nn.Sequential( |
|
|
nn.Linear(input_dim, transformer_input_dim), |
|
|
nn.LayerNorm(transformer_input_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout * 0.5), |
|
|
nn.Linear(transformer_input_dim, transformer_input_dim), |
|
|
) |
|
|
else: |
|
|
self.input_proj = nn.Identity() |
|
|
|
|
|
self.transformer = Encoder( |
|
|
dim=transformer_input_dim, |
|
|
depth=num_layers, |
|
|
heads=nhead, |
|
|
layer_dropout=dropout, |
|
|
attn_dropout=dropout, |
|
|
ff_dropout=dropout, |
|
|
attn_flash=True, |
|
|
rotary_pos_emb=True, |
|
|
) |
|
|
|
|
|
def forward(self, x, src_key_padding_mask=None): |
|
|
""" |
|
|
The input src_key_padding_mask is a B x T boolean mask, where True indicates masked positions. |
|
|
However, in x-transformers, False indicates masked positions. |
|
|
Therefore, it needs to be converted so that False represents masked positions. |
|
|
""" |
|
|
x = self.input_proj(x) |
|
|
mask = ( |
|
|
~torch.tensor(src_key_padding_mask, dtype=torch.bool, device=x.device) |
|
|
if src_key_padding_mask is not None |
|
|
else None |
|
|
) |
|
|
return self.transformer(x, mask=mask) |
|
|
|
|
|
|
|
|
def prefix_dict(d, prefix: str): |
|
|
if prefix: |
|
|
return d |
|
|
return {prefix + key: value for key, value in d.items()} |
|
|
|
|
|
|
|
|
class TimeDownsample(nn.Module): |
|
|
def __init__( |
|
|
self, dim_in, dim_out=None, kernel_size=5, stride=5, padding=0, dropout=0.1 |
|
|
): |
|
|
super().__init__() |
|
|
self.dim_out = dim_out or dim_in |
|
|
assert self.dim_out % 2 == 0 |
|
|
|
|
|
self.depthwise_conv = nn.Conv1d( |
|
|
in_channels=dim_in, |
|
|
out_channels=dim_in, |
|
|
kernel_size=kernel_size, |
|
|
stride=stride, |
|
|
padding=padding, |
|
|
groups=dim_in, |
|
|
bias=False, |
|
|
) |
|
|
self.pointwise_conv = nn.Conv1d( |
|
|
in_channels=dim_in, |
|
|
out_channels=self.dim_out, |
|
|
kernel_size=1, |
|
|
bias=False, |
|
|
) |
|
|
self.pool = nn.AvgPool1d(kernel_size, stride, padding=padding) |
|
|
self.norm1 = nn.LayerNorm(self.dim_out) |
|
|
self.act1 = nn.GELU() |
|
|
self.dropout1 = nn.Dropout(dropout) |
|
|
|
|
|
if dim_in != self.dim_out: |
|
|
self.residual_conv = nn.Conv1d( |
|
|
dim_in, self.dim_out, kernel_size=1, bias=False |
|
|
) |
|
|
else: |
|
|
self.residual_conv = None |
|
|
|
|
|
def forward(self, x): |
|
|
residual = x |
|
|
|
|
|
x_c = x.transpose(1, 2) |
|
|
x_c = self.depthwise_conv(x_c) |
|
|
x_c = self.pointwise_conv(x_c) |
|
|
|
|
|
|
|
|
res = self.pool(residual.transpose(1, 2)) |
|
|
if self.residual_conv: |
|
|
res = self.residual_conv(res) |
|
|
x_c = x_c + res |
|
|
x_c = x_c.transpose(1, 2) |
|
|
x_c = self.norm1(x_c) |
|
|
x_c = self.act1(x_c) |
|
|
x_c = self.dropout1(x_c) |
|
|
return x_c |
|
|
|
|
|
|
|
|
class AddFuse(nn.Module): |
|
|
def __init__(self): |
|
|
super(AddFuse, self).__init__() |
|
|
|
|
|
def forward(self, x, cond): |
|
|
return x + cond |
|
|
|
|
|
|
|
|
class TVLoss1D(nn.Module): |
|
|
def __init__( |
|
|
self, beta=1.0, lambda_tv=0.4, boundary_threshold=0.01, reduction_weight=0.1 |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
beta: Exponential parameter for TV loss (recommended 0.5~1.0) |
|
|
lambda_tv: Overall weight for TV loss |
|
|
boundary_threshold: Label threshold to determine if a region is a "boundary area" (e.g., 0.01) |
|
|
reduction_weight: Scaling factor for TV penalty within boundary regions (e.g., 0.1, meaning only 10% penalty) |
|
|
""" |
|
|
super().__init__() |
|
|
self.beta = beta |
|
|
self.lambda_tv = lambda_tv |
|
|
self.boundary_threshold = boundary_threshold |
|
|
self.reduction_weight = reduction_weight |
|
|
|
|
|
def forward(self, pred, target=None): |
|
|
""" |
|
|
Args: |
|
|
pred: (B, T) or (B, T, 1), float boundary scores output by the model |
|
|
target: (B, T) or (B, T, 1), ground truth labels (optional, used for spatial weighting if provided) |
|
|
|
|
|
Returns: |
|
|
scalar: weighted TV loss |
|
|
""" |
|
|
if pred.dim() == 3: |
|
|
pred = pred.squeeze(-1) |
|
|
if target is not None and target.dim() == 3: |
|
|
target = target.squeeze(-1) |
|
|
|
|
|
diff = pred[:, 1:] - pred[:, :-1] |
|
|
tv_base = torch.pow(torch.abs(diff) + 1e-8, self.beta) |
|
|
|
|
|
if target is None: |
|
|
return self.lambda_tv * tv_base.mean() |
|
|
|
|
|
left_in_boundary = target[:, :-1] > self.boundary_threshold |
|
|
right_in_boundary = target[:, 1:] > self.boundary_threshold |
|
|
near_boundary = left_in_boundary | right_in_boundary |
|
|
weight_mask = torch.where( |
|
|
near_boundary, |
|
|
self.reduction_weight * torch.ones_like(tv_base), |
|
|
torch.ones_like(tv_base), |
|
|
) |
|
|
tv_weighted = (tv_base * weight_mask).mean() |
|
|
return self.lambda_tv * tv_weighted |
|
|
|
|
|
|
|
|
class SoftmaxFocalLoss(nn.Module): |
|
|
""" |
|
|
Softmax Focal Loss for single-label multi-class classification. |
|
|
Suitable for mutually exclusive classes. |
|
|
""" |
|
|
|
|
|
def __init__(self, alpha: float = 0.25, gamma: float = 2.0): |
|
|
super().__init__() |
|
|
self.alpha = alpha |
|
|
self.gamma = gamma |
|
|
|
|
|
def forward(self, pred: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
pred: [B, T, C], raw logits |
|
|
targets: [B, T, C] (soft) or [B, T] (hard, dtype=long) |
|
|
Returns: |
|
|
loss: scalar or [B, T] depending on reduction |
|
|
""" |
|
|
log_probs = F.log_softmax(pred, dim=-1) |
|
|
probs = torch.exp(log_probs) |
|
|
|
|
|
if targets.dtype == torch.long: |
|
|
targets_onehot = F.one_hot(targets, num_classes=pred.size(-1)).float() |
|
|
else: |
|
|
targets_onehot = targets |
|
|
|
|
|
p_t = (probs * targets_onehot).sum(dim=-1) |
|
|
p_t = p_t.clamp(min=1e-8, max=1.0 - 1e-8) |
|
|
|
|
|
if self.alpha > 0: |
|
|
alpha_t = self.alpha * targets_onehot + (1 - self.alpha) * ( |
|
|
1 - targets_onehot |
|
|
) |
|
|
alpha_weight = (alpha_t * targets_onehot).sum(dim=-1) |
|
|
else: |
|
|
alpha_weight = 1.0 |
|
|
|
|
|
focal_weight = (1 - p_t) ** self.gamma |
|
|
ce_loss = -log_probs * targets_onehot |
|
|
ce_loss = ce_loss.sum(dim=-1) |
|
|
|
|
|
loss = alpha_weight * focal_weight * ce_loss |
|
|
return loss |
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
self.input_norm = nn.LayerNorm(config.input_dim) |
|
|
self.mixed_win_downsample = nn.Linear(config.input_dim_raw, config.input_dim) |
|
|
self.dataset_class_prefix = nn.Embedding( |
|
|
num_embeddings=config.num_dataset_classes, |
|
|
embedding_dim=config.transformer_encoder_input_dim, |
|
|
) |
|
|
self.down_sample_conv = TimeDownsample( |
|
|
dim_in=config.input_dim, |
|
|
dim_out=config.transformer_encoder_input_dim, |
|
|
kernel_size=config.down_sample_conv_kernel_size, |
|
|
stride=config.down_sample_conv_stride, |
|
|
dropout=config.down_sample_conv_dropout, |
|
|
padding=config.down_sample_conv_padding, |
|
|
) |
|
|
self.AddFuse = AddFuse() |
|
|
self.transformer = WrapedTransformerEncoder( |
|
|
input_dim=config.transformer_encoder_input_dim, |
|
|
transformer_input_dim=config.transformer_input_dim, |
|
|
num_layers=config.num_transformer_layers, |
|
|
nhead=config.transformer_nhead, |
|
|
dropout=config.transformer_dropout, |
|
|
) |
|
|
self.boundary_TVLoss1D = TVLoss1D( |
|
|
beta=config.boundary_tv_loss_beta, |
|
|
lambda_tv=config.boundary_tv_loss_lambda, |
|
|
boundary_threshold=config.boundary_tv_loss_boundary_threshold, |
|
|
reduction_weight=config.boundary_tv_loss_reduction_weight, |
|
|
) |
|
|
self.label_focal_loss = SoftmaxFocalLoss( |
|
|
alpha=config.label_focal_loss_alpha, gamma=config.label_focal_loss_gamma |
|
|
) |
|
|
self.boundary_head = Head(config.transformer_input_dim, 1) |
|
|
self.function_head = Head(config.transformer_input_dim, config.num_classes) |
|
|
|
|
|
def cal_metrics(self, gt_info: MsaInfo, msa_info: MsaInfo): |
|
|
assert gt_info[-1][1] == "end" and msa_info[-1][1] == "end", ( |
|
|
"gt_info and msa_info should end with 'end'" |
|
|
) |
|
|
gt_info_labels = [label for time_, label in gt_info][:-1] |
|
|
gt_info_inters = [time_ for time_, label in gt_info] |
|
|
gt_info_inters = np.column_stack( |
|
|
[np.array(gt_info_inters[:-1]), np.array(gt_info_inters[1:])] |
|
|
) |
|
|
|
|
|
msa_info_labels = [label for time_, label in msa_info][:-1] |
|
|
msa_info_inters = [time_ for time_, label in msa_info] |
|
|
msa_info_inters = np.column_stack( |
|
|
[np.array(msa_info_inters[:-1]), np.array(msa_info_inters[1:])] |
|
|
) |
|
|
result = compute_results( |
|
|
ann_inter=gt_info_inters, |
|
|
est_inter=msa_info_inters, |
|
|
ann_labels=gt_info_labels, |
|
|
est_labels=msa_info_labels, |
|
|
bins=11, |
|
|
est_file="test.txt", |
|
|
weight=0.58, |
|
|
) |
|
|
return result |
|
|
|
|
|
def cal_acc( |
|
|
self, ann_info: MsaInfo | str, est_info: MsaInfo | str, post_digit: int = 3 |
|
|
): |
|
|
ann_info_time = [ |
|
|
int(round(time_, post_digit) * (10**post_digit)) |
|
|
for time_, label in ann_info |
|
|
] |
|
|
est_info_time = [ |
|
|
int(round(time_, post_digit) * (10**post_digit)) |
|
|
for time_, label in est_info |
|
|
] |
|
|
|
|
|
common_start_time = max(ann_info_time[0], est_info_time[0]) |
|
|
common_end_time = min(ann_info_time[-1], est_info_time[-1]) |
|
|
|
|
|
time_points = {common_start_time, common_end_time} |
|
|
time_points.update( |
|
|
{ |
|
|
time_ |
|
|
for time_ in ann_info_time |
|
|
if common_start_time <= time_ <= common_end_time |
|
|
} |
|
|
) |
|
|
time_points.update( |
|
|
{ |
|
|
time_ |
|
|
for time_ in est_info_time |
|
|
if common_start_time <= time_ <= common_end_time |
|
|
} |
|
|
) |
|
|
|
|
|
time_points = sorted(time_points) |
|
|
total_duration, total_score = 0, 0 |
|
|
|
|
|
for idx in range(len(time_points) - 1): |
|
|
duration = time_points[idx + 1] - time_points[idx] |
|
|
ann_label = ann_info[ |
|
|
bisect.bisect_right(ann_info_time, time_points[idx]) - 1 |
|
|
][1] |
|
|
est_label = est_info[ |
|
|
bisect.bisect_right(est_info_time, time_points[idx]) - 1 |
|
|
][1] |
|
|
total_duration += duration |
|
|
if ann_label == est_label: |
|
|
total_score += duration |
|
|
return total_score / total_duration |
|
|
|
|
|
def infer_with_metrics(self, batch, prefix: str = None): |
|
|
with torch.no_grad(): |
|
|
logits = self.forward_func(batch) |
|
|
|
|
|
losses = self.compute_losses(logits, batch, prefix=None) |
|
|
|
|
|
expanded_mask = batch["label_id_masks"].expand( |
|
|
-1, logits["function_logits"].size(1), -1 |
|
|
) |
|
|
logits["function_logits"] = logits["function_logits"].masked_fill( |
|
|
expanded_mask, -float("inf") |
|
|
) |
|
|
|
|
|
msa_info = postprocess_functional_structure( |
|
|
logits=logits, config=self.config |
|
|
) |
|
|
gt_info = batch["msa_infos"][0] |
|
|
results = self.cal_metrics(gt_info=gt_info, msa_info=msa_info) |
|
|
|
|
|
ret_results = { |
|
|
"loss": losses["loss"].item(), |
|
|
"HitRate_3P": results["HitRate_3P"], |
|
|
"HitRate_3R": results["HitRate_3R"], |
|
|
"HitRate_3F": results["HitRate_3F"], |
|
|
"HitRate_0.5P": results["HitRate_0.5P"], |
|
|
"HitRate_0.5R": results["HitRate_0.5R"], |
|
|
"HitRate_0.5F": results["HitRate_0.5F"], |
|
|
"PWF": results["PWF"], |
|
|
"PWP": results["PWP"], |
|
|
"PWR": results["PWR"], |
|
|
"Sf": results["Sf"], |
|
|
"So": results["So"], |
|
|
"Su": results["Su"], |
|
|
"acc": self.cal_acc(ann_info=gt_info, est_info=msa_info), |
|
|
} |
|
|
if prefix: |
|
|
ret_results = prefix_dict(ret_results, prefix) |
|
|
|
|
|
return ret_results |
|
|
|
|
|
def infer( |
|
|
self, |
|
|
input_embeddings, |
|
|
dataset_ids, |
|
|
label_id_masks, |
|
|
prefix: str = None, |
|
|
with_logits=False, |
|
|
): |
|
|
with torch.no_grad(): |
|
|
input_embeddings = self.mixed_win_downsample(input_embeddings) |
|
|
input_embeddings = self.input_norm(input_embeddings) |
|
|
logits = self.down_sample_conv(input_embeddings) |
|
|
|
|
|
dataset_prefix = self.dataset_class_prefix(dataset_ids) |
|
|
dataset_prefix_expand = dataset_prefix.unsqueeze(1).expand( |
|
|
logits.size(0), 1, -1 |
|
|
) |
|
|
logits = self.AddFuse(x=logits, cond=dataset_prefix_expand) |
|
|
logits = self.transformer(x=logits, src_key_padding_mask=None) |
|
|
|
|
|
function_logits = self.function_head(logits) |
|
|
boundary_logits = self.boundary_head(logits).squeeze(-1) |
|
|
|
|
|
logits = { |
|
|
"function_logits": function_logits, |
|
|
"boundary_logits": boundary_logits, |
|
|
} |
|
|
|
|
|
expanded_mask = label_id_masks.expand( |
|
|
-1, logits["function_logits"].size(1), -1 |
|
|
) |
|
|
logits["function_logits"] = logits["function_logits"].masked_fill( |
|
|
expanded_mask, -float("inf") |
|
|
) |
|
|
|
|
|
msa_info = postprocess_functional_structure( |
|
|
logits=logits, config=self.config |
|
|
) |
|
|
|
|
|
return (msa_info, logits) if with_logits else msa_info |
|
|
|
|
|
def compute_losses(self, outputs, batch, prefix: str = None): |
|
|
loss = 0.0 |
|
|
losses = {} |
|
|
|
|
|
loss_section = F.binary_cross_entropy_with_logits( |
|
|
outputs["boundary_logits"], |
|
|
batch["widen_true_boundaries"], |
|
|
reduction="none", |
|
|
) |
|
|
loss_section += self.config.boundary_tvloss_weight * self.boundary_TVLoss1D( |
|
|
pred=outputs["boundary_logits"], |
|
|
target=batch["widen_true_boundaries"], |
|
|
) |
|
|
loss_function = F.cross_entropy( |
|
|
outputs["function_logits"].transpose(1, 2), |
|
|
batch["true_functions"].transpose(1, 2), |
|
|
reduction="none", |
|
|
) |
|
|
|
|
|
ttt = self.config.label_focal_loss_weight * self.label_focal_loss( |
|
|
pred=outputs["function_logits"], targets=batch["true_functions"] |
|
|
) |
|
|
loss_function += ttt |
|
|
|
|
|
float_masks = (~batch["masks"]).float() |
|
|
boundary_mask = batch.get("boundary_mask", None) |
|
|
function_mask = batch.get("function_mask", None) |
|
|
if boundary_mask is not None: |
|
|
boundary_mask = (~boundary_mask).float() |
|
|
else: |
|
|
boundary_mask = 1 |
|
|
|
|
|
if function_mask is not None: |
|
|
function_mask = (~function_mask).float() |
|
|
else: |
|
|
function_mask = 1 |
|
|
|
|
|
loss_section = torch.mean(boundary_mask * float_masks * loss_section) |
|
|
loss_function = torch.mean(function_mask * float_masks * loss_function) |
|
|
|
|
|
loss_section *= self.config.loss_weight_section |
|
|
loss_function *= self.config.loss_weight_function |
|
|
|
|
|
if self.config.learn_label: |
|
|
loss += loss_function |
|
|
if self.config.learn_segment: |
|
|
loss += loss_section |
|
|
|
|
|
losses.update( |
|
|
loss=loss, |
|
|
loss_section=loss_section, |
|
|
loss_function=loss_function, |
|
|
) |
|
|
if prefix: |
|
|
losses = prefix_dict(losses, prefix) |
|
|
return losses |
|
|
|
|
|
def forward_func(self, batch): |
|
|
input_embeddings = batch["input_embeddings"] |
|
|
input_embeddings = self.mixed_win_downsample(input_embeddings) |
|
|
input_embeddings = self.input_norm(input_embeddings) |
|
|
logits = self.down_sample_conv(input_embeddings) |
|
|
|
|
|
dataset_prefix = self.dataset_class_prefix(batch["dataset_ids"]) |
|
|
logits = self.AddFuse(x=logits, cond=dataset_prefix.unsqueeze(1)) |
|
|
src_key_padding_mask = batch["masks"] |
|
|
logits = self.transformer(x=logits, src_key_padding_mask=src_key_padding_mask) |
|
|
|
|
|
function_logits = self.function_head(logits) |
|
|
boundary_logits = self.boundary_head(logits).squeeze(-1) |
|
|
|
|
|
logits = { |
|
|
"function_logits": function_logits, |
|
|
"boundary_logits": boundary_logits, |
|
|
} |
|
|
return logits |
|
|
|
|
|
def forward(self, batch): |
|
|
logits = self.forward_func(batch) |
|
|
losses = self.compute_losses(logits, batch, prefix=None) |
|
|
return logits, losses["loss"], losses |
|
|
|