Skip to content

Commit 7a46246

Browse files
surculus12thejtshow
authored andcommitted
Resolve #4530: Added vector union support for Python
1 parent 067bfdb commit 7a46246

File tree

13 files changed

+667
-2
lines changed

13 files changed

+667
-2
lines changed

scripts/generate_code.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def glob(path, pattern):
225225
)
226226

227227
flatc(
228-
BASE_OPTS + CPP_OPTS + CS_OPTS + JAVA_OPTS + KOTLIN_OPTS + PHP_OPTS,
228+
BASE_OPTS + CPP_OPTS + CS_OPTS + JAVA_OPTS + KOTLIN_OPTS + PHP_OPTS + PYTHON_OPTS,
229229
prefix="union_vector",
230230
schema="union_vector/union_vector.fbs",
231231
)

src/idl_gen_python.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,39 @@ class PythonGenerator : public BaseGenerator {
10761076
code += Indent + Indent + "return None\n\n";
10771077
}
10781078

1079+
// Get the value of a vector's union member.
1080+
void GetMemberOfVectorOfUnion(const StructDef &struct_def,
1081+
const FieldDef &field,
1082+
std::string *code_ptr) const {
1083+
auto &code = *code_ptr;
1084+
auto vectortype = field.value.type.VectorType();
1085+
1086+
GenReceiver(struct_def, code_ptr);
1087+
code += namer_.Method(field);
1088+
code += "(self, j):" + OffsetPrefix(field);
1089+
code += Indent + Indent + Indent + "x = self._tab.Vector(o)\n";
1090+
code += Indent + Indent + Indent;
1091+
code += "x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * ";
1092+
code += NumToString(InlineSize(vectortype)) + "\n";
1093+
code += Indent + Indent + Indent;
1094+
code += "x -= self._tab.Pos\n";
1095+
1096+
// TODO(rw): this works and is not the good way to it:
1097+
bool is_native_table = TypeName(field) == "*flatbuffers.Table";
1098+
if (is_native_table) {
1099+
code +=
1100+
Indent + Indent + Indent + "from flatbuffers.table import Table\n";
1101+
} else if (parser_.opts.include_dependence_headers) {
1102+
code += Indent + Indent + Indent;
1103+
code += "from " + GenPackageReference(field.value.type) + " import " +
1104+
TypeName(field) + "\n";
1105+
}
1106+
code += Indent + Indent + Indent + "obj = Table(bytearray(), 0)\n";
1107+
code += Indent + Indent + Indent + GenGetter(field.value.type);
1108+
code += "obj, x)\n" + Indent + Indent + Indent + "return obj\n";
1109+
code += Indent + Indent + "return None\n\n";
1110+
}
1111+
10791112
// Get the value of a vector's non-struct member. Uses a named return
10801113
// argument to conveniently set the zero value for the result.
10811114
void GetMemberOfVectorOfNonStruct(const StructDef &struct_def,
@@ -1521,6 +1554,8 @@ class PythonGenerator : public BaseGenerator {
15211554
auto vectortype = field.value.type.VectorType();
15221555
if (vectortype.base_type == BASE_TYPE_STRUCT) {
15231556
GetMemberOfVectorOfStruct(struct_def, field, code_ptr, imports);
1557+
} else if (vectortype.base_type == BASE_TYPE_UNION) {
1558+
GetMemberOfVectorOfUnion(struct_def, field, code_ptr);
15241559
} else {
15251560
GetMemberOfVectorOfNonStruct(struct_def, field, code_ptr);
15261561
if (parser_.opts.python_gen_numpy) {
@@ -1780,6 +1815,9 @@ class PythonGenerator : public BaseGenerator {
17801815
import_list->insert("import " + package_reference);
17811816
}
17821817
field_type = "Optional[List[" + field_type + "]";
1818+
} else if (base_type == BASE_TYPE_UNION) {
1819+
GenUnionInit(field, field_type_ptr, import_list, import_typing_list);
1820+
field_type = "Optional[List[" + field_type + "]]";
17831821
} else {
17841822
field_type = "Optional[List[" +
17851823
GetBasePythonTypeForScalarAndString(base_type) + "]]";

src/idl_parser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2702,7 +2702,7 @@ bool Parser::SupportsAdvancedUnionFeatures() const {
27022702
return (opts.lang_to_generate &
27032703
~(IDLOptions::kCpp | IDLOptions::kTs | IDLOptions::kPhp |
27042704
IDLOptions::kJava | IDLOptions::kCSharp | IDLOptions::kKotlin |
2705-
IDLOptions::kBinary | IDLOptions::kSwift | IDLOptions::kNim |
2705+
IDLOptions::kBinary | IDLOptions::kSwift | IDLOptions::kPython | IDLOptions::kNim |
27062706
IDLOptions::kJson | IDLOptions::kKotlinKmp)) == 0;
27072707
}
27082708

tests/py_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@
5454
import monster_test_generated # the one-file version
5555
import optional_scalars
5656
import optional_scalars.ScalarStuff
57+
import union_vector
58+
import union_vector.Attacker
59+
import union_vector.Character
60+
import union_vector.Movie
5761

5862

5963
def create_namespace_shortcut(is_onefile):
@@ -301,6 +305,51 @@ def test_default_values_with_pack_and_unpack(self):
301305
self.assertEqual(monster2.VectorOfEnumsLength(), 0)
302306
self.assertTrue(monster2.VectorOfEnumsIsNone())
303307

308+
def test_union_vectors_with_pack_and_unpack(self):
309+
b = flatbuffers.Builder(0)
310+
311+
union_vector.Attacker.Start(b)
312+
union_vector.Attacker.AddSwordAttackDamage(b, 1)
313+
attacker_offset = union_vector.Attacker.End(b)
314+
315+
union_vector.Attacker.Start(b)
316+
union_vector.Attacker.AddSwordAttackDamage(b, 2)
317+
attacker_offset2 = union_vector.Attacker.End(b)
318+
319+
characters = [attacker_offset, attacker_offset2]
320+
character_types = [union_vector.Character.Character.MuLan, union_vector.Character.Character.MuLan]
321+
322+
union_vector.Movie.StartCharactersTypeVector(b, len(character_types))
323+
for character_type in reversed(character_types):
324+
b.PrependByte(character_type)
325+
character_types_offset = b.EndVector()
326+
327+
union_vector.Movie.StartCharactersVector(b, len(characters))
328+
for character in reversed(characters):
329+
b.PrependUOffsetTRelative(character)
330+
characters_offset = b.EndVector()
331+
332+
union_vector.Movie.Start(b)
333+
union_vector.Movie.AddMainCharacterType(b, 0)
334+
union_vector.Movie.AddMainCharacter(b, 0)
335+
union_vector.Movie.AddCharactersType(b, character_types_offset)
336+
union_vector.Movie.AddCharacters(b, characters_offset)
337+
movie_offset = union_vector.Movie.End(b)
338+
b.Finish(movie_offset)
339+
340+
buf = b.Output()
341+
movie = union_vector.Movie.Movie.GetRootAsMovie(buf, 0)
342+
343+
self.assertEqual(movie.CharactersTypeLength(), len(character_types))
344+
self.assertEqual(movie.CharactersLength(), len(characters))
345+
self.assertEqual(movie.CharactersType(0), character_types[0])
346+
347+
character = union_vector.Attacker.Attacker()
348+
character.Init(movie.Characters(0).Bytes, movie.Characters(0).Pos)
349+
self.assertEqual(character.SwordAttackDamage(), 1)
350+
character.Init(movie.Characters(1).Bytes, movie.Characters(1).Pos)
351+
self.assertEqual(character.SwordAttackDamage(), 2)
352+
304353
def test_optional_scalars_with_pack_and_unpack(self):
305354
""" Serializes and deserializes between a buffer with optional values (no
306355
specific values are filled when the buffer is created) and its python

tests/union_vector/Attacker.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# automatically generated by the FlatBuffers compiler, do not modify
2+
3+
# namespace:
4+
5+
import flatbuffers
6+
from flatbuffers.compat import import_numpy
7+
np = import_numpy()
8+
9+
class Attacker(object):
10+
__slots__ = ['_tab']
11+
12+
@classmethod
13+
def GetRootAs(cls, buf, offset=0):
14+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
15+
x = Attacker()
16+
x.Init(buf, n + offset)
17+
return x
18+
19+
@classmethod
20+
def GetRootAsAttacker(cls, buf, offset=0):
21+
"""This method is deprecated. Please switch to GetRootAs."""
22+
return cls.GetRootAs(buf, offset)
23+
@classmethod
24+
def AttackerBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
25+
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4D\x4F\x56\x49", size_prefixed=size_prefixed)
26+
27+
# Attacker
28+
def Init(self, buf, pos):
29+
self._tab = flatbuffers.table.Table(buf, pos)
30+
31+
# Attacker
32+
def SwordAttackDamage(self):
33+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
34+
if o != 0:
35+
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
36+
return 0
37+
38+
def AttackerStart(builder): builder.StartObject(1)
39+
def Start(builder):
40+
return AttackerStart(builder)
41+
def AttackerAddSwordAttackDamage(builder, swordAttackDamage): builder.PrependInt32Slot(0, swordAttackDamage, 0)
42+
def AddSwordAttackDamage(builder, swordAttackDamage):
43+
return AttackerAddSwordAttackDamage(builder, swordAttackDamage)
44+
def AttackerEnd(builder): return builder.EndObject()
45+
def End(builder):
46+
return AttackerEnd(builder)
47+
48+
class AttackerT(object):
49+
50+
# AttackerT
51+
def __init__(self):
52+
self.swordAttackDamage = 0 # type: int
53+
54+
@classmethod
55+
def InitFromBuf(cls, buf, pos):
56+
attacker = Attacker()
57+
attacker.Init(buf, pos)
58+
return cls.InitFromObj(attacker)
59+
60+
@classmethod
61+
def InitFromObj(cls, attacker):
62+
x = AttackerT()
63+
x._UnPack(attacker)
64+
return x
65+
66+
# AttackerT
67+
def _UnPack(self, attacker):
68+
if attacker is None:
69+
return
70+
self.swordAttackDamage = attacker.SwordAttackDamage()
71+
72+
# AttackerT
73+
def Pack(self, builder):
74+
AttackerStart(builder)
75+
AttackerAddSwordAttackDamage(builder, self.swordAttackDamage)
76+
attacker = AttackerEnd(builder)
77+
return attacker

tests/union_vector/BookReader.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# automatically generated by the FlatBuffers compiler, do not modify
2+
3+
# namespace:
4+
5+
import flatbuffers
6+
from flatbuffers.compat import import_numpy
7+
np = import_numpy()
8+
9+
class BookReader(object):
10+
__slots__ = ['_tab']
11+
12+
@classmethod
13+
def SizeOf(cls):
14+
return 4
15+
16+
# BookReader
17+
def Init(self, buf, pos):
18+
self._tab = flatbuffers.table.Table(buf, pos)
19+
20+
# BookReader
21+
def BooksRead(self): return self._tab.Get(flatbuffers.number_types.Int32Flags, self._tab.Pos + flatbuffers.number_types.UOffsetTFlags.py_type(0))
22+
23+
def CreateBookReader(builder, booksRead):
24+
builder.Prep(4, 4)
25+
builder.PrependInt32(booksRead)
26+
return builder.Offset()
27+
28+
29+
class BookReaderT(object):
30+
31+
# BookReaderT
32+
def __init__(self):
33+
self.booksRead = 0 # type: int
34+
35+
@classmethod
36+
def InitFromBuf(cls, buf, pos):
37+
bookReader = BookReader()
38+
bookReader.Init(buf, pos)
39+
return cls.InitFromObj(bookReader)
40+
41+
@classmethod
42+
def InitFromObj(cls, bookReader):
43+
x = BookReaderT()
44+
x._UnPack(bookReader)
45+
return x
46+
47+
# BookReaderT
48+
def _UnPack(self, bookReader):
49+
if bookReader is None:
50+
return
51+
self.booksRead = bookReader.BooksRead()
52+
53+
# BookReaderT
54+
def Pack(self, builder):
55+
return CreateBookReader(builder, self.booksRead)

tests/union_vector/Character.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# automatically generated by the FlatBuffers compiler, do not modify
2+
3+
# namespace:
4+
5+
class Character(object):
6+
NONE = 0
7+
MuLan = 1
8+
Rapunzel = 2
9+
Belle = 3
10+
BookFan = 4
11+
Other = 5
12+
Unused = 6
13+
14+
def CharacterCreator(unionType, table):
15+
from flatbuffers.table import Table
16+
if not isinstance(table, Table):
17+
return None
18+
if unionType == Character().MuLan:
19+
import Attacker
20+
return Attacker.AttackerT.InitFromBuf(table.Bytes, table.Pos)
21+
if unionType == Character().Rapunzel:
22+
import Rapunzel
23+
return Rapunzel.RapunzelT.InitFromBuf(table.Bytes, table.Pos)
24+
if unionType == Character().Belle:
25+
import BookReader
26+
return BookReader.BookReaderT.InitFromBuf(table.Bytes, table.Pos)
27+
if unionType == Character().BookFan:
28+
import BookReader
29+
return BookReader.BookReaderT.InitFromBuf(table.Bytes, table.Pos)
30+
if unionType == Character().Other:
31+
tab = Table(table.Bytes, table.Pos)
32+
union = tab.String(table.Pos)
33+
return union
34+
if unionType == Character().Unused:
35+
tab = Table(table.Bytes, table.Pos)
36+
union = tab.String(table.Pos)
37+
return union
38+
return None

tests/union_vector/FallingTub.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# automatically generated by the FlatBuffers compiler, do not modify
2+
3+
# namespace:
4+
5+
import flatbuffers
6+
from flatbuffers.compat import import_numpy
7+
np = import_numpy()
8+
9+
class FallingTub(object):
10+
__slots__ = ['_tab']
11+
12+
@classmethod
13+
def SizeOf(cls):
14+
return 4
15+
16+
# FallingTub
17+
def Init(self, buf, pos):
18+
self._tab = flatbuffers.table.Table(buf, pos)
19+
20+
# FallingTub
21+
def Weight(self): return self._tab.Get(flatbuffers.number_types.Int32Flags, self._tab.Pos + flatbuffers.number_types.UOffsetTFlags.py_type(0))
22+
23+
def CreateFallingTub(builder, weight):
24+
builder.Prep(4, 4)
25+
builder.PrependInt32(weight)
26+
return builder.Offset()
27+
28+
29+
class FallingTubT(object):
30+
31+
# FallingTubT
32+
def __init__(self):
33+
self.weight = 0 # type: int
34+
35+
@classmethod
36+
def InitFromBuf(cls, buf, pos):
37+
fallingTub = FallingTub()
38+
fallingTub.Init(buf, pos)
39+
return cls.InitFromObj(fallingTub)
40+
41+
@classmethod
42+
def InitFromObj(cls, fallingTub):
43+
x = FallingTubT()
44+
x._UnPack(fallingTub)
45+
return x
46+
47+
# FallingTubT
48+
def _UnPack(self, fallingTub):
49+
if fallingTub is None:
50+
return
51+
self.weight = fallingTub.Weight()
52+
53+
# FallingTubT
54+
def Pack(self, builder):
55+
return CreateFallingTub(builder, self.weight)

tests/union_vector/Gadget.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# automatically generated by the FlatBuffers compiler, do not modify
2+
3+
# namespace:
4+
5+
class Gadget(object):
6+
NONE = 0
7+
FallingTub = 1
8+
HandFan = 2
9+
10+
def GadgetCreator(unionType, table):
11+
from flatbuffers.table import Table
12+
if not isinstance(table, Table):
13+
return None
14+
if unionType == Gadget().FallingTub:
15+
import FallingTub
16+
return FallingTub.FallingTubT.InitFromBuf(table.Bytes, table.Pos)
17+
if unionType == Gadget().HandFan:
18+
import HandFan
19+
return HandFan.HandFanT.InitFromBuf(table.Bytes, table.Pos)
20+
return None

0 commit comments

Comments
 (0)