Wednesday, July 1, 2026

OmniVoice Deep Dive

This is a deep dive explainer on the OmniVoice Research Paper / Repo and will help you understand OmniVoice TTS Deep Learning Systems.

  1. Dataset and understanding the data
  2. Training and how it works
    • Inputs
    • Masking
    • Bidirectional Transformer
    • Predict
    • Loss function
  3. Inference
    • Duration Estimation
    • Time Step Schedule
    • Iterative Loop and Unmasking

There are two jobs this paper asks one model to do:

  • Voice design — conjure a voice from an instruction in plain text.
  • Voice cloning — reproduce a specific person's voice from a handful of samples.

Both ambitions live or die on the corpus. So that's where we start.

Dataset and Understanding the Data

The dataset was built from open source data and totals approximately ~581,000 hours of audio across many languages. Each entry in the dataset had an instruct field describing the sample.

Only id, audio_path, and text were strictly required — the rest of the instruct metadata came from the underlying source datasets:

Emiliaopenxlab.org.cn/datasets/Amphion/Emilia
ParaSpeechCapsgithub.com/ajd12342/paraspeechcaps
NvSpeechgithub.com/open-mmlab/NvSpeech
NonVerbalSpeech-38Kgithub.com/zhourongleiden/NonVerbalSpeech-38K

{"id": "EN_B00000_S00001_W000001", "text": "The quick brown fox jumps over the lazy dog.", "text_pinyin": null, "language_id": "en", "instruct": "A young female speaker with a clear American accent, medium pitch, moderate speaking rate", "audio_duration": 4.21, "clean_start_token_idx": null}

The Corpus Is Deeply Uneven

English alone brings 206,061 hours. Mandarin adds another 111,343. By the time you reach Swedish at the tail of the top 20, you're down to 2,453 — and the other 626 languages sit below even that. Train naively on this and the model simply learns to be fluent in English and mumble everything else.

Flattening the Curve With Upsampling

High-resource languages drown out everyone else. The goal isn't to make the distribution uniform — that throws away genuine signal — but to let low-resource languages be seen often enough to matter.

The paper borrows a LoRA-style upsampling trick: compute a repetition factor for each language and repeat its data that many times per epoch.

ri = max( 1, round[ (Dmax / Di)1−β ] )

  • ri — repetition factor for language i, how many times its data repeats
  • Di — total audio hours for language i
  • Dmax — largest corpus = English, 206,061 h
  • β ∈ [0,1] — smoothing knob, β = 0.8 in practice

How it behaves. For the biggest language, Dmax/Di = 1, so ri = 1 — no repetition. For smaller languages the ratio climbs, and raising it to the power (1 − β) dampens how aggressively it climbs.

For example, with β = 0.8 and English as the biggest language (206,061h):

  • English (206,061h): ratio = 1, so r = 1, seen once.
  • Swahili (418h): ratio = 493, raised to 0.2 → r ≈ 3, repeated 3×.
  • Afrikaans (4.4h): ratio = 46,832, raised to 0.2 → r ≈ 9, repeated 9×.

Code — omnivoice/data/dataset.py

Training and How It Works

Inputs

There are two streams of input for the voice to be cloned:

  • Acoustic token matrix [T × C]
  • Text tokens — instruction and transcript

Acoustic token matrix [T × C]

The acoustic token matrix carries the details of the voice to be cloned. The audio tokenizer is HiggsAudioV2 (eustlb/higgs-audio-v2-tokenizer). It produces 8 codebook layers (C) with a vocabulary of 1025 per layer (1024 codes + 1 mask token). T is the number of time steps.

File: omnivoice/scripts/extract_audio_tokens.py:228–247

# HiggsAudioV2 tokenizer encodes raw audio into 8 codebook layers
audio_tokens = worker_tokenizer.encode(
    inputs["input_values"],
).audio_codes.squeeze(0)   # Shape: [8, T]

assert audio_tokens.size(0) == 8   # 8 codebook channels

""" Each of the 8 codebooks has a vocab of 1025 (IDs 0-1023 for audio codes, 1024
= mask token). The result is stored as int16 .npy arrays in WebDataset tar
shards. """

Text Tokens

The text tokens represent something like "male, in 30s, high pitch" along with the transcript to be synthesized, e.g. "Hi, I am Vaibhav, thanks for reading my blog."

Text uses a standard HuggingFace tokenizer (Qwen-based).

File: omnivoice/data/processor.py:109–127

# Style/instruction prefix
style = "<|lang_start|>EN<|lang_end|><|instruct_start|>None<|instruct_end|>"
style_inputs = self.text_tokenizer(style,
return_tensors="pt").input_ids.repeat(num_channels, 1)
# Shape: [8, N1] - same text token IDs duplicated across all 8 channels

# Transcript text
text_inputs = self.text_tokenizer(
    f"<|text_start|>{text}<|text_end|>", return_tensors="pt"
).input_ids.repeat(num_channels, 1)
# Shape: [8, N2]

The embedding input layer is constructed as:

[style tokens | text tokens | audio tokens]

Style tokens are repeated 8×1 so they share the same shape as text tokens and audio tokens.

File: omnivoice/data/processor.py:151–163

# Concatenate along the sequence dimension
input_ids = torch.cat([style_inputs, text_inputs, audio_inputs], dim=1)
# Shape: [C=8, style_len + text_len + audio_len]

# audio_mask marks which positions are audio vs text
audio_mask = torch.zeros(total_length, dtype=torch.bool)
audio_mask[audio_start_idx:] = True

These are then coalesced into a single embedding, since the model expects one unified embedding input (see the _prepare_embed_inputs function).

The audio matrix [8, T] is split into two parts (the "audio tokens" segment in the flow diagram below):

[ prompt (unmasked) | target (masked with prob p) ]
←prompt_length→ ←──maskable_region──→

Masking

Every cell in the target gets independently hidden with probability p. Random masking is chosen over deterministic masking, as it improved overall performance.

During training, mask_ratio is simply random.uniform(0.0, 1.0) per sample.

Config (omnivoice/training/config.py:48):
mask_ratio_range: Tuple[float, float] = field(default_factory=lambda: (0.0,
1.0))

Sampling per sample (omnivoice/data/processor.py:90,143):
mask_ratio = random.uniform(*self.mask_ratio_range)          # line 90:
uniform from (0.0, 1.0)
token_mask = torch.rand(maskable_region.shape) < mask_ratio  # line 143: each
cell masked independently

The overall probability of masking is 0.5.

During inference, the model progressively reveals tokens — unmasking few at first (when predictions are uncertain) and more later (when the model has more context to be confident).

# _get_time_steps with t_shift > 1 produces a concave curve:
# Early steps: small intervals → unmask few tokens (conservative)
# Later steps: large intervals → unmask many tokens (confident)
timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)

Bidirectional Transformer and LLM Initialization

Once the unified embedding layer exists — with randomly masked audio tokens — instead of starting from random weights, OmniVoice initializes from the Qwen 0.6B model. This gives it prior knowledge of world structure and relationships, and avoids training from scratch.

File: omnivoice/training/config.py:37

llm_name_or_path: str = "Qwen/Qwen3-0.6B"

File: omnivoice/training/builder.py:95–116

# Load Qwen's pretrained weights
resolved_llm = _resolve_model_path(config.llm_name_or_path)

llm_config = AutoConfig.from_pretrained(resolved_llm)
# Wrap it inside OmniVoice (adds audio_embeddings + audio_heads)

ov_config = OmniVoiceConfig(
    audio_vocab_size=1025, # 1024 + 1
    num_audio_codebook=8,
    llm_config=llm_config,
)
llm = AutoModel.from_pretrained(resolved_llm, attn_implementation=config.attn_implementation)
model = OmniVoice(config=ov_config, llm=llm)

# Resize text embeddings for new special tokens (<|text_start|>, etc.)
model.llm.resize_token_embeddings(len(tokenizer))

So OmniVoice takes Qwen's full transformer (attention layers, FFN, layer norms) with pretrained weights intact, then wraps it with new audio-specific layers (audio_embeddings and audio_heads).

Making It Bidirectional

Standard LLMs like Qwen are autoregressive (AR) — each token can only attend to previous tokens. Masked audio prediction, however, needs bidirectional attention (non-autoregressive, NAR): to predict a masked token at position 50, the model should see context from both position 30 and position 70.

attention_mask = valid[:, None, None, :].expand(B, 1, max_len, max_len).contiguous()

This copies data from future positions and makes it available as part of the input during training — creating the bidirectional attention mechanism.

llm_outputs = self.llm(
    inputs_embeds=inputs_embeds,
    attention_mask=attention_mask, --> the 4D [B, 1, Seq, Seq] mask goes here
    return_dict=True,
    position_ids=position_ids,
)

Instead of input_ids, the model uses input_embeds, since the embeddings are passed directly. Rather than using the LLM's weights for lookup, they're used as the underlying engine for training and inference (all attention layers, all FFN layers, all layer norms).

Predict

After the bidirectional transformer processes the full sequence, every position gets a hidden state vector of size [Hidden]. The model then needs to predict what audio token should sit at each masked position, across all 8 codebook layers.

File: omnivoice/models/omnivoice.py:225–228

self.audio_heads = nn.Linear(
    hidden_size,                              # e.g. 1024
    num_audio_codebook * audio_vocab_size,    # 8 * 1025 = 8200
    bias=False,
)

A single linear layer projects each hidden state into 8200 logits — 1025 possible token IDs for each of the 8 codebook layers, predicted simultaneously.

File: omnivoice/models/omnivoice.py:423–432

# hidden_states from LLM: [B, Seq, Hidden]
logits_flat = self.audio_heads(hidden_states)
# Shape: [B, Seq, 8200]

# Reshape into per-layer predictions
audio_logits = logits_flat.view(B, Seq, 8, 1025).permute(0, 2, 1, 3)
# Final shape: [B, 8, Seq, 1025]

So for every sequence position, the model outputs 8 independent probability distributions (one per codebook layer), each over 1025 possible token values. At masked positions, these are the model's predictions of the original audio tokens.

Loss Function


The model is only trained to predict masked tokens. Text positions, style positions, prompt (unmasked) audio, and unmasked target tokens all carry label −100, meaning "ignore, don't compute loss here."

Basically a small mistake is treated as a very big mistake for the model so it does not repeat it, and favourable outcomes are wighed in heavily as well.
$$ \mathcal{L} = -\sum_{(t,c) \in \mathcal{M}} \log P(x_{t,c} \mid X, Y; \theta) \tag{1} $$

where M denotes the set of indices (t, c) corresponding to masked positions within the target segment, with t ∈ {Tp + 1, …, T} and c ∈ {1, …, C}. Here, xt,c is the ground-truth acoustic token at time step t and codebook index c, and P(xt,c | …; θ) is the probability distribution predicted by the model parameterized by θ.

File: omnivoice/data/processor.py:145–149

# Only masked tokens get real labels
audio_labels[:, prompt_length:][~token_mask] = -100   # unmasked target → ignore
audio_labels[:, :prompt_length] = -100                 # prompt → ignore
# style_labels and text_labels are already all -100

File: omnivoice/models/omnivoice.py:439–456

# Step 1: Cross-entropy per token, per layer - ignore -100 positions
per_token_loss = F.cross_entropy(
    audio_logits.permute(0, 3, 1, 2),   # [B, Vocab, 8, Seq]
    labels,                               # [B, 8, Seq] - ground truth token IDs
    reduction="none",
    ignore_index=-100,
)
# Shape: [B, 8, Seq] - loss at every position (0 where label was -100)

# Step 2: Average loss per codebook layer
valid_mask = (labels != -100).float()
layer_means = (per_token_loss * valid_mask).sum(dim=(0, 2)) / valid_mask.sum(dim=(0, 2)).clamp(min=1.0)
# Shape: [8] - one mean loss value per codebook layer

# Step 3: Weighted sum across layers
weights = [8, 8, 6, 6, 4, 4, 2, 2]  # normalized to sum=1
loss = (layer_means * weights).sum()

The model predicts 1025 probabilities per codebook layer per position, and cross-entropy measures how wrong each prediction is. The error is averaged per layer and combined with weights [8, 8, 6, 6, 4, 4, 2, 2] — penalizing mistakes on early layers (coarse audio structure) about 4× more than later layers (fine details) — into a single loss value that gets minimized.

Inference

Inference is a 32-step iterative unmasking process. The target starts 100% masked. At each step, the model predicts all positions, picks the most confident ones, reveals them, and repeats — gradually filling in the audio.

Duration Estimation

Dprompt is the duration of the reference prompt. W represents "speaking weight" — Wprompt is the total speaking weight of the reference text, and Wtarget is the total speaking weight of the text you want to synthesize.

Dtarget = Dprompt × (Wtarget / Wprompt)

File: omnivoice/utils/duration.py:208–249

def estimate_duration(self, target_text, ref_text, ref_duration):
    ref_weight = self.calculate_total_weight(ref_text)     # Wprompt
    speed_factor = ref_weight / ref_duration               # chars-per-token rate
    target_weight = self.calculate_total_weight(target_text)  # Wtarget
    estimated_duration = target_weight / speed_factor       # = ref_duration × (Wtarget/Wprompt)
    return estimated_duration

Character weights are looked up via Unicode ranges (binary search):

self.weights = {
    "cjk": 3.0,        # Chinese character ≈ 3× a Latin letter
    "latin": 1.0,      # baseline
    "indic": 1.8,      # Hindi, Tamil, etc.
    "space": 0.2,      # tiny pause
    "digit": 3.5,      # "7" spoken as "seven"
    "punctuation": 0.5, # brief pause
    ...
}

Time Step Schedule

The formula warps a linear 0→1 schedule into a curve where early intervals are larger (unmask many tokens when the model has lots of context from the prompt) and later intervals are smaller (unmask fewer tokens as remaining positions become harder to predict).

File: omnivoice/models/omnivoice.py:1509–1518

def _get_time_steps(t_start=0.0, t_end=1.0, num_step=32, t_shift=0.1):
    timesteps = torch.linspace(0.0, 1.0, 33)  # 33 points for 32 intervals
    timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
    return timesteps

Iterative Loop and Unmasking

The target audio starts 100% masked (all tokens = 1024). Over 32 steps, the model progressively reveals tokens. At each step:

  1. The model sees the current state (some revealed, some masked) and predicts all positions.
  2. The model runs two predictions in parallel — conditioned c_logits (sees style + text + audio → "generate audio that matches this text") and unconditioned u_logits (sees only audio, no text/style → "generate generic audio"):
    log_probs = c_log_probs + 2.0 * (c_log_probs - u_log_probs)
    This ensures conditioned logits are given higher preference (classifier-free guidance).
  3. A confidence score is computed for each position, biased toward lower codebook layers — since layers denoting pitch, rhythm, and energy are more important, they're given preference earlier than later, finer-detail layers.
  4. The top-k most confident, still-masked positions are revealed. After 32 steps, every position is filled and the [8, T] token matrix is complete.

File: omnivoice/models/omnivoice.py

# ─── The main loop (line 1254) ────────────────────────────────────────────────
for step in range(gen_config.num_step):                          # A. Forward pass: run model on current state (mix of
revealed + masked tokens)
    batch_logits = self(                                         #    batch has 2×B items: first B conditioned (with
text), next B unconditioned
        input_ids=batch_input_ids,
        audio_mask=batch_audio_mask,
        attention_mask=batch_attention_mask,
    ).logits.to(torch.float32)                                   #    output: [2*B, 8, Seq, 1025]

    for i in range(B):
        k = schedules[i][step]                                   #    k = how many tokens to unmask this step (from
time-step schedule)
        if k <= 0:
            continue

        c_len, t_len = c_lens[i], task.target_lens[i]

        # Extract logits for the target region only
        c_logits = batch_logits[i : i + 1, :, c_len - t_len : c_len, :]      # conditioned logits [1, 8, T, 1025]
        u_logits = batch_logits[B + i : B + i + 1, :, :t_len, :]             # unconditioned logits [1, 8, T, 1025]

        # ─── B + C. Classifier-free guidance + token assignment (line 1299-1322) ──
        pred_tokens, scores = self._predict_tokens_with_scoring(
            c_logits, u_logits, gen_config
        )
        # Inside _predict_tokens_with_scoring:
        #   c_log_probs = F.log_softmax(c_logits, dim=-1)
        #   u_log_probs = F.log_softmax(u_logits, dim=-1)
        #   log_probs = torch.log_softmax(                       # guidance: amplify conditioned signal
        #       c_log_probs + guidance_scale * (c_log_probs - u_log_probs), dim=-1
        #   )
        #   log_probs[..., self.config.audio_mask_id] = -inf     # never predict mask token
        #   pred_tokens = log_probs.argmax(dim=-1)               # greedy token choice
        #   confidence_scores = log_probs.max(dim=-1)[0]         # how confident per position

        # ─── D. Layer penalty (line 1277) ─────────────────────────────────────────
        scores = scores - (layer_ids * gen_config.layer_penalty_factor)
        # layer 0: -0, layer 1: -5, layer 2: -10, ..., layer 7: -35
        # Lower layers (coarse structure) get higher effective scores → unmasked first

        # ─── E. Position selection (lines 1279-1287) ──────────────────────────────
        if gen_config.position_temperature > 0.0:
            scores = _gumbel_sample(scores, gen_config.position_temperature)  # add randomness to avoid local optima

        sample_tokens = tokens[i : i + 1, :, :t_len]
        scores.masked_fill_(                                     # already-unmasked positions → -inf (can't pick again)
            sample_tokens != self.config.audio_mask_id, -float("inf")
        )

        _, topk_idx = torch.topk(scores.flatten(), k)            # pick top-k most confident masked positions

        # ─── F. Unmask (lines 1288-1295) ──────────────────────────────────────────
        flat_tokens = sample_tokens.flatten()
        flat_tokens[topk_idx] = pred_tokens.flatten()[topk_idx]  # place predicted tokens at selected positions
        sample_tokens.copy_(flat_tokens.view_as(sample_tokens))

        # Update the batch so next step's forward pass sees newly revealed tokens
        tokens[i : i + 1, :, :t_len] = sample_tokens
        batch_input_ids[i : i + 1, :, c_len - t_len : c_len] = sample_tokens
        batch_input_ids[B + i : B + i + 1, :, :t_len] = sample_tokens

return [tokens[i, :, : task.target_lens[i]] for i in range(B)]   # final [8, T] per item