Python高级装饰器示例

本页面展示了Python装饰器的高级用法,包括类装饰器、元类装饰器、装饰器工厂等复杂应用场景。这些示例将帮助你深入理解装饰器的强大功能。

类装饰器

数据验证类装饰器


class ValidateAttributes:
    """属性验证类装饰器"""
    
    def __init__(self, **validators):
        self.validators = validators
    
    def __call__(self, cls):
        # 保存原始的__init__方法
        original_init = cls.__init__
        
        def __init__(self, *args, **kwargs):
            # 调用原始的__init__
            original_init(self, *args, **kwargs)
            # 验证属性
            for attr_name, validator in self.validators.items():
                if hasattr(self, attr_name):
                    value = getattr(self, attr_name)
                    if not validator(value):
                        raise ValueError(
                            f"属性 {attr_name} 的值 {value} 验证失败"
                        )
        
        cls.__init__ = __init__
        return cls

# 使用示例
def is_positive(value):
    return isinstance(value, (int, float)) and value > 0

def is_valid_name(value):
    return isinstance(value, str) and 2 <= len(value) <= 20

@ValidateAttributes(age=is_positive, name=is_valid_name)
class Person:
    def __init__(self, name, age):
        self.name = name
        self.age = age

# 测试
try:
    person = Person("张三", 25)  # 正确
    person = Person("李", -5)   # 验证失败
except ValueError as e:
    print(f"错误: {e}")
                    

属性注入类装饰器


class InjectProperties:
    """属性注入类装饰器"""
    
    def __init__(self, **properties):
        self.properties = properties
    
    def __call__(self, cls):
        for name, value in self.properties.items():
            setattr(cls, name, property(value))
        return cls

def get_full_name(self):
    return f"{self.first_name} {self.last_name}"

def get_age_description(self):
    return f"{self.age}岁"

@InjectProperties(
    full_name=get_full_name,
    age_description=get_age_description
)
class User:
    def __init__(self, first_name, last_name, age):
        self.first_name = first_name
        self.last_name = last_name
        self.age = age

# 测试
user = User("张", "三", 25)
print(user.full_name)        # 输出: 张 三
print(user.age_description)  # 输出: 25岁
                    

元类装饰器

接口强制实现装饰器


from abc import ABCMeta, abstractmethod

class InterfaceEnforcer(type):
    """强制接口实现的元类"""
    
    def __new__(cls, name, bases, attrs):
        # 获取所有抽象方法
        abstracts = {
            name for name, value in attrs.items()
            if getattr(value, "__isabstractmethod__", False)
        }
        
        # 检查是否所有抽象方法都已实现
        for base in bases:
            for name in getattr(base, "__abstractmethods__", set()):
                abstracts.add(name)
        
        # 设置__abstractmethods__属性
        attrs["__abstractmethods__"] = abstracts
        return super().__new__(cls, name, bases, attrs)

class Interface(metaclass=InterfaceEnforcer):
    @abstractmethod
    def method1(self):
        pass
    
    @abstractmethod
    def method2(self):
        pass

# 使用示例
try:
    class MyClass(Interface):
        def method1(self):
            return "实现方法1"
        # 没有实现method2,将引发错误
    
    obj = MyClass()  # 将引发TypeError
except TypeError as e:
    print(f"错误: {e}")
                    

装饰器工厂

参数化装饰器工厂


import functools
import time
import random

def retry_with_strategy(
    max_attempts=3,
    delay_strategy="fixed",
    initial_delay=1,
    max_delay=60,
    exceptions=(Exception,)
):
    """高级重试装饰器工厂
    
    支持多种重试策略:
    - fixed: 固定延迟
    - linear: 线性增长
    - exponential: 指数增长
    """
    
    def calculate_delay(attempt, strategy):
        if strategy == "fixed":
            return initial_delay
        elif strategy == "linear":
            return initial_delay * attempt
        elif strategy == "exponential":
            delay = initial_delay * (2 ** (attempt - 1))
            return min(delay, max_delay)
        else:
            raise ValueError(f"不支持的重试策略: {strategy}")
    
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            last_exception = None
            
            for attempt in range(1, max_attempts + 1):
                try:
                    return func(*args, **kwargs)
                except exceptions as e:
                    last_exception = e
                    if attempt == max_attempts:
                        raise
                    
                    delay = calculate_delay(attempt, delay_strategy)
                    # 添加随机抖动,避免多个请求同时重试
                    jitter = random.uniform(0, 0.1) * delay
                    total_delay = delay + jitter
                    
                    print(f"第{attempt}次尝试失败,{total_delay:.2f}秒后重试...")
                    time.sleep(total_delay)
            
            raise last_exception
        return wrapper
    return decorator

# 使用示例
@retry_with_strategy(
    max_attempts=3,
    delay_strategy="exponential",
    initial_delay=1,
    exceptions=(ConnectionError, TimeoutError)
)
def unstable_network_call():
    if random.random() < 0.7:
        raise ConnectionError("网络连接失败")
    return "请求成功"

# 测试函数
try:
    result = unstable_network_call()
    print(result)
except ConnectionError as e:
    print(f"最终失败: {e}")
                    

异步装饰器

异步重试装饰器


import asyncio
import functools

def async_retry(max_attempts=3, delay=1):
    """异步函数重试装饰器"""
    def decorator(func):
        @functools.wraps(func)
        async def wrapper(*args, **kwargs):
            last_exception = None
            
            for attempt in range(1, max_attempts + 1):
                try:
                    return await func(*args, **kwargs)
                except Exception as e:
                    last_exception = e
                    if attempt == max_attempts:
                        raise
                    
                    print(f"尝试 {attempt}/{max_attempts} 失败,等待 {delay} 秒...")
                    await asyncio.sleep(delay)
            
            raise last_exception
        return wrapper
    return decorator

# 使用示例
@async_retry(max_attempts=3, delay=1)
async def fetch_data(url):
    # 模拟不稳定的异步请求
    if random.random() < 0.7:
        raise ConnectionError("连接失败")
    return f"来自 {url} 的数据"

# 测试代码
async def main():
    try:
        result = await fetch_data("http://api.example.com")
        print(result)
    except ConnectionError as e:
        print(f"最终失败: {e}")

# asyncio.run(main())  # 在异步环境中运行
                    

异步缓存装饰器


import asyncio
import time

def async_cache(ttl=60):
    """异步结果缓存装饰器"""
    cache = {}
    
    def decorator(func):
        @functools.wraps(func)
        async def wrapper(*args, **kwargs):
            key = str(args) + str(kwargs)
            
            # 检查缓存
            if key in cache:
                result, timestamp = cache[key]
                if time.time() - timestamp < ttl:
                    print("使用缓存结果")
                    return result
            
            # 获取新结果
            result = await func(*args, **kwargs)
            cache[key] = (result, time.time())
            return result
        return wrapper
    return decorator

# 使用示例
@async_cache(ttl=5)
async def fetch_user_data(user_id):
    # 模拟耗时的异步操作
    await asyncio.sleep(2)
    return {"id": user_id, "name": f"用户{user_id}"}

# 测试代码
async def test_cache():
    # 第一次调用
    result1 = await fetch_user_data(1)
    print("第一次调用:", result1)
    
    # 第二次调用(使用缓存)
    result2 = await fetch_user_data(1)
    print("第二次调用:", result2)
    
    # 等待缓存过期
    await asyncio.sleep(6)
    
    # 第三次调用(缓存已过期)
    result3 = await fetch_user_data(1)
    print("第三次调用:", result3)

# asyncio.run(test_cache())  # 在异步环境中运行
                    

单例模式

线程安全的单例装饰器


import threading

def singleton(cls):
    """线程安全的单例装饰器"""
    
    instances = {}
    lock = threading.Lock()
    
    def get_instance(*args, **kwargs):
        if cls not in instances:
            with lock:
                # 双重检查锁定模式
                if cls not in instances:
                    instances[cls] = cls(*args, **kwargs)
        return instances[cls]
    
    return get_instance

@singleton
class Database:
    def __init__(self, host="localhost", port=5432):
        self.host = host
        self.port = port
        print(f"创建数据库连接: {self.host}:{self.port}")
    
    def query(self, sql):
        return f"执行查询: {sql}"

# 测试代码
def test_singleton():
    db1 = Database()
    db2 = Database()
    print("是否是同一个实例:", db1 is db2)
    print(db1.query("SELECT * FROM users"))

# 多线程测试
def create_db():
    db = Database()
    print(f"线程 {threading.current_thread().name} 创建的实例: {id(db)}")

threads = [
    threading.Thread(target=create_db)
    for _ in range(5)
]

for t in threads:
    t.start()

for t in threads:
    t.join()
                    

高级应用

依赖注入装饰器


class Inject:
    """依赖注入装饰器"""
    
    _services = {}
    
    @classmethod
    def register(cls, interface, implementation):
        """注册服务实现"""
        cls._services[interface] = implementation
    
    @classmethod
    def get(cls, interface):
        """获取服务实现"""
        if interface not in cls._services:
            raise ValueError(f"未注册的服务: {interface}")
        return cls._services[interface]
    
    def __init__(self, **dependencies):
        self.dependencies = dependencies
    
    def __call__(self, cls):
        # 保存原始的__init__方法
        original_init = cls.__init__
        
        def __init__(self, *args, **kwargs):
            # 注入依赖
            for name, interface in self.dependencies.items():
                setattr(self, name, self.get(interface)())
            # 调用原始的__init__
            original_init(self, *args, **kwargs)
        
        cls.__init__ = __init__
        return cls

# 使用示例
class Logger:
    def log(self, message):
        print(f"日志: {message}")

class Database:
    def query(self, sql):
        return f"执行查询: {sql}"

# 注册服务
Inject.register("logger", Logger)
Inject.register("database", Database)

@Inject(logger="logger", db="database")
class UserService:
    def __init__(self):
        self.logger.log("UserService 已创建")
    
    def get_user(self, user_id):
        self.logger.log(f"获取用户 {user_id}")
        return self.db.query(f"SELECT * FROM users WHERE id = {user_id}")

# 测试
service = UserService()
print(service.get_user(1))
                    

方法链式调用装饰器


def chainable(cls):
    """使类的方法支持链式调用的装饰器"""
    
    class ChainWrapper:
        def __init__(self, wrapped):
            self.wrapped = wrapped
        
        def __getattr__(self, name):
            attr = getattr(self.wrapped, name)
            if callable(attr):
                def wrapper(*args, **kwargs):
                    result = attr(*args, **kwargs)
                    return self if result is None else result
                return wrapper
            return attr
    
    return lambda *args, **kwargs: ChainWrapper(cls(*args, **kwargs))

@chainable
class QueryBuilder:
    def __init__(self):
        self.query = []
    
    def select(self, *fields):
        self.query.append(f"SELECT {', '.join(fields)}")
    
    def from_table(self, table):
        self.query.append(f"FROM {table}")
    
    def where(self, condition):
        self.query.append(f"WHERE {condition}")
    
    def build(self):
        return " ".join(self.query)

# 测试链式调用
query = QueryBuilder()
sql = (query
    .select("id", "name", "age")
    .from_table("users")
    .where("age > 18")
    .build())

print(sql)  # 输出: SELECT id, name, age FROM users WHERE age > 18