Automatische Hintergrundentfernung mit BiRefNet und Python


Ich habe ein Python-Skript entwickelt, das mithilfe des BiRefNet-Modells von Hugging Face automatisch Hintergründe aus Bildern entfernt. Das Tool wurde speziell für die Batch-Verarbeitung großer Bildmengen konzipiert und nutzt dabei moderne GPU-Beschleunigung, um selbst bei hochauflösenden Bildern eine schnelle und präzise Segmentierung zu erreichen.

Die Idee hinter diesem Projekt war es, manuelle Bildbearbeitungsschritte zu automatisieren – insbesondere für Anwendungsfälle wie E-Commerce-Produktfotos, Trainingsmaterial für Machine Learning oder kreative Design-Workflows. Anstatt jedes Bild einzeln in Photoshop oder GIMP zu bearbeiten, verarbeitet das Skript ganze Ordner mit einem einzigen Befehl.

Beispielbild

Wofür ist dieses Tool nützlich?

Für alle, die regelmäßig Produktbilder, Objektfotos oder Trainingsmaterial für Computer Vision vorbereiten müssen, bietet dieses Skript eine enorme Zeitersparnis:

  • E-Commerce: Automatische Freistellung von Produktfotos für Online-Shops, ohne teure manuelle Bearbeitung.
  • Machine Learning: Vorbereitung von Trainingsdaten für Objekterkennungsmodelle oder generative KI-Systeme.
  • Kreative Workflows: Schnelle Integration von Objekten in neue Hintergründe oder Designs.
  • Skalierbarkeit: Verarbeitung von hunderten Bildern auf einmal, mit präziser Maskengenerierung auch bei komplexen Formen.

Technischer Ansatz

Das Skript basiert auf dem BiRefNet-Modell (ZhengPeng7/BiRefNet), einem hochmodernen neuronalen Netzwerk für Bildsegmentierung, das speziell für präzise Objektabgrenzung trainiert wurde. Die Implementierung nutzt PyTorch und Hugging Face Transformers für die Modellinferenz.

1. GPU-Optimierung für moderne Hardware

Ein besonderes Augenmerk lag auf der Unterstützung neuester NVIDIA-GPUs wie der RTX 50-Serie. Um Kompatibilitätsprobleme mit CUDA-Kerneln zu vermeiden, habe ich spezielle Umgebungsvariablen und PyTorch-Konfigurationen gesetzt:

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ['TORCH_USE_CUDA_DSA'] = '1'
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

Diese Einstellungen ermöglichen es dem Modell, die Tensor-Cores moderner GPUs voll auszunutzen und gleichzeitig Speicherfragmentierung zu vermeiden. Das Skript erkennt automatisch die verfügbare Hardware und zeigt Details wie GPU-Name, CUDA-Version und Compute Capability an.

2. Bildvorverarbeitung und Normalisierung

Bevor ein Bild durch das Modell verarbeitet wird, durchläuft es eine standardisierte Vorverarbeitungspipeline:

transform = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

Die Normalisierung mit ImageNet-Standardwerten stellt sicher, dass das Modell optimale Ergebnisse liefert, da es auf ähnlich normalisierten Daten trainiert wurde. Die Resize-Operation auf 1024×1024 Pixel gewährleistet konsistente Eingabegrößen für das Netzwerk.

3. Präzise Maskengenerierung und Nachbearbeitung

Das Herzstück der Verarbeitung ist die Maskengenerierung. Das BiRefNet-Modell erzeugt eine pixelgenaue Segmentierungsmaske, die angibt, welche Pixel zum Vordergrund gehören:

with torch.no_grad():
    preds = model(input_tensor)[-1].sigmoid().cpu()

pred_mask = preds[0].squeeze()
pred_mask_np = pred_mask.numpy()

Die Maske wird anschließend auf die Originalgröße des Bildes hochskaliert (mit LANCZOS-Interpolation für beste Qualität) und als Alpha-Kanal dem Bild hinzugefügt. Das Ergebnis ist ein transparentes PNG mit perfekt freigestelltem Objekt.

4. Robuste Fehlerbehandlung und Fallback-Mechanismen

Ein besonderes Feature ist das automatische CPU-Fallback bei CUDA-Fehlern. Sollte die GPU-Verarbeitung eines Bildes fehlschlagen (z.B. durch Speichermangel), wechselt das Skript automatisch auf CPU-Verarbeitung für dieses eine Bild:

except Exception as e:
    if "CUDA" in str(e) and device.type == 'cuda':
        print(f"\nRetrying {filename} on CPU...")
        torch.cuda.empty_cache()
        model.to('cpu')
        result_image = remove_background(input_path, model, torch.device('cpu'))
        model.to(device)

Dieser Mechanismus stellt sicher, dass die Batch-Verarbeitung auch bei problematischen Bildern nicht vollständig abbricht, sondern robust weiterläuft.

Batch-Verarbeitung mit Fortschrittsanzeige

Das Skript nutzt tqdm für eine detaillierte Fortschrittsanzeige während der Verarbeitung. Es durchläuft alle Bilder in einem Eingabeordner und speichert die freigestellten Versionen automatisch als PNG im Ausgabeordner:

for filename in tqdm(image_files, desc="Removing backgrounds"):
    try:
        input_path = os.path.join(input_folder, filename)
        result_image = remove_background(input_path, model, device)
        
        output_filename = Path(filename).stem + ".png"
        output_path = os.path.join(output_folder, output_filename)
        result_image.save(output_path, "PNG")
        
        successful += 1
    except Exception as e:
        print(f"\nError processing {filename}: {str(e)}")
        failed += 1

Am Ende der Verarbeitung gibt das Skript eine übersichtliche Zusammenfassung aus, die erfolgreiche und fehlgeschlagene Verarbeitungen zeigt.

Verwendete Technologien

  • Sprache: Python 3
  • Deep Learning Framework: PyTorch mit CUDA-Unterstützung
  • Modell: BiRefNet (ZhengPeng7/BiRefNet) von Hugging Face
  • Bildverarbeitung: PIL (Pillow), torchvision transforms
  • API-Zugriff: Hugging Face Transformers mit Token-Authentifizierung
  • Fortschrittsanzeige: tqdm
  • Konfiguration: python-dotenv für Umgebungsvariablen

Sicherheitsaspekte und Best Practices

Das Skript folgt modernen Best Practices für Machine Learning-Anwendungen:

  • Token-Sicherheit: Der Hugging Face API-Token wird über eine .env-Datei geladen, niemals hardcoded.
  • Speicherverwaltung: Automatisches Leeren des GPU-Caches (torch.cuda.empty_cache()) bei Fehlern.
  • Evaluation Mode: Das Modell wird mit model.eval() in den Evaluation-Modus versetzt, um Dropout und Batch-Normalisierung zu deaktivieren.
  • Transparenzerhaltung: Alle Ausgaben werden als PNG gespeichert, um den Alpha-Kanal zu bewahren.

Das vollständige Skript

import os
from pathlib import Path
from PIL import Image
import torch
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
import numpy as np
from tqdm import tqdm
from dotenv import load_dotenv

# Force PyTorch to recompile CUDA kernels for RTX 50-series
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ['TORCH_COMPILE_DEBUG'] = '0'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

# Load environment variables
load_dotenv()

# Configuration
INPUT_FOLDER = r"C:\Users\grisc\Desktop\py\shoes\training\boots"
OUTPUT_FOLDER = r"C:\Users\grisc\Desktop\py\shoes\training\boots_no_bg"
MODEL_NAME = "ZhengPeng7/BiRefNet"
HF_TOKEN = os.getenv("HF_TOKEN")

# Image transformation
image_size = (1024, 1024)

def load_model():
    """Load the BiRefNet model."""
    print(f"Loading model: {MODEL_NAME}")
    
    if not HF_TOKEN:
        raise ValueError("HF_TOKEN not found in .env file. Please add your Hugging Face token.")
    
    # Configure CUDA for RTX 50-series
    if torch.cuda.is_available():
        # Use memory efficient settings
        torch.backends.cudnn.benchmark = False
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    
    model = AutoModelForImageSegmentation.from_pretrained(
        MODEL_NAME, 
        trust_remote_code=True,
        token=HF_TOKEN
    )
    
    # Use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        print(f"✓ Using GPU: {torch.cuda.get_device_name(0)} (CUDA {torch.version.cuda})")
        print(f"  Compute capability: sm_{torch.cuda.get_device_capability(0)[0]}{torch.cuda.get_device_capability(0)[1]}")
    else:
        print("⚠ GPU not available, using CPU")
    
    model.to(device)
    model.eval()
    
    print(f"Model loaded on: {device}")
    return model, device

def preprocess_image(image):
    """Preprocess image for the model."""
    # Resize and normalize
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)

def remove_background(image_path, model, device):
    """Remove background from a single image."""
    # Load image
    original_image = Image.open(image_path).convert("RGB")
    original_size = original_image.size
    
    # Preprocess
    input_tensor = preprocess_image(original_image).to(device)
    
    # Generate mask
    with torch.no_grad():
        if device.type == 'cuda':
            # Try running without autocast (float32) to avoid kernel issues
            preds = model(input_tensor)[-1].sigmoid().cpu()
        else:
            preds = model(input_tensor)[-1].sigmoid().cpu()
    
    # Post-process mask
    pred_mask = preds[0].squeeze()
    pred_mask_np = pred_mask.numpy()
    
    # Resize mask to original image size
    mask_pil = Image.fromarray((pred_mask_np * 255).astype(np.uint8))
    mask_pil = mask_pil.resize(original_size, Image.LANCZOS)
    
    # Apply mask to original image
    original_image.putalpha(mask_pil)
    
    return original_image

def process_folder(input_folder, output_folder, model, device):
    """Process all images in a folder."""
    # Create output folder
    os.makedirs(output_folder, exist_ok=True)
    
    # Get all image files
    image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp']
    image_files = []
    
    for file in os.listdir(input_folder):
        if any(file.lower().endswith(ext) for ext in image_extensions):
            image_files.append(file)
    
    print(f"\nFound {len(image_files)} images to process")
    print(f"Output folder: {output_folder}\n")
    
    # Process images with progress bar
    successful = 0
    failed = 0
    
    for filename in tqdm(image_files, desc="Removing backgrounds"):
        try:
            input_path = os.path.join(input_folder, filename)
            
            # Remove background
            result_image = remove_background(input_path, model, device)
            
            # Save as PNG (to preserve transparency)
            output_filename = Path(filename).stem + ".png"
            output_path = os.path.join(output_folder, output_filename)
            result_image.save(output_path, "PNG")
            
            successful += 1
            
        except Exception as e:
            if "CUDA" in str(e) and device.type == 'cuda':
                print(f"\nCUDA error for {filename}: {str(e)}")
                print(f"Retrying {filename} on CPU...")
                try:
                    torch.cuda.empty_cache()
                    model.to('cpu')
                    result_image = remove_background(input_path, model, torch.device('cpu'))
                    
                    # Save as PNG
                    output_filename = Path(filename).stem + ".png"
                    output_path = os.path.join(output_folder, output_filename)
                    result_image.save(output_path, "PNG")
                    
                    successful += 1
                    
                    # Move model back to GPU
                    model.to(device)
                    continue
                except Exception as cpu_e:
                    print(f"Failed on CPU retry: {str(cpu_e)}")
                    # Ensure model is back on GPU
                    model.to(device)
            
            print(f"\nError processing {filename}: {str(e)}")
            failed += 1
    
    return successful, failed

def main():
    """Main function."""
    print("="*60)
    print("BACKGROUND REMOVAL - BiRefNet")
    print("="*60)
    print(f"Input folder: {INPUT_FOLDER}")
    print(f"Output folder: {OUTPUT_FOLDER}")
    print("="*60)
    
    # Check if CUDA is available
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    else:
        print("GPU: Not available, using CPU")
    
    # Load model
    model, device = load_model()
    
    # Process images
    successful, failed = process_folder(INPUT_FOLDER, OUTPUT_FOLDER, model, device)
    
    # Summary
    print("\n" + "="*60)
    print("PROCESSING COMPLETE")
    print("="*60)
    print(f"✓ Successful: {successful}")
    print(f"✗ Failed: {failed}")
    print(f"Total: {successful + failed}")
    print("="*60)

if __name__ == "__main__":
    main()

Mögliche Erweiterungen

  • Web-Interface: Integration in eine Flask- oder FastAPI-Webanwendung für Browser-basierte Uploads.
  • Batch-API: RESTful API-Endpunkt für die Integration in automatisierte Workflows.
  • Qualitätskontrolle: Automatische Bewertung der Maskenqualität und Markierung problematischer Bilder.
  • Multi-Modell-Support: Vergleich verschiedener Segmentierungsmodelle und Auswahl des besten Ergebnisses.
  • Cloud-Deployment: Skalierung der Verarbeitung auf AWS Lambda oder Google Cloud Functions für serverlose Batch-Jobs.