more code refactor

This commit is contained in:
Blaise
2024-01-16 17:14:18 +01:00
parent 0d92575115
commit 0d3d47f3c3
44 changed files with 4516 additions and 2623 deletions

View File

@@ -11,22 +11,24 @@ def load_yaml_config(path):
def save_config_to_yaml(config, path):
assert path.endswith('.yaml')
with open(path, 'w') as f:
assert path.endswith(".yaml")
with open(path, "w") as f:
f.write(yaml.dump(config))
f.close()
def write_args(args, path):
args_dict = dict((name, getattr(args, name)) for name in dir(args)
if not name.startswith('_'))
with open(path, 'a') as args_file:
args_file.write('==> torch version: {}\n'.format(torch.__version__))
args_dict = dict(
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
)
with open(path, "a") as args_file:
args_file.write("==> torch version: {}\n".format(torch.__version__))
args_file.write(
'==> cudnn version: {}\n'.format(torch.backends.cudnn.version()))
args_file.write('==> Cmd:\n')
"==> cudnn version: {}\n".format(torch.backends.cudnn.version())
)
args_file.write("==> Cmd:\n")
args_file.write(str(sys.argv))
args_file.write('\n==> args:\n')
args_file.write("\n==> args:\n")
for k, v in sorted(args_dict.items()):
args_file.write(' %s: %s\n' % (str(k), str(v)))
args_file.write(" %s: %s\n" % (str(k), str(v)))
args_file.close()