saumyap29's picture
amt context history
65b3cc6
raw
history blame
10.2 kB
import os
#import spaces # Enables ZeroGPU on Hugging Face
import gradio as gr
import torch
from dataclasses import asdict
from transformers import AutoModelForCausalLM
from anticipation.sample import generate
from anticipation.convert import events_to_midi, midi_to_events
from anticipation.tokenize import extract_instruments
from anticipation import ops
from mido import MidiFile # parse MIDI explicitly to avoid .time error
from pyharp.core import ModelCard, build_endpoint
from pyharp.labels import LabelList
# Model Choices
SMALL_MODEL = "stanford-crfm/music-small-800k"
MEDIUM_MODEL = "stanford-crfm/music-medium-800k"
LARGE_MODEL = "stanford-crfm/music-large-800k"
# Model Card (new pyharp)
model_card = ModelCard(
name="Anticipatory Music Transformer",
description=(
"Generate musical accompaniment for your existing vamp using the Anticipatory Music Transformer. "
"Input: a MIDI file with a short accompaniment (vamp) followed by a melody line. "
"Output: a new MIDI file with extended accompaniment matching the melody continuation. "
"Use the sliders to choose model size and how much of the song is used as context."
),
author="John Thickstun, David Hall, Chris Donahue, Percy Liang",
tags=["midi", "generation", "accompaniment"]
)
# Model cache
_model_cache = {}
def load_amt_model(model_choice: str):
"""Loads and caches the AMT model inside the worker process (same behavior as old app)."""
if model_choice in _model_cache:
return _model_cache[model_choice]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_choice == LARGE_MODEL:
print(f"Loading {LARGE_MODEL} (low_cpu_mem_usage, fp16 on CUDA if available)...")
model = AutoModelForCausalLM.from_pretrained(
LARGE_MODEL,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True
).to(device)
else:
print(f"Loading {model_choice} ...")
model = AutoModelForCausalLM.from_pretrained(model_choice).to(device)
_model_cache[model_choice] = model
return model
def find_melody_program(mid, debug=False):
"""
Automatically detect the melody track's program number from a MIDI file.
Uses a balanced heuristic: pitch + note count + temporal coverage.
"""
track_stats = []
total_duration = 0
for i, track in enumerate(mid.tracks):
pitches, times = [], []
current_time = 0
current_program = None
track_note_count = 0
for msg in track:
if msg.type not in ("note_on", "program_change"):
continue
current_time += msg.time
if msg.type == "program_change":
current_program = msg.program
continue
# note_on event
if msg.velocity > 0:
pitches.append(msg.note)
times.append(current_time)
track_note_count += 1
# Early stop if enough notes gathered
if track_note_count >= 100:
break
# Skip empty or trivial tracks
if not pitches:
continue
# Compute duration for this track and update total_duration
track_duration = max(times) - min(times)
total_duration = max(total_duration, current_time)
mean_pitch = sum(pitches) / len(pitches)
polyphony = len(set(pitches)) / len(pitches)
coverage = track_duration / total_duration if total_duration > 0 else 0
track_stats.append((i, mean_pitch, len(pitches), current_program, polyphony, coverage))
if not track_stats:
return None, False
if len(track_stats) == 1:
prog = track_stats[0][3]
if debug:
if prog == 0:
print("Single-track MIDI detected, program 0 (Acoustic Grand Piano) will be treated as melody.")
else:
print(f"Single-track MIDI detected, using program {prog or 'None'}")
return prog, prog is not None
candidates = [t for t in track_stats if t[3] is not None and t[3] > 0]
has_valid_programs = len(candidates) > 0
if not candidates:
candidates = track_stats
if debug:
print(f"\nCandidates: {len(candidates)} tracks")
max_notes = max(t[2] for t in candidates)
max_pitch = max(t[1] for t in candidates)
min_pitch = min(t[1] for t in candidates)
pitch_span = max_pitch - min_pitch if max_pitch > min_pitch else 1
best_score = -1
best_program = None
best_track = None
best_pitch = None
for t in candidates:
idx, pitch, notes, prog, poly, coverage = t
pitch_norm = (pitch - min_pitch) / pitch_span
notes_norm = notes / max_notes
score = (pitch_norm * 0.35) + (notes_norm * 0.35) + (coverage * 0.30)
if poly < 0.15:
score *= 0.95
if 55 <= pitch <= 75:
score *= 1.1
if notes >= 30:
score *= 1.05
if coverage > 0.7:
score *= 1.15
if score > best_score:
best_score = score
best_program = prog
best_track = idx
best_pitch = pitch
return best_program, has_valid_programs
def auto_extract_melody(mid, debug=False):
"""
Extract melody events from a MIDI object (already loaded via MidiFile).
Optimized to avoid re-reading the file from disk.
Returns: (all_events, melody_events)
"""
events = midi_to_events(mid)
melody_program, has_valid_program = find_melody_program(mid, debug=debug)
if not has_valid_program or melody_program is None or melody_program == 0:
if debug:
print("No valid program changes in MIDI, using all events as melody")
return events, events
events, melody = extract_instruments(events, [melody_program])
if len(melody) == 0:
if debug:
print("No events found for selected program, using all events")
return events, events
if debug:
print(f"Extracted {len(melody)} melody events from program {melody_program}")
return events, melody
#@spaces.GPU
# Core generation
def generate_accompaniment(midi_path: str, model_choice: str, history_length: float):
"""
Generates accompaniment for the entire MIDI input, conditioned on user-selected history length.
FIX: parse MIDI with mido.MidiFile before midi_to_events to avoid 'str' .time error.
"""
model = load_amt_model(model_choice)
# Parse MIDI correctly, then convert to events
mid = MidiFile(midi_path)
print(f"Loaded MIDI file: type {mid.type} ({'single track' if mid.type == 0 else 'multi-track'})")
# Automatically detect and extract melody
all_events, melody = auto_extract_melody(mid, debug=True)
if len(melody) == 0:
print("No melody detected; using all events")
melody = all_events
## Compute total time
mid_time = mid.length or 0
ops_time = ops.max_time(all_events, seconds=True)
total_time = round(max(mid_time, ops_time))
# History portion
#history = ops.clip(all_events, 0, history_length, clip_duration=False)
melody_history = ops.clip(all_events, 0, history_length, clip_duration=False)
melody_future = ops.clip(melody, history_length, total_time, clip_duration=False)
# Generate accompaniment for the remaining duration
accompaniment = generate(
model,
start_time=history_length, # start after history
end_time=total_time, # go to end
inputs=melody_history,
controls=melody_future,
top_p=0.95,
debug=False
)
# Combine accompaniment + melody and clip
output_events = ops.clip(
ops.combine(accompaniment, melody),
0,
total_time,
clip_duration=True
)
print(f"[DEBUG] Context events: {len(melody_history)}, Future melody events: {len(melody_future)}")
print(f"Generating from {history_length:.2f}s -> {total_time:.2f}s "
f"(duration = {total_time - history_length:.2f}s)")
# Save MIDI
output_midi = "generated_accompaniment_huggingface.mid"
mid_out = events_to_midi(output_events)
mid_out.save(output_midi)
return output_midi, None
# HARP process fn (JSON FIRST)
def process_fn(input_midi_path, model_choice, history_length):
"""
Returns (JSON, MIDI filepath) to satisfy HARP client's expectation that the 0th item is an object.
"""
output_midi, error_message = generate_accompaniment(
input_midi_path,
model_choice,
float(history_length)
)
if error_message:
# JSON first, then no file
return {"message": error_message}, None
labels = LabelList() # add label entries if desired
# JSON first, then MIDI filepath
return asdict(labels), output_midi
# Gradio + HARP UI
with gr.Blocks() as demo:
gr.Markdown("## 🎼 Anticipatory Music Transformer")
# Inputs
input_midi = gr.File(
file_types=[".mid", ".midi"],
label="Input MIDI File",
type="filepath",
).harp_required(True)
model_dropdown = gr.Dropdown(
choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL],
value=MEDIUM_MODEL,
label="Select AMT Model (Faster vs. Higher Quality)"
)
history_slider = gr.Slider(
minimum=1, maximum=10, step=1, value=5,
label="Select History Length (seconds)"
)
# Outputs (JSON FIRST)
output_labels = gr.JSON(label="Labels / Metadata")
output_midi = gr.File(
file_types=[".mid", ".midi"],
label="Generated MIDI Output",
type="filepath",
)
# Build HARP endpoint (new signature)
_ = build_endpoint(
model_card=model_card,
input_components=[
input_midi,
model_dropdown,
history_slider
],
output_components=[
output_labels, # JSON FIRST
output_midi # MIDI SECOND
],
process_fn=process_fn
)
# Launch App
demo.launch(share=True, show_error=True, debug=True)