Overview
The AudioSealDetector class detects the presence of AudioSeal watermarks in audio signals and decodes embedded secret messages. It returns detection probabilities and binary messages.
Initialization
from audioseal import AudioSeal
# Load a pre-trained detector
detector = AudioSeal.load_detector("audioseal_detector_16bits")
Typically, you’ll load a detector using AudioSeal.load_detector() rather than instantiating directly.
Methods
detect_watermark
Detect watermark presence and decode the embedded message in one convenient call.
import torchaudio
# Load audio
audio, sr = torchaudio.load("audio.wav")
# Detect watermark
detection_prob, message = detector.detect_watermark(audio)
print(f"Detection probability: {detection_prob.item():.2%}")
print(f"Decoded message: {message}")
Parameters
Input audio tensor of shape (batch, channels, samples) or (batch, samples). The audio should be at the model’s expected sample rate (typically 16kHz).
Sample rate of the input audio. This parameter is deprecated and will be ignored in AudioSeal 0.2+.
Threshold for converting message probabilities to binary values. Probabilities above this threshold are set to 1, below to 0.
Threshold for frame-level watermark detection. Used to compute the overall detection probability.
Returns
Detection probability tensor of shape (batch,). Values range from 0.0 to 1.0, indicating the proportion of frames detected as watermarked.
Binary message tensor of shape (batch, nbits). Each value is 0 or 1. If the audio is not watermarked, the message will be essentially random.
Example
import torch
import torchaudio
from audioseal import AudioSeal
# Load detector and audio
detector = AudioSeal.load_detector("audioseal_detector_16bits")
audio, sr = torchaudio.load("watermarked.wav")
# Detect with default thresholds
prob, message = detector.detect_watermark(audio)
if prob > 0.5:
print(f"Watermark detected with {prob.item():.1%} confidence")
print(f"Message: {message.tolist()}")
else:
print("No watermark detected")
# Detect with custom thresholds
prob, message = detector.detect_watermark(
audio,
message_threshold=0.6,
detection_threshold=0.4
)
decode_message
Decode the message from raw detector output.
# Get raw detection result
result, _ = detector.forward(audio)
# Decode message from result
message = detector.decode_message(result[:, 2:, :])
Parameters
Raw watermark result tensor of shape (batch, nbits, frames) from the detector output.
Returns
Decoded message probabilities of shape (batch, nbits). Values range from 0.0 to 1.0 after sigmoid activation, indicating the probability of each bit being 1.
Example
import torch
import torchaudio
# Load audio
audio, sr = torchaudio.load("audio.wav")
# Get raw output
detection_result, message_probs = detector(audio)
# detection_result contains frame-level detection scores
# message_probs contains the decoded message probabilities
# Manual decoding from raw result
raw_result = detector.detector(audio)
message_logits = raw_result[:, 2:, :] # Skip first 2 channels (detection)
message_probs = detector.decode_message(message_logits)
# Convert to binary with custom threshold
message_binary = (message_probs > 0.6).int()
forward
Run the full detection pipeline, returning both detection scores and decoded messages.
# Run full detection
detection_result, message = detector(audio)
Parameters
Input audio tensor of shape (batch, channels, samples) or (batch, samples).
Sample rate of the input audio. This parameter is deprecated and will be ignored.
Returns
Detection result tensor of shape (batch, 2, frames). The first channel contains the probability of no watermark, the second channel contains the probability of watermark presence at each frame. Values are softmax-normalized.
Decoded message probabilities of shape (batch, nbits). Values range from 0.0 to 1.0.
Example
import torch
import torchaudio
from audioseal import AudioSeal
# Load detector
detector = AudioSeal.load_detector("audioseal_detector_16bits")
# Load audio
audio, sr = torchaudio.load("audio.wav")
# Run detection
detection_result, message_probs = detector(audio)
# Analyze frame-level detection
watermark_frames = detection_result[:, 1, :] # Second channel
avg_detection = watermark_frames.mean()
print(f"Average detection score: {avg_detection.item():.2%}")
# Analyze message
message_binary = (message_probs > 0.5).int()
print(f"Decoded message: {message_binary[0].tolist()}")
# Check confidence
message_confidence = torch.abs(message_probs - 0.5) * 2 # 0 to 1 scale
avg_confidence = message_confidence.mean()
print(f"Average bit confidence: {avg_confidence.item():.2%}")
Attributes
The detection network consisting of a SEANetEncoder followed by a 1x1 convolution layer.
Optional normalizer that applies loudness normalization before detection for improved robustness.
Number of bits in the secret message. Set to 0 for 0-bit watermarking (presence detection only).
Complete Example
import torch
import torchaudio
from audioseal import AudioSeal
# Load models
generator = AudioSeal.load_generator("audioseal_wm_16bits", device="cuda")
detector = AudioSeal.load_detector("audioseal_detector_16bits", device="cuda")
# Load audio
audio, sr = torchaudio.load("input.wav")
audio = audio.to("cuda")
# Create and embed watermark
message = torch.randint(0, 2, (1, 16), device="cuda")
watermarked = generator(audio, message=message)
# Detect watermark
detection_prob, decoded_message = detector.detect_watermark(
watermarked,
message_threshold=0.5,
detection_threshold=0.5
)
print(f"Detection probability: {detection_prob.item():.2%}")
print(f"Original message: {message[0].tolist()}")
print(f"Decoded message: {decoded_message[0].tolist()}")
# Check if messages match
if torch.equal(message[0], decoded_message[0]):
print("✓ Message decoded correctly!")
else:
# Calculate bit error rate
errors = (message[0] != decoded_message[0]).sum().item()
ber = errors / message.size(1)
print(f"✗ Message decoded with {ber:.1%} bit error rate")
Advanced Usage
Batch Processing
import torch
import torchaudio
# Load multiple audio files
audio_files = ["audio1.wav", "audio2.wav", "audio3.wav"]
audio_batch = []
for file in audio_files:
audio, sr = torchaudio.load(file)
audio_batch.append(audio)
# Pad to same length and stack
max_length = max(a.size(-1) for a in audio_batch)
audio_batch = [
torch.nn.functional.pad(a, (0, max_length - a.size(-1)))
for a in audio_batch
]
audio_batch = torch.cat(audio_batch, dim=0)
# Detect in batch
probs, messages = detector.detect_watermark(audio_batch)
for i, (prob, msg) in enumerate(zip(probs, messages)):
print(f"{audio_files[i]}: {prob.item():.2%} - {msg.tolist()}")
Localized Detection
import torch
# Process in windows to find watermark location
window_size = 16000 * 5 # 5 seconds at 16kHz
stride = 16000 # 1 second stride
for start in range(0, audio.size(-1) - window_size, stride):
window = audio[:, start:start + window_size]
prob, msg = detector.detect_watermark(window)
if prob > 0.7:
time_start = start / 16000
time_end = (start + window_size) / 16000
print(f"Watermark detected at {time_start:.1f}s - {time_end:.1f}s")
print(f"Message: {msg[0].tolist()}")
See Also