Tensorflow model import to Java - java

I have been trying to import and make use of my trained model (Tensorflow, Python) in Java.
I was able to save the model in Python, but encountered problems when I try to make predictions using the same model in Java.
Here, you can see the python code for initializing, training, saving the model.
Here, you can see the Java code for importing and making predictions for input values.
The error message I get is:
Exception in thread "main" java.lang.IllegalStateException: Attempting to use uninitialized value Variable_7
[[Node: Variable_7/read = Identity[T=DT_FLOAT, _class=["loc:#Variable_7"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_7)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:285)
at org.tensorflow.Session$Runner.run(Session.java:235)
at org.tensorflow.examples.Identity_import.main(Identity_import.java:35)
I believe, the problem is somewhere in the python code, but I was not able to find it.

The Java importGraphDef() function is only importing the computational graph (written by tf.train.write_graph in your Python code), it isn't loading the values of trained variables (stored in the checkpoint), which is why you get an error complaining about uninitialized variables.
The TensorFlow SavedModel format on the other hand includes all information about a model (graph, checkpoint state, other metadata) and to use in Java you'd want to use SavedModelBundle.load to create session initialized with the trained variable values.
To export a model in this format from Python, you might want to take a look at a related question Deploy retrained inception SavedModel to google cloud ml engine
In your case, this should amount to something like the following in Python:
def save_model(session, input_tensor, output_tensor):
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)},
outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)},
)
b = saved_model_builder.SavedModelBuilder('/tmp/model')
b.add_meta_graph_and_variables(session,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
b.save()
And invoke that via save_model(session, x, yhat)
And then in Java load the model using:
try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) {
// b.session().run(...)
}
Hope that helps.

Fwiw, Deeplearning4j lets you import models trained on TensorFlow with Keras 1.0 (Keras 2.0 support is on the way).
https://deeplearning4j.org/model-import-keras
We also built a library called Jumpy, which is a wrapper around Numpy arrays and Pyjnius that uses pointers instead of copying data, which makes it more efficient than Py4j when dealing with tensors.
https://deeplearning4j.org/jumpy

Your python-model will certainly fail at this:
sess.run(init) #<---this will fail
save_model(sess)
error = tf.reduce_mean(tf.square(prediction - y))
#accuracy = tf.reduce_mean(tf.cast(error, 'float'))
print('Error:', error)
init is not defined in the model - I'm unsure what you want achieve at this place, but that should give you a starting point

Related

Gremlin: getting json response in Java with gremlin-driver

I have the following query:
g
.V("user-11")
.repeat(bothE().subgraph("subGraph").outV())
.times(2)
.cap("subGraph")
.next()
When I run it using gremlin-python, I receive the following response:
{'#type': 'tinker:graph',
'#value': {'vertices': [v[device-3], v[device-1], v[user-11], v[card-1]],
'edges': [e[68bad734-db2b-bffc-3e17-a0813d2670cc][user-11-uses_device->device-1],
e[14bad735-2b70-860f-705f-4c0b769a7849][user-11-uses_device->device-3],
e[f0bb3b6d-d161-ec60-5e6d-068272297f24][user-11-uses_card->card-1]]}}
Which is a Graphson representation of the subgraph obtained by the query.
I want to get the same response using Java and gremlin-driver but I haven't been able to figure how.
My best try was:
ObjectMapper mapper = GraphSONMapper.build().version(GraphSONVersion.V3_0).create().createMapper();
Object a = graphTraversalSource
.V(nodeId)
.repeat(bothE().subgraph("subGraph").outV())
.times(2)
.cap("subGraph")
.next();
return mapper.writeValueAsString(a);
But that gave me the following error:
io.netty.handler.codec.DecoderException: org.apache.tinkerpop.gremlin.driver.ser.SerializationException: org.apache.tinkerpop.shaded.kryo.KryoException: Encountered unregistered class ID: 65536
I am using AWS Neptune, but I doubt that makes a difference given that I receive the answer I want through gremlin-python.
I appreciate any help you can give! Thanks
As mentioned in the comments
When using Java what you get back will be an actual TinkerGraph
Using the GraphBinary or GraphSONV3D0 serializer is recommended.
The Gyro one is older and is likely causing the error you saw if you did not specify one of the others serializers.
Note that even if you use one of the other serializers, to get the graph to deserialize into JSON you will need to use the specific TinkerGraph serializer (see the end of this answer for an example). Otherwise you will just get {} returned.
However, you may not need to produce JSON at all in the case of the Java Gremlin client ....
Given you have an actual TinkerGraph back you can run real Gremlin queries against the in-memory subgraph - just create a new traversal source for it. You can also use the graph.io classes to write the graph to file should you wish to. The TinkerGraph will include properties as well as edges and vertices.
You can also access the TinkerGraph object directly using statements such as
a.vertices and a.edges
By means of a concrete example, if you have a query of the form
TinkerGraph tg = (TinkerGraph)g.V().bothE().subgraph("sg").cap("sg").next();
Then you can do
GraphTraversalSource g2 = tg.traversal();
Long cv = g2.V().count().next();
Long ce = g2.E().count().next();
Or you can just access the TinkerGraph data structure directly using statements of the form:
Vertex v = tg.vertices[<some-id>]
Or
List properties = tg.vertices[<some-id>].properties()
This actually means you have a lot more power available to you in the Java client when working with subgraphs.
If you still feel that you need a JSON version of your subgraph, the IO reference is a handy bookmark to have: https://tinkerpop.apache.org/docs/3.4.9/dev/io/#_io_reference
EDITED: - to save you a lot of reading the docs, this code will print a TinkerGraph as JSON
mapper = GraphSONMapper.build().
addRegistry(TinkerIoRegistryV3d0.instance()).
version(GraphSONVersion.V3_0).create().createMapper();
mapper.writeValueAsString(tg)

Wikidata Toolkit: Is it possible to access properties of entities?

First of all, I want to clarify that my experience working with wikidata is very limited, so feel free to correct if any of my terminology is wrong.
I've been playing with wikidata toolkit, more specifically their wdtk-wikibaseapi. This allows you to get entity information and their different properties as such:
WikibaseDataFetcher wbdf = WikibaseDataFetcher.getWikidataDataFetcher();
EntityDocument q42 = wbdf.getEntityDocument("Q42");
List<StatementGroup> groups = ((ItemDocument) q42).getStatementGroups();
for(StatementGroup g : groups) {
List<Statement> statements = g.getStatements();
for(Statement s : statements) {
System.out.println(s.getMainSnak().getPropertyId().getId());
System.out.println(s.getValue());
}
}
The above would get me the entity Douglas Adams and all the properties under his site: https://www.wikidata.org/wiki/Q42
Now wikidata toolkit has the ability to load and process dump files, meaning you can download a dump to your local and process it using their DumpProcessingController class under the wdtk-dumpfiles library. I'm just not sure what is meant by processing.
Can anyone explain me what does processing mean in this context?
Can you do something similar to what was done using wdtk-wikibaseapi in the example above but using a local dump file and wdtk-dumpfiles i.e. get an entity and it's respective properties? I don't want to get the info from online source, only from the dump (offline).
If this is not possible using wikidata-toolkit, could you point me to somewhere that can get me started on getting entities and their properties from a dump file for wikidata please? I am using Java.

How to know the Java interfaces an OpenOffice Calc UNO object supports (through queryInterface)

I'm developing a "macro" for OpenOffice Calc. As the language, I chose Java, in order to get code assistance in Eclipse. I even wrote a small ant build script that compiles and embeds the "macro" in an *.ods file. In general, this works fine and surprisingly fast; I'm already using some simple stuff quite successfully.
BUT
So often I get stuck because with UNO, I need to "query" an interface for any given non-trivial object, to be able to access data / call methods of that object. I.e., I literally need to guess which interfaces a given object may provide. This is not at all obvious and not even visible during Java development (through some sort of meta-information, reflection or the like), and also sparsely documented (I downloaded tons of stuff, but I don't find the source or maybe JavaDoc for the interfaces I'm using, like XButton, XPropertySet, etc. - XButton has setLabel, but not getLabel - what??).
There is online documentation (for the most fundamental concepts, which is not bad at all!), but it lacks many details that I'm faced with. It always magically stops exactly at the point I need to solve.
I'm willing to look at the C++ code to get a clue what interfaces an object (e.g. the button / event I'm currently stuck with) may provide. Confusingly, the C++ class and file names don't exactly match the Java interfaces. It's almost what I'm looking for, but then in Java I don't really find the equivalent, or calling queryInterface on a given object returns null.. It's becoming a bit frustrating.
How are the UNO Java interfaces generated? Is there some kind of documentation in the code that serves as the origin for the generated (Java) code?
I think I really need to know what interfaces are available at which point, in order to become a bit more fluent during Java-UNO-macro development.
For any serious UNO project, use an introspection tool.
As an example, I created a button in Calc, then used the Java Object Inspector to browse to the button.
Right-clicking and choosing "Add to Source Code" generated the following.
import com.sun.star.awt.XControlModel;
import com.sun.star.beans.XPropertySet;
import com.sun.star.container.XIndexAccess;
import com.sun.star.container.XNameAccess;
import com.sun.star.drawing.XControlShape;
import com.sun.star.drawing.XDrawPage;
import com.sun.star.drawing.XDrawPageSupplier;
import com.sun.star.sheet.XSpreadsheetDocument;
import com.sun.star.sheet.XSpreadsheets;
import com.sun.star.uno.AnyConverter;
import com.sun.star.uno.UnoRuntime;
import com.sun.star.uno.XInterface;
//...
public void codesnippet(XInterface _oUnoEntryObject){
try{
XSpreadsheetDocument xSpreadsheetDocument = (XSpreadsheetDocument) UnoRuntime.queryInterface(XSpreadsheetDocument.class, _oUnoEntryObject);
XSpreadsheets xSpreadsheets = xSpreadsheetDocument.getSheets();
XNameAccess xNameAccess = (XNameAccess) UnoRuntime.queryInterface(XNameAccess.class, xSpreadsheets);
Object oName = xNameAccess.getByName("Sheet1");
XDrawPageSupplier xDrawPageSupplier = (XDrawPageSupplier) UnoRuntime.queryInterface(XDrawPageSupplier.class, oName);
XDrawPage xDrawPage = xDrawPageSupplier.getDrawPage();
XIndexAccess xIndexAccess = (XIndexAccess) UnoRuntime.queryInterface(XIndexAccess.class, xDrawPage);
Object oIndex = xIndexAccess.getByIndex(0);
XControlShape xControlShape = (XControlShape) UnoRuntime.queryInterface(XControlShape.class, oIndex);
XControlModel xControlModel = xControlShape.getControl();
XPropertySet xPropertySet = (XPropertySet) UnoRuntime.queryInterface(XPropertySet.class, xControlModel);
String sLabel = AnyConverter.toString(xPropertySet.getPropertyValue("Label"));
}catch (com.sun.star.beans.UnknownPropertyException e){
e.printStackTrace(System.out);
//Enter your Code here...
}catch (com.sun.star.lang.WrappedTargetException e2){
e2.printStackTrace(System.out);
//Enter your Code here...
}catch (com.sun.star.lang.IllegalArgumentException e3){
e3.printStackTrace(System.out);
//Enter your Code here...
}
}
//...
Python-UNO may be better than Java because it does not require querying specific interfaces. Also XrayTool and MRI are easier to use than the Java Object Inspector.

Using training made with python API as input to LabelImage module in java API?

I have a problem with java tensorflow API. I have run the training using the python tensorflow API, generating the files output_graph.pb and output_labels.txt. Now for some reason I want to use those files as input to the LabelImage module in java tensorflow API. I thought everything would have worked fine since that module wants exactly one .pb and one .txt. Nevertheless, when I run the module, I get this error:
2017-04-26 10:12:56.711402: W tensorflow/core/framework/op_def_util.cc:332] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
Exception in thread "main" java.lang.IllegalArgumentException: No Operation named [input] in the Graph
at org.tensorflow.Session$Runner.operationByName(Session.java:343)
at org.tensorflow.Session$Runner.feed(Session.java:137)
at org.tensorflow.Session$Runner.feed(Session.java:126)
at it.zero11.LabelImage.executeInceptionGraph(LabelImage.java:115)
at it.zero11.LabelImage.main(LabelImage.java:68)
I would be very grateful if you help me finding where the problem is. Furthermore I want to ask you if there is a way to run the training from java tensorflow API, because that would make things easier.
To be more precise:
As a matter of fact, I do not use self-written code, at least for the relevant steps. All I have done is doing the training with this module, https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py, feeding it with the directory that contains the images divided among subdirectories according to their description. In particular, I think these are the lines that generate the outputs:
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
f.write('\n'.join(image_lists.keys()) + '\n')
Then, I give the outputs (one some_graph.pb and one some_labels.txt) as input to this java module: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java, replacing the default inputs. The error I get is the one reported above.
The model used by default in LabelImage.java is different that the model that is being retrained, so the names of inputs and output nodes do not align. Note that TensorFlow models are graphs and the arguments to feed() and fetch() are names of nodes in the graph. So you need to know the names appropriate for your model.
Looking at retrain.py, it seems that it has a node that takes the raw contents of a JPEG file as input (the node DecodeJpeg/contents) and produces the set of labels in the node final_result.
If that's the case, then you'd do something like the following in Java (and you don't need the bit that constructs a graph to normalize the image since that seems to be a part of the retrained model, so replace LabelImage.java:64 with something like:
try (Tensor image = Tensor.create(imageBytes);
Graph g = new Graph()) {
g.importGraphDef(graphDef);
try (Session s = new Session(g);
// Note the change to the name of the node and the fact
// that it is being provided the raw imageBytes as input
Tensor result = s.runner().feed("DecodeJpeg/contents", image).fetch("final_result").run().get(0)) {
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
String.format(
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
Arrays.toString(rshape)));
}
int nlabels = (int) rshape[1];
float[] probabilities = result.copyTo(new float[1][nlabels])[0];
// At this point nlabels = number of classes in your retrained model
DoSomethingWith(probabilities);
}
}
Hope that helps.
Regarding the "No operation" error, I was able to resolve that by using input and output layer names "Mul" and "final_result", respectively. See:
https://github.com/tensorflow/tensorflow/issues/2883

What is the Tensorflow Java Api `toGraphDef` equivalent in Python?

I am using the Tensorflow Java Api to load an already created Tensorflow model into the JVM.
I am using this as an example: tensorflow/examples/LabelImage.java
Here is my simple scala code:
import java.nio.file.{Files, Path, Paths}
import org.tensorflow.{Graph, Session, Tensor}
def readAllBytesOrExit(path: Path): Array[Byte] = Files.readAllBytes(path)
val graphDef = readAllBytesOrExit(Paths.get("PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb"))
val g = new Graph()
g.importGraphDef(graphDef)
val session = new Session(g)
val result: Tensor = session.runner().feed("input", image).fetch("output").run().get(0))
How do I save my model to get both the Session and the Graph stored in the same file. as described in the "PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb" above.
Described here it mentions:
The serialized representation of the graph, often referred to as a
GraphDef, can be generated by toGraphDef() and equivalents in other
language APIs.
What are the equivalents in other language APIs? I dont find it obvious
Note: I already looked at the mnist_saved_model.py under tensorflow_serving but saving it through that procedure gives me a .pb file and a variables folder. When trying to load that .pb file I get: java.lang.IllegalArgumentException: Invalid GraphDef
Currently with the Java API of tensorflow, I only found how to save a graph as a graphDef (i.e. without its variables and meta-data). This can be done by just writing the Array[Byte] to a file:
Files.write(Paths.get(modelDir, modelName), myGraph.toGraphDef)
Here myGraph is a java object from the Graph class.
I would suggest to save your model from the Python API, using the SavedModel api defined here. It will save your model in a folder with both the serialized graph in a .pb file and the variables in a folder. Note the tag_constants you use as you'll need it in your scala/java code to load the model with the variables. Then the graph and session with variables are easily loaded with the SavedModelBundle java class from the java api. It returns you a wrapper with both the graph and the session containing the variables values:
val model = SavedModelBundle.load(modelDir, modelTag)
If you already tried this, maybe you can share your code to see why it returned an invalid GraphDef.
Another option is to freeze your graph, i.e. you turned your variable nodes into constant Nodes so everything is self-contained in the .pb file. Mores infos here for the freezing part

Categories

Resources