27

python standard library dataclasses

 4 years ago
source link: https://pinghailinfeng.gitee.io/2020/01/25/python-standard-library-dataclasses/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

jQnQnyB.jpg!web

源码

源码: Lib/dataclasses.py

主要类结构层次,节选自 源码 )

__all__ = ['dataclass',
           'field',
           'Field',
           'FrozenInstanceError',
           'InitVar',
           'MISSING',

           # Helper functions.
           'fields',
           'asdict',
           'astuple',
           'make_dataclass',
           'replace',
           'is_dataclass',
           ]

class _DataclassParams:
    __slots__ = ('init',
                 'repr',
                 'eq',
                 'order',
                 'unsafe_hash',
                 'frozen',
                 )

这个模块提供了一个装饰器和一些函数,用于自动添加生成的 special method s ,例如 __init__() __repr__() 到用户定义的类。 它最初描述于 PEP 557

在这些生成的方法中使用的成员变量使用 PEP 526 类型注释定义。例如这段代码:

@dataclass
class InventoryItem:
    '''Class for keeping track of an item in inventory.'''
    name: str
    unit_price: float
    quantity_on_hand: int = 0

    def total_cost(self) -> float:
        return self.unit_price * self.quantity_on_hand

除其他事情外,将添加 __init__() ,其看起来像:

def __init__(self, name: str, unit_price: float, quantity_on_hand: int=0):
    self.name = name
    self.unit_price = unit_price
    self.quantity_on_hand = quantity_on_hand

请注意,此方法会自动添加到类中:它不会在上面显示的 InventoryItem 定义中直接指定。

模块级装饰器、类和函数

  • `@dataclasses.dataclass `( **, init=True , repr=True , eq=True , order=False , unsafe_hash=False , frozen=False*)

    这个函数是 decorator ,用于将生成的 special method 添加到类中,如下所述。

    dataclass() 装饰器检查类以找到 fieldfield 被定义为具有 类型标注 的类变量。除了下面描述的两个例外,在 dataclass() 中没有任何内容检查变量标注中指定的类型。

    所有生成的方法中的字段顺序是它们在类定义中出现的顺序。

    dataclass() 装饰器将向类中添加各种“dunder”方法,如下所述。如果类中已存在任何添加的方法,则行为取决于参数,如下所述。装饰器返回被调用的同一个类;没有创建新类。

    如果 dataclass() 仅用作没有参数的简单装饰器,它就像它具有此签名中记录的默认值一样。也就是说,这三种 dataclass() 用法是等价的:

    @dataclass
    class C:
        ...
    
    @dataclass()
    class C:
        ...
    
    @dataclass(init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False)
    class C:
       ...
    

    dataclass() 的参数有:

    • init : 如果为真值(默认),将生成一个 __ init__() 方法。

      如果类已定义 __ init__() ,则忽略此参数。

    • repr :如果为真值(默认),将生成一个 __repr__() 方法。 生成的 repr 字符串将具有类名以及每个字段的名称和 repr ,按照它们在类中定义的顺序。不包括标记为从 repr 中排除的字段。 例如: InventoryItem(name='widget', unit_price=3.0, quantity_on_hand=10)

      如果类已定义 __repr__() ,则忽略此参数。

    • eq :如果为true(默认值),将生成 __eq__() 方法。此方法将类作为其字段的元组按顺序比较。比较中的两个实例必须是相同的类型。

      如果类已定义 __eq__() ,则忽略此参数。

    • order :如果为真值(默认为 False ),则 __lt__() __ le__() __gt__() __ge__() 方法将生成。 这将类作为其字段的元组按顺序比较。比较中的两个实例必须是相同的类型。如果 order 为真值并且 eq 为假值 ,则引发 ValueError

      如果类已经定义了 __lt__() __le__() __gt__() 或者 __ge__() 中的任意一个,将引发 TypeError

    • unsafe_hash :如果为 False (默认值),则根据 eqfrozen 的设置方式生成 __hash__() 方法。

      __hash__() 由内置的 hash() 使用,当对象被添加到散列集合(如字典和集合)时。有一个 __hash__() 意味着类的实例是不可变的。可变性是一个复杂的属性,取决于程序员的意图, __eq__() 的存在性和行为,以及 dataclass() 装饰器中 eqfrozen 标志的值。

      默认情况下, dataclass() 不会隐式添加 __hash__() 方法,除非这样做是安全的。 它也不会添加或更改现有的明确定义的 __hash__() 方法。 设置类属性 __hash__ = None 对 Python 具有特定含义,如 __hash__() 文档中所述。

      如果 __hash__() 没有显式定义,或者它被设置为 None ,那么 dataclass() 可以 添加一个隐式 __hash__() 方法。虽然不推荐,但你可以强制 dataclass() unsafe_hash=True 创建一个 __hash__() 方法。 如果你的类在逻辑上是不可变的但实际仍然可变,则可能就是这种情况。这是一个特殊的用例,应该仔细考虑。

      以下是隐式创建 __hash__() 方法的规则。请注意,你不能在数据类中都使用显式的 __hash__() 方法并设置 unsafe_hash=True ;这将导致 TypeError

      如果 eqfrozen 都是 true,默认情况下 dataclass() 将为你生成一个 __hash__() 方法。如果 eq 为 true 且 frozen 为 false ,则 __hash__() 将被设置为 None ,标记它不可用(因为它是可变的)。如果 eq 为 false ,则 __hash__() 将保持不变,这意味着将使用超类的 __hash__() 方法(如果超类是 object ,这意味着它将回到基于id的hash)。

    • frozen : 如为真值 (默认值为 False ),则对字段赋值将会产生异常。 这模拟了只读的冻结实例。 如果在类中定义了 __setattr__() __delattr__() 则将会引发 TypeError 。 参见下文的讨论。

    field s 可以选择使用普通的 Python 语法指定默认值:

    @dataclass
    class C:
        a: int       # 'a' has no default value
        b: int = 0   # assign a default value for 'b'
    

    在这个例子中, ab 都将包含在添加的 __init__() 方法中,它们将被定义为:

    def __init__(self, a: int, b: int = 0):
    

    如果没有默认值的字段跟在具有默认值的字段后,将引发 TypeError 。当这发生在单个类中时,或者作为类继承的结果时,都是如此。

  • dataclasses.field ( **, default=MISSING , default_factory=MISSING , repr=True , hash=None , init=True , compare=True , metadata=None*)

    对于常见和简单的用例,不需要其他功能。但是,有些数据类功能需要额外的每字段信息。为了满足这种对附加信息的需求,你可以通过调用提供的 field() 函数来替换默认字段值。例如:

    @dataclass
    class C:
        mylist: List[int] = field(default_factory=list)
    
    c = C()
    c.mylist += [1, 2, 3]
    

    如上所示, MISSING 值是一个 sentinel 对象,用于检测是否提供了 defaultdefault_factory 参数。 使用此 sentinel 是因为 Nonedefault 的有效值。没有代码应该直接使用 MISSING 值。

    field() 参数有:

    • default :如果提供,这将是该字段的默认值。这是必需的,因为 field() 调用本身会替换一般的默认值。

    • default_factory :如果提供,它必须是一个零参数可调用对象,当该字段需要一个默认值时,它将被调用。除了其他目的之外,这可以用于指定具有可变默认值的字段,如下所述。 同时指定 defaultdefault_factory 将产生错误。

    • init :如果为true(默认值),则该字段作为参数包含在生成的 __init__() 方法中。

    • repr :如果为true(默认值),则该字段包含在生成的 __repr__() 方法返回的字符串中。

    • compare :如果为true(默认值),则该字段包含在生成的相等性和比较方法中( __eq__() __gt__() 等等)。

    • hash :这可以是布尔值或 None 。如果为true,则此字段包含在生成的 __hash__() 方法中。如果为 None (默认值),请使用 compare 的值,这通常是预期的行为。如果字段用于比较,则应在 hash 中考虑该字段。不鼓励将此值设置为 None 以外的任何值。

      设置 hash=Falsecompare=True 的一个可能原因是,如果一个计算 hash 的代价很高的字段是检验等价性需要的,但还有其他字段可以计算类型的 hash 。 即使从 hash 中排除某个字段,它仍将用于比较。

    • metadata :这可以是映射或 None 。 None 被视为一个空的字典。这个值包含在 MappingProxyType() 中,使其成为只读,并暴露在 Field 对象上。数据类根本不使用它,它是作为第三方扩展机制提供的。多个第三方可以各自拥有自己的键值,以用作元数据中的命名空间。

    如果通过调用 field() 指定字段的默认值,则该字段的类属性将替换为指定的 default 值。如果没有提供 default ,那么将删除类属性。目的是在 dataclass() 装饰器运行之后,类属性将包含字段的默认值,就像指定了默认值一样。例如,之后:

    @dataclass
    class C:
        x: int
        y: int = field(repr=False)
        z: int = field(repr=False, default=10)
        t: int = 20
    

    类属性 C.z 将是 10 ,类属性 C.t 将是 20 ,类属性 C.xC.y 将不设置。

  • class dataclasses.Field

    Field 对象描述每个定义的字段。这些对象在内部创建,并由 fields() 模块级方法返回(见下文)。用户永远不应该直接实例化 Field 对象。 其有文档的属性是:

    • name :字段的名字。
    • type :字段的类型。
    • defaultdefault_factoryinitreprhashcompare 以及 metadata 与具有和 field() 声明中相同的意义和值。

    可能存在其他属性,但它们是私有的,不能被审查或依赖。

  • dataclasses.fields ( class_or_instance )

    返回 Field 对象的元组,用于定义此数据类的字段。 接受数据类或数据类的实例。如果没有传递一个数据类或实例将引发 TypeError 。 不返回 ClassVarInitVar 的伪字段。

  • dataclasses.asdict ( instance , **, dict_factory=dict*)

    将数据类 instance 转换为字典(使用工厂函数 dict_factory )。每个数据类都转换为其字段的字典,如 name: value 对。数据类、字典、列表和元组被递归。例如:

    @dataclass
    class Point:
         x: int
         y: int
    
    @dataclass
    class C:
         mylist: List[Point]
    
    p = Point(10, 20)
    assert asdict(p) == {'x': 10, 'y': 20}
    
    c = C([Point(0, 0), Point(10, 4)])
    assert asdict(c) == {'mylist': [{'x': 0, 'y': 0}, {'x': 10, 'y': 4}]}
    

    引发 TypeError 如果 instance 不是数据类实例。

  • dataclasses.astuple ( instance , **, tuple_factory=tuple*)

    将数据类 instance 转换为元组(通过使用工厂函数 tuple_factory )。每个数据类都转换为其字段值的元组。数据类、字典、列表和元组被递归。

    继续前一个例子:

    assert astuple(p) == (10, 20)
    assert astuple(c) == ([(0, 0), (10, 4)],)
    

    引发 TypeError 如果 instance 不是数据类实例。

  • dataclasses.make_dataclass ( cls_name , fields , **, bases=() , namespace=None , init=True , repr=True , eq=True , order=False , unsafe_hash=False , frozen=False*)

    创建一个名为 cls_name 的新数据类,字段为 fields 中定义的字段,基类为 bases 中给出的基类,并使用 namespace 中给出的命名空间进行初始化。 fields 是一个可迭代的元素,每个元素都是 name(name, type)(name, type, Field) 。 如果只提供name , typetyping.Anyinitrepreqorderunsafe_hashfrozen 的值与它们在 dataclass() 中的含义相同。

    此函数不是严格要求的,因为用于任何创建带有 __annotations__ 的新类的 Python 机制都可以应用 dataclass() 函数将该类转换为数据类。提供此功能是为了方便。例如:

    C = make_dataclass('C',
                       [('x', int),
                         'y',
                        ('z', int, field(default=5))],
                       namespace={'add_one': lambda self: self.x + 1})
    

    等价于

    @dataclass
    class C:
        x: int
        y: 'typing.Any'
        z: int = 5
    
        def add_one(self):
            return self.x + 1
    
  • dataclasses.replace ( instance , **changes )

    创建一个 instance 相同类型的新对象,用 changes 中的值替换字段。如果 instance 不是数据类,则引发 TypeError 。如果 changes 中的值没有指定字段,则引发 TypeError

    新返回的对象通过调用数据类的 __init__() 方法创建。这确保了如果存在 __post_init__() ,其也被调用。

    如果存在没有默认值的仅初始化变量,必须在调用 replace() 时指定,以便它们可以传递给 __init__() __post_init__()

    changes 包含任何定义为 init=False 的字段是错误的。在这种情况下会引发 ValueError

    提前提醒 init=False 字段在调用 replace() 时的工作方式。如果它们完全被初始化的话,它们不是从源对象复制的,而是在 __post_init__() 中初始化。估计 init=False 字段很少能被正确地使用。如果使用它们,那么使用备用类构造函数或者可能是处理实例复制的自定义 replace() (或类似命名的)方法可能是明智的。

  • dataclasses.is_dataclass ( class_or_instance )

    如果其形参为 dataclass 或其实例则返回 True ,否则返回 False

    如果你需要知道一个类是否是一个数据类的实例(而不是一个数据类本身),那么再添加一个 not isinstance(obj, type) 检查:

    def is_dataclass_instance(obj):
        return is_dataclass(obj) and not isinstance(obj, type)
    

初始化后处理

生成的 __init__() 代码将调用一个名为 __post_init__() 的方法,如果在类上已经定义了 __post_init__() 。它通常被称为 self.__post_init__() 。但是,如果定义了任何 InitVar 字段,它们也将按照它们在类中定义的顺序传递给 __post_init__() 。 如果没有 __ init__() 方法生成,那么 __post_init__() 将不会被自动调用。

在其他用途中,这允许初始化依赖于一个或多个其他字段的字段值。例如:

@dataclass
class C:
    a: float
    b: float
    c: float = field(init=False)

    def __post_init__(self):
        self.c = self.a + self.b

有关将参数传递给 __post_init__() 的方法,请参阅下面有关仅初始化变量的段落。另请参阅关于 replace() 处理 init=False 字段的警告。

类变量

两个地方 dataclass() 实际检查字段类型的之一是确定字段是否是如 PEP 526 所定义的类变量。它通过检查字段的类型是否为 typing.ClassVar 来完成此操作。如果一个字段是一个 ClassVar ,它将被排除在考虑范围之外,并被数据类机制忽略。这样的 ClassVar 伪字段不会由模块级的 fields() 函数返回。

仅初始化变量

另一个 dataclass() 检查类型注解地方是为了确定一个字段是否是一个仅初始化变量。它通过查看字段的类型是否为 dataclasses.InitVar 类型来实现。如果一个字段是一个 InitVar ,它被认为是一个称为仅初始化字段的伪字段。因为它不是一个真正的字段,所以它不会被模块级的 fields() 函数返回。仅初始化字段作为参数添加到生成的 __init__() 方法中,并传递给可选的 __post_init__() 方法。数据类不会使用它们。

例如,假设一个字段将从数据库初始化,如果在创建类时未提供其值:

@dataclass
class C:
    i: int
    j: int = None
    database: InitVar[DatabaseType] = None

    def __post_init__(self, database):
        if self.j is None and database is not None:
            self.j = database.lookup('j')

c = C(10, database=my_database)

在这种情况下, fields() 将返回 ij Field 对象,但不包括 database

冻结的实例

无法创建真正不可变的 Python 对象。但是,通过将 frozen=True 传递给 dataclass() 装饰器,你可以模拟不变性。在这种情况下,数据类将向类添加 __setattr__() __delattr__() 方法。 些方法在调用时会引发 FrozenInstanceError

使用 frozen=True 时会有很小的性能损失: __ init__() 不能使用简单的赋值来初始化字段,并必须使用 object.__ setattr__()

继承

当数组由 dataclass() 装饰器创建时,它会查看反向 MRO 中的所有类的基类(即从 object 开始 ),并且对于它找到的每个数据类, 将该基类中的字段添加到字段的有序映射中。添加完所有基类字段后,它会将自己的字段添加到有序映射中。所有生成的方法都将使用这种组合的,计算的有序字段映射。由于字段是按插入顺序排列的,因此派生类会重载基类。一个例子:

@dataclass
class Base:
    x: Any = 15.0
    y: int = 0

@dataclass
class C(Base):
    z: int = 10
    x: int = 15

最后的字段列表依次是 xyzx 的最终类型是 int ,如类 C 中所指定的那样。

C 生成的 __init__() 方法看起来像:

def __init__(self, x: int = 15, y: int = 0, z: int = 10):

默认工厂函数

如果一个 field() 指定了一个 default_factory ,当需要该字段的默认值时,将使用零参数调用它。例如,要创建列表的新实例,请使用:

mylist: list = field(default_factory=list)

如果一个字段被排除在 __init__() 之外(使用 init=False )并且字段也指定 default_factory ,则默认的工厂函数将始终从生成的 __ init__() 函数调用。发生这种情况是因为没有其他方法可以为字段提供初始值。

可变的默认值

Python 在类属性中存储默认成员变量值。思考这个例子,不使用数据类:

class C:
    x = []
    def add(self, element):
        self.x.append(element)

o1 = C()
o2 = C()
o1.add(1)
o2.add(2)
assert o1.x == [1, 2]
assert o1.x is o2.x

请注意,类 C 的两个实例共享相同的类变量 x ,如预期的那样。

使用数据类, 如果 此代码有效:

@dataclass
class D:
    x: List = []
    def add(self, element):
        self.x += element

它生成的代码类似于:

class D:
    x = []
    def __init__(self, x=x):
        self.x = x
    def add(self, element):
        self.x += element

assert D().x is D().x

这与使用类 C 的原始示例具有相同的问题。也就是说,在创建类实例时没有为 x 指定值的类 D 的两个实例将共享相同的 x 副本。由于数据类只使用普通的 Python 类创建,因此它们也会共享此行为。数据类没有通用的方法来检测这种情况。相反,如果数据类检测到类型为 listdictset 的默认参数,则会引发 TypeError 。这是一个部分解决方案,但它可以防止许多常见错误。

使用默认工厂函数是一种创建可变类型新实例的方法,并将其作为字段的默认值:

@dataclass
class D:
    x: list = field(default_factory=list)

assert D().x is not D().x

异常

  • exception dataclasses.FrozenInstanceError

    在使用 frozen=True 定义的数据类上调用隐式定义的 __setattr__() __delattr__() 时引发。

总结

我们经常会遇到这样的情况:

比如我们设计一个商品类:

class Product(object):
    def __init__(self, id=None, author_id=None, category_id=None, brand_id=None, spu_id=None, 
                 title=None, item_id=None, n_comments=None, creation_time=None, update_time=None, 
                 source='', parent_id=0, ancestor_id=0): 
        self.id = id
        self.author_id = author_id
        self.category_id = category_id
        self.brand_id = brand_id
        self.spu_id = spu_id
        self.title = title
        self.item_id = item_id
        ...

__init__ 方法包含了众多参数,

应用一

我们在打印的时候不希望打印所有的参数

通常的做法是,重写 __repr__ 方法

def __repr__(self):
        return '{}(id={}, author_id={}, category_id={}, brand_id={})'.format(
        self.__class__.__name__, self.id, self.author_id, self.category_id, 
        self.brand_id)
#对象打印
p = Product()
print(p)   
>>> Product(id=1, author_id=100001, category_id=2003, brand_id=20)

问题来了,我需要在每个类里重写这个方法,那该怎么处理?

应用二

对象比较,有时候需要判断2个对象是否相等甚至大小(例如用于展示顺序)

通常的做法是 重写对应的 __eq__ , __lt__ 方法

def __eq__(self, other):
    if not isinstance(other, self.__class__):
        return NotImplemented
    return (self.id, self.author_id, self.category_id, self.brand_id) == (
        other.id, other.author_id, other.category_id, other.brand_id)
def __lt__(self, other):
    if not isinstance(other, self.__class__):
        return NotImplemented
    return (self.id, self.author_id, self.category_id, self.brand_id) < (
        other.id, other.author_id, other.category_id, other.brand_id)

应用三

对象去重,重写 __hash__ 方法

def __hash__(self):
    return hash((self.id, self.author_id, self.category_id, self.brand_id))

应用四

导出字典格式

def to_dict(self):
    return {
        'id': self.id,
        'author_id': self.author_id,
        'category_id': self.category_id,
        ...
    }

但是我并不想打印所有的属性,于是有下面的做法

def to_dict(self):
    self._a = 1
    return vars(self)

等等,python难道没有解决方案。

答案是肯定的,当然有,那就是 dataclasses

dataclasses 解决上面的问题

from datetime import datetime
from dataclasses import dataclass, field


@dataclass(unsafe_hash=True, order=True)
class Product(object):
    id: int
    author_id: int
    brand_id: int
    spu_id: int
    title: str = field(hash=False, repr=False, compare=False)
    item_id: int = field(hash=False, repr=False, compare=False)
    n_comments: int = field(hash=False, repr=False, compare=False)
    creation_time: datetime = field(default=None, repr=False, compare=False,hash=False)
    update_time: datetime = field(default=None, repr=False, compare=False, hash=False)
    source: str = field(default='', repr=False, compare=False, hash=False)
    parent_id: int = field(default=0, repr=False, compare=False, hash=False)
    ancestor_id: int = field(default=0, repr=False, compare=False, hash=False)


p1 = Product(1, 100001, 2003, 20, 1002393002, '这是一个测试商品1', 2000001, 100, None, 1)

p2 = Product(1, 100001, 2003, 20, 1002393002, '这是一个测试商品2', 2000001, 100, None, 2)

p3 = Product(3, 100001, 2003, 20, 1002393002, '这是一个测试商品3', 2000001, 100, None, 3)

print(p1)

print(p1 == p2)

print(p1 > p2)

print({p1, p2, p3})

from dataclasses import asdict
asdict(p1)

想了解更多的使用方法,强烈建议阅读源码。

最后,希望武汉的兄弟们能保护好自己,健康才是最重要的。加油!!!


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK