Commit 934f8b0d authored by Christopher Ostrouchov's avatar Christopher Ostrouchov
Browse files

Added training set base

parent 28c216eb
from .vasp import VaspReader
from .lammps import LammpsReader, LammpsWriter
from .mattoolkit import MTKReader
import sys
class FloatParameter:
""" Float with tracking. initial value and bounds.
"""
def __init__(self, initial, bounds=(-sys.float_info.max, sys.float_info.max), computed=None):
self.initial = float(initial)
self.current = float(initial)
self.bounds = [float(_) for _ in bounds]
self.computed = computed
def __float__(self):
if self.computed is None:
return self.current
return self.computed()
def __str__(self):
return str(self.current)
......@@ -8,7 +8,8 @@ import yaml
import numpy as np
from pymatgen.core import Composition
from .schema import PotentialSchema, Parameter
from .schema import PotentialSchema
from .parameter import FloatParameter
class Potential:
......@@ -26,7 +27,7 @@ class Potential:
if not {e.symbol for e in composition.keys()} <= charges.keys():
raise ValueError('charge ballance constrains requires all elements to be defined in charge')
for charge_element, parameter in charges.items():
if isinstance(parameter, Parameter) and parameter.computed == None:
if isinstance(parameter, FloatParameter) and parameter.computed == None:
break
else:
if abs(sum(float(charges[element.symbol]) * amount for element, amount in composition.items())) > 1e-8:
......@@ -44,7 +45,7 @@ class Potential:
elif isinstance(value, (tuple, list)):
for item in value:
_walk(item)
elif isinstance(value, Parameter) and value.computed == None:
elif isinstance(value, FloatParameter) and value.computed == None:
self._parameters.append(value)
_walk(self.schema)
......
from .potential import PotentialSchema
from .training import TrainingSchema
from marshmallow import Schema, ValidationError
from marshmallow.decorators import validates_schema
class BaseSchema(Schema):
def __init__(self, strict=True, **kwargs):
super(Schema, self).__init__(strict=strict, **kwargs)
@validates_schema(pass_original=True, pass_many=False, skip_on_field_errors=True)
def check_unknown_fields(self, data, original_data):
def check_unknown(original_data_single):
dump_only_keys = {key for key in self.fields if self.fields[key].dump_only}
unknown_dump = set(dump_only_keys) & set(original_data_single)
unknown_invalid = set(original_data_single) - set(self.fields.keys())
unknown = unknown_dump | unknown_invalid
if unknown:
raise ValidationError('Unknown field', unknown)
if isinstance(original_data, list):
for original_data_single in original_data:
check_unknown(original_data_single)
else:
check_unknown(original_data)
# Taken from marshmallow_polyfield package
from marshmallow import ValidationError
from marshmallow.fields import Field
class PolyField(Field):
"""
A field that (de)serializes to one of many types. Passed in functions
are called to disambiguate what schema to use for the (de)serialization
Intended to assist in working with fields that can contain any subclass
of a base type
"""
def __init__(
self,
serialization_schema_selector=None,
deserialization_schema_selector=None,
many=False,
**metadata
):
"""
:param serialization_schema_selector: Function that takes in either
an object representing that object, it's parent object
and returns the appropriate schema.
:param deserialization_schema_selector: Function that takes in either
an a dict representing that object, dict representing it's parent dict
and returns the appropriate schema
"""
super(PolyField, self).__init__(**metadata)
self.many = many
self.serialization_schema_selector = serialization_schema_selector
self.deserialization_schema_selector = deserialization_schema_selector
def _deserialize(self, value, attr, data):
if not self.many:
value = [value]
results = []
for v in value:
schema = None
try:
schema = self.deserialization_schema_selector(v, data)
assert hasattr(schema, 'load')
except Exception:
schema_message = None
if schema:
schema_message = str(type(schema))
raise ValidationError(
"Unable to use schema. Ensure there is a deserialization_schema_selector"
" and that it returns a schema when the function is passed in {value_passed}."
" This is the class I got. Make sure it is a schema: {class_type}".format(
value_passed=v,
class_type=schema_message
)
)
data, errors = schema.load(v)
if errors:
raise ValidationError(errors)
results.append(data)
if self.many:
return results
else:
# Will be at least one otherwise value would have been None
return results[0]
def _serialize(self, value, key, obj):
if value is None:
return None
try:
if self.many:
return [self.serialization_schema_selector(v, obj).dump(v).data for v in value]
else:
return self.serialization_schema_selector(value, obj).dump(value).data
except Exception as err:
raise TypeError(
'Failed to serialize object. Error: {0}\n'
' Ensure the serialization_schema_selector exists and '
' returns a Schema and that schema'
' can serialize this value {1}'.format(err, value))
import sys
from marshmallow import Schema, fields, validate, ValidationError, pre_load
from marshmallow.decorators import validates_schema
from marshmallow import fields, validate, ValidationError
from .data import element_symbols
class BaseSchema(Schema):
def __init__(self, strict=True, **kwargs):
super(Schema, self).__init__(strict=strict, **kwargs)
@validates_schema(pass_original=True, pass_many=False, skip_on_field_errors=True)
def check_unknown_fields(self, data, original_data):
def check_unknown(original_data_single):
dump_only_keys = {key for key in self.fields if self.fields[key].dump_only}
unknown_dump = set(dump_only_keys) & set(original_data_single)
unknown_invalid = set(original_data_single) - set(self.fields.keys())
unknown = unknown_dump | unknown_invalid
if unknown:
raise ValidationError('Unknown field', unknown)
if isinstance(original_data, list):
for original_data_single in original_data:
check_unknown(original_data_single)
else:
check_unknown(original_data)
# ======= Potential ========
class Parameter:
""" Float with tracking. initial value and bounds.
"""
def __init__(self, initial, bounds=(-sys.float_info.max, sys.float_info.max), computed=None):
self.initial = float(initial)
self.current = float(initial)
self.bounds = [float(_) for _ in bounds]
self.computed = computed
def __float__(self):
if self.computed is None:
return self.current
return self.computed()
def __str__(self):
return str(self.current)
from .base import BaseSchema
from ..data import element_symbols
from ..parameter import FloatParameter
class ParameterSchema(BaseSchema):
......@@ -60,7 +18,7 @@ class FloatOrParameter(fields.Field):
return float(value)
except (ValueError, TypeError):
schema_load, errors = ParameterSchema().load(value)
return Parameter(**value)
return FloatParameter(**value)
def _validate(self, value):
if value is None:
......@@ -112,15 +70,3 @@ class PotentialSchema(BaseSchema):
version = fields.String(required=True, validate=validate.Equal('v1'))
kind = fields.String(required=True, validate=validate.Equal('Potential'))
spec = fields.Nested(PotentialSpecSchema)
# ======= Training Set =======
class TrainingSpecSchema(BaseSchema):
pass
class TrainingSchema(BaseSchema):
version = fields.String(required=True, validate=validate.Equal('v1'))
kind = fields.String(required=True, validate=validate.Equal('Training'))
spec = fields.Nested(TrainingSpecSchema)
from marshmallow import fields, validate, ValidationError
from .base import BaseSchema
from .fields import PolyField
class MTKSelectorSchema(BaseSchema):
labels = fields.List(fields.String(), required=True)
class MTKTrainingSetSchema(BaseSchema):
type = fields.String(required=True, validate=validate.Equal('mattoolkit'))
selector = fields.Nested(MTKSelectorSchema, required=True)
def training_property_schema_serialization_disambiguation(base_object, obj):
type_to_schema = {
'mattoolkit': MTKTrainingSetSchema,
}
try:
return type_to_schema[obj.mode]()
except KeyError:
pass
raise TypeError("Could not detect type did you specify a type?")
def trainging_property_schema_deserialization_disambiguation(object_dict, data):
type_to_schema = {
'mattoolkit': MTKTrainingSetSchema,
}
try:
return type_to_schema[object_dict['type']]()
except KeyError:
pass
raise TypeError("Could not detect type did you specify a type?")
class TrainingSchema(BaseSchema):
version = fields.String(required=True, validate=validate.Equal('v1'))
kind = fields.String(required=True, validate=validate.Equal('Training'))
spec = PolyField(
serialization_schema_selector=training_property_schema_serialization_disambiguation,
deserialization_schema_selector=trainging_property_schema_deserialization_disambiguation,
many=True
)
......@@ -2,3 +2,52 @@
Has a caching layer as to speed up future runs
"""
import json
import yaml
from mattoolkit.api import CalculationResourceList
from .schema import TrainingSchema
from .io import MTKReader
class Training:
def __init__(self, schema):
schema_load, errors = TrainingSchema().load(schema)
self.schema = schema_load
self._gather_calculations()
def _gather_calculations(self):
self._calculations = []
for calculation in self.schema['spec']:
if calculation['type'] == 'mattoolkit':
self._calculations.extend(self.download_mattoolkit_calculations(calculation['selector']))
@property
def calculations(self):
return self._calculations
def download_mattoolkit_calculations(self, selector):
calculations = CalculationResourceList()
calculations.get(params={'labels': selector['labels']})
return [MTKReader(c.id) for c in calculations.items]
@classmethod
def from_file(cls, filename, format=None):
if format not in {'json', 'yaml'}:
if filename.endswith('json'):
format = 'json'
elif filename.endswith('yaml') or filename.endswith('yml'):
format = 'yaml'
else:
raise ValueError('unrecognized filetype from filename %s' % filename)
if format == 'json':
with open(filename) as f:
return cls(json.load(f))
elif format in {'yaml', 'yml'}:
with open(filename) as f:
return cls(yaml.load(f))
def __str__(self):
return json.dumps(self.schema, sort_keys=True, indent=4)
......@@ -8,13 +8,6 @@ spec:
- structrue:MgO
- calculation_type:static
- calculation_group:lattice_constant
- type: mattoolkit
selector:
labels:
- project:potential_fitting
- structure:MgO
- calculation_type:static
- calculation_group:perturb
- type: mattoolkit
selector:
labels:
......@@ -22,17 +15,3 @@ spec:
- structure:MgO
- calculation_type:static
- calculation_group:phonon_displacements
- type: mattoolkit
selector:
labels:
- project:potential_fitting
- structure:MgO
- calculation_type:static
- calculation_group:perturb_training
- type: mattoolkit
selector:
labels:
- project:potential_fitting
- structure:MgO
- calculation_type:static
- calculation_group:strains_training
import pytest
from dftfit.training import Training
@pytest.mark.mattoolkit
def test_potential_from_file():
potential = Training.from_file('test_files/training/training-mgo.yaml')
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment