/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.decoder.ff.lm.distributed_lm;
import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SrilmSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.ff.lm.NGramLanguageModel;
import joshua.decoder.ff.lm.buildin_lm.LMGrammarJAVA;
import joshua.decoder.ff.lm.srilm.LMGrammarSRILM;
import joshua.util.io.LineReader;
import joshua.util.Regex;
import java.io.IOException;
//import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
//import java.net.UnknownHostException;
import java.util.HashMap;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
/**
* this class implement
* (1) load lm file
* (2) listen to connection request
* (3) serve request for LM probablity
*
* @author Zhifei Li, <zhifei.work@gmail.com>
* @version $LastChangedDate: 2009-05-19 20:43:54 -0500 (Tue, 19 May 2009) $
*/
public class LMServer {
private static final Logger logger = Logger.getLogger(LMServer.class.getName());
//common options
public static int port = 9800;
static boolean use_srilm = true;
public static boolean use_left_euqivalent_state = false;
public static boolean use_right_euqivalent_state = false;
static int g_lm_order = 3;
static double lm_ceiling_cost = 100;//TODO: make sure LMGrammar is using this number
static String remote_symbol_tbl = null;
//lm specific
static String lm_file = null;
static Double interpolation_weight = null;//the interpolation weight of this lm
static String g_host_name = null;
//pointer
static NGramLanguageModel p_lm;
static HashMap<String,String> request_cache = new HashMap<String,String>();//cmd with result
static int cache_size_limit = 3000000;
// stat
static int g_n_request = 0;
static int g_n_cache_hit = 0;
static SymbolTable p_symbolTable;
public static void main(String[] args) throws IOException {
if (args.length != 1) {
System.err.println("Usage: java LMServer config_file");
if (logger.isLoggable(Level.FINE)) {
logger.fine("num of args is "+ args.length);
for (int i = 0; i < args.length; i++) {
logger.fine("arg is: " + args[i]);
}
}
System.exit(1);
}
String config_file = args[0].trim();
read_config_file(config_file);
ServerSocket serverSocket = null;
LMServer server = new LMServer();
//p_lm.write_vocab_map_srilm(remote_symbol_tbl);
//####write host infomation
//String hostname=LMServer.findHostName();//this one is not stable, sometimes throw exception
//String hostname="unknown";
//### begin loop
try {
serverSocket = new ServerSocket(port);
if (null == serverSocket) {
throw new IOException("server socket is null");
}
init_lm_grammar();
logger.info("finished lm reading, wait for connection");
// serverSocket = new ServerSocket(0);//0 means any free port
// port = serverSocket.getLocalPort();
while (true) {
Socket socket = serverSocket.accept();
logger.info("accept a connection from client");
ClientHandler handler = new ClientHandler(socket,server);
handler.start();
}
} catch (IOException ioe) {
logger.severe("cannot create serversocket at port or connection fail");
ioe.printStackTrace();
} finally {
try {
if (null != serverSocket) serverSocket.close();
} catch(IOException ioe) {
ioe.printStackTrace();
}
}
}
// BUG: duplicates initializeLanguageModel and initializeSymbolTable in JoshuaDecoder, needs unifying
public static void init_lm_grammar() throws IOException {
if (use_srilm) {
if (use_left_euqivalent_state || use_right_euqivalent_state) {
throw new IllegalArgumentException("when using local srilm, we cannot use suffix stuff");
}
p_symbolTable = new SrilmSymbol(remote_symbol_tbl, g_lm_order);
p_lm = new LMGrammarSRILM((SrilmSymbol)p_symbolTable, g_lm_order, lm_file);
} else {
//p_lm = new LMGrammar_JAVA(g_lm_order, lm_file, use_left_euqivalent_state);
//big bug: should load the consistent symbol files
p_symbolTable = new BuildinSymbol(remote_symbol_tbl);
p_lm = new LMGrammarJAVA((BuildinSymbol)p_symbolTable, g_lm_order, lm_file, use_left_euqivalent_state, use_right_euqivalent_state);
}
}
// BUG: this is duplicating code in JoshuaConfiguration, needs unifying
public static void read_config_file(String config_file)
throws IOException {
LineReader configReader = new LineReader(config_file);
try { for (String line : configReader) {
//line = line.trim().toLowerCase();
line = line.trim();
if (Regex.commentOrEmptyLine.matches(line)) continue;
if (line.indexOf("=") != -1) { //parameters
String[] fds = Regex.equalsWithSpaces.split(line);
if (fds.length != 2) {
throw new IllegalArgumentException("Wrong config line: " + line);
}
if ("lm_file".equals(fds[0])) {
lm_file = fds[1].trim();
if (logger.isLoggable(Level.FINE))
logger.fine(String.format("lm file: %s", lm_file));
} else if ("use_srilm".equals(fds[0])) {
use_srilm = Boolean.valueOf(fds[1]);
if (logger.isLoggable(Level.FINE))
logger.fine(String.format("use_srilm: %s", use_srilm));
} else if ("lm_ceiling_cost".equals(fds[0])) {
lm_ceiling_cost = Double.parseDouble(fds[1]);
if (logger.isLoggable(Level.FINE))
logger.fine(String.format("lm_ceiling_cost: %s", lm_ceiling_cost));
} else if ("use_left_euqivalent_state".equals(fds[0])) {
use_left_euqivalent_state = Boolean.valueOf(fds[1]);
if (logger.isLoggable(Level.FINE))
logger.fine(String.format("use_left_euqivalent_state: %s", use_left_euqivalent_state));
} else if ("use_right_euqivalent_state".equals(fds[0])) {
use_right_euqivalent_state = Boolean.valueOf(fds[1]);
if (logger.isLoggable(Level.FINE))
logger.fine(String.format("use_right_euqivalent_state: %s", use_right_euqivalent_state));
} else if ("order".equals(fds[0])) {
g_lm_order = Integer.parseInt(fds[1]);
if (logger.isLoggable(Level.FINE))
logger.fine(String.format("g_lm_order: %s", g_lm_order));
} else if ("remote_lm_server_port".equals(fds[0])) {
port = Integer.parseInt(fds[1]);
if (logger.isLoggable(Level.FINE))
logger.fine(String.format("remote_lm_server_port: %s", port));
} else if ("remote_symbol_tbl".equals(fds[0])) {
remote_symbol_tbl = fds[1];
if (logger.isLoggable(Level.FINE))
logger.fine(String.format("remote_symbol_tbl: %s", remote_symbol_tbl));
} else if ("hostname".equals(fds[0])) {
g_host_name = fds[1].trim();
if (logger.isLoggable(Level.FINE))
logger.fine(String.format("host name is: %s", g_host_name));
} else if ("interpolation_weight".equals(fds[0])) {
interpolation_weight = Double.parseDouble(fds[1]);
if (logger.isLoggable(Level.FINE))
logger.fine(String.format("interpolation_weightt: %s", interpolation_weight));
} else {
logger.warning("LMServer doesn't use config line: " + line);
//System.exit(1);
}
}
} } finally { configReader.close(); }
}
// used by server to process diffent Client
public static class ClientHandler extends Thread {
public static class DecodedStructure {
String cmd;
int num;
int[] wrds;
}
LMServer parent;
private Socket socket;
private BufferedReader in;
private PrintWriter out;
public ClientHandler(Socket sock, LMServer pa) throws IOException {
parent = pa;
socket = sock;
in = new BufferedReader(
new InputStreamReader(socket.getInputStream()));
out = new PrintWriter(
new OutputStreamWriter(socket.getOutputStream()));
}
public void run() {
String line_in;
String line_out;
try {
while ((line_in = in.readLine()) != null) {
//TODO block read
//System.out.println("coming in: " + line);
//line_out = process_request(line_in);
line_out = process_request_no_cache(line_in);
out.println(line_out);
out.flush();
}
} catch(IOException ioe) {
ioe.printStackTrace();
} finally {
try {
in.close();
out.close();
socket.close();
} catch(IOException ioe) {
ioe.printStackTrace();
}
}
}
private String process_request_no_cache(String packet) {
//search cache
g_n_request++;
String cmd_res = process_request_helper(packet);
if (logger.isLoggable(Level.FINE) && g_n_request % 50000 == 0) {
logger.fine("n_requests: " + g_n_request);
}
return cmd_res;
}
//This is the funciton that application specific
private String process_request_helper(String line) {
DecodedStructure ds = decode_packet(line);
if ("prob".equals(ds.cmd)) {
return get_prob(ds);
} else if ("prob_bow".equals(ds.cmd)) {
return get_prob_backoff_state(ds);
} else if ("equiv_left".equals(ds.cmd)) {
return get_left_equiv_state(ds);
} else if ("equiv_right".equals(ds.cmd)) {
return get_right_equiv_state(ds);
} else {
logger.severe("error : Wrong request line: " + line);
//System.exit(1);
return "";
}
}
// format: prob order wrds
private String get_prob(DecodedStructure ds) {
return Double.toString(p_lm.ngramLogProbability(ds.wrds, ds.num));
}
// format: prob order wrds
private String get_prob_backoff_state(DecodedStructure ds) {
throw new RuntimeException("call get_prob_backoff_state in lmserver, must exit");
/*Double res = p_lm.get_prob_backoff_state(ds.wrds, ds.num, ds.num);
return res.toString();*/
}
// format: prob order wrds
private String get_left_equiv_state(DecodedStructure ds) {
throw new RuntimeException("call get_left_equiv_state in lmserver, must exit");
}
// format: prob order wrds
private String get_right_equiv_state(DecodedStructure ds) {
throw new RuntimeException("call get_right_equiv_state in lmserver, must exit");
}
private DecodedStructure decode_packet(String packet) {
String[] fds = Regex.spaces.split(packet);
DecodedStructure res = new DecodedStructure();
res.cmd = fds[0].trim();
res.num = Integer.parseInt(fds[1]);
int[] wrds = new int[fds.length-2];
for (int i = 2; i < fds.length; i++) {
wrds[i-2] = Integer.parseInt(fds[i]);
}
res.wrds = wrds;
return res;
}
}
}