// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_SEQUENCE_LAbELER_H_h_
#define DLIB_SEQUENCE_LAbELER_H_h_
#include "sequence_labeler_abstract.h"
#include "../matrix.h"
#include <vector>
#include "../optimization/find_max_factor_graph_viterbi.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace fe_helpers
{
template <typename EXP>
struct dot_functor
{
dot_functor(const matrix_exp<EXP>& lambda_) : lambda(lambda_), value(0) {}
inline void operator() (
unsigned long feat_index
)
{
value += lambda(feat_index);
}
inline void operator() (
unsigned long feat_index,
double feat_value
)
{
value += feat_value*lambda(feat_index);
}
const matrix_exp<EXP>& lambda;
double value;
};
template <typename feature_extractor, typename EXP, typename sequence_type, typename EXP2>
double dot(
const matrix_exp<EXP>& lambda,
const feature_extractor& fe,
const sequence_type& sequence,
const matrix_exp<EXP2>& candidate_labeling,
unsigned long position
)
{
dot_functor<EXP> dot(lambda);
fe.get_features(dot, sequence, candidate_labeling, position);
return dot.value;
}
}
// ----------------------------------------------------------------------------------------
namespace impl
{
DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(
has_reject_labeling,
bool,
template reject_labeling<matrix<unsigned long> >,
(const typename T::sequence_type&, const matrix_exp<matrix<unsigned long> >&, unsigned long)const
)
template <typename feature_extractor, typename EXP, typename sequence_type>
typename enable_if<has_reject_labeling<feature_extractor>,bool>::type call_reject_labeling_if_exists (
const feature_extractor& fe,
const sequence_type& x,
const matrix_exp<EXP>& y,
unsigned long position
)
{
return fe.reject_labeling(x, y, position);
}
template <typename feature_extractor, typename EXP, typename sequence_type>
typename disable_if<has_reject_labeling<feature_extractor>,bool>::type call_reject_labeling_if_exists (
const feature_extractor& ,
const sequence_type& ,
const matrix_exp<EXP>& ,
unsigned long
)
{
return false;
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
typename enable_if<dlib::impl::has_reject_labeling<feature_extractor>,bool>::type contains_invalid_labeling (
const feature_extractor& fe,
const typename feature_extractor::sequence_type& x,
const std::vector<unsigned long>& y
)
{
if (x.size() != y.size())
return true;
matrix<unsigned long,0,1> node_states;
for (unsigned long i = 0; i < x.size(); ++i)
{
node_states.set_size(std::min(fe.order(),i) + 1);
for (unsigned long j = 0; j < (unsigned long)node_states.size(); ++j)
node_states(j) = y[i-j];
if (fe.reject_labeling(x, node_states, i))
return true;
}
return false;
}
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
typename disable_if<dlib::impl::has_reject_labeling<feature_extractor>,bool>::type contains_invalid_labeling (
const feature_extractor& ,
const typename feature_extractor::sequence_type& x,
const std::vector<unsigned long>& y
)
{
if (x.size() != y.size())
return true;
return false;
}
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
bool contains_invalid_labeling (
const feature_extractor& fe,
const std::vector<typename feature_extractor::sequence_type>& x,
const std::vector<std::vector<unsigned long> >& y
)
{
if (x.size() != y.size())
return true;
for (unsigned long i = 0; i < x.size(); ++i)
{
if (contains_invalid_labeling(fe,x[i],y[i]))
return true;
}
return false;
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
class sequence_labeler
{
public:
typedef typename feature_extractor::sequence_type sample_sequence_type;
typedef std::vector<unsigned long> labeled_sequence_type;
private:
class map_prob
{
public:
unsigned long order() const { return fe.order(); }
unsigned long num_states() const { return fe.num_labels(); }
map_prob(
const sample_sequence_type& x_,
const feature_extractor& fe_,
const matrix<double,0,1>& weights_
) :
sequence(x_),
fe(fe_),
weights(weights_)
{
}
unsigned long number_of_nodes(
) const
{
return sequence.size();
}
template <
typename EXP
>
double factor_value (
unsigned long node_id,
const matrix_exp<EXP>& node_states
) const
{
if (dlib::impl::call_reject_labeling_if_exists(fe, sequence, node_states, node_id))
return -std::numeric_limits<double>::infinity();
return fe_helpers::dot(weights, fe, sequence, node_states, node_id);
}
const sample_sequence_type& sequence;
const feature_extractor& fe;
const matrix<double,0,1>& weights;
};
public:
sequence_labeler()
{
weights.set_size(fe.num_features());
weights = 0;
}
explicit sequence_labeler(
const matrix<double,0,1>& weights_
) :
weights(weights_)
{
// make sure requires clause is not broken
DLIB_ASSERT(fe.num_features() == static_cast<unsigned long>(weights_.size()),
"\t sequence_labeler::sequence_labeler(weights_)"
<< "\n\t These sizes should match"
<< "\n\t fe.num_features(): " << fe.num_features()
<< "\n\t weights_.size(): " << weights_.size()
<< "\n\t this: " << this
);
}
sequence_labeler(
const matrix<double,0,1>& weights_,
const feature_extractor& fe_
) :
fe(fe_),
weights(weights_)
{
// make sure requires clause is not broken
DLIB_ASSERT(fe_.num_features() == static_cast<unsigned long>(weights_.size()),
"\t sequence_labeler::sequence_labeler(weights_,fe_)"
<< "\n\t These sizes should match"
<< "\n\t fe_.num_features(): " << fe_.num_features()
<< "\n\t weights_.size(): " << weights_.size()
<< "\n\t this: " << this
);
}
const feature_extractor& get_feature_extractor (
) const { return fe; }
const matrix<double,0,1>& get_weights (
) const { return weights; }
unsigned long num_labels (
) const { return fe.num_labels(); }
labeled_sequence_type operator() (
const sample_sequence_type& x
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(num_labels() > 0,
"\t labeled_sequence_type sequence_labeler::operator()(x)"
<< "\n\t You can't have no labels."
<< "\n\t this: " << this
);
labeled_sequence_type y;
find_max_factor_graph_viterbi(map_prob(x,fe,weights), y);
return y;
}
void label_sequence (
const sample_sequence_type& x,
labeled_sequence_type& y
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(num_labels() > 0,
"\t void sequence_labeler::label_sequence(x,y)"
<< "\n\t You can't have no labels."
<< "\n\t this: " << this
);
find_max_factor_graph_viterbi(map_prob(x,fe,weights), y);
}
private:
feature_extractor fe;
matrix<double,0,1> weights;
};
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
void serialize (
const sequence_labeler<feature_extractor>& item,
std::ostream& out
)
{
serialize(item.get_feature_extractor(), out);
serialize(item.get_weights(), out);
}
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
void deserialize (
sequence_labeler<feature_extractor>& item,
std::istream& in
)
{
feature_extractor fe;
matrix<double,0,1> weights;
deserialize(fe, in);
deserialize(weights, in);
item = sequence_labeler<feature_extractor>(weights, fe);
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_SEQUENCE_LAbELER_H_h_