ref: 1711e971655439c02423ceeec58e3078d2eee5a3
parent: 20568812ae92bb148fe3fb0190b7629f1c4d0b96
author: Jan Buethe <jbuethe@amazon.de>
date: Mon May 6 10:11:59 EDT 2024
fixed enable_binary_blob option for CWriter
--- a/dnn/torch/lossgen/export_lossgen.py
+++ b/dnn/torch/lossgen/export_lossgen.py
@@ -52,7 +52,7 @@
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"- writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen')
+ writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen', enable_binary_blob=False)
writer.header.write(
f"""
#include "opus_types.h"
--- a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py
+++ b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py
@@ -120,50 +120,49 @@
def _finalize_header(self):
# create model type
- if self.enable_binary_blob:
- if self.add_typedef:
- self.header.write(f"\ntypedef struct {{")- else:
- self.header.write(f"\nstruct {self.model_struct_name} {{")- for name, data in self.layer_dict.items():
- layer_type = data[0]
- self.header.write(f"\n {layer_type} {name};")- if self.add_typedef:
- self.header.write(f"\n}} {self.model_struct_name};\n")- else:
- self.header.write(f"\n}};\n")
+ if self.add_typedef:
+ self.header.write(f"\ntypedef struct {{")+ else:
+ self.header.write(f"\nstruct {self.model_struct_name} {{")+ for name, data in self.layer_dict.items():
+ layer_type = data[0]
+ self.header.write(f"\n {layer_type} {name};")+ if self.add_typedef:
+ self.header.write(f"\n}} {self.model_struct_name};\n")+ else:
+ self.header.write(f"\n}};\n")
- init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"- self.header.write(f"\n{init_prototype};\n")+ init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"+ self.header.write(f"\n{init_prototype};\n") self.header.write(f"\n#endif /* {self.header_guard} */\n")def _finalize_source(self):
- if self.enable_binary_blob:
- # create weight array
- if len(set(self.weight_arrays)) != len(self.weight_arrays):
- raise ValueError("error: detected duplicates in weight arrays")- self.source.write("\n#ifndef USE_WEIGHTS_FILE\n")- self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n")- for name in self.weight_arrays:
- self.source.write(f"#ifdef WEIGHTS_{name}_DEFINED\n")- self.source.write(f' {{"{name}", WEIGHTS_{name}_TYPE, sizeof({name}), {name}}},\n')- self.source.write(f"#endif\n")
- self.source.write(" {NULL, 0, 0, NULL}\n")- self.source.write("};\n")- self.source.write("#endif /* USE_WEIGHTS_FILE */\n")+ # create weight array
+ if len(set(self.weight_arrays)) != len(self.weight_arrays):
+ raise ValueError("error: detected duplicates in weight arrays")+ if self.enable_binary_blob: self.source.write("\n#ifndef USE_WEIGHTS_FILE\n")+ self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n")+ for name in self.weight_arrays:
+ self.source.write(f"#ifdef WEIGHTS_{name}_DEFINED\n")+ self.source.write(f' {{"{name}", WEIGHTS_{name}_TYPE, sizeof({name}), {name}}},\n')+ self.source.write(f"#endif\n")
+ self.source.write(" {NULL, 0, 0, NULL}\n")+ self.source.write("};\n")- # create init function definition
- init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"- self.source.write("\n#ifndef DUMP_BINARY_WEIGHTS\n")- self.source.write(f"{init_prototype} {{\n")- for name, data in self.layer_dict.items():
- self.source.write(f" if ({data[1]}) return 1;\n")- self.source.write(" return 0;\n")- self.source.write("}\n")- self.source.write("#endif /* DUMP_BINARY_WEIGHTS */\n")+ if self.enable_binary_blob: self.source.write("#endif /* USE_WEIGHTS_FILE */\n")+
+ # create init function definition
+ init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"+ if self.enable_binary_blob: self.source.write("\n#ifndef DUMP_BINARY_WEIGHTS\n")+ self.source.write(f"{init_prototype} {{\n")+ for name, data in self.layer_dict.items():
+ self.source.write(f" if ({data[1]}) return 1;\n")+ self.source.write(" return 0;\n")+ self.source.write("}\n")+ if self.enable_binary_blob:self.source.write("#endif /* DUMP_BINARY_WEIGHTS */\n")def close(self):
--- a/dnn/torch/weight-exchange/wexchange/c_export/common.py
+++ b/dnn/torch/weight-exchange/wexchange/c_export/common.py
@@ -54,7 +54,7 @@
#ifndef USE_WEIGHTS_FILE
'''
)
- writer.weight_arrays.append(name)
+ writer.weight_arrays.append(name)
if reshape_8x4:
vector = vector.reshape((vector.shape[0]//4, 4, vector.shape[1]//8, 8))
--
⑨