|
|
import os |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
from dataclasses import asdict |
|
|
|
|
|
from transformers import AutoModelForCausalLM |
|
|
from anticipation.sample import generate |
|
|
from anticipation.convert import events_to_midi, midi_to_events |
|
|
from anticipation.tokenize import extract_instruments |
|
|
from anticipation import ops |
|
|
from mido import MidiFile |
|
|
|
|
|
from pyharp.core import ModelCard, build_endpoint |
|
|
from pyharp.labels import LabelList |
|
|
|
|
|
|
|
|
|
|
|
SMALL_MODEL = "stanford-crfm/music-small-800k" |
|
|
MEDIUM_MODEL = "stanford-crfm/music-medium-800k" |
|
|
LARGE_MODEL = "stanford-crfm/music-large-800k" |
|
|
|
|
|
|
|
|
|
|
|
model_card = ModelCard( |
|
|
name="Anticipatory Music Transformer", |
|
|
description=( |
|
|
"Generate musical accompaniment for your existing vamp using the Anticipatory Music Transformer. " |
|
|
"Input: a MIDI file with a short accompaniment (vamp) followed by a melody line. " |
|
|
"Output: a new MIDI file with extended accompaniment matching the melody continuation. " |
|
|
"Use the sliders to choose model size and how much of the song is used as context." |
|
|
), |
|
|
author="John Thickstun, David Hall, Chris Donahue, Percy Liang", |
|
|
tags=["midi", "generation", "accompaniment"] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
_model_cache = {} |
|
|
|
|
|
def load_amt_model(model_choice: str): |
|
|
"""Loads and caches the AMT model inside the worker process (same behavior as old app).""" |
|
|
if model_choice in _model_cache: |
|
|
return _model_cache[model_choice] |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
if model_choice == LARGE_MODEL: |
|
|
print(f"Loading {LARGE_MODEL} (low_cpu_mem_usage, fp16 on CUDA if available)...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
LARGE_MODEL, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
low_cpu_mem_usage=True |
|
|
).to(device) |
|
|
else: |
|
|
print(f"Loading {model_choice} ...") |
|
|
model = AutoModelForCausalLM.from_pretrained(model_choice).to(device) |
|
|
|
|
|
_model_cache[model_choice] = model |
|
|
return model |
|
|
|
|
|
def find_melody_program(mid, debug=False): |
|
|
""" |
|
|
Automatically detect the melody track's program number from a MIDI file. |
|
|
Uses a balanced heuristic: pitch + note count + temporal coverage. |
|
|
""" |
|
|
track_stats = [] |
|
|
total_duration = 0 |
|
|
|
|
|
for i, track in enumerate(mid.tracks): |
|
|
pitches, times = [], [] |
|
|
current_time = 0 |
|
|
current_program = None |
|
|
track_note_count = 0 |
|
|
|
|
|
for msg in track: |
|
|
if msg.type not in ("note_on", "program_change"): |
|
|
continue |
|
|
|
|
|
current_time += msg.time |
|
|
if msg.type == "program_change": |
|
|
current_program = msg.program |
|
|
continue |
|
|
|
|
|
|
|
|
if msg.velocity > 0: |
|
|
pitches.append(msg.note) |
|
|
times.append(current_time) |
|
|
track_note_count += 1 |
|
|
|
|
|
|
|
|
if track_note_count >= 100: |
|
|
break |
|
|
|
|
|
|
|
|
if not pitches: |
|
|
continue |
|
|
|
|
|
|
|
|
track_duration = max(times) - min(times) |
|
|
total_duration = max(total_duration, current_time) |
|
|
|
|
|
mean_pitch = sum(pitches) / len(pitches) |
|
|
polyphony = len(set(pitches)) / len(pitches) |
|
|
coverage = track_duration / total_duration if total_duration > 0 else 0 |
|
|
|
|
|
track_stats.append((i, mean_pitch, len(pitches), current_program, polyphony, coverage)) |
|
|
|
|
|
if not track_stats: |
|
|
return None, False |
|
|
|
|
|
if len(track_stats) == 1: |
|
|
prog = track_stats[0][3] |
|
|
if debug: |
|
|
if prog == 0: |
|
|
print("Single-track MIDI detected, program 0 (Acoustic Grand Piano) will be treated as melody.") |
|
|
else: |
|
|
print(f"Single-track MIDI detected, using program {prog or 'None'}") |
|
|
return prog, prog is not None |
|
|
|
|
|
candidates = [t for t in track_stats if t[3] is not None and t[3] > 0] |
|
|
has_valid_programs = len(candidates) > 0 |
|
|
if not candidates: |
|
|
candidates = track_stats |
|
|
|
|
|
if debug: |
|
|
print(f"\nCandidates: {len(candidates)} tracks") |
|
|
|
|
|
max_notes = max(t[2] for t in candidates) |
|
|
max_pitch = max(t[1] for t in candidates) |
|
|
min_pitch = min(t[1] for t in candidates) |
|
|
pitch_span = max_pitch - min_pitch if max_pitch > min_pitch else 1 |
|
|
|
|
|
best_score = -1 |
|
|
best_program = None |
|
|
best_track = None |
|
|
best_pitch = None |
|
|
|
|
|
for t in candidates: |
|
|
idx, pitch, notes, prog, poly, coverage = t |
|
|
pitch_norm = (pitch - min_pitch) / pitch_span |
|
|
notes_norm = notes / max_notes |
|
|
|
|
|
score = (pitch_norm * 0.35) + (notes_norm * 0.35) + (coverage * 0.30) |
|
|
|
|
|
if poly < 0.15: |
|
|
score *= 0.95 |
|
|
if 55 <= pitch <= 75: |
|
|
score *= 1.1 |
|
|
if notes >= 30: |
|
|
score *= 1.05 |
|
|
if coverage > 0.7: |
|
|
score *= 1.15 |
|
|
|
|
|
if score > best_score: |
|
|
best_score = score |
|
|
best_program = prog |
|
|
best_track = idx |
|
|
best_pitch = pitch |
|
|
|
|
|
return best_program, has_valid_programs |
|
|
|
|
|
|
|
|
def auto_extract_melody(mid, debug=False): |
|
|
""" |
|
|
Extract melody events from a MIDI object (already loaded via MidiFile). |
|
|
Optimized to avoid re-reading the file from disk. |
|
|
Returns: (all_events, melody_events) |
|
|
""" |
|
|
events = midi_to_events(mid) |
|
|
|
|
|
melody_program, has_valid_program = find_melody_program(mid, debug=debug) |
|
|
|
|
|
if not has_valid_program or melody_program is None or melody_program == 0: |
|
|
if debug: |
|
|
print("No valid program changes in MIDI, using all events as melody") |
|
|
return events, events |
|
|
|
|
|
events, melody = extract_instruments(events, [melody_program]) |
|
|
|
|
|
if len(melody) == 0: |
|
|
if debug: |
|
|
print("No events found for selected program, using all events") |
|
|
return events, events |
|
|
|
|
|
if debug: |
|
|
print(f"Extracted {len(melody)} melody events from program {melody_program}") |
|
|
|
|
|
return events, melody |
|
|
|
|
|
|
|
|
|
|
|
def generate_accompaniment(midi_path: str, model_choice: str, history_length: float): |
|
|
""" |
|
|
Generates accompaniment for the entire MIDI input, conditioned on user-selected history length. |
|
|
FIX: parse MIDI with mido.MidiFile before midi_to_events to avoid 'str' .time error. |
|
|
""" |
|
|
model = load_amt_model(model_choice) |
|
|
|
|
|
|
|
|
mid = MidiFile(midi_path) |
|
|
print(f"Loaded MIDI file: type {mid.type} ({'single track' if mid.type == 0 else 'multi-track'})") |
|
|
|
|
|
|
|
|
all_events, melody = auto_extract_melody(mid, debug=True) |
|
|
if len(melody) == 0: |
|
|
print("No melody detected; using all events") |
|
|
melody = all_events |
|
|
|
|
|
|
|
|
mid_time = mid.length or 0 |
|
|
ops_time = ops.max_time(all_events, seconds=True) |
|
|
total_time = round(max(mid_time, ops_time)) |
|
|
|
|
|
|
|
|
|
|
|
melody_history = ops.clip(all_events, 0, history_length, clip_duration=False) |
|
|
melody_future = ops.clip(melody, history_length, total_time, clip_duration=False) |
|
|
|
|
|
|
|
|
accompaniment = generate( |
|
|
model, |
|
|
start_time=history_length, |
|
|
end_time=total_time, |
|
|
inputs=melody_history, |
|
|
controls=melody_future, |
|
|
top_p=0.95, |
|
|
debug=False |
|
|
) |
|
|
|
|
|
|
|
|
output_events = ops.clip( |
|
|
ops.combine(accompaniment, melody), |
|
|
0, |
|
|
total_time, |
|
|
clip_duration=True |
|
|
) |
|
|
|
|
|
print(f"[DEBUG] Context events: {len(melody_history)}, Future melody events: {len(melody_future)}") |
|
|
|
|
|
print(f"Generating from {history_length:.2f}s -> {total_time:.2f}s " |
|
|
f"(duration = {total_time - history_length:.2f}s)") |
|
|
|
|
|
|
|
|
output_midi = "generated_accompaniment_huggingface.mid" |
|
|
mid_out = events_to_midi(output_events) |
|
|
mid_out.save(output_midi) |
|
|
|
|
|
return output_midi, None |
|
|
|
|
|
|
|
|
|
|
|
def process_fn(input_midi_path, model_choice, history_length): |
|
|
""" |
|
|
Returns (JSON, MIDI filepath) to satisfy HARP client's expectation that the 0th item is an object. |
|
|
""" |
|
|
output_midi, error_message = generate_accompaniment( |
|
|
input_midi_path, |
|
|
model_choice, |
|
|
float(history_length) |
|
|
) |
|
|
|
|
|
if error_message: |
|
|
|
|
|
return {"message": error_message}, None |
|
|
|
|
|
labels = LabelList() |
|
|
|
|
|
return asdict(labels), output_midi |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## 🎼 Anticipatory Music Transformer") |
|
|
|
|
|
|
|
|
input_midi = gr.File( |
|
|
file_types=[".mid", ".midi"], |
|
|
label="Input MIDI File", |
|
|
type="filepath", |
|
|
).harp_required(True) |
|
|
|
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL], |
|
|
value=MEDIUM_MODEL, |
|
|
label="Select AMT Model (Faster vs. Higher Quality)" |
|
|
) |
|
|
|
|
|
history_slider = gr.Slider( |
|
|
minimum=1, maximum=10, step=1, value=5, |
|
|
label="Select History Length (seconds)" |
|
|
) |
|
|
|
|
|
|
|
|
output_labels = gr.JSON(label="Labels / Metadata") |
|
|
output_midi = gr.File( |
|
|
file_types=[".mid", ".midi"], |
|
|
label="Generated MIDI Output", |
|
|
type="filepath", |
|
|
) |
|
|
|
|
|
|
|
|
_ = build_endpoint( |
|
|
model_card=model_card, |
|
|
input_components=[ |
|
|
input_midi, |
|
|
model_dropdown, |
|
|
history_slider |
|
|
], |
|
|
output_components=[ |
|
|
output_labels, |
|
|
output_midi |
|
|
], |
|
|
process_fn=process_fn |
|
|
) |
|
|
|
|
|
|
|
|
demo.launch(share=True, show_error=True, debug=True) |