Skip to content

Caching System

Surfing Weights implements a sophisticated caching system to optimize performance and memory usage. This guide explains how the caching system works and how to configure it for your needs.

Overview

The caching system operates at two levels:

  1. Server-Side Caching
  2. Caches raw model weights
  3. Shared across all clients
  4. Memory-efficient storage

  5. Client-Side Caching

  6. Caches loaded model components
  7. Per-client cache
  8. Optimized for inference

Server-Side Cache

Configuration

from streaming_weights import WeightServer

server = WeightServer(
    model_path="./chunks/bert-tiny",
    cache_size_mb=200  # Set cache size in megabytes
)

Command line configuration:

streaming-weights-server --chunks-dir ./chunks/bert-tiny \
    --cache-size 200  # Cache size in MB

Features

  1. LRU (Least Recently Used) Eviction
  2. Automatically removes least used weights
  3. Optimizes memory usage
  4. Adapts to access patterns

  5. Size-Based Management

  6. Configurable maximum size
  7. Automatic eviction when full
  8. Memory usage monitoring

Client-Side Cache

Configuration

from streaming_weights import StreamingBertModel

model = StreamingBertModel(
    model_name="prajjwal1/bert-tiny",
    cache_size=3  # Number of layers to cache
)

Features

  1. Component-Level Caching
  2. Caches entire model layers
  3. Maintains layer state
  4. Optimizes inference speed

  5. Smart Prefetching

    # Enable prefetching for better performance
    outputs = await model.forward_async(
        input_ids=inputs,
        enable_prefetch=True,
        prefetch_count=2  # Prefetch next 2 layers
    )
    

  6. Cache Warmup

    # Preload specific layers
    await model.warmup(layer_indices=[0, 1, 2])
    

Performance Monitoring

Cache Statistics

# Get cache performance metrics
stats = model.get_inference_stats()
print(f"Cache hit rate: {stats['cache_hit_rate']:.2%}")
print(f"Average inference time: {stats['avg_inference_time']:.3f}s")

# Get current cache state
cache_info = model.get_cache_info()
print(f"Cached components: {cache_info['cached_components']}")
print(f"Cache memory usage: {cache_info['memory_usage_mb']:.2f} MB")

Cache Management

# Clear the cache manually
model.clear_cache()

# Update cache size at runtime
model.cache_size = 5  # Increase cache size

Advanced Features

Cache Optimization

  1. Access Pattern Optimization

    # Order layers for optimal caching
    await model.warmup([0, 1, 2])  # Cache first layers
    await model.prefetch_next_layers(2, prefetch_count=2)  # Prefetch next layers
    

  2. Memory Management

    # Monitor and adjust cache size
    if model.get_cache_info()['memory_usage_mb'] > 1000:
        model.cache_size = model.cache_size - 1
    

Distributed Caching

When using multiple servers:

from streaming_weights import AdvancedWeightServer

server = AdvancedWeightServer(
    chunks_dir="./chunks",
    redis_url="redis://localhost:6379",  # Redis for distributed caching
    cache_size_mb=1000
)

Best Practices

  1. Cache Size Configuration
  2. Set server cache size based on available RAM
  3. Adjust client cache size based on model architecture
  4. Monitor cache hit rates for optimization

  5. Performance Optimization

  6. Use warmup for frequently accessed layers
  7. Enable prefetching for sequential access
  8. Clear cache when switching tasks

  9. Memory Management

  10. Monitor memory usage with get_cache_info()
  11. Adjust cache sizes based on workload
  12. Clear cache when memory pressure is high

Troubleshooting

Common Issues

  1. High Memory Usage
  2. Reduce cache size
  3. Clear cache more frequently
  4. Monitor with get_cache_info()

  5. Poor Cache Performance

  6. Check cache hit rates
  7. Adjust cache size
  8. Review access patterns

Cache Monitoring

# Monitor cache performance
while running_inference:
    stats = model.get_inference_stats()
    if stats['cache_hit_rate'] < 0.5:
        print("Warning: Low cache hit rate")
    await asyncio.sleep(60)  # Check every minute

Next Steps