How to define a dataclass so each of its attributes is the list of its subclass attributes?
You could create a new field after __init__
was called:
from dataclasses import dataclass, field, fields
from typing import List
@dataclass
class Position:
name: str
lon: float
lat: float
@dataclass
class Section:
positions: List[Position]
_pos: dict = field(init=False, repr=False)
def __post_init__(self):
# create _pos after init is done, read only!
Section._pos = property(Section._get_positions)
def _get_positions(self):
_pos = {}
# iterate over all fields and add to _pos
for field in [f.name for f in fields(self.positions[0])]:
if field not in _pos:
_pos[field] = []
for p in self.positions:
_pos[field].append(getattr(p, field))
return _pos
pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)
sec = Section([pos1, pos2, pos3])
print(sec.positions)
print(sec._pos['name'])
print(sec._pos['lon'])
print(sec._pos['lat'])
Out:
[Position(name='a', lon=52, lat=10), Position(name='b', lon=46, lat=-10), Position(name='c', lon=45, lat=-10)]
['a', 'b', 'c']
[52, 46, 45]
[10, -10, -10]
Edit:
In case you just need it more generic, you could overwrite __getattr__
:
from dataclasses import dataclass, field, fields
from typing import List
@dataclass
class Position:
name: str
lon: float
lat: float
@dataclass
class Section:
positions: List[Position]
def __getattr__(self, keyName):
for f in fields(self.positions[0]):
if f"{f.name}s" == keyName:
return [getattr(x, f.name) for x in self.positions]
# Error handling here: Return empty list, raise AttributeError, ...
pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)
sec = Section([pos1, pos2, pos3])
print(sec.names)
print(sec.lons)
print(sec.lats)
Out:
['a', 'b', 'c']
[52, 46, 45]
[10, -10, -10]