类装饰器
数据验证类装饰器
class ValidateAttributes:
"""属性验证类装饰器"""
def __init__(self, **validators):
self.validators = validators
def __call__(self, cls):
# 保存原始的__init__方法
original_init = cls.__init__
def __init__(self, *args, **kwargs):
# 调用原始的__init__
original_init(self, *args, **kwargs)
# 验证属性
for attr_name, validator in self.validators.items():
if hasattr(self, attr_name):
value = getattr(self, attr_name)
if not validator(value):
raise ValueError(
f"属性 {attr_name} 的值 {value} 验证失败"
)
cls.__init__ = __init__
return cls
# 使用示例
def is_positive(value):
return isinstance(value, (int, float)) and value > 0
def is_valid_name(value):
return isinstance(value, str) and 2 <= len(value) <= 20
@ValidateAttributes(age=is_positive, name=is_valid_name)
class Person:
def __init__(self, name, age):
self.name = name
self.age = age
# 测试
try:
person = Person("张三", 25) # 正确
person = Person("李", -5) # 验证失败
except ValueError as e:
print(f"错误: {e}")
属性注入类装饰器
class InjectProperties:
"""属性注入类装饰器"""
def __init__(self, **properties):
self.properties = properties
def __call__(self, cls):
for name, value in self.properties.items():
setattr(cls, name, property(value))
return cls
def get_full_name(self):
return f"{self.first_name} {self.last_name}"
def get_age_description(self):
return f"{self.age}岁"
@InjectProperties(
full_name=get_full_name,
age_description=get_age_description
)
class User:
def __init__(self, first_name, last_name, age):
self.first_name = first_name
self.last_name = last_name
self.age = age
# 测试
user = User("张", "三", 25)
print(user.full_name) # 输出: 张 三
print(user.age_description) # 输出: 25岁
元类装饰器
接口强制实现装饰器
from abc import ABCMeta, abstractmethod
class InterfaceEnforcer(type):
"""强制接口实现的元类"""
def __new__(cls, name, bases, attrs):
# 获取所有抽象方法
abstracts = {
name for name, value in attrs.items()
if getattr(value, "__isabstractmethod__", False)
}
# 检查是否所有抽象方法都已实现
for base in bases:
for name in getattr(base, "__abstractmethods__", set()):
abstracts.add(name)
# 设置__abstractmethods__属性
attrs["__abstractmethods__"] = abstracts
return super().__new__(cls, name, bases, attrs)
class Interface(metaclass=InterfaceEnforcer):
@abstractmethod
def method1(self):
pass
@abstractmethod
def method2(self):
pass
# 使用示例
try:
class MyClass(Interface):
def method1(self):
return "实现方法1"
# 没有实现method2,将引发错误
obj = MyClass() # 将引发TypeError
except TypeError as e:
print(f"错误: {e}")
装饰器工厂
参数化装饰器工厂
import functools
import time
import random
def retry_with_strategy(
max_attempts=3,
delay_strategy="fixed",
initial_delay=1,
max_delay=60,
exceptions=(Exception,)
):
"""高级重试装饰器工厂
支持多种重试策略:
- fixed: 固定延迟
- linear: 线性增长
- exponential: 指数增长
"""
def calculate_delay(attempt, strategy):
if strategy == "fixed":
return initial_delay
elif strategy == "linear":
return initial_delay * attempt
elif strategy == "exponential":
delay = initial_delay * (2 ** (attempt - 1))
return min(delay, max_delay)
else:
raise ValueError(f"不支持的重试策略: {strategy}")
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(1, max_attempts + 1):
try:
return func(*args, **kwargs)
except exceptions as e:
last_exception = e
if attempt == max_attempts:
raise
delay = calculate_delay(attempt, delay_strategy)
# 添加随机抖动,避免多个请求同时重试
jitter = random.uniform(0, 0.1) * delay
total_delay = delay + jitter
print(f"第{attempt}次尝试失败,{total_delay:.2f}秒后重试...")
time.sleep(total_delay)
raise last_exception
return wrapper
return decorator
# 使用示例
@retry_with_strategy(
max_attempts=3,
delay_strategy="exponential",
initial_delay=1,
exceptions=(ConnectionError, TimeoutError)
)
def unstable_network_call():
if random.random() < 0.7:
raise ConnectionError("网络连接失败")
return "请求成功"
# 测试函数
try:
result = unstable_network_call()
print(result)
except ConnectionError as e:
print(f"最终失败: {e}")
异步装饰器
异步重试装饰器
import asyncio
import functools
def async_retry(max_attempts=3, delay=1):
"""异步函数重试装饰器"""
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(1, max_attempts + 1):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt == max_attempts:
raise
print(f"尝试 {attempt}/{max_attempts} 失败,等待 {delay} 秒...")
await asyncio.sleep(delay)
raise last_exception
return wrapper
return decorator
# 使用示例
@async_retry(max_attempts=3, delay=1)
async def fetch_data(url):
# 模拟不稳定的异步请求
if random.random() < 0.7:
raise ConnectionError("连接失败")
return f"来自 {url} 的数据"
# 测试代码
async def main():
try:
result = await fetch_data("http://api.example.com")
print(result)
except ConnectionError as e:
print(f"最终失败: {e}")
# asyncio.run(main()) # 在异步环境中运行
异步缓存装饰器
import asyncio
import time
def async_cache(ttl=60):
"""异步结果缓存装饰器"""
cache = {}
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
key = str(args) + str(kwargs)
# 检查缓存
if key in cache:
result, timestamp = cache[key]
if time.time() - timestamp < ttl:
print("使用缓存结果")
return result
# 获取新结果
result = await func(*args, **kwargs)
cache[key] = (result, time.time())
return result
return wrapper
return decorator
# 使用示例
@async_cache(ttl=5)
async def fetch_user_data(user_id):
# 模拟耗时的异步操作
await asyncio.sleep(2)
return {"id": user_id, "name": f"用户{user_id}"}
# 测试代码
async def test_cache():
# 第一次调用
result1 = await fetch_user_data(1)
print("第一次调用:", result1)
# 第二次调用(使用缓存)
result2 = await fetch_user_data(1)
print("第二次调用:", result2)
# 等待缓存过期
await asyncio.sleep(6)
# 第三次调用(缓存已过期)
result3 = await fetch_user_data(1)
print("第三次调用:", result3)
# asyncio.run(test_cache()) # 在异步环境中运行
单例模式
线程安全的单例装饰器
import threading
def singleton(cls):
"""线程安全的单例装饰器"""
instances = {}
lock = threading.Lock()
def get_instance(*args, **kwargs):
if cls not in instances:
with lock:
# 双重检查锁定模式
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance
@singleton
class Database:
def __init__(self, host="localhost", port=5432):
self.host = host
self.port = port
print(f"创建数据库连接: {self.host}:{self.port}")
def query(self, sql):
return f"执行查询: {sql}"
# 测试代码
def test_singleton():
db1 = Database()
db2 = Database()
print("是否是同一个实例:", db1 is db2)
print(db1.query("SELECT * FROM users"))
# 多线程测试
def create_db():
db = Database()
print(f"线程 {threading.current_thread().name} 创建的实例: {id(db)}")
threads = [
threading.Thread(target=create_db)
for _ in range(5)
]
for t in threads:
t.start()
for t in threads:
t.join()
高级应用
依赖注入装饰器
class Inject:
"""依赖注入装饰器"""
_services = {}
@classmethod
def register(cls, interface, implementation):
"""注册服务实现"""
cls._services[interface] = implementation
@classmethod
def get(cls, interface):
"""获取服务实现"""
if interface not in cls._services:
raise ValueError(f"未注册的服务: {interface}")
return cls._services[interface]
def __init__(self, **dependencies):
self.dependencies = dependencies
def __call__(self, cls):
# 保存原始的__init__方法
original_init = cls.__init__
def __init__(self, *args, **kwargs):
# 注入依赖
for name, interface in self.dependencies.items():
setattr(self, name, self.get(interface)())
# 调用原始的__init__
original_init(self, *args, **kwargs)
cls.__init__ = __init__
return cls
# 使用示例
class Logger:
def log(self, message):
print(f"日志: {message}")
class Database:
def query(self, sql):
return f"执行查询: {sql}"
# 注册服务
Inject.register("logger", Logger)
Inject.register("database", Database)
@Inject(logger="logger", db="database")
class UserService:
def __init__(self):
self.logger.log("UserService 已创建")
def get_user(self, user_id):
self.logger.log(f"获取用户 {user_id}")
return self.db.query(f"SELECT * FROM users WHERE id = {user_id}")
# 测试
service = UserService()
print(service.get_user(1))
方法链式调用装饰器
def chainable(cls):
"""使类的方法支持链式调用的装饰器"""
class ChainWrapper:
def __init__(self, wrapped):
self.wrapped = wrapped
def __getattr__(self, name):
attr = getattr(self.wrapped, name)
if callable(attr):
def wrapper(*args, **kwargs):
result = attr(*args, **kwargs)
return self if result is None else result
return wrapper
return attr
return lambda *args, **kwargs: ChainWrapper(cls(*args, **kwargs))
@chainable
class QueryBuilder:
def __init__(self):
self.query = []
def select(self, *fields):
self.query.append(f"SELECT {', '.join(fields)}")
def from_table(self, table):
self.query.append(f"FROM {table}")
def where(self, condition):
self.query.append(f"WHERE {condition}")
def build(self):
return " ".join(self.query)
# 测试链式调用
query = QueryBuilder()
sql = (query
.select("id", "name", "age")
.from_table("users")
.where("age > 18")
.build())
print(sql) # 输出: SELECT id, name, age FROM users WHERE age > 18