diff --git a/pkg/model/types.go b/pkg/model/types.go index 0930039c..099d4a90 100644 --- a/pkg/model/types.go +++ b/pkg/model/types.go @@ -11,29 +11,32 @@ type Model struct { } type Metadata struct { - Author string `json:"author,omitempty" yaml:"author,omitempty"` - Created time.Time `json:"created,omitempty" yaml:"created,omitempty"` - Description string `json:"description,omitempty" yaml:"description,omitempty"` - Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"` - Labels map[string]string `json:"labels,omitempty" yaml:"labels,omitempty"` - Format string `json:"format,omitempty" yaml:"format,omitempty"` - // GPUType is for TensorRT format only, it must be set when extract signature or serve + Author string `json:"author,omitempty" yaml:"author,omitempty"` + Created time.Time `json:"created,omitempty" yaml:"created,omitempty"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"` + Labels map[string]string `json:"labels,omitempty" yaml:"labels,omitempty"` + Format string `json:"format,omitempty" yaml:"format,omitempty"` + // GPUType is for TensorRT format only, it must be set when extract signature or serve // as a online service, otherwise, it can not extract or serve as a service. // for other model format, you can set empty string or not set. - GPUType string `json:"gpuType,omitempty" yaml:"gpuType,omitempty"` - Framework string `json:"framework,omitempty" yaml:"framework,omitempty"` - Metrics []Metric `json:"metrics,omitempty" yaml:"metrics,omitempty"` - Hyperparameters []Hyperparameter `json:"hyperparameters,omitempty" yaml:"hyperparameters,omitempty"` - Signature Signature `json:"signature,omitempty" yaml:"signature,omitempty"` - Training Training `json:"training,omitempty" yaml:"training,omitempty"` - Dataset Dataset `json:"dataset,omitempty" yaml:"dataset,omitempty"` + GPUType string `json:"gpuType,omitempty" yaml:"gpuType,omitempty"` + Framework string `json:"framework,omitempty" yaml:"framework,omitempty"` + Metrics []Metric `json:"metrics,omitempty" yaml:"metrics,omitempty"` + Hyperparameters []Hyperparameter `json:"hyperparameters,omitempty" yaml:"hyperparameters,omitempty"` + Signature *Signature `json:"signature,omitempty" yaml:"signature,omitempty"` + Training *Training `json:"training,omitempty" yaml:"training,omitempty"` + Dataset *Dataset `json:"dataset,omitempty" yaml:"dataset,omitempty"` + DirectoryStructure []string `json:"directoryStructure,omitempty" yaml:"directoryStructure,omitempty"` } +// Metric is the type for training metric (e.g. acc). type Metric struct { Name string `json:"name"` Value string `json:"value"` } +// Hyperparameter is the type for training hyperparameter (e.g. learning rate). type Hyperparameter struct { Name string `json:"name"` Value string `json:"value"` diff --git a/pkg/saver/saver.go b/pkg/saver/saver.go index 4f46ee05..1f8809e9 100644 --- a/pkg/saver/saver.go +++ b/pkg/saver/saver.go @@ -48,9 +48,13 @@ func (d Saver) Save(path string) (*model.Model, error) { // Save the model from /model. buf := &bytes.Buffer{} - if err := Tar(filepath.Join(path, consts.ORMBModelDirectory), buf); err != nil { + directoryStructure, err := TarAndGetDirectoryStructure( + filepath.Join(path, consts.ORMBModelDirectory), buf) + if err != nil { return nil, err } + // Set directoryStructure for the model metadata. + metadata.DirectoryStructure = directoryStructure m := &model.Model{ Metadata: metadata, @@ -61,12 +65,14 @@ func (d Saver) Save(path string) (*model.Model, error) { return m, nil } -// Tar is copied from https://medium.com/@skdomino/taring-untaring-files-in-go-6b07cf56bc07. -func Tar(src string, writers ...io.Writer) error { +// TarAndGetDirectoryStructure is copied from https://medium.com/@skdomino/taring-untaring-files-in-go-6b07cf56bc07. +func TarAndGetDirectoryStructure( + src string, writers ...io.Writer) ([]string, error) { + structure := make([]string, 0) // ensure the src actually exists before trying to tar it if _, err := os.Stat(src); err != nil { - return fmt.Errorf("Unable to tar files - %v", err.Error()) + return nil, fmt.Errorf("Unable to tar files - %v", err.Error()) } mw := io.MultiWriter(writers...) @@ -78,7 +84,7 @@ func Tar(src string, writers ...io.Writer) error { defer tw.Close() // walk path - return filepath.Walk(src, func(file string, fi os.FileInfo, err error) error { + err := filepath.Walk(src, func(file string, fi os.FileInfo, err error) error { // return on any error if err != nil { @@ -99,7 +105,11 @@ func Tar(src string, writers ...io.Writer) error { parentDir := filepath.Dir(src) // update the name to correctly reflect the desired destination when untaring - header.Name = strings.TrimPrefix(strings.Replace(file, parentDir, "", -1), string(filepath.Separator)) + header.Name = strings.TrimPrefix( + strings.Replace(file, parentDir, "", -1), string(filepath.Separator)) + + // Add filename to the directory structure. + structure = append(structure, header.Name) // write the header if err := tw.WriteHeader(header); err != nil { @@ -123,4 +133,6 @@ func Tar(src string, writers ...io.Writer) error { return nil }) + + return structure, err }